20318: Route (*KeepClient)Get() through disk cache layer.
authorTom Clegg <tom@curii.com>
Thu, 11 Jan 2024 17:01:11 +0000 (12:01 -0500)
committerTom Clegg <tom@curii.com>
Thu, 11 Jan 2024 20:13:28 +0000 (15:13 -0500)
The updated implementation maintains the existing calling signature
but has some semantic differences:

The block size returned by Get is now -1 if the block size is not
indicated by the supplied locator. All locators generated by Arvados
include sizes, and Get() doesn't work on bare hashes because they have
no permission signature, so in practice this only comes up for test
cases and keepstore's pull worker.

The url returned by Get is now "" because
* in general it is not necessarily available in principle (data often
  comes from the cache instead of a backend server)
* in the cases where it is available in principle, it would be a lot
  of trouble to propagate it
* the only known reason to reveal the url is to provide detail about
  an error, in which case the error message should already include the
  relevant url.

In some cases where the upstream server responds 200 but an error is
detected early in the response content, Get() now returns the error
itself, where the old Get() implementation would have returned a
reader whose Read method returns an error.

Arvados-DCO-1.1-Signed-off-by: Tom Clegg <tom@curii.com>

sdk/go/keepclient/keepclient.go
sdk/go/keepclient/keepclient_test.go
services/keepproxy/keepproxy.go
services/keepproxy/keepproxy_test.go
services/keepstore/pull_worker.go

index 3dc0aa0389158268af38861ec202574d8f20211c..2bd7996b59c0260caf1d61560316c3bc42e09357 100644 (file)
@@ -7,6 +7,7 @@
 package keepclient
 
 import (
+       "bufio"
        "bytes"
        "context"
        "crypto/md5"
@@ -408,16 +409,65 @@ func (kc *KeepClient) LocalLocator(locator string) (string, error) {
        return kc.upstreamGateway().LocalLocator(locator)
 }
 
-// Get retrieves a block, given a locator. Returns a reader, the
-// expected data length, the URL the block is being fetched from, and
-// an error.
+// Get retrieves the specified block from the local cache or a backend
+// server. Returns a reader, the expected data length (or -1 if not
+// known), and an error.
+//
+// The third return value (formerly a source URL in previous versions)
+// is an empty string.
 //
 // If the block checksum does not match, the final Read() on the
 // reader returned by this method will return a BadChecksum error
 // instead of EOF.
+//
+// New code should use BlockRead and/or ReadAt instead of Get.
 func (kc *KeepClient) Get(locator string) (io.ReadCloser, int64, string, error) {
-       rdr, size, url, _, err := kc.getOrHead("GET", locator, nil)
-       return rdr, size, url, err
+       loc, err := MakeLocator(locator)
+       if err != nil {
+               return nil, 0, "", err
+       }
+       pr, pw := io.Pipe()
+       go func() {
+               n, err := kc.BlockRead(context.Background(), arvados.BlockReadOptions{
+                       Locator: locator,
+                       WriteTo: pw,
+               })
+               if err != nil {
+                       pw.CloseWithError(err)
+               } else if loc.Size >= 0 && n != loc.Size {
+                       pw.CloseWithError(fmt.Errorf("expected block size %d but read %d bytes", loc.Size, n))
+               } else {
+                       pw.Close()
+               }
+       }()
+       // Wait for the first byte to arrive, so that, if there's an
+       // error before we receive any data, we can return the error
+       // directly, instead of indirectly via a reader that returns
+       // an error.
+       bufr := bufio.NewReader(pr)
+       _, err = bufr.Peek(1)
+       if err != nil && err != io.EOF {
+               pr.CloseWithError(err)
+               return nil, 0, "", err
+       }
+       if err == io.EOF && (loc.Size == 0 || loc.Hash == "d41d8cd98f00b204e9800998ecf8427e") {
+               // In the special case of the zero-length block, EOF
+               // error from Peek() is normal.
+               return pr, 0, "", nil
+       }
+       return struct {
+               io.Reader
+               io.Closer
+       }{
+               Reader: bufr,
+               Closer: pr,
+       }, int64(loc.Size), "", err
+}
+
+// BlockRead retrieves a block from the cache if it's present, otherwise
+// from the network.
+func (kc *KeepClient) BlockRead(ctx context.Context, opts arvados.BlockReadOptions) (int, error) {
+       return kc.upstreamGateway().BlockRead(ctx, opts)
 }
 
 // ReadAt retrieves a portion of block from the cache if it's
index ad5d12b505c0fa17e79efc4ce35c5aa365c7da3a..1e217e53624471c0081c4277ac9d3ddf96f3649d 100644 (file)
@@ -41,8 +41,16 @@ type ServerRequiredSuite struct{}
 // Standalone tests
 type StandaloneSuite struct{}
 
+var origHOME = os.Getenv("HOME")
+
 func (s *StandaloneSuite) SetUpTest(c *C) {
        RefreshServiceDiscovery()
+       // Prevent cache state from leaking between test cases
+       os.Setenv("HOME", c.MkDir())
+}
+
+func (s *StandaloneSuite) TearDownTest(c *C) {
+       os.Setenv("HOME", origHOME)
 }
 
 func pythonDir() string {
@@ -56,10 +64,13 @@ func (s *ServerRequiredSuite) SetUpSuite(c *C) {
 
 func (s *ServerRequiredSuite) TearDownSuite(c *C) {
        arvadostest.StopKeep(2)
+       os.Setenv("HOME", origHOME)
 }
 
 func (s *ServerRequiredSuite) SetUpTest(c *C) {
        RefreshServiceDiscovery()
+       // Prevent cache state from leaking between test cases
+       os.Setenv("HOME", c.MkDir())
 }
 
 func (s *ServerRequiredSuite) TestMakeKeepClient(c *C) {
@@ -715,15 +726,14 @@ func (s *StandaloneSuite) TestGet(c *C) {
        arv.ApiToken = "abc123"
        kc.SetServiceRoots(map[string]string{"x": ks.url}, nil, nil)
 
-       r, n, url2, err := kc.Get(hash)
-       defer r.Close()
-       c.Check(err, Equals, nil)
+       r, n, _, err := kc.Get(hash)
+       c.Assert(err, IsNil)
        c.Check(n, Equals, int64(3))
-       c.Check(url2, Equals, fmt.Sprintf("%s/%s", ks.url, hash))
 
        content, err2 := ioutil.ReadAll(r)
-       c.Check(err2, Equals, nil)
+       c.Check(err2, IsNil)
        c.Check(content, DeepEquals, []byte("foo"))
+       c.Check(r.Close(), IsNil)
 }
 
 func (s *StandaloneSuite) TestGet404(c *C) {
@@ -740,11 +750,10 @@ func (s *StandaloneSuite) TestGet404(c *C) {
        arv.ApiToken = "abc123"
        kc.SetServiceRoots(map[string]string{"x": ks.url}, nil, nil)
 
-       r, n, url2, err := kc.Get(hash)
+       r, n, _, err := kc.Get(hash)
        c.Check(err, Equals, BlockNotFound)
        c.Check(n, Equals, int64(0))
-       c.Check(url2, Equals, "")
-       c.Check(r, Equals, nil)
+       c.Check(r, IsNil)
 }
 
 func (s *StandaloneSuite) TestGetEmptyBlock(c *C) {
@@ -759,14 +768,14 @@ func (s *StandaloneSuite) TestGetEmptyBlock(c *C) {
        arv.ApiToken = "abc123"
        kc.SetServiceRoots(map[string]string{"x": ks.url}, nil, nil)
 
-       r, n, url2, err := kc.Get("d41d8cd98f00b204e9800998ecf8427e+0")
+       r, n, _, err := kc.Get("d41d8cd98f00b204e9800998ecf8427e+0")
        c.Check(err, IsNil)
        c.Check(n, Equals, int64(0))
-       c.Check(url2, Equals, "")
        c.Assert(r, NotNil)
        buf, err := ioutil.ReadAll(r)
        c.Check(err, IsNil)
        c.Check(buf, DeepEquals, []byte{})
+       c.Check(r.Close(), IsNil)
 }
 
 func (s *StandaloneSuite) TestGetFail(c *C) {
@@ -784,14 +793,14 @@ func (s *StandaloneSuite) TestGetFail(c *C) {
        kc.SetServiceRoots(map[string]string{"x": ks.url}, nil, nil)
        kc.Retries = 0
 
-       r, n, url2, err := kc.Get(hash)
+       r, n, _, err := kc.Get(hash)
        errNotFound, _ := err.(*ErrNotFound)
-       c.Check(errNotFound, NotNil)
-       c.Check(strings.Contains(errNotFound.Error(), "HTTP 500"), Equals, true)
-       c.Check(errNotFound.Temporary(), Equals, true)
+       if c.Check(errNotFound, NotNil) {
+               c.Check(strings.Contains(errNotFound.Error(), "HTTP 500"), Equals, true)
+               c.Check(errNotFound.Temporary(), Equals, true)
+       }
        c.Check(n, Equals, int64(0))
-       c.Check(url2, Equals, "")
-       c.Check(r, Equals, nil)
+       c.Check(r, IsNil)
 }
 
 func (s *StandaloneSuite) TestGetFailRetry(c *C) {
@@ -815,15 +824,14 @@ func (s *StandaloneSuite) TestGetFailRetry(c *C) {
        arv.ApiToken = "abc123"
        kc.SetServiceRoots(map[string]string{"x": ks.url}, nil, nil)
 
-       r, n, url2, err := kc.Get(hash)
-       defer r.Close()
-       c.Check(err, Equals, nil)
+       r, n, _, err := kc.Get(hash)
+       c.Assert(err, IsNil)
        c.Check(n, Equals, int64(3))
-       c.Check(url2, Equals, fmt.Sprintf("%s/%s", ks.url, hash))
 
-       content, err2 := ioutil.ReadAll(r)
-       c.Check(err2, Equals, nil)
+       content, err := ioutil.ReadAll(r)
+       c.Check(err, IsNil)
        c.Check(content, DeepEquals, []byte("foo"))
+       c.Check(r.Close(), IsNil)
 
        c.Logf("%q", st.reqIDs)
        c.Assert(len(st.reqIDs) > 1, Equals, true)
@@ -842,14 +850,14 @@ func (s *StandaloneSuite) TestGetNetError(c *C) {
        arv.ApiToken = "abc123"
        kc.SetServiceRoots(map[string]string{"x": "http://localhost:62222"}, nil, nil)
 
-       r, n, url2, err := kc.Get(hash)
+       r, n, _, err := kc.Get(hash)
        errNotFound, _ := err.(*ErrNotFound)
-       c.Check(errNotFound, NotNil)
-       c.Check(strings.Contains(errNotFound.Error(), "connection refused"), Equals, true)
-       c.Check(errNotFound.Temporary(), Equals, true)
+       if c.Check(errNotFound, NotNil) {
+               c.Check(strings.Contains(errNotFound.Error(), "connection refused"), Equals, true)
+               c.Check(errNotFound.Temporary(), Equals, true)
+       }
        c.Check(n, Equals, int64(0))
-       c.Check(url2, Equals, "")
-       c.Check(r, Equals, nil)
+       c.Check(r, IsNil)
 }
 
 func (s *StandaloneSuite) TestGetWithServiceHint(c *C) {
@@ -882,15 +890,14 @@ func (s *StandaloneSuite) TestGetWithServiceHint(c *C) {
                nil,
                map[string]string{uuid: ks.url})
 
-       r, n, uri, err := kc.Get(hash + "+K@" + uuid)
-       defer r.Close()
-       c.Check(err, Equals, nil)
+       r, n, _, err := kc.Get(hash + "+K@" + uuid)
+       c.Assert(err, IsNil)
        c.Check(n, Equals, int64(3))
-       c.Check(uri, Equals, fmt.Sprintf("%s/%s", ks.url, hash+"+K@"+uuid))
 
        content, err := ioutil.ReadAll(r)
-       c.Check(err, Equals, nil)
+       c.Check(err, IsNil)
        c.Check(content, DeepEquals, []byte("foo"))
+       c.Check(r.Close(), IsNil)
 }
 
 // Use a service hint to fetch from a local disk service, overriding
@@ -905,8 +912,8 @@ func (s *StandaloneSuite) TestGetWithLocalServiceHint(c *C) {
                c,
                "error if used",
                "abc123",
-               http.StatusOK,
-               []byte("foo")})
+               http.StatusBadGateway,
+               nil})
        defer ks0.listener.Close()
        // This one should be used:
        ks := RunFakeKeepServer(StubGetHandler{
@@ -935,15 +942,14 @@ func (s *StandaloneSuite) TestGetWithLocalServiceHint(c *C) {
                        uuid:                          ks.url},
        )
 
-       r, n, uri, err := kc.Get(hash + "+K@" + uuid)
-       defer r.Close()
-       c.Check(err, Equals, nil)
+       r, n, _, err := kc.Get(hash + "+K@" + uuid)
+       c.Assert(err, IsNil)
        c.Check(n, Equals, int64(3))
-       c.Check(uri, Equals, fmt.Sprintf("%s/%s", ks.url, hash+"+K@"+uuid))
 
        content, err := ioutil.ReadAll(r)
-       c.Check(err, Equals, nil)
+       c.Check(err, IsNil)
        c.Check(content, DeepEquals, []byte("foo"))
+       c.Check(r.Close(), IsNil)
 }
 
 func (s *StandaloneSuite) TestGetWithServiceHintFailoverToLocals(c *C) {
@@ -974,15 +980,14 @@ func (s *StandaloneSuite) TestGetWithServiceHintFailoverToLocals(c *C) {
                nil,
                map[string]string{uuid: ksGateway.url})
 
-       r, n, uri, err := kc.Get(hash + "+K@" + uuid)
-       c.Assert(err, Equals, nil)
-       defer r.Close()
+       r, n, _, err := kc.Get(hash + "+K@" + uuid)
+       c.Assert(err, IsNil)
        c.Check(n, Equals, int64(3))
-       c.Check(uri, Equals, fmt.Sprintf("%s/%s", ksLocal.url, hash+"+K@"+uuid))
 
        content, err := ioutil.ReadAll(r)
-       c.Check(err, Equals, nil)
+       c.Check(err, IsNil)
        c.Check(content, DeepEquals, []byte("foo"))
+       c.Check(r.Close(), IsNil)
 }
 
 type BarHandler struct {
@@ -1018,9 +1023,11 @@ func (s *StandaloneSuite) TestChecksum(c *C) {
        <-st.handled
 
        r, n, _, err = kc.Get(foohash)
-       c.Check(err, IsNil)
-       _, err = ioutil.ReadAll(r)
-       c.Check(n, Equals, int64(3))
+       if err == nil {
+               buf, readerr := ioutil.ReadAll(r)
+               c.Logf("%q", buf)
+               err = readerr
+       }
        c.Check(err, Equals, BadChecksum)
 
        <-st.handled
@@ -1072,16 +1079,16 @@ func (s *StandaloneSuite) TestGetWithFailures(c *C) {
        // an example that passes this Assert.)
        c.Assert(NewRootSorter(localRoots, hash).GetSortedRoots()[0], Not(Equals), ks1[0].url)
 
-       r, n, url2, err := kc.Get(hash)
+       r, n, _, err := kc.Get(hash)
 
        <-fh.handled
-       c.Check(err, Equals, nil)
+       c.Assert(err, IsNil)
        c.Check(n, Equals, int64(3))
-       c.Check(url2, Equals, fmt.Sprintf("%s/%s", ks1[0].url, hash))
 
        readContent, err2 := ioutil.ReadAll(r)
-       c.Check(err2, Equals, nil)
+       c.Check(err2, IsNil)
        c.Check(readContent, DeepEquals, content)
+       c.Check(r.Close(), IsNil)
 }
 
 func (s *ServerRequiredSuite) TestPutGetHead(c *C) {
@@ -1106,14 +1113,16 @@ func (s *ServerRequiredSuite) TestPutGetHead(c *C) {
                c.Check(replicas, Equals, 2)
        }
        {
-               r, n, url2, err := kc.Get(hash)
-               c.Check(err, Equals, nil)
+               r, n, _, err := kc.Get(hash)
+               c.Check(err, IsNil)
                c.Check(n, Equals, int64(len(content)))
-               c.Check(url2, Matches, fmt.Sprintf("http://localhost:\\d+/%s", hash))
                if c.Check(r, NotNil) {
-                       readContent, err2 := ioutil.ReadAll(r)
-                       c.Check(err2, Equals, nil)
-                       c.Check(readContent, DeepEquals, content)
+                       readContent, err := ioutil.ReadAll(r)
+                       c.Check(err, IsNil)
+                       if c.Check(len(readContent), Equals, len(content)) {
+                               c.Check(readContent, DeepEquals, content)
+                       }
+                       c.Check(r.Close(), IsNil)
                }
        }
        {
index 6d173ab00dcd4fc407b095f2d1b54f9f1f7cb127..39ffd45cbe37b69f663dc6093acd4dfe221c74a1 100644 (file)
@@ -304,7 +304,6 @@ func (h *proxyHandler) Get(resp http.ResponseWriter, req *http.Request) {
        var err error
        var status int
        var expectLength, responseLength int64
-       var proxiedURI = "-"
 
        logger := ctxlog.FromContext(req.Context())
        defer func() {
@@ -312,7 +311,6 @@ func (h *proxyHandler) Get(resp http.ResponseWriter, req *http.Request) {
                        "locator":        locator,
                        "expectLength":   expectLength,
                        "responseLength": responseLength,
-                       "proxiedURI":     proxiedURI,
                        "err":            err,
                })
                if status != http.StatusOK {
@@ -346,9 +344,9 @@ func (h *proxyHandler) Get(resp http.ResponseWriter, req *http.Request) {
 
        switch req.Method {
        case "HEAD":
-               expectLength, proxiedURI, err = kc.Ask(locator)
+               expectLength, _, err = kc.Ask(locator)
        case "GET":
-               reader, expectLength, proxiedURI, err = kc.Get(locator)
+               reader, expectLength, _, err = kc.Get(locator)
                if reader != nil {
                        defer reader.Close()
                }
index f57aeb3617cc9a8c59197c520563eb23a79465f2..a24eb752f9c01817adfd1c5fc5b40fa1ddcfcc6b 100644 (file)
@@ -641,7 +641,7 @@ func getIndexWorker(c *C, useConfig bool) {
        c.Check(rep, Equals, 2)
        c.Check(err, Equals, nil)
 
-       reader, blocklen, _, err := kc.Get(hash)
+       reader, blocklen, _, err := kc.Get(hash2)
        c.Assert(err, IsNil)
        c.Check(blocklen, Equals, int64(10))
        all, err := ioutil.ReadAll(reader)
@@ -783,10 +783,12 @@ func (s *NoKeepServerSuite) TestAskGetNoKeepServerError(c *C) {
                },
        } {
                err := f()
-               c.Assert(err, NotNil)
+               c.Check(err, NotNil)
                errNotFound, _ := err.(*keepclient.ErrNotFound)
-               c.Check(errNotFound.Temporary(), Equals, true)
-               c.Check(err, ErrorMatches, `.*HTTP 502.*`)
+               if c.Check(errNotFound, NotNil) {
+                       c.Check(errNotFound.Temporary(), Equals, true)
+                       c.Check(err, ErrorMatches, `.*HTTP 502.*`)
+               }
        }
 }
 
index b9194fe6f66b029d4262122aa384788f80eff4eb..348bfb4df00087a1726ef36cbd186fe0eb5ea4c7 100644 (file)
@@ -59,7 +59,7 @@ func (h *handler) pullItemAndProcess(pullRequest PullRequest) error {
 
        signedLocator := SignLocator(h.Cluster, pullRequest.Locator, keepClient.Arvados.ApiToken, time.Now().Add(time.Minute))
 
-       reader, contentLen, _, err := GetContent(signedLocator, keepClient)
+       reader, _, _, err := GetContent(signedLocator, keepClient)
        if err != nil {
                return err
        }
@@ -73,7 +73,7 @@ func (h *handler) pullItemAndProcess(pullRequest PullRequest) error {
                return err
        }
 
-       if (readContent == nil) || (int64(len(readContent)) != contentLen) {
+       if readContent == nil {
                return fmt.Errorf("Content not found for: %s", signedLocator)
        }