]> git.arvados.org - arvados.git/blob - lib/cloud/ec2/ec2.go
23044: De-dup ContainerWebServices routing logic.
[arvados.git] / lib / cloud / ec2 / ec2.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package ec2
6
7 import (
8         "context"
9         "crypto/md5"
10         "crypto/rsa"
11         "crypto/sha1"
12         "crypto/x509"
13         "encoding/base64"
14         "encoding/hex"
15         "encoding/json"
16         "errors"
17         "fmt"
18         "math/big"
19         "regexp"
20         "strconv"
21         "strings"
22         "sync"
23         "sync/atomic"
24         "time"
25         "unicode"
26
27         "git.arvados.org/arvados.git/lib/cloud"
28         "git.arvados.org/arvados.git/sdk/go/arvados"
29         "github.com/aws/aws-sdk-go-v2/aws"
30         "github.com/aws/aws-sdk-go-v2/aws/retry"
31         config "github.com/aws/aws-sdk-go-v2/config"
32         "github.com/aws/aws-sdk-go-v2/credentials"
33         "github.com/aws/aws-sdk-go-v2/service/ec2"
34         "github.com/aws/aws-sdk-go-v2/service/ec2/types"
35         "github.com/aws/smithy-go"
36         "github.com/prometheus/client_golang/prometheus"
37         "github.com/sirupsen/logrus"
38         "golang.org/x/crypto/ssh"
39 )
40
41 // Driver is the ec2 implementation of the cloud.Driver interface.
42 var Driver = cloud.DriverFunc(newEC2InstanceSet)
43
44 const (
45         throttleDelayMin = time.Second
46         throttleDelayMax = time.Minute
47 )
48
49 type ec2InstanceSetConfig struct {
50         AccessKeyID             string
51         SecretAccessKey         string
52         Region                  string
53         SecurityGroupIDs        arvados.StringSet
54         SubnetID                sliceOrSingleString
55         AdminUsername           string
56         EBSVolumeType           types.VolumeType
57         EBSPrice                float64
58         IAMInstanceProfile      string
59         SpotPriceUpdateInterval arvados.Duration
60         InstanceTypeQuotaGroups map[string]string
61 }
62
63 type sliceOrSingleString []string
64
65 // UnmarshalJSON unmarshals an array of strings, and also accepts ""
66 // as [], and "foo" as ["foo"].
67 func (ss *sliceOrSingleString) UnmarshalJSON(data []byte) error {
68         if len(data) == 0 {
69                 *ss = nil
70         } else if data[0] == '[' {
71                 var slice []string
72                 err := json.Unmarshal(data, &slice)
73                 if err != nil {
74                         return err
75                 }
76                 if len(slice) == 0 {
77                         *ss = nil
78                 } else {
79                         *ss = slice
80                 }
81         } else {
82                 var str string
83                 err := json.Unmarshal(data, &str)
84                 if err != nil {
85                         return err
86                 }
87                 if str == "" {
88                         *ss = nil
89                 } else {
90                         *ss = []string{str}
91                 }
92         }
93         return nil
94 }
95
96 type ec2Interface interface {
97         DescribeKeyPairs(context.Context, *ec2.DescribeKeyPairsInput, ...func(*ec2.Options)) (*ec2.DescribeKeyPairsOutput, error)
98         ImportKeyPair(context.Context, *ec2.ImportKeyPairInput, ...func(*ec2.Options)) (*ec2.ImportKeyPairOutput, error)
99         RunInstances(context.Context, *ec2.RunInstancesInput, ...func(*ec2.Options)) (*ec2.RunInstancesOutput, error)
100         DescribeInstances(context.Context, *ec2.DescribeInstancesInput, ...func(*ec2.Options)) (*ec2.DescribeInstancesOutput, error)
101         DescribeInstanceStatus(context.Context, *ec2.DescribeInstanceStatusInput, ...func(*ec2.Options)) (*ec2.DescribeInstanceStatusOutput, error)
102         DescribeSpotPriceHistory(context.Context, *ec2.DescribeSpotPriceHistoryInput, ...func(*ec2.Options)) (*ec2.DescribeSpotPriceHistoryOutput, error)
103         CreateTags(context.Context, *ec2.CreateTagsInput, ...func(*ec2.Options)) (*ec2.CreateTagsOutput, error)
104         TerminateInstances(context.Context, *ec2.TerminateInstancesInput, ...func(*ec2.Options)) (*ec2.TerminateInstancesOutput, error)
105 }
106
107 type ec2InstanceSet struct {
108         ec2config              ec2InstanceSetConfig
109         currentSubnetIDIndex   int32
110         instanceSetID          cloud.InstanceSetID
111         logger                 logrus.FieldLogger
112         client                 ec2Interface
113         keysMtx                sync.Mutex
114         keys                   map[string]string
115         throttleDelayCreate    atomic.Value
116         throttleDelayInstances atomic.Value
117
118         prices        map[priceKey][]cloud.InstancePrice
119         pricesLock    sync.Mutex
120         pricesUpdated map[priceKey]time.Time
121
122         mInstances      *prometheus.GaugeVec
123         mInstanceStarts *prometheus.CounterVec
124 }
125
126 func newEC2InstanceSet(confRaw json.RawMessage, instanceSetID cloud.InstanceSetID, _ cloud.SharedResourceTags, logger logrus.FieldLogger, reg *prometheus.Registry) (prv cloud.InstanceSet, err error) {
127         instanceSet := &ec2InstanceSet{
128                 instanceSetID: instanceSetID,
129                 logger:        logger,
130         }
131         err = json.Unmarshal(confRaw, &instanceSet.ec2config)
132         if err != nil {
133                 return nil, err
134         }
135         awsConfig, err := config.LoadDefaultConfig(context.Background(),
136                 config.WithRegion(instanceSet.ec2config.Region),
137                 config.WithCredentialsCacheOptions(func(o *aws.CredentialsCacheOptions) {
138                         o.ExpiryWindow = 5 * time.Minute
139                 }),
140                 func(o *config.LoadOptions) error {
141                         if instanceSet.ec2config.AccessKeyID == "" && instanceSet.ec2config.SecretAccessKey == "" {
142                                 // Use default SDK behavior (IAM role
143                                 // via IMDSv2)
144                                 return nil
145                         }
146                         o.Credentials = credentials.StaticCredentialsProvider{
147                                 Value: aws.Credentials{
148                                         AccessKeyID:     instanceSet.ec2config.AccessKeyID,
149                                         SecretAccessKey: instanceSet.ec2config.SecretAccessKey,
150                                         Source:          "Arvados configuration",
151                                 },
152                         }
153                         return nil
154                 })
155         if err != nil {
156                 return nil, err
157         }
158
159         instanceSet.client = ec2.NewFromConfig(awsConfig)
160         instanceSet.keys = make(map[string]string)
161         if instanceSet.ec2config.EBSVolumeType == "" {
162                 instanceSet.ec2config.EBSVolumeType = "gp2"
163         }
164
165         // Set up metrics
166         instanceSet.mInstances = prometheus.NewGaugeVec(prometheus.GaugeOpts{
167                 Namespace: "arvados",
168                 Subsystem: "dispatchcloud",
169                 Name:      "ec2_instances",
170                 Help:      "Number of instances running",
171         }, []string{"subnet_id"})
172         instanceSet.mInstanceStarts = prometheus.NewCounterVec(prometheus.CounterOpts{
173                 Namespace: "arvados",
174                 Subsystem: "dispatchcloud",
175                 Name:      "ec2_instance_starts_total",
176                 Help:      "Number of attempts to start a new instance",
177         }, []string{"subnet_id", "success"})
178         // Initialize all of the series we'll be reporting.  Otherwise
179         // the {subnet=A, success=0} series doesn't appear in metrics
180         // at all until there's a failure in subnet A.
181         for _, subnet := range instanceSet.ec2config.SubnetID {
182                 instanceSet.mInstanceStarts.WithLabelValues(subnet, "0").Add(0)
183                 instanceSet.mInstanceStarts.WithLabelValues(subnet, "1").Add(0)
184         }
185         if len(instanceSet.ec2config.SubnetID) == 0 {
186                 instanceSet.mInstanceStarts.WithLabelValues("", "0").Add(0)
187                 instanceSet.mInstanceStarts.WithLabelValues("", "1").Add(0)
188         }
189         if reg != nil {
190                 reg.MustRegister(instanceSet.mInstances)
191                 reg.MustRegister(instanceSet.mInstanceStarts)
192         }
193
194         return instanceSet, nil
195 }
196
197 // Calculate the public key fingerprints that AWS might use for a
198 // given key.  For an rsa key, return the AWS MD5 and SHA-1
199 // fingerprints in that order, like
200 // {"02:d8:ca:c4:67:58:7b:46:64:50:41:59:3d:90:33:40",
201 // "da:39:a3:ee:5e:6b:4b:0d:32:55:bf:ef:95:60:18:90:af:d8:07:09"}.
202 // For an ed25519 key, return the SHA-256 fingerprint with and without
203 // padding, like
204 // {"SHA256:jgxbPn8JspgUBbZo3nRPWJ5e2h4v6FbiwlTe49NsNKE=",
205 // "SHA256:jgxbPn8JspgUBbZo3nRPWJ5e2h4v6FbiwlTe49NsNKE"}.
206 //
207 // "When Amazon EC2 calculates a fingerprint, Amazon EC2 might append
208 // padding to the fingerprint with = characters."
209 //
210 // https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/verify-keys.html
211 func awsKeyFingerprints(pk ssh.PublicKey) ([]string, error) {
212         if pk.Type() != "ssh-rsa" {
213                 // sha256 is always 256 bits, so the padded base64
214                 // encoding will always be the unpadded encoding (as
215                 // returned by ssh.FingerprintSHA256) plus a final
216                 // "=".
217                 hash2 := ssh.FingerprintSHA256(pk)
218                 hash1 := hash2 + "="
219                 return []string{hash1, hash2}, nil
220         }
221         // AWS key fingerprints don't use the usual key fingerprint
222         // you get from ssh-keygen or ssh.FingerprintLegacyMD5()
223         // (you can get that from md5.Sum(pk.Marshal())
224         //
225         // AWS uses the md5 or sha1 of the PKIX DER encoding of the
226         // public key, so calculate those fingerprints here.
227         var rsaPub struct {
228                 Name string
229                 E    *big.Int
230                 N    *big.Int
231         }
232         if err := ssh.Unmarshal(pk.Marshal(), &rsaPub); err != nil {
233                 return nil, fmt.Errorf("Unmarshal failed to parse public key: %w", err)
234         }
235         rsaPk := rsa.PublicKey{
236                 E: int(rsaPub.E.Int64()),
237                 N: rsaPub.N,
238         }
239         pkix, _ := x509.MarshalPKIXPublicKey(&rsaPk)
240         sum1 := md5.Sum(pkix)
241         sum2 := sha1.Sum(pkix)
242         return []string{
243                 hexFingerprint(sum1[:]),
244                 hexFingerprint(sum2[:]),
245         }, nil
246 }
247
248 // Return hex-fingerprint representation of sum, like "12:34:56:...".
249 func hexFingerprint(sum []byte) string {
250         hexarray := make([]string, len(sum))
251         for i, c := range sum {
252                 hexarray[i] = hex.EncodeToString([]byte{c})
253         }
254         return strings.Join(hexarray, ":")
255 }
256
257 func (instanceSet *ec2InstanceSet) Create(
258         instanceType arvados.InstanceType,
259         imageID cloud.ImageID,
260         newTags cloud.InstanceTags,
261         initCommand cloud.InitCommand,
262         publicKey ssh.PublicKey) (cloud.Instance, error) {
263
264         ec2tags := []types.Tag{}
265         for k, v := range newTags {
266                 ec2tags = append(ec2tags, types.Tag{
267                         Key:   aws.String(k),
268                         Value: aws.String(v),
269                 })
270         }
271
272         var groups []string
273         for sg := range instanceSet.ec2config.SecurityGroupIDs {
274                 groups = append(groups, sg)
275         }
276
277         rii := ec2.RunInstancesInput{
278                 ImageId:      aws.String(string(imageID)),
279                 InstanceType: types.InstanceType(instanceType.ProviderType),
280                 MaxCount:     aws.Int32(1),
281                 MinCount:     aws.Int32(1),
282
283                 NetworkInterfaces: []types.InstanceNetworkInterfaceSpecification{{
284                         AssociatePublicIpAddress: aws.Bool(false),
285                         DeleteOnTermination:      aws.Bool(true),
286                         DeviceIndex:              aws.Int32(0),
287                         Groups:                   groups,
288                 }},
289                 DisableApiTermination:             aws.Bool(false),
290                 InstanceInitiatedShutdownBehavior: types.ShutdownBehaviorTerminate,
291                 TagSpecifications: []types.TagSpecification{
292                         {
293                                 ResourceType: types.ResourceTypeInstance,
294                                 Tags:         ec2tags,
295                         }},
296                 MetadataOptions: &types.InstanceMetadataOptionsRequest{
297                         // Require IMDSv2, as described at
298                         // https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/configuring-IMDS-new-instances.html
299                         HttpEndpoint: types.InstanceMetadataEndpointStateEnabled,
300                         HttpTokens:   types.HttpTokensStateRequired,
301                 },
302                 UserData: aws.String(base64.StdEncoding.EncodeToString([]byte("#!/bin/sh\n" + initCommand + "\n"))),
303         }
304
305         if publicKey != nil {
306                 keyname, err := instanceSet.getKeyName(publicKey)
307                 if err != nil {
308                         return nil, err
309                 }
310                 rii.KeyName = &keyname
311         }
312
313         if instanceType.AddedScratch > 0 {
314                 rii.BlockDeviceMappings = []types.BlockDeviceMapping{{
315                         DeviceName: aws.String("/dev/xvdt"),
316                         Ebs: &types.EbsBlockDevice{
317                                 DeleteOnTermination: aws.Bool(true),
318                                 VolumeSize:          aws.Int32(int32((int64(instanceType.AddedScratch) + (1<<30 - 1)) >> 30)),
319                                 VolumeType:          instanceSet.ec2config.EBSVolumeType,
320                         }}}
321         }
322
323         if instanceType.Preemptible {
324                 rii.InstanceMarketOptions = &types.InstanceMarketOptionsRequest{
325                         MarketType: types.MarketTypeSpot,
326                         SpotOptions: &types.SpotMarketOptions{
327                                 InstanceInterruptionBehavior: types.InstanceInterruptionBehaviorTerminate,
328                                 MaxPrice:                     aws.String(fmt.Sprintf("%v", instanceType.Price)),
329                         }}
330         }
331
332         if instanceSet.ec2config.IAMInstanceProfile != "" {
333                 rii.IamInstanceProfile = &types.IamInstanceProfileSpecification{
334                         Name: aws.String(instanceSet.ec2config.IAMInstanceProfile),
335                 }
336         }
337
338         var rsv *ec2.RunInstancesOutput
339         var errToReturn error
340         var returningCapacityError bool
341         subnets := instanceSet.ec2config.SubnetID
342         currentSubnetIDIndex := int(atomic.LoadInt32(&instanceSet.currentSubnetIDIndex))
343         for tryOffset := 0; ; tryOffset++ {
344                 tryIndex := 0
345                 trySubnet := ""
346                 if len(subnets) > 0 {
347                         tryIndex = (currentSubnetIDIndex + tryOffset) % len(subnets)
348                         trySubnet = subnets[tryIndex]
349                         rii.NetworkInterfaces[0].SubnetId = aws.String(trySubnet)
350                 }
351                 var err error
352                 rsv, err = instanceSet.client.RunInstances(context.Background(), &rii)
353                 instanceSet.mInstanceStarts.WithLabelValues(trySubnet, boolLabelValue[err == nil]).Add(1)
354                 if instcap, groupcap := isErrorCapacity(err); !returningCapacityError || instcap || groupcap {
355                         // We want to return the last capacity error,
356                         // if any; otherwise the last non-capacity
357                         // error.
358                         errToReturn = err
359                         returningCapacityError = instcap || groupcap
360                 }
361                 if isErrorSubnetSpecific(err) &&
362                         tryOffset < len(subnets)-1 {
363                         instanceSet.logger.WithError(err).WithField("SubnetID", subnets[tryIndex]).
364                                 Warn("RunInstances failed, trying next subnet")
365                         continue
366                 }
367                 // Succeeded, or exhausted all subnets, or got a
368                 // non-subnet-related error.
369                 //
370                 // We intentionally update currentSubnetIDIndex even
371                 // in the non-retryable-failure case here to avoid a
372                 // situation where successive calls to Create() keep
373                 // returning errors for the same subnet (perhaps
374                 // "subnet full") and never reveal the errors for the
375                 // other configured subnets (perhaps "subnet ID
376                 // invalid").
377                 atomic.StoreInt32(&instanceSet.currentSubnetIDIndex, int32(tryIndex))
378                 break
379         }
380         if rsv == nil || len(rsv.Instances) == 0 {
381                 return nil, wrapError(errToReturn, &instanceSet.throttleDelayCreate)
382         }
383         return &ec2Instance{
384                 provider: instanceSet,
385                 instance: rsv.Instances[0],
386         }, nil
387 }
388
389 func (instanceSet *ec2InstanceSet) getKeyName(publicKey ssh.PublicKey) (string, error) {
390         instanceSet.keysMtx.Lock()
391         defer instanceSet.keysMtx.Unlock()
392         fingerprints, err := awsKeyFingerprints(publicKey)
393         if err != nil {
394                 return "", fmt.Errorf("Could not make key fingerprint: %w", err)
395         }
396         if keyname, ok := instanceSet.keys[fingerprints[0]]; ok {
397                 return keyname, nil
398         }
399         keyout, err := instanceSet.client.DescribeKeyPairs(context.Background(), &ec2.DescribeKeyPairsInput{
400                 Filters: []types.Filter{{
401                         Name:   aws.String("fingerprint"),
402                         Values: fingerprints,
403                 }},
404         })
405         if err != nil {
406                 return "", fmt.Errorf("Could not search for keypair: %w", err)
407         }
408         if len(keyout.KeyPairs) > 0 {
409                 return *(keyout.KeyPairs[0].KeyName), nil
410         }
411         keyname := "arvados-dispatch-keypair-" + fingerprints[0]
412         _, err = instanceSet.client.ImportKeyPair(context.Background(), &ec2.ImportKeyPairInput{
413                 KeyName:           &keyname,
414                 PublicKeyMaterial: ssh.MarshalAuthorizedKey(publicKey),
415         })
416         if err != nil {
417                 return "", fmt.Errorf("Could not import keypair: %w", err)
418         }
419         instanceSet.keys[fingerprints[0]] = keyname
420         return keyname, nil
421 }
422
423 func (instanceSet *ec2InstanceSet) Instances(tags cloud.InstanceTags) (instances []cloud.Instance, err error) {
424         var filters []types.Filter
425         for k, v := range tags {
426                 filters = append(filters, types.Filter{
427                         Name:   aws.String("tag:" + k),
428                         Values: []string{v},
429                 })
430         }
431         needAZs := false
432         dii := &ec2.DescribeInstancesInput{Filters: filters}
433         for {
434                 dio, err := instanceSet.client.DescribeInstances(context.Background(), dii)
435                 err = wrapError(err, &instanceSet.throttleDelayInstances)
436                 if err != nil {
437                         return nil, err
438                 }
439
440                 for _, rsv := range dio.Reservations {
441                         for _, inst := range rsv.Instances {
442                                 switch inst.State.Name {
443                                 case types.InstanceStateNameShuttingDown:
444                                 case types.InstanceStateNameTerminated:
445                                 default:
446                                         instances = append(instances, &ec2Instance{
447                                                 provider: instanceSet,
448                                                 instance: inst,
449                                         })
450                                         if inst.InstanceLifecycle == types.InstanceLifecycleTypeSpot {
451                                                 needAZs = true
452                                         }
453                                 }
454                         }
455                 }
456                 if dio.NextToken == nil || *dio.NextToken == "" {
457                         break
458                 }
459                 dii.NextToken = dio.NextToken
460         }
461         if needAZs && instanceSet.ec2config.SpotPriceUpdateInterval > 0 {
462                 az := map[string]string{}
463                 disi := &ec2.DescribeInstanceStatusInput{IncludeAllInstances: aws.Bool(true)}
464                 for {
465                         page, err := instanceSet.client.DescribeInstanceStatus(context.Background(), disi)
466                         if err != nil {
467                                 instanceSet.logger.WithError(err).Warn("error getting instance statuses")
468                                 break
469                         }
470                         for _, ent := range page.InstanceStatuses {
471                                 az[*ent.InstanceId] = *ent.AvailabilityZone
472                         }
473                         if page.NextToken == nil || *page.NextToken == "" {
474                                 break
475                         }
476                         disi.NextToken = page.NextToken
477                 }
478                 for _, inst := range instances {
479                         inst := inst.(*ec2Instance)
480                         inst.availabilityZone = az[*inst.instance.InstanceId]
481                 }
482                 instanceSet.updateSpotPrices(instances)
483         }
484
485         // Count instances in each subnet, and report in metrics.
486         subnetInstances := map[string]int{"": 0}
487         for _, subnet := range instanceSet.ec2config.SubnetID {
488                 subnetInstances[subnet] = 0
489         }
490         for _, inst := range instances {
491                 subnet := inst.(*ec2Instance).instance.SubnetId
492                 if subnet != nil {
493                         subnetInstances[*subnet]++
494                 } else {
495                         subnetInstances[""]++
496                 }
497         }
498         for subnet, count := range subnetInstances {
499                 instanceSet.mInstances.WithLabelValues(subnet).Set(float64(count))
500         }
501
502         return instances, err
503 }
504
505 type priceKey struct {
506         instanceType     string
507         spot             bool
508         availabilityZone string
509 }
510
511 // Refresh recent spot instance pricing data for the given instances,
512 // unless we already have recent pricing data for all relevant types.
513 func (instanceSet *ec2InstanceSet) updateSpotPrices(instances []cloud.Instance) {
514         if len(instances) == 0 {
515                 return
516         }
517
518         instanceSet.pricesLock.Lock()
519         defer instanceSet.pricesLock.Unlock()
520         if instanceSet.prices == nil {
521                 instanceSet.prices = map[priceKey][]cloud.InstancePrice{}
522                 instanceSet.pricesUpdated = map[priceKey]time.Time{}
523         }
524
525         updateTime := time.Now()
526         staleTime := updateTime.Add(-instanceSet.ec2config.SpotPriceUpdateInterval.Duration())
527         needUpdate := false
528         allTypes := map[types.InstanceType]bool{}
529
530         for _, inst := range instances {
531                 ec2inst := inst.(*ec2Instance).instance
532                 if ec2inst.InstanceLifecycle == types.InstanceLifecycleTypeSpot {
533                         pk := priceKey{
534                                 instanceType:     string(ec2inst.InstanceType),
535                                 spot:             true,
536                                 availabilityZone: inst.(*ec2Instance).availabilityZone,
537                         }
538                         if instanceSet.pricesUpdated[pk].Before(staleTime) {
539                                 needUpdate = true
540                         }
541                         allTypes[ec2inst.InstanceType] = true
542                 }
543         }
544         if !needUpdate {
545                 return
546         }
547         var typeFilterValues []string
548         for instanceType := range allTypes {
549                 typeFilterValues = append(typeFilterValues, string(instanceType))
550         }
551         // Get 3x update interval worth of pricing data. (Ideally the
552         // AWS API would tell us "we have shown you all of the price
553         // changes up to time T", but it doesn't, so we'll just ask
554         // for 3 intervals worth of data on each update, de-duplicate
555         // the data points, and not worry too much about occasionally
556         // missing some data points when our lookups fail twice in a
557         // row.
558         dsphi := &ec2.DescribeSpotPriceHistoryInput{
559                 StartTime: aws.Time(updateTime.Add(-3 * instanceSet.ec2config.SpotPriceUpdateInterval.Duration())),
560                 Filters: []types.Filter{
561                         types.Filter{Name: aws.String("instance-type"), Values: typeFilterValues},
562                         types.Filter{Name: aws.String("product-description"), Values: []string{"Linux/UNIX"}},
563                 },
564         }
565         for {
566                 page, err := instanceSet.client.DescribeSpotPriceHistory(context.Background(), dsphi)
567                 if err != nil {
568                         instanceSet.logger.WithError(err).Warn("error retrieving spot instance prices")
569                         break
570                 }
571                 for _, ent := range page.SpotPriceHistory {
572                         if ent.InstanceType == "" || ent.SpotPrice == nil || ent.Timestamp == nil {
573                                 // bogus record?
574                                 continue
575                         }
576                         price, err := strconv.ParseFloat(*ent.SpotPrice, 64)
577                         if err != nil {
578                                 // bogus record?
579                                 continue
580                         }
581                         pk := priceKey{
582                                 instanceType:     string(ent.InstanceType),
583                                 spot:             true,
584                                 availabilityZone: *ent.AvailabilityZone,
585                         }
586                         instanceSet.prices[pk] = append(instanceSet.prices[pk], cloud.InstancePrice{
587                                 StartTime: *ent.Timestamp,
588                                 Price:     price,
589                         })
590                         instanceSet.pricesUpdated[pk] = updateTime
591                 }
592                 if page.NextToken == nil || *page.NextToken == "" {
593                         break
594                 }
595                 dsphi.NextToken = page.NextToken
596         }
597
598         expiredTime := updateTime.Add(-64 * instanceSet.ec2config.SpotPriceUpdateInterval.Duration())
599         for pk, last := range instanceSet.pricesUpdated {
600                 if last.Before(expiredTime) {
601                         delete(instanceSet.pricesUpdated, pk)
602                         delete(instanceSet.prices, pk)
603                 }
604         }
605         for pk, prices := range instanceSet.prices {
606                 instanceSet.prices[pk] = cloud.NormalizePriceHistory(prices)
607         }
608 }
609
610 func (instanceSet *ec2InstanceSet) Stop() {
611 }
612
613 func (instanceSet *ec2InstanceSet) InstanceQuotaGroup(it arvados.InstanceType) cloud.InstanceQuotaGroup {
614         // https://docs.aws.amazon.com/ec2/latest/instancetypes/ec2-instance-quotas.html
615         // 2024-09-09
616         var quotaGroup string
617         pt := strings.ToLower(it.ProviderType)
618         for i, c := range pt {
619                 if !unicode.IsLower(c) && quotaGroup == "" {
620                         // Fall back to the alphabetic prefix of
621                         // ProviderType.
622                         quotaGroup = pt[:i]
623                 }
624                 if conf := instanceSet.ec2config.InstanceTypeQuotaGroups[pt[:i]]; conf != "" && quotaGroup != "" {
625                         // Prefer the longest prefix of ProviderType
626                         // that is listed explicitly in config.
627                         //
628                         // (But don't look up a too-short prefix --
629                         // for an instance type like "trn1.234", use
630                         // the config for "trn" or "trn1" but not
631                         // "t".)
632                         quotaGroup = conf
633                 }
634         }
635         if it.Preemptible {
636                 // Spot instance quotas are separate from demand
637                 // quotas.
638                 quotaGroup += "-spot"
639         }
640         return cloud.InstanceQuotaGroup(quotaGroup)
641 }
642
643 type ec2Instance struct {
644         provider         *ec2InstanceSet
645         instance         types.Instance
646         availabilityZone string // sometimes available for spot instances
647 }
648
649 func (inst *ec2Instance) ID() cloud.InstanceID {
650         return cloud.InstanceID(*inst.instance.InstanceId)
651 }
652
653 func (inst *ec2Instance) String() string {
654         return *inst.instance.InstanceId
655 }
656
657 func (inst *ec2Instance) ProviderType() string {
658         return string(inst.instance.InstanceType)
659 }
660
661 func (inst *ec2Instance) SetTags(newTags cloud.InstanceTags) error {
662         var ec2tags []types.Tag
663         for k, v := range newTags {
664                 ec2tags = append(ec2tags, types.Tag{
665                         Key:   aws.String(k),
666                         Value: aws.String(v),
667                 })
668         }
669
670         _, err := inst.provider.client.CreateTags(context.Background(), &ec2.CreateTagsInput{
671                 Resources: []string{*inst.instance.InstanceId},
672                 Tags:      ec2tags,
673         })
674
675         return err
676 }
677
678 func (inst *ec2Instance) Tags() cloud.InstanceTags {
679         tags := make(map[string]string)
680
681         for _, t := range inst.instance.Tags {
682                 tags[*t.Key] = *t.Value
683         }
684
685         return tags
686 }
687
688 func (inst *ec2Instance) Destroy() error {
689         _, err := inst.provider.client.TerminateInstances(context.Background(), &ec2.TerminateInstancesInput{
690                 InstanceIds: []string{*inst.instance.InstanceId},
691         })
692         return err
693 }
694
695 func (inst *ec2Instance) Address() string {
696         if inst.instance.PrivateIpAddress != nil {
697                 return *inst.instance.PrivateIpAddress
698         }
699         return ""
700 }
701
702 func (inst *ec2Instance) RemoteUser() string {
703         return inst.provider.ec2config.AdminUsername
704 }
705
706 func (inst *ec2Instance) VerifyHostKey(ssh.PublicKey, *ssh.Client) error {
707         return cloud.ErrNotImplemented
708 }
709
710 // PriceHistory returns the price history for this specific instance.
711 //
712 // AWS documentation is elusive about whether the hourly cost of a
713 // given spot instance changes as the current spot price changes for
714 // the corresponding instance type and availability zone. Our
715 // implementation assumes the answer is yes, based on the following
716 // hints.
717 //
718 // https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/spot-requests.html
719 // says: "After your Spot Instance is running, if the Spot price rises
720 // above your maximum price, Amazon EC2 interrupts your Spot
721 // Instance." (This doesn't address what happens when the spot price
722 // rises *without* exceeding your maximum price.)
723 //
724 // https://docs.aws.amazon.com/whitepapers/latest/cost-optimization-leveraging-ec2-spot-instances/how-spot-instances-work.html
725 // says: "You pay the Spot price that's in effect, billed to the
726 // nearest second." (But it's not explicitly stated whether "the price
727 // in effect" changes over time for a given instance.)
728 //
729 // The same page also says, in a discussion about the effect of
730 // specifying a maximum price: "Note that you never pay more than the
731 // Spot price that is in effect when your Spot Instance is running."
732 // (The use of the phrase "is running", as opposed to "was launched",
733 // hints that pricing is dynamic.)
734 func (inst *ec2Instance) PriceHistory(instType arvados.InstanceType) []cloud.InstancePrice {
735         inst.provider.pricesLock.Lock()
736         defer inst.provider.pricesLock.Unlock()
737         // Note updateSpotPrices currently populates
738         // inst.provider.prices only for spot instances, so if
739         // spot==false here, we will return no data.
740         pk := priceKey{
741                 instanceType:     string(inst.instance.InstanceType),
742                 spot:             inst.instance.InstanceLifecycle == types.InstanceLifecycleTypeSpot,
743                 availabilityZone: inst.availabilityZone,
744         }
745         var prices []cloud.InstancePrice
746         for _, price := range inst.provider.prices[pk] {
747                 // ceil(added scratch space in GiB)
748                 gib := (instType.AddedScratch + 1<<30 - 1) >> 30
749                 monthly := inst.provider.ec2config.EBSPrice * float64(gib)
750                 hourly := monthly / 30 / 24
751                 price.Price += hourly
752                 prices = append(prices, price)
753         }
754         return prices
755 }
756
757 type rateLimitError struct {
758         error
759         earliestRetry time.Time
760 }
761
762 func (err rateLimitError) EarliestRetry() time.Time {
763         return err.earliestRetry
764 }
765
766 type capacityError struct {
767         error
768         isInstanceQuotaGroupSpecific bool
769         isInstanceTypeSpecific       bool
770 }
771
772 func (er *capacityError) IsCapacityError() bool {
773         return true
774 }
775
776 func (er *capacityError) IsInstanceQuotaGroupSpecific() bool {
777         return er.isInstanceQuotaGroupSpecific
778 }
779
780 func (er *capacityError) IsInstanceTypeSpecific() bool {
781         return er.isInstanceTypeSpecific
782 }
783
784 var isCodeQuota = map[string]bool{
785         "InstanceLimitExceeded":             true,
786         "InsufficientAddressCapacity":       true,
787         "InsufficientFreeAddressesInSubnet": true,
788         "InsufficientVolumeCapacity":        true,
789         "MaxSpotInstanceCountExceeded":      true,
790 }
791
792 // isErrorQuota returns whether the error indicates we have reached
793 // some usage quota/limit -- i.e., immediately retrying with an equal
794 // or larger instance type will probably not work.
795 //
796 // Returns false if error is nil.
797 func isErrorQuota(err error) bool {
798         var aerr smithy.APIError
799         if errors.As(err, &aerr) {
800                 if _, ok := isCodeQuota[aerr.ErrorCode()]; ok {
801                         return true
802                 }
803         }
804         return false
805 }
806
807 var reSubnetSpecificInvalidParameterMessage = regexp.MustCompile(`(?ms).*( subnet |sufficient free [Ii]pv[46] addresses).*`)
808
809 // isErrorSubnetSpecific returns true if the problem encountered by
810 // RunInstances might be avoided by trying a different subnet.
811 func isErrorSubnetSpecific(err error) bool {
812         var aerr smithy.APIError
813         if !errors.As(err, &aerr) {
814                 return false
815         }
816         code := aerr.ErrorCode()
817         return strings.Contains(code, "Subnet") ||
818                 code == "InsufficientInstanceCapacity" ||
819                 code == "InsufficientVolumeCapacity" ||
820                 code == "Unsupported" ||
821                 // See TestIsErrorSubnetSpecific for examples of why
822                 // we look for substrings in code/message instead of
823                 // only using specific codes here.
824                 (strings.Contains(code, "InvalidParameter") &&
825                         reSubnetSpecificInvalidParameterMessage.MatchString(aerr.ErrorMessage()))
826 }
827
828 // isErrorCapacity determines whether the given error indicates lack
829 // of capacity -- either temporary or permanent -- to run a specific
830 // instance type (i.e., retrying with any other instance type might
831 // succeed) or an instance quota group (i.e., retrying with an
832 // instance type in a different instance quota group might succeed).
833 func isErrorCapacity(err error) (instcap bool, groupcap bool) {
834         var aerr smithy.APIError
835         if !errors.As(err, &aerr) {
836                 return false, false
837         }
838         code := aerr.ErrorCode()
839         if code == "VcpuLimitExceeded" {
840                 return false, true
841         }
842         if code == "InsufficientInstanceCapacity" ||
843                 (code == "Unsupported" && strings.Contains(aerr.ErrorMessage(), "requested instance type")) {
844                 return true, false
845         }
846         return false, false
847 }
848
849 type ec2QuotaError struct {
850         error
851 }
852
853 func (er *ec2QuotaError) IsQuotaError() bool {
854         return true
855 }
856
857 func isThrottleError(err error) bool {
858         var aerr smithy.APIError
859         if !errors.As(err, &aerr) {
860                 return false
861         }
862         _, is := retry.DefaultThrottleErrorCodes[aerr.ErrorCode()]
863         return is
864 }
865
866 func wrapError(err error, throttleValue *atomic.Value) error {
867         if isThrottleError(err) {
868                 // Back off exponentially until an upstream call
869                 // either succeeds or returns a non-throttle error.
870                 d, _ := throttleValue.Load().(time.Duration)
871                 d = d*3/2 + time.Second
872                 if d < throttleDelayMin {
873                         d = throttleDelayMin
874                 } else if d > throttleDelayMax {
875                         d = throttleDelayMax
876                 }
877                 throttleValue.Store(d)
878                 return rateLimitError{error: err, earliestRetry: time.Now().Add(d)}
879         } else if isErrorQuota(err) {
880                 return &ec2QuotaError{error: err}
881         } else if instcap, groupcap := isErrorCapacity(err); instcap || groupcap {
882                 return &capacityError{
883                         error:                        err,
884                         isInstanceTypeSpecific:       !groupcap,
885                         isInstanceQuotaGroupSpecific: groupcap,
886                 }
887         } else if err != nil {
888                 throttleValue.Store(time.Duration(0))
889                 return err
890         }
891         throttleValue.Store(time.Duration(0))
892         return nil
893 }
894
895 var boolLabelValue = map[bool]string{false: "0", true: "1"}