X-Git-Url: https://git.arvados.org/arvados.git/blobdiff_plain/89be4b30feccc3680ca77339711b29367754dc05..22361307cf41f916afd562e7f33fcdaacefe5f9d:/sdk/go/auth/auth.go diff --git a/sdk/go/auth/auth.go b/sdk/go/auth/auth.go index 3c266e0d3a..da9b4ea5b8 100644 --- a/sdk/go/auth/auth.go +++ b/sdk/go/auth/auth.go @@ -5,6 +5,7 @@ package auth import ( + "context" "encoding/base64" "net/http" "net/url" @@ -15,12 +16,21 @@ type Credentials struct { Tokens []string } -func NewCredentials() *Credentials { - return &Credentials{Tokens: []string{}} +func NewCredentials(tokens ...string) *Credentials { + return &Credentials{Tokens: tokens} +} + +func NewContext(ctx context.Context, c *Credentials) context.Context { + return context.WithValue(ctx, contextKeyCredentials{}, c) +} + +func FromContext(ctx context.Context) (*Credentials, bool) { + c, ok := ctx.Value(contextKeyCredentials{}).(*Credentials) + return c, ok } func CredentialsFromRequest(r *http.Request) *Credentials { - if c, ok := r.Context().Value(contextKeyCredentials).(*Credentials); ok { + if c, ok := FromContext(r.Context()); ok { // preloaded by middleware return c } @@ -44,13 +54,13 @@ func (a *Credentials) LoadTokensFromHTTPRequest(r *http.Request) { // Load plain token from "Authorization: OAuth2 ..." header // (typically used by smart API clients) if toks := strings.SplitN(r.Header.Get("Authorization"), " ", 2); len(toks) == 2 && (toks[0] == "OAuth2" || toks[0] == "Bearer") { - a.Tokens = append(a.Tokens, toks[1]) + a.Tokens = append(a.Tokens, strings.TrimSpace(toks[1])) } // Load base64-encoded token from "Authorization: Basic ..." // header (typically used by git via credential helper) if _, password, ok := r.BasicAuth(); ok { - a.Tokens = append(a.Tokens, password) + a.Tokens = append(a.Tokens, strings.TrimSpace(password)) } // Load tokens from query string. It's generally not a good @@ -66,7 +76,9 @@ func (a *Credentials) LoadTokensFromHTTPRequest(r *http.Request) { // find/report decoding errors in a suitable way. qvalues, _ := url.ParseQuery(r.URL.RawQuery) if val, ok := qvalues["api_token"]; ok { - a.Tokens = append(a.Tokens, val...) + for _, token := range val { + a.Tokens = append(a.Tokens, strings.TrimSpace(token)) + } } a.loadTokenFromCookie(r) @@ -84,10 +96,10 @@ func (a *Credentials) loadTokenFromCookie(r *http.Request) { if err != nil { return } - a.Tokens = append(a.Tokens, string(token)) + a.Tokens = append(a.Tokens, strings.TrimSpace(string(token))) } -// LoadTokensFromHTTPRequestBody() loads credentials from the request +// LoadTokensFromHTTPRequestBody loads credentials from the request // body. // // This is separate from LoadTokensFromHTTPRequest() because it's not @@ -101,7 +113,7 @@ func (a *Credentials) LoadTokensFromHTTPRequestBody(r *http.Request) error { return err } if t := r.PostFormValue("api_token"); t != "" { - a.Tokens = append(a.Tokens, t) + a.Tokens = append(a.Tokens, strings.TrimSpace(t)) } return nil }