Merge branch '20235-probe-after-upgrade'
[arvados.git] / lib / dispatchcloud / worker / verify.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package worker
6
7 import (
8         "bytes"
9         "errors"
10         "fmt"
11
12         "git.arvados.org/arvados.git/lib/cloud"
13         "golang.org/x/crypto/ssh"
14 )
15
16 var (
17         errBadInstanceSecret = errors.New("bad instance secret")
18
19         // filename on instance, as given to shell (quoted accordingly)
20         instanceSecretFilename = "/var/run/arvados-instance-secret"
21         instanceSecretLength   = 40 // hex digits
22 )
23
24 type TagVerifier struct {
25         cloud.Instance
26         Secret         string
27         ReportVerified func(cloud.Instance)
28 }
29
30 func (tv TagVerifier) InitCommand() cloud.InitCommand {
31         return cloud.InitCommand(fmt.Sprintf("umask 0177 && echo -n %q >%s", tv.Secret, instanceSecretFilename))
32 }
33
34 func (tv TagVerifier) VerifyHostKey(pubKey ssh.PublicKey, client *ssh.Client) error {
35         if tv.ReportVerified != nil {
36                 tv.ReportVerified(tv.Instance)
37         }
38         if err := tv.Instance.VerifyHostKey(pubKey, client); err != cloud.ErrNotImplemented || tv.Secret == "" {
39                 // If the wrapped instance indicates it has a way to
40                 // verify the key, return that decision.
41                 return err
42         }
43         session, err := client.NewSession()
44         if err != nil {
45                 return err
46         }
47         defer session.Close()
48         var stdout, stderr bytes.Buffer
49         session.Stdin = bytes.NewBuffer(nil)
50         session.Stdout = &stdout
51         session.Stderr = &stderr
52         cmd := fmt.Sprintf("cat %s", instanceSecretFilename)
53         if u := tv.RemoteUser(); u != "root" {
54                 cmd = "sudo " + cmd
55         }
56         err = session.Run(cmd)
57         if err != nil {
58                 return err
59         }
60         if stdout.String() != tv.Secret {
61                 return errBadInstanceSecret
62         }
63         return nil
64 }