14287: Clean up context key usage.
authorTom Clegg <tclegg@veritasgenetics.com>
Tue, 25 Jun 2019 15:39:41 +0000 (11:39 -0400)
committerTom Clegg <tclegg@veritasgenetics.com>
Tue, 25 Jun 2019 15:39:41 +0000 (11:39 -0400)
Arvados-DCO-1.1-Signed-off-by: Tom Clegg <tclegg@veritasgenetics.com>

lib/controller/federation/conn.go
lib/controller/railsproxy/railsproxy.go
lib/controller/router/router.go
lib/controller/rpc/conn.go
sdk/go/arvados/client.go
sdk/go/arvados/context.go
sdk/go/auth/auth.go
sdk/go/auth/handlers.go

index 3b60328954eb8a73cecdf9c648ff29b40a3ddf05..2d84db4df90f188d49702e7b2c1565835ab02fba 100644 (file)
@@ -51,7 +51,7 @@ func New(cluster *arvados.Cluster) arvados.API {
 func saltedTokenProvider(local backend, remoteID string) rpc.TokenProvider {
        return func(ctx context.Context) ([]string, error) {
                var tokens []string
-               incoming, ok := ctx.Value(auth.ContextKeyCredentials).(*auth.Credentials)
+               incoming, ok := auth.FromContext(ctx)
                if !ok {
                        return nil, errors.New("no token provided")
                }
@@ -63,7 +63,7 @@ func saltedTokenProvider(local backend, remoteID string) rpc.TokenProvider {
                        case auth.ErrSalted:
                                tokens = append(tokens, token)
                        case auth.ErrObsoleteToken:
-                               ctx := context.WithValue(ctx, auth.ContextKeyCredentials, &auth.Credentials{Tokens: []string{token}})
+                               ctx := auth.NewContext(ctx, &auth.Credentials{Tokens: []string{token}})
                                aca, err := local.APIClientAuthorizationCurrent(ctx, arvados.GetOptions{})
                                if errStatus(err) == http.StatusUnauthorized {
                                        // pass through unmodified
index 5070fa396d80b3c4308015dd1424ab2037d17abc..576e603eedd758f8ff53f2556e1161b6957b0691 100644 (file)
@@ -44,7 +44,7 @@ func NewConn(cluster *arvados.Cluster) *rpc.Conn {
 }
 
 func provideIncomingToken(ctx context.Context) ([]string, error) {
-       incoming, ok := ctx.Value(auth.ContextKeyCredentials).(*auth.Credentials)
+       incoming, ok := auth.FromContext(ctx)
        if !ok {
                return nil, errors.New("no token provided")
        }
index f846f2dcdf6f2b33aa65dad321932b69d859d3b6..f37c7ea9073ac51c0553ecf03c91ff4a9b1b2e92 100644 (file)
@@ -244,8 +244,7 @@ func (rtr *router) addRoute(endpoint arvados.APIEndpoint, defaultOpts func() int
                                }
                        }
                }
-               ctx := req.Context()
-               ctx = context.WithValue(ctx, auth.ContextKeyCredentials, creds)
+               ctx := auth.NewContext(req.Context(), creds)
                ctx = arvados.ContextWithRequestID(ctx, req.Header.Get("X-Request-Id"))
                logger.WithFields(logrus.Fields{
                        "apiEndpoint": endpoint,
index b32717f9a090d00bf9215050afe293fd78f4a018..e07eaf40affbe3ec0dc0d78422686926eea550c0 100644 (file)
@@ -19,10 +19,6 @@ import (
        "git.curoverse.com/arvados.git/sdk/go/arvados"
 )
 
-type contextKey string
-
-const ContextKeyCredentials contextKey = "credentials"
-
 type TokenProvider func(context.Context) ([]string, error)
 
 type Conn struct {
@@ -72,13 +68,13 @@ func (conn *Conn) requestAndDecode(ctx context.Context, dst interface{}, ep arva
        if err != nil {
                return err
        } else if len(tokens) > 0 {
-               ctx = context.WithValue(ctx, "Authorization", "Bearer "+tokens[0])
+               ctx = arvados.ContextWithAuthorization(ctx, "Bearer "+tokens[0])
        } else {
                // Use a non-empty auth string to ensure we override
                // any default token set on aClient -- and to avoid
                // having the remote prompt us to send a token by
                // responding 401.
-               ctx = context.WithValue(ctx, "Authorization", "Bearer -")
+               ctx = arvados.ContextWithAuthorization(ctx, "Bearer -")
        }
 
        // Encode opts to JSON and decode from there to a
index 102018bb1280289efe923643c3db521dc509756d..a5815987b192a86c9ee646205bcc9ea0f7986dcc 100644 (file)
@@ -121,16 +121,16 @@ var reqIDGen = httpserver.IDGenerator{Prefix: "req-"}
 // Do adds Authorization and X-Request-Id headers and then calls
 // (*http.Client)Do().
 func (c *Client) Do(req *http.Request) (*http.Response, error) {
-       if auth, _ := req.Context().Value("Authorization").(string); auth != "" {
+       if auth, _ := req.Context().Value(contextKeyAuthorization{}).(string); auth != "" {
                req.Header.Add("Authorization", auth)
        } else if c.AuthToken != "" {
                req.Header.Add("Authorization", "OAuth2 "+c.AuthToken)
        }
 
        if req.Header.Get("X-Request-Id") == "" {
-               reqid, _ := req.Context().Value(contextKeyRequestID).(string)
+               reqid, _ := req.Context().Value(contextKeyRequestID{}).(string)
                if reqid == "" {
-                       reqid, _ = c.context().Value(contextKeyRequestID).(string)
+                       reqid, _ = c.context().Value(contextKeyRequestID{}).(string)
                }
                if reqid == "" {
                        reqid = reqIDGen.Next()
index 555cfc8e9087874bf41cc087773e499209182dd9..6ecf85b4e0403f02d0109152c4de16253d8aef7e 100644 (file)
@@ -8,10 +8,17 @@ import (
        "context"
 )
 
-type contextKey string
-
-var contextKeyRequestID contextKey = "X-Request-Id"
+type contextKeyRequestID struct{}
+type contextKeyAuthorization struct{}
 
 func ContextWithRequestID(ctx context.Context, reqid string) context.Context {
-       return context.WithValue(ctx, contextKeyRequestID, reqid)
+       return context.WithValue(ctx, contextKeyRequestID{}, reqid)
+}
+
+// ContextWithAuthorization returns a child context that (when used
+// with (*Client)RequestAndDecodeContext) sends the given
+// Authorization header value instead of the Client's default
+// AuthToken.
+func ContextWithAuthorization(ctx context.Context, value string) context.Context {
+       return context.WithValue(ctx, contextKeyAuthorization{}, value)
 }
index de3b1e9523467d754d1354d587b67a73a506bfef..c2f6a0e8f0885e68a98f7e62a4ee4f17d0d930d2 100644 (file)
@@ -5,6 +5,7 @@
 package auth
 
 import (
+       "context"
        "encoding/base64"
        "net/http"
        "net/url"
@@ -19,8 +20,17 @@ func NewCredentials() *Credentials {
        return &Credentials{Tokens: []string{}}
 }
 
+func NewContext(ctx context.Context, c *Credentials) context.Context {
+       return context.WithValue(ctx, contextKeyCredentials{}, c)
+}
+
+func FromContext(ctx context.Context) (*Credentials, bool) {
+       c, ok := ctx.Value(contextKeyCredentials{}).(*Credentials)
+       return c, ok
+}
+
 func CredentialsFromRequest(r *http.Request) *Credentials {
-       if c, ok := r.Context().Value(ContextKeyCredentials).(*Credentials); ok {
+       if c, ok := FromContext(r.Context()); ok {
                // preloaded by middleware
                return c
        }
index 9fa501ab7abf6b71e3ccc5953c8c3f76a37cfcec..b638f7982516b431e15322adde0b55b0637c8a3f 100644 (file)
@@ -9,17 +9,15 @@ import (
        "net/http"
 )
 
-type contextKey string
-
-var ContextKeyCredentials contextKey = "credentials"
+type contextKeyCredentials struct{}
 
 // LoadToken wraps the next handler, adding credentials to the request
 // context so subsequent handlers can access them efficiently via
 // CredentialsFromRequest.
 func LoadToken(next http.Handler) http.Handler {
        return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-               if _, ok := r.Context().Value(ContextKeyCredentials).(*Credentials); !ok {
-                       r = r.WithContext(context.WithValue(r.Context(), ContextKeyCredentials, CredentialsFromRequest(r)))
+               if _, ok := r.Context().Value(contextKeyCredentials{}).(*Credentials); !ok {
+                       r = r.WithContext(context.WithValue(r.Context(), contextKeyCredentials{}, CredentialsFromRequest(r)))
                }
                next.ServeHTTP(w, r)
        })