// An SSHExecFunc handles an "exec" session on a multiplexed SSH
// connection.
-type SSHExecFunc func(command string, stdin io.Reader, stdout, stderr io.Writer) uint32
+type SSHExecFunc func(env map[string]string, command string, stdin io.Reader, stdout, stderr io.Writer) uint32
// An SSHService accepts SSH connections on an available TCP port and
// passes clients' "exec" sessions to the provided SSHExecFunc.
log.Printf("accept channel: %s", err)
return
}
- var execReq struct {
- Command string
- }
+ didExec := false
+ sessionEnv := map[string]string{}
go func() {
for req := range reqs {
- if req.Type == "exec" && execReq.Command == "" {
+ switch {
+ case didExec:
+ // Reject anything after exec
+ req.Reply(false, nil)
+ case req.Type == "exec":
+ var execReq struct {
+ Command string
+ }
req.Reply(true, nil)
ssh.Unmarshal(req.Payload, &execReq)
go func() {
var resp struct {
Status uint32
}
- resp.Status = ss.Exec(execReq.Command, ch, ch, ch.Stderr())
+ resp.Status = ss.Exec(sessionEnv, execReq.Command, ch, ch, ch.Stderr())
ch.SendRequest("exit-status", false, ssh.Marshal(&resp))
ch.Close()
}()
+ didExec = true
+ case req.Type == "env":
+ var envReq struct {
+ Name string
+ Value string
+ }
+ req.Reply(true, nil)
+ ssh.Unmarshal(req.Payload, &envReq)
+ sessionEnv[envReq.Name] = envReq.Value
}
}
}()