Merge branch '21717-keepstore-cors'
[arvados.git] / services / keepstore / router.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package keepstore
6
7 import (
8         "encoding/json"
9         "errors"
10         "fmt"
11         "io"
12         "net/http"
13         "os"
14         "strconv"
15         "strings"
16         "sync/atomic"
17
18         "git.arvados.org/arvados.git/lib/service"
19         "git.arvados.org/arvados.git/sdk/go/arvados"
20         "git.arvados.org/arvados.git/sdk/go/auth"
21         "git.arvados.org/arvados.git/sdk/go/httpserver"
22         "git.arvados.org/arvados.git/sdk/go/keepclient"
23         "github.com/gorilla/mux"
24 )
25
26 type router struct {
27         http.Handler
28         keepstore *keepstore
29         puller    *puller
30         trasher   *trasher
31 }
32
33 func newRouter(keepstore *keepstore, puller *puller, trasher *trasher) service.Handler {
34         rtr := &router{
35                 keepstore: keepstore,
36                 puller:    puller,
37                 trasher:   trasher,
38         }
39         adminonly := func(h http.HandlerFunc) http.HandlerFunc {
40                 return auth.RequireLiteralToken(keepstore.cluster.SystemRootToken, h).ServeHTTP
41         }
42
43         r := mux.NewRouter()
44         locatorPath := `/{locator:[0-9a-f]{32}.*}`
45         get := r.Methods(http.MethodGet, http.MethodHead).Subrouter()
46         get.HandleFunc(locatorPath, rtr.handleBlockRead)
47         get.HandleFunc(`/index`, adminonly(rtr.handleIndex))
48         get.HandleFunc(`/index/{prefix:[0-9a-f]{0,32}}`, adminonly(rtr.handleIndex))
49         get.HandleFunc(`/mounts`, adminonly(rtr.handleMounts))
50         get.HandleFunc(`/mounts/{uuid}/blocks`, adminonly(rtr.handleIndex))
51         get.HandleFunc(`/mounts/{uuid}/blocks/{prefix:[0-9a-f]{0,32}}`, adminonly(rtr.handleIndex))
52         put := r.Methods(http.MethodPut).Subrouter()
53         put.HandleFunc(locatorPath, rtr.handleBlockWrite)
54         put.HandleFunc(`/pull`, adminonly(rtr.handlePullList))
55         put.HandleFunc(`/trash`, adminonly(rtr.handleTrashList))
56         put.HandleFunc(`/untrash`+locatorPath, adminonly(rtr.handleUntrash))
57         touch := r.Methods("TOUCH").Subrouter()
58         touch.HandleFunc(locatorPath, adminonly(rtr.handleBlockTouch))
59         delete := r.Methods(http.MethodDelete).Subrouter()
60         delete.HandleFunc(locatorPath, adminonly(rtr.handleBlockTrash))
61         options := r.Methods(http.MethodOptions).Subrouter()
62         options.NewRoute().PathPrefix(`/`).HandlerFunc(rtr.handleOptions)
63         r.NotFoundHandler = http.HandlerFunc(rtr.handleBadRequest)
64         r.MethodNotAllowedHandler = http.HandlerFunc(rtr.handleBadRequest)
65         rtr.Handler = corsHandler(auth.LoadToken(r))
66         return rtr
67 }
68
69 func (rtr *router) CheckHealth() error {
70         return nil
71 }
72
73 func (rtr *router) Done() <-chan struct{} {
74         return nil
75 }
76
77 func (rtr *router) handleBlockRead(w http.ResponseWriter, req *http.Request) {
78         // Intervening proxies must not return a cached GET response
79         // to a prior request if a X-Keep-Signature request header has
80         // been added or changed.
81         w.Header().Add("Vary", keepclient.XKeepSignature)
82         var localLocator func(string)
83         if strings.SplitN(req.Header.Get(keepclient.XKeepSignature), ",", 2)[0] == "local" {
84                 localLocator = func(locator string) {
85                         w.Header().Set(keepclient.XKeepLocator, locator)
86                 }
87         }
88         out := w
89         if req.Method == http.MethodHead {
90                 out = discardWrite{ResponseWriter: w}
91         } else if li, err := getLocatorInfo(mux.Vars(req)["locator"]); err != nil {
92                 rtr.handleError(w, req, err)
93                 return
94         } else if li.size == 0 && li.hash != "d41d8cd98f00b204e9800998ecf8427e" {
95                 // GET {hash} (with no size hint) is not allowed
96                 // because we can't report md5 mismatches.
97                 rtr.handleError(w, req, errMethodNotAllowed)
98                 return
99         }
100         n, err := rtr.keepstore.BlockRead(req.Context(), arvados.BlockReadOptions{
101                 Locator:      mux.Vars(req)["locator"],
102                 WriteTo:      out,
103                 LocalLocator: localLocator,
104         })
105         if err != nil && (n == 0 || req.Method == http.MethodHead) {
106                 rtr.handleError(w, req, err)
107                 return
108         }
109 }
110
111 func (rtr *router) handleBlockWrite(w http.ResponseWriter, req *http.Request) {
112         dataSize, _ := strconv.Atoi(req.Header.Get("Content-Length"))
113         replicas, _ := strconv.Atoi(req.Header.Get(keepclient.XKeepDesiredReplicas))
114         resp, err := rtr.keepstore.BlockWrite(req.Context(), arvados.BlockWriteOptions{
115                 Hash:           mux.Vars(req)["locator"],
116                 Reader:         req.Body,
117                 DataSize:       dataSize,
118                 RequestID:      req.Header.Get("X-Request-Id"),
119                 StorageClasses: trimSplit(req.Header.Get(keepclient.XKeepStorageClasses), ","),
120                 Replicas:       replicas,
121         })
122         if err != nil {
123                 rtr.handleError(w, req, err)
124                 return
125         }
126         w.Header().Set(keepclient.XKeepReplicasStored, fmt.Sprintf("%d", resp.Replicas))
127         scc := ""
128         for k, n := range resp.StorageClasses {
129                 if n > 0 {
130                         if scc != "" {
131                                 scc += "; "
132                         }
133                         scc += fmt.Sprintf("%s=%d", k, n)
134                 }
135         }
136         w.Header().Set(keepclient.XKeepStorageClassesConfirmed, scc)
137         w.WriteHeader(http.StatusOK)
138         fmt.Fprintln(w, resp.Locator)
139 }
140
141 func (rtr *router) handleBlockTouch(w http.ResponseWriter, req *http.Request) {
142         err := rtr.keepstore.BlockTouch(req.Context(), mux.Vars(req)["locator"])
143         rtr.handleError(w, req, err)
144 }
145
146 func (rtr *router) handleBlockTrash(w http.ResponseWriter, req *http.Request) {
147         err := rtr.keepstore.BlockTrash(req.Context(), mux.Vars(req)["locator"])
148         rtr.handleError(w, req, err)
149 }
150
151 func (rtr *router) handleMounts(w http.ResponseWriter, req *http.Request) {
152         json.NewEncoder(w).Encode(rtr.keepstore.Mounts())
153 }
154
155 func (rtr *router) handleIndex(w http.ResponseWriter, req *http.Request) {
156         prefix := req.FormValue("prefix")
157         if prefix == "" {
158                 prefix = mux.Vars(req)["prefix"]
159         }
160         cw := &countingWriter{writer: w}
161         err := rtr.keepstore.Index(req.Context(), indexOptions{
162                 MountUUID: mux.Vars(req)["uuid"],
163                 Prefix:    prefix,
164                 WriteTo:   cw,
165         })
166         if err != nil && cw.n.Load() == 0 {
167                 // Nothing was written, so it's not too late to report
168                 // an error via http response header. (Otherwise, all
169                 // we can do is omit the trailing newline below to
170                 // indicate something went wrong.)
171                 rtr.handleError(w, req, err)
172                 return
173         }
174         if err == nil {
175                 // A trailing blank line signals to the caller that
176                 // the response is complete.
177                 w.Write([]byte("\n"))
178         }
179 }
180
181 func (rtr *router) handlePullList(w http.ResponseWriter, req *http.Request) {
182         var pl []PullListItem
183         err := json.NewDecoder(req.Body).Decode(&pl)
184         if err != nil {
185                 rtr.handleError(w, req, err)
186                 return
187         }
188         req.Body.Close()
189         if len(pl) > 0 && len(pl[0].Locator) == 32 {
190                 rtr.handleError(w, req, httpserver.ErrorWithStatus(errors.New("rejecting pull list containing a locator without a size hint -- this probably means keep-balance needs to be upgraded"), http.StatusBadRequest))
191                 return
192         }
193         rtr.puller.SetPullList(pl)
194 }
195
196 func (rtr *router) handleTrashList(w http.ResponseWriter, req *http.Request) {
197         var tl []TrashListItem
198         err := json.NewDecoder(req.Body).Decode(&tl)
199         if err != nil {
200                 rtr.handleError(w, req, err)
201                 return
202         }
203         req.Body.Close()
204         rtr.trasher.SetTrashList(tl)
205 }
206
207 func (rtr *router) handleUntrash(w http.ResponseWriter, req *http.Request) {
208         err := rtr.keepstore.BlockUntrash(req.Context(), mux.Vars(req)["locator"])
209         rtr.handleError(w, req, err)
210 }
211
212 func (rtr *router) handleBadRequest(w http.ResponseWriter, req *http.Request) {
213         http.Error(w, "Bad Request", http.StatusBadRequest)
214 }
215
216 func (rtr *router) handleOptions(w http.ResponseWriter, req *http.Request) {
217 }
218
219 func (rtr *router) handleError(w http.ResponseWriter, req *http.Request, err error) {
220         if req.Context().Err() != nil {
221                 w.WriteHeader(499)
222                 return
223         }
224         if err == nil {
225                 return
226         } else if os.IsNotExist(err) {
227                 w.WriteHeader(http.StatusNotFound)
228         } else if statusErr := interface{ HTTPStatus() int }(nil); errors.As(err, &statusErr) {
229                 w.WriteHeader(statusErr.HTTPStatus())
230         } else {
231                 w.WriteHeader(http.StatusInternalServerError)
232         }
233         fmt.Fprintln(w, err.Error())
234 }
235
236 type countingWriter struct {
237         writer io.Writer
238         n      atomic.Int64
239 }
240
241 func (cw *countingWriter) Write(p []byte) (int, error) {
242         n, err := cw.writer.Write(p)
243         cw.n.Add(int64(n))
244         return n, err
245 }
246
247 // Split s by sep, trim whitespace from each part, and drop empty
248 // parts.
249 func trimSplit(s, sep string) []string {
250         var r []string
251         for _, part := range strings.Split(s, sep) {
252                 part = strings.TrimSpace(part)
253                 if part != "" {
254                         r = append(r, part)
255                 }
256         }
257         return r
258 }
259
260 // setSizeOnWrite sets the Content-Length header to the given size on
261 // first write.
262 type setSizeOnWrite struct {
263         http.ResponseWriter
264         size  int
265         wrote bool
266 }
267
268 func (ss *setSizeOnWrite) Write(p []byte) (int, error) {
269         if !ss.wrote {
270                 ss.Header().Set("Content-Length", fmt.Sprintf("%d", ss.size))
271                 ss.wrote = true
272         }
273         return ss.ResponseWriter.Write(p)
274 }
275
276 type discardWrite struct {
277         http.ResponseWriter
278 }
279
280 func (discardWrite) Write(p []byte) (int, error) {
281         return len(p), nil
282 }
283
284 func corsHandler(h http.Handler) http.Handler {
285         return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
286                 SetCORSHeaders(w)
287                 h.ServeHTTP(w, r)
288         })
289 }
290
291 var corsHeaders = map[string]string{
292         "Access-Control-Allow-Methods":  "GET, HEAD, PUT, OPTIONS",
293         "Access-Control-Allow-Origin":   "*",
294         "Access-Control-Allow-Headers":  "Authorization, Content-Length, Content-Type, " + keepclient.XKeepDesiredReplicas + ", " + keepclient.XKeepSignature + ", " + keepclient.XKeepStorageClasses,
295         "Access-Control-Expose-Headers": keepclient.XKeepLocator + ", " + keepclient.XKeepReplicasStored + ", " + keepclient.XKeepStorageClassesConfirmed,
296         "Access-Control-Max-Age":        "86486400",
297 }
298
299 func SetCORSHeaders(w http.ResponseWriter) {
300         for k, v := range corsHeaders {
301                 w.Header().Set(k, v)
302         }
303 }