17778: Merge branch 'master' into 17778-doc-update
[arvados.git] / lib / cloud / ec2 / ec2.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package ec2
6
7 import (
8         "crypto/md5"
9         "crypto/rsa"
10         "crypto/sha1"
11         "crypto/x509"
12         "encoding/base64"
13         "encoding/json"
14         "fmt"
15         "math/big"
16         "sync"
17         "sync/atomic"
18         "time"
19
20         "git.arvados.org/arvados.git/lib/cloud"
21         "git.arvados.org/arvados.git/sdk/go/arvados"
22         "github.com/aws/aws-sdk-go/aws"
23         "github.com/aws/aws-sdk-go/aws/awserr"
24         "github.com/aws/aws-sdk-go/aws/credentials"
25         "github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds"
26         "github.com/aws/aws-sdk-go/aws/ec2metadata"
27         "github.com/aws/aws-sdk-go/aws/request"
28         "github.com/aws/aws-sdk-go/aws/session"
29         "github.com/aws/aws-sdk-go/service/ec2"
30         "github.com/sirupsen/logrus"
31         "golang.org/x/crypto/ssh"
32 )
33
34 // Driver is the ec2 implementation of the cloud.Driver interface.
35 var Driver = cloud.DriverFunc(newEC2InstanceSet)
36
37 const (
38         throttleDelayMin = time.Second
39         throttleDelayMax = time.Minute
40 )
41
42 type ec2InstanceSetConfig struct {
43         AccessKeyID      string
44         SecretAccessKey  string
45         Region           string
46         SecurityGroupIDs arvados.StringSet
47         SubnetID         string
48         AdminUsername    string
49         EBSVolumeType    string
50 }
51
52 type ec2Interface interface {
53         DescribeKeyPairs(input *ec2.DescribeKeyPairsInput) (*ec2.DescribeKeyPairsOutput, error)
54         ImportKeyPair(input *ec2.ImportKeyPairInput) (*ec2.ImportKeyPairOutput, error)
55         RunInstances(input *ec2.RunInstancesInput) (*ec2.Reservation, error)
56         DescribeInstances(input *ec2.DescribeInstancesInput) (*ec2.DescribeInstancesOutput, error)
57         CreateTags(input *ec2.CreateTagsInput) (*ec2.CreateTagsOutput, error)
58         TerminateInstances(input *ec2.TerminateInstancesInput) (*ec2.TerminateInstancesOutput, error)
59 }
60
61 type ec2InstanceSet struct {
62         ec2config              ec2InstanceSetConfig
63         instanceSetID          cloud.InstanceSetID
64         logger                 logrus.FieldLogger
65         client                 ec2Interface
66         keysMtx                sync.Mutex
67         keys                   map[string]string
68         throttleDelayCreate    atomic.Value
69         throttleDelayInstances atomic.Value
70 }
71
72 func newEC2InstanceSet(config json.RawMessage, instanceSetID cloud.InstanceSetID, _ cloud.SharedResourceTags, logger logrus.FieldLogger) (prv cloud.InstanceSet, err error) {
73         instanceSet := &ec2InstanceSet{
74                 instanceSetID: instanceSetID,
75                 logger:        logger,
76         }
77         err = json.Unmarshal(config, &instanceSet.ec2config)
78         if err != nil {
79                 return nil, err
80         }
81
82         sess, err := session.NewSession()
83         if err != nil {
84                 return nil, err
85         }
86         // First try any static credentials, fall back to an IAM instance profile/role
87         creds := credentials.NewChainCredentials(
88                 []credentials.Provider{
89                         &credentials.StaticProvider{Value: credentials.Value{AccessKeyID: instanceSet.ec2config.AccessKeyID, SecretAccessKey: instanceSet.ec2config.SecretAccessKey}},
90                         &ec2rolecreds.EC2RoleProvider{Client: ec2metadata.New(sess)},
91                 })
92
93         awsConfig := aws.NewConfig().WithCredentials(creds).WithRegion(instanceSet.ec2config.Region)
94         instanceSet.client = ec2.New(session.Must(session.NewSession(awsConfig)))
95         instanceSet.keys = make(map[string]string)
96         if instanceSet.ec2config.EBSVolumeType == "" {
97                 instanceSet.ec2config.EBSVolumeType = "gp2"
98         }
99         return instanceSet, nil
100 }
101
102 func awsKeyFingerprint(pk ssh.PublicKey) (md5fp string, sha1fp string, err error) {
103         // AWS key fingerprints don't use the usual key fingerprint
104         // you get from ssh-keygen or ssh.FingerprintLegacyMD5()
105         // (you can get that from md5.Sum(pk.Marshal())
106         //
107         // AWS uses the md5 or sha1 of the PKIX DER encoding of the
108         // public key, so calculate those fingerprints here.
109         var rsaPub struct {
110                 Name string
111                 E    *big.Int
112                 N    *big.Int
113         }
114         if err := ssh.Unmarshal(pk.Marshal(), &rsaPub); err != nil {
115                 return "", "", fmt.Errorf("agent: Unmarshal failed to parse public key: %v", err)
116         }
117         rsaPk := rsa.PublicKey{
118                 E: int(rsaPub.E.Int64()),
119                 N: rsaPub.N,
120         }
121         pkix, _ := x509.MarshalPKIXPublicKey(&rsaPk)
122         md5pkix := md5.Sum([]byte(pkix))
123         sha1pkix := sha1.Sum([]byte(pkix))
124         md5fp = ""
125         sha1fp = ""
126         for i := 0; i < len(md5pkix); i++ {
127                 md5fp += fmt.Sprintf(":%02x", md5pkix[i])
128         }
129         for i := 0; i < len(sha1pkix); i++ {
130                 sha1fp += fmt.Sprintf(":%02x", sha1pkix[i])
131         }
132         return md5fp[1:], sha1fp[1:], nil
133 }
134
135 func (instanceSet *ec2InstanceSet) Create(
136         instanceType arvados.InstanceType,
137         imageID cloud.ImageID,
138         newTags cloud.InstanceTags,
139         initCommand cloud.InitCommand,
140         publicKey ssh.PublicKey) (cloud.Instance, error) {
141
142         md5keyFingerprint, sha1keyFingerprint, err := awsKeyFingerprint(publicKey)
143         if err != nil {
144                 return nil, fmt.Errorf("Could not make key fingerprint: %v", err)
145         }
146         instanceSet.keysMtx.Lock()
147         var keyname string
148         var ok bool
149         if keyname, ok = instanceSet.keys[md5keyFingerprint]; !ok {
150                 keyout, err := instanceSet.client.DescribeKeyPairs(&ec2.DescribeKeyPairsInput{
151                         Filters: []*ec2.Filter{{
152                                 Name:   aws.String("fingerprint"),
153                                 Values: []*string{&md5keyFingerprint, &sha1keyFingerprint},
154                         }},
155                 })
156                 if err != nil {
157                         return nil, fmt.Errorf("Could not search for keypair: %v", err)
158                 }
159
160                 if len(keyout.KeyPairs) > 0 {
161                         keyname = *(keyout.KeyPairs[0].KeyName)
162                 } else {
163                         keyname = "arvados-dispatch-keypair-" + md5keyFingerprint
164                         _, err := instanceSet.client.ImportKeyPair(&ec2.ImportKeyPairInput{
165                                 KeyName:           &keyname,
166                                 PublicKeyMaterial: ssh.MarshalAuthorizedKey(publicKey),
167                         })
168                         if err != nil {
169                                 return nil, fmt.Errorf("Could not import keypair: %v", err)
170                         }
171                 }
172                 instanceSet.keys[md5keyFingerprint] = keyname
173         }
174         instanceSet.keysMtx.Unlock()
175
176         ec2tags := []*ec2.Tag{}
177         for k, v := range newTags {
178                 ec2tags = append(ec2tags, &ec2.Tag{
179                         Key:   aws.String(k),
180                         Value: aws.String(v),
181                 })
182         }
183
184         var groups []string
185         for sg := range instanceSet.ec2config.SecurityGroupIDs {
186                 groups = append(groups, sg)
187         }
188
189         rii := ec2.RunInstancesInput{
190                 ImageId:      aws.String(string(imageID)),
191                 InstanceType: &instanceType.ProviderType,
192                 MaxCount:     aws.Int64(1),
193                 MinCount:     aws.Int64(1),
194                 KeyName:      &keyname,
195
196                 NetworkInterfaces: []*ec2.InstanceNetworkInterfaceSpecification{
197                         {
198                                 AssociatePublicIpAddress: aws.Bool(false),
199                                 DeleteOnTermination:      aws.Bool(true),
200                                 DeviceIndex:              aws.Int64(0),
201                                 Groups:                   aws.StringSlice(groups),
202                                 SubnetId:                 &instanceSet.ec2config.SubnetID,
203                         }},
204                 DisableApiTermination:             aws.Bool(false),
205                 InstanceInitiatedShutdownBehavior: aws.String("terminate"),
206                 TagSpecifications: []*ec2.TagSpecification{
207                         {
208                                 ResourceType: aws.String("instance"),
209                                 Tags:         ec2tags,
210                         }},
211                 UserData: aws.String(base64.StdEncoding.EncodeToString([]byte("#!/bin/sh\n" + initCommand + "\n"))),
212         }
213
214         if instanceType.AddedScratch > 0 {
215                 rii.BlockDeviceMappings = []*ec2.BlockDeviceMapping{{
216                         DeviceName: aws.String("/dev/xvdt"),
217                         Ebs: &ec2.EbsBlockDevice{
218                                 DeleteOnTermination: aws.Bool(true),
219                                 VolumeSize:          aws.Int64((int64(instanceType.AddedScratch) + (1<<30 - 1)) >> 30),
220                                 VolumeType:          &instanceSet.ec2config.EBSVolumeType,
221                         }}}
222         }
223
224         if instanceType.Preemptible {
225                 rii.InstanceMarketOptions = &ec2.InstanceMarketOptionsRequest{
226                         MarketType: aws.String("spot"),
227                         SpotOptions: &ec2.SpotMarketOptions{
228                                 InstanceInterruptionBehavior: aws.String("terminate"),
229                                 MaxPrice:                     aws.String(fmt.Sprintf("%v", instanceType.Price)),
230                         }}
231         }
232
233         rsv, err := instanceSet.client.RunInstances(&rii)
234         err = wrapError(err, &instanceSet.throttleDelayCreate)
235         if err != nil {
236                 return nil, err
237         }
238
239         return &ec2Instance{
240                 provider: instanceSet,
241                 instance: rsv.Instances[0],
242         }, nil
243 }
244
245 func (instanceSet *ec2InstanceSet) Instances(tags cloud.InstanceTags) (instances []cloud.Instance, err error) {
246         var filters []*ec2.Filter
247         for k, v := range tags {
248                 filters = append(filters, &ec2.Filter{
249                         Name:   aws.String("tag:" + k),
250                         Values: []*string{aws.String(v)},
251                 })
252         }
253         dii := &ec2.DescribeInstancesInput{Filters: filters}
254         for {
255                 dio, err := instanceSet.client.DescribeInstances(dii)
256                 err = wrapError(err, &instanceSet.throttleDelayInstances)
257                 if err != nil {
258                         return nil, err
259                 }
260
261                 for _, rsv := range dio.Reservations {
262                         for _, inst := range rsv.Instances {
263                                 if *inst.State.Name != "shutting-down" && *inst.State.Name != "terminated" {
264                                         instances = append(instances, &ec2Instance{instanceSet, inst})
265                                 }
266                         }
267                 }
268                 if dio.NextToken == nil {
269                         return instances, err
270                 }
271                 dii.NextToken = dio.NextToken
272         }
273 }
274
275 func (instanceSet *ec2InstanceSet) Stop() {
276 }
277
278 type ec2Instance struct {
279         provider *ec2InstanceSet
280         instance *ec2.Instance
281 }
282
283 func (inst *ec2Instance) ID() cloud.InstanceID {
284         return cloud.InstanceID(*inst.instance.InstanceId)
285 }
286
287 func (inst *ec2Instance) String() string {
288         return *inst.instance.InstanceId
289 }
290
291 func (inst *ec2Instance) ProviderType() string {
292         return *inst.instance.InstanceType
293 }
294
295 func (inst *ec2Instance) SetTags(newTags cloud.InstanceTags) error {
296         var ec2tags []*ec2.Tag
297         for k, v := range newTags {
298                 ec2tags = append(ec2tags, &ec2.Tag{
299                         Key:   aws.String(k),
300                         Value: aws.String(v),
301                 })
302         }
303
304         _, err := inst.provider.client.CreateTags(&ec2.CreateTagsInput{
305                 Resources: []*string{inst.instance.InstanceId},
306                 Tags:      ec2tags,
307         })
308
309         return err
310 }
311
312 func (inst *ec2Instance) Tags() cloud.InstanceTags {
313         tags := make(map[string]string)
314
315         for _, t := range inst.instance.Tags {
316                 tags[*t.Key] = *t.Value
317         }
318
319         return tags
320 }
321
322 func (inst *ec2Instance) Destroy() error {
323         _, err := inst.provider.client.TerminateInstances(&ec2.TerminateInstancesInput{
324                 InstanceIds: []*string{inst.instance.InstanceId},
325         })
326         return err
327 }
328
329 func (inst *ec2Instance) Address() string {
330         if inst.instance.PrivateIpAddress != nil {
331                 return *inst.instance.PrivateIpAddress
332         }
333         return ""
334 }
335
336 func (inst *ec2Instance) RemoteUser() string {
337         return inst.provider.ec2config.AdminUsername
338 }
339
340 func (inst *ec2Instance) VerifyHostKey(ssh.PublicKey, *ssh.Client) error {
341         return cloud.ErrNotImplemented
342 }
343
344 type rateLimitError struct {
345         error
346         earliestRetry time.Time
347 }
348
349 func (err rateLimitError) EarliestRetry() time.Time {
350         return err.earliestRetry
351 }
352
353 var isCodeCapacity = map[string]bool{
354         "InsufficientInstanceCapacity": true,
355         "VcpuLimitExceeded":            true,
356         "MaxSpotInstanceCountExceeded": true,
357 }
358
359 // isErrorCapacity returns whether the error is to be throttled based on its code.
360 // Returns false if error is nil.
361 func isErrorCapacity(err error) bool {
362         if aerr, ok := err.(awserr.Error); ok && aerr != nil {
363                 if _, ok := isCodeCapacity[aerr.Code()]; ok {
364                         return true
365                 }
366         }
367         return false
368 }
369
370 type ec2QuotaError struct {
371         error
372 }
373
374 func (er *ec2QuotaError) IsQuotaError() bool {
375         return true
376 }
377
378 func wrapError(err error, throttleValue *atomic.Value) error {
379         if request.IsErrorThrottle(err) {
380                 // Back off exponentially until an upstream call
381                 // either succeeds or returns a non-throttle error.
382                 d, _ := throttleValue.Load().(time.Duration)
383                 d = d*3/2 + time.Second
384                 if d < throttleDelayMin {
385                         d = throttleDelayMin
386                 } else if d > throttleDelayMax {
387                         d = throttleDelayMax
388                 }
389                 throttleValue.Store(d)
390                 return rateLimitError{error: err, earliestRetry: time.Now().Add(d)}
391         } else if isErrorCapacity(err) {
392                 return &ec2QuotaError{err}
393         } else if err != nil {
394                 throttleValue.Store(time.Duration(0))
395                 return err
396         }
397         throttleValue.Store(time.Duration(0))
398         return nil
399 }