14360: Comment stubServer.
[arvados.git] / lib / dispatchcloud / test / stub_driver.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package test
6
7 import (
8         "crypto/rand"
9         "errors"
10         "fmt"
11         "io"
12         math_rand "math/rand"
13         "sync"
14
15         "git.curoverse.com/arvados.git/lib/cloud"
16         "git.curoverse.com/arvados.git/sdk/go/arvados"
17         "github.com/mitchellh/mapstructure"
18         "golang.org/x/crypto/ssh"
19 )
20
21 type StubExecFunc func(instance cloud.Instance, command string, stdin io.Reader, stdout, stderr io.Writer) uint32
22
23 // A StubDriver implements cloud.Driver by setting up local SSH
24 // servers that pass their command execution requests to the provided
25 // SSHExecFunc.
26 type StubDriver struct {
27         Exec           StubExecFunc
28         HostKey        ssh.Signer
29         AuthorizedKeys []ssh.PublicKey
30
31         ErrorRateDestroy float64
32
33         instanceSets []*StubInstanceSet
34 }
35
36 // InstanceSet returns a new *StubInstanceSet.
37 func (sd *StubDriver) InstanceSet(params map[string]interface{}, id cloud.InstanceSetID) (cloud.InstanceSet, error) {
38         sis := StubInstanceSet{
39                 driver:  sd,
40                 servers: map[cloud.InstanceID]*stubServer{},
41         }
42         sd.instanceSets = append(sd.instanceSets, &sis)
43         return &sis, mapstructure.Decode(params, &sis)
44 }
45
46 // InstanceSets returns all instances that have been created by the
47 // driver. This can be used to test a component that uses the driver
48 // but doesn't expose the InstanceSets it has created.
49 func (sd *StubDriver) InstanceSets() []*StubInstanceSet {
50         return sd.instanceSets
51 }
52
53 type StubInstanceSet struct {
54         driver  *StubDriver
55         servers map[cloud.InstanceID]*stubServer
56         mtx     sync.RWMutex
57         stopped bool
58 }
59
60 func (sis *StubInstanceSet) Create(it arvados.InstanceType, image cloud.ImageID, tags cloud.InstanceTags, authKey ssh.PublicKey) (cloud.Instance, error) {
61         sis.mtx.Lock()
62         defer sis.mtx.Unlock()
63         if sis.stopped {
64                 return nil, errors.New("StubInstanceSet: Create called after Stop")
65         }
66         ak := sis.driver.AuthorizedKeys
67         if authKey != nil {
68                 ak = append([]ssh.PublicKey{authKey}, ak...)
69         }
70         var ss *stubServer
71         ss = &stubServer{
72                 sis:          sis,
73                 id:           cloud.InstanceID(fmt.Sprintf("stub-%s-%x", it.ProviderType, math_rand.Int63())),
74                 tags:         copyTags(tags),
75                 providerType: it.ProviderType,
76                 SSHService: SSHService{
77                         HostKey:        sis.driver.HostKey,
78                         AuthorizedKeys: ak,
79                         Exec: func(command string, stdin io.Reader, stdout, stderr io.Writer) uint32 {
80                                 return sis.driver.Exec(ss.Instance(), command, stdin, stdout, stderr)
81                         },
82                 },
83         }
84
85         sis.servers[ss.id] = ss
86         return ss.Instance(), nil
87 }
88
89 func (sis *StubInstanceSet) Instances(cloud.InstanceTags) ([]cloud.Instance, error) {
90         sis.mtx.RLock()
91         defer sis.mtx.RUnlock()
92         var r []cloud.Instance
93         for _, ss := range sis.servers {
94                 r = append(r, ss.Instance())
95         }
96         return r, nil
97 }
98
99 func (sis *StubInstanceSet) Stop() {
100         sis.mtx.Lock()
101         defer sis.mtx.Unlock()
102         if sis.stopped {
103                 panic("Stop called twice")
104         }
105         sis.stopped = true
106 }
107
108 // stubServer is a fake server that runs an SSH service. It represents
109 // a VM running in a fake cloud.
110 //
111 // Note this is distinct from a stubInstance, which is a snapshot of
112 // the VM's metadata. As with a VM in a real cloud, the stubServer
113 // keeps running (and might change IP addresses, shut down, etc.)
114 // without updating any stubInstances that have been returned to
115 // callers.
116 type stubServer struct {
117         sis          *StubInstanceSet
118         id           cloud.InstanceID
119         tags         cloud.InstanceTags
120         providerType string
121         SSHService   SSHService
122         sync.Mutex
123 }
124
125 func (ss *stubServer) Instance() stubInstance {
126         ss.Lock()
127         defer ss.Unlock()
128         return stubInstance{
129                 ss:   ss,
130                 addr: ss.SSHService.Address(),
131                 // We deliberately return a cached/stale copy of the
132                 // real tags here, so that (Instance)Tags() sometimes
133                 // returns old data after a call to
134                 // (Instance)SetTags().  This is permitted by the
135                 // driver interface, and this might help remind
136                 // callers that they need to tolerate it.
137                 tags: copyTags(ss.tags),
138         }
139 }
140
141 type stubInstance struct {
142         ss   *stubServer
143         addr string
144         tags cloud.InstanceTags
145 }
146
147 func (si stubInstance) ID() cloud.InstanceID {
148         return si.ss.id
149 }
150
151 func (si stubInstance) Address() string {
152         return si.addr
153 }
154
155 func (si stubInstance) Destroy() error {
156         if math_rand.Float64() < si.ss.sis.driver.ErrorRateDestroy {
157                 return errors.New("instance could not be destroyed")
158         }
159         si.ss.SSHService.Close()
160         sis := si.ss.sis
161         sis.mtx.Lock()
162         defer sis.mtx.Unlock()
163         delete(sis.servers, si.ss.id)
164         return nil
165 }
166
167 func (si stubInstance) ProviderType() string {
168         return si.ss.providerType
169 }
170
171 func (si stubInstance) SetTags(tags cloud.InstanceTags) error {
172         tags = copyTags(tags)
173         ss := si.ss
174         go func() {
175                 ss.Lock()
176                 defer ss.Unlock()
177                 ss.tags = tags
178         }()
179         return nil
180 }
181
182 func (si stubInstance) Tags() cloud.InstanceTags {
183         return si.tags
184 }
185
186 func (si stubInstance) String() string {
187         return string(si.ss.id)
188 }
189
190 func (si stubInstance) VerifyHostKey(key ssh.PublicKey, client *ssh.Client) error {
191         buf := make([]byte, 512)
192         _, err := io.ReadFull(rand.Reader, buf)
193         if err != nil {
194                 return err
195         }
196         sig, err := si.ss.sis.driver.HostKey.Sign(rand.Reader, buf)
197         if err != nil {
198                 return err
199         }
200         return key.Verify(buf, sig)
201 }
202
203 func copyTags(src cloud.InstanceTags) cloud.InstanceTags {
204         dst := cloud.InstanceTags{}
205         for k, v := range src {
206                 dst[k] = v
207         }
208         return dst
209 }