1 // Copyright (C) The Arvados Authors. All rights reserved.
3 // SPDX-License-Identifier: AGPL-3.0
15 "git.curoverse.com/arvados.git/lib/cloud"
16 "git.curoverse.com/arvados.git/sdk/go/arvados"
17 "github.com/mitchellh/mapstructure"
18 "golang.org/x/crypto/ssh"
21 type StubExecFunc func(instance cloud.Instance, command string, stdin io.Reader, stdout, stderr io.Writer) uint32
23 // A StubDriver implements cloud.Driver by setting up local SSH
24 // servers that pass their command execution requests to the provided
26 type StubDriver struct {
29 AuthorizedKeys []ssh.PublicKey
31 ErrorRateDestroy float64
33 instanceSets []*StubInstanceSet
36 // InstanceSet returns a new *StubInstanceSet.
37 func (sd *StubDriver) InstanceSet(params map[string]interface{}, id cloud.InstanceSetID) (cloud.InstanceSet, error) {
38 sis := StubInstanceSet{
40 servers: map[cloud.InstanceID]*stubServer{},
42 sd.instanceSets = append(sd.instanceSets, &sis)
43 return &sis, mapstructure.Decode(params, &sis)
46 // InstanceSets returns all instances that have been created by the
47 // driver. This can be used to test a component that uses the driver
48 // but doesn't expose the InstanceSets it has created.
49 func (sd *StubDriver) InstanceSets() []*StubInstanceSet {
50 return sd.instanceSets
53 type StubInstanceSet struct {
55 servers map[cloud.InstanceID]*stubServer
60 func (sis *StubInstanceSet) Create(it arvados.InstanceType, image cloud.ImageID, tags cloud.InstanceTags, authKey ssh.PublicKey) (cloud.Instance, error) {
62 defer sis.mtx.Unlock()
64 return nil, errors.New("StubInstanceSet: Create called after Stop")
66 ak := sis.driver.AuthorizedKeys
68 ak = append([]ssh.PublicKey{authKey}, ak...)
73 id: cloud.InstanceID(fmt.Sprintf("stub-%s-%x", it.ProviderType, math_rand.Int63())),
75 providerType: it.ProviderType,
76 SSHService: SSHService{
77 HostKey: sis.driver.HostKey,
79 Exec: func(command string, stdin io.Reader, stdout, stderr io.Writer) uint32 {
80 return sis.driver.Exec(ss.Instance(), command, stdin, stdout, stderr)
85 sis.servers[ss.id] = ss
86 return ss.Instance(), nil
89 func (sis *StubInstanceSet) Instances(cloud.InstanceTags) ([]cloud.Instance, error) {
91 defer sis.mtx.RUnlock()
92 var r []cloud.Instance
93 for _, ss := range sis.servers {
94 r = append(r, ss.Instance())
99 func (sis *StubInstanceSet) Stop() {
101 defer sis.mtx.Unlock()
103 panic("Stop called twice")
108 // stubServer is a fake server that runs an SSH service. It represents
109 // a VM running in a fake cloud.
111 // Note this is distinct from a stubInstance, which is a snapshot of
112 // the VM's metadata. As with a VM in a real cloud, the stubServer
113 // keeps running (and might change IP addresses, shut down, etc.)
114 // without updating any stubInstances that have been returned to
116 type stubServer struct {
119 tags cloud.InstanceTags
121 SSHService SSHService
125 func (ss *stubServer) Instance() stubInstance {
130 addr: ss.SSHService.Address(),
131 // We deliberately return a cached/stale copy of the
132 // real tags here, so that (Instance)Tags() sometimes
133 // returns old data after a call to
134 // (Instance)SetTags(). This is permitted by the
135 // driver interface, and this might help remind
136 // callers that they need to tolerate it.
137 tags: copyTags(ss.tags),
141 type stubInstance struct {
144 tags cloud.InstanceTags
147 func (si stubInstance) ID() cloud.InstanceID {
151 func (si stubInstance) Address() string {
155 func (si stubInstance) Destroy() error {
156 if math_rand.Float64() < si.ss.sis.driver.ErrorRateDestroy {
157 return errors.New("instance could not be destroyed")
159 si.ss.SSHService.Close()
162 defer sis.mtx.Unlock()
163 delete(sis.servers, si.ss.id)
167 func (si stubInstance) ProviderType() string {
168 return si.ss.providerType
171 func (si stubInstance) SetTags(tags cloud.InstanceTags) error {
172 tags = copyTags(tags)
182 func (si stubInstance) Tags() cloud.InstanceTags {
186 func (si stubInstance) String() string {
187 return string(si.ss.id)
190 func (si stubInstance) VerifyHostKey(key ssh.PublicKey, client *ssh.Client) error {
191 buf := make([]byte, 512)
192 _, err := io.ReadFull(rand.Reader, buf)
196 sig, err := si.ss.sis.driver.HostKey.Sign(rand.Reader, buf)
200 return key.Verify(buf, sig)
203 func copyTags(src cloud.InstanceTags) cloud.InstanceTags {
204 dst := cloud.InstanceTags{}
205 for k, v := range src {