Merge branch '21535-multi-wf-delete'
[arvados.git] / sdk / go / auth / auth.go
index ad1d398c763d7eaacefefcde8993e39044582f2a..da9b4ea5b8f193fd83072e72ae3ece3cfa6602bc 100644 (file)
@@ -5,6 +5,7 @@
 package auth
 
 import (
+       "context"
        "encoding/base64"
        "net/http"
        "net/url"
@@ -15,11 +16,24 @@ type Credentials struct {
        Tokens []string
 }
 
-func NewCredentials() *Credentials {
-       return &Credentials{Tokens: []string{}}
+func NewCredentials(tokens ...string) *Credentials {
+       return &Credentials{Tokens: tokens}
 }
 
-func NewCredentialsFromHTTPRequest(r *http.Request) *Credentials {
+func NewContext(ctx context.Context, c *Credentials) context.Context {
+       return context.WithValue(ctx, contextKeyCredentials{}, c)
+}
+
+func FromContext(ctx context.Context) (*Credentials, bool) {
+       c, ok := ctx.Value(contextKeyCredentials{}).(*Credentials)
+       return c, ok
+}
+
+func CredentialsFromRequest(r *http.Request) *Credentials {
+       if c, ok := FromContext(r.Context()); ok {
+               // preloaded by middleware
+               return c
+       }
        c := NewCredentials()
        c.LoadTokensFromHTTPRequest(r)
        return c
@@ -40,13 +54,13 @@ func (a *Credentials) LoadTokensFromHTTPRequest(r *http.Request) {
        // Load plain token from "Authorization: OAuth2 ..." header
        // (typically used by smart API clients)
        if toks := strings.SplitN(r.Header.Get("Authorization"), " ", 2); len(toks) == 2 && (toks[0] == "OAuth2" || toks[0] == "Bearer") {
-               a.Tokens = append(a.Tokens, toks[1])
+               a.Tokens = append(a.Tokens, strings.TrimSpace(toks[1]))
        }
 
        // Load base64-encoded token from "Authorization: Basic ..."
        // header (typically used by git via credential helper)
        if _, password, ok := r.BasicAuth(); ok {
-               a.Tokens = append(a.Tokens, password)
+               a.Tokens = append(a.Tokens, strings.TrimSpace(password))
        }
 
        // Load tokens from query string. It's generally not a good
@@ -62,7 +76,9 @@ func (a *Credentials) LoadTokensFromHTTPRequest(r *http.Request) {
        // find/report decoding errors in a suitable way.
        qvalues, _ := url.ParseQuery(r.URL.RawQuery)
        if val, ok := qvalues["api_token"]; ok {
-               a.Tokens = append(a.Tokens, val...)
+               for _, token := range val {
+                       a.Tokens = append(a.Tokens, strings.TrimSpace(token))
+               }
        }
 
        a.loadTokenFromCookie(r)
@@ -80,10 +96,10 @@ func (a *Credentials) loadTokenFromCookie(r *http.Request) {
        if err != nil {
                return
        }
-       a.Tokens = append(a.Tokens, string(token))
+       a.Tokens = append(a.Tokens, strings.TrimSpace(string(token)))
 }
 
-// LoadTokensFromHTTPRequestBody() loads credentials from the request
+// LoadTokensFromHTTPRequestBody loads credentials from the request
 // body.
 //
 // This is separate from LoadTokensFromHTTPRequest() because it's not
@@ -97,7 +113,7 @@ func (a *Credentials) LoadTokensFromHTTPRequestBody(r *http.Request) error {
                return err
        }
        if t := r.PostFormValue("api_token"); t != "" {
-               a.Tokens = append(a.Tokens, t)
+               a.Tokens = append(a.Tokens, strings.TrimSpace(t))
        }
        return nil
 }