5eb1afd3b58f189e4da3e75814f243d9b63227de
[arvados.git] / lib / cloud / ec2 / ec2.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package ec2
6
7 import (
8         "context"
9         "crypto/md5"
10         "crypto/rsa"
11         "crypto/sha1"
12         "crypto/x509"
13         "encoding/base64"
14         "encoding/json"
15         "errors"
16         "fmt"
17         "math/big"
18         "os"
19         "regexp"
20         "strconv"
21         "strings"
22         "sync"
23         "sync/atomic"
24         "time"
25
26         "git.arvados.org/arvados.git/lib/cloud"
27         "git.arvados.org/arvados.git/sdk/go/arvados"
28         "github.com/aws/aws-sdk-go-v2/aws"
29         "github.com/aws/aws-sdk-go-v2/aws/retry"
30         awsconfig "github.com/aws/aws-sdk-go-v2/config"
31         "github.com/aws/aws-sdk-go-v2/service/ec2"
32         "github.com/aws/aws-sdk-go-v2/service/ec2/types"
33         "github.com/aws/smithy-go"
34         "github.com/prometheus/client_golang/prometheus"
35         "github.com/sirupsen/logrus"
36         "golang.org/x/crypto/ssh"
37 )
38
39 // Driver is the ec2 implementation of the cloud.Driver interface.
40 var Driver = cloud.DriverFunc(newEC2InstanceSet)
41
42 const (
43         throttleDelayMin = time.Second
44         throttleDelayMax = time.Minute
45 )
46
47 type ec2InstanceSetConfig struct {
48         AccessKeyID             string
49         SecretAccessKey         string
50         Region                  string
51         SecurityGroupIDs        arvados.StringSet
52         SubnetID                sliceOrSingleString
53         AdminUsername           string
54         EBSVolumeType           types.VolumeType
55         EBSPrice                float64
56         IAMInstanceProfile      string
57         SpotPriceUpdateInterval arvados.Duration
58 }
59
60 type sliceOrSingleString []string
61
62 // UnmarshalJSON unmarshals an array of strings, and also accepts ""
63 // as [], and "foo" as ["foo"].
64 func (ss *sliceOrSingleString) UnmarshalJSON(data []byte) error {
65         if len(data) == 0 {
66                 *ss = nil
67         } else if data[0] == '[' {
68                 var slice []string
69                 err := json.Unmarshal(data, &slice)
70                 if err != nil {
71                         return err
72                 }
73                 if len(slice) == 0 {
74                         *ss = nil
75                 } else {
76                         *ss = slice
77                 }
78         } else {
79                 var str string
80                 err := json.Unmarshal(data, &str)
81                 if err != nil {
82                         return err
83                 }
84                 if str == "" {
85                         *ss = nil
86                 } else {
87                         *ss = []string{str}
88                 }
89         }
90         return nil
91 }
92
93 type ec2Interface interface {
94         DescribeKeyPairs(context.Context, *ec2.DescribeKeyPairsInput, ...func(*ec2.Options)) (*ec2.DescribeKeyPairsOutput, error)
95         ImportKeyPair(context.Context, *ec2.ImportKeyPairInput, ...func(*ec2.Options)) (*ec2.ImportKeyPairOutput, error)
96         RunInstances(context.Context, *ec2.RunInstancesInput, ...func(*ec2.Options)) (*ec2.RunInstancesOutput, error)
97         DescribeInstances(context.Context, *ec2.DescribeInstancesInput, ...func(*ec2.Options)) (*ec2.DescribeInstancesOutput, error)
98         DescribeInstanceStatus(context.Context, *ec2.DescribeInstanceStatusInput, ...func(*ec2.Options)) (*ec2.DescribeInstanceStatusOutput, error)
99         DescribeSpotPriceHistory(context.Context, *ec2.DescribeSpotPriceHistoryInput, ...func(*ec2.Options)) (*ec2.DescribeSpotPriceHistoryOutput, error)
100         CreateTags(context.Context, *ec2.CreateTagsInput, ...func(*ec2.Options)) (*ec2.CreateTagsOutput, error)
101         TerminateInstances(context.Context, *ec2.TerminateInstancesInput, ...func(*ec2.Options)) (*ec2.TerminateInstancesOutput, error)
102 }
103
104 type ec2InstanceSet struct {
105         ec2config              ec2InstanceSetConfig
106         currentSubnetIDIndex   int32
107         instanceSetID          cloud.InstanceSetID
108         logger                 logrus.FieldLogger
109         client                 ec2Interface
110         keysMtx                sync.Mutex
111         keys                   map[string]string
112         throttleDelayCreate    atomic.Value
113         throttleDelayInstances atomic.Value
114
115         prices        map[priceKey][]cloud.InstancePrice
116         pricesLock    sync.Mutex
117         pricesUpdated map[priceKey]time.Time
118
119         mInstances      *prometheus.GaugeVec
120         mInstanceStarts *prometheus.CounterVec
121 }
122
123 func newEC2InstanceSet(config json.RawMessage, instanceSetID cloud.InstanceSetID, _ cloud.SharedResourceTags, logger logrus.FieldLogger, reg *prometheus.Registry) (prv cloud.InstanceSet, err error) {
124         instanceSet := &ec2InstanceSet{
125                 instanceSetID: instanceSetID,
126                 logger:        logger,
127         }
128         err = json.Unmarshal(config, &instanceSet.ec2config)
129         if err != nil {
130                 return nil, err
131         }
132
133         if len(instanceSet.ec2config.AccessKeyID)+len(instanceSet.ec2config.SecretAccessKey) > 0 {
134                 // AWS SDK will use credentials in environment vars if
135                 // present.
136                 os.Setenv("AWS_ACCESS_KEY_ID", instanceSet.ec2config.AccessKeyID)
137                 os.Setenv("AWS_SECRET_ACCESS_KEY", instanceSet.ec2config.SecretAccessKey)
138         } else {
139                 os.Unsetenv("AWS_ACCESS_KEY_ID")
140                 os.Unsetenv("AWS_SECRET_ACCESS_KEY")
141         }
142         awsConfig, err := awsconfig.LoadDefaultConfig(context.TODO(),
143                 awsconfig.WithRegion(instanceSet.ec2config.Region))
144         if err != nil {
145                 return nil, err
146         }
147
148         instanceSet.client = ec2.NewFromConfig(awsConfig)
149         instanceSet.keys = make(map[string]string)
150         if instanceSet.ec2config.EBSVolumeType == "" {
151                 instanceSet.ec2config.EBSVolumeType = "gp2"
152         }
153
154         // Set up metrics
155         instanceSet.mInstances = prometheus.NewGaugeVec(prometheus.GaugeOpts{
156                 Namespace: "arvados",
157                 Subsystem: "dispatchcloud",
158                 Name:      "ec2_instances",
159                 Help:      "Number of instances running",
160         }, []string{"subnet_id"})
161         instanceSet.mInstanceStarts = prometheus.NewCounterVec(prometheus.CounterOpts{
162                 Namespace: "arvados",
163                 Subsystem: "dispatchcloud",
164                 Name:      "ec2_instance_starts_total",
165                 Help:      "Number of attempts to start a new instance",
166         }, []string{"subnet_id", "success"})
167         // Initialize all of the series we'll be reporting.  Otherwise
168         // the {subnet=A, success=0} series doesn't appear in metrics
169         // at all until there's a failure in subnet A.
170         for _, subnet := range instanceSet.ec2config.SubnetID {
171                 instanceSet.mInstanceStarts.WithLabelValues(subnet, "0").Add(0)
172                 instanceSet.mInstanceStarts.WithLabelValues(subnet, "1").Add(0)
173         }
174         if len(instanceSet.ec2config.SubnetID) == 0 {
175                 instanceSet.mInstanceStarts.WithLabelValues("", "0").Add(0)
176                 instanceSet.mInstanceStarts.WithLabelValues("", "1").Add(0)
177         }
178         if reg != nil {
179                 reg.MustRegister(instanceSet.mInstances)
180                 reg.MustRegister(instanceSet.mInstanceStarts)
181         }
182
183         return instanceSet, nil
184 }
185
186 func awsKeyFingerprint(pk ssh.PublicKey) (md5fp string, sha1fp string, err error) {
187         // AWS key fingerprints don't use the usual key fingerprint
188         // you get from ssh-keygen or ssh.FingerprintLegacyMD5()
189         // (you can get that from md5.Sum(pk.Marshal())
190         //
191         // AWS uses the md5 or sha1 of the PKIX DER encoding of the
192         // public key, so calculate those fingerprints here.
193         var rsaPub struct {
194                 Name string
195                 E    *big.Int
196                 N    *big.Int
197         }
198         if err := ssh.Unmarshal(pk.Marshal(), &rsaPub); err != nil {
199                 return "", "", fmt.Errorf("agent: Unmarshal failed to parse public key: %v", err)
200         }
201         rsaPk := rsa.PublicKey{
202                 E: int(rsaPub.E.Int64()),
203                 N: rsaPub.N,
204         }
205         pkix, _ := x509.MarshalPKIXPublicKey(&rsaPk)
206         md5pkix := md5.Sum([]byte(pkix))
207         sha1pkix := sha1.Sum([]byte(pkix))
208         md5fp = ""
209         sha1fp = ""
210         for i := 0; i < len(md5pkix); i++ {
211                 md5fp += fmt.Sprintf(":%02x", md5pkix[i])
212         }
213         for i := 0; i < len(sha1pkix); i++ {
214                 sha1fp += fmt.Sprintf(":%02x", sha1pkix[i])
215         }
216         return md5fp[1:], sha1fp[1:], nil
217 }
218
219 func (instanceSet *ec2InstanceSet) Create(
220         instanceType arvados.InstanceType,
221         imageID cloud.ImageID,
222         newTags cloud.InstanceTags,
223         initCommand cloud.InitCommand,
224         publicKey ssh.PublicKey) (cloud.Instance, error) {
225
226         ec2tags := []types.Tag{}
227         for k, v := range newTags {
228                 ec2tags = append(ec2tags, types.Tag{
229                         Key:   aws.String(k),
230                         Value: aws.String(v),
231                 })
232         }
233
234         var groups []string
235         for sg := range instanceSet.ec2config.SecurityGroupIDs {
236                 groups = append(groups, sg)
237         }
238
239         rii := ec2.RunInstancesInput{
240                 ImageId:      aws.String(string(imageID)),
241                 InstanceType: types.InstanceType(instanceType.ProviderType),
242                 MaxCount:     aws.Int32(1),
243                 MinCount:     aws.Int32(1),
244
245                 NetworkInterfaces: []types.InstanceNetworkInterfaceSpecification{{
246                         AssociatePublicIpAddress: aws.Bool(false),
247                         DeleteOnTermination:      aws.Bool(true),
248                         DeviceIndex:              aws.Int32(0),
249                         Groups:                   groups,
250                 }},
251                 DisableApiTermination:             aws.Bool(false),
252                 InstanceInitiatedShutdownBehavior: types.ShutdownBehaviorTerminate,
253                 TagSpecifications: []types.TagSpecification{
254                         {
255                                 ResourceType: types.ResourceTypeInstance,
256                                 Tags:         ec2tags,
257                         }},
258                 MetadataOptions: &types.InstanceMetadataOptionsRequest{
259                         // Require IMDSv2, as described at
260                         // https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/configuring-IMDS-new-instances.html
261                         HttpEndpoint: types.InstanceMetadataEndpointStateEnabled,
262                         HttpTokens:   types.HttpTokensStateRequired,
263                 },
264                 UserData: aws.String(base64.StdEncoding.EncodeToString([]byte("#!/bin/sh\n" + initCommand + "\n"))),
265         }
266
267         if publicKey != nil {
268                 keyname, err := instanceSet.getKeyName(publicKey)
269                 if err != nil {
270                         return nil, err
271                 }
272                 rii.KeyName = &keyname
273         }
274
275         if instanceType.AddedScratch > 0 {
276                 rii.BlockDeviceMappings = []types.BlockDeviceMapping{{
277                         DeviceName: aws.String("/dev/xvdt"),
278                         Ebs: &types.EbsBlockDevice{
279                                 DeleteOnTermination: aws.Bool(true),
280                                 VolumeSize:          aws.Int32(int32((int64(instanceType.AddedScratch) + (1<<30 - 1)) >> 30)),
281                                 VolumeType:          instanceSet.ec2config.EBSVolumeType,
282                         }}}
283         }
284
285         if instanceType.Preemptible {
286                 rii.InstanceMarketOptions = &types.InstanceMarketOptionsRequest{
287                         MarketType: types.MarketTypeSpot,
288                         SpotOptions: &types.SpotMarketOptions{
289                                 InstanceInterruptionBehavior: types.InstanceInterruptionBehaviorTerminate,
290                                 MaxPrice:                     aws.String(fmt.Sprintf("%v", instanceType.Price)),
291                         }}
292         }
293
294         if instanceSet.ec2config.IAMInstanceProfile != "" {
295                 rii.IamInstanceProfile = &types.IamInstanceProfileSpecification{
296                         Name: aws.String(instanceSet.ec2config.IAMInstanceProfile),
297                 }
298         }
299
300         var rsv *ec2.RunInstancesOutput
301         var errToReturn error
302         subnets := instanceSet.ec2config.SubnetID
303         currentSubnetIDIndex := int(atomic.LoadInt32(&instanceSet.currentSubnetIDIndex))
304         for tryOffset := 0; ; tryOffset++ {
305                 tryIndex := 0
306                 trySubnet := ""
307                 if len(subnets) > 0 {
308                         tryIndex = (currentSubnetIDIndex + tryOffset) % len(subnets)
309                         trySubnet = subnets[tryIndex]
310                         rii.NetworkInterfaces[0].SubnetId = aws.String(trySubnet)
311                 }
312                 var err error
313                 rsv, err = instanceSet.client.RunInstances(context.TODO(), &rii)
314                 instanceSet.mInstanceStarts.WithLabelValues(trySubnet, boolLabelValue[err == nil]).Add(1)
315                 if !isErrorCapacity(errToReturn) || isErrorCapacity(err) {
316                         // We want to return the last capacity error,
317                         // if any; otherwise the last non-capacity
318                         // error.
319                         errToReturn = err
320                 }
321                 if isErrorSubnetSpecific(err) &&
322                         tryOffset < len(subnets)-1 {
323                         instanceSet.logger.WithError(err).WithField("SubnetID", subnets[tryIndex]).
324                                 Warn("RunInstances failed, trying next subnet")
325                         continue
326                 }
327                 // Succeeded, or exhausted all subnets, or got a
328                 // non-subnet-related error.
329                 //
330                 // We intentionally update currentSubnetIDIndex even
331                 // in the non-retryable-failure case here to avoid a
332                 // situation where successive calls to Create() keep
333                 // returning errors for the same subnet (perhaps
334                 // "subnet full") and never reveal the errors for the
335                 // other configured subnets (perhaps "subnet ID
336                 // invalid").
337                 atomic.StoreInt32(&instanceSet.currentSubnetIDIndex, int32(tryIndex))
338                 break
339         }
340         if rsv == nil || len(rsv.Instances) == 0 {
341                 return nil, wrapError(errToReturn, &instanceSet.throttleDelayCreate)
342         }
343         return &ec2Instance{
344                 provider: instanceSet,
345                 instance: rsv.Instances[0],
346         }, nil
347 }
348
349 func (instanceSet *ec2InstanceSet) getKeyName(publicKey ssh.PublicKey) (string, error) {
350         instanceSet.keysMtx.Lock()
351         defer instanceSet.keysMtx.Unlock()
352         md5keyFingerprint, sha1keyFingerprint, err := awsKeyFingerprint(publicKey)
353         if err != nil {
354                 return "", fmt.Errorf("Could not make key fingerprint: %v", err)
355         }
356         if keyname, ok := instanceSet.keys[md5keyFingerprint]; ok {
357                 return keyname, nil
358         }
359         keyout, err := instanceSet.client.DescribeKeyPairs(context.TODO(), &ec2.DescribeKeyPairsInput{
360                 Filters: []types.Filter{{
361                         Name:   aws.String("fingerprint"),
362                         Values: []string{md5keyFingerprint, sha1keyFingerprint},
363                 }},
364         })
365         if err != nil {
366                 return "", fmt.Errorf("Could not search for keypair: %v", err)
367         }
368         if len(keyout.KeyPairs) > 0 {
369                 return *(keyout.KeyPairs[0].KeyName), nil
370         }
371         keyname := "arvados-dispatch-keypair-" + md5keyFingerprint
372         _, err = instanceSet.client.ImportKeyPair(context.TODO(), &ec2.ImportKeyPairInput{
373                 KeyName:           &keyname,
374                 PublicKeyMaterial: ssh.MarshalAuthorizedKey(publicKey),
375         })
376         if err != nil {
377                 return "", fmt.Errorf("Could not import keypair: %v", err)
378         }
379         instanceSet.keys[md5keyFingerprint] = keyname
380         return keyname, nil
381 }
382
383 func (instanceSet *ec2InstanceSet) Instances(tags cloud.InstanceTags) (instances []cloud.Instance, err error) {
384         var filters []types.Filter
385         for k, v := range tags {
386                 filters = append(filters, types.Filter{
387                         Name:   aws.String("tag:" + k),
388                         Values: []string{v},
389                 })
390         }
391         needAZs := false
392         dii := &ec2.DescribeInstancesInput{Filters: filters}
393         for {
394                 dio, err := instanceSet.client.DescribeInstances(context.TODO(), dii)
395                 err = wrapError(err, &instanceSet.throttleDelayInstances)
396                 if err != nil {
397                         return nil, err
398                 }
399
400                 for _, rsv := range dio.Reservations {
401                         for _, inst := range rsv.Instances {
402                                 switch inst.State.Name {
403                                 case types.InstanceStateNameShuttingDown:
404                                 case types.InstanceStateNameTerminated:
405                                 default:
406                                         instances = append(instances, &ec2Instance{
407                                                 provider: instanceSet,
408                                                 instance: inst,
409                                         })
410                                         if inst.InstanceLifecycle == types.InstanceLifecycleTypeSpot {
411                                                 needAZs = true
412                                         }
413                                 }
414                         }
415                 }
416                 if dio.NextToken == nil {
417                         break
418                 }
419                 dii.NextToken = dio.NextToken
420         }
421         if needAZs && instanceSet.ec2config.SpotPriceUpdateInterval > 0 {
422                 az := map[string]string{}
423                 disi := &ec2.DescribeInstanceStatusInput{IncludeAllInstances: aws.Bool(true)}
424                 for {
425                         page, err := instanceSet.client.DescribeInstanceStatus(context.TODO(), disi)
426                         if err != nil {
427                                 instanceSet.logger.Warnf("error getting instance statuses: %s", err)
428                                 break
429                         }
430                         for _, ent := range page.InstanceStatuses {
431                                 az[*ent.InstanceId] = *ent.AvailabilityZone
432                         }
433                         if page.NextToken == nil {
434                                 break
435                         }
436                         disi.NextToken = page.NextToken
437                 }
438                 for _, inst := range instances {
439                         inst := inst.(*ec2Instance)
440                         inst.availabilityZone = az[*inst.instance.InstanceId]
441                 }
442                 instanceSet.updateSpotPrices(instances)
443         }
444
445         // Count instances in each subnet, and report in metrics.
446         subnetInstances := map[string]int{"": 0}
447         for _, subnet := range instanceSet.ec2config.SubnetID {
448                 subnetInstances[subnet] = 0
449         }
450         for _, inst := range instances {
451                 subnet := inst.(*ec2Instance).instance.SubnetId
452                 if subnet != nil {
453                         subnetInstances[*subnet]++
454                 } else {
455                         subnetInstances[""]++
456                 }
457         }
458         for subnet, count := range subnetInstances {
459                 instanceSet.mInstances.WithLabelValues(subnet).Set(float64(count))
460         }
461
462         return instances, err
463 }
464
465 type priceKey struct {
466         instanceType     string
467         spot             bool
468         availabilityZone string
469 }
470
471 // Refresh recent spot instance pricing data for the given instances,
472 // unless we already have recent pricing data for all relevant types.
473 func (instanceSet *ec2InstanceSet) updateSpotPrices(instances []cloud.Instance) {
474         if len(instances) == 0 {
475                 return
476         }
477
478         instanceSet.pricesLock.Lock()
479         defer instanceSet.pricesLock.Unlock()
480         if instanceSet.prices == nil {
481                 instanceSet.prices = map[priceKey][]cloud.InstancePrice{}
482                 instanceSet.pricesUpdated = map[priceKey]time.Time{}
483         }
484
485         updateTime := time.Now()
486         staleTime := updateTime.Add(-instanceSet.ec2config.SpotPriceUpdateInterval.Duration())
487         needUpdate := false
488         allTypes := map[types.InstanceType]bool{}
489
490         for _, inst := range instances {
491                 ec2inst := inst.(*ec2Instance).instance
492                 if ec2inst.InstanceLifecycle == types.InstanceLifecycleTypeSpot {
493                         pk := priceKey{
494                                 instanceType:     string(ec2inst.InstanceType),
495                                 spot:             true,
496                                 availabilityZone: inst.(*ec2Instance).availabilityZone,
497                         }
498                         if instanceSet.pricesUpdated[pk].Before(staleTime) {
499                                 needUpdate = true
500                         }
501                         allTypes[ec2inst.InstanceType] = true
502                 }
503         }
504         if !needUpdate {
505                 return
506         }
507         var typeFilterValues []string
508         for instanceType := range allTypes {
509                 typeFilterValues = append(typeFilterValues, string(instanceType))
510         }
511         // Get 3x update interval worth of pricing data. (Ideally the
512         // AWS API would tell us "we have shown you all of the price
513         // changes up to time T", but it doesn't, so we'll just ask
514         // for 3 intervals worth of data on each update, de-duplicate
515         // the data points, and not worry too much about occasionally
516         // missing some data points when our lookups fail twice in a
517         // row.
518         dsphi := &ec2.DescribeSpotPriceHistoryInput{
519                 StartTime: aws.Time(updateTime.Add(-3 * instanceSet.ec2config.SpotPriceUpdateInterval.Duration())),
520                 Filters: []types.Filter{
521                         types.Filter{Name: aws.String("instance-type"), Values: typeFilterValues},
522                         types.Filter{Name: aws.String("product-description"), Values: []string{"Linux/UNIX"}},
523                 },
524         }
525         for {
526                 page, err := instanceSet.client.DescribeSpotPriceHistory(context.TODO(), dsphi)
527                 if err != nil {
528                         instanceSet.logger.Warnf("error retrieving spot instance prices: %s", err)
529                         break
530                 }
531                 for _, ent := range page.SpotPriceHistory {
532                         if ent.InstanceType == "" || ent.SpotPrice == nil || ent.Timestamp == nil {
533                                 // bogus record?
534                                 continue
535                         }
536                         price, err := strconv.ParseFloat(*ent.SpotPrice, 64)
537                         if err != nil {
538                                 // bogus record?
539                                 continue
540                         }
541                         pk := priceKey{
542                                 instanceType:     string(ent.InstanceType),
543                                 spot:             true,
544                                 availabilityZone: *ent.AvailabilityZone,
545                         }
546                         instanceSet.prices[pk] = append(instanceSet.prices[pk], cloud.InstancePrice{
547                                 StartTime: *ent.Timestamp,
548                                 Price:     price,
549                         })
550                         instanceSet.pricesUpdated[pk] = updateTime
551                 }
552                 if page.NextToken == nil {
553                         break
554                 }
555                 dsphi.NextToken = page.NextToken
556         }
557
558         expiredTime := updateTime.Add(-64 * instanceSet.ec2config.SpotPriceUpdateInterval.Duration())
559         for pk, last := range instanceSet.pricesUpdated {
560                 if last.Before(expiredTime) {
561                         delete(instanceSet.pricesUpdated, pk)
562                         delete(instanceSet.prices, pk)
563                 }
564         }
565         for pk, prices := range instanceSet.prices {
566                 instanceSet.prices[pk] = cloud.NormalizePriceHistory(prices)
567         }
568 }
569
570 func (instanceSet *ec2InstanceSet) Stop() {
571 }
572
573 type ec2Instance struct {
574         provider         *ec2InstanceSet
575         instance         types.Instance
576         availabilityZone string // sometimes available for spot instances
577 }
578
579 func (inst *ec2Instance) ID() cloud.InstanceID {
580         return cloud.InstanceID(*inst.instance.InstanceId)
581 }
582
583 func (inst *ec2Instance) String() string {
584         return *inst.instance.InstanceId
585 }
586
587 func (inst *ec2Instance) ProviderType() string {
588         return string(inst.instance.InstanceType)
589 }
590
591 func (inst *ec2Instance) SetTags(newTags cloud.InstanceTags) error {
592         var ec2tags []types.Tag
593         for k, v := range newTags {
594                 ec2tags = append(ec2tags, types.Tag{
595                         Key:   aws.String(k),
596                         Value: aws.String(v),
597                 })
598         }
599
600         _, err := inst.provider.client.CreateTags(context.TODO(), &ec2.CreateTagsInput{
601                 Resources: []string{*inst.instance.InstanceId},
602                 Tags:      ec2tags,
603         })
604
605         return err
606 }
607
608 func (inst *ec2Instance) Tags() cloud.InstanceTags {
609         tags := make(map[string]string)
610
611         for _, t := range inst.instance.Tags {
612                 tags[*t.Key] = *t.Value
613         }
614
615         return tags
616 }
617
618 func (inst *ec2Instance) Destroy() error {
619         _, err := inst.provider.client.TerminateInstances(context.TODO(), &ec2.TerminateInstancesInput{
620                 InstanceIds: []string{*inst.instance.InstanceId},
621         })
622         return err
623 }
624
625 func (inst *ec2Instance) Address() string {
626         if inst.instance.PrivateIpAddress != nil {
627                 return *inst.instance.PrivateIpAddress
628         }
629         return ""
630 }
631
632 func (inst *ec2Instance) RemoteUser() string {
633         return inst.provider.ec2config.AdminUsername
634 }
635
636 func (inst *ec2Instance) VerifyHostKey(ssh.PublicKey, *ssh.Client) error {
637         return cloud.ErrNotImplemented
638 }
639
640 // PriceHistory returns the price history for this specific instance.
641 //
642 // AWS documentation is elusive about whether the hourly cost of a
643 // given spot instance changes as the current spot price changes for
644 // the corresponding instance type and availability zone. Our
645 // implementation assumes the answer is yes, based on the following
646 // hints.
647 //
648 // https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/spot-requests.html
649 // says: "After your Spot Instance is running, if the Spot price rises
650 // above your maximum price, Amazon EC2 interrupts your Spot
651 // Instance." (This doesn't address what happens when the spot price
652 // rises *without* exceeding your maximum price.)
653 //
654 // https://docs.aws.amazon.com/whitepapers/latest/cost-optimization-leveraging-ec2-spot-instances/how-spot-instances-work.html
655 // says: "You pay the Spot price that's in effect, billed to the
656 // nearest second." (But it's not explicitly stated whether "the price
657 // in effect" changes over time for a given instance.)
658 //
659 // The same page also says, in a discussion about the effect of
660 // specifying a maximum price: "Note that you never pay more than the
661 // Spot price that is in effect when your Spot Instance is running."
662 // (The use of the phrase "is running", as opposed to "was launched",
663 // hints that pricing is dynamic.)
664 func (inst *ec2Instance) PriceHistory(instType arvados.InstanceType) []cloud.InstancePrice {
665         inst.provider.pricesLock.Lock()
666         defer inst.provider.pricesLock.Unlock()
667         // Note updateSpotPrices currently populates
668         // inst.provider.prices only for spot instances, so if
669         // spot==false here, we will return no data.
670         pk := priceKey{
671                 instanceType:     string(inst.instance.InstanceType),
672                 spot:             inst.instance.InstanceLifecycle == types.InstanceLifecycleTypeSpot,
673                 availabilityZone: inst.availabilityZone,
674         }
675         var prices []cloud.InstancePrice
676         for _, price := range inst.provider.prices[pk] {
677                 // ceil(added scratch space in GiB)
678                 gib := (instType.AddedScratch + 1<<30 - 1) >> 30
679                 monthly := inst.provider.ec2config.EBSPrice * float64(gib)
680                 hourly := monthly / 30 / 24
681                 price.Price += hourly
682                 prices = append(prices, price)
683         }
684         return prices
685 }
686
687 type rateLimitError struct {
688         error
689         earliestRetry time.Time
690 }
691
692 func (err rateLimitError) EarliestRetry() time.Time {
693         return err.earliestRetry
694 }
695
696 type capacityError struct {
697         error
698         isInstanceTypeSpecific bool
699 }
700
701 func (er *capacityError) IsCapacityError() bool {
702         return true
703 }
704
705 func (er *capacityError) IsInstanceTypeSpecific() bool {
706         return er.isInstanceTypeSpecific
707 }
708
709 var isCodeQuota = map[string]bool{
710         "InstanceLimitExceeded":             true,
711         "InsufficientAddressCapacity":       true,
712         "InsufficientFreeAddressesInSubnet": true,
713         "InsufficientVolumeCapacity":        true,
714         "MaxSpotInstanceCountExceeded":      true,
715         "VcpuLimitExceeded":                 true,
716 }
717
718 // isErrorQuota returns whether the error indicates we have reached
719 // some usage quota/limit -- i.e., immediately retrying with an equal
720 // or larger instance type will probably not work.
721 //
722 // Returns false if error is nil.
723 func isErrorQuota(err error) bool {
724         var aerr smithy.APIError
725         if errors.As(err, &aerr) {
726                 if _, ok := isCodeQuota[aerr.ErrorCode()]; ok {
727                         return true
728                 }
729         }
730         return false
731 }
732
733 var reSubnetSpecificInvalidParameterMessage = regexp.MustCompile(`(?ms).*( subnet |sufficient free [Ii]pv[46] addresses).*`)
734
735 // isErrorSubnetSpecific returns true if the problem encountered by
736 // RunInstances might be avoided by trying a different subnet.
737 func isErrorSubnetSpecific(err error) bool {
738         var aerr smithy.APIError
739         if !errors.As(err, &aerr) {
740                 return false
741         }
742         code := aerr.ErrorCode()
743         return strings.Contains(code, "Subnet") ||
744                 code == "InsufficientInstanceCapacity" ||
745                 code == "InsufficientVolumeCapacity" ||
746                 code == "Unsupported" ||
747                 // See TestIsErrorSubnetSpecific for examples of why
748                 // we look for substrings in code/message instead of
749                 // only using specific codes here.
750                 (strings.Contains(code, "InvalidParameter") &&
751                         reSubnetSpecificInvalidParameterMessage.MatchString(aerr.ErrorMessage()))
752 }
753
754 // isErrorCapacity returns true if the error indicates lack of
755 // capacity (either temporary or permanent) to run a specific instance
756 // type -- i.e., retrying with a different instance type might
757 // succeed.
758 func isErrorCapacity(err error) bool {
759         var aerr smithy.APIError
760         if !errors.As(err, &aerr) {
761                 return false
762         }
763         code := aerr.ErrorCode()
764         return code == "InsufficientInstanceCapacity" ||
765                 (code == "Unsupported" && strings.Contains(aerr.ErrorMessage(), "requested instance type"))
766 }
767
768 type ec2QuotaError struct {
769         error
770 }
771
772 func (er *ec2QuotaError) IsQuotaError() bool {
773         return true
774 }
775
776 func isThrottleError(err error) bool {
777         var aerr smithy.APIError
778         if !errors.As(err, &aerr) {
779                 return false
780         }
781         _, is := retry.DefaultThrottleErrorCodes[aerr.ErrorCode()]
782         return is
783 }
784
785 func wrapError(err error, throttleValue *atomic.Value) error {
786         if isThrottleError(err) {
787                 // Back off exponentially until an upstream call
788                 // either succeeds or returns a non-throttle error.
789                 d, _ := throttleValue.Load().(time.Duration)
790                 d = d*3/2 + time.Second
791                 if d < throttleDelayMin {
792                         d = throttleDelayMin
793                 } else if d > throttleDelayMax {
794                         d = throttleDelayMax
795                 }
796                 throttleValue.Store(d)
797                 return rateLimitError{error: err, earliestRetry: time.Now().Add(d)}
798         } else if isErrorQuota(err) {
799                 return &ec2QuotaError{err}
800         } else if isErrorCapacity(err) {
801                 return &capacityError{err, true}
802         } else if err != nil {
803                 throttleValue.Store(time.Duration(0))
804                 return err
805         }
806         throttleValue.Store(time.Duration(0))
807         return nil
808 }
809
810 var boolLabelValue = map[bool]string{false: "0", true: "1"}