16152: Fix nil http handler and ignored config args.
authorTom Clegg <tom@tomclegg.ca>
Thu, 13 Feb 2020 20:58:09 +0000 (15:58 -0500)
committerTom Clegg <tom@tomclegg.ca>
Thu, 13 Feb 2020 20:58:09 +0000 (15:58 -0500)
Arvados-DCO-1.1-Signed-off-by: Tom Clegg <tom@tomclegg.ca>

services/keep-balance/main.go
services/keep-balance/main_test.go [new file with mode: 0644]

index 6e89df9a5552cc34bac4d13e3bae4356d6acf48a..65bd8d4cf098a17610953810ab8147f678616aee 100644 (file)
@@ -9,6 +9,7 @@ import (
        "flag"
        "fmt"
        "io"
+       "net/http"
        "os"
 
        "git.arvados.org/arvados.git/lib/config"
@@ -50,10 +51,17 @@ func runCommand(prog string, args []string, stdin io.Reader, stdout, stderr io.W
                options.Dumper = dumper
        }
 
-       // Only pass along the version flag, which gets handled in RunCommand
+       // Drop our custom args that would be rejected by the generic
+       // service.Command
        args = nil
+       dropFlag := map[string]bool{
+               "once":         true,
+               "commit-pulls": true,
+               "commit-trash": true,
+               "dump":         true,
+       }
        flags.Visit(func(f *flag.Flag) {
-               if f.Name == "version" {
+               if !dropFlag[f.Name] {
                        args = append(args, "-"+f.Name, f.Value.String())
                }
        })
@@ -75,6 +83,7 @@ func runCommand(prog string, args []string, stdin io.Reader, stdout, stderr io.W
                        }
 
                        srv := &Server{
+                               Handler:    http.NotFoundHandler(),
                                Cluster:    cluster,
                                ArvClient:  ac,
                                RunOptions: options,
diff --git a/services/keep-balance/main_test.go b/services/keep-balance/main_test.go
new file mode 100644 (file)
index 0000000..a644550
--- /dev/null
@@ -0,0 +1,84 @@
+// Copyright (C) The Arvados Authors. All rights reserved.
+//
+// SPDX-License-Identifier: AGPL-3.0
+
+package main
+
+import (
+       "bytes"
+       "io/ioutil"
+       "net"
+       "net/http"
+       "time"
+
+       check "gopkg.in/check.v1"
+)
+
+var _ = check.Suite(&mainSuite{})
+
+type mainSuite struct{}
+
+func (s *mainSuite) TestVersionFlag(c *check.C) {
+       var stdout, stderr bytes.Buffer
+       runCommand("keep-balance", []string{"-version"}, nil, &stdout, &stderr)
+       c.Check(stderr.String(), check.Equals, "")
+       c.Log(stdout.String())
+}
+
+func (s *mainSuite) TestHTTPServer(c *check.C) {
+       ln, err := net.Listen("tcp", ":0")
+       if err != nil {
+               c.Fatal(err)
+       }
+       _, p, err := net.SplitHostPort(ln.Addr().String())
+       ln.Close()
+       config := "Clusters:\n zzzzz:\n  ManagementToken: abcdefg\n  Services: {Keepbalance: {InternalURLs: {'http://localhost:" + p + "/': {}}}}\n"
+
+       var stdout bytes.Buffer
+       go runCommand("keep-balance", []string{"-config", "-"}, bytes.NewBufferString(config), &stdout, &stdout)
+       done := make(chan struct{})
+       go func() {
+               defer close(done)
+               for {
+                       time.Sleep(time.Second / 10)
+                       req, err := http.NewRequest(http.MethodGet, "http://:"+p+"/metrics", nil)
+                       if err != nil {
+                               c.Fatal(err)
+                               return
+                       }
+                       req.Header.Set("Authorization", "Bearer abcdefg")
+                       resp, err := http.DefaultClient.Do(req)
+                       if err != nil {
+                               c.Logf("error %s", err)
+                               continue
+                       }
+                       defer resp.Body.Close()
+                       if resp.StatusCode != http.StatusOK {
+                               c.Logf("http status %d", resp.StatusCode)
+                               continue
+                       }
+                       buf, err := ioutil.ReadAll(resp.Body)
+                       if err != nil {
+                               c.Logf("read body: %s", err)
+                               continue
+                       }
+                       c.Check(string(buf), check.Matches, `(?ms).*arvados_keepbalance_sweep_seconds_sum.*`)
+                       return
+               }
+       }()
+       select {
+       case <-done:
+       case <-time.After(time.Second):
+               c.Log(stdout.String())
+               c.Fatal("timeout")
+       }
+
+       // Check non-metrics URL that gets passed through to us from
+       // service.Command
+       req, err := http.NewRequest(http.MethodGet, "http://:"+p+"/not-metrics", nil)
+       c.Assert(err, check.IsNil)
+       resp, err := http.DefaultClient.Do(req)
+       c.Check(err, check.IsNil)
+       defer resp.Body.Close()
+       c.Check(resp.StatusCode, check.Equals, http.StatusNotFound)
+}