X-Git-Url: https://git.arvados.org/arvados.git/blobdiff_plain/b205525d0b7c7b9042513fe77d2e8061534208ae..3e9b9dcdcce3905fa33dde900ef99f27ba036fea:/lib/dispatchcloud/ssh_executor/executor.go diff --git a/lib/dispatchcloud/ssh_executor/executor.go b/lib/dispatchcloud/ssh_executor/executor.go index b5dba9870d..d608763cf5 100644 --- a/lib/dispatchcloud/ssh_executor/executor.go +++ b/lib/dispatchcloud/ssh_executor/executor.go @@ -36,9 +36,11 @@ func New(t cloud.ExecutorTarget) *Executor { // // An Executor must not be copied. type Executor struct { - target cloud.ExecutorTarget - signers []ssh.Signer - mtx sync.RWMutex // controls access to instance after creation + target cloud.ExecutorTarget + targetPort string + targetUser string + signers []ssh.Signer + mtx sync.RWMutex // controls access to instance after creation client *ssh.Client clientErr error @@ -67,6 +69,17 @@ func (exr *Executor) SetTarget(t cloud.ExecutorTarget) { exr.target = t } +// SetTargetPort sets the default port (name or number) to connect +// to. This is used only when the address returned by the target's +// Address() method does not specify a port. If the given port is +// empty (or SetTargetPort is not called at all), the default port is +// "ssh". +func (exr *Executor) SetTargetPort(port string) { + exr.mtx.Lock() + defer exr.mtx.Unlock() + exr.targetPort = port +} + // Target returns the current target. func (exr *Executor) Target() cloud.ExecutorTarget { exr.mtx.RLock() @@ -76,12 +89,18 @@ func (exr *Executor) Target() cloud.ExecutorTarget { // Execute runs cmd on the target. If an existing connection is not // usable, it sets up a new connection to the current target. -func (exr *Executor) Execute(cmd string, stdin io.Reader) ([]byte, []byte, error) { +func (exr *Executor) Execute(env map[string]string, cmd string, stdin io.Reader) ([]byte, []byte, error) { session, err := exr.newSession() if err != nil { return nil, nil, err } defer session.Close() + for k, v := range env { + err = session.Setenv(k, v) + if err != nil { + return nil, nil, err + } + } var stdout, stderr bytes.Buffer session.Stdin = stdin session.Stdout = &stdout @@ -154,16 +173,34 @@ func (exr *Executor) sshClient(create bool) (*ssh.Client, error) { return exr.client, exr.clientErr } +func (exr *Executor) TargetHostPort() (string, string) { + addr := exr.Target().Address() + if addr == "" { + return "", "" + } + h, p, err := net.SplitHostPort(addr) + if err != nil || p == "" { + // Target address does not specify a port. Use + // targetPort, or "ssh". + if h == "" { + h = addr + } + if p = exr.targetPort; p == "" { + p = "ssh" + } + } + return h, p +} + // Create a new SSH client. func (exr *Executor) setupSSHClient() (*ssh.Client, error) { - target := exr.Target() - addr := target.Address() - if addr == "" { + addr := net.JoinHostPort(exr.TargetHostPort()) + if addr == ":" { return nil, errors.New("instance has no address") } var receivedKey ssh.PublicKey client, err := ssh.Dial("tcp", addr, &ssh.ClientConfig{ - User: "root", + User: exr.Target().RemoteUser(), Auth: []ssh.AuthMethod{ ssh.PublicKeys(exr.signers...), }, @@ -180,7 +217,7 @@ func (exr *Executor) setupSSHClient() (*ssh.Client, error) { } if exr.hostKey == nil || !bytes.Equal(exr.hostKey.Marshal(), receivedKey.Marshal()) { - err = target.VerifyHostKey(receivedKey, client) + err = exr.Target().VerifyHostKey(receivedKey, client) if err != nil { return nil, err }