Merge branch 'pr/28'
[arvados.git] / sdk / go / keepclient / keepclient.go
index 8272d1162f06b4f72e7561ed5c1fc152052b172b..67c304deaf3ae54b2668cb8c2f2856e909da8c5a 100644 (file)
@@ -13,7 +13,6 @@ import (
        "io/ioutil"
        "log"
        "net/http"
-       "os"
        "regexp"
        "strconv"
        "strings"
@@ -29,7 +28,6 @@ var OversizeBlockError = errors.New("Exceeded maximum block size (" + strconv.It
 var MissingArvadosApiHost = errors.New("Missing required environment variable ARVADOS_API_HOST")
 var MissingArvadosApiToken = errors.New("Missing required environment variable ARVADOS_API_TOKEN")
 var InvalidLocatorError = errors.New("Invalid locator")
-var KeepServerError = errors.New("One or more keep servers returned an error")
 
 // ErrNoSuchKeepServer is returned when GetIndex is invoked with a UUID with no matching keep server
 var ErrNoSuchKeepServer = errors.New("No keep server matching the given UUID is found")
@@ -56,12 +54,15 @@ type KeepClient struct {
        replicasPerService int
 }
 
-// Create a new KeepClient.  This will contact the API server to discover Keep
-// servers.
+// MakeKeepClient creates a new KeepClient by contacting the API server to discover Keep servers.
 func MakeKeepClient(arv *arvadosclient.ArvadosClient) (*KeepClient, error) {
-       var matchTrue = regexp.MustCompile("^(?i:1|yes|true)$")
-       insecure := matchTrue.MatchString(os.Getenv("ARVADOS_API_HOST_INSECURE"))
+       kc := New(arv)
+       return kc, kc.DiscoverKeepServers()
+}
 
+// New func creates a new KeepClient struct.
+// This func does not discover keep servers. It is the caller's responsibility.
+func New(arv *arvadosclient.ArvadosClient) *KeepClient {
        defaultReplicationLevel := 2
        value, err := arv.Discovery("defaultCollectionReplication")
        if err == nil {
@@ -76,10 +77,10 @@ func MakeKeepClient(arv *arvadosclient.ArvadosClient) (*KeepClient, error) {
                Want_replicas: defaultReplicationLevel,
                Using_proxy:   false,
                Client: &http.Client{Transport: &http.Transport{
-                       TLSClientConfig: &tls.Config{InsecureSkipVerify: insecure}}},
+                       TLSClientConfig: &tls.Config{InsecureSkipVerify: arv.ApiInsecure}}},
                Retries: 2,
        }
-       return kc, kc.DiscoverKeepServers()
+       return kc
 }
 
 // Put a block given the block hash, a reader, and the number of bytes
@@ -140,23 +141,21 @@ func (kc *KeepClient) PutR(r io.Reader) (locator string, replicas int, err error
        }
 }
 
-// Get() retrieves a block, given a locator. Returns a reader, the
-// expected data length, the URL the block is being fetched from, and
-// an error.
-//
-// If the block checksum does not match, the final Read() on the
-// reader returned by this method will return a BadChecksum error
-// instead of EOF.
-func (kc *KeepClient) Get(locator string) (io.ReadCloser, int64, string, error) {
+func (kc *KeepClient) getOrHead(method string, locator string) (io.ReadCloser, int64, string, error) {
        var errs []string
-       server_error := false
-
-       for _, host := range kc.getSortedRoots(locator) {
-               url := host + "/" + locator
-               tries_remaining := 1 + kc.Retries
-               for tries_remaining > 0 {
-                       tries_remaining -= 1
-                       req, err := http.NewRequest("GET", url, nil)
+
+       tries_remaining := 1 + kc.Retries
+       serversToTry := kc.getSortedRoots(locator)
+       var retryList []string
+
+       for tries_remaining > 0 {
+               tries_remaining -= 1
+               retryList = nil
+
+               for _, host := range serversToTry {
+                       url := host + "/" + locator
+
+                       req, err := http.NewRequest(method, url, nil)
                        if err != nil {
                                errs = append(errs, fmt.Sprintf("%s: %v", url, err))
                                continue
@@ -164,45 +163,56 @@ func (kc *KeepClient) Get(locator string) (io.ReadCloser, int64, string, error)
                        req.Header.Add("Authorization", fmt.Sprintf("OAuth2 %s", kc.Arvados.ApiToken))
                        resp, err := kc.Client.Do(req)
                        if err != nil {
-                               // Probably a network error, may be
-                               // transient, can try again.
-                               server_error = true
+                               // Probably a network error, may be transient,
+                               // can try again.
                                errs = append(errs, fmt.Sprintf("%s: %v", url, err))
+                               retryList = append(retryList, host)
                        } else if resp.StatusCode != http.StatusOK {
-                               respbody, _ := ioutil.ReadAll(&io.LimitedReader{resp.Body, 4096})
+                               var respbody []byte
+                               respbody, _ = ioutil.ReadAll(&io.LimitedReader{resp.Body, 4096})
                                resp.Body.Close()
-                               errs = append(errs, fmt.Sprintf("%s: %d %s",
+                               errs = append(errs, fmt.Sprintf("%s: HTTP %d %q",
                                        url, resp.StatusCode, bytes.TrimSpace(respbody)))
 
-                               if resp.StatusCode >= 500 {
-                                       // Server side failure, may be
-                                       // transient, can try again.
-                                       server_error = true
-                               } else {
-                                       // Some other error (4xx),
-                                       // typically 403 or 404, don't
-                                       // try again.
-                                       tries_remaining = 0
+                               if resp.StatusCode == 408 ||
+                                       resp.StatusCode == 429 ||
+                                       resp.StatusCode >= 500 {
+                                       // Timeout, too many requests, or other
+                                       // server side failure, transient
+                                       // error, can try again.
+                                       retryList = append(retryList, host)
                                }
                        } else {
                                // Success.
-                               return HashCheckingReader{
-                                       Reader: resp.Body,
-                                       Hash:   md5.New(),
-                                       Check:  locator[0:32],
-                               }, resp.ContentLength, url, nil
+                               if method == "GET" {
+                                       return HashCheckingReader{
+                                               Reader: resp.Body,
+                                               Hash:   md5.New(),
+                                               Check:  locator[0:32],
+                                       }, resp.ContentLength, url, nil
+                               } else {
+                                       resp.Body.Close()
+                                       return nil, resp.ContentLength, url, nil
+                               }
                        }
+
                }
+               serversToTry = retryList
        }
-       log.Printf("DEBUG: GET %s failed: %v", locator, errs)
+       log.Printf("DEBUG: %s %s failed: %v", method, locator, errs)
 
-       if server_error {
-               // There was at least one failure to get a final answer
-               return nil, 0, "", KeepServerError
-       } else {
-               // Ever server returned a 4xx error
-               return nil, 0, "", BlockNotFound
-       }
+       return nil, 0, "", BlockNotFound
+}
+
+// Get() retrieves a block, given a locator. Returns a reader, the
+// expected data length, the URL the block is being fetched from, and
+// an error.
+//
+// If the block checksum does not match, the final Read() on the
+// reader returned by this method will return a BadChecksum error
+// instead of EOF.
+func (kc *KeepClient) Get(locator string) (io.ReadCloser, int64, string, error) {
+       return kc.getOrHead("GET", locator)
 }
 
 // Ask() verifies that a block with the given hash is available and
@@ -213,18 +223,8 @@ func (kc *KeepClient) Get(locator string) (io.ReadCloser, int64, string, error)
 // Returns the data size (content length) reported by the Keep service
 // and the URI reporting the data size.
 func (kc *KeepClient) Ask(locator string) (int64, string, error) {
-       for _, host := range kc.getSortedRoots(locator) {
-               url := host + "/" + locator
-               req, err := http.NewRequest("HEAD", url, nil)
-               if err != nil {
-                       continue
-               }
-               req.Header.Add("Authorization", fmt.Sprintf("OAuth2 %s", kc.Arvados.ApiToken))
-               if resp, err := kc.Client.Do(req); err == nil && resp.StatusCode == http.StatusOK {
-                       return resp.ContentLength, url, nil
-               }
-       }
-       return 0, "", BlockNotFound
+       _, size, url, err := kc.getOrHead("HEAD", locator)
+       return size, url, err
 }
 
 // GetIndex retrieves a list of blocks stored on the given server whose hashes