12167: Propagate X-Request-Id in API calls.
[arvados.git] / sdk / go / keepclient / keepclient.go
index b886684c348aed30abd0aa6ccaeb16a24af2aad7..620bdbec4eaa64a7809bbbc539bf59d0c6ea2275 100644 (file)
@@ -1,3 +1,7 @@
+// Copyright (C) The Arvados Authors. All rights reserved.
+//
+// SPDX-License-Identifier: Apache-2.0
+
 /* Provides low-level Get/Put primitives for accessing Arvados Keep blocks. */
 package keepclient
 
@@ -17,12 +21,24 @@ import (
        "time"
 
        "git.curoverse.com/arvados.git/sdk/go/arvadosclient"
-       "git.curoverse.com/arvados.git/sdk/go/streamer"
+       "git.curoverse.com/arvados.git/sdk/go/asyncbuf"
 )
 
 // A Keep "block" is 64MB.
 const BLOCKSIZE = 64 * 1024 * 1024
 
+var (
+       DefaultRequestTimeout      = 20 * time.Second
+       DefaultConnectTimeout      = 2 * time.Second
+       DefaultTLSHandshakeTimeout = 4 * time.Second
+       DefaultKeepAlive           = 180 * time.Second
+
+       DefaultProxyRequestTimeout      = 300 * time.Second
+       DefaultProxyConnectTimeout      = 30 * time.Second
+       DefaultProxyTLSHandshakeTimeout = 10 * time.Second
+       DefaultProxyKeepAlive           = 120 * time.Second
+)
+
 // Error interface with an error and boolean indicating whether the error is temporary
 type Error interface {
        error
@@ -76,9 +92,9 @@ type HTTPClient interface {
 type KeepClient struct {
        Arvados            *arvadosclient.ArvadosClient
        Want_replicas      int
-       localRoots         *map[string]string
-       writableLocalRoots *map[string]string
-       gatewayRoots       *map[string]string
+       localRoots         map[string]string
+       writableLocalRoots map[string]string
+       gatewayRoots       map[string]string
        lock               sync.RWMutex
        HTTPClient         HTTPClient
        Retries            int
@@ -89,6 +105,9 @@ type KeepClient struct {
 
        // Any non-disk typed services found in the list of keepservers?
        foundNonDiskSvc bool
+
+       // Disable automatic discovery of keep services
+       disableDiscovery bool
 }
 
 // MakeKeepClient creates a new KeepClient, calls
@@ -96,12 +115,11 @@ type KeepClient struct {
 // use.
 func MakeKeepClient(arv *arvadosclient.ArvadosClient) (*KeepClient, error) {
        kc := New(arv)
-       return kc, kc.DiscoverKeepServers()
+       return kc, kc.discoverServices()
 }
 
-// New creates a new KeepClient. The caller must call
-// DiscoverKeepServers() before using the returned client to read or
-// write data.
+// New creates a new KeepClient. Service discovery will occur on the
+// next read/write operation.
 func New(arv *arvadosclient.ArvadosClient) *KeepClient {
        defaultReplicationLevel := 2
        value, err := arv.Discovery("defaultCollectionReplication")
@@ -138,10 +156,12 @@ func (kc *KeepClient) PutHR(hash string, r io.Reader, dataBytes int64) (string,
                bufsize = BLOCKSIZE
        }
 
-       t := streamer.AsyncStreamFromReader(bufsize, HashCheckingReader{r, md5.New(), hash})
-       defer t.Close()
-
-       return kc.putReplicas(hash, t, dataBytes)
+       buf := asyncbuf.NewBuffer(make([]byte, 0, bufsize))
+       go func() {
+               _, err := io.Copy(buf, HashCheckingReader{r, md5.New(), hash})
+               buf.CloseWithError(err)
+       }()
+       return kc.putReplicas(hash, buf.NewReader, dataBytes)
 }
 
 // PutHB writes a block to Keep. The hash of the bytes is given in
@@ -149,9 +169,8 @@ func (kc *KeepClient) PutHR(hash string, r io.Reader, dataBytes int64) (string,
 //
 // Return values are the same as for PutHR.
 func (kc *KeepClient) PutHB(hash string, buf []byte) (string, int, error) {
-       t := streamer.AsyncStreamFromSlice(buf)
-       defer t.Close()
-       return kc.putReplicas(hash, t, int64(len(buf)))
+       newReader := func() io.Reader { return bytes.NewBuffer(buf) }
+       return kc.putReplicas(hash, newReader, int64(len(buf)))
 }
 
 // PutB writes a block to Keep. It computes the hash itself.
@@ -181,6 +200,15 @@ func (kc *KeepClient) getOrHead(method string, locator string) (io.ReadCloser, i
                return ioutil.NopCloser(bytes.NewReader(nil)), 0, "", nil
        }
 
+       var expectLength int64
+       if parts := strings.SplitN(locator, "+", 3); len(parts) < 2 {
+               expectLength = -1
+       } else if n, err := strconv.ParseInt(parts[1], 10, 64); err != nil {
+               expectLength = -1
+       } else {
+               expectLength = n
+       }
+
        var errs []string
 
        tries_remaining := 1 + kc.Retries
@@ -211,7 +239,9 @@ func (kc *KeepClient) getOrHead(method string, locator string) (io.ReadCloser, i
                                // can try again.
                                errs = append(errs, fmt.Sprintf("%s: %v", url, err))
                                retryList = append(retryList, host)
-                       } else if resp.StatusCode != http.StatusOK {
+                               continue
+                       }
+                       if resp.StatusCode != http.StatusOK {
                                var respbody []byte
                                respbody, _ = ioutil.ReadAll(&io.LimitedReader{R: resp.Body, N: 4096})
                                resp.Body.Close()
@@ -228,20 +258,29 @@ func (kc *KeepClient) getOrHead(method string, locator string) (io.ReadCloser, i
                                } else if resp.StatusCode == 404 {
                                        count404++
                                }
-                       } else {
-                               // Success.
-                               if method == "GET" {
-                                       return HashCheckingReader{
-                                               Reader: resp.Body,
-                                               Hash:   md5.New(),
-                                               Check:  locator[0:32],
-                                       }, resp.ContentLength, url, nil
-                               } else {
+                               continue
+                       }
+                       if expectLength < 0 {
+                               if resp.ContentLength < 0 {
                                        resp.Body.Close()
-                                       return nil, resp.ContentLength, url, nil
+                                       return nil, 0, "", fmt.Errorf("error reading %q: no size hint, no Content-Length header in response", locator)
                                }
+                               expectLength = resp.ContentLength
+                       } else if resp.ContentLength >= 0 && expectLength != resp.ContentLength {
+                               resp.Body.Close()
+                               return nil, 0, "", fmt.Errorf("error reading %q: size hint %d != Content-Length %d", locator, expectLength, resp.ContentLength)
+                       }
+                       // Success
+                       if method == "GET" {
+                               return HashCheckingReader{
+                                       Reader: resp.Body,
+                                       Hash:   md5.New(),
+                                       Check:  locator[0:32],
+                               }, expectLength, url, nil
+                       } else {
+                               resp.Body.Close()
+                               return nil, expectLength, url, nil
                        }
-
                }
                serversToTry = retryList
        }
@@ -270,6 +309,12 @@ func (kc *KeepClient) Get(locator string) (io.ReadCloser, int64, string, error)
        return kc.getOrHead("GET", locator)
 }
 
+// ReadAt() retrieves a portion of block from the cache if it's
+// present, otherwise from the network.
+func (kc *KeepClient) ReadAt(locator string, p []byte, off int) (int, error) {
+       return kc.cache().ReadAt(kc, locator, p, off)
+}
+
 // Ask() verifies that a block with the given hash is available and
 // readable, according to at least one Keep service. Unlike Get, it
 // does not retrieve the data or verify that the data content matches
@@ -337,55 +382,47 @@ func (kc *KeepClient) GetIndex(keepServiceUUID, prefix string) (io.Reader, error
 // LocalRoots() returns the map of local (i.e., disk and proxy) Keep
 // services: uuid -> baseURI.
 func (kc *KeepClient) LocalRoots() map[string]string {
+       kc.discoverServices()
        kc.lock.RLock()
        defer kc.lock.RUnlock()
-       return *kc.localRoots
+       return kc.localRoots
 }
 
 // GatewayRoots() returns the map of Keep remote gateway services:
 // uuid -> baseURI.
 func (kc *KeepClient) GatewayRoots() map[string]string {
+       kc.discoverServices()
        kc.lock.RLock()
        defer kc.lock.RUnlock()
-       return *kc.gatewayRoots
+       return kc.gatewayRoots
 }
 
 // WritableLocalRoots() returns the map of writable local Keep services:
 // uuid -> baseURI.
 func (kc *KeepClient) WritableLocalRoots() map[string]string {
+       kc.discoverServices()
        kc.lock.RLock()
        defer kc.lock.RUnlock()
-       return *kc.writableLocalRoots
+       return kc.writableLocalRoots
 }
 
-// SetServiceRoots updates the localRoots and gatewayRoots maps,
-// without risk of disrupting operations that are already in progress.
+// SetServiceRoots disables service discovery and updates the
+// localRoots and gatewayRoots maps, without disrupting operations
+// that are already in progress.
 //
-// The KeepClient makes its own copy of the supplied maps, so the
-// caller can reuse/modify them after SetServiceRoots returns, but
-// they should not be modified by any other goroutine while
-// SetServiceRoots is running.
-func (kc *KeepClient) SetServiceRoots(newLocals, newWritableLocals, newGateways map[string]string) {
-       locals := make(map[string]string)
-       for uuid, root := range newLocals {
-               locals[uuid] = root
-       }
-
-       writables := make(map[string]string)
-       for uuid, root := range newWritableLocals {
-               writables[uuid] = root
-       }
-
-       gateways := make(map[string]string)
-       for uuid, root := range newGateways {
-               gateways[uuid] = root
-       }
+// The supplied maps must not be modified after calling
+// SetServiceRoots.
+func (kc *KeepClient) SetServiceRoots(locals, writables, gateways map[string]string) {
+       kc.disableDiscovery = true
+       kc.setServiceRoots(locals, writables, gateways)
+}
 
+func (kc *KeepClient) setServiceRoots(locals, writables, gateways map[string]string) {
        kc.lock.Lock()
        defer kc.lock.Unlock()
-       kc.localRoots = &locals
-       kc.writableLocalRoots = &writables
-       kc.gatewayRoots = &gateways
+       kc.localRoots = locals
+       kc.writableLocalRoots = writables
+       kc.gatewayRoots = gateways
 }
 
 // getSortedRoots returns a list of base URIs of Keep services, in the
@@ -424,6 +461,10 @@ func (kc *KeepClient) cache() *BlockCache {
        }
 }
 
+func (kc *KeepClient) ClearBlockCache() {
+       kc.cache().Clear()
+}
+
 var (
        // There are four global http.Client objects for the four
        // possible permutations of TLS behavior (verify/skip-verify)
@@ -452,34 +493,44 @@ func (kc *KeepClient) httpClient() HTTPClient {
                return c
        }
 
-       var requestTimeout, connectTimeout, keepAliveInterval, tlsTimeout time.Duration
+       var requestTimeout, connectTimeout, keepAlive, tlsTimeout time.Duration
        if kc.foundNonDiskSvc {
                // Use longer timeouts when connecting to a proxy,
                // because this usually means the intervening network
                // is slower.
-               requestTimeout = 300 * time.Second
-               connectTimeout = 30 * time.Second
-               tlsTimeout = 10 * time.Second
-               keepAliveInterval = 120 * time.Second
+               requestTimeout = DefaultProxyRequestTimeout
+               connectTimeout = DefaultProxyConnectTimeout
+               tlsTimeout = DefaultProxyTLSHandshakeTimeout
+               keepAlive = DefaultProxyKeepAlive
        } else {
-               requestTimeout = 20 * time.Second
-               connectTimeout = 2 * time.Second
-               tlsTimeout = 4 * time.Second
-               keepAliveInterval = 180 * time.Second
+               requestTimeout = DefaultRequestTimeout
+               connectTimeout = DefaultConnectTimeout
+               tlsTimeout = DefaultTLSHandshakeTimeout
+               keepAlive = DefaultKeepAlive
        }
-       transport := &http.Transport{
-               Dial: (&net.Dialer{
-                       Timeout:   connectTimeout,
-                       KeepAlive: keepAliveInterval,
-               }).Dial,
-               TLSClientConfig:     arvadosclient.MakeTLSConfig(kc.Arvados.ApiInsecure),
-               TLSHandshakeTimeout: tlsTimeout,
-       }
-       go func() {
-               for range time.NewTicker(10 * time.Minute).C {
-                       transport.CloseIdleConnections()
+
+       transport, ok := http.DefaultTransport.(*http.Transport)
+       if ok {
+               copy := *transport
+               transport = &copy
+       } else {
+               // Evidently the application has replaced
+               // http.DefaultTransport with a different type, so we
+               // need to build our own from scratch using the Go 1.8
+               // defaults.
+               transport = &http.Transport{
+                       MaxIdleConns:          100,
+                       IdleConnTimeout:       90 * time.Second,
+                       ExpectContinueTimeout: time.Second,
                }
-       }()
+       }
+       transport.DialContext = (&net.Dialer{
+               Timeout:   connectTimeout,
+               KeepAlive: keepAlive,
+               DualStack: true,
+       }).DialContext
+       transport.TLSHandshakeTimeout = tlsTimeout
+       transport.TLSClientConfig = arvadosclient.MakeTLSConfig(kc.Arvados.ApiInsecure)
        c := &http.Client{
                Timeout:   requestTimeout,
                Transport: transport,