Arvados-DCO-1.1-Signed-off-by: Radhika Chippada <radhika@curoverse.com>
[arvados.git] / sdk / go / keepclient / keepclient.go
index 4cd9f544f7d75f56f3869bc211749b22792a47e3..8b518ac8d51a2858980441ac0a2c81993bbc9436 100644 (file)
@@ -1,30 +1,76 @@
+// Copyright (C) The Arvados Authors. All rights reserved.
+//
+// SPDX-License-Identifier: Apache-2.0
+
 /* Provides low-level Get/Put primitives for accessing Arvados Keep blocks. */
 package keepclient
 
 import (
        "bytes"
        "crypto/md5"
-       "crypto/tls"
        "errors"
        "fmt"
-       "git.curoverse.com/arvados.git/sdk/go/arvadosclient"
-       "git.curoverse.com/arvados.git/sdk/go/streamer"
        "io"
        "io/ioutil"
-       "log"
+       "net"
        "net/http"
        "regexp"
        "strconv"
        "strings"
        "sync"
+       "time"
+
+       "git.curoverse.com/arvados.git/sdk/go/arvadosclient"
+       "git.curoverse.com/arvados.git/sdk/go/streamer"
 )
 
 // A Keep "block" is 64MB.
 const BLOCKSIZE = 64 * 1024 * 1024
 
-var BlockNotFound = errors.New("Block not found")
-var InsufficientReplicasError = errors.New("Could not write sufficient replicas")
-var OversizeBlockError = errors.New("Exceeded maximum block size (" + strconv.Itoa(BLOCKSIZE) + ")")
+var (
+       DefaultRequestTimeout      = 20 * time.Second
+       DefaultConnectTimeout      = 2 * time.Second
+       DefaultTLSHandshakeTimeout = 4 * time.Second
+       DefaultKeepAlive           = 180 * time.Second
+
+       DefaultProxyRequestTimeout      = 300 * time.Second
+       DefaultProxyConnectTimeout      = 30 * time.Second
+       DefaultProxyTLSHandshakeTimeout = 10 * time.Second
+       DefaultProxyKeepAlive           = 120 * time.Second
+)
+
+// Error interface with an error and boolean indicating whether the error is temporary
+type Error interface {
+       error
+       Temporary() bool
+}
+
+// multipleResponseError is of type Error
+type multipleResponseError struct {
+       error
+       isTemp bool
+}
+
+func (e *multipleResponseError) Temporary() bool {
+       return e.isTemp
+}
+
+// BlockNotFound is a multipleResponseError where isTemp is false
+var BlockNotFound = &ErrNotFound{multipleResponseError{
+       error:  errors.New("Block not found"),
+       isTemp: false,
+}}
+
+// ErrNotFound is a multipleResponseError where isTemp can be true or false
+type ErrNotFound struct {
+       multipleResponseError
+}
+
+type InsufficientReplicasError error
+
+type OversizeBlockError error
+
+var ErrOversizeBlock = OversizeBlockError(errors.New("Exceeded maximum block size (" + strconv.Itoa(BLOCKSIZE) + ")"))
 var MissingArvadosApiHost = errors.New("Missing required environment variable ARVADOS_API_HOST")
 var MissingArvadosApiToken = errors.New("Missing required environment variable ARVADOS_API_TOKEN")
 var InvalidLocatorError = errors.New("Invalid locator")
@@ -38,38 +84,56 @@ var ErrIncompleteIndex = errors.New("Got incomplete index")
 const X_Keep_Desired_Replicas = "X-Keep-Desired-Replicas"
 const X_Keep_Replicas_Stored = "X-Keep-Replicas-Stored"
 
+type HTTPClient interface {
+       Do(*http.Request) (*http.Response, error)
+}
+
 // Information about Arvados and Keep servers.
 type KeepClient struct {
        Arvados            *arvadosclient.ArvadosClient
        Want_replicas      int
-       Using_proxy        bool
-       localRoots         *map[string]string
-       writableLocalRoots *map[string]string
-       gatewayRoots       *map[string]string
+       localRoots         map[string]string
+       writableLocalRoots map[string]string
+       gatewayRoots       map[string]string
        lock               sync.RWMutex
-       Client             *http.Client
+       HTTPClient         HTTPClient
+       Retries            int
+       BlockCache         *BlockCache
 
        // set to 1 if all writable services are of disk type, otherwise 0
        replicasPerService int
+
+       // Any non-disk typed services found in the list of keepservers?
+       foundNonDiskSvc bool
+
+       // Disable automatic discovery of keep services
+       disableDiscovery bool
 }
 
-// MakeKeepClient creates a new KeepClient by contacting the API server to discover Keep servers.
+// MakeKeepClient creates a new KeepClient, calls
+// DiscoverKeepServices(), and returns when the client is ready to
+// use.
 func MakeKeepClient(arv *arvadosclient.ArvadosClient) (*KeepClient, error) {
        kc := New(arv)
-       return kc, kc.DiscoverKeepServers()
+       return kc, kc.discoverServices()
 }
 
-// New func creates a new KeepClient struct.
-// This func does not discover keep servers. It is the caller's responsibility.
+// New creates a new KeepClient. Service discovery will occur on the
+// next read/write operation.
 func New(arv *arvadosclient.ArvadosClient) *KeepClient {
-       kc := &KeepClient{
+       defaultReplicationLevel := 2
+       value, err := arv.Discovery("defaultCollectionReplication")
+       if err == nil {
+               v, ok := value.(float64)
+               if ok && v > 0 {
+                       defaultReplicationLevel = int(v)
+               }
+       }
+       return &KeepClient{
                Arvados:       arv,
-               Want_replicas: 2,
-               Using_proxy:   false,
-               Client: &http.Client{Transport: &http.Transport{
-                       TLSClientConfig: &tls.Config{InsecureSkipVerify: arv.ApiInsecure}}},
+               Want_replicas: defaultReplicationLevel,
+               Retries:       2,
        }
-       return kc
 }
 
 // Put a block given the block hash, a reader, and the number of bytes
@@ -78,14 +142,14 @@ func New(arv *arvadosclient.ArvadosClient) *KeepClient {
 // Returns the locator for the written block, the number of replicas
 // written, and an error.
 //
-// Returns an InsufficientReplicas error if 0 <= replicas <
+// Returns an InsufficientReplicasError if 0 <= replicas <
 // kc.Wants_replicas.
 func (kc *KeepClient) PutHR(hash string, r io.Reader, dataBytes int64) (string, int, error) {
        // Buffer for reads from 'r'
        var bufsize int
        if dataBytes > 0 {
                if dataBytes > BLOCKSIZE {
-                       return "", 0, OversizeBlockError
+                       return "", 0, ErrOversizeBlock
                }
                bufsize = int(dataBytes)
        } else {
@@ -130,6 +194,89 @@ func (kc *KeepClient) PutR(r io.Reader) (locator string, replicas int, err error
        }
 }
 
+func (kc *KeepClient) getOrHead(method string, locator string) (io.ReadCloser, int64, string, error) {
+       if strings.HasPrefix(locator, "d41d8cd98f00b204e9800998ecf8427e+0") {
+               return ioutil.NopCloser(bytes.NewReader(nil)), 0, "", nil
+       }
+
+       var errs []string
+
+       tries_remaining := 1 + kc.Retries
+
+       serversToTry := kc.getSortedRoots(locator)
+
+       numServers := len(serversToTry)
+       count404 := 0
+
+       var retryList []string
+
+       for tries_remaining > 0 {
+               tries_remaining -= 1
+               retryList = nil
+
+               for _, host := range serversToTry {
+                       url := host + "/" + locator
+
+                       req, err := http.NewRequest(method, url, nil)
+                       if err != nil {
+                               errs = append(errs, fmt.Sprintf("%s: %v", url, err))
+                               continue
+                       }
+                       req.Header.Add("Authorization", fmt.Sprintf("OAuth2 %s", kc.Arvados.ApiToken))
+                       resp, err := kc.httpClient().Do(req)
+                       if err != nil {
+                               // Probably a network error, may be transient,
+                               // can try again.
+                               errs = append(errs, fmt.Sprintf("%s: %v", url, err))
+                               retryList = append(retryList, host)
+                       } else if resp.StatusCode != http.StatusOK {
+                               var respbody []byte
+                               respbody, _ = ioutil.ReadAll(&io.LimitedReader{R: resp.Body, N: 4096})
+                               resp.Body.Close()
+                               errs = append(errs, fmt.Sprintf("%s: HTTP %d %q",
+                                       url, resp.StatusCode, bytes.TrimSpace(respbody)))
+
+                               if resp.StatusCode == 408 ||
+                                       resp.StatusCode == 429 ||
+                                       resp.StatusCode >= 500 {
+                                       // Timeout, too many requests, or other
+                                       // server side failure, transient
+                                       // error, can try again.
+                                       retryList = append(retryList, host)
+                               } else if resp.StatusCode == 404 {
+                                       count404++
+                               }
+                       } else {
+                               // Success.
+                               if method == "GET" {
+                                       return HashCheckingReader{
+                                               Reader: resp.Body,
+                                               Hash:   md5.New(),
+                                               Check:  locator[0:32],
+                                       }, resp.ContentLength, url, nil
+                               } else {
+                                       resp.Body.Close()
+                                       return nil, resp.ContentLength, url, nil
+                               }
+                       }
+
+               }
+               serversToTry = retryList
+       }
+       DebugPrintf("DEBUG: %s %s failed: %v", method, locator, errs)
+
+       var err error
+       if count404 == numServers {
+               err = BlockNotFound
+       } else {
+               err = &ErrNotFound{multipleResponseError{
+                       error:  fmt.Errorf("%s %s failed: %v", method, locator, errs),
+                       isTemp: len(serversToTry) > 0,
+               }}
+       }
+       return nil, 0, "", err
+}
+
 // Get() retrieves a block, given a locator. Returns a reader, the
 // expected data length, the URL the block is being fetched from, and
 // an error.
@@ -138,34 +285,7 @@ func (kc *KeepClient) PutR(r io.Reader) (locator string, replicas int, err error
 // reader returned by this method will return a BadChecksum error
 // instead of EOF.
 func (kc *KeepClient) Get(locator string) (io.ReadCloser, int64, string, error) {
-       var errs []string
-       for _, host := range kc.getSortedRoots(locator) {
-               url := host + "/" + locator
-               req, err := http.NewRequest("GET", url, nil)
-               if err != nil {
-                       errs = append(errs, fmt.Sprintf("%s: %v", url, err))
-                       continue
-               }
-               req.Header.Add("Authorization", fmt.Sprintf("OAuth2 %s", kc.Arvados.ApiToken))
-               resp, err := kc.Client.Do(req)
-               if err != nil {
-                       errs = append(errs, fmt.Sprintf("%s: %v", url, err))
-                       continue
-               } else if resp.StatusCode != http.StatusOK {
-                       respbody, _ := ioutil.ReadAll(&io.LimitedReader{resp.Body, 4096})
-                       resp.Body.Close()
-                       errs = append(errs, fmt.Sprintf("%s: HTTP %d %q",
-                               url, resp.StatusCode, bytes.TrimSpace(respbody)))
-                       continue
-               }
-               return HashCheckingReader{
-                       Reader: resp.Body,
-                       Hash:   md5.New(),
-                       Check:  locator[0:32],
-               }, resp.ContentLength, url, nil
-       }
-       log.Printf("DEBUG: GET %s failed: %v", locator, errs)
-       return nil, 0, "", BlockNotFound
+       return kc.getOrHead("GET", locator)
 }
 
 // Ask() verifies that a block with the given hash is available and
@@ -176,18 +296,8 @@ func (kc *KeepClient) Get(locator string) (io.ReadCloser, int64, string, error)
 // Returns the data size (content length) reported by the Keep service
 // and the URI reporting the data size.
 func (kc *KeepClient) Ask(locator string) (int64, string, error) {
-       for _, host := range kc.getSortedRoots(locator) {
-               url := host + "/" + locator
-               req, err := http.NewRequest("HEAD", url, nil)
-               if err != nil {
-                       continue
-               }
-               req.Header.Add("Authorization", fmt.Sprintf("OAuth2 %s", kc.Arvados.ApiToken))
-               if resp, err := kc.Client.Do(req); err == nil && resp.StatusCode == http.StatusOK {
-                       return resp.ContentLength, url, nil
-               }
-       }
-       return 0, "", BlockNotFound
+       _, size, url, err := kc.getOrHead("HEAD", locator)
+       return size, url, err
 }
 
 // GetIndex retrieves a list of blocks stored on the given server whose hashes
@@ -214,7 +324,7 @@ func (kc *KeepClient) GetIndex(keepServiceUUID, prefix string) (io.Reader, error
        }
 
        req.Header.Add("Authorization", fmt.Sprintf("OAuth2 %s", kc.Arvados.ApiToken))
-       resp, err := kc.Client.Do(req)
+       resp, err := kc.httpClient().Do(req)
        if err != nil {
                return nil, err
        }
@@ -245,55 +355,47 @@ func (kc *KeepClient) GetIndex(keepServiceUUID, prefix string) (io.Reader, error
 // LocalRoots() returns the map of local (i.e., disk and proxy) Keep
 // services: uuid -> baseURI.
 func (kc *KeepClient) LocalRoots() map[string]string {
+       kc.discoverServices()
        kc.lock.RLock()
        defer kc.lock.RUnlock()
-       return *kc.localRoots
+       return kc.localRoots
 }
 
 // GatewayRoots() returns the map of Keep remote gateway services:
 // uuid -> baseURI.
 func (kc *KeepClient) GatewayRoots() map[string]string {
+       kc.discoverServices()
        kc.lock.RLock()
        defer kc.lock.RUnlock()
-       return *kc.gatewayRoots
+       return kc.gatewayRoots
 }
 
 // WritableLocalRoots() returns the map of writable local Keep services:
 // uuid -> baseURI.
 func (kc *KeepClient) WritableLocalRoots() map[string]string {
+       kc.discoverServices()
        kc.lock.RLock()
        defer kc.lock.RUnlock()
-       return *kc.writableLocalRoots
+       return kc.writableLocalRoots
 }
 
-// SetServiceRoots updates the localRoots and gatewayRoots maps,
-// without risk of disrupting operations that are already in progress.
+// SetServiceRoots disables service discovery and updates the
+// localRoots and gatewayRoots maps, without disrupting operations
+// that are already in progress.
 //
-// The KeepClient makes its own copy of the supplied maps, so the
-// caller can reuse/modify them after SetServiceRoots returns, but
-// they should not be modified by any other goroutine while
-// SetServiceRoots is running.
-func (kc *KeepClient) SetServiceRoots(newLocals, newWritableLocals map[string]string, newGateways map[string]string) {
-       locals := make(map[string]string)
-       for uuid, root := range newLocals {
-               locals[uuid] = root
-       }
-
-       writables := make(map[string]string)
-       for uuid, root := range newWritableLocals {
-               writables[uuid] = root
-       }
-
-       gateways := make(map[string]string)
-       for uuid, root := range newGateways {
-               gateways[uuid] = root
-       }
+// The supplied maps must not be modified after calling
+// SetServiceRoots.
+func (kc *KeepClient) SetServiceRoots(locals, writables, gateways map[string]string) {
+       kc.disableDiscovery = true
+       kc.setServiceRoots(locals, writables, gateways)
+}
 
+func (kc *KeepClient) setServiceRoots(locals, writables, gateways map[string]string) {
        kc.lock.Lock()
        defer kc.lock.Unlock()
-       kc.localRoots = &locals
-       kc.writableLocalRoots = &writables
-       kc.gatewayRoots = &gateways
+       kc.localRoots = locals
+       kc.writableLocalRoots = writables
+       kc.gatewayRoots = gateways
 }
 
 // getSortedRoots returns a list of base URIs of Keep services, in the
@@ -324,6 +426,88 @@ func (kc *KeepClient) getSortedRoots(locator string) []string {
        return found
 }
 
+func (kc *KeepClient) cache() *BlockCache {
+       if kc.BlockCache != nil {
+               return kc.BlockCache
+       } else {
+               return DefaultBlockCache
+       }
+}
+
+var (
+       // There are four global http.Client objects for the four
+       // possible permutations of TLS behavior (verify/skip-verify)
+       // and timeout settings (proxy/non-proxy).
+       defaultClient = map[bool]map[bool]HTTPClient{
+               // defaultClient[false] is used for verified TLS reqs
+               false: {},
+               // defaultClient[true] is used for unverified
+               // (insecure) TLS reqs
+               true: {},
+       }
+       defaultClientMtx sync.Mutex
+)
+
+// httpClient returns the HTTPClient field if it's not nil, otherwise
+// whichever of the four global http.Client objects is suitable for
+// the current environment (i.e., TLS verification on/off, keep
+// services are/aren't proxies).
+func (kc *KeepClient) httpClient() HTTPClient {
+       if kc.HTTPClient != nil {
+               return kc.HTTPClient
+       }
+       defaultClientMtx.Lock()
+       defer defaultClientMtx.Unlock()
+       if c, ok := defaultClient[kc.Arvados.ApiInsecure][kc.foundNonDiskSvc]; ok {
+               return c
+       }
+
+       var requestTimeout, connectTimeout, keepAlive, tlsTimeout time.Duration
+       if kc.foundNonDiskSvc {
+               // Use longer timeouts when connecting to a proxy,
+               // because this usually means the intervening network
+               // is slower.
+               requestTimeout = DefaultProxyRequestTimeout
+               connectTimeout = DefaultProxyConnectTimeout
+               tlsTimeout = DefaultProxyTLSHandshakeTimeout
+               keepAlive = DefaultProxyKeepAlive
+       } else {
+               requestTimeout = DefaultRequestTimeout
+               connectTimeout = DefaultConnectTimeout
+               tlsTimeout = DefaultTLSHandshakeTimeout
+               keepAlive = DefaultKeepAlive
+       }
+
+       transport, ok := http.DefaultTransport.(*http.Transport)
+       if ok {
+               copy := *transport
+               transport = &copy
+       } else {
+               // Evidently the application has replaced
+               // http.DefaultTransport with a different type, so we
+               // need to build our own from scratch using the Go 1.8
+               // defaults.
+               transport = &http.Transport{
+                       MaxIdleConns:          100,
+                       IdleConnTimeout:       90 * time.Second,
+                       ExpectContinueTimeout: time.Second,
+               }
+       }
+       transport.DialContext = (&net.Dialer{
+               Timeout:   connectTimeout,
+               KeepAlive: keepAlive,
+               DualStack: true,
+       }).DialContext
+       transport.TLSHandshakeTimeout = tlsTimeout
+       transport.TLSClientConfig = arvadosclient.MakeTLSConfig(kc.Arvados.ApiInsecure)
+       c := &http.Client{
+               Timeout:   requestTimeout,
+               Transport: transport,
+       }
+       defaultClient[kc.Arvados.ApiInsecure][kc.foundNonDiskSvc] = c
+       return c
+}
+
 type Locator struct {
        Hash  string
        Size  int      // -1 if data size is not known