14291: Introduce "AddedScratch" and "IncludedScratch" to InstanceType
[arvados.git] / lib / cloud / ec2 / ec2.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package ec2
6
7 import (
8         "encoding/base64"
9         "encoding/json"
10         "fmt"
11         "log"
12         "strings"
13         "sync"
14
15         "git.curoverse.com/arvados.git/lib/cloud"
16         "git.curoverse.com/arvados.git/sdk/go/arvados"
17         "github.com/aws/aws-sdk-go/aws"
18         "github.com/aws/aws-sdk-go/aws/credentials"
19         "github.com/aws/aws-sdk-go/aws/session"
20         "github.com/aws/aws-sdk-go/service/ec2"
21         "github.com/sirupsen/logrus"
22         "golang.org/x/crypto/ssh"
23 )
24
25 const arvadosDispatchID = "arvados-dispatch-id"
26 const tagPrefix = "arvados-dispatch-tag-"
27
28 // Driver is the ec2 implementation of the cloud.Driver interface.
29 var Driver = cloud.DriverFunc(newEC2InstanceSet)
30
31 type ec2InstanceSetConfig struct {
32         AccessKeyID      string
33         SecretAccessKey  string
34         Region           string
35         SecurityGroupIDs []string
36         SubnetID         string
37         AdminUsername    string
38 }
39
40 type ec2Interface interface {
41         ImportKeyPair(input *ec2.ImportKeyPairInput) (*ec2.ImportKeyPairOutput, error)
42         RunInstances(input *ec2.RunInstancesInput) (*ec2.Reservation, error)
43         DescribeInstances(input *ec2.DescribeInstancesInput) (*ec2.DescribeInstancesOutput, error)
44         CreateTags(input *ec2.CreateTagsInput) (*ec2.CreateTagsOutput, error)
45         TerminateInstances(input *ec2.TerminateInstancesInput) (*ec2.TerminateInstancesOutput, error)
46 }
47
48 type ec2InstanceSet struct {
49         ec2config    ec2InstanceSetConfig
50         dispatcherID cloud.InstanceSetID
51         logger       logrus.FieldLogger
52         client       ec2Interface
53         keysMtx      sync.Mutex
54         keys         map[string]string
55 }
56
57 func newEC2InstanceSet(config json.RawMessage, dispatcherID cloud.InstanceSetID, logger logrus.FieldLogger) (prv cloud.InstanceSet, err error) {
58         instanceSet := &ec2InstanceSet{
59                 dispatcherID: dispatcherID,
60                 logger:       logger,
61         }
62         err = json.Unmarshal(config, &instanceSet.ec2config)
63         if err != nil {
64                 return nil, err
65         }
66         awsConfig := aws.NewConfig().
67                 WithCredentials(credentials.NewStaticCredentials(
68                         instanceSet.ec2config.AccessKeyID,
69                         instanceSet.ec2config.SecretAccessKey,
70                         "")).
71                 WithRegion(instanceSet.ec2config.Region)
72         instanceSet.client = ec2.New(session.Must(session.NewSession(awsConfig)))
73         instanceSet.keys = make(map[string]string)
74         return instanceSet, nil
75 }
76
77 func (instanceSet *ec2InstanceSet) Create(
78         instanceType arvados.InstanceType,
79         imageID cloud.ImageID,
80         newTags cloud.InstanceTags,
81         initCommand cloud.InitCommand,
82         publicKey ssh.PublicKey) (cloud.Instance, error) {
83
84         keyFingerprint := ssh.FingerprintSHA256(publicKey)
85         instanceSet.keysMtx.Lock()
86         var keyname string
87         var ok bool
88         if keyname, ok = instanceSet.keys[keyFingerprint]; !ok {
89                 keyname = "arvados-dispatch-keypair-" + keyFingerprint
90                 instanceSet.client.ImportKeyPair(&ec2.ImportKeyPairInput{
91                         KeyName:           &keyname,
92                         PublicKeyMaterial: ssh.MarshalAuthorizedKey(publicKey),
93                 })
94                 instanceSet.keys[keyFingerprint] = keyname
95         }
96         instanceSet.keysMtx.Unlock()
97
98         ec2tags := []*ec2.Tag{
99                 &ec2.Tag{
100                         Key:   aws.String(arvadosDispatchID),
101                         Value: aws.String(string(instanceSet.dispatcherID)),
102                 },
103                 &ec2.Tag{
104                         Key:   aws.String("arvados-class"),
105                         Value: aws.String("dynamic-compute"),
106                 },
107         }
108         for k, v := range newTags {
109                 ec2tags = append(ec2tags, &ec2.Tag{
110                         Key:   aws.String(tagPrefix + k),
111                         Value: aws.String(v),
112                 })
113         }
114
115         rii := ec2.RunInstancesInput{
116                 ImageId:      aws.String(string(imageID)),
117                 InstanceType: &instanceType.ProviderType,
118                 MaxCount:     aws.Int64(1),
119                 MinCount:     aws.Int64(1),
120                 KeyName:      &keyname,
121
122                 NetworkInterfaces: []*ec2.InstanceNetworkInterfaceSpecification{
123                         &ec2.InstanceNetworkInterfaceSpecification{
124                                 AssociatePublicIpAddress: aws.Bool(false),
125                                 DeleteOnTermination:      aws.Bool(true),
126                                 DeviceIndex:              aws.Int64(0),
127                                 Groups:                   aws.StringSlice(instanceSet.ec2config.SecurityGroupIDs),
128                                 SubnetId:                 &instanceSet.ec2config.SubnetID,
129                         }},
130                 DisableApiTermination:             aws.Bool(false),
131                 InstanceInitiatedShutdownBehavior: aws.String("terminate"),
132                 UserData: aws.String(base64.StdEncoding.EncodeToString([]byte("#!/bin/sh\n" + initCommand + "\n"))),
133                 TagSpecifications: []*ec2.TagSpecification{
134                         &ec2.TagSpecification{
135                                 ResourceType: aws.String("instance"),
136                                 Tags:         ec2tags,
137                         }},
138         }
139
140         if instanceType.AddedScratch > 0 {
141                 rii.BlockDeviceMappings = []*ec2.BlockDeviceMapping{&ec2.BlockDeviceMapping{
142                         DeviceName: aws.String("/dev/xvdt"),
143                         Ebs: &ec2.EbsBlockDevice{
144                                 DeleteOnTermination: aws.Bool(true),
145                                 VolumeSize:          aws.Int64((int64(instanceType.AddedScratch) / 1000000000) + 1),
146                                 VolumeType:          aws.String("gp2"),
147                         }}}
148         }
149
150         if instanceType.Preemptible {
151                 rii.InstanceMarketOptions = &ec2.InstanceMarketOptionsRequest{
152                         MarketType: aws.String("spot"),
153                         SpotOptions: &ec2.SpotMarketOptions{
154                                 InstanceInterruptionBehavior: aws.String("terminate"),
155                                 MaxPrice:                     aws.String(fmt.Sprintf("%v", instanceType.Price)),
156                         }}
157         }
158
159         rsv, err := instanceSet.client.RunInstances(&rii)
160
161         if err != nil {
162                 return nil, err
163         }
164
165         return &ec2Instance{
166                 provider: instanceSet,
167                 instance: rsv.Instances[0],
168         }, nil
169 }
170
171 func (instanceSet *ec2InstanceSet) Instances(cloud.InstanceTags) (instances []cloud.Instance, err error) {
172         dii := &ec2.DescribeInstancesInput{
173                 Filters: []*ec2.Filter{&ec2.Filter{
174                         Name:   aws.String("tag:" + arvadosDispatchID),
175                         Values: []*string{aws.String(string(instanceSet.dispatcherID))},
176                 }}}
177
178         for {
179                 dio, err := instanceSet.client.DescribeInstances(dii)
180                 if err != nil {
181                         return nil, err
182                 }
183
184                 for _, rsv := range dio.Reservations {
185                         for _, inst := range rsv.Instances {
186                                 instances = append(instances, &ec2Instance{instanceSet, inst})
187                         }
188                 }
189                 if dio.NextToken == nil {
190                         return instances, err
191                 }
192                 dii.NextToken = dio.NextToken
193         }
194 }
195
196 func (az *ec2InstanceSet) Stop() {
197 }
198
199 type ec2Instance struct {
200         provider *ec2InstanceSet
201         instance *ec2.Instance
202 }
203
204 func (inst *ec2Instance) ID() cloud.InstanceID {
205         return cloud.InstanceID(*inst.instance.InstanceId)
206 }
207
208 func (inst *ec2Instance) String() string {
209         return *inst.instance.InstanceId
210 }
211
212 func (inst *ec2Instance) ProviderType() string {
213         return *inst.instance.InstanceType
214 }
215
216 func (inst *ec2Instance) SetTags(newTags cloud.InstanceTags) error {
217         ec2tags := []*ec2.Tag{
218                 &ec2.Tag{
219                         Key:   aws.String(arvadosDispatchID),
220                         Value: aws.String(string(inst.provider.dispatcherID)),
221                 },
222         }
223         for k, v := range newTags {
224                 ec2tags = append(ec2tags, &ec2.Tag{
225                         Key:   aws.String(tagPrefix + k),
226                         Value: aws.String(v),
227                 })
228         }
229
230         _, err := inst.provider.client.CreateTags(&ec2.CreateTagsInput{
231                 Resources: []*string{inst.instance.InstanceId},
232                 Tags:      ec2tags,
233         })
234
235         return err
236 }
237
238 func (inst *ec2Instance) Tags() cloud.InstanceTags {
239         tags := make(map[string]string)
240
241         for _, t := range inst.instance.Tags {
242                 if strings.HasPrefix(*t.Key, tagPrefix) {
243                         tags[(*t.Key)[len(tagPrefix):]] = *t.Value
244                 }
245         }
246
247         return tags
248 }
249
250 func (inst *ec2Instance) Destroy() error {
251         log.Printf("terminating %v", *inst.instance.InstanceId)
252         _, err := inst.provider.client.TerminateInstances(&ec2.TerminateInstancesInput{
253                 InstanceIds: []*string{inst.instance.InstanceId},
254         })
255         return err
256 }
257
258 func (inst *ec2Instance) Address() string {
259         if inst.instance.PrivateIpAddress != nil {
260                 return *inst.instance.PrivateIpAddress
261         } else {
262                 return ""
263         }
264 }
265
266 func (inst *ec2Instance) RemoteUser() string {
267         return inst.provider.ec2config.AdminUsername
268 }
269
270 func (inst *ec2Instance) VerifyHostKey(ssh.PublicKey, *ssh.Client) error {
271         return cloud.ErrNotImplemented
272 }