X-Git-Url: https://git.arvados.org/arvados.git/blobdiff_plain/39f6e9f70f683237d9488faac1c549ca19ac9dae..04704ea80b294655fe14d0c8cddf4ec1a6b21b4d:/services/keepstore/keepstore.go diff --git a/services/keepstore/keepstore.go b/services/keepstore/keepstore.go index 62b6d15e56..60d062e1e3 100644 --- a/services/keepstore/keepstore.go +++ b/services/keepstore/keepstore.go @@ -223,7 +223,7 @@ func (ks *keepstore) signLocator(token, locator string) string { } func (ks *keepstore) BlockRead(ctx context.Context, opts arvados.BlockReadOptions) (n int, err error) { - li, err := parseLocator(opts.Locator) + li, err := getLocatorInfo(opts.Locator) if err != nil { return 0, err } @@ -243,13 +243,58 @@ func (ks *keepstore) BlockRead(ctx context.Context, opts arvados.BlockReadOption } else { out = io.MultiWriter(out, hashcheck) } + + buf, err := ks.bufferPool.GetContext(ctx) + if err != nil { + return 0, err + } + defer ks.bufferPool.Put(buf) + streamer := newStreamWriterAt(out, 65536, buf) + defer streamer.Close() + var errToCaller error = os.ErrNotExist for _, mnt := range ks.rendezvous(li.hash, ks.mountsR) { if ctx.Err() != nil { return 0, ctx.Err() } - n, err = mnt.BlockRead(ctx, li.hash, out) - if err == nil && li.size > 0 && n != li.size { + err := mnt.BlockRead(ctx, li.hash, streamer) + if err != nil { + if streamer.WroteAt() != 0 { + // BlockRead encountered an error + // after writing some data, so it's + // too late to try another + // volume. Flush streamer before + // calling Wrote() to ensure our + // return value accurately reflects + // the number of bytes written to + // opts.WriteTo. + streamer.Close() + return streamer.Wrote(), err + } + if !os.IsNotExist(err) { + errToCaller = err + } + continue + } + if li.size == 0 { + // hashCheckingWriter isn't in use because we + // don't know the expected size. All we can do + // is check after writing all the data, and + // trust the caller is doing a HEAD request so + // it's not too late to set an error code in + // the response header. + err = streamer.Close() + if hash := fmt.Sprintf("%x", hashcheck.Sum(nil)); hash != li.hash && err == nil { + err = errChecksum + } + if rw, ok := opts.WriteTo.(http.ResponseWriter); ok { + // We didn't set the content-length header + // above because we didn't know the block size + // until now. + rw.Header().Set("Content-Length", fmt.Sprintf("%d", streamer.WroteAt())) + } + return streamer.WroteAt(), err + } else if streamer.WroteAt() != li.size { // If the backend read fewer bytes than // expected but returns no error, we can // classify this as a checksum error (even @@ -260,49 +305,27 @@ func (ks *keepstore) BlockRead(ctx context.Context, opts arvados.BlockReadOption // it anyway, but if it's a HEAD request the // caller can still change the response status // code. - return n, errChecksum - } - if err == nil && li.size == 0 { - // hashCheckingWriter isn't in use because we - // don't know the expected size. All we can do - // is check after writing all the data, and - // trust the caller is doing a HEAD request so - // it's not too late to set an error code in - // the response header. - if hash := fmt.Sprintf("%x", hashcheck.Sum(nil)); hash != li.hash { - return n, errChecksum - } - } - if rw, ok := opts.WriteTo.(http.ResponseWriter); ok && li.size == 0 && err == nil { - // We didn't set the content-length header - // above because we didn't know the block size - // until now. - rw.Header().Set("Content-Length", fmt.Sprintf("%d", n)) - } - if n > 0 || err == nil { - // success, or there's an error but we can't - // retry because we've already sent some data. - return n, err - } - if !os.IsNotExist(err) { - // If some volume returns a transient error, - // return it to the caller instead of "Not - // found" so it can retry. - errToCaller = err + return streamer.WroteAt(), errChecksum } + // Ensure streamer flushes all buffered data without + // errors. + err = streamer.Close() + return streamer.Wrote(), err } return 0, errToCaller } func (ks *keepstore) blockReadRemote(ctx context.Context, opts arvados.BlockReadOptions) (int, error) { - ks.logger.Infof("blockReadRemote(%s)", opts.Locator) token := ctxToken(ctx) if token == "" { return 0, errNoTokenProvided } var remoteClient *keepclient.KeepClient var parts []string - var size int + li, err := getLocatorInfo(opts.Locator) + if err != nil { + return 0, err + } for i, part := range strings.Split(opts.Locator, "+") { switch { case i == 0: @@ -324,8 +347,6 @@ func (ks *keepstore) blockReadRemote(ctx context.Context, opts arvados.BlockRead } remoteClient = kc part = "A" + part[7:] - case len(part) > 0 && part[0] >= '0' && part[0] <= '9': - size, _ = strconv.Atoi(part) } parts = append(parts, part) } @@ -336,8 +357,8 @@ func (ks *keepstore) blockReadRemote(ctx context.Context, opts arvados.BlockRead if opts.LocalLocator == nil { // Read from remote cluster and stream response back // to caller - if rw, ok := opts.WriteTo.(http.ResponseWriter); ok && size > 0 { - rw.Header().Set("Content-Length", fmt.Sprintf("%d", size)) + if rw, ok := opts.WriteTo.(http.ResponseWriter); ok && li.size > 0 { + rw.Header().Set("Content-Length", fmt.Sprintf("%d", li.size)) } return remoteClient.BlockRead(ctx, arvados.BlockReadOptions{ Locator: locator, @@ -459,7 +480,7 @@ func (ks *keepstore) BlockWrite(ctx context.Context, opts arvados.BlockWriteOpti continue } cmp := &checkEqual{Expect: opts.Data} - if _, err := mnt.BlockRead(ctx, hash, cmp); err == nil { + if err := mnt.BlockRead(ctx, hash, cmp); err == nil { if !cmp.Equal() { return resp, errCollision } @@ -564,31 +585,34 @@ func (*keepstore) rendezvous(locator string, mnts []*mount) []*mount { return mnts } -// checkEqual reports whether the data written to it (via io.Writer +// checkEqual reports whether the data written to it (via io.WriterAt // interface) is equal to the expected data. // // Expect should not be changed after the first Write. +// +// Results are undefined if WriteAt is called with overlapping ranges. type checkEqual struct { - Expect []byte - equalUntil int + Expect []byte + equal atomic.Int64 + notequal atomic.Bool } func (ce *checkEqual) Equal() bool { - return ce.equalUntil == len(ce.Expect) + return !ce.notequal.Load() && ce.equal.Load() == int64(len(ce.Expect)) } -func (ce *checkEqual) Write(p []byte) (int, error) { - endpos := ce.equalUntil + len(p) - if ce.equalUntil >= 0 && endpos <= len(ce.Expect) && bytes.Equal(p, ce.Expect[ce.equalUntil:endpos]) { - ce.equalUntil = endpos +func (ce *checkEqual) WriteAt(p []byte, offset int64) (int, error) { + endpos := int(offset) + len(p) + if offset >= 0 && endpos <= len(ce.Expect) && bytes.Equal(p, ce.Expect[int(offset):endpos]) { + ce.equal.Add(int64(len(p))) } else { - ce.equalUntil = -1 + ce.notequal.Store(true) } return len(p), nil } func (ks *keepstore) BlockUntrash(ctx context.Context, locator string) error { - li, err := parseLocator(locator) + li, err := getLocatorInfo(locator) if err != nil { return err } @@ -608,7 +632,7 @@ func (ks *keepstore) BlockUntrash(ctx context.Context, locator string) error { } func (ks *keepstore) BlockTouch(ctx context.Context, locator string) error { - li, err := parseLocator(locator) + li, err := getLocatorInfo(locator) if err != nil { return err } @@ -632,7 +656,7 @@ func (ks *keepstore) BlockTrash(ctx context.Context, locator string) error { if !ks.cluster.Collections.BlobTrash { return errMethodNotAllowed } - li, err := parseLocator(locator) + li, err := getLocatorInfo(locator) if err != nil { return err } @@ -685,40 +709,55 @@ func ctxToken(ctx context.Context) string { } } +// locatorInfo expresses the attributes of a locator that are relevant +// for keepstore decision-making. type locatorInfo struct { hash string size int - remote bool - signed bool + remote bool // locator has a +R hint + signed bool // locator has a +A hint } -func parseLocator(loc string) (locatorInfo, error) { +func getLocatorInfo(loc string) (locatorInfo, error) { var li locatorInfo - for i, part := range strings.Split(loc, "+") { - if i == 0 { - if len(part) != 32 { + plus := 0 // number of '+' chars seen so far + partlen := 0 // chars since last '+' + for i, c := range loc + "+" { + if c == '+' { + if partlen == 0 { + // double/leading/trailing '+' return li, errInvalidLocator } - li.hash = part + if plus == 0 { + if i != 32 { + return li, errInvalidLocator + } + li.hash = loc[:i] + } + if plus == 1 { + if size, err := strconv.Atoi(loc[i-partlen : i]); err == nil { + li.size = size + } + } + plus++ + partlen = 0 continue } - if i == 1 { - if size, err := strconv.Atoi(part); err == nil { - li.size = size - continue + partlen++ + if partlen == 1 { + if c == 'A' { + li.signed = true + } + if c == 'R' { + li.remote = true + } + if plus > 1 && c >= '0' && c <= '9' { + // size, if present at all, must come first + return li, errInvalidLocator } } - if len(part) == 0 { - return li, errInvalidLocator - } - if part[0] == 'A' { - li.signed = true - } - if part[0] == 'R' { - li.remote = true - } - if part[0] >= '0' && part[0] <= '9' { - // size, if present at all, must come first + if plus == 0 && !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f')) { + // non-hexadecimal char in hash part return li, errInvalidLocator } }