Merge branch 'github-pr-224'
[arvados.git] / lib / controller / handler.go
index 05b45a0fc0e6270543cf43189fa490f34230cc80..7c4bb0912fb3feae8871d9a5e2f920bb777738c4 100644 (file)
@@ -140,6 +140,8 @@ func (h *Handler) setup() {
        mux.Handle("/arvados/v1/groups/", rtr)
        mux.Handle("/arvados/v1/links", rtr)
        mux.Handle("/arvados/v1/links/", rtr)
+       mux.Handle("/arvados/v1/authorized_keys", rtr)
+       mux.Handle("/arvados/v1/authorized_keys/", rtr)
        mux.Handle("/login", rtr)
        mux.Handle("/logout", rtr)
        mux.Handle("/arvados/v1/api_client_authorizations", rtr)
@@ -147,6 +149,7 @@ func (h *Handler) setup() {
 
        hs := http.NotFoundHandler()
        hs = prepend(hs, h.proxyRailsAPI)
+       hs = prepend(hs, h.routeContainerEndpoints(rtr))
        hs = prepend(hs, h.limitLogCreateRequests)
        hs = h.setupProxyRemoteCluster(hs)
        hs = prepend(hs, oidcAuthorizer.Middleware)
@@ -171,7 +174,7 @@ func (h *Handler) setup() {
                Name: "arvados-controller",
        }
        h.cache = map[string]*cacheEnt{
-               "/discovery/v1/apis/arvados/v1/rest": &cacheEnt{},
+               "/discovery/v1/apis/arvados/v1/rest": &cacheEnt{validate: validateDiscoveryDoc},
        }
 
        go h.trashSweepWorker()
@@ -202,9 +205,32 @@ func (h *Handler) localClusterRequest(req *http.Request) (*http.Response, error)
        if insecure {
                client = h.insecureClient
        }
+       // Clearing the Host field here causes the Go http client to
+       // use the host part of urlOut as the Host header in the
+       // outgoing request, instead of the Host value from the
+       // original request we received.
+       req.Host = ""
        return h.proxy.Do(req, urlOut, client)
 }
 
+// Route /arvados/v1/containers/{uuid}/log*, .../ssh, and
+// .../gateway_tunnel to rtr, pass everything else to next.
+//
+// (http.ServeMux doesn't let us route these without also routing
+// everything under /containers/, which we don't want yet.)
+func (h *Handler) routeContainerEndpoints(rtr http.Handler) middlewareFunc {
+       return func(w http.ResponseWriter, req *http.Request, next http.Handler) {
+               trim := strings.TrimPrefix(req.URL.Path, "/arvados/v1/containers/")
+               if trim != req.URL.Path && (strings.Index(trim, "/log") == 27 ||
+                       strings.Index(trim, "/ssh") == 27 ||
+                       strings.Index(trim, "/gateway_tunnel") == 27) {
+                       rtr.ServeHTTP(w, req)
+               } else {
+                       next.ServeHTTP(w, req)
+               }
+       }
+}
+
 func (h *Handler) limitLogCreateRequests(w http.ResponseWriter, req *http.Request, next http.Handler) {
        if cap(h.limitLogCreate) > 0 && req.Method == http.MethodPost && strings.HasPrefix(req.URL.Path, "/arvados/v1/logs") {
                select {
@@ -222,14 +248,19 @@ func (h *Handler) limitLogCreateRequests(w http.ResponseWriter, req *http.Reques
 // cacheEnt implements a basic stale-while-revalidate cache, suitable
 // for the Arvados discovery document.
 type cacheEnt struct {
+       validate     func(body []byte) error
        mtx          sync.Mutex
        header       http.Header
        body         []byte
+       expireAfter  time.Time
        refreshAfter time.Time
        refreshLock  sync.Mutex
 }
 
-const cacheTTL = 5 * time.Minute
+const (
+       cacheTTL    = 5 * time.Minute
+       cacheExpire = 24 * time.Hour
+)
 
 func (ent *cacheEnt) refresh(path string, do func(*http.Request) (*http.Response, error)) (http.Header, []byte, error) {
        ent.refreshLock.Lock()
@@ -253,13 +284,15 @@ func (ent *cacheEnt) refresh(path string, do func(*http.Request) (*http.Response
 
        ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Minute))
        defer cancel()
-       // 0.0.0.0:0 is just a placeholder here -- do(), which is
+       // "http://localhost" is just a placeholder here -- we'll fill
+       // in req.URL.Path below, and then do(), which is
        // localClusterRequest(), will replace the scheme and host
        // parts with the real proxy destination.
-       req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://0.0.0.0:0/"+path, nil)
+       req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost", nil)
        if err != nil {
                return nil, nil, err
        }
+       req.URL.Path = path
        resp, err := do(req)
        if err != nil {
                return nil, nil, err
@@ -272,12 +305,16 @@ func (ent *cacheEnt) refresh(path string, do func(*http.Request) (*http.Response
                return nil, nil, fmt.Errorf("Read error: %w", err)
        }
        header := http.Header{}
-       for _, k := range []string{"Content-Type", "Etag", "Last-Modified"} {
-               if v, ok := header[k]; ok {
-                       resp.Header[k] = v
+       for k, v := range resp.Header {
+               if !dropHeaders[k] && k != "X-Request-Id" {
+                       header[k] = v
                }
        }
-       if mediatype, _, err := mime.ParseMediaType(header.Get("Content-Type")); err == nil && mediatype == "application/json" {
+       if ent.validate != nil {
+               if err := ent.validate(body); err != nil {
+                       return nil, nil, err
+               }
+       } else if mediatype, _, err := mime.ParseMediaType(header.Get("Content-Type")); err == nil && mediatype == "application/json" {
                if !json.Valid(body) {
                        return nil, nil, errors.New("invalid JSON encoding in response")
                }
@@ -287,12 +324,16 @@ func (ent *cacheEnt) refresh(path string, do func(*http.Request) (*http.Response
        ent.header = header
        ent.body = body
        ent.refreshAfter = time.Now().Add(cacheTTL)
+       ent.expireAfter = time.Now().Add(cacheExpire)
        return ent.header, ent.body, nil
 }
 
 func (ent *cacheEnt) response() (http.Header, []byte, bool) {
        ent.mtx.Lock()
        defer ent.mtx.Unlock()
+       if ent.expireAfter.Before(time.Now()) {
+               ent.header, ent.body, ent.refreshAfter = nil, nil, time.Time{}
+       }
        return ent.header, ent.body, ent.refreshAfter.Before(time.Now())
 }
 
@@ -350,3 +391,15 @@ func findRailsAPI(cluster *arvados.Cluster) (*url.URL, bool, error) {
        }
        return best, cluster.TLS.Insecure, nil
 }
+
+func validateDiscoveryDoc(body []byte) error {
+       var dd arvados.DiscoveryDocument
+       err := json.Unmarshal(body, &dd)
+       if err != nil {
+               return fmt.Errorf("error decoding JSON response: %w", err)
+       }
+       if dd.BasePath == "" {
+               return errors.New("error in discovery document: no value for basePath")
+       }
+       return nil
+}