17417: Add arm64 packages for our Golang components.
[arvados.git] / lib / cloud / ec2 / ec2.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package ec2
6
7 import (
8         "crypto/md5"
9         "crypto/rsa"
10         "crypto/sha1"
11         "crypto/x509"
12         "encoding/base64"
13         "encoding/json"
14         "fmt"
15         "math/big"
16         "sync"
17
18         "git.arvados.org/arvados.git/lib/cloud"
19         "git.arvados.org/arvados.git/sdk/go/arvados"
20         "github.com/aws/aws-sdk-go/aws"
21         "github.com/aws/aws-sdk-go/aws/credentials"
22         "github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds"
23         "github.com/aws/aws-sdk-go/aws/ec2metadata"
24         "github.com/aws/aws-sdk-go/aws/session"
25         "github.com/aws/aws-sdk-go/service/ec2"
26         "github.com/sirupsen/logrus"
27         "golang.org/x/crypto/ssh"
28 )
29
30 // Driver is the ec2 implementation of the cloud.Driver interface.
31 var Driver = cloud.DriverFunc(newEC2InstanceSet)
32
33 type ec2InstanceSetConfig struct {
34         AccessKeyID      string
35         SecretAccessKey  string
36         Region           string
37         SecurityGroupIDs arvados.StringSet
38         SubnetID         string
39         AdminUsername    string
40         EBSVolumeType    string
41 }
42
43 type ec2Interface interface {
44         DescribeKeyPairs(input *ec2.DescribeKeyPairsInput) (*ec2.DescribeKeyPairsOutput, error)
45         ImportKeyPair(input *ec2.ImportKeyPairInput) (*ec2.ImportKeyPairOutput, error)
46         RunInstances(input *ec2.RunInstancesInput) (*ec2.Reservation, error)
47         DescribeInstances(input *ec2.DescribeInstancesInput) (*ec2.DescribeInstancesOutput, error)
48         CreateTags(input *ec2.CreateTagsInput) (*ec2.CreateTagsOutput, error)
49         TerminateInstances(input *ec2.TerminateInstancesInput) (*ec2.TerminateInstancesOutput, error)
50 }
51
52 type ec2InstanceSet struct {
53         ec2config     ec2InstanceSetConfig
54         instanceSetID cloud.InstanceSetID
55         logger        logrus.FieldLogger
56         client        ec2Interface
57         keysMtx       sync.Mutex
58         keys          map[string]string
59 }
60
61 func newEC2InstanceSet(config json.RawMessage, instanceSetID cloud.InstanceSetID, _ cloud.SharedResourceTags, logger logrus.FieldLogger) (prv cloud.InstanceSet, err error) {
62         instanceSet := &ec2InstanceSet{
63                 instanceSetID: instanceSetID,
64                 logger:        logger,
65         }
66         err = json.Unmarshal(config, &instanceSet.ec2config)
67         if err != nil {
68                 return nil, err
69         }
70
71         sess, err := session.NewSession()
72         if err != nil {
73                 return nil, err
74         }
75         // First try any static credentials, fall back to an IAM instance profile/role
76         creds := credentials.NewChainCredentials(
77                 []credentials.Provider{
78                         &credentials.StaticProvider{Value: credentials.Value{AccessKeyID: instanceSet.ec2config.AccessKeyID, SecretAccessKey: instanceSet.ec2config.SecretAccessKey}},
79                         &ec2rolecreds.EC2RoleProvider{Client: ec2metadata.New(sess)},
80                 })
81
82         awsConfig := aws.NewConfig().WithCredentials(creds).WithRegion(instanceSet.ec2config.Region)
83         instanceSet.client = ec2.New(session.Must(session.NewSession(awsConfig)))
84         instanceSet.keys = make(map[string]string)
85         if instanceSet.ec2config.EBSVolumeType == "" {
86                 instanceSet.ec2config.EBSVolumeType = "gp2"
87         }
88         return instanceSet, nil
89 }
90
91 func awsKeyFingerprint(pk ssh.PublicKey) (md5fp string, sha1fp string, err error) {
92         // AWS key fingerprints don't use the usual key fingerprint
93         // you get from ssh-keygen or ssh.FingerprintLegacyMD5()
94         // (you can get that from md5.Sum(pk.Marshal())
95         //
96         // AWS uses the md5 or sha1 of the PKIX DER encoding of the
97         // public key, so calculate those fingerprints here.
98         var rsaPub struct {
99                 Name string
100                 E    *big.Int
101                 N    *big.Int
102         }
103         if err := ssh.Unmarshal(pk.Marshal(), &rsaPub); err != nil {
104                 return "", "", fmt.Errorf("agent: Unmarshal failed to parse public key: %v", err)
105         }
106         rsaPk := rsa.PublicKey{
107                 E: int(rsaPub.E.Int64()),
108                 N: rsaPub.N,
109         }
110         pkix, _ := x509.MarshalPKIXPublicKey(&rsaPk)
111         md5pkix := md5.Sum([]byte(pkix))
112         sha1pkix := sha1.Sum([]byte(pkix))
113         md5fp = ""
114         sha1fp = ""
115         for i := 0; i < len(md5pkix); i++ {
116                 md5fp += fmt.Sprintf(":%02x", md5pkix[i])
117         }
118         for i := 0; i < len(sha1pkix); i++ {
119                 sha1fp += fmt.Sprintf(":%02x", sha1pkix[i])
120         }
121         return md5fp[1:], sha1fp[1:], nil
122 }
123
124 func (instanceSet *ec2InstanceSet) Create(
125         instanceType arvados.InstanceType,
126         imageID cloud.ImageID,
127         newTags cloud.InstanceTags,
128         initCommand cloud.InitCommand,
129         publicKey ssh.PublicKey) (cloud.Instance, error) {
130
131         md5keyFingerprint, sha1keyFingerprint, err := awsKeyFingerprint(publicKey)
132         if err != nil {
133                 return nil, fmt.Errorf("Could not make key fingerprint: %v", err)
134         }
135         instanceSet.keysMtx.Lock()
136         var keyname string
137         var ok bool
138         if keyname, ok = instanceSet.keys[md5keyFingerprint]; !ok {
139                 keyout, err := instanceSet.client.DescribeKeyPairs(&ec2.DescribeKeyPairsInput{
140                         Filters: []*ec2.Filter{{
141                                 Name:   aws.String("fingerprint"),
142                                 Values: []*string{&md5keyFingerprint, &sha1keyFingerprint},
143                         }},
144                 })
145                 if err != nil {
146                         return nil, fmt.Errorf("Could not search for keypair: %v", err)
147                 }
148
149                 if len(keyout.KeyPairs) > 0 {
150                         keyname = *(keyout.KeyPairs[0].KeyName)
151                 } else {
152                         keyname = "arvados-dispatch-keypair-" + md5keyFingerprint
153                         _, err := instanceSet.client.ImportKeyPair(&ec2.ImportKeyPairInput{
154                                 KeyName:           &keyname,
155                                 PublicKeyMaterial: ssh.MarshalAuthorizedKey(publicKey),
156                         })
157                         if err != nil {
158                                 return nil, fmt.Errorf("Could not import keypair: %v", err)
159                         }
160                 }
161                 instanceSet.keys[md5keyFingerprint] = keyname
162         }
163         instanceSet.keysMtx.Unlock()
164
165         ec2tags := []*ec2.Tag{}
166         for k, v := range newTags {
167                 ec2tags = append(ec2tags, &ec2.Tag{
168                         Key:   aws.String(k),
169                         Value: aws.String(v),
170                 })
171         }
172
173         var groups []string
174         for sg := range instanceSet.ec2config.SecurityGroupIDs {
175                 groups = append(groups, sg)
176         }
177
178         rii := ec2.RunInstancesInput{
179                 ImageId:      aws.String(string(imageID)),
180                 InstanceType: &instanceType.ProviderType,
181                 MaxCount:     aws.Int64(1),
182                 MinCount:     aws.Int64(1),
183                 KeyName:      &keyname,
184
185                 NetworkInterfaces: []*ec2.InstanceNetworkInterfaceSpecification{
186                         {
187                                 AssociatePublicIpAddress: aws.Bool(false),
188                                 DeleteOnTermination:      aws.Bool(true),
189                                 DeviceIndex:              aws.Int64(0),
190                                 Groups:                   aws.StringSlice(groups),
191                                 SubnetId:                 &instanceSet.ec2config.SubnetID,
192                         }},
193                 DisableApiTermination:             aws.Bool(false),
194                 InstanceInitiatedShutdownBehavior: aws.String("terminate"),
195                 TagSpecifications: []*ec2.TagSpecification{
196                         {
197                                 ResourceType: aws.String("instance"),
198                                 Tags:         ec2tags,
199                         }},
200                 UserData: aws.String(base64.StdEncoding.EncodeToString([]byte("#!/bin/sh\n" + initCommand + "\n"))),
201         }
202
203         if instanceType.AddedScratch > 0 {
204                 rii.BlockDeviceMappings = []*ec2.BlockDeviceMapping{{
205                         DeviceName: aws.String("/dev/xvdt"),
206                         Ebs: &ec2.EbsBlockDevice{
207                                 DeleteOnTermination: aws.Bool(true),
208                                 VolumeSize:          aws.Int64((int64(instanceType.AddedScratch) + (1<<30 - 1)) >> 30),
209                                 VolumeType:          &instanceSet.ec2config.EBSVolumeType,
210                         }}}
211         }
212
213         if instanceType.Preemptible {
214                 rii.InstanceMarketOptions = &ec2.InstanceMarketOptionsRequest{
215                         MarketType: aws.String("spot"),
216                         SpotOptions: &ec2.SpotMarketOptions{
217                                 InstanceInterruptionBehavior: aws.String("terminate"),
218                                 MaxPrice:                     aws.String(fmt.Sprintf("%v", instanceType.Price)),
219                         }}
220         }
221
222         rsv, err := instanceSet.client.RunInstances(&rii)
223
224         if err != nil {
225                 return nil, err
226         }
227
228         return &ec2Instance{
229                 provider: instanceSet,
230                 instance: rsv.Instances[0],
231         }, nil
232 }
233
234 func (instanceSet *ec2InstanceSet) Instances(tags cloud.InstanceTags) (instances []cloud.Instance, err error) {
235         var filters []*ec2.Filter
236         for k, v := range tags {
237                 filters = append(filters, &ec2.Filter{
238                         Name:   aws.String("tag:" + k),
239                         Values: []*string{aws.String(v)},
240                 })
241         }
242         dii := &ec2.DescribeInstancesInput{Filters: filters}
243         for {
244                 dio, err := instanceSet.client.DescribeInstances(dii)
245                 if err != nil {
246                         return nil, err
247                 }
248
249                 for _, rsv := range dio.Reservations {
250                         for _, inst := range rsv.Instances {
251                                 if *inst.State.Name != "shutting-down" && *inst.State.Name != "terminated" {
252                                         instances = append(instances, &ec2Instance{instanceSet, inst})
253                                 }
254                         }
255                 }
256                 if dio.NextToken == nil {
257                         return instances, err
258                 }
259                 dii.NextToken = dio.NextToken
260         }
261 }
262
263 func (instanceSet *ec2InstanceSet) Stop() {
264 }
265
266 type ec2Instance struct {
267         provider *ec2InstanceSet
268         instance *ec2.Instance
269 }
270
271 func (inst *ec2Instance) ID() cloud.InstanceID {
272         return cloud.InstanceID(*inst.instance.InstanceId)
273 }
274
275 func (inst *ec2Instance) String() string {
276         return *inst.instance.InstanceId
277 }
278
279 func (inst *ec2Instance) ProviderType() string {
280         return *inst.instance.InstanceType
281 }
282
283 func (inst *ec2Instance) SetTags(newTags cloud.InstanceTags) error {
284         var ec2tags []*ec2.Tag
285         for k, v := range newTags {
286                 ec2tags = append(ec2tags, &ec2.Tag{
287                         Key:   aws.String(k),
288                         Value: aws.String(v),
289                 })
290         }
291
292         _, err := inst.provider.client.CreateTags(&ec2.CreateTagsInput{
293                 Resources: []*string{inst.instance.InstanceId},
294                 Tags:      ec2tags,
295         })
296
297         return err
298 }
299
300 func (inst *ec2Instance) Tags() cloud.InstanceTags {
301         tags := make(map[string]string)
302
303         for _, t := range inst.instance.Tags {
304                 tags[*t.Key] = *t.Value
305         }
306
307         return tags
308 }
309
310 func (inst *ec2Instance) Destroy() error {
311         _, err := inst.provider.client.TerminateInstances(&ec2.TerminateInstancesInput{
312                 InstanceIds: []*string{inst.instance.InstanceId},
313         })
314         return err
315 }
316
317 func (inst *ec2Instance) Address() string {
318         if inst.instance.PrivateIpAddress != nil {
319                 return *inst.instance.PrivateIpAddress
320         }
321         return ""
322 }
323
324 func (inst *ec2Instance) RemoteUser() string {
325         return inst.provider.ec2config.AdminUsername
326 }
327
328 func (inst *ec2Instance) VerifyHostKey(ssh.PublicKey, *ssh.Client) error {
329         return cloud.ErrNotImplemented
330 }