package auth
import (
+ "context"
"encoding/base64"
"net/http"
"net/url"
Tokens []string
}
-func NewCredentials() *Credentials {
- return &Credentials{Tokens: []string{}}
+func NewCredentials(tokens ...string) *Credentials {
+ return &Credentials{Tokens: tokens}
}
-func NewCredentialsFromHTTPRequest(r *http.Request) *Credentials {
+func NewContext(ctx context.Context, c *Credentials) context.Context {
+ return context.WithValue(ctx, contextKeyCredentials{}, c)
+}
+
+func FromContext(ctx context.Context) (*Credentials, bool) {
+ c, ok := ctx.Value(contextKeyCredentials{}).(*Credentials)
+ return c, ok
+}
+
+func CredentialsFromRequest(r *http.Request) *Credentials {
+ if c, ok := FromContext(r.Context()); ok {
+ // preloaded by middleware
+ return c
+ }
c := NewCredentials()
c.LoadTokensFromHTTPRequest(r)
return c
// Load plain token from "Authorization: OAuth2 ..." header
// (typically used by smart API clients)
if toks := strings.SplitN(r.Header.Get("Authorization"), " ", 2); len(toks) == 2 && (toks[0] == "OAuth2" || toks[0] == "Bearer") {
- a.Tokens = append(a.Tokens, toks[1])
+ a.Tokens = append(a.Tokens, strings.TrimSpace(toks[1]))
}
// Load base64-encoded token from "Authorization: Basic ..."
// header (typically used by git via credential helper)
if _, password, ok := r.BasicAuth(); ok {
- a.Tokens = append(a.Tokens, password)
+ a.Tokens = append(a.Tokens, strings.TrimSpace(password))
}
// Load tokens from query string. It's generally not a good
// find/report decoding errors in a suitable way.
qvalues, _ := url.ParseQuery(r.URL.RawQuery)
if val, ok := qvalues["api_token"]; ok {
- a.Tokens = append(a.Tokens, val...)
+ for _, token := range val {
+ a.Tokens = append(a.Tokens, strings.TrimSpace(token))
+ }
}
a.loadTokenFromCookie(r)
if err != nil {
return
}
- a.Tokens = append(a.Tokens, string(token))
+ a.Tokens = append(a.Tokens, strings.TrimSpace(string(token)))
}
-// LoadTokensFromHTTPRequestBody() loads credentials from the request
+// LoadTokensFromHTTPRequestBody loads credentials from the request
// body.
//
// This is separate from LoadTokensFromHTTPRequest() because it's not
return err
}
if t := r.PostFormValue("api_token"); t != "" {
- a.Tokens = append(a.Tokens, t)
+ a.Tokens = append(a.Tokens, strings.TrimSpace(t))
}
return nil
}