14807: Update API endpoints: instance_id is always a query param.
[arvados.git] / lib / dispatchcloud / ssh_executor / executor.go
index 804ae6f15e2ca4c88ce92ab168101430af2c09c0..feed1c2a78b82a84821f22eee99e39e960dbd431 100644 (file)
@@ -36,9 +36,11 @@ func New(t cloud.ExecutorTarget) *Executor {
 //
 // An Executor must not be copied.
 type Executor struct {
-       target  cloud.ExecutorTarget
-       signers []ssh.Signer
-       mtx     sync.RWMutex // controls access to instance after creation
+       target     cloud.ExecutorTarget
+       targetPort string
+       targetUser string
+       signers    []ssh.Signer
+       mtx        sync.RWMutex // controls access to instance after creation
 
        client      *ssh.Client
        clientErr   error
@@ -67,6 +69,17 @@ func (exr *Executor) SetTarget(t cloud.ExecutorTarget) {
        exr.target = t
 }
 
+// SetTargetPort sets the default port (name or number) to connect
+// to. This is used only when the address returned by the target's
+// Address() method does not specify a port. If the given port is
+// empty (or SetTargetPort is not called at all), the default port is
+// "ssh".
+func (exr *Executor) SetTargetPort(port string) {
+       exr.mtx.Lock()
+       defer exr.mtx.Unlock()
+       exr.targetPort = port
+}
+
 // Target returns the current target.
 func (exr *Executor) Target() cloud.ExecutorTarget {
        exr.mtx.RLock()
@@ -76,12 +89,18 @@ func (exr *Executor) Target() cloud.ExecutorTarget {
 
 // Execute runs cmd on the target. If an existing connection is not
 // usable, it sets up a new connection to the current target.
-func (exr *Executor) Execute(cmd string, stdin io.Reader) ([]byte, []byte, error) {
+func (exr *Executor) Execute(env map[string]string, cmd string, stdin io.Reader) ([]byte, []byte, error) {
        session, err := exr.newSession()
        if err != nil {
                return nil, nil, err
        }
        defer session.Close()
+       for k, v := range env {
+               err = session.Setenv(k, v)
+               if err != nil {
+                       return nil, nil, err
+               }
+       }
        var stdout, stderr bytes.Buffer
        session.Stdin = stdin
        session.Stdout = &stdout
@@ -90,6 +109,19 @@ func (exr *Executor) Execute(cmd string, stdin io.Reader) ([]byte, []byte, error
        return stdout.Bytes(), stderr.Bytes(), err
 }
 
+// Close shuts down any active connections.
+func (exr *Executor) Close() {
+       // Ensure exr is initialized
+       exr.sshClient(false)
+
+       exr.clientSetup <- true
+       if exr.client != nil {
+               defer exr.client.Close()
+       }
+       exr.client, exr.clientErr = nil, errors.New("closed")
+       <-exr.clientSetup
+}
+
 // Create a new SSH session. If session setup fails or the SSH client
 // hasn't been setup yet, setup a new SSH client and try again.
 func (exr *Executor) newSession() (*ssh.Session, error) {
@@ -121,6 +153,11 @@ func (exr *Executor) sshClient(create bool) (*ssh.Client, error) {
                if create {
                        client, err := exr.setupSSHClient()
                        if err == nil || exr.client == nil {
+                               if exr.client != nil {
+                                       // Hang up the previous
+                                       // (non-working) client
+                                       go exr.client.Close()
+                               }
                                exr.client, exr.clientErr = client, err
                        }
                        if err != nil {
@@ -143,9 +180,20 @@ func (exr *Executor) setupSSHClient() (*ssh.Client, error) {
        if addr == "" {
                return nil, errors.New("instance has no address")
        }
+       if h, p, err := net.SplitHostPort(addr); err != nil || p == "" {
+               // Target address does not specify a port.  Use
+               // targetPort, or "ssh".
+               if h == "" {
+                       h = addr
+               }
+               if p = exr.targetPort; p == "" {
+                       p = "ssh"
+               }
+               addr = net.JoinHostPort(h, p)
+       }
        var receivedKey ssh.PublicKey
        client, err := ssh.Dial("tcp", addr, &ssh.ClientConfig{
-               User: "root",
+               User: target.RemoteUser(),
                Auth: []ssh.AuthMethod{
                        ssh.PublicKeys(exr.signers...),
                },