18947: Refactor keepproxy as an arvados-server subcommand.
[arvados.git] / services / keepproxy / keepproxy_test.go
index 052109bf29b925af9ef8945274ce099c96e9f112..12d65201dafd4503699a63802d37b6ff0b662c13 100644 (file)
@@ -2,14 +2,16 @@
 //
 // SPDX-License-Identifier: AGPL-3.0
 
-package main
+package keepproxy
 
 import (
        "bytes"
+       "context"
        "crypto/md5"
        "fmt"
        "io/ioutil"
        "math/rand"
+       "net"
        "net/http"
        "net/http/httptest"
        "strings"
@@ -22,6 +24,7 @@ import (
        "git.arvados.org/arvados.git/sdk/go/arvadosclient"
        "git.arvados.org/arvados.git/sdk/go/arvadostest"
        "git.arvados.org/arvados.git/sdk/go/ctxlog"
+       "git.arvados.org/arvados.git/sdk/go/httpserver"
        "git.arvados.org/arvados.git/sdk/go/keepclient"
        log "github.com/sirupsen/logrus"
 
@@ -54,27 +57,6 @@ type NoKeepServerSuite struct{}
 
 var TestProxyUUID = "zzzzz-bi6l4-lrixqc4fxofbmzz"
 
-// Wait (up to 1 second) for keepproxy to listen on a port. This
-// avoids a race condition where we hit a "connection refused" error
-// because we start testing the proxy too soon.
-func waitForListener() {
-       const (
-               ms = 5
-       )
-       for i := 0; listener == nil && i < 10000; i += ms {
-               time.Sleep(ms * time.Millisecond)
-       }
-       if listener == nil {
-               panic("Timed out waiting for listener to start")
-       }
-}
-
-func closeListener() {
-       if listener != nil {
-               listener.Close()
-       }
-}
-
 func (s *ServerRequiredSuite) SetUpSuite(c *C) {
        arvadostest.StartKeep(2, false)
 }
@@ -111,7 +93,12 @@ func (s *NoKeepServerSuite) SetUpTest(c *C) {
        arvadostest.ResetEnv()
 }
 
-func runProxy(c *C, bogusClientToken bool, loadKeepstoresFromConfig bool, kp *arvados.UploadDownloadRolePermissions) (*keepclient.KeepClient, *bytes.Buffer) {
+type testServer struct {
+       *httpserver.Server
+       proxyHandler *proxyHandler
+}
+
+func runProxy(c *C, bogusClientToken bool, loadKeepstoresFromConfig bool, kp *arvados.UploadDownloadRolePermissions) (*testServer, *keepclient.KeepClient, *bytes.Buffer) {
        cfg, err := config.NewLoader(nil, ctxlog.TestLogger(c)).Load()
        c.Assert(err, Equals, nil)
        cluster, err := cfg.GetCluster("")
@@ -128,38 +115,47 @@ func runProxy(c *C, bogusClientToken bool, loadKeepstoresFromConfig bool, kp *ar
                cluster.Collections.KeepproxyPermission = *kp
        }
 
-       listener = nil
        logbuf := &bytes.Buffer{}
        logger := log.New()
        logger.Out = logbuf
-       go func() {
-               run(logger, cluster)
-               defer closeListener()
-       }()
-       waitForListener()
+       ctx := ctxlog.Context(context.Background(), logger)
+
+       handler := newHandlerOrErrorHandler(ctx, cluster, cluster.SystemRootToken, nil).(*proxyHandler)
+       srv := &testServer{
+               Server: &httpserver.Server{
+                       Server: http.Server{
+                               BaseContext: func(net.Listener) context.Context { return ctx },
+                               Handler: httpserver.AddRequestIDs(
+                                       httpserver.LogRequests(handler)),
+                       },
+                       Addr: ":",
+               },
+               proxyHandler: handler,
+       }
+       err = srv.Start()
+       c.Assert(err, IsNil)
 
        client := arvados.NewClientFromEnv()
        arv, err := arvadosclient.New(client)
-       c.Assert(err, Equals, nil)
+       c.Assert(err, IsNil)
        if bogusClientToken {
                arv.ApiToken = "bogus-token"
        }
        kc := keepclient.New(arv)
        sr := map[string]string{
-               TestProxyUUID: "http://" + listener.Addr().String(),
+               TestProxyUUID: "http://" + srv.Addr,
        }
        kc.SetServiceRoots(sr, sr, sr)
        kc.Arvados.External = true
-
-       return kc, logbuf
+       return srv, kc, logbuf
 }
 
 func (s *ServerRequiredSuite) TestResponseViaHeader(c *C) {
-       runProxy(c, false, false, nil)
-       defer closeListener()
+       srv, _, _ := runProxy(c, false, false, nil)
+       defer srv.Close()
 
        req, err := http.NewRequest("POST",
-               "http://"+listener.Addr().String()+"/",
+               "http://"+srv.Addr+"/",
                strings.NewReader("TestViaHeader"))
        c.Assert(err, Equals, nil)
        req.Header.Add("Authorization", "OAuth2 "+arvadostest.ActiveToken)
@@ -172,7 +168,7 @@ func (s *ServerRequiredSuite) TestResponseViaHeader(c *C) {
        resp.Body.Close()
 
        req, err = http.NewRequest("GET",
-               "http://"+listener.Addr().String()+"/"+string(locator),
+               "http://"+srv.Addr+"/"+string(locator),
                nil)
        c.Assert(err, Equals, nil)
        resp, err = (&http.Client{}).Do(req)
@@ -182,13 +178,13 @@ func (s *ServerRequiredSuite) TestResponseViaHeader(c *C) {
 }
 
 func (s *ServerRequiredSuite) TestLoopDetection(c *C) {
-       kc, _ := runProxy(c, false, false, nil)
-       defer closeListener()
+       srv, kc, _ := runProxy(c, false, false, nil)
+       defer srv.Close()
 
        sr := map[string]string{
-               TestProxyUUID: "http://" + listener.Addr().String(),
+               TestProxyUUID: "http://" + srv.Addr,
        }
-       router.(*proxyHandler).KeepClient.SetServiceRoots(sr, sr, sr)
+       srv.proxyHandler.KeepClient.SetServiceRoots(sr, sr, sr)
 
        content := []byte("TestLoopDetection")
        _, _, err := kc.PutB(content)
@@ -200,8 +196,8 @@ func (s *ServerRequiredSuite) TestLoopDetection(c *C) {
 }
 
 func (s *ServerRequiredSuite) TestStorageClassesHeader(c *C) {
-       kc, _ := runProxy(c, false, false, nil)
-       defer closeListener()
+       srv, kc, _ := runProxy(c, false, false, nil)
+       defer srv.Close()
 
        // Set up fake keepstore to record request headers
        var hdr http.Header
@@ -216,7 +212,7 @@ func (s *ServerRequiredSuite) TestStorageClassesHeader(c *C) {
        sr := map[string]string{
                TestProxyUUID: ts.URL,
        }
-       router.(*proxyHandler).KeepClient.SetServiceRoots(sr, sr, sr)
+       srv.proxyHandler.KeepClient.SetServiceRoots(sr, sr, sr)
 
        // Set up client to ask for storage classes to keepproxy
        kc.StorageClasses = []string{"secure"}
@@ -227,15 +223,15 @@ func (s *ServerRequiredSuite) TestStorageClassesHeader(c *C) {
 }
 
 func (s *ServerRequiredSuite) TestStorageClassesConfirmedHeader(c *C) {
-       runProxy(c, false, false, nil)
-       defer closeListener()
+       srv, _, _ := runProxy(c, false, false, nil)
+       defer srv.Close()
 
        content := []byte("foo")
        hash := fmt.Sprintf("%x", md5.Sum(content))
        client := &http.Client{}
 
        req, err := http.NewRequest("PUT",
-               fmt.Sprintf("http://%s/%s", listener.Addr().String(), hash),
+               fmt.Sprintf("http://%s/%s", srv.Addr, hash),
                bytes.NewReader(content))
        c.Assert(err, IsNil)
        req.Header.Set("X-Keep-Storage-Classes", "default")
@@ -249,8 +245,8 @@ func (s *ServerRequiredSuite) TestStorageClassesConfirmedHeader(c *C) {
 }
 
 func (s *ServerRequiredSuite) TestDesiredReplicas(c *C) {
-       kc, _ := runProxy(c, false, false, nil)
-       defer closeListener()
+       srv, kc, _ := runProxy(c, false, false, nil)
+       defer srv.Close()
 
        content := []byte("TestDesiredReplicas")
        hash := fmt.Sprintf("%x", md5.Sum(content))
@@ -270,8 +266,8 @@ func (s *ServerRequiredSuite) TestDesiredReplicas(c *C) {
 }
 
 func (s *ServerRequiredSuite) TestPutWrongContentLength(c *C) {
-       kc, _ := runProxy(c, false, false, nil)
-       defer closeListener()
+       srv, kc, _ := runProxy(c, false, false, nil)
+       defer srv.Close()
 
        content := []byte("TestPutWrongContentLength")
        hash := fmt.Sprintf("%x", md5.Sum(content))
@@ -281,7 +277,7 @@ func (s *ServerRequiredSuite) TestPutWrongContentLength(c *C) {
        // fixes the invalid Content-Length header. In order to test
        // our server behavior, we have to call the handler directly
        // using an httptest.ResponseRecorder.
-       rtr, err := MakeRESTRouter(kc, 10*time.Second, &arvados.Cluster{}, log.New())
+       rtr, err := newHandler(context.Background(), kc, 10*time.Second, &arvados.Cluster{})
        c.Assert(err, check.IsNil)
 
        type testcase struct {
@@ -296,7 +292,7 @@ func (s *ServerRequiredSuite) TestPutWrongContentLength(c *C) {
                {"abcdef", http.StatusLengthRequired},
        } {
                req, err := http.NewRequest("PUT",
-                       fmt.Sprintf("http://%s/%s+%d", listener.Addr().String(), hash, len(content)),
+                       fmt.Sprintf("http://%s/%s+%d", srv.Addr, hash, len(content)),
                        bytes.NewReader(content))
                c.Assert(err, IsNil)
                req.Header.Set("Content-Length", t.sendLength)
@@ -310,9 +306,9 @@ func (s *ServerRequiredSuite) TestPutWrongContentLength(c *C) {
 }
 
 func (s *ServerRequiredSuite) TestManyFailedPuts(c *C) {
-       kc, _ := runProxy(c, false, false, nil)
-       defer closeListener()
-       router.(*proxyHandler).timeout = time.Nanosecond
+       srv, kc, _ := runProxy(c, false, false, nil)
+       defer srv.Close()
+       srv.proxyHandler.timeout = time.Nanosecond
 
        buf := make([]byte, 1<<20)
        rand.Read(buf)
@@ -337,8 +333,8 @@ func (s *ServerRequiredSuite) TestManyFailedPuts(c *C) {
 }
 
 func (s *ServerRequiredSuite) TestPutAskGet(c *C) {
-       kc, logbuf := runProxy(c, false, false, nil)
-       defer closeListener()
+       srv, kc, logbuf := runProxy(c, false, false, nil)
+       defer srv.Close()
 
        hash := fmt.Sprintf("%x", md5.Sum([]byte("foo")))
        var hash2 string
@@ -374,7 +370,7 @@ func (s *ServerRequiredSuite) TestPutAskGet(c *C) {
                c.Check(err, Equals, nil)
                c.Log("Finished PutB (expected success)")
 
-               c.Check(logbuf.String(), Matches, `(?ms).*msg="Block upload" locator=acbd18db4cc2f85cedef654fccc4a4d8\+3 user_full_name="TestCase Administrator" user_uuid=zzzzz-tpzed-d9tiejq69daie8f.*`)
+               c.Check(logbuf.String(), Matches, `(?ms).*msg="Block upload".* locator=acbd18db4cc2f85cedef654fccc4a4d8\+3.* user_full_name="TestCase Administrator".* user_uuid=zzzzz-tpzed-d9tiejq69daie8f.*`)
                logbuf.Reset()
        }
 
@@ -383,7 +379,7 @@ func (s *ServerRequiredSuite) TestPutAskGet(c *C) {
                c.Assert(err, Equals, nil)
                c.Check(blocklen, Equals, int64(3))
                c.Log("Finished Ask (expected success)")
-               c.Check(logbuf.String(), Matches, `(?ms).*msg="Block download" locator=acbd18db4cc2f85cedef654fccc4a4d8\+3 user_full_name="TestCase Administrator" user_uuid=zzzzz-tpzed-d9tiejq69daie8f.*`)
+               c.Check(logbuf.String(), Matches, `(?ms).*msg="Block download".* locator=acbd18db4cc2f85cedef654fccc4a4d8\+3.* user_full_name="TestCase Administrator".* user_uuid=zzzzz-tpzed-d9tiejq69daie8f.*`)
                logbuf.Reset()
        }
 
@@ -395,7 +391,7 @@ func (s *ServerRequiredSuite) TestPutAskGet(c *C) {
                c.Check(all, DeepEquals, []byte("foo"))
                c.Check(blocklen, Equals, int64(3))
                c.Log("Finished Get (expected success)")
-               c.Check(logbuf.String(), Matches, `(?ms).*msg="Block download" locator=acbd18db4cc2f85cedef654fccc4a4d8\+3 user_full_name="TestCase Administrator" user_uuid=zzzzz-tpzed-d9tiejq69daie8f.*`)
+               c.Check(logbuf.String(), Matches, `(?ms).*msg="Block download".* locator=acbd18db4cc2f85cedef654fccc4a4d8\+3.* user_full_name="TestCase Administrator".* user_uuid=zzzzz-tpzed-d9tiejq69daie8f.*`)
                logbuf.Reset()
        }
 
@@ -421,8 +417,8 @@ func (s *ServerRequiredSuite) TestPutAskGet(c *C) {
 }
 
 func (s *ServerRequiredSuite) TestPutAskGetForbidden(c *C) {
-       kc, _ := runProxy(c, true, false, nil)
-       defer closeListener()
+       srv, kc, _ := runProxy(c, true, false, nil)
+       defer srv.Close()
 
        hash := fmt.Sprintf("%x+3", md5.Sum([]byte("bar")))
 
@@ -455,8 +451,8 @@ func testPermission(c *C, admin bool, perm arvados.UploadDownloadPermission) {
                kp.User = perm
        }
 
-       kc, logbuf := runProxy(c, false, false, &kp)
-       defer closeListener()
+       srv, kc, logbuf := runProxy(c, false, false, &kp)
+       defer srv.Close()
        if admin {
                kc.Arvados.ApiToken = arvadostest.AdminToken
        } else {
@@ -477,10 +473,10 @@ func testPermission(c *C, admin bool, perm arvados.UploadDownloadPermission) {
                        c.Check(err, Equals, nil)
                        c.Log("Finished PutB (expected success)")
                        if admin {
-                               c.Check(logbuf.String(), Matches, `(?ms).*msg="Block upload" locator=acbd18db4cc2f85cedef654fccc4a4d8\+3 user_full_name="TestCase Administrator" user_uuid=zzzzz-tpzed-d9tiejq69daie8f.*`)
+                               c.Check(logbuf.String(), Matches, `(?ms).*msg="Block upload".* locator=acbd18db4cc2f85cedef654fccc4a4d8\+3.* user_full_name="TestCase Administrator".* user_uuid=zzzzz-tpzed-d9tiejq69daie8f.*`)
                        } else {
 
-                               c.Check(logbuf.String(), Matches, `(?ms).*msg="Block upload" locator=acbd18db4cc2f85cedef654fccc4a4d8\+3 user_full_name="Active User" user_uuid=zzzzz-tpzed-xurymjxw79nv3jz.*`)
+                               c.Check(logbuf.String(), Matches, `(?ms).*msg="Block upload".* locator=acbd18db4cc2f85cedef654fccc4a4d8\+3.* user_full_name="Active User".* user_uuid=zzzzz-tpzed-xurymjxw79nv3jz.*`)
                        }
                } else {
                        c.Check(hash2, Equals, "")
@@ -501,9 +497,9 @@ func testPermission(c *C, admin bool, perm arvados.UploadDownloadPermission) {
                        c.Check(blocklen, Equals, int64(3))
                        c.Log("Finished Get (expected success)")
                        if admin {
-                               c.Check(logbuf.String(), Matches, `(?ms).*msg="Block download" locator=acbd18db4cc2f85cedef654fccc4a4d8\+3 user_full_name="TestCase Administrator" user_uuid=zzzzz-tpzed-d9tiejq69daie8f.*`)
+                               c.Check(logbuf.String(), Matches, `(?ms).*msg="Block download".* locator=acbd18db4cc2f85cedef654fccc4a4d8\+3.* user_full_name="TestCase Administrator".* user_uuid=zzzzz-tpzed-d9tiejq69daie8f.*`)
                        } else {
-                               c.Check(logbuf.String(), Matches, `(?ms).*msg="Block download" locator=acbd18db4cc2f85cedef654fccc4a4d8\+3 user_full_name="Active User" user_uuid=zzzzz-tpzed-xurymjxw79nv3jz.*`)
+                               c.Check(logbuf.String(), Matches, `(?ms).*msg="Block download".* locator=acbd18db4cc2f85cedef654fccc4a4d8\+3.* user_full_name="Active User".* user_uuid=zzzzz-tpzed-xurymjxw79nv3jz.*`)
                        }
                } else {
                        c.Check(err, FitsTypeOf, &keepclient.ErrNotFound{})
@@ -545,13 +541,13 @@ func (s *ServerRequiredSuite) TestPutGetPermission(c *C) {
 }
 
 func (s *ServerRequiredSuite) TestCorsHeaders(c *C) {
-       runProxy(c, false, false, nil)
-       defer closeListener()
+       srv, _, _ := runProxy(c, false, false, nil)
+       defer srv.Close()
 
        {
                client := http.Client{}
                req, err := http.NewRequest("OPTIONS",
-                       fmt.Sprintf("http://%s/%x+3", listener.Addr().String(), md5.Sum([]byte("foo"))),
+                       fmt.Sprintf("http://%s/%x+3", srv.Addr, md5.Sum([]byte("foo"))),
                        nil)
                c.Assert(err, IsNil)
                req.Header.Add("Access-Control-Request-Method", "PUT")
@@ -567,8 +563,7 @@ func (s *ServerRequiredSuite) TestCorsHeaders(c *C) {
        }
 
        {
-               resp, err := http.Get(
-                       fmt.Sprintf("http://%s/%x+3", listener.Addr().String(), md5.Sum([]byte("foo"))))
+               resp, err := http.Get(fmt.Sprintf("http://%s/%x+3", srv.Addr, md5.Sum([]byte("foo"))))
                c.Check(err, Equals, nil)
                c.Check(resp.Header.Get("Access-Control-Allow-Headers"), Equals, "Authorization, Content-Length, Content-Type, X-Keep-Desired-Replicas")
                c.Check(resp.Header.Get("Access-Control-Allow-Origin"), Equals, "*")
@@ -576,13 +571,13 @@ func (s *ServerRequiredSuite) TestCorsHeaders(c *C) {
 }
 
 func (s *ServerRequiredSuite) TestPostWithoutHash(c *C) {
-       runProxy(c, false, false, nil)
-       defer closeListener()
+       srv, _, _ := runProxy(c, false, false, nil)
+       defer srv.Close()
 
        {
                client := http.Client{}
                req, err := http.NewRequest("POST",
-                       "http://"+listener.Addr().String()+"/",
+                       "http://"+srv.Addr+"/",
                        strings.NewReader("qux"))
                c.Check(err, IsNil)
                req.Header.Add("Authorization", "OAuth2 "+arvadostest.ActiveToken)
@@ -634,8 +629,8 @@ func (s *ServerRequiredConfigYmlSuite) TestGetIndex(c *C) {
 }
 
 func getIndexWorker(c *C, useConfig bool) {
-       kc, _ := runProxy(c, false, useConfig, nil)
-       defer closeListener()
+       srv, kc, _ := runProxy(c, false, useConfig, nil)
+       defer srv.Close()
 
        // Put "index-data" blocks
        data := []byte("index-data")
@@ -697,8 +692,8 @@ func getIndexWorker(c *C, useConfig bool) {
 }
 
 func (s *ServerRequiredSuite) TestCollectionSharingToken(c *C) {
-       kc, _ := runProxy(c, false, false, nil)
-       defer closeListener()
+       srv, kc, _ := runProxy(c, false, false, nil)
+       defer srv.Close()
        hash, _, err := kc.PutB([]byte("shareddata"))
        c.Check(err, IsNil)
        kc.Arvados.ApiToken = arvadostest.FooCollectionSharingToken
@@ -710,8 +705,8 @@ func (s *ServerRequiredSuite) TestCollectionSharingToken(c *C) {
 }
 
 func (s *ServerRequiredSuite) TestPutAskGetInvalidToken(c *C) {
-       kc, _ := runProxy(c, false, false, nil)
-       defer closeListener()
+       srv, kc, _ := runProxy(c, false, false, nil)
+       defer srv.Close()
 
        // Put a test block
        hash, rep, err := kc.PutB([]byte("foo"))
@@ -747,14 +742,14 @@ func (s *ServerRequiredSuite) TestPutAskGetInvalidToken(c *C) {
 }
 
 func (s *ServerRequiredSuite) TestAskGetKeepProxyConnectionError(c *C) {
-       kc, _ := runProxy(c, false, false, nil)
-       defer closeListener()
+       srv, kc, _ := runProxy(c, false, false, nil)
+       defer srv.Close()
 
        // Point keepproxy at a non-existent keepstore
        locals := map[string]string{
                TestProxyUUID: "http://localhost:12345",
        }
-       router.(*proxyHandler).KeepClient.SetServiceRoots(locals, nil, nil)
+       srv.proxyHandler.KeepClient.SetServiceRoots(locals, nil, nil)
 
        // Ask should result in temporary bad gateway error
        hash := fmt.Sprintf("%x", md5.Sum([]byte("foo")))
@@ -773,8 +768,8 @@ func (s *ServerRequiredSuite) TestAskGetKeepProxyConnectionError(c *C) {
 }
 
 func (s *NoKeepServerSuite) TestAskGetNoKeepServerError(c *C) {
-       kc, _ := runProxy(c, false, false, nil)
-       defer closeListener()
+       srv, kc, _ := runProxy(c, false, false, nil)
+       defer srv.Close()
 
        hash := fmt.Sprintf("%x", md5.Sum([]byte("foo")))
        for _, f := range []func() error{
@@ -796,14 +791,14 @@ func (s *NoKeepServerSuite) TestAskGetNoKeepServerError(c *C) {
 }
 
 func (s *ServerRequiredSuite) TestPing(c *C) {
-       kc, _ := runProxy(c, false, false, nil)
-       defer closeListener()
+       srv, kc, _ := runProxy(c, false, false, nil)
+       defer srv.Close()
 
-       rtr, err := MakeRESTRouter(kc, 10*time.Second, &arvados.Cluster{ManagementToken: arvadostest.ManagementToken}, log.New())
+       rtr, err := newHandler(context.Background(), kc, 10*time.Second, &arvados.Cluster{ManagementToken: arvadostest.ManagementToken})
        c.Assert(err, check.IsNil)
 
        req, err := http.NewRequest("GET",
-               "http://"+listener.Addr().String()+"/_health/ping",
+               "http://"+srv.Addr+"/_health/ping",
                nil)
        c.Assert(err, IsNil)
        req.Header.Set("Authorization", "Bearer "+arvadostest.ManagementToken)