Merge branch '15397-remove-obsolete-apis'
[arvados.git] / lib / dispatchcloud / sshexecutor / executor.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 // Package sshexecutor provides an implementation of pool.Executor
6 // using a long-lived multiplexed SSH session.
7 package sshexecutor
8
9 import (
10         "bytes"
11         "errors"
12         "io"
13         "net"
14         "sync"
15         "time"
16
17         "git.arvados.org/arvados.git/lib/cloud"
18         "golang.org/x/crypto/ssh"
19 )
20
21 var ErrNoAddress = errors.New("instance has no address")
22
23 // New returns a new Executor, using the given target.
24 func New(t cloud.ExecutorTarget) *Executor {
25         return &Executor{target: t}
26 }
27
28 // An Executor uses a multiplexed SSH connection to execute shell
29 // commands on a remote target. It reconnects automatically after
30 // errors.
31 //
32 // When setting up a connection, the Executor accepts whatever host
33 // key is provided by the remote server, then passes the received key
34 // and the SSH connection to the target's VerifyHostKey method before
35 // executing commands on the connection.
36 //
37 // A zero Executor must not be used before calling SetTarget.
38 //
39 // An Executor must not be copied.
40 type Executor struct {
41         target     cloud.ExecutorTarget
42         targetPort string
43         targetUser string
44         signers    []ssh.Signer
45         mtx        sync.RWMutex // controls access to instance after creation
46
47         client      *ssh.Client
48         clientErr   error
49         clientOnce  sync.Once     // initialized private state
50         clientSetup chan bool     // len>0 while client setup is in progress
51         hostKey     ssh.PublicKey // most recent host key that passed verification, if any
52 }
53
54 // SetSigners updates the set of private keys that will be offered to
55 // the target next time the Executor sets up a new connection.
56 func (exr *Executor) SetSigners(signers ...ssh.Signer) {
57         exr.mtx.Lock()
58         defer exr.mtx.Unlock()
59         exr.signers = signers
60 }
61
62 // SetTarget sets the current target. The new target will be used next
63 // time a new connection is set up; until then, the Executor will
64 // continue to use the existing target.
65 //
66 // The new target is assumed to represent the same host as the
67 // previous target, although its address and host key might differ.
68 func (exr *Executor) SetTarget(t cloud.ExecutorTarget) {
69         exr.mtx.Lock()
70         defer exr.mtx.Unlock()
71         exr.target = t
72 }
73
74 // SetTargetPort sets the default port (name or number) to connect
75 // to. This is used only when the address returned by the target's
76 // Address() method does not specify a port. If the given port is
77 // empty (or SetTargetPort is not called at all), the default port is
78 // "ssh".
79 func (exr *Executor) SetTargetPort(port string) {
80         exr.mtx.Lock()
81         defer exr.mtx.Unlock()
82         exr.targetPort = port
83 }
84
85 // Target returns the current target.
86 func (exr *Executor) Target() cloud.ExecutorTarget {
87         exr.mtx.RLock()
88         defer exr.mtx.RUnlock()
89         return exr.target
90 }
91
92 // Execute runs cmd on the target. If an existing connection is not
93 // usable, it sets up a new connection to the current target.
94 func (exr *Executor) Execute(env map[string]string, cmd string, stdin io.Reader) ([]byte, []byte, error) {
95         session, err := exr.newSession()
96         if err != nil {
97                 return nil, nil, err
98         }
99         defer session.Close()
100         for k, v := range env {
101                 err = session.Setenv(k, v)
102                 if err != nil {
103                         return nil, nil, err
104                 }
105         }
106         var stdout, stderr bytes.Buffer
107         session.Stdin = stdin
108         session.Stdout = &stdout
109         session.Stderr = &stderr
110         err = session.Run(cmd)
111         return stdout.Bytes(), stderr.Bytes(), err
112 }
113
114 // Close shuts down any active connections.
115 func (exr *Executor) Close() {
116         // Ensure exr is initialized
117         exr.sshClient(false)
118
119         exr.clientSetup <- true
120         if exr.client != nil {
121                 defer exr.client.Close()
122         }
123         exr.client, exr.clientErr = nil, errors.New("closed")
124         <-exr.clientSetup
125 }
126
127 // Create a new SSH session. If session setup fails or the SSH client
128 // hasn't been setup yet, setup a new SSH client and try again.
129 func (exr *Executor) newSession() (*ssh.Session, error) {
130         try := func(create bool) (*ssh.Session, error) {
131                 client, err := exr.sshClient(create)
132                 if err != nil {
133                         return nil, err
134                 }
135                 return client.NewSession()
136         }
137         session, err := try(false)
138         if err != nil {
139                 session, err = try(true)
140         }
141         return session, err
142 }
143
144 // Get the latest SSH client. If another goroutine is in the process
145 // of setting one up, wait for it to finish and return its result (or
146 // the last successfully setup client, if it fails).
147 func (exr *Executor) sshClient(create bool) (*ssh.Client, error) {
148         exr.clientOnce.Do(func() {
149                 exr.clientSetup = make(chan bool, 1)
150                 exr.clientErr = errors.New("client not yet created")
151         })
152         defer func() { <-exr.clientSetup }()
153         select {
154         case exr.clientSetup <- true:
155                 if create {
156                         client, err := exr.setupSSHClient()
157                         if err == nil || exr.client == nil {
158                                 if exr.client != nil {
159                                         // Hang up the previous
160                                         // (non-working) client
161                                         go exr.client.Close()
162                                 }
163                                 exr.client, exr.clientErr = client, err
164                         }
165                         if err != nil {
166                                 return nil, err
167                         }
168                 }
169         default:
170                 // Another goroutine is doing the above case.  Wait
171                 // for it to finish and return whatever it leaves in
172                 // wkr.client.
173                 exr.clientSetup <- true
174         }
175         return exr.client, exr.clientErr
176 }
177
178 func (exr *Executor) TargetHostPort() (string, string) {
179         addr := exr.Target().Address()
180         if addr == "" {
181                 return "", ""
182         }
183         h, p, err := net.SplitHostPort(addr)
184         if err != nil || p == "" {
185                 // Target address does not specify a port.  Use
186                 // targetPort, or "ssh".
187                 if h == "" {
188                         h = addr
189                 }
190                 if p = exr.targetPort; p == "" {
191                         p = "ssh"
192                 }
193         }
194         return h, p
195 }
196
197 // Create a new SSH client.
198 func (exr *Executor) setupSSHClient() (*ssh.Client, error) {
199         addr := net.JoinHostPort(exr.TargetHostPort())
200         if addr == ":" {
201                 return nil, ErrNoAddress
202         }
203         var receivedKey ssh.PublicKey
204         client, err := ssh.Dial("tcp", addr, &ssh.ClientConfig{
205                 User: exr.Target().RemoteUser(),
206                 Auth: []ssh.AuthMethod{
207                         ssh.PublicKeys(exr.signers...),
208                 },
209                 HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
210                         receivedKey = key
211                         return nil
212                 },
213                 Timeout: time.Minute,
214         })
215         if err != nil {
216                 return nil, err
217         } else if receivedKey == nil {
218                 return nil, errors.New("BUG: key was never provided to HostKeyCallback")
219         }
220
221         if exr.hostKey == nil || !bytes.Equal(exr.hostKey.Marshal(), receivedKey.Marshal()) {
222                 err = exr.Target().VerifyHostKey(receivedKey, client)
223                 if err != nil {
224                         return nil, err
225                 }
226                 exr.hostKey = receivedKey
227         }
228         return client, nil
229 }