Merge branch 'patch-1' of https://github.com/mr-c/arvados into mr-c-patch-1
[arvados.git] / sdk / go / arvados / client.go
index a5815987b192a86c9ee646205bcc9ea0f7986dcc..562c8c1e7d7c66528a2ce0874eca034c9eb7b328 100644 (file)
@@ -20,7 +20,7 @@ import (
        "strings"
        "time"
 
-       "git.curoverse.com/arvados.git/sdk/go/httpserver"
+       "git.arvados.org/arvados.git/sdk/go/httpserver"
 )
 
 // A Client is an HTTP client with an API endpoint and a set of
@@ -54,9 +54,19 @@ type Client struct {
        // arvadosclient.ArvadosClient.)
        KeepServiceURIs []string `json:",omitempty"`
 
+       // HTTP headers to add/override in outgoing requests.
+       SendHeader http.Header
+
+       // Timeout for requests. NewClientFromConfig and
+       // NewClientFromEnv return a Client with a default 5 minute
+       // timeout.  To disable this timeout and rely on each
+       // http.Request's context deadline instead, set Timeout to
+       // zero.
+       Timeout time.Duration
+
        dd *DiscoveryDocument
 
-       ctx context.Context
+       defaultRequestID string
 }
 
 // The default http.Client used by a Client with Insecure==true and
@@ -64,12 +74,10 @@ type Client struct {
 var InsecureHTTPClient = &http.Client{
        Transport: &http.Transport{
                TLSClientConfig: &tls.Config{
-                       InsecureSkipVerify: true}},
-       Timeout: 5 * time.Minute}
+                       InsecureSkipVerify: true}}}
 
 // The default http.Client used by a Client otherwise.
-var DefaultSecureClient = &http.Client{
-       Timeout: 5 * time.Minute}
+var DefaultSecureClient = &http.Client{}
 
 // NewClientFromConfig creates a new Client that uses the endpoints in
 // the given cluster.
@@ -84,6 +92,7 @@ func NewClientFromConfig(cluster *Cluster) (*Client, error) {
                Scheme:   ctrlURL.Scheme,
                APIHost:  ctrlURL.Host,
                Insecure: cluster.TLS.Insecure,
+               Timeout:  5 * time.Minute,
        }, nil
 }
 
@@ -113,6 +122,7 @@ func NewClientFromEnv() *Client {
                AuthToken:       os.Getenv("ARVADOS_API_TOKEN"),
                Insecure:        insecure,
                KeepServiceURIs: svcs,
+               Timeout:         5 * time.Minute,
        }
 }
 
@@ -128,11 +138,12 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) {
        }
 
        if req.Header.Get("X-Request-Id") == "" {
-               reqid, _ := req.Context().Value(contextKeyRequestID{}).(string)
-               if reqid == "" {
-                       reqid, _ = c.context().Value(contextKeyRequestID{}).(string)
-               }
-               if reqid == "" {
+               var reqid string
+               if ctxreqid, _ := req.Context().Value(contextKeyRequestID{}).(string); ctxreqid != "" {
+                       reqid = ctxreqid
+               } else if c.defaultRequestID != "" {
+                       reqid = c.defaultRequestID
+               } else {
                        reqid = reqIDGen.Next()
                }
                if req.Header == nil {
@@ -141,12 +152,54 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) {
                        req.Header.Set("X-Request-Id", reqid)
                }
        }
-       return c.httpClient().Do(req)
+       var cancel context.CancelFunc
+       if c.Timeout > 0 {
+               ctx := req.Context()
+               ctx, cancel = context.WithDeadline(ctx, time.Now().Add(c.Timeout))
+               req = req.WithContext(ctx)
+       }
+       resp, err := c.httpClient().Do(req)
+       if err == nil && cancel != nil {
+               // We need to call cancel() eventually, but we can't
+               // use "defer cancel()" because the context has to
+               // stay alive until the caller has finished reading
+               // the response body.
+               resp.Body = cancelOnClose{ReadCloser: resp.Body, cancel: cancel}
+       } else if cancel != nil {
+               cancel()
+       }
+       return resp, err
+}
+
+// cancelOnClose calls a provided CancelFunc when its wrapped
+// ReadCloser's Close() method is called.
+type cancelOnClose struct {
+       io.ReadCloser
+       cancel context.CancelFunc
+}
+
+func (coc cancelOnClose) Close() error {
+       err := coc.ReadCloser.Close()
+       coc.cancel()
+       return err
+}
+
+func isRedirectStatus(code int) bool {
+       switch code {
+       case http.StatusMovedPermanently, http.StatusFound, http.StatusSeeOther, http.StatusTemporaryRedirect, http.StatusPermanentRedirect:
+               return true
+       default:
+               return false
+       }
 }
 
 // DoAndDecode performs req and unmarshals the response (which must be
 // JSON) into dst. Use this instead of RequestAndDecode if you need
 // more control of the http.Request object.
+//
+// If the response status indicates an HTTP redirect, the Location
+// header value is unmarshalled to dst as a RedirectLocation
+// key/field.
 func (c *Client) DoAndDecode(dst interface{}, req *http.Request) error {
        resp, err := c.Do(req)
        if err != nil {
@@ -157,13 +210,28 @@ func (c *Client) DoAndDecode(dst interface{}, req *http.Request) error {
        if err != nil {
                return err
        }
-       if resp.StatusCode != 200 {
-               return newTransactionError(req, resp, buf)
-       }
-       if dst == nil {
+       switch {
+       case resp.StatusCode == http.StatusOK && dst == nil:
+               return nil
+       case resp.StatusCode == http.StatusOK:
+               return json.Unmarshal(buf, dst)
+
+       // If the caller uses a client with a custom CheckRedirect
+       // func, Do() might return the 3xx response instead of
+       // following it.
+       case isRedirectStatus(resp.StatusCode) && dst == nil:
                return nil
+       case isRedirectStatus(resp.StatusCode):
+               // Copy the redirect target URL to dst.RedirectLocation.
+               buf, err := json.Marshal(map[string]string{"redirect_location": resp.Header.Get("Location")})
+               if err != nil {
+                       return err
+               }
+               return json.Unmarshal(buf, dst)
+
+       default:
+               return newTransactionError(req, resp, buf)
        }
-       return json.Unmarshal(buf, dst)
 }
 
 // Convert an arbitrary struct to url.Values. For example,
@@ -235,7 +303,7 @@ func anythingToValues(params interface{}) (url.Values, error) {
 //
 // path must not contain a query string.
 func (c *Client) RequestAndDecode(dst interface{}, method, path string, body io.Reader, params interface{}) error {
-       return c.RequestAndDecodeContext(c.context(), dst, method, path, body, params)
+       return c.RequestAndDecodeContext(context.Background(), dst, method, path, body, params)
 }
 
 func (c *Client) RequestAndDecodeContext(ctx context.Context, dst interface{}, method, path string, body io.Reader, params interface{}) error {
@@ -250,9 +318,8 @@ func (c *Client) RequestAndDecodeContext(ctx context.Context, dst interface{}, m
        }
        if urlValues == nil {
                // Nothing to send
-       } else if method == "GET" || method == "HEAD" || body != nil {
-               // Must send params in query part of URL (FIXME: what
-               // if resulting URL is too long?)
+       } else if body != nil || ((method == "GET" || method == "HEAD") && len(urlValues.Encode()) < 1000) {
+               // Send params in query part of URL
                u, err := url.Parse(urlString)
                if err != nil {
                        return err
@@ -266,8 +333,15 @@ func (c *Client) RequestAndDecodeContext(ctx context.Context, dst interface{}, m
        if err != nil {
                return err
        }
+       if (method == "GET" || method == "HEAD") && body != nil {
+               req.Header.Set("X-Http-Method-Override", method)
+               req.Method = "POST"
+       }
        req = req.WithContext(ctx)
        req.Header.Set("Content-type", "application/x-www-form-urlencoded")
+       for k, v := range c.SendHeader {
+               req.Header[k] = v
+       }
        return c.DoAndDecode(dst, req)
 }
 
@@ -295,17 +369,10 @@ func (c *Client) UpdateBody(rsc resource) io.Reader {
 // header.
 func (c *Client) WithRequestID(reqid string) *Client {
        cc := *c
-       cc.ctx = ContextWithRequestID(cc.context(), reqid)
+       cc.defaultRequestID = reqid
        return &cc
 }
 
-func (c *Client) context() context.Context {
-       if c.ctx == nil {
-               return context.Background()
-       }
-       return c.ctx
-}
-
 func (c *Client) httpClient() *http.Client {
        switch {
        case c.Client != nil: