Merge branch '17284-redact-railsapi-host'
[arvados.git] / lib / cloud / ec2 / ec2.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package ec2
6
7 import (
8         "crypto/md5"
9         "crypto/rsa"
10         "crypto/sha1"
11         "crypto/x509"
12         "encoding/base64"
13         "encoding/json"
14         "fmt"
15         "math/big"
16         "sync"
17         "sync/atomic"
18         "time"
19
20         "git.arvados.org/arvados.git/lib/cloud"
21         "git.arvados.org/arvados.git/sdk/go/arvados"
22         "github.com/aws/aws-sdk-go/aws"
23         "github.com/aws/aws-sdk-go/aws/credentials"
24         "github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds"
25         "github.com/aws/aws-sdk-go/aws/ec2metadata"
26         "github.com/aws/aws-sdk-go/aws/request"
27         "github.com/aws/aws-sdk-go/aws/session"
28         "github.com/aws/aws-sdk-go/service/ec2"
29         "github.com/sirupsen/logrus"
30         "golang.org/x/crypto/ssh"
31 )
32
33 // Driver is the ec2 implementation of the cloud.Driver interface.
34 var Driver = cloud.DriverFunc(newEC2InstanceSet)
35
36 const (
37         throttleDelayMin = time.Second
38         throttleDelayMax = time.Minute
39 )
40
41 type ec2InstanceSetConfig struct {
42         AccessKeyID      string
43         SecretAccessKey  string
44         Region           string
45         SecurityGroupIDs arvados.StringSet
46         SubnetID         string
47         AdminUsername    string
48         EBSVolumeType    string
49 }
50
51 type ec2Interface interface {
52         DescribeKeyPairs(input *ec2.DescribeKeyPairsInput) (*ec2.DescribeKeyPairsOutput, error)
53         ImportKeyPair(input *ec2.ImportKeyPairInput) (*ec2.ImportKeyPairOutput, error)
54         RunInstances(input *ec2.RunInstancesInput) (*ec2.Reservation, error)
55         DescribeInstances(input *ec2.DescribeInstancesInput) (*ec2.DescribeInstancesOutput, error)
56         CreateTags(input *ec2.CreateTagsInput) (*ec2.CreateTagsOutput, error)
57         TerminateInstances(input *ec2.TerminateInstancesInput) (*ec2.TerminateInstancesOutput, error)
58 }
59
60 type ec2InstanceSet struct {
61         ec2config              ec2InstanceSetConfig
62         instanceSetID          cloud.InstanceSetID
63         logger                 logrus.FieldLogger
64         client                 ec2Interface
65         keysMtx                sync.Mutex
66         keys                   map[string]string
67         throttleDelayCreate    atomic.Value
68         throttleDelayInstances atomic.Value
69 }
70
71 func newEC2InstanceSet(config json.RawMessage, instanceSetID cloud.InstanceSetID, _ cloud.SharedResourceTags, logger logrus.FieldLogger) (prv cloud.InstanceSet, err error) {
72         instanceSet := &ec2InstanceSet{
73                 instanceSetID: instanceSetID,
74                 logger:        logger,
75         }
76         err = json.Unmarshal(config, &instanceSet.ec2config)
77         if err != nil {
78                 return nil, err
79         }
80
81         sess, err := session.NewSession()
82         if err != nil {
83                 return nil, err
84         }
85         // First try any static credentials, fall back to an IAM instance profile/role
86         creds := credentials.NewChainCredentials(
87                 []credentials.Provider{
88                         &credentials.StaticProvider{Value: credentials.Value{AccessKeyID: instanceSet.ec2config.AccessKeyID, SecretAccessKey: instanceSet.ec2config.SecretAccessKey}},
89                         &ec2rolecreds.EC2RoleProvider{Client: ec2metadata.New(sess)},
90                 })
91
92         awsConfig := aws.NewConfig().WithCredentials(creds).WithRegion(instanceSet.ec2config.Region)
93         instanceSet.client = ec2.New(session.Must(session.NewSession(awsConfig)))
94         instanceSet.keys = make(map[string]string)
95         if instanceSet.ec2config.EBSVolumeType == "" {
96                 instanceSet.ec2config.EBSVolumeType = "gp2"
97         }
98         return instanceSet, nil
99 }
100
101 func awsKeyFingerprint(pk ssh.PublicKey) (md5fp string, sha1fp string, err error) {
102         // AWS key fingerprints don't use the usual key fingerprint
103         // you get from ssh-keygen or ssh.FingerprintLegacyMD5()
104         // (you can get that from md5.Sum(pk.Marshal())
105         //
106         // AWS uses the md5 or sha1 of the PKIX DER encoding of the
107         // public key, so calculate those fingerprints here.
108         var rsaPub struct {
109                 Name string
110                 E    *big.Int
111                 N    *big.Int
112         }
113         if err := ssh.Unmarshal(pk.Marshal(), &rsaPub); err != nil {
114                 return "", "", fmt.Errorf("agent: Unmarshal failed to parse public key: %v", err)
115         }
116         rsaPk := rsa.PublicKey{
117                 E: int(rsaPub.E.Int64()),
118                 N: rsaPub.N,
119         }
120         pkix, _ := x509.MarshalPKIXPublicKey(&rsaPk)
121         md5pkix := md5.Sum([]byte(pkix))
122         sha1pkix := sha1.Sum([]byte(pkix))
123         md5fp = ""
124         sha1fp = ""
125         for i := 0; i < len(md5pkix); i++ {
126                 md5fp += fmt.Sprintf(":%02x", md5pkix[i])
127         }
128         for i := 0; i < len(sha1pkix); i++ {
129                 sha1fp += fmt.Sprintf(":%02x", sha1pkix[i])
130         }
131         return md5fp[1:], sha1fp[1:], nil
132 }
133
134 func (instanceSet *ec2InstanceSet) Create(
135         instanceType arvados.InstanceType,
136         imageID cloud.ImageID,
137         newTags cloud.InstanceTags,
138         initCommand cloud.InitCommand,
139         publicKey ssh.PublicKey) (cloud.Instance, error) {
140
141         md5keyFingerprint, sha1keyFingerprint, err := awsKeyFingerprint(publicKey)
142         if err != nil {
143                 return nil, fmt.Errorf("Could not make key fingerprint: %v", err)
144         }
145         instanceSet.keysMtx.Lock()
146         var keyname string
147         var ok bool
148         if keyname, ok = instanceSet.keys[md5keyFingerprint]; !ok {
149                 keyout, err := instanceSet.client.DescribeKeyPairs(&ec2.DescribeKeyPairsInput{
150                         Filters: []*ec2.Filter{{
151                                 Name:   aws.String("fingerprint"),
152                                 Values: []*string{&md5keyFingerprint, &sha1keyFingerprint},
153                         }},
154                 })
155                 if err != nil {
156                         return nil, fmt.Errorf("Could not search for keypair: %v", err)
157                 }
158
159                 if len(keyout.KeyPairs) > 0 {
160                         keyname = *(keyout.KeyPairs[0].KeyName)
161                 } else {
162                         keyname = "arvados-dispatch-keypair-" + md5keyFingerprint
163                         _, err := instanceSet.client.ImportKeyPair(&ec2.ImportKeyPairInput{
164                                 KeyName:           &keyname,
165                                 PublicKeyMaterial: ssh.MarshalAuthorizedKey(publicKey),
166                         })
167                         if err != nil {
168                                 return nil, fmt.Errorf("Could not import keypair: %v", err)
169                         }
170                 }
171                 instanceSet.keys[md5keyFingerprint] = keyname
172         }
173         instanceSet.keysMtx.Unlock()
174
175         ec2tags := []*ec2.Tag{}
176         for k, v := range newTags {
177                 ec2tags = append(ec2tags, &ec2.Tag{
178                         Key:   aws.String(k),
179                         Value: aws.String(v),
180                 })
181         }
182
183         var groups []string
184         for sg := range instanceSet.ec2config.SecurityGroupIDs {
185                 groups = append(groups, sg)
186         }
187
188         rii := ec2.RunInstancesInput{
189                 ImageId:      aws.String(string(imageID)),
190                 InstanceType: &instanceType.ProviderType,
191                 MaxCount:     aws.Int64(1),
192                 MinCount:     aws.Int64(1),
193                 KeyName:      &keyname,
194
195                 NetworkInterfaces: []*ec2.InstanceNetworkInterfaceSpecification{
196                         {
197                                 AssociatePublicIpAddress: aws.Bool(false),
198                                 DeleteOnTermination:      aws.Bool(true),
199                                 DeviceIndex:              aws.Int64(0),
200                                 Groups:                   aws.StringSlice(groups),
201                                 SubnetId:                 &instanceSet.ec2config.SubnetID,
202                         }},
203                 DisableApiTermination:             aws.Bool(false),
204                 InstanceInitiatedShutdownBehavior: aws.String("terminate"),
205                 TagSpecifications: []*ec2.TagSpecification{
206                         {
207                                 ResourceType: aws.String("instance"),
208                                 Tags:         ec2tags,
209                         }},
210                 UserData: aws.String(base64.StdEncoding.EncodeToString([]byte("#!/bin/sh\n" + initCommand + "\n"))),
211         }
212
213         if instanceType.AddedScratch > 0 {
214                 rii.BlockDeviceMappings = []*ec2.BlockDeviceMapping{{
215                         DeviceName: aws.String("/dev/xvdt"),
216                         Ebs: &ec2.EbsBlockDevice{
217                                 DeleteOnTermination: aws.Bool(true),
218                                 VolumeSize:          aws.Int64((int64(instanceType.AddedScratch) + (1<<30 - 1)) >> 30),
219                                 VolumeType:          &instanceSet.ec2config.EBSVolumeType,
220                         }}}
221         }
222
223         if instanceType.Preemptible {
224                 rii.InstanceMarketOptions = &ec2.InstanceMarketOptionsRequest{
225                         MarketType: aws.String("spot"),
226                         SpotOptions: &ec2.SpotMarketOptions{
227                                 InstanceInterruptionBehavior: aws.String("terminate"),
228                                 MaxPrice:                     aws.String(fmt.Sprintf("%v", instanceType.Price)),
229                         }}
230         }
231
232         rsv, err := instanceSet.client.RunInstances(&rii)
233         err = wrapError(err, &instanceSet.throttleDelayCreate)
234         if err != nil {
235                 return nil, err
236         }
237
238         return &ec2Instance{
239                 provider: instanceSet,
240                 instance: rsv.Instances[0],
241         }, nil
242 }
243
244 func (instanceSet *ec2InstanceSet) Instances(tags cloud.InstanceTags) (instances []cloud.Instance, err error) {
245         var filters []*ec2.Filter
246         for k, v := range tags {
247                 filters = append(filters, &ec2.Filter{
248                         Name:   aws.String("tag:" + k),
249                         Values: []*string{aws.String(v)},
250                 })
251         }
252         dii := &ec2.DescribeInstancesInput{Filters: filters}
253         for {
254                 dio, err := instanceSet.client.DescribeInstances(dii)
255                 err = wrapError(err, &instanceSet.throttleDelayInstances)
256                 if err != nil {
257                         return nil, err
258                 }
259
260                 for _, rsv := range dio.Reservations {
261                         for _, inst := range rsv.Instances {
262                                 if *inst.State.Name != "shutting-down" && *inst.State.Name != "terminated" {
263                                         instances = append(instances, &ec2Instance{instanceSet, inst})
264                                 }
265                         }
266                 }
267                 if dio.NextToken == nil {
268                         return instances, err
269                 }
270                 dii.NextToken = dio.NextToken
271         }
272 }
273
274 func (instanceSet *ec2InstanceSet) Stop() {
275 }
276
277 type ec2Instance struct {
278         provider *ec2InstanceSet
279         instance *ec2.Instance
280 }
281
282 func (inst *ec2Instance) ID() cloud.InstanceID {
283         return cloud.InstanceID(*inst.instance.InstanceId)
284 }
285
286 func (inst *ec2Instance) String() string {
287         return *inst.instance.InstanceId
288 }
289
290 func (inst *ec2Instance) ProviderType() string {
291         return *inst.instance.InstanceType
292 }
293
294 func (inst *ec2Instance) SetTags(newTags cloud.InstanceTags) error {
295         var ec2tags []*ec2.Tag
296         for k, v := range newTags {
297                 ec2tags = append(ec2tags, &ec2.Tag{
298                         Key:   aws.String(k),
299                         Value: aws.String(v),
300                 })
301         }
302
303         _, err := inst.provider.client.CreateTags(&ec2.CreateTagsInput{
304                 Resources: []*string{inst.instance.InstanceId},
305                 Tags:      ec2tags,
306         })
307
308         return err
309 }
310
311 func (inst *ec2Instance) Tags() cloud.InstanceTags {
312         tags := make(map[string]string)
313
314         for _, t := range inst.instance.Tags {
315                 tags[*t.Key] = *t.Value
316         }
317
318         return tags
319 }
320
321 func (inst *ec2Instance) Destroy() error {
322         _, err := inst.provider.client.TerminateInstances(&ec2.TerminateInstancesInput{
323                 InstanceIds: []*string{inst.instance.InstanceId},
324         })
325         return err
326 }
327
328 func (inst *ec2Instance) Address() string {
329         if inst.instance.PrivateIpAddress != nil {
330                 return *inst.instance.PrivateIpAddress
331         }
332         return ""
333 }
334
335 func (inst *ec2Instance) RemoteUser() string {
336         return inst.provider.ec2config.AdminUsername
337 }
338
339 func (inst *ec2Instance) VerifyHostKey(ssh.PublicKey, *ssh.Client) error {
340         return cloud.ErrNotImplemented
341 }
342
343 type rateLimitError struct {
344         error
345         earliestRetry time.Time
346 }
347
348 func (err rateLimitError) EarliestRetry() time.Time {
349         return err.earliestRetry
350 }
351
352 func wrapError(err error, throttleValue *atomic.Value) error {
353         if request.IsErrorThrottle(err) {
354                 // Back off exponentially until an upstream call
355                 // either succeeds or returns a non-throttle error.
356                 d, _ := throttleValue.Load().(time.Duration)
357                 d = d*3/2 + time.Second
358                 if d < throttleDelayMin {
359                         d = throttleDelayMin
360                 } else if d > throttleDelayMax {
361                         d = throttleDelayMax
362                 }
363                 throttleValue.Store(d)
364                 return rateLimitError{error: err, earliestRetry: time.Now().Add(d)}
365         } else if err != nil {
366                 throttleValue.Store(time.Duration(0))
367                 return err
368         }
369         throttleValue.Store(time.Duration(0))
370         return nil
371 }