Merge branch '13697-database-timeout'
authorTom Clegg <tom@curii.com>
Thu, 23 Sep 2021 21:01:53 +0000 (17:01 -0400)
committerTom Clegg <tom@curii.com>
Thu, 23 Sep 2021 21:01:53 +0000 (17:01 -0400)
fixes #13697

Arvados-DCO-1.1-Signed-off-by: Tom Clegg <tom@curii.com>

lib/controller/auth_test.go
lib/controller/federation_test.go
lib/controller/server_test.go
lib/service/cmd.go
sdk/go/httpserver/logger.go
sdk/go/httpserver/logger_test.go
services/api/config/arvados_config.rb
services/api/config/initializers/db_timeout.rb [new file with mode: 0644]
services/keep-web/server.go

index 69458655ba0c8ce7caf9839d05eb5985f3fd0b7b..17524114671e840ecdac05e2457d9ebbb96c5635 100644 (file)
@@ -8,6 +8,7 @@ import (
        "context"
        "encoding/json"
        "fmt"
+       "net"
        "net/http"
        "net/http/httptest"
        "os"
@@ -99,9 +100,10 @@ func (s *AuthSuite) SetUpTest(c *check.C) {
 
        s.testHandler = &Handler{Cluster: cluster}
        s.testServer = newServerFromIntegrationTestEnv(c)
-       s.testServer.Server.Handler = httpserver.HandlerWithContext(
-               ctxlog.Context(context.Background(), s.log),
-               httpserver.AddRequestIDs(httpserver.LogRequests(s.testHandler)))
+       s.testServer.Server.BaseContext = func(net.Listener) context.Context {
+               return ctxlog.Context(context.Background(), s.log)
+       }
+       s.testServer.Server.Handler = httpserver.AddRequestIDs(httpserver.LogRequests(s.testHandler))
        c.Assert(s.testServer.Start(), check.IsNil)
 }
 
index 6d74ab65f93b6f7c73c95905b601a6090bb50a70..211c7619809ed6a8855248915facef843da55081 100644 (file)
@@ -11,6 +11,7 @@ import (
        "fmt"
        "io"
        "io/ioutil"
+       "net"
        "net/http"
        "net/http/httptest"
        "net/url"
@@ -71,9 +72,10 @@ func (s *FederationSuite) SetUpTest(c *check.C) {
        arvadostest.SetServiceURL(&cluster.Services.Controller, "http://localhost:/")
        s.testHandler = &Handler{Cluster: cluster}
        s.testServer = newServerFromIntegrationTestEnv(c)
-       s.testServer.Server.Handler = httpserver.HandlerWithContext(
-               ctxlog.Context(context.Background(), s.log),
-               httpserver.AddRequestIDs(httpserver.LogRequests(s.testHandler)))
+       s.testServer.Server.BaseContext = func(net.Listener) context.Context {
+               return ctxlog.Context(context.Background(), s.log)
+       }
+       s.testServer.Server.Handler = httpserver.AddRequestIDs(httpserver.LogRequests(s.testHandler))
 
        cluster.RemoteClusters = map[string]arvados.RemoteCluster{
                "zzzzz": {
index 051f355716c421e486e25ce1fb717d8355847e30..b2b3365a2015b2ac899a3b62f45d563042267ac9 100644 (file)
@@ -6,6 +6,7 @@ package controller
 
 import (
        "context"
+       "net"
        "net/http"
        "os"
        "path/filepath"
@@ -48,9 +49,10 @@ func newServerFromIntegrationTestEnv(c *check.C) *httpserver.Server {
 
        srv := &httpserver.Server{
                Server: http.Server{
-                       Handler: httpserver.HandlerWithContext(
-                               ctxlog.Context(context.Background(), log),
-                               httpserver.AddRequestIDs(httpserver.LogRequests(handler))),
+                       BaseContext: func(net.Listener) context.Context {
+                               return ctxlog.Context(context.Background(), log)
+                       },
+                       Handler: httpserver.AddRequestIDs(httpserver.LogRequests(handler)),
                },
                Addr: ":",
        }
index 40db4f9c7c7f80f744ab3c44da874794d925c9c1..e67c24f65f39cea4929c95fe30abbdc5ab98a901 100644 (file)
@@ -126,13 +126,14 @@ func (c *command) RunCommand(prog string, args []string, stdin io.Reader, stdout
        }
 
        instrumented := httpserver.Instrument(reg, log,
-               httpserver.HandlerWithContext(ctx,
+               httpserver.HandlerWithDeadline(cluster.API.RequestTimeout.Duration(),
                        httpserver.AddRequestIDs(
                                httpserver.LogRequests(
                                        httpserver.NewRequestLimiter(cluster.API.MaxConcurrentRequests, handler, reg)))))
        srv := &httpserver.Server{
                Server: http.Server{
-                       Handler: instrumented.ServeAPI(cluster.ManagementToken, instrumented),
+                       Handler:     instrumented.ServeAPI(cluster.ManagementToken, instrumented),
+                       BaseContext: func(net.Listener) context.Context { return ctx },
                },
                Addr: listenURL.Host,
        }
index 78a1f77adb9ca9942d00fe559c3192b658e8e5fe..7eb7f0f03d57b571e314f8d87ca6714cf7d6563f 100644 (file)
@@ -5,7 +5,9 @@
 package httpserver
 
 import (
+       "bufio"
        "context"
+       "net"
        "net/http"
        "time"
 
@@ -22,11 +24,42 @@ var (
        requestTimeContextKey = contextKey{"requestTime"}
 )
 
-// HandlerWithContext returns an http.Handler that changes the request
-// context to ctx (replacing http.Server's default
-// context.Background()), then calls next.
-func HandlerWithContext(ctx context.Context, next http.Handler) http.Handler {
+type hijacker interface {
+       http.ResponseWriter
+       http.Hijacker
+}
+
+// hijackNotifier wraps a ResponseWriter, calling the provided
+// Notify() func if/when the wrapped Hijacker is hijacked.
+type hijackNotifier struct {
+       hijacker
+       hijacked chan<- bool
+}
+
+func (hn hijackNotifier) Hijack() (net.Conn, *bufio.ReadWriter, error) {
+       close(hn.hijacked)
+       return hn.hijacker.Hijack()
+}
+
+// HandlerWithDeadline cancels the request context if the request
+// takes longer than the specified timeout without having its
+// connection hijacked.
+func HandlerWithDeadline(timeout time.Duration, next http.Handler) http.Handler {
        return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+               ctx, cancel := context.WithCancel(r.Context())
+               defer cancel()
+               nodeadline := make(chan bool)
+               go func() {
+                       select {
+                       case <-nodeadline:
+                       case <-ctx.Done():
+                       case <-time.After(timeout):
+                               cancel()
+                       }
+               }()
+               if hj, ok := w.(hijacker); ok {
+                       w = hijackNotifier{hj, nodeadline}
+               }
                next.ServeHTTP(w, r.WithContext(ctx))
        })
 }
index af45a640ca38e2bb8baa76244070d4f5753324b0..60768b3fc907681c8598a53ae5ff6f9985ef541f 100644 (file)
@@ -9,6 +9,8 @@ import (
        "context"
        "encoding/json"
        "fmt"
+       "io/ioutil"
+       "net"
        "net/http"
        "net/http/httptest"
        "testing"
@@ -41,6 +43,61 @@ func (s *Suite) SetUpTest(c *check.C) {
        s.ctx = ctxlog.Context(context.Background(), s.log)
 }
 
+func (s *Suite) TestWithDeadline(c *check.C) {
+       req, err := http.NewRequest("GET", "https://foo.example/bar", nil)
+       c.Assert(err, check.IsNil)
+
+       // Short timeout cancels context in <1s
+       resp := httptest.NewRecorder()
+       HandlerWithDeadline(time.Millisecond, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
+               select {
+               case <-req.Context().Done():
+                       w.Write([]byte("ok"))
+               case <-time.After(time.Second):
+                       c.Error("timed out")
+               }
+       })).ServeHTTP(resp, req.WithContext(s.ctx))
+       c.Check(resp.Body.String(), check.Equals, "ok")
+
+       // Long timeout does not cancel context in <1ms
+       resp = httptest.NewRecorder()
+       HandlerWithDeadline(time.Second, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
+               select {
+               case <-req.Context().Done():
+                       c.Error("request context done too soon")
+               case <-time.After(time.Millisecond):
+                       w.Write([]byte("ok"))
+               }
+       })).ServeHTTP(resp, req.WithContext(s.ctx))
+       c.Check(resp.Body.String(), check.Equals, "ok")
+}
+
+func (s *Suite) TestNoDeadlineAfterHijacked(c *check.C) {
+       srv := Server{
+               Addr: ":",
+               Server: http.Server{
+                       Handler: HandlerWithDeadline(time.Millisecond, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
+                               conn, _, err := w.(http.Hijacker).Hijack()
+                               c.Assert(err, check.IsNil)
+                               defer conn.Close()
+                               select {
+                               case <-req.Context().Done():
+                                       c.Error("request context done too soon")
+                               case <-time.After(time.Second / 10):
+                                       conn.Write([]byte("HTTP/1.1 200 OK\r\n\r\nok"))
+                               }
+                       })),
+                       BaseContext: func(net.Listener) context.Context { return s.ctx },
+               },
+       }
+       srv.Start()
+       defer srv.Close()
+       resp, err := http.Get("http://" + srv.Addr)
+       c.Assert(err, check.IsNil)
+       body, err := ioutil.ReadAll(resp.Body)
+       c.Check(string(body), check.Equals, "ok")
+}
+
 func (s *Suite) TestLogRequests(c *check.C) {
        h := AddRequestIDs(LogRequests(
                http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
@@ -52,7 +109,7 @@ func (s *Suite) TestLogRequests(c *check.C) {
        c.Assert(err, check.IsNil)
        resp := httptest.NewRecorder()
 
-       HandlerWithContext(s.ctx, h).ServeHTTP(resp, req)
+       h.ServeHTTP(resp, req.WithContext(s.ctx))
 
        dec := json.NewDecoder(s.logdata)
 
@@ -104,12 +161,12 @@ func (s *Suite) TestLogErrorBody(c *check.C) {
                c.Assert(err, check.IsNil)
                resp := httptest.NewRecorder()
 
-               HandlerWithContext(s.ctx, LogRequests(
+               LogRequests(
                        http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
                                w.WriteHeader(trial.statusCode)
                                w.Write([]byte(trial.sentBody))
                        }),
-               )).ServeHTTP(resp, req)
+               ).ServeHTTP(resp, req.WithContext(s.ctx))
 
                gotReq := make(map[string]interface{})
                err = dec.Decode(&gotReq)
index ea421a289b98cc628d6bfa9b0473a105a2986b0c..49865039c4649a8a8b7270682b97fd073b742504 100644 (file)
@@ -84,6 +84,7 @@ arvcfg.declare_config "API.MaxRequestSize", Integer, :max_request_size
 arvcfg.declare_config "API.MaxIndexDatabaseRead", Integer, :max_index_database_read
 arvcfg.declare_config "API.MaxItemsPerResponse", Integer, :max_items_per_response
 arvcfg.declare_config "API.MaxTokenLifetime", ActiveSupport::Duration
+arvcfg.declare_config "API.RequestTimeout", ActiveSupport::Duration
 arvcfg.declare_config "API.AsyncPermissionsUpdateInterval", ActiveSupport::Duration, :async_permissions_update_interval
 arvcfg.declare_config "Users.AutoSetupNewUsers", Boolean, :auto_setup_new_users
 arvcfg.declare_config "Users.AutoSetupNewUsersWithVmUUID", String, :auto_setup_new_users_with_vm_uuid
diff --git a/services/api/config/initializers/db_timeout.rb b/services/api/config/initializers/db_timeout.rb
new file mode 100644 (file)
index 0000000..3b61f15
--- /dev/null
@@ -0,0 +1,9 @@
+# Copyright (C) The Arvados Authors. All rights reserved.
+#
+# SPDX-License-Identifier: AGPL-3.0
+
+ActiveRecord::ConnectionAdapters::AbstractAdapter.set_callback :checkout, :before, ->(conn) do
+  ms = Rails.configuration.API.RequestTimeout.to_i * 1000
+  conn.execute("SET statement_timeout = #{ms}")
+  conn.execute("SET lock_timeout = #{ms}")
+end
index 8f623c627d067f843f2746a2b8b64248006b1a18..586f6b805736e23437357dbc8eb1a8cf4b5a5cc7 100644 (file)
@@ -6,6 +6,7 @@ package main
 
 import (
        "context"
+       "net"
        "net/http"
 
        "git.arvados.org/arvados.git/sdk/go/arvados"
@@ -24,10 +25,15 @@ func (srv *server) Start(logger *logrus.Logger) error {
        h := &handler{Config: srv.Config}
        reg := prometheus.NewRegistry()
        h.Config.Cache.registry = reg
-       ctx := ctxlog.Context(context.Background(), logger)
-       mh := httpserver.Instrument(reg, logger, httpserver.HandlerWithContext(ctx, httpserver.AddRequestIDs(httpserver.LogRequests(h))))
+       // Warning: when updating this to use Command() from
+       // lib/service, make sure to implement an exemption in
+       // httpserver.HandlerWithDeadline() so large file uploads are
+       // allowed to take longer than the usual API.RequestTimeout.
+       // See #13697.
+       mh := httpserver.Instrument(reg, logger, httpserver.AddRequestIDs(httpserver.LogRequests(h)))
        h.MetricsAPI = mh.ServeAPI(h.Config.cluster.ManagementToken, http.NotFoundHandler())
        srv.Handler = mh
+       srv.BaseContext = func(net.Listener) context.Context { return ctxlog.Context(context.Background(), logger) }
        var listen arvados.URL
        for listen = range srv.Config.cluster.Services.WebDAV.InternalURLs {
                break