X-Git-Url: https://git.arvados.org/arvados.git/blobdiff_plain/f04693da1811e670d4cbb981debeecf14d79137c..7c430b8a41da3a66522d1ca08e3a9f637b609195:/services/keepstore/s3_volume.go diff --git a/services/keepstore/s3_volume.go b/services/keepstore/s3_volume.go index 22a38e2085..220377af28 100644 --- a/services/keepstore/s3_volume.go +++ b/services/keepstore/s3_volume.go @@ -5,6 +5,7 @@ package main import ( + "bufio" "bytes" "context" "crypto/sha256" @@ -45,8 +46,8 @@ func newS3Volume(cluster *arvados.Cluster, volume arvados.Volume, logger logrus. } func (v *S3Volume) check() error { - if v.Bucket == "" || v.AccessKey == "" || v.SecretKey == "" { - return errors.New("DriverParameters: Bucket, AccessKey, and SecretKey must be provided") + if v.Bucket == "" { + return errors.New("DriverParameters: Bucket must be provided") } if v.IndexPageSize == 0 { v.IndexPageSize = 1000 @@ -55,7 +56,8 @@ func (v *S3Volume) check() error { return errors.New("DriverParameters: RaceWindow must not be negative") } - region, ok := aws.Regions[v.Region] + var ok bool + v.region, ok = aws.Regions[v.Region] if v.Endpoint == "" { if !ok { return fmt.Errorf("unrecognized region %+q; try specifying endpoint instead", v.Region) @@ -64,18 +66,13 @@ func (v *S3Volume) check() error { return fmt.Errorf("refusing to use AWS region name %+q with endpoint %+q; "+ "specify empty endpoint or use a different region name", v.Region, v.Endpoint) } else { - region = aws.Region{ + v.region = aws.Region{ Name: v.Region, S3Endpoint: v.Endpoint, S3LocationConstraint: v.LocationConstraint, } } - auth := aws.Auth{ - AccessKey: v.AccessKey, - SecretKey: v.SecretKey, - } - // Zero timeouts mean "wait forever", which is a bad // default. Default to long timeouts instead. if v.ConnectTimeout == 0 { @@ -85,16 +82,9 @@ func (v *S3Volume) check() error { v.ReadTimeout = s3DefaultReadTimeout } - client := s3.New(auth, region) - if region.EC2Endpoint.Signer == aws.V4Signature { - // Currently affects only eu-central-1 - client.Signature = aws.V4Signature - } - client.ConnectTimeout = time.Duration(v.ConnectTimeout) - client.ReadTimeout = time.Duration(v.ReadTimeout) v.bucket = &s3bucket{ - Bucket: &s3.Bucket{ - S3: client, + bucket: &s3.Bucket{ + S3: v.newS3Client(), Name: v.Bucket, }, } @@ -102,6 +92,11 @@ func (v *S3Volume) check() error { lbls := prometheus.Labels{"device_id": v.GetDeviceID()} v.bucket.stats.opsCounters, v.bucket.stats.errCounters, v.bucket.stats.ioBytes = v.metrics.getCounterVecsFor(lbls) + err := v.bootstrapIAMCredentials() + if err != nil { + return fmt.Errorf("error getting IAM credentials: %s", err) + } + return nil } @@ -113,7 +108,7 @@ const ( var ( // ErrS3TrashDisabled is returned by Trash if that operation // is impossible with the current config. - ErrS3TrashDisabled = fmt.Errorf("trash function is disabled because -trash-lifetime=0 and -s3-unsafe-delete=false") + ErrS3TrashDisabled = fmt.Errorf("trash function is disabled because Collections.BlobTrashLifetime=0 and DriverParameters.UnsafeDelete=false") s3ACL = s3.Private @@ -136,6 +131,9 @@ func s3regions() (okList []string) { type S3Volume struct { AccessKey string SecretKey string + AuthToken string // populated automatically when IAMRole is used + AuthExpiration time.Time // populated automatically when IAMRole is used + IAMRole string Endpoint string Region string Bucket string @@ -151,6 +149,7 @@ type S3Volume struct { logger logrus.FieldLogger metrics *volumeMetricsVecs bucket *s3bucket + region aws.Region startOnce sync.Once } @@ -159,6 +158,141 @@ func (v *S3Volume) GetDeviceID() string { return "s3://" + v.Endpoint + "/" + v.Bucket } +func (v *S3Volume) bootstrapIAMCredentials() error { + if v.AccessKey != "" || v.SecretKey != "" { + if v.IAMRole != "" { + return errors.New("invalid DriverParameters: AccessKey and SecretKey must be blank if IAMRole is specified") + } + return nil + } + ttl, err := v.updateIAMCredentials() + if err != nil { + return err + } + go func() { + for { + time.Sleep(ttl) + ttl, err = v.updateIAMCredentials() + if err != nil { + v.logger.WithError(err).Warnf("failed to update credentials for IAM role %q", v.IAMRole) + ttl = time.Second + } else if ttl < time.Second { + v.logger.WithField("TTL", ttl).Warnf("received stale credentials for IAM role %q", v.IAMRole) + ttl = time.Second + } + } + }() + return nil +} + +func (v *S3Volume) newS3Client() *s3.S3 { + auth := aws.NewAuth(v.AccessKey, v.SecretKey, v.AuthToken, v.AuthExpiration) + client := s3.New(*auth, v.region) + if v.region.EC2Endpoint.Signer == aws.V4Signature { + // Currently affects only eu-central-1 + client.Signature = aws.V4Signature + } + client.ConnectTimeout = time.Duration(v.ConnectTimeout) + client.ReadTimeout = time.Duration(v.ReadTimeout) + return client +} + +// returned by AWS metadata endpoint .../security-credentials/${rolename} +type iamCredentials struct { + Code string + LastUpdated time.Time + Type string + AccessKeyID string + SecretAccessKey string + Token string + Expiration time.Time +} + +// Returns TTL of updated credentials, i.e., time to sleep until next +// update. +func (v *S3Volume) updateIAMCredentials() (time.Duration, error) { + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Minute)) + defer cancel() + + metadataBaseURL := "http://169.254.169.254/latest/meta-data/iam/security-credentials/" + + var url string + if strings.Contains(v.IAMRole, "://") { + // Configuration provides complete URL (used by tests) + url = v.IAMRole + } else if v.IAMRole != "" { + // Configuration provides IAM role name and we use the + // AWS metadata endpoint + url = metadataBaseURL + v.IAMRole + } else { + url = metadataBaseURL + v.logger.WithField("URL", url).Debug("looking up IAM role name") + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return 0, fmt.Errorf("error setting up request %s: %s", url, err) + } + resp, err := http.DefaultClient.Do(req.WithContext(ctx)) + if err != nil { + return 0, fmt.Errorf("error getting %s: %s", url, err) + } + defer resp.Body.Close() + if resp.StatusCode == http.StatusNotFound { + return 0, fmt.Errorf("this instance does not have an IAM role assigned -- either assign a role, or configure AccessKey and SecretKey explicitly in DriverParameters (error getting %s: HTTP status %s)", url, resp.Status) + } else if resp.StatusCode != http.StatusOK { + return 0, fmt.Errorf("error getting %s: HTTP status %s", url, resp.Status) + } + body := bufio.NewReader(resp.Body) + var role string + _, err = fmt.Fscanf(body, "%s\n", &role) + if err != nil { + return 0, fmt.Errorf("error reading response from %s: %s", url, err) + } + if n, _ := body.Read(make([]byte, 64)); n > 0 { + v.logger.Warnf("ignoring additional data returned by metadata endpoint %s after the single role name that we expected", url) + } + v.logger.WithField("Role", role).Debug("looked up IAM role name") + url = url + role + } + + v.logger.WithField("URL", url).Debug("getting credentials") + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return 0, fmt.Errorf("error setting up request %s: %s", url, err) + } + resp, err := http.DefaultClient.Do(req.WithContext(ctx)) + if err != nil { + return 0, fmt.Errorf("error getting %s: %s", url, err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return 0, fmt.Errorf("error getting %s: HTTP status %s", url, resp.Status) + } + var cred iamCredentials + err = json.NewDecoder(resp.Body).Decode(&cred) + if err != nil { + return 0, fmt.Errorf("error decoding credentials from %s: %s", url, err) + } + v.AccessKey, v.SecretKey, v.AuthToken, v.AuthExpiration = cred.AccessKeyID, cred.SecretAccessKey, cred.Token, cred.Expiration + v.bucket.SetBucket(&s3.Bucket{ + S3: v.newS3Client(), + Name: v.Bucket, + }) + // TTL is time from now to expiration, minus 5m. "We make new + // credentials available at least five minutes before the + // expiration of the old credentials." -- + // https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/iam-roles-for-amazon-ec2.html#instance-metadata-security-credentials + // (If that's not true, the returned ttl might be zero or + // negative, which the caller can handle.) + ttl := cred.Expiration.Sub(time.Now()) - 5*time.Minute + v.logger.WithFields(logrus.Fields{ + "AccessKeyID": cred.AccessKeyID, + "LastUpdated": cred.LastUpdated, + "Expiration": cred.Expiration, + "TTL": arvados.Duration(ttl), + }).Debug("updated credentials") + return ttl, nil +} + func (v *S3Volume) getReaderWithContext(ctx context.Context, loc string) (rdr io.ReadCloser, err error) { ready := make(chan bool) go func() { @@ -410,13 +544,13 @@ func (v *S3Volume) Mtime(loc string) (time.Time, error) { func (v *S3Volume) IndexTo(prefix string, writer io.Writer) error { // Use a merge sort to find matching sets of X and recent/X. dataL := s3Lister{ - Bucket: v.bucket.Bucket, + Bucket: v.bucket.Bucket(), Prefix: prefix, PageSize: v.IndexPageSize, Stats: &v.bucket.stats, } recentL := s3Lister{ - Bucket: v.bucket.Bucket, + Bucket: v.bucket.Bucket(), Prefix: "recent/" + prefix, PageSize: v.IndexPageSize, Stats: &v.bucket.stats, @@ -531,15 +665,15 @@ func (v *S3Volume) checkRaceWindow(loc string) error { // (PutCopy returns 200 OK if the request was received, even if the // copy failed). func (v *S3Volume) safeCopy(dst, src string) error { - resp, err := v.bucket.PutCopy(dst, s3ACL, s3.CopyOptions{ + resp, err := v.bucket.Bucket().PutCopy(dst, s3ACL, s3.CopyOptions{ ContentType: "application/octet-stream", MetadataDirective: "REPLACE", - }, v.bucket.Name+"/"+src) + }, v.bucket.Bucket().Name+"/"+src) err = v.translateError(err) if os.IsNotExist(err) { return err } else if err != nil { - return fmt.Errorf("PutCopy(%q ← %q): %s", dst, v.bucket.Name+"/"+src, err) + return fmt.Errorf("PutCopy(%q ← %q): %s", dst, v.bucket.Bucket().Name+"/"+src, err) } if t, err := time.Parse(time.RFC3339Nano, resp.LastModified); err != nil { return fmt.Errorf("PutCopy succeeded but did not return a timestamp: %q: %s", resp.LastModified, err) @@ -663,7 +797,7 @@ func (v *S3Volume) translateError(err error) error { return err } -// EmptyTrash looks for trashed blocks that exceeded TrashLifetime +// EmptyTrash looks for trashed blocks that exceeded BlobTrashLifetime // and deletes them from the volume. func (v *S3Volume) EmptyTrash() { if v.cluster.Collections.BlobDeleteConcurrency < 1 { @@ -712,8 +846,8 @@ func (v *S3Volume) EmptyTrash() { // the raceWindow that starts if we // delete trash/X now. // - // Note this means (TrashCheckInterval - // < BlobSignatureTTL - raceWindow) is + // Note this means (TrashSweepInterval + // < BlobSigningTTL - raceWindow) is // necessary to avoid starvation. log.Printf("notice: %s: EmptyTrash: detected old race for %q, calling fixRace + Touch", v, loc) v.fixRace(loc) @@ -769,7 +903,7 @@ func (v *S3Volume) EmptyTrash() { } trashL := s3Lister{ - Bucket: v.bucket.Bucket, + Bucket: v.bucket.Bucket(), Prefix: "trash/", PageSize: v.IndexPageSize, Stats: &v.bucket.stats, @@ -848,14 +982,29 @@ func (lister *s3Lister) pop() (k *s3.Key) { return } -// s3bucket wraps s3.bucket and counts I/O and API usage stats. +// s3bucket wraps s3.bucket and counts I/O and API usage stats. The +// wrapped bucket can be replaced atomically with SetBucket in order +// to update credentials. type s3bucket struct { - *s3.Bucket - stats s3bucketStats + bucket *s3.Bucket + stats s3bucketStats + mu sync.Mutex +} + +func (b *s3bucket) Bucket() *s3.Bucket { + b.mu.Lock() + defer b.mu.Unlock() + return b.bucket +} + +func (b *s3bucket) SetBucket(bucket *s3.Bucket) { + b.mu.Lock() + defer b.mu.Unlock() + b.bucket = bucket } func (b *s3bucket) GetReader(path string) (io.ReadCloser, error) { - rdr, err := b.Bucket.GetReader(path) + rdr, err := b.Bucket().GetReader(path) b.stats.TickOps("get") b.stats.Tick(&b.stats.Ops, &b.stats.GetOps) b.stats.TickErr(err) @@ -863,7 +1012,7 @@ func (b *s3bucket) GetReader(path string) (io.ReadCloser, error) { } func (b *s3bucket) Head(path string, headers map[string][]string) (*http.Response, error) { - resp, err := b.Bucket.Head(path, headers) + resp, err := b.Bucket().Head(path, headers) b.stats.TickOps("head") b.stats.Tick(&b.stats.Ops, &b.stats.HeadOps) b.stats.TickErr(err) @@ -882,7 +1031,7 @@ func (b *s3bucket) PutReader(path string, r io.Reader, length int64, contType st } else { r = NewCountingReader(r, b.stats.TickOutBytes) } - err := b.Bucket.PutReader(path, r, length, contType, perm, options) + err := b.Bucket().PutReader(path, r, length, contType, perm, options) b.stats.TickOps("put") b.stats.Tick(&b.stats.Ops, &b.stats.PutOps) b.stats.TickErr(err) @@ -890,7 +1039,7 @@ func (b *s3bucket) PutReader(path string, r io.Reader, length int64, contType st } func (b *s3bucket) Del(path string) error { - err := b.Bucket.Del(path) + err := b.Bucket().Del(path) b.stats.TickOps("delete") b.stats.Tick(&b.stats.Ops, &b.stats.DelOps) b.stats.TickErr(err)