14291: Report errors from ImportKeyPair
[arvados.git] / lib / cloud / ec2 / ec2.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package ec2
6
7 import (
8         "encoding/base64"
9         "encoding/json"
10         "fmt"
11         "log"
12         "strings"
13         "sync"
14
15         "git.curoverse.com/arvados.git/lib/cloud"
16         "git.curoverse.com/arvados.git/sdk/go/arvados"
17         "github.com/aws/aws-sdk-go/aws"
18         "github.com/aws/aws-sdk-go/aws/credentials"
19         "github.com/aws/aws-sdk-go/aws/session"
20         "github.com/aws/aws-sdk-go/service/ec2"
21         "github.com/sirupsen/logrus"
22         "golang.org/x/crypto/ssh"
23 )
24
25 const arvadosDispatchID = "arvados-dispatch-id"
26 const tagPrefix = "arvados-dispatch-tag-"
27
28 // Driver is the ec2 implementation of the cloud.Driver interface.
29 var Driver = cloud.DriverFunc(newEC2InstanceSet)
30
31 type ec2InstanceSetConfig struct {
32         AccessKeyID      string
33         SecretAccessKey  string
34         Region           string
35         SecurityGroupIDs []string
36         SubnetID         string
37         AdminUsername    string
38 }
39
40 type ec2Interface interface {
41         ImportKeyPair(input *ec2.ImportKeyPairInput) (*ec2.ImportKeyPairOutput, error)
42         RunInstances(input *ec2.RunInstancesInput) (*ec2.Reservation, error)
43         DescribeInstances(input *ec2.DescribeInstancesInput) (*ec2.DescribeInstancesOutput, error)
44         CreateTags(input *ec2.CreateTagsInput) (*ec2.CreateTagsOutput, error)
45         TerminateInstances(input *ec2.TerminateInstancesInput) (*ec2.TerminateInstancesOutput, error)
46 }
47
48 type ec2InstanceSet struct {
49         ec2config    ec2InstanceSetConfig
50         dispatcherID cloud.InstanceSetID
51         logger       logrus.FieldLogger
52         client       ec2Interface
53         keysMtx      sync.Mutex
54         keys         map[string]string
55 }
56
57 func newEC2InstanceSet(config json.RawMessage, dispatcherID cloud.InstanceSetID, logger logrus.FieldLogger) (prv cloud.InstanceSet, err error) {
58         instanceSet := &ec2InstanceSet{
59                 dispatcherID: dispatcherID,
60                 logger:       logger,
61         }
62         err = json.Unmarshal(config, &instanceSet.ec2config)
63         if err != nil {
64                 return nil, err
65         }
66         awsConfig := aws.NewConfig().
67                 WithCredentials(credentials.NewStaticCredentials(
68                         instanceSet.ec2config.AccessKeyID,
69                         instanceSet.ec2config.SecretAccessKey,
70                         "")).
71                 WithRegion(instanceSet.ec2config.Region)
72         instanceSet.client = ec2.New(session.Must(session.NewSession(awsConfig)))
73         instanceSet.keys = make(map[string]string)
74         return instanceSet, nil
75 }
76
77 func (instanceSet *ec2InstanceSet) Create(
78         instanceType arvados.InstanceType,
79         imageID cloud.ImageID,
80         newTags cloud.InstanceTags,
81         initCommand cloud.InitCommand,
82         publicKey ssh.PublicKey) (cloud.Instance, error) {
83
84         keyFingerprint := ssh.FingerprintSHA256(publicKey)
85         instanceSet.keysMtx.Lock()
86         var keyname string
87         var ok bool
88         if keyname, ok = instanceSet.keys[keyFingerprint]; !ok {
89                 keyname = "arvados-dispatch-keypair-" + keyFingerprint
90                 _, err := instanceSet.client.ImportKeyPair(&ec2.ImportKeyPairInput{
91                         KeyName:           &keyname,
92                         PublicKeyMaterial: ssh.MarshalAuthorizedKey(publicKey),
93                 })
94                 if err != nil {
95                         return nil, fmt.Errorf("Could not import keypair: %v", err)
96                 }
97                 instanceSet.keys[keyFingerprint] = keyname
98         }
99         instanceSet.keysMtx.Unlock()
100
101         ec2tags := []*ec2.Tag{
102                 &ec2.Tag{
103                         Key:   aws.String(arvadosDispatchID),
104                         Value: aws.String(string(instanceSet.dispatcherID)),
105                 },
106                 &ec2.Tag{
107                         Key:   aws.String("arvados-class"),
108                         Value: aws.String("dynamic-compute"),
109                 },
110         }
111         for k, v := range newTags {
112                 ec2tags = append(ec2tags, &ec2.Tag{
113                         Key:   aws.String(tagPrefix + k),
114                         Value: aws.String(v),
115                 })
116         }
117
118         rii := ec2.RunInstancesInput{
119                 ImageId:      aws.String(string(imageID)),
120                 InstanceType: &instanceType.ProviderType,
121                 MaxCount:     aws.Int64(1),
122                 MinCount:     aws.Int64(1),
123                 KeyName:      &keyname,
124
125                 NetworkInterfaces: []*ec2.InstanceNetworkInterfaceSpecification{
126                         &ec2.InstanceNetworkInterfaceSpecification{
127                                 AssociatePublicIpAddress: aws.Bool(false),
128                                 DeleteOnTermination:      aws.Bool(true),
129                                 DeviceIndex:              aws.Int64(0),
130                                 Groups:                   aws.StringSlice(instanceSet.ec2config.SecurityGroupIDs),
131                                 SubnetId:                 &instanceSet.ec2config.SubnetID,
132                         }},
133                 DisableApiTermination:             aws.Bool(false),
134                 InstanceInitiatedShutdownBehavior: aws.String("terminate"),
135                 UserData: aws.String(base64.StdEncoding.EncodeToString([]byte("#!/bin/sh\n" + initCommand + "\n"))),
136                 TagSpecifications: []*ec2.TagSpecification{
137                         &ec2.TagSpecification{
138                                 ResourceType: aws.String("instance"),
139                                 Tags:         ec2tags,
140                         }},
141         }
142
143         if instanceType.AddedScratch > 0 {
144                 rii.BlockDeviceMappings = []*ec2.BlockDeviceMapping{&ec2.BlockDeviceMapping{
145                         DeviceName: aws.String("/dev/xvdt"),
146                         Ebs: &ec2.EbsBlockDevice{
147                                 DeleteOnTermination: aws.Bool(true),
148                                 VolumeSize:          aws.Int64((int64(instanceType.AddedScratch) / 1000000000) + 1),
149                                 VolumeType:          aws.String("gp2"),
150                         }}}
151         }
152
153         if instanceType.Preemptible {
154                 rii.InstanceMarketOptions = &ec2.InstanceMarketOptionsRequest{
155                         MarketType: aws.String("spot"),
156                         SpotOptions: &ec2.SpotMarketOptions{
157                                 InstanceInterruptionBehavior: aws.String("terminate"),
158                                 MaxPrice:                     aws.String(fmt.Sprintf("%v", instanceType.Price)),
159                         }}
160         }
161
162         rsv, err := instanceSet.client.RunInstances(&rii)
163
164         if err != nil {
165                 return nil, err
166         }
167
168         return &ec2Instance{
169                 provider: instanceSet,
170                 instance: rsv.Instances[0],
171         }, nil
172 }
173
174 func (instanceSet *ec2InstanceSet) Instances(cloud.InstanceTags) (instances []cloud.Instance, err error) {
175         dii := &ec2.DescribeInstancesInput{
176                 Filters: []*ec2.Filter{&ec2.Filter{
177                         Name:   aws.String("tag:" + arvadosDispatchID),
178                         Values: []*string{aws.String(string(instanceSet.dispatcherID))},
179                 }}}
180
181         for {
182                 dio, err := instanceSet.client.DescribeInstances(dii)
183                 if err != nil {
184                         return nil, err
185                 }
186
187                 for _, rsv := range dio.Reservations {
188                         for _, inst := range rsv.Instances {
189                                 instances = append(instances, &ec2Instance{instanceSet, inst})
190                         }
191                 }
192                 if dio.NextToken == nil {
193                         return instances, err
194                 }
195                 dii.NextToken = dio.NextToken
196         }
197 }
198
199 func (az *ec2InstanceSet) Stop() {
200 }
201
202 type ec2Instance struct {
203         provider *ec2InstanceSet
204         instance *ec2.Instance
205 }
206
207 func (inst *ec2Instance) ID() cloud.InstanceID {
208         return cloud.InstanceID(*inst.instance.InstanceId)
209 }
210
211 func (inst *ec2Instance) String() string {
212         return *inst.instance.InstanceId
213 }
214
215 func (inst *ec2Instance) ProviderType() string {
216         return *inst.instance.InstanceType
217 }
218
219 func (inst *ec2Instance) SetTags(newTags cloud.InstanceTags) error {
220         ec2tags := []*ec2.Tag{
221                 &ec2.Tag{
222                         Key:   aws.String(arvadosDispatchID),
223                         Value: aws.String(string(inst.provider.dispatcherID)),
224                 },
225         }
226         for k, v := range newTags {
227                 ec2tags = append(ec2tags, &ec2.Tag{
228                         Key:   aws.String(tagPrefix + k),
229                         Value: aws.String(v),
230                 })
231         }
232
233         _, err := inst.provider.client.CreateTags(&ec2.CreateTagsInput{
234                 Resources: []*string{inst.instance.InstanceId},
235                 Tags:      ec2tags,
236         })
237
238         return err
239 }
240
241 func (inst *ec2Instance) Tags() cloud.InstanceTags {
242         tags := make(map[string]string)
243
244         for _, t := range inst.instance.Tags {
245                 if strings.HasPrefix(*t.Key, tagPrefix) {
246                         tags[(*t.Key)[len(tagPrefix):]] = *t.Value
247                 }
248         }
249
250         return tags
251 }
252
253 func (inst *ec2Instance) Destroy() error {
254         log.Printf("terminating %v", *inst.instance.InstanceId)
255         _, err := inst.provider.client.TerminateInstances(&ec2.TerminateInstancesInput{
256                 InstanceIds: []*string{inst.instance.InstanceId},
257         })
258         return err
259 }
260
261 func (inst *ec2Instance) Address() string {
262         if inst.instance.PrivateIpAddress != nil {
263                 return *inst.instance.PrivateIpAddress
264         } else {
265                 return ""
266         }
267 }
268
269 func (inst *ec2Instance) RemoteUser() string {
270         return inst.provider.ec2config.AdminUsername
271 }
272
273 func (inst *ec2Instance) VerifyHostKey(ssh.PublicKey, *ssh.Client) error {
274         return cloud.ErrNotImplemented
275 }