13497: Proxy requests to Rails API.
authorTom Clegg <tclegg@veritasgenetics.com>
Tue, 12 Jun 2018 20:00:07 +0000 (16:00 -0400)
committerTom Clegg <tclegg@veritasgenetics.com>
Thu, 14 Jun 2018 17:35:40 +0000 (13:35 -0400)
Arvados-DCO-1.1-Signed-off-by: Tom Clegg <tclegg@veritasgenetics.com>

lib/controller/cmd.go
lib/controller/handler.go
lib/controller/handler_test.go [new file with mode: 0644]
sdk/go/arvados/config.go
sdk/go/httpserver/logger.go

index 2bb68aed9ed90b31999dd794aee8ff8da0aff5b1..e006b65941615038ab9722b6db9ae11129264072 100644 (file)
@@ -12,6 +12,8 @@ import (
        "git.curoverse.com/arvados.git/sdk/go/arvados"
 )
 
-var Command cmd.Handler = service.Command(arvados.ServiceNameController, func(cluster *arvados.Cluster, _ *arvados.SystemNode) http.Handler {
-       return &Handler{Cluster: cluster}
-})
+var Command cmd.Handler = service.Command(arvados.ServiceNameController, newHandler)
+
+func newHandler(cluster *arvados.Cluster, node *arvados.SystemNode) http.Handler {
+       return &Handler{Cluster: cluster, Node: node}
+}
index f0354d94d93b164c28906b192c7f02770c12c5f4..a1b3848e5cca6869479f8091e7b7f01049f24197 100644 (file)
@@ -6,16 +6,20 @@ package controller
 
 import (
        "io"
+       "net"
        "net/http"
        "net/url"
+       "strings"
        "sync"
 
        "git.curoverse.com/arvados.git/sdk/go/arvados"
        "git.curoverse.com/arvados.git/sdk/go/health"
+       "git.curoverse.com/arvados.git/sdk/go/httpserver"
 )
 
 type Handler struct {
        Cluster *arvados.Cluster
+       Node    *arvados.SystemNode
 
        setupOnce    sync.Once
        handlerStack http.Handler
@@ -37,15 +41,39 @@ func (h *Handler) setup() {
        h.handlerStack = mux
 }
 
-func (h *Handler) proxyRailsAPI(w http.ResponseWriter, incomingReq *http.Request) {
-       url, err := findRailsAPI(h.Cluster)
+func (h *Handler) proxyRailsAPI(w http.ResponseWriter, reqIn *http.Request) {
+       urlOut, err := findRailsAPI(h.Cluster, h.Node)
        if err != nil {
                http.Error(w, err.Error(), http.StatusInternalServerError)
                return
        }
-       req := *incomingReq
-       req.URL.Host = url.Host
-       resp, err := arvados.InsecureHTTPClient.Do(&req)
+       urlOut = &url.URL{
+               Scheme:   urlOut.Scheme,
+               Host:     urlOut.Host,
+               Path:     reqIn.URL.Path,
+               RawPath:  reqIn.URL.RawPath,
+               RawQuery: reqIn.URL.RawQuery,
+       }
+
+       // Copy headers from incoming request, then add/replace proxy
+       // headers like Via and X-Forwarded-For.
+       hdrOut := http.Header{}
+       for k, v := range reqIn.Header {
+               hdrOut[k] = v
+       }
+       xff := reqIn.RemoteAddr
+       if xffIn := reqIn.Header.Get("X-Forwarded-For"); xffIn != "" {
+               xff = xffIn + "," + xff
+       }
+       hdrOut.Set("X-Forwarded-For", xff)
+       hdrOut.Add("Via", reqIn.Proto+" arvados-controller")
+
+       reqOut := (&http.Request{
+               Method: reqIn.Method,
+               URL:    urlOut,
+               Header: hdrOut,
+       }).WithContext(reqIn.Context())
+       resp, err := arvados.InsecureHTTPClient.Do(reqOut)
        if err != nil {
                http.Error(w, err.Error(), http.StatusInternalServerError)
                return
@@ -56,15 +84,27 @@ func (h *Handler) proxyRailsAPI(w http.ResponseWriter, incomingReq *http.Request
                }
        }
        w.WriteHeader(resp.StatusCode)
-       io.Copy(w, resp.Body)
+       n, err := io.Copy(w, resp.Body)
+       if err != nil {
+               httpserver.Logger(reqIn).WithError(err).WithField("bytesCopied", n).Error("error copying response body")
+       }
 }
 
 // For now, findRailsAPI always uses the rails API running on this
 // node.
-func findRailsAPI(cluster *arvados.Cluster) (*url.URL, error) {
-       node, err := cluster.GetThisSystemNode()
-       if err != nil {
+func findRailsAPI(cluster *arvados.Cluster, node *arvados.SystemNode) (*url.URL, error) {
+       hostport := node.RailsAPI.Listen
+       if len(hostport) > 1 && hostport[0] == ':' && strings.TrimRight(hostport[1:], "0123456789") == "" {
+               // ":12345" => connect to indicated port on localhost
+               hostport = "localhost" + hostport
+       } else if _, _, err := net.SplitHostPort(hostport); err == nil {
+               // "[::1]:12345" => connect to indicated address & port
+       } else {
                return nil, err
        }
-       return url.Parse("http://" + node.RailsAPI.Listen)
+       proto := "http"
+       if node.RailsAPI.TLS {
+               proto = "https"
+       }
+       return url.Parse(proto + "://" + hostport)
 }
diff --git a/lib/controller/handler_test.go b/lib/controller/handler_test.go
new file mode 100644 (file)
index 0000000..57bb13d
--- /dev/null
@@ -0,0 +1,91 @@
+// Copyright (C) The Arvados Authors. All rights reserved.
+//
+// SPDX-License-Identifier: AGPL-3.0
+
+package controller
+
+import (
+       "encoding/json"
+       "net/http"
+       "net/http/httptest"
+       "os"
+       "testing"
+
+       "git.curoverse.com/arvados.git/sdk/go/arvados"
+       "git.curoverse.com/arvados.git/sdk/go/arvadostest"
+       check "gopkg.in/check.v1"
+)
+
+// Gocheck boilerplate
+func Test(t *testing.T) {
+       check.TestingT(t)
+}
+
+var _ = check.Suite(&HandlerSuite{})
+
+type HandlerSuite struct {
+       cluster *arvados.Cluster
+       handler http.Handler
+}
+
+func (s *HandlerSuite) SetUpTest(c *check.C) {
+       s.cluster = &arvados.Cluster{
+               ClusterID: "zzzzz",
+               SystemNodes: map[string]arvados.SystemNode{
+                       "*": {
+                               Controller: arvados.SystemServiceInstance{Listen: ":"},
+                               RailsAPI:   arvados.SystemServiceInstance{Listen: os.Getenv("ARVADOS_API_HOST"), TLS: true},
+                       },
+               },
+       }
+       node := s.cluster.SystemNodes["*"]
+       s.handler = newHandler(s.cluster, &node)
+}
+
+func (s *HandlerSuite) TestProxyDiscoveryDoc(c *check.C) {
+       req := httptest.NewRequest("GET", "/discovery/v1/apis/arvados/v1/rest", nil)
+       resp := httptest.NewRecorder()
+       s.handler.ServeHTTP(resp, req)
+       c.Check(resp.Code, check.Equals, http.StatusOK)
+       var dd arvados.DiscoveryDocument
+       err := json.Unmarshal(resp.Body.Bytes(), &dd)
+       c.Check(err, check.IsNil)
+       c.Check(dd.BlobSignatureTTL, check.Not(check.Equals), int64(0))
+       c.Check(dd.BlobSignatureTTL > 0, check.Equals, true)
+       c.Check(len(dd.Resources), check.Not(check.Equals), 0)
+       c.Check(len(dd.Schemas), check.Not(check.Equals), 0)
+}
+
+func (s *HandlerSuite) TestProxyWithoutToken(c *check.C) {
+       req := httptest.NewRequest("GET", "/arvados/v1/users/current", nil)
+       resp := httptest.NewRecorder()
+       s.handler.ServeHTTP(resp, req)
+       c.Check(resp.Code, check.Equals, http.StatusUnauthorized)
+       jresp := map[string]interface{}{}
+       err := json.Unmarshal(resp.Body.Bytes(), &jresp)
+       c.Check(err, check.IsNil)
+       c.Check(jresp["errors"], check.FitsTypeOf, []interface{}{})
+}
+
+func (s *HandlerSuite) TestProxyWithToken(c *check.C) {
+       req := httptest.NewRequest("GET", "/arvados/v1/users/current", nil)
+       req.Header.Set("Authorization", "Bearer "+arvadostest.ActiveToken)
+       resp := httptest.NewRecorder()
+       s.handler.ServeHTTP(resp, req)
+       c.Check(resp.Code, check.Equals, http.StatusOK)
+       var u arvados.User
+       err := json.Unmarshal(resp.Body.Bytes(), &u)
+       c.Check(err, check.IsNil)
+       c.Check(u.UUID, check.Equals, arvadostest.ActiveUserUUID)
+}
+
+func (s *HandlerSuite) TestProxyNotFound(c *check.C) {
+       req := httptest.NewRequest("GET", "/arvados/v1/xyzzy", nil)
+       resp := httptest.NewRecorder()
+       s.handler.ServeHTTP(resp, req)
+       c.Check(resp.Code, check.Equals, http.StatusNotFound)
+       jresp := map[string]interface{}{}
+       err := json.Unmarshal(resp.Body.Bytes(), &jresp)
+       c.Check(err, check.IsNil)
+       c.Check(jresp["errors"], check.FitsTypeOf, []interface{}{})
+}
index 875a274dc6259aedfd648a1785d01ce85bb828dc..e0a2b1d28b7422e365250cc7f74ba553ba9c7011 100644 (file)
@@ -131,4 +131,5 @@ func (sn *SystemNode) ServicePorts() map[ServiceName]string {
 
 type SystemServiceInstance struct {
        Listen string
+       TLS    bool
 }
index ec3fa7fae18dc6b88022b95771efc1de83f96107..9577718c76e45c1757297d5272c6174f5a454571 100644 (file)
@@ -17,7 +17,10 @@ type contextKey struct {
        name string
 }
 
-var requestTimeContextKey = contextKey{"requestTime"}
+var (
+       requestTimeContextKey = contextKey{"requestTime"}
+       loggerContextKey      = contextKey{"logger"}
+)
 
 // LogRequests wraps an http.Handler, logging each request and
 // response via logger.
@@ -27,7 +30,6 @@ func LogRequests(logger logrus.FieldLogger, h http.Handler) http.Handler {
        }
        return http.HandlerFunc(func(wrapped http.ResponseWriter, req *http.Request) {
                w := &responseTimer{ResponseWriter: WrapResponseWriter(wrapped)}
-               req = req.WithContext(context.WithValue(req.Context(), &requestTimeContextKey, time.Now()))
                lgr := logger.WithFields(logrus.Fields{
                        "RequestID":       req.Header.Get("X-Request-Id"),
                        "remoteAddr":      req.RemoteAddr,
@@ -38,12 +40,25 @@ func LogRequests(logger logrus.FieldLogger, h http.Handler) http.Handler {
                        "reqQuery":        req.URL.RawQuery,
                        "reqBytes":        req.ContentLength,
                })
+               ctx := req.Context()
+               ctx = context.WithValue(ctx, &requestTimeContextKey, time.Now())
+               ctx = context.WithValue(ctx, &loggerContextKey, lgr)
+               req = req.WithContext(ctx)
+
                logRequest(w, req, lgr)
                defer logResponse(w, req, lgr)
                h.ServeHTTP(w, req)
        })
 }
 
+func Logger(req *http.Request) logrus.FieldLogger {
+       if lgr, ok := req.Context().Value(&loggerContextKey).(logrus.FieldLogger); ok {
+               return lgr
+       } else {
+               return logrus.StandardLogger()
+       }
+}
+
 func logRequest(w *responseTimer, req *http.Request, lgr *logrus.Entry) {
        lgr.Info("request")
 }