1 // Copyright (C) The Arvados Authors. All rights reserved.
3 // SPDX-License-Identifier: AGPL-3.0
15 "git.curoverse.com/arvados.git/lib/cloud"
16 "git.curoverse.com/arvados.git/sdk/go/arvados"
17 "github.com/mitchellh/mapstructure"
18 "golang.org/x/crypto/ssh"
21 type StubExecFunc func(instance cloud.Instance, command string, stdin io.Reader, stdout, stderr io.Writer) uint32
23 // A StubDriver implements cloud.Driver by setting up local SSH
24 // servers that pass their command execution requests to the provided
26 type StubDriver struct {
29 AuthorizedKeys []ssh.PublicKey
31 ErrorRateDestroy float64
33 instanceSets []*StubInstanceSet
36 // InstanceSet returns a new *StubInstanceSet.
37 func (sd *StubDriver) InstanceSet(params map[string]interface{}, id cloud.InstanceSetID) (cloud.InstanceSet, error) {
38 sis := StubInstanceSet{
40 servers: map[cloud.InstanceID]*stubServer{},
42 sd.instanceSets = append(sd.instanceSets, &sis)
43 return &sis, mapstructure.Decode(params, &sis)
46 // InstanceSets returns all instances that have been created by the
47 // driver. This can be used to test a component that uses the driver
48 // but doesn't expose the InstanceSets it has created.
49 func (sd *StubDriver) InstanceSets() []*StubInstanceSet {
50 return sd.instanceSets
53 type StubInstanceSet struct {
55 servers map[cloud.InstanceID]*stubServer
60 func (sis *StubInstanceSet) Create(it arvados.InstanceType, image cloud.ImageID, tags cloud.InstanceTags, authKey ssh.PublicKey) (cloud.Instance, error) {
62 defer sis.mtx.Unlock()
64 return nil, errors.New("StubInstanceSet: Create called after Stop")
66 ak := sis.driver.AuthorizedKeys
68 ak = append([]ssh.PublicKey{authKey}, ak...)
73 id: cloud.InstanceID(fmt.Sprintf("stub-%s-%x", it.ProviderType, math_rand.Int63())),
75 providerType: it.ProviderType,
76 SSHService: SSHService{
77 HostKey: sis.driver.HostKey,
79 Exec: func(command string, stdin io.Reader, stdout, stderr io.Writer) uint32 {
80 return sis.driver.Exec(ss.Instance(), command, stdin, stdout, stderr)
85 sis.servers[ss.id] = ss
86 return ss.Instance(), nil
89 func (sis *StubInstanceSet) Instances(cloud.InstanceTags) ([]cloud.Instance, error) {
91 defer sis.mtx.RUnlock()
92 var r []cloud.Instance
93 for _, ss := range sis.servers {
94 r = append(r, ss.Instance())
99 func (sis *StubInstanceSet) Stop() {
101 defer sis.mtx.Unlock()
103 panic("Stop called twice")
108 type stubServer struct {
111 tags cloud.InstanceTags
113 SSHService SSHService
117 func (ss *stubServer) Instance() stubInstance {
122 addr: ss.SSHService.Address(),
123 // We deliberately return a cached/stale copy of the
124 // real tags here, so that (Instance)Tags() sometimes
125 // returns old data after a call to
126 // (Instance)SetTags(). This is permitted by the
127 // driver interface, and this might help remind
128 // callers that they need to tolerate it.
129 tags: copyTags(ss.tags),
133 type stubInstance struct {
136 tags cloud.InstanceTags
139 func (si stubInstance) ID() cloud.InstanceID {
143 func (si stubInstance) Address() string {
147 func (si stubInstance) Destroy() error {
148 if math_rand.Float64() < si.ss.sis.driver.ErrorRateDestroy {
149 return errors.New("instance could not be destroyed")
151 si.ss.SSHService.Close()
154 defer sis.mtx.Unlock()
155 delete(sis.servers, si.ss.id)
159 func (si stubInstance) ProviderType() string {
160 return si.ss.providerType
163 func (si stubInstance) SetTags(tags cloud.InstanceTags) error {
164 tags = copyTags(tags)
174 func (si stubInstance) Tags() cloud.InstanceTags {
178 func (si stubInstance) String() string {
179 return string(si.ss.id)
182 func (si stubInstance) VerifyHostKey(key ssh.PublicKey, client *ssh.Client) error {
183 buf := make([]byte, 512)
184 _, err := io.ReadFull(rand.Reader, buf)
188 sig, err := si.ss.sis.driver.HostKey.Sign(rand.Reader, buf)
192 return key.Verify(buf, sig)
195 func copyTags(src cloud.InstanceTags) cloud.InstanceTags {
196 dst := cloud.InstanceTags{}
197 for k, v := range src {