14360: Locking comment.
[arvados.git] / lib / dispatchcloud / test / stub_driver.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package test
6
7 import (
8         "crypto/rand"
9         "errors"
10         "fmt"
11         "io"
12         math_rand "math/rand"
13         "sync"
14
15         "git.curoverse.com/arvados.git/lib/cloud"
16         "git.curoverse.com/arvados.git/sdk/go/arvados"
17         "github.com/mitchellh/mapstructure"
18         "golang.org/x/crypto/ssh"
19 )
20
21 type StubExecFunc func(instance cloud.Instance, command string, stdin io.Reader, stdout, stderr io.Writer) uint32
22
23 // A StubDriver implements cloud.Driver by setting up local SSH
24 // servers that pass their command execution requests to the provided
25 // SSHExecFunc.
26 type StubDriver struct {
27         Exec           StubExecFunc
28         HostKey        ssh.Signer
29         AuthorizedKeys []ssh.PublicKey
30
31         ErrorRateDestroy float64
32
33         instanceSets []*StubInstanceSet
34 }
35
36 // InstanceSet returns a new *StubInstanceSet.
37 func (sd *StubDriver) InstanceSet(params map[string]interface{}, id cloud.InstanceSetID) (cloud.InstanceSet, error) {
38         sis := StubInstanceSet{
39                 driver:  sd,
40                 servers: map[cloud.InstanceID]*stubServer{},
41         }
42         sd.instanceSets = append(sd.instanceSets, &sis)
43         return &sis, mapstructure.Decode(params, &sis)
44 }
45
46 // InstanceSets returns all instances that have been created by the
47 // driver. This can be used to test a component that uses the driver
48 // but doesn't expose the InstanceSets it has created.
49 func (sd *StubDriver) InstanceSets() []*StubInstanceSet {
50         return sd.instanceSets
51 }
52
53 type StubInstanceSet struct {
54         driver  *StubDriver
55         servers map[cloud.InstanceID]*stubServer
56         mtx     sync.RWMutex
57         stopped bool
58 }
59
60 func (sis *StubInstanceSet) Create(it arvados.InstanceType, image cloud.ImageID, tags cloud.InstanceTags, authKey ssh.PublicKey) (cloud.Instance, error) {
61         sis.mtx.Lock()
62         defer sis.mtx.Unlock()
63         if sis.stopped {
64                 return nil, errors.New("StubInstanceSet: Create called after Stop")
65         }
66         ak := sis.driver.AuthorizedKeys
67         if authKey != nil {
68                 ak = append([]ssh.PublicKey{authKey}, ak...)
69         }
70         var ss *stubServer
71         ss = &stubServer{
72                 sis:          sis,
73                 id:           cloud.InstanceID(fmt.Sprintf("stub-%s-%x", it.ProviderType, math_rand.Int63())),
74                 tags:         copyTags(tags),
75                 providerType: it.ProviderType,
76                 SSHService: SSHService{
77                         HostKey:        sis.driver.HostKey,
78                         AuthorizedKeys: ak,
79                         Exec: func(command string, stdin io.Reader, stdout, stderr io.Writer) uint32 {
80                                 return sis.driver.Exec(ss.Instance(), command, stdin, stdout, stderr)
81                         },
82                 },
83         }
84
85         sis.servers[ss.id] = ss
86         return ss.Instance(), nil
87 }
88
89 func (sis *StubInstanceSet) Instances(cloud.InstanceTags) ([]cloud.Instance, error) {
90         sis.mtx.RLock()
91         defer sis.mtx.RUnlock()
92         var r []cloud.Instance
93         for _, ss := range sis.servers {
94                 r = append(r, ss.Instance())
95         }
96         return r, nil
97 }
98
99 func (sis *StubInstanceSet) Stop() {
100         sis.mtx.Lock()
101         defer sis.mtx.Unlock()
102         if sis.stopped {
103                 panic("Stop called twice")
104         }
105         sis.stopped = true
106 }
107
108 type stubServer struct {
109         sis          *StubInstanceSet
110         id           cloud.InstanceID
111         tags         cloud.InstanceTags
112         providerType string
113         SSHService   SSHService
114         sync.Mutex
115 }
116
117 func (ss *stubServer) Instance() stubInstance {
118         ss.Lock()
119         defer ss.Unlock()
120         return stubInstance{
121                 ss:   ss,
122                 addr: ss.SSHService.Address(),
123                 // We deliberately return a cached/stale copy of the
124                 // real tags here, so that (Instance)Tags() sometimes
125                 // returns old data after a call to
126                 // (Instance)SetTags().  This is permitted by the
127                 // driver interface, and this might help remind
128                 // callers that they need to tolerate it.
129                 tags: copyTags(ss.tags),
130         }
131 }
132
133 type stubInstance struct {
134         ss   *stubServer
135         addr string
136         tags cloud.InstanceTags
137 }
138
139 func (si stubInstance) ID() cloud.InstanceID {
140         return si.ss.id
141 }
142
143 func (si stubInstance) Address() string {
144         return si.addr
145 }
146
147 func (si stubInstance) Destroy() error {
148         if math_rand.Float64() < si.ss.sis.driver.ErrorRateDestroy {
149                 return errors.New("instance could not be destroyed")
150         }
151         si.ss.SSHService.Close()
152         sis := si.ss.sis
153         sis.mtx.Lock()
154         defer sis.mtx.Unlock()
155         delete(sis.servers, si.ss.id)
156         return nil
157 }
158
159 func (si stubInstance) ProviderType() string {
160         return si.ss.providerType
161 }
162
163 func (si stubInstance) SetTags(tags cloud.InstanceTags) error {
164         tags = copyTags(tags)
165         ss := si.ss
166         go func() {
167                 ss.Lock()
168                 defer ss.Unlock()
169                 ss.tags = tags
170         }()
171         return nil
172 }
173
174 func (si stubInstance) Tags() cloud.InstanceTags {
175         return si.tags
176 }
177
178 func (si stubInstance) String() string {
179         return string(si.ss.id)
180 }
181
182 func (si stubInstance) VerifyHostKey(key ssh.PublicKey, client *ssh.Client) error {
183         buf := make([]byte, 512)
184         _, err := io.ReadFull(rand.Reader, buf)
185         if err != nil {
186                 return err
187         }
188         sig, err := si.ss.sis.driver.HostKey.Sign(rand.Reader, buf)
189         if err != nil {
190                 return err
191         }
192         return key.Verify(buf, sig)
193 }
194
195 func copyTags(src cloud.InstanceTags) cloud.InstanceTags {
196         dst := cloud.InstanceTags{}
197         for k, v := range src {
198                 dst[k] = v
199         }
200         return dst
201 }