From cf311e8e16ba74467c77b5353afedc29b40a6a41 Mon Sep 17 00:00:00 2001 From: Tom Clegg Date: Mon, 17 Apr 2017 19:19:10 -0400 Subject: [PATCH] 11509: Allow cross-origin requests with Range headers. --- services/keep-web/handler.go | 14 ++++++++++++++ services/keep-web/handler_test.go | 29 +++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/services/keep-web/handler.go b/services/keep-web/handler.go index db7517adc6..a79973b975 100644 --- a/services/keep-web/handler.go +++ b/services/keep-web/handler.go @@ -94,6 +94,20 @@ func (h *handler) ServeHTTP(wOrig http.ResponseWriter, r *http.Request) { httpserver.Log(remoteAddr, statusCode, statusText, w.WroteBodyBytes(), r.Method, r.Host, r.URL.Path, r.URL.RawQuery) }() + if r.Method == "OPTIONS" { + method := r.Header.Get("Access-Control-Request-Method") + if method != "GET" && method != "POST" { + statusCode = http.StatusMethodNotAllowed + return + } + w.Header().Set("Access-Control-Allow-Headers", "Range") + w.Header().Set("Access-Control-Allow-Methods", "GET, POST") + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Max-Age", "86400") + statusCode = http.StatusOK + return + } + if r.Method != "GET" && r.Method != "POST" { statusCode, statusText = http.StatusMethodNotAllowed, r.Method return diff --git a/services/keep-web/handler_test.go b/services/keep-web/handler_test.go index 0c960b8c0e..e2eb33e278 100644 --- a/services/keep-web/handler_test.go +++ b/services/keep-web/handler_test.go @@ -18,6 +18,35 @@ var _ = check.Suite(&UnitSuite{}) type UnitSuite struct{} +func (s *UnitSuite) TestCORSPreflight(c *check.C) { + h := handler{Config: &Config{}} + u, _ := url.Parse("http://keep-web.example/c=" + arvadostest.FooCollection + "/foo") + req := &http.Request{ + Method: "OPTIONS", + Host: u.Host, + URL: u, + RequestURI: u.RequestURI(), + Header: http.Header{ + "Origin": {"https://workbench.example"}, + "Access-Control-Request-Method": {"POST"}, + }, + } + + resp := httptest.NewRecorder() + h.ServeHTTP(resp, req) + c.Check(resp.Code, check.Equals, http.StatusOK) + c.Check(resp.Body.String(), check.Equals, "") + c.Check(resp.Header().Get("Access-Control-Allow-Origin"), check.Equals, "*") + c.Check(resp.Header().Get("Access-Control-Allow-Methods"), check.Equals, "GET, POST") + c.Check(resp.Header().Get("Access-Control-Allow-Headers"), check.Equals, "Range") + + resp = httptest.NewRecorder() + req.Header.Set("Access-Control-Request-Method", "DELETE") + h.ServeHTTP(resp, req) + c.Check(resp.Body.String(), check.Equals, "") + c.Check(resp.Code, check.Equals, http.StatusMethodNotAllowed) +} + func mustParseURL(s string) *url.URL { r, err := url.Parse(s) if err != nil { -- 2.39.5