15954: Fix process start/stop race.
[arvados.git] / lib / dispatchcloud / ssh_executor / executor.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 // Package ssh_executor provides an implementation of pool.Executor
6 // using a long-lived multiplexed SSH session.
7 package ssh_executor
8
9 import (
10         "bytes"
11         "errors"
12         "io"
13         "net"
14         "sync"
15         "time"
16
17         "git.arvados.org/arvados.git/lib/cloud"
18         "golang.org/x/crypto/ssh"
19 )
20
21 // New returns a new Executor, using the given target.
22 func New(t cloud.ExecutorTarget) *Executor {
23         return &Executor{target: t}
24 }
25
26 // An Executor uses a multiplexed SSH connection to execute shell
27 // commands on a remote target. It reconnects automatically after
28 // errors.
29 //
30 // When setting up a connection, the Executor accepts whatever host
31 // key is provided by the remote server, then passes the received key
32 // and the SSH connection to the target's VerifyHostKey method before
33 // executing commands on the connection.
34 //
35 // A zero Executor must not be used before calling SetTarget.
36 //
37 // An Executor must not be copied.
38 type Executor struct {
39         target     cloud.ExecutorTarget
40         targetPort string
41         targetUser string
42         signers    []ssh.Signer
43         mtx        sync.RWMutex // controls access to instance after creation
44
45         client      *ssh.Client
46         clientErr   error
47         clientOnce  sync.Once     // initialized private state
48         clientSetup chan bool     // len>0 while client setup is in progress
49         hostKey     ssh.PublicKey // most recent host key that passed verification, if any
50 }
51
52 // SetSigners updates the set of private keys that will be offered to
53 // the target next time the Executor sets up a new connection.
54 func (exr *Executor) SetSigners(signers ...ssh.Signer) {
55         exr.mtx.Lock()
56         defer exr.mtx.Unlock()
57         exr.signers = signers
58 }
59
60 // SetTarget sets the current target. The new target will be used next
61 // time a new connection is set up; until then, the Executor will
62 // continue to use the existing target.
63 //
64 // The new target is assumed to represent the same host as the
65 // previous target, although its address and host key might differ.
66 func (exr *Executor) SetTarget(t cloud.ExecutorTarget) {
67         exr.mtx.Lock()
68         defer exr.mtx.Unlock()
69         exr.target = t
70 }
71
72 // SetTargetPort sets the default port (name or number) to connect
73 // to. This is used only when the address returned by the target's
74 // Address() method does not specify a port. If the given port is
75 // empty (or SetTargetPort is not called at all), the default port is
76 // "ssh".
77 func (exr *Executor) SetTargetPort(port string) {
78         exr.mtx.Lock()
79         defer exr.mtx.Unlock()
80         exr.targetPort = port
81 }
82
83 // Target returns the current target.
84 func (exr *Executor) Target() cloud.ExecutorTarget {
85         exr.mtx.RLock()
86         defer exr.mtx.RUnlock()
87         return exr.target
88 }
89
90 // Execute runs cmd on the target. If an existing connection is not
91 // usable, it sets up a new connection to the current target.
92 func (exr *Executor) Execute(env map[string]string, cmd string, stdin io.Reader) ([]byte, []byte, error) {
93         session, err := exr.newSession()
94         if err != nil {
95                 return nil, nil, err
96         }
97         defer session.Close()
98         for k, v := range env {
99                 err = session.Setenv(k, v)
100                 if err != nil {
101                         return nil, nil, err
102                 }
103         }
104         var stdout, stderr bytes.Buffer
105         session.Stdin = stdin
106         session.Stdout = &stdout
107         session.Stderr = &stderr
108         err = session.Run(cmd)
109         return stdout.Bytes(), stderr.Bytes(), err
110 }
111
112 // Close shuts down any active connections.
113 func (exr *Executor) Close() {
114         // Ensure exr is initialized
115         exr.sshClient(false)
116
117         exr.clientSetup <- true
118         if exr.client != nil {
119                 defer exr.client.Close()
120         }
121         exr.client, exr.clientErr = nil, errors.New("closed")
122         <-exr.clientSetup
123 }
124
125 // Create a new SSH session. If session setup fails or the SSH client
126 // hasn't been setup yet, setup a new SSH client and try again.
127 func (exr *Executor) newSession() (*ssh.Session, error) {
128         try := func(create bool) (*ssh.Session, error) {
129                 client, err := exr.sshClient(create)
130                 if err != nil {
131                         return nil, err
132                 }
133                 return client.NewSession()
134         }
135         session, err := try(false)
136         if err != nil {
137                 session, err = try(true)
138         }
139         return session, err
140 }
141
142 // Get the latest SSH client. If another goroutine is in the process
143 // of setting one up, wait for it to finish and return its result (or
144 // the last successfully setup client, if it fails).
145 func (exr *Executor) sshClient(create bool) (*ssh.Client, error) {
146         exr.clientOnce.Do(func() {
147                 exr.clientSetup = make(chan bool, 1)
148                 exr.clientErr = errors.New("client not yet created")
149         })
150         defer func() { <-exr.clientSetup }()
151         select {
152         case exr.clientSetup <- true:
153                 if create {
154                         client, err := exr.setupSSHClient()
155                         if err == nil || exr.client == nil {
156                                 if exr.client != nil {
157                                         // Hang up the previous
158                                         // (non-working) client
159                                         go exr.client.Close()
160                                 }
161                                 exr.client, exr.clientErr = client, err
162                         }
163                         if err != nil {
164                                 return nil, err
165                         }
166                 }
167         default:
168                 // Another goroutine is doing the above case.  Wait
169                 // for it to finish and return whatever it leaves in
170                 // wkr.client.
171                 exr.clientSetup <- true
172         }
173         return exr.client, exr.clientErr
174 }
175
176 func (exr *Executor) TargetHostPort() (string, string) {
177         addr := exr.Target().Address()
178         if addr == "" {
179                 return "", ""
180         }
181         h, p, err := net.SplitHostPort(addr)
182         if err != nil || p == "" {
183                 // Target address does not specify a port.  Use
184                 // targetPort, or "ssh".
185                 if h == "" {
186                         h = addr
187                 }
188                 if p = exr.targetPort; p == "" {
189                         p = "ssh"
190                 }
191         }
192         return h, p
193 }
194
195 // Create a new SSH client.
196 func (exr *Executor) setupSSHClient() (*ssh.Client, error) {
197         addr := net.JoinHostPort(exr.TargetHostPort())
198         if addr == ":" {
199                 return nil, errors.New("instance has no address")
200         }
201         var receivedKey ssh.PublicKey
202         client, err := ssh.Dial("tcp", addr, &ssh.ClientConfig{
203                 User: exr.Target().RemoteUser(),
204                 Auth: []ssh.AuthMethod{
205                         ssh.PublicKeys(exr.signers...),
206                 },
207                 HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
208                         receivedKey = key
209                         return nil
210                 },
211                 Timeout: time.Minute,
212         })
213         if err != nil {
214                 return nil, err
215         } else if receivedKey == nil {
216                 return nil, errors.New("BUG: key was never provided to HostKeyCallback")
217         }
218
219         if exr.hostKey == nil || !bytes.Equal(exr.hostKey.Marshal(), receivedKey.Marshal()) {
220                 err = exr.Target().VerifyHostKey(receivedKey, client)
221                 if err != nil {
222                         return nil, err
223                 }
224                 exr.hostKey = receivedKey
225         }
226         return client, nil
227 }