15954: Merge branch 'master'
[arvados.git] / lib / controller / federation / conn.go
index 1b1ffef24cf3a1c6bb2245f3d4c9cb258118e620..279b7a51d5d8d4e57f920721db10ace45268b1d2 100644 (file)
@@ -7,7 +7,6 @@ package federation
 import (
        "bytes"
        "context"
-       "crypto/md5"
        "encoding/json"
        "errors"
        "fmt"
@@ -35,10 +34,14 @@ func New(cluster *arvados.Cluster) *Conn {
        local := localdb.NewConn(cluster)
        remotes := map[string]backend{}
        for id, remote := range cluster.RemoteClusters {
-               if !remote.Proxy {
+               if !remote.Proxy || id == cluster.ClusterID {
                        continue
                }
-               remotes[id] = rpc.NewConn(id, &url.URL{Scheme: remote.Scheme, Host: remote.Host}, remote.Insecure, saltedTokenProvider(local, id))
+               conn := rpc.NewConn(id, &url.URL{Scheme: remote.Scheme, Host: remote.Host}, remote.Insecure, saltedTokenProvider(local, id))
+               // Older versions of controller rely on the Via header
+               // to detect loops.
+               conn.SendHeader = http.Header{"Via": {"HTTP/1.1 arvados-controller"}}
+               remotes[id] = conn
        }
 
        return &Conn{
@@ -116,8 +119,13 @@ func (conn *Conn) chooseBackend(id string) backend {
 // or "" for the local backend.
 //
 // A non-nil error means all backends failed.
-func (conn *Conn) tryLocalThenRemotes(ctx context.Context, fn func(context.Context, string, backend) error) error {
-       if err := fn(ctx, "", conn.local); err == nil || errStatus(err) != http.StatusNotFound {
+func (conn *Conn) tryLocalThenRemotes(ctx context.Context, forwardedFor string, fn func(context.Context, string, backend) error) error {
+       if err := fn(ctx, "", conn.local); err == nil || errStatus(err) != http.StatusNotFound || forwardedFor != "" {
+               // Note: forwardedFor != "" means this request came
+               // from a remote cluster, so we don't take a second
+               // hop. This avoids cycles, redundant calls to a
+               // mutually reachable remote, and use of double-salted
+               // tokens.
                return err
        }
 
@@ -160,26 +168,6 @@ func rewriteManifest(mt, remoteID string) string {
        })
 }
 
-// this could be in sdk/go/arvados
-func portableDataHash(mt string) string {
-       h := md5.New()
-       blkRe := regexp.MustCompile(`^ [0-9a-f]{32}\+\d+`)
-       size := 0
-       _ = regexp.MustCompile(` ?[^ ]*`).ReplaceAllFunc([]byte(mt), func(tok []byte) []byte {
-               if m := blkRe.Find(tok); m != nil {
-                       // write hash+size, ignore remaining block hints
-                       tok = m
-               }
-               n, err := h.Write(tok)
-               if err != nil {
-                       panic(err)
-               }
-               size += n
-               return nil
-       })
-       return fmt.Sprintf("%x+%d", h.Sum(nil), size)
-}
-
 func (conn *Conn) ConfigGet(ctx context.Context) (json.RawMessage, error) {
        var buf bytes.Buffer
        err := config.ExportJSON(&buf, conn.cluster)
@@ -213,6 +201,32 @@ func (conn *Conn) Login(ctx context.Context, options arvados.LoginOptions) (arva
        }
 }
 
+func (conn *Conn) Logout(ctx context.Context, options arvados.LogoutOptions) (arvados.LogoutResponse, error) {
+       // If the logout request comes with an API token from a known
+       // remote cluster, redirect to that cluster's logout handler
+       // so it has an opportunity to clear sessions, expire tokens,
+       // etc. Otherwise use the local endpoint.
+       reqauth, ok := auth.FromContext(ctx)
+       if !ok || len(reqauth.Tokens) == 0 || len(reqauth.Tokens[0]) < 8 || !strings.HasPrefix(reqauth.Tokens[0], "v2/") {
+               return conn.local.Logout(ctx, options)
+       }
+       id := reqauth.Tokens[0][3:8]
+       if id == conn.cluster.ClusterID {
+               return conn.local.Logout(ctx, options)
+       }
+       remote, ok := conn.remotes[id]
+       if !ok {
+               return conn.local.Logout(ctx, options)
+       }
+       baseURL := remote.BaseURL()
+       target, err := baseURL.Parse(arvados.EndpointLogout.Path)
+       if err != nil {
+               return arvados.LogoutResponse{}, fmt.Errorf("internal error getting redirect target: %s", err)
+       }
+       target.RawQuery = url.Values{"return_to": {options.ReturnTo}}.Encode()
+       return arvados.LogoutResponse{RedirectLocation: target.String()}, nil
+}
+
 func (conn *Conn) CollectionGet(ctx context.Context, options arvados.GetOptions) (arvados.Collection, error) {
        if len(options.UUID) == 27 {
                // UUID is really a UUID
@@ -224,15 +238,17 @@ func (conn *Conn) CollectionGet(ctx context.Context, options arvados.GetOptions)
        } else {
                // UUID is a PDH
                first := make(chan arvados.Collection, 1)
-               err := conn.tryLocalThenRemotes(ctx, func(ctx context.Context, remoteID string, be backend) error {
-                       c, err := be.CollectionGet(ctx, options)
+               err := conn.tryLocalThenRemotes(ctx, options.ForwardedFor, func(ctx context.Context, remoteID string, be backend) error {
+                       remoteOpts := options
+                       remoteOpts.ForwardedFor = conn.cluster.ClusterID + "-" + options.ForwardedFor
+                       c, err := be.CollectionGet(ctx, remoteOpts)
                        if err != nil {
                                return err
                        }
                        // options.UUID is either hash+size or
                        // hash+size+hints; only hash+size need to
                        // match the computed PDH.
-                       if pdh := portableDataHash(c.ManifestText); pdh != options.UUID && !strings.HasPrefix(options.UUID, pdh+"+") {
+                       if pdh := arvados.PortableDataHash(c.ManifestText); pdh != options.UUID && !strings.HasPrefix(options.UUID, pdh+"+") {
                                err = httpErrorf(http.StatusBadGateway, "bad portable data hash %q received from remote %q (expected %q)", pdh, remoteID, options.UUID)
                                ctxlog.FromContext(ctx).Warn(err)
                                return err