Merge branch '21640-max-nofile'
[arvados.git] / lib / cloud / ec2 / ec2.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package ec2
6
7 import (
8         "context"
9         "crypto/md5"
10         "crypto/rsa"
11         "crypto/sha1"
12         "crypto/x509"
13         "encoding/base64"
14         "encoding/json"
15         "errors"
16         "fmt"
17         "math/big"
18         "regexp"
19         "strconv"
20         "strings"
21         "sync"
22         "sync/atomic"
23         "time"
24
25         "git.arvados.org/arvados.git/lib/cloud"
26         "git.arvados.org/arvados.git/sdk/go/arvados"
27         "github.com/aws/aws-sdk-go-v2/aws"
28         "github.com/aws/aws-sdk-go-v2/aws/retry"
29         config "github.com/aws/aws-sdk-go-v2/config"
30         "github.com/aws/aws-sdk-go-v2/credentials"
31         "github.com/aws/aws-sdk-go-v2/service/ec2"
32         "github.com/aws/aws-sdk-go-v2/service/ec2/types"
33         "github.com/aws/smithy-go"
34         "github.com/prometheus/client_golang/prometheus"
35         "github.com/sirupsen/logrus"
36         "golang.org/x/crypto/ssh"
37 )
38
39 // Driver is the ec2 implementation of the cloud.Driver interface.
40 var Driver = cloud.DriverFunc(newEC2InstanceSet)
41
42 const (
43         throttleDelayMin = time.Second
44         throttleDelayMax = time.Minute
45 )
46
47 type ec2InstanceSetConfig struct {
48         AccessKeyID             string
49         SecretAccessKey         string
50         Region                  string
51         SecurityGroupIDs        arvados.StringSet
52         SubnetID                sliceOrSingleString
53         AdminUsername           string
54         EBSVolumeType           types.VolumeType
55         EBSPrice                float64
56         IAMInstanceProfile      string
57         SpotPriceUpdateInterval arvados.Duration
58 }
59
60 type sliceOrSingleString []string
61
62 // UnmarshalJSON unmarshals an array of strings, and also accepts ""
63 // as [], and "foo" as ["foo"].
64 func (ss *sliceOrSingleString) UnmarshalJSON(data []byte) error {
65         if len(data) == 0 {
66                 *ss = nil
67         } else if data[0] == '[' {
68                 var slice []string
69                 err := json.Unmarshal(data, &slice)
70                 if err != nil {
71                         return err
72                 }
73                 if len(slice) == 0 {
74                         *ss = nil
75                 } else {
76                         *ss = slice
77                 }
78         } else {
79                 var str string
80                 err := json.Unmarshal(data, &str)
81                 if err != nil {
82                         return err
83                 }
84                 if str == "" {
85                         *ss = nil
86                 } else {
87                         *ss = []string{str}
88                 }
89         }
90         return nil
91 }
92
93 type ec2Interface interface {
94         DescribeKeyPairs(context.Context, *ec2.DescribeKeyPairsInput, ...func(*ec2.Options)) (*ec2.DescribeKeyPairsOutput, error)
95         ImportKeyPair(context.Context, *ec2.ImportKeyPairInput, ...func(*ec2.Options)) (*ec2.ImportKeyPairOutput, error)
96         RunInstances(context.Context, *ec2.RunInstancesInput, ...func(*ec2.Options)) (*ec2.RunInstancesOutput, error)
97         DescribeInstances(context.Context, *ec2.DescribeInstancesInput, ...func(*ec2.Options)) (*ec2.DescribeInstancesOutput, error)
98         DescribeInstanceStatus(context.Context, *ec2.DescribeInstanceStatusInput, ...func(*ec2.Options)) (*ec2.DescribeInstanceStatusOutput, error)
99         DescribeSpotPriceHistory(context.Context, *ec2.DescribeSpotPriceHistoryInput, ...func(*ec2.Options)) (*ec2.DescribeSpotPriceHistoryOutput, error)
100         CreateTags(context.Context, *ec2.CreateTagsInput, ...func(*ec2.Options)) (*ec2.CreateTagsOutput, error)
101         TerminateInstances(context.Context, *ec2.TerminateInstancesInput, ...func(*ec2.Options)) (*ec2.TerminateInstancesOutput, error)
102 }
103
104 type ec2InstanceSet struct {
105         ec2config              ec2InstanceSetConfig
106         currentSubnetIDIndex   int32
107         instanceSetID          cloud.InstanceSetID
108         logger                 logrus.FieldLogger
109         client                 ec2Interface
110         keysMtx                sync.Mutex
111         keys                   map[string]string
112         throttleDelayCreate    atomic.Value
113         throttleDelayInstances atomic.Value
114
115         prices        map[priceKey][]cloud.InstancePrice
116         pricesLock    sync.Mutex
117         pricesUpdated map[priceKey]time.Time
118
119         mInstances      *prometheus.GaugeVec
120         mInstanceStarts *prometheus.CounterVec
121 }
122
123 func newEC2InstanceSet(confRaw json.RawMessage, instanceSetID cloud.InstanceSetID, _ cloud.SharedResourceTags, logger logrus.FieldLogger, reg *prometheus.Registry) (prv cloud.InstanceSet, err error) {
124         instanceSet := &ec2InstanceSet{
125                 instanceSetID: instanceSetID,
126                 logger:        logger,
127         }
128         err = json.Unmarshal(confRaw, &instanceSet.ec2config)
129         if err != nil {
130                 return nil, err
131         }
132         awsConfig, err := config.LoadDefaultConfig(context.TODO(),
133                 config.WithRegion(instanceSet.ec2config.Region),
134                 config.WithCredentialsCacheOptions(func(o *aws.CredentialsCacheOptions) {
135                         o.ExpiryWindow = 5 * time.Minute
136                 }),
137                 func(o *config.LoadOptions) error {
138                         if instanceSet.ec2config.AccessKeyID == "" && instanceSet.ec2config.SecretAccessKey == "" {
139                                 // Use default SDK behavior (IAM role
140                                 // via IMDSv2)
141                                 return nil
142                         }
143                         o.Credentials = credentials.StaticCredentialsProvider{
144                                 Value: aws.Credentials{
145                                         AccessKeyID:     instanceSet.ec2config.AccessKeyID,
146                                         SecretAccessKey: instanceSet.ec2config.SecretAccessKey,
147                                         Source:          "Arvados configuration",
148                                 },
149                         }
150                         return nil
151                 })
152         if err != nil {
153                 return nil, err
154         }
155
156         instanceSet.client = ec2.NewFromConfig(awsConfig)
157         instanceSet.keys = make(map[string]string)
158         if instanceSet.ec2config.EBSVolumeType == "" {
159                 instanceSet.ec2config.EBSVolumeType = "gp2"
160         }
161
162         // Set up metrics
163         instanceSet.mInstances = prometheus.NewGaugeVec(prometheus.GaugeOpts{
164                 Namespace: "arvados",
165                 Subsystem: "dispatchcloud",
166                 Name:      "ec2_instances",
167                 Help:      "Number of instances running",
168         }, []string{"subnet_id"})
169         instanceSet.mInstanceStarts = prometheus.NewCounterVec(prometheus.CounterOpts{
170                 Namespace: "arvados",
171                 Subsystem: "dispatchcloud",
172                 Name:      "ec2_instance_starts_total",
173                 Help:      "Number of attempts to start a new instance",
174         }, []string{"subnet_id", "success"})
175         // Initialize all of the series we'll be reporting.  Otherwise
176         // the {subnet=A, success=0} series doesn't appear in metrics
177         // at all until there's a failure in subnet A.
178         for _, subnet := range instanceSet.ec2config.SubnetID {
179                 instanceSet.mInstanceStarts.WithLabelValues(subnet, "0").Add(0)
180                 instanceSet.mInstanceStarts.WithLabelValues(subnet, "1").Add(0)
181         }
182         if len(instanceSet.ec2config.SubnetID) == 0 {
183                 instanceSet.mInstanceStarts.WithLabelValues("", "0").Add(0)
184                 instanceSet.mInstanceStarts.WithLabelValues("", "1").Add(0)
185         }
186         if reg != nil {
187                 reg.MustRegister(instanceSet.mInstances)
188                 reg.MustRegister(instanceSet.mInstanceStarts)
189         }
190
191         return instanceSet, nil
192 }
193
194 func awsKeyFingerprint(pk ssh.PublicKey) (md5fp string, sha1fp string, err error) {
195         // AWS key fingerprints don't use the usual key fingerprint
196         // you get from ssh-keygen or ssh.FingerprintLegacyMD5()
197         // (you can get that from md5.Sum(pk.Marshal())
198         //
199         // AWS uses the md5 or sha1 of the PKIX DER encoding of the
200         // public key, so calculate those fingerprints here.
201         var rsaPub struct {
202                 Name string
203                 E    *big.Int
204                 N    *big.Int
205         }
206         if err := ssh.Unmarshal(pk.Marshal(), &rsaPub); err != nil {
207                 return "", "", fmt.Errorf("agent: Unmarshal failed to parse public key: %v", err)
208         }
209         rsaPk := rsa.PublicKey{
210                 E: int(rsaPub.E.Int64()),
211                 N: rsaPub.N,
212         }
213         pkix, _ := x509.MarshalPKIXPublicKey(&rsaPk)
214         md5pkix := md5.Sum([]byte(pkix))
215         sha1pkix := sha1.Sum([]byte(pkix))
216         md5fp = ""
217         sha1fp = ""
218         for i := 0; i < len(md5pkix); i++ {
219                 md5fp += fmt.Sprintf(":%02x", md5pkix[i])
220         }
221         for i := 0; i < len(sha1pkix); i++ {
222                 sha1fp += fmt.Sprintf(":%02x", sha1pkix[i])
223         }
224         return md5fp[1:], sha1fp[1:], nil
225 }
226
227 func (instanceSet *ec2InstanceSet) Create(
228         instanceType arvados.InstanceType,
229         imageID cloud.ImageID,
230         newTags cloud.InstanceTags,
231         initCommand cloud.InitCommand,
232         publicKey ssh.PublicKey) (cloud.Instance, error) {
233
234         ec2tags := []types.Tag{}
235         for k, v := range newTags {
236                 ec2tags = append(ec2tags, types.Tag{
237                         Key:   aws.String(k),
238                         Value: aws.String(v),
239                 })
240         }
241
242         var groups []string
243         for sg := range instanceSet.ec2config.SecurityGroupIDs {
244                 groups = append(groups, sg)
245         }
246
247         rii := ec2.RunInstancesInput{
248                 ImageId:      aws.String(string(imageID)),
249                 InstanceType: types.InstanceType(instanceType.ProviderType),
250                 MaxCount:     aws.Int32(1),
251                 MinCount:     aws.Int32(1),
252
253                 NetworkInterfaces: []types.InstanceNetworkInterfaceSpecification{{
254                         AssociatePublicIpAddress: aws.Bool(false),
255                         DeleteOnTermination:      aws.Bool(true),
256                         DeviceIndex:              aws.Int32(0),
257                         Groups:                   groups,
258                 }},
259                 DisableApiTermination:             aws.Bool(false),
260                 InstanceInitiatedShutdownBehavior: types.ShutdownBehaviorTerminate,
261                 TagSpecifications: []types.TagSpecification{
262                         {
263                                 ResourceType: types.ResourceTypeInstance,
264                                 Tags:         ec2tags,
265                         }},
266                 MetadataOptions: &types.InstanceMetadataOptionsRequest{
267                         // Require IMDSv2, as described at
268                         // https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/configuring-IMDS-new-instances.html
269                         HttpEndpoint: types.InstanceMetadataEndpointStateEnabled,
270                         HttpTokens:   types.HttpTokensStateRequired,
271                 },
272                 UserData: aws.String(base64.StdEncoding.EncodeToString([]byte("#!/bin/sh\n" + initCommand + "\n"))),
273         }
274
275         if publicKey != nil {
276                 keyname, err := instanceSet.getKeyName(publicKey)
277                 if err != nil {
278                         return nil, err
279                 }
280                 rii.KeyName = &keyname
281         }
282
283         if instanceType.AddedScratch > 0 {
284                 rii.BlockDeviceMappings = []types.BlockDeviceMapping{{
285                         DeviceName: aws.String("/dev/xvdt"),
286                         Ebs: &types.EbsBlockDevice{
287                                 DeleteOnTermination: aws.Bool(true),
288                                 VolumeSize:          aws.Int32(int32((int64(instanceType.AddedScratch) + (1<<30 - 1)) >> 30)),
289                                 VolumeType:          instanceSet.ec2config.EBSVolumeType,
290                         }}}
291         }
292
293         if instanceType.Preemptible {
294                 rii.InstanceMarketOptions = &types.InstanceMarketOptionsRequest{
295                         MarketType: types.MarketTypeSpot,
296                         SpotOptions: &types.SpotMarketOptions{
297                                 InstanceInterruptionBehavior: types.InstanceInterruptionBehaviorTerminate,
298                                 MaxPrice:                     aws.String(fmt.Sprintf("%v", instanceType.Price)),
299                         }}
300         }
301
302         if instanceSet.ec2config.IAMInstanceProfile != "" {
303                 rii.IamInstanceProfile = &types.IamInstanceProfileSpecification{
304                         Name: aws.String(instanceSet.ec2config.IAMInstanceProfile),
305                 }
306         }
307
308         var rsv *ec2.RunInstancesOutput
309         var errToReturn error
310         subnets := instanceSet.ec2config.SubnetID
311         currentSubnetIDIndex := int(atomic.LoadInt32(&instanceSet.currentSubnetIDIndex))
312         for tryOffset := 0; ; tryOffset++ {
313                 tryIndex := 0
314                 trySubnet := ""
315                 if len(subnets) > 0 {
316                         tryIndex = (currentSubnetIDIndex + tryOffset) % len(subnets)
317                         trySubnet = subnets[tryIndex]
318                         rii.NetworkInterfaces[0].SubnetId = aws.String(trySubnet)
319                 }
320                 var err error
321                 rsv, err = instanceSet.client.RunInstances(context.TODO(), &rii)
322                 instanceSet.mInstanceStarts.WithLabelValues(trySubnet, boolLabelValue[err == nil]).Add(1)
323                 if !isErrorCapacity(errToReturn) || isErrorCapacity(err) {
324                         // We want to return the last capacity error,
325                         // if any; otherwise the last non-capacity
326                         // error.
327                         errToReturn = err
328                 }
329                 if isErrorSubnetSpecific(err) &&
330                         tryOffset < len(subnets)-1 {
331                         instanceSet.logger.WithError(err).WithField("SubnetID", subnets[tryIndex]).
332                                 Warn("RunInstances failed, trying next subnet")
333                         continue
334                 }
335                 // Succeeded, or exhausted all subnets, or got a
336                 // non-subnet-related error.
337                 //
338                 // We intentionally update currentSubnetIDIndex even
339                 // in the non-retryable-failure case here to avoid a
340                 // situation where successive calls to Create() keep
341                 // returning errors for the same subnet (perhaps
342                 // "subnet full") and never reveal the errors for the
343                 // other configured subnets (perhaps "subnet ID
344                 // invalid").
345                 atomic.StoreInt32(&instanceSet.currentSubnetIDIndex, int32(tryIndex))
346                 break
347         }
348         if rsv == nil || len(rsv.Instances) == 0 {
349                 return nil, wrapError(errToReturn, &instanceSet.throttleDelayCreate)
350         }
351         return &ec2Instance{
352                 provider: instanceSet,
353                 instance: rsv.Instances[0],
354         }, nil
355 }
356
357 func (instanceSet *ec2InstanceSet) getKeyName(publicKey ssh.PublicKey) (string, error) {
358         instanceSet.keysMtx.Lock()
359         defer instanceSet.keysMtx.Unlock()
360         md5keyFingerprint, sha1keyFingerprint, err := awsKeyFingerprint(publicKey)
361         if err != nil {
362                 return "", fmt.Errorf("Could not make key fingerprint: %v", err)
363         }
364         if keyname, ok := instanceSet.keys[md5keyFingerprint]; ok {
365                 return keyname, nil
366         }
367         keyout, err := instanceSet.client.DescribeKeyPairs(context.TODO(), &ec2.DescribeKeyPairsInput{
368                 Filters: []types.Filter{{
369                         Name:   aws.String("fingerprint"),
370                         Values: []string{md5keyFingerprint, sha1keyFingerprint},
371                 }},
372         })
373         if err != nil {
374                 return "", fmt.Errorf("Could not search for keypair: %v", err)
375         }
376         if len(keyout.KeyPairs) > 0 {
377                 return *(keyout.KeyPairs[0].KeyName), nil
378         }
379         keyname := "arvados-dispatch-keypair-" + md5keyFingerprint
380         _, err = instanceSet.client.ImportKeyPair(context.TODO(), &ec2.ImportKeyPairInput{
381                 KeyName:           &keyname,
382                 PublicKeyMaterial: ssh.MarshalAuthorizedKey(publicKey),
383         })
384         if err != nil {
385                 return "", fmt.Errorf("Could not import keypair: %v", err)
386         }
387         instanceSet.keys[md5keyFingerprint] = keyname
388         return keyname, nil
389 }
390
391 func (instanceSet *ec2InstanceSet) Instances(tags cloud.InstanceTags) (instances []cloud.Instance, err error) {
392         var filters []types.Filter
393         for k, v := range tags {
394                 filters = append(filters, types.Filter{
395                         Name:   aws.String("tag:" + k),
396                         Values: []string{v},
397                 })
398         }
399         needAZs := false
400         dii := &ec2.DescribeInstancesInput{Filters: filters}
401         for {
402                 dio, err := instanceSet.client.DescribeInstances(context.TODO(), dii)
403                 err = wrapError(err, &instanceSet.throttleDelayInstances)
404                 if err != nil {
405                         return nil, err
406                 }
407
408                 for _, rsv := range dio.Reservations {
409                         for _, inst := range rsv.Instances {
410                                 switch inst.State.Name {
411                                 case types.InstanceStateNameShuttingDown:
412                                 case types.InstanceStateNameTerminated:
413                                 default:
414                                         instances = append(instances, &ec2Instance{
415                                                 provider: instanceSet,
416                                                 instance: inst,
417                                         })
418                                         if inst.InstanceLifecycle == types.InstanceLifecycleTypeSpot {
419                                                 needAZs = true
420                                         }
421                                 }
422                         }
423                 }
424                 if dio.NextToken == nil {
425                         break
426                 }
427                 dii.NextToken = dio.NextToken
428         }
429         if needAZs && instanceSet.ec2config.SpotPriceUpdateInterval > 0 {
430                 az := map[string]string{}
431                 disi := &ec2.DescribeInstanceStatusInput{IncludeAllInstances: aws.Bool(true)}
432                 for {
433                         page, err := instanceSet.client.DescribeInstanceStatus(context.TODO(), disi)
434                         if err != nil {
435                                 instanceSet.logger.Warnf("error getting instance statuses: %s", err)
436                                 break
437                         }
438                         for _, ent := range page.InstanceStatuses {
439                                 az[*ent.InstanceId] = *ent.AvailabilityZone
440                         }
441                         if page.NextToken == nil {
442                                 break
443                         }
444                         disi.NextToken = page.NextToken
445                 }
446                 for _, inst := range instances {
447                         inst := inst.(*ec2Instance)
448                         inst.availabilityZone = az[*inst.instance.InstanceId]
449                 }
450                 instanceSet.updateSpotPrices(instances)
451         }
452
453         // Count instances in each subnet, and report in metrics.
454         subnetInstances := map[string]int{"": 0}
455         for _, subnet := range instanceSet.ec2config.SubnetID {
456                 subnetInstances[subnet] = 0
457         }
458         for _, inst := range instances {
459                 subnet := inst.(*ec2Instance).instance.SubnetId
460                 if subnet != nil {
461                         subnetInstances[*subnet]++
462                 } else {
463                         subnetInstances[""]++
464                 }
465         }
466         for subnet, count := range subnetInstances {
467                 instanceSet.mInstances.WithLabelValues(subnet).Set(float64(count))
468         }
469
470         return instances, err
471 }
472
473 type priceKey struct {
474         instanceType     string
475         spot             bool
476         availabilityZone string
477 }
478
479 // Refresh recent spot instance pricing data for the given instances,
480 // unless we already have recent pricing data for all relevant types.
481 func (instanceSet *ec2InstanceSet) updateSpotPrices(instances []cloud.Instance) {
482         if len(instances) == 0 {
483                 return
484         }
485
486         instanceSet.pricesLock.Lock()
487         defer instanceSet.pricesLock.Unlock()
488         if instanceSet.prices == nil {
489                 instanceSet.prices = map[priceKey][]cloud.InstancePrice{}
490                 instanceSet.pricesUpdated = map[priceKey]time.Time{}
491         }
492
493         updateTime := time.Now()
494         staleTime := updateTime.Add(-instanceSet.ec2config.SpotPriceUpdateInterval.Duration())
495         needUpdate := false
496         allTypes := map[types.InstanceType]bool{}
497
498         for _, inst := range instances {
499                 ec2inst := inst.(*ec2Instance).instance
500                 if ec2inst.InstanceLifecycle == types.InstanceLifecycleTypeSpot {
501                         pk := priceKey{
502                                 instanceType:     string(ec2inst.InstanceType),
503                                 spot:             true,
504                                 availabilityZone: inst.(*ec2Instance).availabilityZone,
505                         }
506                         if instanceSet.pricesUpdated[pk].Before(staleTime) {
507                                 needUpdate = true
508                         }
509                         allTypes[ec2inst.InstanceType] = true
510                 }
511         }
512         if !needUpdate {
513                 return
514         }
515         var typeFilterValues []string
516         for instanceType := range allTypes {
517                 typeFilterValues = append(typeFilterValues, string(instanceType))
518         }
519         // Get 3x update interval worth of pricing data. (Ideally the
520         // AWS API would tell us "we have shown you all of the price
521         // changes up to time T", but it doesn't, so we'll just ask
522         // for 3 intervals worth of data on each update, de-duplicate
523         // the data points, and not worry too much about occasionally
524         // missing some data points when our lookups fail twice in a
525         // row.
526         dsphi := &ec2.DescribeSpotPriceHistoryInput{
527                 StartTime: aws.Time(updateTime.Add(-3 * instanceSet.ec2config.SpotPriceUpdateInterval.Duration())),
528                 Filters: []types.Filter{
529                         types.Filter{Name: aws.String("instance-type"), Values: typeFilterValues},
530                         types.Filter{Name: aws.String("product-description"), Values: []string{"Linux/UNIX"}},
531                 },
532         }
533         for {
534                 page, err := instanceSet.client.DescribeSpotPriceHistory(context.TODO(), dsphi)
535                 if err != nil {
536                         instanceSet.logger.Warnf("error retrieving spot instance prices: %s", err)
537                         break
538                 }
539                 for _, ent := range page.SpotPriceHistory {
540                         if ent.InstanceType == "" || ent.SpotPrice == nil || ent.Timestamp == nil {
541                                 // bogus record?
542                                 continue
543                         }
544                         price, err := strconv.ParseFloat(*ent.SpotPrice, 64)
545                         if err != nil {
546                                 // bogus record?
547                                 continue
548                         }
549                         pk := priceKey{
550                                 instanceType:     string(ent.InstanceType),
551                                 spot:             true,
552                                 availabilityZone: *ent.AvailabilityZone,
553                         }
554                         instanceSet.prices[pk] = append(instanceSet.prices[pk], cloud.InstancePrice{
555                                 StartTime: *ent.Timestamp,
556                                 Price:     price,
557                         })
558                         instanceSet.pricesUpdated[pk] = updateTime
559                 }
560                 if page.NextToken == nil {
561                         break
562                 }
563                 dsphi.NextToken = page.NextToken
564         }
565
566         expiredTime := updateTime.Add(-64 * instanceSet.ec2config.SpotPriceUpdateInterval.Duration())
567         for pk, last := range instanceSet.pricesUpdated {
568                 if last.Before(expiredTime) {
569                         delete(instanceSet.pricesUpdated, pk)
570                         delete(instanceSet.prices, pk)
571                 }
572         }
573         for pk, prices := range instanceSet.prices {
574                 instanceSet.prices[pk] = cloud.NormalizePriceHistory(prices)
575         }
576 }
577
578 func (instanceSet *ec2InstanceSet) Stop() {
579 }
580
581 type ec2Instance struct {
582         provider         *ec2InstanceSet
583         instance         types.Instance
584         availabilityZone string // sometimes available for spot instances
585 }
586
587 func (inst *ec2Instance) ID() cloud.InstanceID {
588         return cloud.InstanceID(*inst.instance.InstanceId)
589 }
590
591 func (inst *ec2Instance) String() string {
592         return *inst.instance.InstanceId
593 }
594
595 func (inst *ec2Instance) ProviderType() string {
596         return string(inst.instance.InstanceType)
597 }
598
599 func (inst *ec2Instance) SetTags(newTags cloud.InstanceTags) error {
600         var ec2tags []types.Tag
601         for k, v := range newTags {
602                 ec2tags = append(ec2tags, types.Tag{
603                         Key:   aws.String(k),
604                         Value: aws.String(v),
605                 })
606         }
607
608         _, err := inst.provider.client.CreateTags(context.TODO(), &ec2.CreateTagsInput{
609                 Resources: []string{*inst.instance.InstanceId},
610                 Tags:      ec2tags,
611         })
612
613         return err
614 }
615
616 func (inst *ec2Instance) Tags() cloud.InstanceTags {
617         tags := make(map[string]string)
618
619         for _, t := range inst.instance.Tags {
620                 tags[*t.Key] = *t.Value
621         }
622
623         return tags
624 }
625
626 func (inst *ec2Instance) Destroy() error {
627         _, err := inst.provider.client.TerminateInstances(context.TODO(), &ec2.TerminateInstancesInput{
628                 InstanceIds: []string{*inst.instance.InstanceId},
629         })
630         return err
631 }
632
633 func (inst *ec2Instance) Address() string {
634         if inst.instance.PrivateIpAddress != nil {
635                 return *inst.instance.PrivateIpAddress
636         }
637         return ""
638 }
639
640 func (inst *ec2Instance) RemoteUser() string {
641         return inst.provider.ec2config.AdminUsername
642 }
643
644 func (inst *ec2Instance) VerifyHostKey(ssh.PublicKey, *ssh.Client) error {
645         return cloud.ErrNotImplemented
646 }
647
648 // PriceHistory returns the price history for this specific instance.
649 //
650 // AWS documentation is elusive about whether the hourly cost of a
651 // given spot instance changes as the current spot price changes for
652 // the corresponding instance type and availability zone. Our
653 // implementation assumes the answer is yes, based on the following
654 // hints.
655 //
656 // https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/spot-requests.html
657 // says: "After your Spot Instance is running, if the Spot price rises
658 // above your maximum price, Amazon EC2 interrupts your Spot
659 // Instance." (This doesn't address what happens when the spot price
660 // rises *without* exceeding your maximum price.)
661 //
662 // https://docs.aws.amazon.com/whitepapers/latest/cost-optimization-leveraging-ec2-spot-instances/how-spot-instances-work.html
663 // says: "You pay the Spot price that's in effect, billed to the
664 // nearest second." (But it's not explicitly stated whether "the price
665 // in effect" changes over time for a given instance.)
666 //
667 // The same page also says, in a discussion about the effect of
668 // specifying a maximum price: "Note that you never pay more than the
669 // Spot price that is in effect when your Spot Instance is running."
670 // (The use of the phrase "is running", as opposed to "was launched",
671 // hints that pricing is dynamic.)
672 func (inst *ec2Instance) PriceHistory(instType arvados.InstanceType) []cloud.InstancePrice {
673         inst.provider.pricesLock.Lock()
674         defer inst.provider.pricesLock.Unlock()
675         // Note updateSpotPrices currently populates
676         // inst.provider.prices only for spot instances, so if
677         // spot==false here, we will return no data.
678         pk := priceKey{
679                 instanceType:     string(inst.instance.InstanceType),
680                 spot:             inst.instance.InstanceLifecycle == types.InstanceLifecycleTypeSpot,
681                 availabilityZone: inst.availabilityZone,
682         }
683         var prices []cloud.InstancePrice
684         for _, price := range inst.provider.prices[pk] {
685                 // ceil(added scratch space in GiB)
686                 gib := (instType.AddedScratch + 1<<30 - 1) >> 30
687                 monthly := inst.provider.ec2config.EBSPrice * float64(gib)
688                 hourly := monthly / 30 / 24
689                 price.Price += hourly
690                 prices = append(prices, price)
691         }
692         return prices
693 }
694
695 type rateLimitError struct {
696         error
697         earliestRetry time.Time
698 }
699
700 func (err rateLimitError) EarliestRetry() time.Time {
701         return err.earliestRetry
702 }
703
704 type capacityError struct {
705         error
706         isInstanceTypeSpecific bool
707 }
708
709 func (er *capacityError) IsCapacityError() bool {
710         return true
711 }
712
713 func (er *capacityError) IsInstanceTypeSpecific() bool {
714         return er.isInstanceTypeSpecific
715 }
716
717 var isCodeQuota = map[string]bool{
718         "InstanceLimitExceeded":             true,
719         "InsufficientAddressCapacity":       true,
720         "InsufficientFreeAddressesInSubnet": true,
721         "InsufficientVolumeCapacity":        true,
722         "MaxSpotInstanceCountExceeded":      true,
723         "VcpuLimitExceeded":                 true,
724 }
725
726 // isErrorQuota returns whether the error indicates we have reached
727 // some usage quota/limit -- i.e., immediately retrying with an equal
728 // or larger instance type will probably not work.
729 //
730 // Returns false if error is nil.
731 func isErrorQuota(err error) bool {
732         var aerr smithy.APIError
733         if errors.As(err, &aerr) {
734                 if _, ok := isCodeQuota[aerr.ErrorCode()]; ok {
735                         return true
736                 }
737         }
738         return false
739 }
740
741 var reSubnetSpecificInvalidParameterMessage = regexp.MustCompile(`(?ms).*( subnet |sufficient free [Ii]pv[46] addresses).*`)
742
743 // isErrorSubnetSpecific returns true if the problem encountered by
744 // RunInstances might be avoided by trying a different subnet.
745 func isErrorSubnetSpecific(err error) bool {
746         var aerr smithy.APIError
747         if !errors.As(err, &aerr) {
748                 return false
749         }
750         code := aerr.ErrorCode()
751         return strings.Contains(code, "Subnet") ||
752                 code == "InsufficientInstanceCapacity" ||
753                 code == "InsufficientVolumeCapacity" ||
754                 code == "Unsupported" ||
755                 // See TestIsErrorSubnetSpecific for examples of why
756                 // we look for substrings in code/message instead of
757                 // only using specific codes here.
758                 (strings.Contains(code, "InvalidParameter") &&
759                         reSubnetSpecificInvalidParameterMessage.MatchString(aerr.ErrorMessage()))
760 }
761
762 // isErrorCapacity returns true if the error indicates lack of
763 // capacity (either temporary or permanent) to run a specific instance
764 // type -- i.e., retrying with a different instance type might
765 // succeed.
766 func isErrorCapacity(err error) bool {
767         var aerr smithy.APIError
768         if !errors.As(err, &aerr) {
769                 return false
770         }
771         code := aerr.ErrorCode()
772         return code == "InsufficientInstanceCapacity" ||
773                 (code == "Unsupported" && strings.Contains(aerr.ErrorMessage(), "requested instance type"))
774 }
775
776 type ec2QuotaError struct {
777         error
778 }
779
780 func (er *ec2QuotaError) IsQuotaError() bool {
781         return true
782 }
783
784 func isThrottleError(err error) bool {
785         var aerr smithy.APIError
786         if !errors.As(err, &aerr) {
787                 return false
788         }
789         _, is := retry.DefaultThrottleErrorCodes[aerr.ErrorCode()]
790         return is
791 }
792
793 func wrapError(err error, throttleValue *atomic.Value) error {
794         if isThrottleError(err) {
795                 // Back off exponentially until an upstream call
796                 // either succeeds or returns a non-throttle error.
797                 d, _ := throttleValue.Load().(time.Duration)
798                 d = d*3/2 + time.Second
799                 if d < throttleDelayMin {
800                         d = throttleDelayMin
801                 } else if d > throttleDelayMax {
802                         d = throttleDelayMax
803                 }
804                 throttleValue.Store(d)
805                 return rateLimitError{error: err, earliestRetry: time.Now().Add(d)}
806         } else if isErrorQuota(err) {
807                 return &ec2QuotaError{err}
808         } else if isErrorCapacity(err) {
809                 return &capacityError{err, true}
810         } else if err != nil {
811                 throttleValue.Store(time.Duration(0))
812                 return err
813         }
814         throttleValue.Store(time.Duration(0))
815         return nil
816 }
817
818 var boolLabelValue = map[bool]string{false: "0", true: "1"}