5414: Add client support for Keep service hints.
authorTom Clegg <tom@curoverse.com>
Sun, 29 Mar 2015 00:26:00 +0000 (20:26 -0400)
committerTom Clegg <tom@curoverse.com>
Wed, 8 Apr 2015 15:45:38 +0000 (11:45 -0400)
Also, some incidental improvements in nearby code:

* Consistent logging in keepproxy, with one reusable logging statement
  instead of a different statement/format for each outcome.

* In sdk/go/keepclient, remove public AuthorizedGet and AuthorizedAsk
  methods. Instead, Get() and Ask() accept a locator (with or without
  a permission token) and do the right thing. Callers don't have to
  parse locators to decide which method to call.

* In sdk/go/keepclient, use an RWMutex instead of atomic.LoadPointer()
  and unsafe.Pointer() to update KeepClient root maps safely.

* In sdk/go/keepclient, DiscoverKeepServers() doesn't return the new
  root maps, just an error. In normal usage, the caller only cares
  whether discovery was successful.

Also, some Go style fixes in nearby code:

* Use pointer receivers for all KeepClient methods.
  https://golang.org/doc/faq#methods_on_values_or_pointers

* Use receiver name "kc", not "this".
  https://github.com/golang/go/wiki/CodeReviewComments#receiver-names

* Handle errors first, use minimal indentation for normal code path.
  https://github.com/golang/go/wiki/CodeReviewComments#indent-error-flow

12 files changed:
sdk/go/keepclient/keepclient.go
sdk/go/keepclient/keepclient_test.go
sdk/go/keepclient/support.go
sdk/python/arvados/keep.py
sdk/python/tests/arvados_testutil.py
sdk/python/tests/test_keep_client.py
services/keepproxy/keepproxy.go
services/keepproxy/keepproxy_test.go
services/keepstore/keepstore.go
services/keepstore/pull_worker.go
services/keepstore/pull_worker_integration_test.go
services/keepstore/pull_worker_test.go

index 5d791948dcb808f3373555d183d61f7df5a22100..29352b42bba0390acd4ba109e387d3ab28f1f5fe 100644 (file)
@@ -14,11 +14,9 @@ import (
        "net/http"
        "os"
        "regexp"
+       "strconv"
        "strings"
        "sync"
-       "sync/atomic"
-       "time"
-       "unsafe"
 )
 
 // A Keep "block" is 64MB.
@@ -26,7 +24,7 @@ const BLOCKSIZE = 64 * 1024 * 1024
 
 var BlockNotFound = errors.New("Block not found")
 var InsufficientReplicasError = errors.New("Could not write sufficient replicas")
-var OversizeBlockError = errors.New("Block too big")
+var OversizeBlockError = errors.New("Exceeded maximum block size ("+strconv.Itoa(BLOCKSIZE)+")")
 var MissingArvadosApiHost = errors.New("Missing required environment variable ARVADOS_API_HOST")
 var MissingArvadosApiToken = errors.New("Missing required environment variable ARVADOS_API_TOKEN")
 
@@ -38,42 +36,43 @@ type KeepClient struct {
        Arvados       *arvadosclient.ArvadosClient
        Want_replicas int
        Using_proxy   bool
-       service_roots *map[string]string
-       lock          sync.Mutex
+       localRoots    *map[string]string
+       gatewayRoots  *map[string]string
+       lock          sync.RWMutex
        Client        *http.Client
 }
 
 // Create a new KeepClient.  This will contact the API server to discover Keep
 // servers.
-func MakeKeepClient(arv *arvadosclient.ArvadosClient) (kc KeepClient, err error) {
+func MakeKeepClient(arv *arvadosclient.ArvadosClient) (*KeepClient, error) {
        var matchTrue = regexp.MustCompile("^(?i:1|yes|true)$")
        insecure := matchTrue.MatchString(os.Getenv("ARVADOS_API_HOST_INSECURE"))
-       kc KeepClient{
+       kc := &KeepClient{
                Arvados:       arv,
                Want_replicas: 2,
                Using_proxy:   false,
                Client: &http.Client{Transport: &http.Transport{
                        TLSClientConfig: &tls.Config{InsecureSkipVerify: insecure}}},
        }
-       _, err = (&kc).DiscoverKeepServers()
-
-       return kc, err
+       return kc, kc.DiscoverKeepServers()
 }
 
-// Put a block given the block hash, a reader with the block data, and the
-// expected length of that data.  The desired number of replicas is given in
-// KeepClient.Want_replicas.  Returns the number of replicas that were written
-// and if there was an error.  Note this will return InsufficientReplias
-// whenever 0 <= replicas < this.Wants_replicas.
-func (this KeepClient) PutHR(hash string, r io.Reader, expectedLength int64) (locator string, replicas int, err error) {
-
+// Put a block given the block hash, a reader, and the number of bytes
+// to read from the reader (which must be between 0 and BLOCKSIZE).
+//
+// Returns the locator for the written block, the number of replicas
+// written, and an error.
+//
+// Returns an InsufficientReplicas error if 0 <= replicas <
+// kc.Wants_replicas.
+func (kc *KeepClient) PutHR(hash string, r io.Reader, dataBytes int64) (string, int, error) {
        // Buffer for reads from 'r'
        var bufsize int
-       if expectedLength > 0 {
-               if expectedLength > BLOCKSIZE {
+       if dataBytes > 0 {
+               if dataBytes > BLOCKSIZE {
                        return "", 0, OversizeBlockError
                }
-               bufsize = int(expectedLength)
+               bufsize = int(dataBytes)
        } else {
                bufsize = BLOCKSIZE
        }
@@ -81,171 +80,167 @@ func (this KeepClient) PutHR(hash string, r io.Reader, expectedLength int64) (lo
        t := streamer.AsyncStreamFromReader(bufsize, HashCheckingReader{r, md5.New(), hash})
        defer t.Close()
 
-       return this.putReplicas(hash, t, expectedLength)
+       return kc.putReplicas(hash, t, dataBytes)
 }
 
-// Put a block given the block hash and a byte buffer.  The desired number of
-// replicas is given in KeepClient.Want_replicas.  Returns the number of
-// replicas that were written and if there was an error.  Note this will return
-// InsufficientReplias whenever 0 <= replicas < this.Wants_replicas.
-func (this KeepClient) PutHB(hash string, buf []byte) (locator string, replicas int, err error) {
+// PutHB writes a block to Keep. The hash of the bytes is given in
+// hash, and the data is given in buf.
+//
+// Return values are the same as for PutHR.
+func (kc *KeepClient) PutHB(hash string, buf []byte) (string, int, error) {
        t := streamer.AsyncStreamFromSlice(buf)
        defer t.Close()
-
-       return this.putReplicas(hash, t, int64(len(buf)))
+       return kc.putReplicas(hash, t, int64(len(buf)))
 }
 
-// Put a block given a buffer.  The hash will be computed.  The desired number
-// of replicas is given in KeepClient.Want_replicas.  Returns the number of
-// replicas that were written and if there was an error.  Note this will return
-// InsufficientReplias whenever 0 <= replicas < this.Wants_replicas.
-func (this KeepClient) PutB(buffer []byte) (locator string, replicas int, err error) {
+// PutB writes a block to Keep. It computes the hash itself.
+//
+// Return values are the same as for PutHR.
+func (kc *KeepClient) PutB(buffer []byte) (string, int, error) {
        hash := fmt.Sprintf("%x", md5.Sum(buffer))
-       return this.PutHB(hash, buffer)
+       return kc.PutHB(hash, buffer)
 }
 
-// Put a block, given a Reader.  This will read the entire reader into a buffer
-// to compute the hash.  The desired number of replicas is given in
-// KeepClient.Want_replicas.  Returns the number of replicas that were written
-// and if there was an error.  Note this will return InsufficientReplias
-// whenever 0 <= replicas < this.Wants_replicas.  Also nhote that if the block
-// hash and data size are available, PutHR() is more efficient.
-func (this KeepClient) PutR(r io.Reader) (locator string, replicas int, err error) {
+// PutR writes a block to Keep. It first reads all data from r into a buffer
+// in order to compute the hash.
+//
+// Return values are the same as for PutHR.
+//
+// If the block hash and data size are known, PutHR is more efficient.
+func (kc *KeepClient) PutR(r io.Reader) (locator string, replicas int, err error) {
        if buffer, err := ioutil.ReadAll(r); err != nil {
                return "", 0, err
        } else {
-               return this.PutB(buffer)
+               return kc.PutB(buffer)
        }
 }
 
-// Get a block given a hash.  Return a reader, the expected data length, the
-// URL the block was fetched from, and if there was an error.  If the block
-// checksum does not match, the final Read() on the reader returned by this
-// method will return a BadChecksum error instead of EOF.
-func (this KeepClient) Get(hash string) (reader io.ReadCloser,
-       contentLength int64, url string, err error) {
-       return this.AuthorizedGet(hash, "", "")
-}
-
-// Get a block given a hash, with additional authorization provided by
-// signature and timestamp.  Return a reader, the expected data length, the URL
-// the block was fetched from, and if there was an error.  If the block
-// checksum does not match, the final Read() on the reader returned by this
-// method will return a BadChecksum error instead of EOF.
-func (this KeepClient) AuthorizedGet(hash string,
-       signature string,
-       timestamp string) (reader io.ReadCloser,
-       contentLength int64, url string, err error) {
-
-       // Take the hash of locator and timestamp in order to identify this
-       // specific transaction in log statements.
-       requestId := fmt.Sprintf("%x", md5.Sum([]byte(hash+time.Now().String())))[0:8]
-
-       // Calculate the ordering for asking servers
-       sv := NewRootSorter(this.ServiceRoots(), hash).GetSortedRoots()
-
-       for _, host := range sv {
-               var req *http.Request
-               var err error
-               var url string
-               if signature != "" {
-                       url = fmt.Sprintf("%s/%s+A%s@%s", host, hash,
-                               signature, timestamp)
-               } else {
-                       url = fmt.Sprintf("%s/%s", host, hash)
-               }
-               if req, err = http.NewRequest("GET", url, nil); err != nil {
+// Get() retrieves a block, given a locator. Returns a reader, the
+// expected data length, the URL the block is being fetched from, and
+// an error.
+//
+// If the block checksum does not match, the final Read() on the
+// reader returned by this method will return a BadChecksum error
+// instead of EOF.
+func (kc *KeepClient) Get(locator string) (io.ReadCloser, int64, string, error) {
+       var errs []string
+       for _, host := range kc.getSortedRoots(locator) {
+               url := host+"/"+locator
+               req, err := http.NewRequest("GET", url, nil)
+               if err != nil {
                        continue
                }
-
-               req.Header.Add("Authorization", fmt.Sprintf("OAuth2 %s", this.Arvados.ApiToken))
-
-               log.Printf("[%v] Begin download %s", requestId, url)
-
-               var resp *http.Response
-               if resp, err = this.Client.Do(req); err != nil || resp.StatusCode != http.StatusOK {
-                       statusCode := -1
-                       var respbody []byte
+               req.Header.Add("Authorization", fmt.Sprintf("OAuth2 %s", kc.Arvados.ApiToken))
+               resp, err := kc.Client.Do(req)
+               if err != nil || resp.StatusCode != http.StatusOK {
                        if resp != nil {
-                               statusCode = resp.StatusCode
+                               var respbody []byte
                                if resp.Body != nil {
                                        respbody, _ = ioutil.ReadAll(&io.LimitedReader{resp.Body, 4096})
                                }
+                               errs = append(errs, fmt.Sprintf("%s: %d %s",
+                                       url, resp.StatusCode, strings.TrimSpace(string(respbody))))
+                       } else {
+                               errs = append(errs, fmt.Sprintf("%s: %v", url, err))
                        }
-                       response := strings.TrimSpace(string(respbody))
-                       log.Printf("[%v] Download %v status code: %v error: \"%v\" response: \"%v\"",
-                               requestId, url, statusCode, err, response)
                        continue
                }
-
-               if resp.StatusCode == http.StatusOK {
-                       log.Printf("[%v] Download %v status code: %v", requestId, url, resp.StatusCode)
-                       return HashCheckingReader{resp.Body, md5.New(), hash}, resp.ContentLength, url, nil
-               }
+               return HashCheckingReader{
+                       Reader: resp.Body,
+                       Hash: md5.New(),
+                       Check: locator[0:32],
+               }, resp.ContentLength, url, nil
        }
-
+       log.Printf("DEBUG: GET %s failed: %v", locator, errs)
        return nil, 0, "", BlockNotFound
 }
 
-// Determine if a block with the given hash is available and readable, but does
-// not return the block contents.
-func (this KeepClient) Ask(hash string) (contentLength int64, url string, err error) {
-       return this.AuthorizedAsk(hash, "", "")
-}
-
-// Determine if a block with the given hash is available and readable with the
-// given signature and timestamp, but does not return the block contents.
-func (this KeepClient) AuthorizedAsk(hash string, signature string,
-       timestamp string) (contentLength int64, url string, err error) {
-       // Calculate the ordering for asking servers
-       sv := NewRootSorter(this.ServiceRoots(), hash).GetSortedRoots()
-
-       for _, host := range sv {
-               var req *http.Request
-               var err error
-               if signature != "" {
-                       url = fmt.Sprintf("%s/%s+A%s@%s", host, hash,
-                               signature, timestamp)
-               } else {
-                       url = fmt.Sprintf("%s/%s", host, hash)
-               }
-
-               if req, err = http.NewRequest("HEAD", url, nil); err != nil {
-                       continue
-               }
-
-               req.Header.Add("Authorization", fmt.Sprintf("OAuth2 %s", this.Arvados.ApiToken))
-
-               var resp *http.Response
-               if resp, err = this.Client.Do(req); err != nil {
+// Ask() verifies that a block with the given hash is available and
+// readable, according to at least one Keep service. Unlike Get, it
+// does not retrieve the data or verify that the data content matches
+// the hash specified by the locator.
+//
+// Returns the data size (content length) reported by the Keep service
+// and the URI reporting the data size.
+func (kc *KeepClient) Ask(locator string) (int64, string, error) {
+       for _, host := range kc.getSortedRoots(locator) {
+               url := host+"/"+locator
+               req, err := http.NewRequest("HEAD", url, nil)
+               if err != nil {
                        continue
                }
-
-               if resp.StatusCode == http.StatusOK {
+               req.Header.Add("Authorization", fmt.Sprintf("OAuth2 %s", kc.Arvados.ApiToken))
+               if resp, err := kc.Client.Do(req); err == nil && resp.StatusCode == http.StatusOK {
                        return resp.ContentLength, url, nil
                }
        }
-
        return 0, "", BlockNotFound
+}
 
+// LocalRoots() returns the map of local (i.e., disk and proxy) Keep
+// services: uuid -> baseURI.
+func (kc *KeepClient) LocalRoots() map[string]string {
+       kc.lock.RLock()
+       defer kc.lock.RUnlock()
+       return *kc.localRoots
 }
 
-// Atomically read the service_roots field.
-func (this *KeepClient) ServiceRoots() map[string]string {
-       r := (*map[string]string)(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&this.service_roots))))
-       return *r
+// GatewayRoots() returns the map of Keep remote gateway services:
+// uuid -> baseURI.
+func (kc *KeepClient) GatewayRoots() map[string]string {
+       kc.lock.RLock()
+       defer kc.lock.RUnlock()
+       return *kc.gatewayRoots
 }
 
-// Atomically update the service_roots field.  Enables you to update
-// service_roots without disrupting any GET or PUT operations that might
-// already be in progress.
-func (this *KeepClient) SetServiceRoots(new_roots map[string]string) {
-       roots := make(map[string]string)
-       for uuid, root := range new_roots {
-               roots[uuid] = root
+// SetServiceRoots updates the localRoots and gatewayRoots maps,
+// without risk of disrupting operations that are already in progress.
+//
+// The KeepClient makes its own copy of the supplied maps, so the
+// caller can reuse/modify them after SetServiceRoots returns, but
+// they should not be modified by any other goroutine while
+// SetServiceRoots is running.
+func (kc *KeepClient) SetServiceRoots(newLocals, newGateways map[string]string) {
+       locals := make(map[string]string)
+       for uuid, root := range newLocals {
+               locals[uuid] = root
        }
-       atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&this.service_roots)),
-               unsafe.Pointer(&roots))
+       gateways := make(map[string]string)
+       for uuid, root := range newGateways {
+               gateways[uuid] = root
+       }
+       kc.lock.Lock()
+       defer kc.lock.Unlock()
+       kc.localRoots = &locals
+       kc.gatewayRoots = &gateways
+}
+
+// getSortedRoots returns a list of base URIs of Keep services, in the
+// order they should be attempted in order to retrieve content for the
+// given locator.
+func (kc *KeepClient) getSortedRoots(locator string) []string {
+       var found []string
+       for _, hint := range strings.Split(locator, "+") {
+               if len(hint) < 7 || hint[0:2] != "K@" {
+                       // Not a service hint.
+                       continue
+               }
+               if len(hint) == 7 {
+                       // +K@abcde means fetch from proxy at
+                       // keep.abcde.arvadosapi.com
+                       found = append(found, "https://keep."+hint[2:]+".arvadosapi.com")
+               } else if len(hint) == 29 {
+                       // +K@abcde-abcde-abcdeabcdeabcde means fetch
+                       // from gateway with given uuid
+                       if gwURI, ok := kc.GatewayRoots()[hint[2:]]; ok {
+                               found = append(found, gwURI)
+                       }
+                       // else this hint is no use to us; carry on.
+               }
+       }
+       // After trying all usable service hints, fall back to local roots.
+       found = append(found, NewRootSorter(kc.LocalRoots(), locator[0:32]).GetSortedRoots()...)
+       return found
 }
 
 type Locator struct {
@@ -279,12 +274,9 @@ func MakeLocator2(hash string, hints string) (locator Locator) {
        return locator
 }
 
-func MakeLocator(path string) Locator {
-       pathpattern, err := regexp.Compile("^([0-9a-f]{32})([+].*)?$")
-       if err != nil {
-               log.Print("Don't like regexp", err)
-       }
+var pathpattern = regexp.MustCompile("^([0-9a-f]{32})([+].*)?$")
 
+func MakeLocator(path string) Locator {
        sm := pathpattern.FindStringSubmatch(path)
        if sm == nil {
                log.Print("Failed match ", path)
index cbd27d72e7c7e9310de1ed027e47912b7a187baa..825696bbfc3f9309be81eca6c3d250c833b72745 100644 (file)
@@ -63,8 +63,8 @@ func (s *ServerRequiredSuite) TestMakeKeepClient(c *C) {
        kc, err := MakeKeepClient(&arv)
 
        c.Assert(err, Equals, nil)
-       c.Check(len(kc.ServiceRoots()), Equals, 2)
-       for _, root := range kc.ServiceRoots() {
+       c.Check(len(kc.LocalRoots()), Equals, 2)
+       for _, root := range kc.LocalRoots() {
                c.Check(root, Matches, "http://localhost:\\d+")
        }
 }
@@ -77,14 +77,14 @@ type StubPutHandler struct {
        handled        chan string
 }
 
-func (this StubPutHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
-       this.c.Check(req.URL.Path, Equals, "/"+this.expectPath)
-       this.c.Check(req.Header.Get("Authorization"), Equals, fmt.Sprintf("OAuth2 %s", this.expectApiToken))
+func (sph StubPutHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
+       sph.c.Check(req.URL.Path, Equals, "/"+sph.expectPath)
+       sph.c.Check(req.Header.Get("Authorization"), Equals, fmt.Sprintf("OAuth2 %s", sph.expectApiToken))
        body, err := ioutil.ReadAll(req.Body)
-       this.c.Check(err, Equals, nil)
-       this.c.Check(body, DeepEquals, []byte(this.expectBody))
+       sph.c.Check(err, Equals, nil)
+       sph.c.Check(body, DeepEquals, []byte(sph.expectBody))
        resp.WriteHeader(200)
-       this.handled <- fmt.Sprintf("http://%s", req.Host)
+       sph.handled <- fmt.Sprintf("http://%s", req.Host)
 }
 
 func RunFakeKeepServer(st http.Handler) (ks KeepServer) {
@@ -98,7 +98,7 @@ func RunFakeKeepServer(st http.Handler) (ks KeepServer) {
        return
 }
 
-func UploadToStubHelper(c *C, st http.Handler, f func(KeepClient, string,
+func UploadToStubHelper(c *C, st http.Handler, f func(*KeepClient, string,
        io.ReadCloser, io.WriteCloser, chan uploadStatus)) {
 
        ks := RunFakeKeepServer(st)
@@ -126,7 +126,7 @@ func (s *StandaloneSuite) TestUploadToStubKeepServer(c *C) {
                make(chan string)}
 
        UploadToStubHelper(c, st,
-               func(kc KeepClient, url string, reader io.ReadCloser,
+               func(kc *KeepClient, url string, reader io.ReadCloser,
                        writer io.WriteCloser, upload_status chan uploadStatus) {
 
                        go kc.uploadToKeepServer(url, st.expectPath, reader, upload_status, int64(len("foo")), "TestUploadToStubKeepServer")
@@ -153,7 +153,7 @@ func (s *StandaloneSuite) TestUploadToStubKeepServerBufferReader(c *C) {
                make(chan string)}
 
        UploadToStubHelper(c, st,
-               func(kc KeepClient, url string, reader io.ReadCloser,
+               func(kc *KeepClient, url string, reader io.ReadCloser,
                        writer io.WriteCloser, upload_status chan uploadStatus) {
 
                        tr := streamer.AsyncStreamFromReader(512, reader)
@@ -179,9 +179,9 @@ type FailHandler struct {
        handled chan string
 }
 
-func (this FailHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
+func (fh FailHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
        resp.WriteHeader(500)
-       this.handled <- fmt.Sprintf("http://%s", req.Host)
+       fh.handled <- fmt.Sprintf("http://%s", req.Host)
 }
 
 func (s *StandaloneSuite) TestFailedUploadToStubKeepServer(c *C) {
@@ -193,7 +193,7 @@ func (s *StandaloneSuite) TestFailedUploadToStubKeepServer(c *C) {
        hash := "acbd18db4cc2f85cedef654fccc4a4d8"
 
        UploadToStubHelper(c, st,
-               func(kc KeepClient, url string, reader io.ReadCloser,
+               func(kc *KeepClient, url string, reader io.ReadCloser,
                        writer io.WriteCloser, upload_status chan uploadStatus) {
 
                        go kc.uploadToKeepServer(url, hash, reader, upload_status, 3, "TestFailedUploadToStubKeepServer")
@@ -242,21 +242,21 @@ func (s *StandaloneSuite) TestPutB(c *C) {
 
        kc.Want_replicas = 2
        arv.ApiToken = "abc123"
-       service_roots := make(map[string]string)
+       localRoots := make(map[string]string)
 
        ks := RunSomeFakeKeepServers(st, 5)
 
        for i, k := range ks {
-               service_roots[fmt.Sprintf("zzzzz-bi6l4-fakefakefake%03d", i)] = k.url
+               localRoots[fmt.Sprintf("zzzzz-bi6l4-fakefakefake%03d", i)] = k.url
                defer k.listener.Close()
        }
 
-       kc.SetServiceRoots(service_roots)
+       kc.SetServiceRoots(localRoots, nil)
 
        kc.PutB([]byte("foo"))
 
        shuff := NewRootSorter(
-               kc.ServiceRoots(), Md5String("foo")).GetSortedRoots()
+               kc.LocalRoots(), Md5String("foo")).GetSortedRoots()
 
        s1 := <-st.handled
        s2 := <-st.handled
@@ -285,16 +285,16 @@ func (s *StandaloneSuite) TestPutHR(c *C) {
 
        kc.Want_replicas = 2
        arv.ApiToken = "abc123"
-       service_roots := make(map[string]string)
+       localRoots := make(map[string]string)
 
        ks := RunSomeFakeKeepServers(st, 5)
 
        for i, k := range ks {
-               service_roots[fmt.Sprintf("zzzzz-bi6l4-fakefakefake%03d", i)] = k.url
+               localRoots[fmt.Sprintf("zzzzz-bi6l4-fakefakefake%03d", i)] = k.url
                defer k.listener.Close()
        }
 
-       kc.SetServiceRoots(service_roots)
+       kc.SetServiceRoots(localRoots, nil)
 
        reader, writer := io.Pipe()
 
@@ -305,7 +305,7 @@ func (s *StandaloneSuite) TestPutHR(c *C) {
 
        kc.PutHR(hash, reader, 3)
 
-       shuff := NewRootSorter(kc.ServiceRoots(), hash).GetSortedRoots()
+       shuff := NewRootSorter(kc.LocalRoots(), hash).GetSortedRoots()
        log.Print(shuff)
 
        s1 := <-st.handled
@@ -339,24 +339,24 @@ func (s *StandaloneSuite) TestPutWithFail(c *C) {
 
        kc.Want_replicas = 2
        arv.ApiToken = "abc123"
-       service_roots := make(map[string]string)
+       localRoots := make(map[string]string)
 
        ks1 := RunSomeFakeKeepServers(st, 4)
        ks2 := RunSomeFakeKeepServers(fh, 1)
 
        for i, k := range ks1 {
-               service_roots[fmt.Sprintf("zzzzz-bi6l4-fakefakefake%03d", i)] = k.url
+               localRoots[fmt.Sprintf("zzzzz-bi6l4-fakefakefake%03d", i)] = k.url
                defer k.listener.Close()
        }
        for i, k := range ks2 {
-               service_roots[fmt.Sprintf("zzzzz-bi6l4-fakefakefake%03d", i+len(ks1))] = k.url
+               localRoots[fmt.Sprintf("zzzzz-bi6l4-fakefakefake%03d", i+len(ks1))] = k.url
                defer k.listener.Close()
        }
 
-       kc.SetServiceRoots(service_roots)
+       kc.SetServiceRoots(localRoots, nil)
 
        shuff := NewRootSorter(
-               kc.ServiceRoots(), Md5String("foo")).GetSortedRoots()
+               kc.LocalRoots(), Md5String("foo")).GetSortedRoots()
 
        phash, replicas, err := kc.PutB([]byte("foo"))
 
@@ -395,21 +395,21 @@ func (s *StandaloneSuite) TestPutWithTooManyFail(c *C) {
 
        kc.Want_replicas = 2
        arv.ApiToken = "abc123"
-       service_roots := make(map[string]string)
+       localRoots := make(map[string]string)
 
        ks1 := RunSomeFakeKeepServers(st, 1)
        ks2 := RunSomeFakeKeepServers(fh, 4)
 
        for i, k := range ks1 {
-               service_roots[fmt.Sprintf("zzzzz-bi6l4-fakefakefake%03d", i)] = k.url
+               localRoots[fmt.Sprintf("zzzzz-bi6l4-fakefakefake%03d", i)] = k.url
                defer k.listener.Close()
        }
        for i, k := range ks2 {
-               service_roots[fmt.Sprintf("zzzzz-bi6l4-fakefakefake%03d", i+len(ks1))] = k.url
+               localRoots[fmt.Sprintf("zzzzz-bi6l4-fakefakefake%03d", i+len(ks1))] = k.url
                defer k.listener.Close()
        }
 
-       kc.SetServiceRoots(service_roots)
+       kc.SetServiceRoots(localRoots, nil)
 
        _, replicas, err := kc.PutB([]byte("foo"))
 
@@ -424,14 +424,16 @@ type StubGetHandler struct {
        c              *C
        expectPath     string
        expectApiToken string
-       returnBody     []byte
+       httpStatus     int
+       body           []byte
 }
 
-func (this StubGetHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
-       this.c.Check(req.URL.Path, Equals, "/"+this.expectPath)
-       this.c.Check(req.Header.Get("Authorization"), Equals, fmt.Sprintf("OAuth2 %s", this.expectApiToken))
-       resp.Header().Set("Content-Length", fmt.Sprintf("%d", len(this.returnBody)))
-       resp.Write(this.returnBody)
+func (sgh StubGetHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
+       sgh.c.Check(req.URL.Path, Equals, "/"+sgh.expectPath)
+       sgh.c.Check(req.Header.Get("Authorization"), Equals, fmt.Sprintf("OAuth2 %s", sgh.expectApiToken))
+       resp.WriteHeader(sgh.httpStatus)
+       resp.Header().Set("Content-Length", fmt.Sprintf("%d", len(sgh.body)))
+       resp.Write(sgh.body)
 }
 
 func (s *StandaloneSuite) TestGet(c *C) {
@@ -443,6 +445,7 @@ func (s *StandaloneSuite) TestGet(c *C) {
                c,
                hash,
                "abc123",
+               http.StatusOK,
                []byte("foo")}
 
        ks := RunFakeKeepServer(st)
@@ -451,7 +454,7 @@ func (s *StandaloneSuite) TestGet(c *C) {
        arv, err := arvadosclient.MakeArvadosClient()
        kc, _ := MakeKeepClient(&arv)
        arv.ApiToken = "abc123"
-       kc.SetServiceRoots(map[string]string{"x": ks.url})
+       kc.SetServiceRoots(map[string]string{"x": ks.url}, nil)
 
        r, n, url2, err := kc.Get(hash)
        defer r.Close()
@@ -477,7 +480,7 @@ func (s *StandaloneSuite) TestGetFail(c *C) {
        arv, err := arvadosclient.MakeArvadosClient()
        kc, _ := MakeKeepClient(&arv)
        arv.ApiToken = "abc123"
-       kc.SetServiceRoots(map[string]string{"x": ks.url})
+       kc.SetServiceRoots(map[string]string{"x": ks.url}, nil)
 
        r, n, url2, err := kc.Get(hash)
        c.Check(err, Equals, BlockNotFound)
@@ -486,6 +489,82 @@ func (s *StandaloneSuite) TestGetFail(c *C) {
        c.Check(r, Equals, nil)
 }
 
+func (s *StandaloneSuite) TestGetWithServiceHint(c *C) {
+       uuid := "zzzzz-bi6l4-123451234512345"
+       hash := fmt.Sprintf("%x", md5.Sum([]byte("foo")))
+
+       // This one shouldn't be used:
+       ks0 := RunFakeKeepServer(StubGetHandler{
+               c,
+               "error if used",
+               "abc123",
+               http.StatusOK,
+               []byte("foo")})
+       defer ks0.listener.Close()
+       // This one should be used:
+       ks := RunFakeKeepServer(StubGetHandler{
+               c,
+               hash+"+K@"+uuid,
+               "abc123",
+               http.StatusOK,
+               []byte("foo")})
+       defer ks.listener.Close()
+
+       arv, err := arvadosclient.MakeArvadosClient()
+       kc, _ := MakeKeepClient(&arv)
+       arv.ApiToken = "abc123"
+       kc.SetServiceRoots(
+               map[string]string{"x": ks0.url},
+               map[string]string{uuid: ks.url})
+
+       r, n, uri, err := kc.Get(hash+"+K@"+uuid)
+       defer r.Close()
+       c.Check(err, Equals, nil)
+       c.Check(n, Equals, int64(3))
+       c.Check(uri, Equals, fmt.Sprintf("%s/%s", ks.url, hash+"+K@"+uuid))
+
+       content, err := ioutil.ReadAll(r)
+       c.Check(err, Equals, nil)
+       c.Check(content, DeepEquals, []byte("foo"))
+}
+
+func (s *StandaloneSuite) TestGetWithServiceHintFailoverToLocals(c *C) {
+       uuid := "zzzzz-bi6l4-123451234512345"
+       hash := fmt.Sprintf("%x", md5.Sum([]byte("foo")))
+
+       ksLocal := RunFakeKeepServer(StubGetHandler{
+               c,
+               hash+"+K@"+uuid,
+               "abc123",
+               http.StatusOK,
+               []byte("foo")})
+       defer ksLocal.listener.Close()
+       ksGateway := RunFakeKeepServer(StubGetHandler{
+               c,
+               hash+"+K@"+uuid,
+               "abc123",
+               http.StatusInternalServerError,
+               []byte("Error")})
+       defer ksGateway.listener.Close()
+
+       arv, err := arvadosclient.MakeArvadosClient()
+       kc, _ := MakeKeepClient(&arv)
+       arv.ApiToken = "abc123"
+       kc.SetServiceRoots(
+               map[string]string{"zzzzz-bi6l4-keepdisk0000000": ksLocal.url},
+               map[string]string{uuid: ksGateway.url})
+
+       r, n, uri, err := kc.Get(hash+"+K@"+uuid)
+       c.Assert(err, Equals, nil)
+       defer r.Close()
+       c.Check(n, Equals, int64(3))
+       c.Check(uri, Equals, fmt.Sprintf("%s/%s", ksLocal.url, hash+"+K@"+uuid))
+
+       content, err := ioutil.ReadAll(r)
+       c.Check(err, Equals, nil)
+       c.Check(content, DeepEquals, []byte("foo"))
+}
+
 type BarHandler struct {
        handled chan string
 }
@@ -507,7 +586,7 @@ func (s *StandaloneSuite) TestChecksum(c *C) {
        arv, err := arvadosclient.MakeArvadosClient()
        kc, _ := MakeKeepClient(&arv)
        arv.ApiToken = "abc123"
-       kc.SetServiceRoots(map[string]string{"x": ks.url})
+       kc.SetServiceRoots(map[string]string{"x": ks.url}, nil)
 
        r, n, _, err := kc.Get(barhash)
        _, err = ioutil.ReadAll(r)
@@ -535,26 +614,27 @@ func (s *StandaloneSuite) TestGetWithFailures(c *C) {
                c,
                hash,
                "abc123",
+               http.StatusOK,
                content}
 
        arv, err := arvadosclient.MakeArvadosClient()
        kc, _ := MakeKeepClient(&arv)
        arv.ApiToken = "abc123"
-       service_roots := make(map[string]string)
+       localRoots := make(map[string]string)
 
        ks1 := RunSomeFakeKeepServers(st, 1)
        ks2 := RunSomeFakeKeepServers(fh, 4)
 
        for i, k := range ks1 {
-               service_roots[fmt.Sprintf("zzzzz-bi6l4-fakefakefake%03d", i)] = k.url
+               localRoots[fmt.Sprintf("zzzzz-bi6l4-fakefakefake%03d", i)] = k.url
                defer k.listener.Close()
        }
        for i, k := range ks2 {
-               service_roots[fmt.Sprintf("zzzzz-bi6l4-fakefakefake%03d", i+len(ks1))] = k.url
+               localRoots[fmt.Sprintf("zzzzz-bi6l4-fakefakefake%03d", i+len(ks1))] = k.url
                defer k.listener.Close()
        }
 
-       kc.SetServiceRoots(service_roots)
+       kc.SetServiceRoots(localRoots, nil)
 
        // This test works only if one of the failing services is
        // attempted before the succeeding service. Otherwise,
@@ -562,7 +642,7 @@ func (s *StandaloneSuite) TestGetWithFailures(c *C) {
        // the choice of block content "waz" and the UUIDs of the fake
        // servers, so we just tried different strings until we found
        // an example that passes this Assert.)
-       c.Assert(NewRootSorter(service_roots, hash).GetSortedRoots()[0], Not(Equals), ks1[0].url)
+       c.Assert(NewRootSorter(localRoots, hash).GetSortedRoots()[0], Not(Equals), ks1[0].url)
 
        r, n, url2, err := kc.Get(hash)
 
@@ -634,16 +714,16 @@ func (s *StandaloneSuite) TestPutProxy(c *C) {
        kc.Want_replicas = 2
        kc.Using_proxy = true
        arv.ApiToken = "abc123"
-       service_roots := make(map[string]string)
+       localRoots := make(map[string]string)
 
        ks1 := RunSomeFakeKeepServers(st, 1)
 
        for i, k := range ks1 {
-               service_roots[fmt.Sprintf("zzzzz-bi6l4-fakefakefake%03d", i)] = k.url
+               localRoots[fmt.Sprintf("zzzzz-bi6l4-fakefakefake%03d", i)] = k.url
                defer k.listener.Close()
        }
 
-       kc.SetServiceRoots(service_roots)
+       kc.SetServiceRoots(localRoots, nil)
 
        _, replicas, err := kc.PutB([]byte("foo"))
        <-st.handled
@@ -665,15 +745,15 @@ func (s *StandaloneSuite) TestPutProxyInsufficientReplicas(c *C) {
        kc.Want_replicas = 3
        kc.Using_proxy = true
        arv.ApiToken = "abc123"
-       service_roots := make(map[string]string)
+       localRoots := make(map[string]string)
 
        ks1 := RunSomeFakeKeepServers(st, 1)
 
        for i, k := range ks1 {
-               service_roots[fmt.Sprintf("zzzzz-bi6l4-fakefakefake%03d", i)] = k.url
+               localRoots[fmt.Sprintf("zzzzz-bi6l4-fakefakefake%03d", i)] = k.url
                defer k.listener.Close()
        }
-       kc.SetServiceRoots(service_roots)
+       kc.SetServiceRoots(localRoots, nil)
 
        _, replicas, err := kc.PutB([]byte("foo"))
        <-st.handled
index 940a110081dbaa46920cb472ea09d9e0635fd219..6393503e965b92b7ff27e8320e46e28eed69857d 100644 (file)
@@ -76,7 +76,7 @@ func (this *KeepClient) setClientSettingsStore() {
        }
 }
 
-func (this *KeepClient) DiscoverKeepServers() (map[string]string, error) {
+func (this *KeepClient) DiscoverKeepServers() error {
        type svcList struct {
                Items []keepDisk `json:"items"`
        }
@@ -86,31 +86,40 @@ func (this *KeepClient) DiscoverKeepServers() (map[string]string, error) {
 
        if err != nil {
                if err := this.Arvados.List("keep_disks", nil, &m); err != nil {
-                       return nil, err
+                       return err
                }
        }
 
        listed := make(map[string]bool)
-       service_roots := make(map[string]string)
+       localRoots := make(map[string]string)
+       gatewayRoots := make(map[string]string)
 
-       for _, element := range m.Items {
-               n := ""
-
-               if element.SSL {
-                       n = "s"
+       for _, service := range m.Items {
+               scheme := "http"
+               if service.SSL {
+                       scheme = "https"
                }
-
-               // Construct server URL
-               url := fmt.Sprintf("http%s://%s:%d", n, element.Hostname, element.Port)
+               url := fmt.Sprintf("%s://%s:%d", scheme, service.Hostname, service.Port)
 
                // Skip duplicates
-               if !listed[url] {
-                       listed[url] = true
-                       service_roots[element.Uuid] = url
+               if listed[url] {
+                       continue
                }
-               if element.SvcType == "proxy" {
+               listed[url] = true
+
+               switch service.SvcType {
+               case "disk":
+                       localRoots[service.Uuid] = url
+               case "proxy":
+                       localRoots[service.Uuid] = url
                        this.Using_proxy = true
                }
+               // Gateway services are only used when specified by
+               // UUID, so there's nothing to gain by filtering them
+               // by service type. Including all accessible services
+               // (gateway and otherwise) merely accommodates more
+               // service configurations.
+               gatewayRoots[service.Uuid] = url
        }
 
        if this.Using_proxy {
@@ -119,9 +128,8 @@ func (this *KeepClient) DiscoverKeepServers() (map[string]string, error) {
                this.setClientSettingsStore()
        }
 
-       this.SetServiceRoots(service_roots)
-
-       return service_roots, nil
+       this.SetServiceRoots(localRoots, gatewayRoots)
+       return nil
 }
 
 type uploadStatus struct {
@@ -204,7 +212,7 @@ func (this KeepClient) putReplicas(
        requestId := fmt.Sprintf("%x", md5.Sum([]byte(locator+time.Now().String())))[0:8]
 
        // Calculate the ordering for uploading to servers
-       sv := NewRootSorter(this.ServiceRoots(), hash).GetSortedRoots()
+       sv := NewRootSorter(this.LocalRoots(), hash).GetSortedRoots()
 
        // The next server to try contacting
        next_server := 0
index 6196b502021a6036aa96c33d61c9bc6fb5d4f4f1..842a36d8ed6145062f166347f4b79126eadce196 100644 (file)
@@ -62,7 +62,7 @@ class KeepLocator(object):
             self.size = None
         for hint in pieces:
             if self.HINT_RE.match(hint) is None:
-                raise ValueError("unrecognized hint data {}".format(hint))
+                raise ValueError("invalid hint format: {}".format(hint))
             elif hint.startswith('A'):
                 self.parse_permission_hint(hint)
             else:
@@ -518,6 +518,7 @@ class KeepClient(object):
                 if not proxy.endswith('/'):
                     proxy += '/'
                 self.api_token = api_token
+                self._gateway_services = {}
                 self._keep_services = [{
                     'uuid': 'proxy',
                     '_service_root': proxy,
@@ -531,6 +532,7 @@ class KeepClient(object):
                     api_client = arvados.api('v1')
                 self.api_client = api_client
                 self.api_token = api_client.api_token
+                self._gateway_services = {}
                 self._keep_services = None
                 self.using_proxy = None
                 self._static_services_list = False
@@ -560,21 +562,31 @@ class KeepClient(object):
             except Exception:  # API server predates Keep services.
                 keep_services = self.api_client.keep_disks().list()
 
-            self._keep_services = keep_services.execute().get('items')
-            if not self._keep_services:
+            accessible = keep_services.execute().get('items')
+            if not accessible:
                 raise arvados.errors.NoKeepServersError()
 
-            self.using_proxy = any(ks.get('service_type') == 'proxy'
-                                   for ks in self._keep_services)
-
             # Precompute the base URI for each service.
-            for r in self._keep_services:
+            for r in accessible:
                 r['_service_root'] = "{}://[{}]:{:d}/".format(
                     'https' if r['service_ssl_flag'] else 'http',
                     r['service_host'],
                     r['service_port'])
+
+            # Gateway services are only used when specified by UUID,
+            # so there's nothing to gain by filtering them by
+            # service_type.
+            self._gateway_services = {ks.get('uuid'): ks for ks in accessible}
+            _logger.debug(str(self._gateway_services))
+
+            self._keep_services = [
+                ks for ks in accessible
+                if ks.get('service_type') in ['disk', 'proxy']]
             _logger.debug(str(self._keep_services))
 
+            self.using_proxy = any(ks.get('service_type') == 'proxy'
+                                   for ks in self._keep_services)
+
     def _service_weight(self, data_hash, service_uuid):
         """Compute the weight of a Keep service endpoint for a data
         block with a known hash.
@@ -584,31 +596,46 @@ class KeepClient(object):
         """
         return hashlib.md5(data_hash + service_uuid[-15:]).hexdigest()
 
-    def weighted_service_roots(self, data_hash, force_rebuild=False):
+    def weighted_service_roots(self, locator, force_rebuild=False):
         """Return an array of Keep service endpoints, in the order in
         which they should be probed when reading or writing data with
-        the given hash.
+        the given hash+hints.
         """
         self.build_services_list(force_rebuild)
 
-        # Sort the available services by weight (heaviest first) for
-        # this data_hash, and return their service_roots (base URIs)
+        sorted_roots = []
+
+        # Use the services indicated by the given +K@... remote
+        # service hints, if any are present and can be resolved to a
+        # URI.
+        for hint in locator.hints:
+            if hint.startswith('K@'):
+                if len(hint) == 7:
+                    sorted_roots.append(
+                        "https://keep.{}.arvadosapi.com/".format(hint[2:]))
+                elif len(hint) == 29:
+                    svc = self._gateway_services.get(hint[2:])
+                    if svc:
+                        sorted_roots.append(svc['_service_root'])
+
+        # Sort the available local services by weight (heaviest first)
+        # for this locator, and return their service_roots (base URIs)
         # in that order.
-        sorted_roots = [
+        sorted_roots.extend([
             svc['_service_root'] for svc in sorted(
                 self._keep_services,
                 reverse=True,
-                key=lambda svc: self._service_weight(data_hash, svc['uuid']))]
-        _logger.debug(data_hash + ': ' + str(sorted_roots))
+                key=lambda svc: self._service_weight(locator.md5sum, svc['uuid']))])
+        _logger.debug("{}: {}".format(locator, sorted_roots))
         return sorted_roots
 
-    def map_new_services(self, roots_map, md5_s, force_rebuild, **headers):
+    def map_new_services(self, roots_map, locator, force_rebuild, **headers):
         # roots_map is a dictionary, mapping Keep service root strings
         # to KeepService objects.  Poll for Keep services, and add any
         # new ones to roots_map.  Return the current list of local
         # root strings.
         headers.setdefault('Authorization', "OAuth2 %s" % (self.api_token,))
-        local_roots = self.weighted_service_roots(md5_s, force_rebuild)
+        local_roots = self.weighted_service_roots(locator, force_rebuild)
         for root in local_roots:
             if root not in roots_map:
                 roots_map[root] = self.KeepService(root, self.session, **headers)
@@ -664,28 +691,40 @@ class KeepClient(object):
         if ',' in loc_s:
             return ''.join(self.get(x) for x in loc_s.split(','))
         locator = KeepLocator(loc_s)
-        expect_hash = locator.md5sum
-        slot, first = self.block_cache.reserve_cache(expect_hash)
+        slot, first = self.block_cache.reserve_cache(locator.md5sum)
         if not first:
             v = slot.get()
             return v
 
+        # If the locator has hints specifying a prefix (indicating a
+        # remote keepproxy) or the UUID of a local gateway service,
+        # read data from the indicated service(s) instead of the usual
+        # list of local disk services.
+        hint_roots = ['http://keep.{}.arvadosapi.com/'.format(hint[2:])
+                      for hint in locator.hints if hint.startswith('K@') and len(hint) == 7]
+        hint_roots.extend([self._gateway_services[hint[2:]]['_service_root']
+                           for hint in locator.hints if (
+                                   hint.startswith('K@') and
+                                   len(hint) == 29 and
+                                   self._gateway_services.get(hint[2:])
+                                   )])
+        # Map root URLs to their KeepService objects.
+        roots_map = {root: self.KeepService(root, self.session) for root in hint_roots}
+
         # See #3147 for a discussion of the loop implementation.  Highlights:
         # * Refresh the list of Keep services after each failure, in case
         #   it's being updated.
         # * Retry until we succeed, we're out of retries, or every available
         #   service has returned permanent failure.
-        hint_roots = ['http://keep.{}.arvadosapi.com/'.format(hint[2:])
-                      for hint in locator.hints if hint.startswith('K@')]
-        # Map root URLs their KeepService objects.
-        roots_map = {root: self.KeepService(root, self.session) for root in hint_roots}
+        sorted_roots = []
+        roots_map = {}
         blob = None
         loop = retry.RetryLoop(num_retries, self._check_loop_result,
                                backoff_start=2)
         for tries_left in loop:
             try:
-                local_roots = self.map_new_services(
-                    roots_map, expect_hash,
+                sorted_roots = self.map_new_services(
+                    roots_map, locator,
                     force_rebuild=(tries_left < num_retries))
             except Exception as error:
                 loop.save_result(error)
@@ -694,7 +733,7 @@ class KeepClient(object):
             # Query KeepService objects that haven't returned
             # permanent failure, in our specified shuffle order.
             services_to_try = [roots_map[root]
-                               for root in (local_roots + hint_roots)
+                               for root in sorted_roots
                                if roots_map[root].usable()]
             for keep_service in services_to_try:
                 blob = keep_service.get(locator, timeout=self.current_timeout(num_retries-tries_left))
@@ -708,22 +747,17 @@ class KeepClient(object):
         if loop.success():
             return blob
 
-        try:
-            all_roots = local_roots + hint_roots
-        except NameError:
-            # We never successfully fetched local_roots.
-            all_roots = hint_roots
         # Q: Including 403 is necessary for the Keep tests to continue
         # passing, but maybe they should expect KeepReadError instead?
-        not_founds = sum(1 for key in all_roots
+        not_founds = sum(1 for key in sorted_roots
                          if roots_map[key].last_status() in {403, 404, 410})
         service_errors = ((key, roots_map[key].last_result)
-                          for key in all_roots)
+                          for key in sorted_roots)
         if not roots_map:
             raise arvados.errors.KeepReadError(
                 "failed to read {}: no Keep services available ({})".format(
                     loc_s, loop.last_result()))
-        elif not_founds == len(all_roots):
+        elif not_founds == len(sorted_roots):
             raise arvados.errors.NotFoundError(
                 "{} not found".format(loc_s), service_errors)
         else:
@@ -758,6 +792,7 @@ class KeepClient(object):
         data_hash = hashlib.md5(data).hexdigest()
         if copies < 1:
             return data_hash
+        locator = KeepLocator(data_hash + '+' + str(len(data)))
 
         headers = {}
         if self.using_proxy:
@@ -770,7 +805,7 @@ class KeepClient(object):
         for tries_left in loop:
             try:
                 local_roots = self.map_new_services(
-                    roots_map, data_hash,
+                    roots_map, locator,
                     force_rebuild=(tries_left < num_retries), **headers)
             except Exception as error:
                 loop.save_result(error)
index 644dfffbaca0657a934a43cf0742e03cc227f62b..a10802ae81b7b2bf439f8452052a0ba1e9643617 100644 (file)
@@ -104,7 +104,8 @@ class ApiClientMock(object):
                            service_type='disk',
                            service_host=None,
                            service_port=None,
-                           service_ssl_flag=False):
+                           service_ssl_flag=False,
+                           additional_services=[]):
         if api_mock is None:
             api_mock = self.api_client_mock()
         body = {
@@ -116,7 +117,7 @@ class ApiClientMock(object):
                 'service_port': service_port or 65535-i,
                 'service_ssl_flag': service_ssl_flag,
                 'service_type': service_type,
-            } for i in range(0, count)]
+            } for i in range(0, count)] + additional_services
         }
         self._mock_api_call(api_mock.keep_services().accessible, status, body)
         return api_mock
index baae28e3d78ff8de224dfaf509856cd625dbe8f4..be13c55048a97d640f95d91e4e787824814b7f62 100644 (file)
@@ -251,7 +251,7 @@ class KeepProxyTestCase(run_test_server.TestCaseWithServers):
 class KeepClientServiceTestCase(unittest.TestCase, tutil.ApiClientMock):
     def get_service_roots(self, api_client):
         keep_client = arvados.KeepClient(api_client=api_client)
-        services = keep_client.weighted_service_roots('000000')
+        services = keep_client.weighted_service_roots(arvados.KeepLocator('0'*32))
         return [urlparse.urlparse(url) for url in sorted(services)]
 
     def test_ssl_flag_respected_in_roots(self):
@@ -344,7 +344,7 @@ class KeepClientServiceTestCase(unittest.TestCase, tutil.ApiClientMock):
         api_client = self.mock_keep_services(count=16)
         keep_client = arvados.KeepClient(api_client=api_client)
         for i, hash in enumerate(hashes):
-            roots = keep_client.weighted_service_roots(hash)
+            roots = keep_client.weighted_service_roots(arvados.KeepLocator(hash))
             got_order = [
                 re.search(r'//\[?keep0x([0-9a-f]+)', root).group(1)
                 for root in roots]
@@ -357,14 +357,14 @@ class KeepClientServiceTestCase(unittest.TestCase, tutil.ApiClientMock):
         api_client = self.mock_keep_services(count=initial_services)
         keep_client = arvados.KeepClient(api_client=api_client)
         probes_before = [
-            keep_client.weighted_service_roots(hash) for hash in hashes]
+            keep_client.weighted_service_roots(arvados.KeepLocator(hash)) for hash in hashes]
         for added_services in range(1, 12):
             api_client = self.mock_keep_services(count=initial_services+added_services)
             keep_client = arvados.KeepClient(api_client=api_client)
             total_penalty = 0
             for hash_index in range(len(hashes)):
                 probe_after = keep_client.weighted_service_roots(
-                    hashes[hash_index])
+                    arvados.KeepLocator(hashes[hash_index]))
                 penalty = probe_after.index(probes_before[hash_index][0])
                 self.assertLessEqual(penalty, added_services)
                 total_penalty += penalty
@@ -457,6 +457,65 @@ class KeepClientServiceTestCase(unittest.TestCase, tutil.ApiClientMock):
         self.assertEqual(2, len(exc_check.exception.request_errors()))
 
 
+class KeepClientGatewayTestCase(unittest.TestCase, tutil.ApiClientMock):
+    def mock_disks_and_gateways(self, disks=3, gateways=1):
+        self.gateways = [{
+                'uuid': 'zzzzz-bi6l4-gateway{:08d}'.format(i),
+                'owner_uuid': 'zzzzz-tpzed-000000000000000',
+                'service_host': 'gatewayhost{}'.format(i),
+                'service_port': 12345,
+                'service_ssl_flag': True,
+                'service_type': 'gateway:test',
+        } for i in range(gateways)]
+        self.gateway_roots = [
+            "https://[{service_host}]:{service_port}/".format(**gw)
+            for gw in self.gateways]
+        self.api_client = self.mock_keep_services(
+            count=disks, additional_services=self.gateways)
+        self.keepClient = arvados.KeepClient(api_client=self.api_client)
+
+    @mock.patch('requests.Session')
+    def test_get_with_gateway_hint_first(self, MockSession):
+        MockSession.return_value.get.return_value = tutil.fake_requests_response(
+            code=200, body='foo', headers={'Content-Length': 3})
+        self.mock_disks_and_gateways()
+        locator = 'acbd18db4cc2f85cedef654fccc4a4d8+3+K@' + self.gateways[0]['uuid']
+        self.assertEqual('foo', self.keepClient.get(locator))
+        self.assertEqual((self.gateway_roots[0]+locator,),
+                         MockSession.return_value.get.call_args_list[0][0])
+
+    @mock.patch('requests.Session')
+    def test_get_with_gateway_hints_in_order(self, MockSession):
+        gateways = 4
+        disks = 3
+        MockSession.return_value.get.return_value = tutil.fake_requests_response(
+            code=404, body='')
+        self.mock_disks_and_gateways(gateways=gateways, disks=disks)
+        locator = '+'.join(['acbd18db4cc2f85cedef654fccc4a4d8+3'] +
+                           ['K@'+gw['uuid'] for gw in self.gateways])
+        with self.assertRaises(arvados.errors.NotFoundError):
+            self.keepClient.get(locator)
+        # Gateways are tried first, in the order given.
+        for i, root in enumerate(self.gateway_roots):
+            self.assertEqual((root+locator,),
+                             MockSession.return_value.get.call_args_list[i][0])
+        # Disk services are tried next.
+        for i in range(gateways, gateways+disks):
+            self.assertRegexpMatches(
+                MockSession.return_value.get.call_args_list[i][0][0],
+                r'keep0x')
+
+    @mock.patch('requests.Session')
+    def test_get_with_remote_proxy_hint(self, MockSession):
+        MockSession.return_value.get.return_value = tutil.fake_requests_response(
+            code=200, body='foo', headers={'Content-Length': 3})
+        self.mock_disks_and_gateways()
+        locator = 'acbd18db4cc2f85cedef654fccc4a4d8+3+K@xyzzy'
+        self.assertEqual('foo', self.keepClient.get(locator))
+        self.assertEqual(('https://keep.xyzzy.arvadosapi.com/'+locator,),
+                         MockSession.return_value.get.call_args_list[0][0])
+
+
 class KeepClientRetryTestMixin(object):
     # Testing with a local Keep store won't exercise the retry behavior.
     # Instead, our strategy is:
index 581f7f48739fd2fef826ef6b6dfae6c9a8baefec..af81ba242e08d9199d2e79f65a79dce6b0c26c68 100644 (file)
@@ -1,6 +1,7 @@
 package main
 
 import (
+       "errors"
        "flag"
        "fmt"
        "git.curoverse.com/arvados.git/sdk/go/arvadosclient"
@@ -105,7 +106,7 @@ func main() {
                log.Fatalf("Could not listen on %v", listen)
        }
 
-       go RefreshServicesList(&kc)
+       go RefreshServicesList(kc)
 
        // Shut down the server gracefully (by closing the listener)
        // if SIGTERM is received.
@@ -118,10 +119,10 @@ func main() {
        signal.Notify(term, syscall.SIGTERM)
        signal.Notify(term, syscall.SIGINT)
 
-       log.Printf("Arvados Keep proxy started listening on %v with server list %v", listener.Addr(), kc.ServiceRoots())
+       log.Printf("Arvados Keep proxy started listening on %v", listener.Addr())
 
        // Start listening for requests.
-       http.Serve(listener, MakeRESTRouter(!no_get, !no_put, &kc))
+       http.Serve(listener, MakeRESTRouter(!no_get, !no_put, kc))
 
        log.Println("shutting down")
 }
@@ -134,27 +135,24 @@ type ApiTokenCache struct {
 
 // Refresh the keep service list every five minutes.
 func RefreshServicesList(kc *keepclient.KeepClient) {
-       var sleeptime time.Duration
+       previousRoots := ""
        for {
-               oldservices := kc.ServiceRoots()
-               newservices, err := kc.DiscoverKeepServers()
-               if err == nil && len(newservices) > 0 {
-                       s1 := fmt.Sprint(oldservices)
-                       s2 := fmt.Sprint(newservices)
-                       if s1 != s2 {
-                               log.Printf("Updated server list to %v", s2)
-                       }
-                       sleeptime = 300 * time.Second
+               if err := kc.DiscoverKeepServers(); err != nil {
+                       log.Println("Error retrieving services list:", err)
+                       time.Sleep(3*time.Second)
+                       previousRoots = ""
+               } else if len(kc.LocalRoots()) == 0 {
+                       log.Println("Received empty services list")
+                       time.Sleep(3*time.Second)
+                       previousRoots = ""
                } else {
-                       // There was an error, or the list is empty, so wait 3 seconds and try again.
-                       if err != nil {
-                               log.Printf("Error retrieving server list: %v", err)
-                       } else {
-                               log.Printf("Retrieved an empty server list")
+                       newRoots := fmt.Sprint("Locals ", kc.LocalRoots(), ", gateways ", kc.GatewayRoots())
+                       if newRoots != previousRoots {
+                               log.Println("Updated services list:", newRoots)
+                               previousRoots = newRoots
                        }
-                       sleeptime = 3 * time.Second
+                       time.Sleep(300*time.Second)
                }
-               time.Sleep(sleeptime)
        }
 }
 
@@ -258,14 +256,14 @@ func MakeRESTRouter(
        rest := mux.NewRouter()
 
        if enable_get {
-               rest.Handle(`/{hash:[0-9a-f]{32}}+{hints}`,
+               rest.Handle(`/{locator:[0-9a-f]{32}\+.*}`,
                        GetBlockHandler{kc, t}).Methods("GET", "HEAD")
-               rest.Handle(`/{hash:[0-9a-f]{32}}`, GetBlockHandler{kc, t}).Methods("GET", "HEAD")
+               rest.Handle(`/{locator:[0-9a-f]{32}}`, GetBlockHandler{kc, t}).Methods("GET", "HEAD")
        }
 
        if enable_put {
-               rest.Handle(`/{hash:[0-9a-f]{32}}+{hints}`, PutBlockHandler{kc, t}).Methods("PUT")
-               rest.Handle(`/{hash:[0-9a-f]{32}}`, PutBlockHandler{kc, t}).Methods("PUT")
+               rest.Handle(`/{locator:[0-9a-f]{32}\+.*}`, PutBlockHandler{kc, t}).Methods("PUT")
+               rest.Handle(`/{locator:[0-9a-f]{32}}`, PutBlockHandler{kc, t}).Methods("PUT")
                rest.Handle(`/`, PutBlockHandler{kc, t}).Methods("POST")
                rest.Handle(`/{any}`, OptionsHandler{}).Methods("OPTIONS")
                rest.Handle(`/`, OptionsHandler{}).Methods("OPTIONS")
@@ -293,22 +291,32 @@ func (this OptionsHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request
        SetCorsHeaders(resp)
 }
 
+var BadAuthorizationHeader = errors.New("Missing or invalid Authorization header")
+var ContentLengthMismatch = errors.New("Actual length != expected content length")
+var MethodNotSupported = errors.New("Method not supported")
+
 func (this GetBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
        SetCorsHeaders(resp)
 
-       kc := *this.KeepClient
-
-       hash := mux.Vars(req)["hash"]
-       hints := mux.Vars(req)["hints"]
-
-       locator := keepclient.MakeLocator2(hash, hints)
+       locator := mux.Vars(req)["locator"]
+       var err error
+       var status int
+       var expectLength, responseLength int64
+       var proxiedURI = "-"
+
+       defer func() {
+               log.Println(GetRemoteAddress(req), req.Method, req.URL.Path, status, expectLength, responseLength, proxiedURI, err)
+               if status != http.StatusOK {
+                       http.Error(resp, err.Error(), status)
+               }
+       }()
 
-       log.Printf("%s: %s %s begin", GetRemoteAddress(req), req.Method, hash)
+       kc := *this.KeepClient
 
        var pass bool
        var tok string
        if pass, tok = CheckAuthorizationHeader(kc, this.ApiTokenCache, req); !pass {
-               http.Error(resp, "Missing or invalid Authorization header", http.StatusForbidden)
+               status, err = http.StatusForbidden, BadAuthorizationHeader
                return
        }
 
@@ -318,92 +326,91 @@ func (this GetBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques
        kc.Arvados = &arvclient
 
        var reader io.ReadCloser
-       var err error
-       var blocklen int64
 
-       if req.Method == "GET" {
-               reader, blocklen, _, err = kc.AuthorizedGet(hash, locator.Signature, locator.Timestamp)
+       switch req.Method {
+       case "HEAD":
+               expectLength, proxiedURI, err = kc.Ask(locator)
+       case "GET":
+               reader, expectLength, proxiedURI, err = kc.Get(locator)
                if reader != nil {
                        defer reader.Close()
                }
-       } else if req.Method == "HEAD" {
-               blocklen, _, err = kc.AuthorizedAsk(hash, locator.Signature, locator.Timestamp)
+       default:
+               status, err = http.StatusNotImplemented, MethodNotSupported
+               return
        }
 
-       if blocklen == -1 {
-               log.Printf("%s: %s %s Keep server did not return Content-Length",
-                       GetRemoteAddress(req), req.Method, hash)
+       if expectLength == -1 {
+               log.Println("Warning:", GetRemoteAddress(req), req.Method, proxiedURI, "Content-Length not provided")
        }
 
-       var status = 0
        switch err {
        case nil:
                status = http.StatusOK
-               resp.Header().Set("Content-Length", fmt.Sprint(blocklen))
-               if reader != nil {
-                       n, err2 := io.Copy(resp, reader)
-                       if blocklen > -1 && n != blocklen {
-                               log.Printf("%s: %s %s %v %v mismatched copy size expected Content-Length: %v",
-                                       GetRemoteAddress(req), req.Method, hash, status, n, blocklen)
-                       } else if err2 == nil {
-                               log.Printf("%s: %s %s %v %v",
-                                       GetRemoteAddress(req), req.Method, hash, status, n)
-                       } else {
-                               log.Printf("%s: %s %s %v %v copy error: %v",
-                                       GetRemoteAddress(req), req.Method, hash, status, n, err2.Error())
+               resp.Header().Set("Content-Length", fmt.Sprint(expectLength))
+               switch req.Method {
+               case "HEAD":
+                       responseLength = 0
+               case "GET":
+                       responseLength, err = io.Copy(resp, reader)
+                       if err == nil && expectLength > -1 && responseLength != expectLength {
+                               err = ContentLengthMismatch
                        }
-               } else {
-                       log.Printf("%s: %s %s %v 0", GetRemoteAddress(req), req.Method, hash, status)
                }
        case keepclient.BlockNotFound:
                status = http.StatusNotFound
-               http.Error(resp, "Not Found", http.StatusNotFound)
        default:
                status = http.StatusBadGateway
-               http.Error(resp, err.Error(), http.StatusBadGateway)
-       }
-
-       if err != nil {
-               log.Printf("%s: %s %s %v error: %v",
-                       GetRemoteAddress(req), req.Method, hash, status, err.Error())
        }
 }
 
+var LengthRequiredError = errors.New(http.StatusText(http.StatusLengthRequired))
+var LengthMismatchError = errors.New("Locator size hint does not match Content-Length header")
+
 func (this PutBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
        SetCorsHeaders(resp)
 
        kc := *this.KeepClient
+       var err error
+       var expectLength int64 = -1
+       var status = http.StatusInternalServerError
+       var wroteReplicas int
+       var locatorOut string = "-"
+
+       defer func() {
+               log.Println(GetRemoteAddress(req), req.Method, req.URL.Path, status, expectLength, kc.Want_replicas, wroteReplicas, locatorOut, err)
+               if status != http.StatusOK {
+                       http.Error(resp, err.Error(), status)
+               }
+       }()
 
-       hash := mux.Vars(req)["hash"]
-       hints := mux.Vars(req)["hints"]
-
-       locator := keepclient.MakeLocator2(hash, hints)
+       locatorIn := mux.Vars(req)["locator"]
 
-       var contentLength int64 = -1
        if req.Header.Get("Content-Length") != "" {
-               _, err := fmt.Sscanf(req.Header.Get("Content-Length"), "%d", &contentLength)
+               _, err := fmt.Sscanf(req.Header.Get("Content-Length"), "%d", &expectLength)
                if err != nil {
-                       resp.Header().Set("Content-Length", fmt.Sprintf("%d", contentLength))
+                       resp.Header().Set("Content-Length", fmt.Sprintf("%d", expectLength))
                }
 
        }
 
-       log.Printf("%s: %s %s Content-Length %v", GetRemoteAddress(req), req.Method, hash, contentLength)
-
-       if contentLength < 0 {
-               http.Error(resp, "Must include Content-Length header", http.StatusLengthRequired)
+       if expectLength < 0 {
+               err = LengthRequiredError
+               status = http.StatusLengthRequired
                return
        }
 
-       if locator.Size > 0 && int64(locator.Size) != contentLength {
-               http.Error(resp, "Locator size hint does not match Content-Length header", http.StatusBadRequest)
+       if loc := keepclient.MakeLocator(locatorIn); loc.Size > 0 && int64(loc.Size) != expectLength {
+               err = LengthMismatchError
+               status = http.StatusBadRequest
                return
        }
 
        var pass bool
        var tok string
        if pass, tok = CheckAuthorizationHeader(kc, this.ApiTokenCache, req); !pass {
-               http.Error(resp, "Missing or invalid Authorization header", http.StatusForbidden)
+               err = BadAuthorizationHeader
+               status = http.StatusForbidden
                return
        }
 
@@ -422,57 +429,42 @@ func (this PutBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques
        }
 
        // Now try to put the block through
-       var replicas int
-       var put_err error
-       if hash == "" {
+       if locatorIn == "" {
                if bytes, err := ioutil.ReadAll(req.Body); err != nil {
-                       msg := fmt.Sprintf("Error reading request body: %s", err)
-                       log.Printf(msg)
-                       http.Error(resp, msg, http.StatusInternalServerError)
+                       err = errors.New(fmt.Sprintf("Error reading request body: %s", err))
+                       status = http.StatusInternalServerError
                        return
                } else {
-                       hash, replicas, put_err = kc.PutB(bytes)
+                       locatorOut, wroteReplicas, err = kc.PutB(bytes)
                }
        } else {
-               hash, replicas, put_err = kc.PutHR(hash, req.Body, contentLength)
+               locatorOut, wroteReplicas, err = kc.PutHR(locatorIn, req.Body, expectLength)
        }
 
        // Tell the client how many successful PUTs we accomplished
-       resp.Header().Set(keepclient.X_Keep_Replicas_Stored, fmt.Sprintf("%d", replicas))
+       resp.Header().Set(keepclient.X_Keep_Replicas_Stored, fmt.Sprintf("%d", wroteReplicas))
 
-       switch put_err {
+       switch err {
        case nil:
-               // Default will return http.StatusOK
-               log.Printf("%s: %s %s finished, stored %v replicas (desired %v)", GetRemoteAddress(req), req.Method, hash, replicas, kc.Want_replicas)
-               n, err2 := io.WriteString(resp, hash)
-               if err2 != nil {
-                       log.Printf("%s: wrote %v bytes to response body and got error %v", n, err2.Error())
-               }
+               status = http.StatusOK
+               _, err = io.WriteString(resp, locatorOut)
 
        case keepclient.OversizeBlockError:
                // Too much data
-               http.Error(resp, fmt.Sprintf("Exceeded maximum blocksize %d", keepclient.BLOCKSIZE), http.StatusRequestEntityTooLarge)
+               status = http.StatusRequestEntityTooLarge
 
        case keepclient.InsufficientReplicasError:
-               if replicas > 0 {
+               if wroteReplicas > 0 {
                        // At least one write is considered success.  The
                        // client can decide if getting less than the number of
                        // replications it asked for is a fatal error.
-                       // Default will return http.StatusOK
-                       n, err2 := io.WriteString(resp, hash)
-                       if err2 != nil {
-                               log.Printf("%s: wrote %v bytes to response body and got error %v", n, err2.Error())
-                       }
+                       status = http.StatusOK
+                       _, err = io.WriteString(resp, locatorOut)
                } else {
-                       http.Error(resp, put_err.Error(), http.StatusServiceUnavailable)
+                       status = http.StatusServiceUnavailable
                }
 
        default:
-               http.Error(resp, put_err.Error(), http.StatusBadGateway)
-       }
-
-       if put_err != nil {
-               log.Printf("%s: %s %s stored %v replicas (desired %v) got error %v", GetRemoteAddress(req), req.Method, hash, replicas, kc.Want_replicas, put_err.Error())
+               status = http.StatusBadGateway
        }
-
 }
index e3b4e36b63de23dee806a35a9c9d55958cbc3afd..5f6e2b9bc8eda0ef2d958071301e949a277b57c0 100644 (file)
@@ -117,10 +117,10 @@ func runProxy(c *C, args []string, port int, bogusClientToken bool) keepclient.K
        }
        kc.SetServiceRoots(map[string]string{
                "proxy": fmt.Sprintf("http://localhost:%v", port),
-       })
+       }, nil)
        c.Check(kc.Using_proxy, Equals, true)
-       c.Check(len(kc.ServiceRoots()), Equals, 1)
-       for _, root := range kc.ServiceRoots() {
+       c.Check(len(kc.LocalRoots()), Equals, 1)
+       for _, root := range kc.LocalRoots() {
                c.Check(root, Equals, fmt.Sprintf("http://localhost:%v", port))
        }
        log.Print("keepclient created")
@@ -154,8 +154,8 @@ func (s *ServerRequiredSuite) TestPutAskGet(c *C) {
        c.Assert(err, Equals, nil)
        c.Check(kc.Arvados.External, Equals, true)
        c.Check(kc.Using_proxy, Equals, true)
-       c.Check(len(kc.ServiceRoots()), Equals, 1)
-       for _, root := range kc.ServiceRoots() {
+       c.Check(len(kc.LocalRoots()), Equals, 1)
+       for _, root := range kc.LocalRoots() {
                c.Check(root, Equals, "http://localhost:29950")
        }
        os.Setenv("ARVADOS_EXTERNAL_CLIENT", "")
index a363bac2553998e6356216f77472bcbf537b78d3..c1371dcc5752caed634e72b2904e5fa8aaac9ed6 100644 (file)
@@ -276,7 +276,7 @@ func main() {
        }
 
        // Initialize Pull queue and worker
-       keepClient := keepclient.KeepClient{
+       keepClient := &keepclient.KeepClient{
                Arvados:       nil,
                Want_replicas: 1,
                Using_proxy:   true,
index fac4bb15030eaaa8334bf375dc2a9baa4695fbb0..d85458a325a1c44e2e53d177da6bd12f8adbe07b 100644 (file)
@@ -19,7 +19,7 @@ import (
                        Skip the rest of the servers if no errors
                Repeat
 */
-func RunPullWorker(pullq *WorkQueue, keepClient keepclient.KeepClient) {
+func RunPullWorker(pullq *WorkQueue, keepClient *keepclient.KeepClient) {
        nextItem := pullq.NextItem
        for item := range nextItem {
                pullRequest := item.(PullRequest)
@@ -39,14 +39,14 @@ func RunPullWorker(pullq *WorkQueue, keepClient keepclient.KeepClient) {
                Using this token & signature, retrieve the given block.
                Write to storage
 */
-func PullItemAndProcess(pullRequest PullRequest, token string, keepClient keepclient.KeepClient) (err error) {
+func PullItemAndProcess(pullRequest PullRequest, token string, keepClient *keepclient.KeepClient) (err error) {
        keepClient.Arvados.ApiToken = token
 
        service_roots := make(map[string]string)
        for _, addr := range pullRequest.Servers {
                service_roots[addr] = addr
        }
-       keepClient.SetServiceRoots(service_roots)
+       keepClient.SetServiceRoots(service_roots, nil)
 
        // Generate signature with a random token
        expires_at := time.Now().Add(60 * time.Second)
@@ -75,7 +75,7 @@ func PullItemAndProcess(pullRequest PullRequest, token string, keepClient keepcl
 }
 
 // Fetch the content for the given locator using keepclient.
-var GetContent = func(signedLocator string, keepClient keepclient.KeepClient) (
+var GetContent = func(signedLocator string, keepClient *keepclient.KeepClient) (
        reader io.ReadCloser, contentLength int64, url string, err error) {
        reader, blocklen, url, err := keepClient.Get(signedLocator)
        return reader, blocklen, url, err
index b293cf92ea87260dd487e5e9d190a85aca779708..7a930d58c8372ef9434310c4f0574b8c09eb2a00 100644 (file)
@@ -10,7 +10,7 @@ import (
        "testing"
 )
 
-var keepClient keepclient.KeepClient
+var keepClient *keepclient.KeepClient
 
 type PullWorkIntegrationTestData struct {
        Name     string
@@ -33,7 +33,7 @@ func SetupPullWorkerIntegrationTest(t *testing.T, testData PullWorkIntegrationTe
        }
 
        // keep client
-       keepClient = keepclient.KeepClient{
+       keepClient = &keepclient.KeepClient{
                Arvados:       &arv,
                Want_replicas: 1,
                Using_proxy:   true,
@@ -42,17 +42,15 @@ func SetupPullWorkerIntegrationTest(t *testing.T, testData PullWorkIntegrationTe
 
        // discover keep services
        var servers []string
-       service_roots, err := keepClient.DiscoverKeepServers()
-       if err != nil {
+       if err := keepClient.DiscoverKeepServers(); err != nil {
                t.Error("Error discovering keep services")
        }
-       for _, host := range service_roots {
+       for _, host := range keepClient.LocalRoots() {
                servers = append(servers, host)
        }
 
        // Put content if the test needs it
        if wantData {
-               keepClient.SetServiceRoots(service_roots)
                locator, _, err := keepClient.PutB([]byte(testData.Content))
                if err != nil {
                        t.Errorf("Error putting test data in setup for %s %s %v", testData.Content, locator, err)
index f0e9e65f1ee1015a57c2bd87e8d9c926978f21c4..124c9b835c1beef29ad85e25e75b0771e7d4b9c9 100644 (file)
@@ -244,7 +244,7 @@ func performTest(testData PullWorkerTestData, c *C) {
        testPullLists[testData.name] = testData.response_body
 
        // Override GetContent to mock keepclient Get functionality
-       GetContent = func(signedLocator string, keepClient keepclient.KeepClient) (
+       GetContent = func(signedLocator string, keepClient *keepclient.KeepClient) (
                reader io.ReadCloser, contentLength int64, url string, err error) {
 
                processedPullLists[testData.name] = testData.response_body