14988: Fixes functionals tests (WIP)
[arvados.git] / lib / cloud / ec2 / ec2.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package ec2
6
7 import (
8         "crypto/md5"
9         "crypto/rsa"
10         "crypto/sha1"
11         "crypto/x509"
12         "encoding/base64"
13         "encoding/json"
14         "fmt"
15         "math/big"
16         "strings"
17         "sync"
18
19         "git.curoverse.com/arvados.git/lib/cloud"
20         "git.curoverse.com/arvados.git/sdk/go/arvados"
21         "github.com/aws/aws-sdk-go/aws"
22         "github.com/aws/aws-sdk-go/aws/credentials"
23         "github.com/aws/aws-sdk-go/aws/session"
24         "github.com/aws/aws-sdk-go/service/ec2"
25         "github.com/sirupsen/logrus"
26         "golang.org/x/crypto/ssh"
27 )
28
29 const arvadosDispatchID = "arvados-dispatch-id"
30 const tagPrefix = "arvados-dispatch-tag-"
31
32 // Driver is the ec2 implementation of the cloud.Driver interface.
33 var Driver = cloud.DriverFunc(newEC2InstanceSet)
34
35 type ec2InstanceSetConfig struct {
36         AccessKeyID      string
37         SecretAccessKey  string
38         Region           string
39         SecurityGroupIDs []string
40         SubnetID         string
41         AdminUsername    string
42         EBSVolumeType    string
43 }
44
45 type ec2Interface interface {
46         DescribeKeyPairs(input *ec2.DescribeKeyPairsInput) (*ec2.DescribeKeyPairsOutput, error)
47         ImportKeyPair(input *ec2.ImportKeyPairInput) (*ec2.ImportKeyPairOutput, error)
48         RunInstances(input *ec2.RunInstancesInput) (*ec2.Reservation, error)
49         DescribeInstances(input *ec2.DescribeInstancesInput) (*ec2.DescribeInstancesOutput, error)
50         CreateTags(input *ec2.CreateTagsInput) (*ec2.CreateTagsOutput, error)
51         TerminateInstances(input *ec2.TerminateInstancesInput) (*ec2.TerminateInstancesOutput, error)
52 }
53
54 type ec2InstanceSet struct {
55         ec2config    ec2InstanceSetConfig
56         dispatcherID cloud.InstanceSetID
57         logger       logrus.FieldLogger
58         client       ec2Interface
59         keysMtx      sync.Mutex
60         keys         map[string]string
61 }
62
63 func newEC2InstanceSet(config json.RawMessage, dispatcherID cloud.InstanceSetID, logger logrus.FieldLogger) (prv cloud.InstanceSet, err error) {
64         instanceSet := &ec2InstanceSet{
65                 dispatcherID: dispatcherID,
66                 logger:       logger,
67         }
68         err = json.Unmarshal(config, &instanceSet.ec2config)
69         if err != nil {
70                 return nil, err
71         }
72         awsConfig := aws.NewConfig().
73                 WithCredentials(credentials.NewStaticCredentials(
74                         instanceSet.ec2config.AccessKeyID,
75                         instanceSet.ec2config.SecretAccessKey,
76                         "")).
77                 WithRegion(instanceSet.ec2config.Region)
78         instanceSet.client = ec2.New(session.Must(session.NewSession(awsConfig)))
79         instanceSet.keys = make(map[string]string)
80         if instanceSet.ec2config.EBSVolumeType == "" {
81                 instanceSet.ec2config.EBSVolumeType = "gp2"
82         }
83         return instanceSet, nil
84 }
85
86 func awsKeyFingerprint(pk ssh.PublicKey) (md5fp string, sha1fp string, err error) {
87         // AWS key fingerprints don't use the usual key fingerprint
88         // you get from ssh-keygen or ssh.FingerprintLegacyMD5()
89         // (you can get that from md5.Sum(pk.Marshal())
90         //
91         // AWS uses the md5 or sha1 of the PKIX DER encoding of the
92         // public key, so calculate those fingerprints here.
93         var rsaPub struct {
94                 Name string
95                 E    *big.Int
96                 N    *big.Int
97         }
98         if err := ssh.Unmarshal(pk.Marshal(), &rsaPub); err != nil {
99                 return "", "", fmt.Errorf("agent: Unmarshal failed to parse public key: %v", err)
100         }
101         rsaPk := rsa.PublicKey{
102                 E: int(rsaPub.E.Int64()),
103                 N: rsaPub.N,
104         }
105         pkix, _ := x509.MarshalPKIXPublicKey(&rsaPk)
106         md5pkix := md5.Sum([]byte(pkix))
107         sha1pkix := sha1.Sum([]byte(pkix))
108         md5fp = ""
109         sha1fp = ""
110         for i := 0; i < len(md5pkix); i += 1 {
111                 md5fp += fmt.Sprintf(":%02x", md5pkix[i])
112         }
113         for i := 0; i < len(sha1pkix); i += 1 {
114                 sha1fp += fmt.Sprintf(":%02x", sha1pkix[i])
115         }
116         return md5fp[1:], sha1fp[1:], nil
117 }
118
119 func (instanceSet *ec2InstanceSet) Create(
120         instanceType arvados.InstanceType,
121         imageID cloud.ImageID,
122         newTags cloud.InstanceTags,
123         initCommand cloud.InitCommand,
124         publicKey ssh.PublicKey) (cloud.Instance, error) {
125
126         md5keyFingerprint, sha1keyFingerprint, err := awsKeyFingerprint(publicKey)
127         if err != nil {
128                 return nil, fmt.Errorf("Could not make key fingerprint: %v", err)
129         }
130         instanceSet.keysMtx.Lock()
131         var keyname string
132         var ok bool
133         if keyname, ok = instanceSet.keys[md5keyFingerprint]; !ok {
134                 keyout, err := instanceSet.client.DescribeKeyPairs(&ec2.DescribeKeyPairsInput{
135                         Filters: []*ec2.Filter{&ec2.Filter{
136                                 Name:   aws.String("fingerprint"),
137                                 Values: []*string{&md5keyFingerprint, &sha1keyFingerprint},
138                         }},
139                 })
140                 if err != nil {
141                         return nil, fmt.Errorf("Could not search for keypair: %v", err)
142                 }
143
144                 if len(keyout.KeyPairs) > 0 {
145                         keyname = *(keyout.KeyPairs[0].KeyName)
146                 } else {
147                         keyname = "arvados-dispatch-keypair-" + md5keyFingerprint
148                         _, err := instanceSet.client.ImportKeyPair(&ec2.ImportKeyPairInput{
149                                 KeyName:           &keyname,
150                                 PublicKeyMaterial: ssh.MarshalAuthorizedKey(publicKey),
151                         })
152                         if err != nil {
153                                 return nil, fmt.Errorf("Could not import keypair: %v", err)
154                         }
155                 }
156                 instanceSet.keys[md5keyFingerprint] = keyname
157         }
158         instanceSet.keysMtx.Unlock()
159
160         ec2tags := []*ec2.Tag{
161                 &ec2.Tag{
162                         Key:   aws.String(arvadosDispatchID),
163                         Value: aws.String(string(instanceSet.dispatcherID)),
164                 },
165                 &ec2.Tag{
166                         Key:   aws.String("arvados-class"),
167                         Value: aws.String("dynamic-compute"),
168                 },
169         }
170         for k, v := range newTags {
171                 ec2tags = append(ec2tags, &ec2.Tag{
172                         Key:   aws.String(tagPrefix + k),
173                         Value: aws.String(v),
174                 })
175         }
176
177         rii := ec2.RunInstancesInput{
178                 ImageId:      aws.String(string(imageID)),
179                 InstanceType: &instanceType.ProviderType,
180                 MaxCount:     aws.Int64(1),
181                 MinCount:     aws.Int64(1),
182                 KeyName:      &keyname,
183
184                 NetworkInterfaces: []*ec2.InstanceNetworkInterfaceSpecification{
185                         &ec2.InstanceNetworkInterfaceSpecification{
186                                 AssociatePublicIpAddress: aws.Bool(false),
187                                 DeleteOnTermination:      aws.Bool(true),
188                                 DeviceIndex:              aws.Int64(0),
189                                 Groups:                   aws.StringSlice(instanceSet.ec2config.SecurityGroupIDs),
190                                 SubnetId:                 &instanceSet.ec2config.SubnetID,
191                         }},
192                 DisableApiTermination:             aws.Bool(false),
193                 InstanceInitiatedShutdownBehavior: aws.String("terminate"),
194                 UserData: aws.String(base64.StdEncoding.EncodeToString([]byte("#!/bin/sh\n" + initCommand + "\n"))),
195                 TagSpecifications: []*ec2.TagSpecification{
196                         &ec2.TagSpecification{
197                                 ResourceType: aws.String("instance"),
198                                 Tags:         ec2tags,
199                         }},
200         }
201
202         if instanceType.AddedScratch > 0 {
203                 rii.BlockDeviceMappings = []*ec2.BlockDeviceMapping{&ec2.BlockDeviceMapping{
204                         DeviceName: aws.String("/dev/xvdt"),
205                         Ebs: &ec2.EbsBlockDevice{
206                                 DeleteOnTermination: aws.Bool(true),
207                                 VolumeSize:          aws.Int64((int64(instanceType.AddedScratch) + (1<<30 - 1)) >> 30),
208                                 VolumeType:          &instanceSet.ec2config.EBSVolumeType,
209                         }}}
210         }
211
212         if instanceType.Preemptible {
213                 rii.InstanceMarketOptions = &ec2.InstanceMarketOptionsRequest{
214                         MarketType: aws.String("spot"),
215                         SpotOptions: &ec2.SpotMarketOptions{
216                                 InstanceInterruptionBehavior: aws.String("terminate"),
217                                 MaxPrice:                     aws.String(fmt.Sprintf("%v", instanceType.Price)),
218                         }}
219         }
220
221         rsv, err := instanceSet.client.RunInstances(&rii)
222
223         if err != nil {
224                 return nil, err
225         }
226
227         return &ec2Instance{
228                 provider: instanceSet,
229                 instance: rsv.Instances[0],
230         }, nil
231 }
232
233 func (instanceSet *ec2InstanceSet) Instances(cloud.InstanceTags) (instances []cloud.Instance, err error) {
234         dii := &ec2.DescribeInstancesInput{
235                 Filters: []*ec2.Filter{&ec2.Filter{
236                         Name:   aws.String("tag:" + arvadosDispatchID),
237                         Values: []*string{aws.String(string(instanceSet.dispatcherID))},
238                 }}}
239
240         for {
241                 dio, err := instanceSet.client.DescribeInstances(dii)
242                 if err != nil {
243                         return nil, err
244                 }
245
246                 for _, rsv := range dio.Reservations {
247                         for _, inst := range rsv.Instances {
248                                 if *inst.State.Name != "shutting-down" && *inst.State.Name != "terminated" {
249                                         instances = append(instances, &ec2Instance{instanceSet, inst})
250                                 }
251                         }
252                 }
253                 if dio.NextToken == nil {
254                         return instances, err
255                 }
256                 dii.NextToken = dio.NextToken
257         }
258 }
259
260 func (az *ec2InstanceSet) Stop() {
261 }
262
263 type ec2Instance struct {
264         provider *ec2InstanceSet
265         instance *ec2.Instance
266 }
267
268 func (inst *ec2Instance) ID() cloud.InstanceID {
269         return cloud.InstanceID(*inst.instance.InstanceId)
270 }
271
272 func (inst *ec2Instance) String() string {
273         return *inst.instance.InstanceId
274 }
275
276 func (inst *ec2Instance) ProviderType() string {
277         return *inst.instance.InstanceType
278 }
279
280 func (inst *ec2Instance) SetTags(newTags cloud.InstanceTags) error {
281         ec2tags := []*ec2.Tag{
282                 &ec2.Tag{
283                         Key:   aws.String(arvadosDispatchID),
284                         Value: aws.String(string(inst.provider.dispatcherID)),
285                 },
286         }
287         for k, v := range newTags {
288                 ec2tags = append(ec2tags, &ec2.Tag{
289                         Key:   aws.String(tagPrefix + k),
290                         Value: aws.String(v),
291                 })
292         }
293
294         _, err := inst.provider.client.CreateTags(&ec2.CreateTagsInput{
295                 Resources: []*string{inst.instance.InstanceId},
296                 Tags:      ec2tags,
297         })
298
299         return err
300 }
301
302 func (inst *ec2Instance) Tags() cloud.InstanceTags {
303         tags := make(map[string]string)
304
305         for _, t := range inst.instance.Tags {
306                 if strings.HasPrefix(*t.Key, tagPrefix) {
307                         tags[(*t.Key)[len(tagPrefix):]] = *t.Value
308                 }
309         }
310
311         return tags
312 }
313
314 func (inst *ec2Instance) Destroy() error {
315         _, err := inst.provider.client.TerminateInstances(&ec2.TerminateInstancesInput{
316                 InstanceIds: []*string{inst.instance.InstanceId},
317         })
318         return err
319 }
320
321 func (inst *ec2Instance) Address() string {
322         if inst.instance.PrivateIpAddress != nil {
323                 return *inst.instance.PrivateIpAddress
324         } else {
325                 return ""
326         }
327 }
328
329 func (inst *ec2Instance) RemoteUser() string {
330         return inst.provider.ec2config.AdminUsername
331 }
332
333 func (inst *ec2Instance) VerifyHostKey(ssh.PublicKey, *ssh.Client) error {
334         return cloud.ErrNotImplemented
335 }