22124: Reject requests with double slash, except GET //locator.
[arvados.git] / services / keepstore / router.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package keepstore
6
7 import (
8         "encoding/json"
9         "errors"
10         "fmt"
11         "io"
12         "net/http"
13         "os"
14         "strconv"
15         "strings"
16         "sync/atomic"
17
18         "git.arvados.org/arvados.git/lib/service"
19         "git.arvados.org/arvados.git/sdk/go/arvados"
20         "git.arvados.org/arvados.git/sdk/go/auth"
21         "git.arvados.org/arvados.git/sdk/go/httpserver"
22         "git.arvados.org/arvados.git/sdk/go/keepclient"
23         "github.com/gorilla/mux"
24 )
25
26 type router struct {
27         http.Handler
28         keepstore *keepstore
29         puller    *puller
30         trasher   *trasher
31 }
32
33 func newRouter(keepstore *keepstore, puller *puller, trasher *trasher) service.Handler {
34         rtr := &router{
35                 keepstore: keepstore,
36                 puller:    puller,
37                 trasher:   trasher,
38         }
39         adminonly := func(h http.HandlerFunc) http.HandlerFunc {
40                 return auth.RequireLiteralToken(keepstore.cluster.SystemRootToken, h).ServeHTTP
41         }
42
43         r := mux.NewRouter()
44         // Without SkipClean(true), the gorilla/mux package responds
45         // to "PUT //foo" with a 301 redirect to "/foo", which causes
46         // a (Fetch Standard compliant) client to repeat the request
47         // as "GET /foo", which is clearly inappropriate in this case.
48         // It's less confusing if we just return 400.
49         r.SkipClean(true)
50         locatorPath := `/{locator:[0-9a-f]{32}.*}`
51         get := r.Methods(http.MethodGet, http.MethodHead).Subrouter()
52         get.HandleFunc(locatorPath, rtr.handleBlockRead)
53         get.HandleFunc("/"+locatorPath, rtr.handleBlockRead) // for compatibility -- see TestBlockRead_DoubleSlash
54         get.HandleFunc(`/index`, adminonly(rtr.handleIndex))
55         get.HandleFunc(`/index/{prefix:[0-9a-f]{0,32}}`, adminonly(rtr.handleIndex))
56         get.HandleFunc(`/mounts`, adminonly(rtr.handleMounts))
57         get.HandleFunc(`/mounts/{uuid}/blocks`, adminonly(rtr.handleIndex))
58         get.HandleFunc(`/mounts/{uuid}/blocks/{prefix:[0-9a-f]{0,32}}`, adminonly(rtr.handleIndex))
59         put := r.Methods(http.MethodPut).Subrouter()
60         put.HandleFunc(locatorPath, rtr.handleBlockWrite)
61         put.HandleFunc(`/pull`, adminonly(rtr.handlePullList))
62         put.HandleFunc(`/trash`, adminonly(rtr.handleTrashList))
63         put.HandleFunc(`/untrash`+locatorPath, adminonly(rtr.handleUntrash))
64         touch := r.Methods("TOUCH").Subrouter()
65         touch.HandleFunc(locatorPath, adminonly(rtr.handleBlockTouch))
66         delete := r.Methods(http.MethodDelete).Subrouter()
67         delete.HandleFunc(locatorPath, adminonly(rtr.handleBlockTrash))
68         options := r.Methods(http.MethodOptions).Subrouter()
69         options.NewRoute().PathPrefix(`/`).HandlerFunc(rtr.handleOptions)
70         r.NotFoundHandler = http.HandlerFunc(rtr.handleBadRequest)
71         r.MethodNotAllowedHandler = http.HandlerFunc(rtr.handleBadRequest)
72         rtr.Handler = corsHandler(auth.LoadToken(r))
73         return rtr
74 }
75
76 func (rtr *router) CheckHealth() error {
77         return nil
78 }
79
80 func (rtr *router) Done() <-chan struct{} {
81         return nil
82 }
83
84 func (rtr *router) handleBlockRead(w http.ResponseWriter, req *http.Request) {
85         // Intervening proxies must not return a cached GET response
86         // to a prior request if a X-Keep-Signature request header has
87         // been added or changed.
88         w.Header().Add("Vary", keepclient.XKeepSignature)
89         var localLocator func(string)
90         if strings.SplitN(req.Header.Get(keepclient.XKeepSignature), ",", 2)[0] == "local" {
91                 localLocator = func(locator string) {
92                         w.Header().Set(keepclient.XKeepLocator, locator)
93                 }
94         }
95         out := w
96         if req.Method == http.MethodHead {
97                 out = discardWrite{ResponseWriter: w}
98         } else if li, err := getLocatorInfo(mux.Vars(req)["locator"]); err != nil {
99                 rtr.handleError(w, req, err)
100                 return
101         } else if li.size == 0 && li.hash != "d41d8cd98f00b204e9800998ecf8427e" {
102                 // GET {hash} (with no size hint) is not allowed
103                 // because we can't report md5 mismatches.
104                 rtr.handleError(w, req, errMethodNotAllowed)
105                 return
106         }
107         n, err := rtr.keepstore.BlockRead(req.Context(), arvados.BlockReadOptions{
108                 Locator:      mux.Vars(req)["locator"],
109                 WriteTo:      out,
110                 LocalLocator: localLocator,
111         })
112         if err != nil && (n == 0 || req.Method == http.MethodHead) {
113                 rtr.handleError(w, req, err)
114                 return
115         }
116 }
117
118 func (rtr *router) handleBlockWrite(w http.ResponseWriter, req *http.Request) {
119         dataSize, _ := strconv.Atoi(req.Header.Get("Content-Length"))
120         replicas, _ := strconv.Atoi(req.Header.Get(keepclient.XKeepDesiredReplicas))
121         resp, err := rtr.keepstore.BlockWrite(req.Context(), arvados.BlockWriteOptions{
122                 Hash:           mux.Vars(req)["locator"],
123                 Reader:         req.Body,
124                 DataSize:       dataSize,
125                 RequestID:      req.Header.Get("X-Request-Id"),
126                 StorageClasses: trimSplit(req.Header.Get(keepclient.XKeepStorageClasses), ","),
127                 Replicas:       replicas,
128         })
129         if err != nil {
130                 rtr.handleError(w, req, err)
131                 return
132         }
133         w.Header().Set(keepclient.XKeepReplicasStored, fmt.Sprintf("%d", resp.Replicas))
134         scc := ""
135         for k, n := range resp.StorageClasses {
136                 if n > 0 {
137                         if scc != "" {
138                                 scc += ", "
139                         }
140                         scc += fmt.Sprintf("%s=%d", k, n)
141                 }
142         }
143         w.Header().Set(keepclient.XKeepStorageClassesConfirmed, scc)
144         w.WriteHeader(http.StatusOK)
145         fmt.Fprintln(w, resp.Locator)
146 }
147
148 func (rtr *router) handleBlockTouch(w http.ResponseWriter, req *http.Request) {
149         err := rtr.keepstore.BlockTouch(req.Context(), mux.Vars(req)["locator"])
150         rtr.handleError(w, req, err)
151 }
152
153 func (rtr *router) handleBlockTrash(w http.ResponseWriter, req *http.Request) {
154         err := rtr.keepstore.BlockTrash(req.Context(), mux.Vars(req)["locator"])
155         rtr.handleError(w, req, err)
156 }
157
158 func (rtr *router) handleMounts(w http.ResponseWriter, req *http.Request) {
159         json.NewEncoder(w).Encode(rtr.keepstore.Mounts())
160 }
161
162 func (rtr *router) handleIndex(w http.ResponseWriter, req *http.Request) {
163         prefix := req.FormValue("prefix")
164         if prefix == "" {
165                 prefix = mux.Vars(req)["prefix"]
166         }
167         cw := &countingWriter{writer: w}
168         err := rtr.keepstore.Index(req.Context(), indexOptions{
169                 MountUUID: mux.Vars(req)["uuid"],
170                 Prefix:    prefix,
171                 WriteTo:   cw,
172         })
173         if err != nil && cw.n.Load() == 0 {
174                 // Nothing was written, so it's not too late to report
175                 // an error via http response header. (Otherwise, all
176                 // we can do is omit the trailing newline below to
177                 // indicate something went wrong.)
178                 rtr.handleError(w, req, err)
179                 return
180         }
181         if err == nil {
182                 // A trailing blank line signals to the caller that
183                 // the response is complete.
184                 w.Write([]byte("\n"))
185         }
186 }
187
188 func (rtr *router) handlePullList(w http.ResponseWriter, req *http.Request) {
189         var pl []PullListItem
190         err := json.NewDecoder(req.Body).Decode(&pl)
191         if err != nil {
192                 rtr.handleError(w, req, err)
193                 return
194         }
195         req.Body.Close()
196         if len(pl) > 0 && len(pl[0].Locator) == 32 {
197                 rtr.handleError(w, req, httpserver.ErrorWithStatus(errors.New("rejecting pull list containing a locator without a size hint -- this probably means keep-balance needs to be upgraded"), http.StatusBadRequest))
198                 return
199         }
200         rtr.puller.SetPullList(pl)
201 }
202
203 func (rtr *router) handleTrashList(w http.ResponseWriter, req *http.Request) {
204         var tl []TrashListItem
205         err := json.NewDecoder(req.Body).Decode(&tl)
206         if err != nil {
207                 rtr.handleError(w, req, err)
208                 return
209         }
210         req.Body.Close()
211         rtr.trasher.SetTrashList(tl)
212 }
213
214 func (rtr *router) handleUntrash(w http.ResponseWriter, req *http.Request) {
215         err := rtr.keepstore.BlockUntrash(req.Context(), mux.Vars(req)["locator"])
216         rtr.handleError(w, req, err)
217 }
218
219 func (rtr *router) handleBadRequest(w http.ResponseWriter, req *http.Request) {
220         http.Error(w, "Bad Request", http.StatusBadRequest)
221 }
222
223 func (rtr *router) handleOptions(w http.ResponseWriter, req *http.Request) {
224 }
225
226 func (rtr *router) handleError(w http.ResponseWriter, req *http.Request, err error) {
227         if req.Context().Err() != nil {
228                 w.WriteHeader(499)
229                 return
230         }
231         if err == nil {
232                 return
233         } else if os.IsNotExist(err) {
234                 w.WriteHeader(http.StatusNotFound)
235         } else if statusErr := interface{ HTTPStatus() int }(nil); errors.As(err, &statusErr) {
236                 w.WriteHeader(statusErr.HTTPStatus())
237         } else {
238                 w.WriteHeader(http.StatusInternalServerError)
239         }
240         fmt.Fprintln(w, err.Error())
241 }
242
243 type countingWriter struct {
244         writer io.Writer
245         n      atomic.Int64
246 }
247
248 func (cw *countingWriter) Write(p []byte) (int, error) {
249         n, err := cw.writer.Write(p)
250         cw.n.Add(int64(n))
251         return n, err
252 }
253
254 // Split s by sep, trim whitespace from each part, and drop empty
255 // parts.
256 func trimSplit(s, sep string) []string {
257         var r []string
258         for _, part := range strings.Split(s, sep) {
259                 part = strings.TrimSpace(part)
260                 if part != "" {
261                         r = append(r, part)
262                 }
263         }
264         return r
265 }
266
267 // setSizeOnWrite sets the Content-Length header to the given size on
268 // first write.
269 type setSizeOnWrite struct {
270         http.ResponseWriter
271         size  int
272         wrote bool
273 }
274
275 func (ss *setSizeOnWrite) Write(p []byte) (int, error) {
276         if !ss.wrote {
277                 ss.Header().Set("Content-Length", fmt.Sprintf("%d", ss.size))
278                 ss.wrote = true
279         }
280         return ss.ResponseWriter.Write(p)
281 }
282
283 type discardWrite struct {
284         http.ResponseWriter
285 }
286
287 func (discardWrite) Write(p []byte) (int, error) {
288         return len(p), nil
289 }
290
291 func corsHandler(h http.Handler) http.Handler {
292         return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
293                 SetCORSHeaders(w)
294                 h.ServeHTTP(w, r)
295         })
296 }
297
298 var corsHeaders = map[string]string{
299         "Access-Control-Allow-Methods":  "GET, HEAD, PUT, OPTIONS",
300         "Access-Control-Allow-Origin":   "*",
301         "Access-Control-Allow-Headers":  "Authorization, Content-Length, Content-Type, " + keepclient.XKeepDesiredReplicas + ", " + keepclient.XKeepSignature + ", " + keepclient.XKeepStorageClasses,
302         "Access-Control-Expose-Headers": keepclient.XKeepLocator + ", " + keepclient.XKeepReplicasStored + ", " + keepclient.XKeepStorageClassesConfirmed,
303         "Access-Control-Max-Age":        "86486400",
304 }
305
306 func SetCORSHeaders(w http.ResponseWriter) {
307         for k, v := range corsHeaders {
308                 w.Header().Set(k, v)
309         }
310 }