Merge branch '5824-keep-web' into 5824-keep-web-workbench
[arvados.git] / services / keepproxy / keepproxy.go
1 package main
2
3 import (
4         "errors"
5         "flag"
6         "fmt"
7         "git.curoverse.com/arvados.git/sdk/go/arvadosclient"
8         "git.curoverse.com/arvados.git/sdk/go/keepclient"
9         "github.com/gorilla/mux"
10         "io"
11         "io/ioutil"
12         "log"
13         "net"
14         "net/http"
15         "os"
16         "os/signal"
17         "reflect"
18         "regexp"
19         "sync"
20         "syscall"
21         "time"
22 )
23
24 // Default TCP address on which to listen for requests.
25 // Override with -listen.
26 const DefaultAddr = ":25107"
27
28 var listener net.Listener
29
30 func main() {
31         var (
32                 listen           string
33                 no_get           bool
34                 no_put           bool
35                 default_replicas int
36                 timeout          int64
37                 pidfile          string
38         )
39
40         flagset := flag.NewFlagSet("keepproxy", flag.ExitOnError)
41
42         flagset.StringVar(
43                 &listen,
44                 "listen",
45                 DefaultAddr,
46                 "Interface on which to listen for requests, in the format "+
47                         "ipaddr:port. e.g. -listen=10.0.1.24:8000. Use -listen=:port "+
48                         "to listen on all network interfaces.")
49
50         flagset.BoolVar(
51                 &no_get,
52                 "no-get",
53                 false,
54                 "If set, disable GET operations")
55
56         flagset.BoolVar(
57                 &no_put,
58                 "no-put",
59                 false,
60                 "If set, disable PUT operations")
61
62         flagset.IntVar(
63                 &default_replicas,
64                 "default-replicas",
65                 2,
66                 "Default number of replicas to write if not specified by the client.")
67
68         flagset.Int64Var(
69                 &timeout,
70                 "timeout",
71                 15,
72                 "Timeout on requests to internal Keep services (default 15 seconds)")
73
74         flagset.StringVar(
75                 &pidfile,
76                 "pid",
77                 "",
78                 "Path to write pid file")
79
80         flagset.Parse(os.Args[1:])
81
82         if pidfile != "" {
83                 f, err := os.Create(pidfile)
84                 if err != nil {
85                         log.Fatalf("Error writing pid file (%s): %s", pidfile, err.Error())
86                 }
87                 fmt.Fprint(f, os.Getpid())
88                 f.Close()
89                 defer os.Remove(pidfile)
90         }
91
92         arv, err := arvadosclient.MakeArvadosClient()
93         if err != nil {
94                 log.Fatalf("setting up arvados client: %v", err)
95         }
96         kc, err := keepclient.MakeKeepClient(&arv)
97         if err != nil {
98                 log.Fatalf("setting up keep client: %v", err)
99         }
100         kc.Want_replicas = default_replicas
101         kc.Client.Timeout = time.Duration(timeout) * time.Second
102         go RefreshServicesList(kc, 5*time.Minute, 3*time.Second)
103
104         listener, err = net.Listen("tcp", listen)
105         if err != nil {
106                 log.Fatalf("Could not listen on %v", listen)
107         }
108         log.Printf("Arvados Keep proxy started listening on %v", listener.Addr())
109
110         // Shut down the server gracefully (by closing the listener)
111         // if SIGTERM is received.
112         term := make(chan os.Signal, 1)
113         go func(sig <-chan os.Signal) {
114                 s := <-sig
115                 log.Println("caught signal:", s)
116                 listener.Close()
117         }(term)
118         signal.Notify(term, syscall.SIGTERM)
119         signal.Notify(term, syscall.SIGINT)
120
121         // Start serving requests.
122         http.Serve(listener, MakeRESTRouter(!no_get, !no_put, kc))
123
124         log.Println("shutting down")
125 }
126
127 type ApiTokenCache struct {
128         tokens     map[string]int64
129         lock       sync.Mutex
130         expireTime int64
131 }
132
133 // Refresh the keep service list on SIGHUP; when the given interval
134 // has elapsed since the last refresh; and (if the last refresh
135 // failed) the given errInterval has elapsed.
136 func RefreshServicesList(kc *keepclient.KeepClient, interval, errInterval time.Duration) {
137         var previousRoots = []map[string]string{}
138
139         timer := time.NewTimer(interval)
140         gotHUP := make(chan os.Signal, 1)
141         signal.Notify(gotHUP, syscall.SIGHUP)
142
143         for {
144                 select {
145                 case <-gotHUP:
146                 case <-timer.C:
147                 }
148                 timer.Reset(interval)
149
150                 if err := kc.DiscoverKeepServers(); err != nil {
151                         log.Println("Error retrieving services list: %v (retrying in %v)", err, errInterval)
152                         timer.Reset(errInterval)
153                         continue
154                 }
155                 newRoots := []map[string]string{kc.LocalRoots(), kc.GatewayRoots()}
156
157                 if !reflect.DeepEqual(previousRoots, newRoots) {
158                         log.Printf("Updated services list: locals %v gateways %v", newRoots[0], newRoots[1])
159                         previousRoots = newRoots
160                 }
161
162                 if len(newRoots[0]) == 0 {
163                         log.Printf("WARNING: No local services (retrying in %v)", errInterval)
164                         timer.Reset(errInterval)
165                 }
166         }
167 }
168
169 // Cache the token and set an expire time.  If we already have an expire time
170 // on the token, it is not updated.
171 func (this *ApiTokenCache) RememberToken(token string) {
172         this.lock.Lock()
173         defer this.lock.Unlock()
174
175         now := time.Now().Unix()
176         if this.tokens[token] == 0 {
177                 this.tokens[token] = now + this.expireTime
178         }
179 }
180
181 // Check if the cached token is known and still believed to be valid.
182 func (this *ApiTokenCache) RecallToken(token string) bool {
183         this.lock.Lock()
184         defer this.lock.Unlock()
185
186         now := time.Now().Unix()
187         if this.tokens[token] == 0 {
188                 // Unknown token
189                 return false
190         } else if now < this.tokens[token] {
191                 // Token is known and still valid
192                 return true
193         } else {
194                 // Token is expired
195                 this.tokens[token] = 0
196                 return false
197         }
198 }
199
200 func GetRemoteAddress(req *http.Request) string {
201         if xff := req.Header.Get("X-Forwarded-For"); xff != "" {
202                 return xff + "," + req.RemoteAddr
203         }
204         return req.RemoteAddr
205 }
206
207 func CheckAuthorizationHeader(kc *keepclient.KeepClient, cache *ApiTokenCache, req *http.Request) (pass bool, tok string) {
208         var auth string
209         if auth = req.Header.Get("Authorization"); auth == "" {
210                 return false, ""
211         }
212
213         _, err := fmt.Sscanf(auth, "OAuth2 %s", &tok)
214         if err != nil {
215                 // Scanning error
216                 return false, ""
217         }
218
219         if cache.RecallToken(tok) {
220                 // Valid in the cache, short circut
221                 return true, tok
222         }
223
224         arv := *kc.Arvados
225         arv.ApiToken = tok
226         if err := arv.Call("HEAD", "users", "", "current", nil, nil); err != nil {
227                 log.Printf("%s: CheckAuthorizationHeader error: %v", GetRemoteAddress(req), err)
228                 return false, ""
229         }
230
231         // Success!  Update cache
232         cache.RememberToken(tok)
233
234         return true, tok
235 }
236
237 type GetBlockHandler struct {
238         *keepclient.KeepClient
239         *ApiTokenCache
240 }
241
242 type PutBlockHandler struct {
243         *keepclient.KeepClient
244         *ApiTokenCache
245 }
246
247 type IndexHandler struct {
248         *keepclient.KeepClient
249         *ApiTokenCache
250 }
251
252 type InvalidPathHandler struct{}
253
254 type OptionsHandler struct{}
255
256 // MakeRESTRouter
257 //     Returns a mux.Router that passes GET and PUT requests to the
258 //     appropriate handlers.
259 //
260 func MakeRESTRouter(
261         enable_get bool,
262         enable_put bool,
263         kc *keepclient.KeepClient) *mux.Router {
264
265         t := &ApiTokenCache{tokens: make(map[string]int64), expireTime: 300}
266
267         rest := mux.NewRouter()
268
269         if enable_get {
270                 rest.Handle(`/{locator:[0-9a-f]{32}\+.*}`,
271                         GetBlockHandler{kc, t}).Methods("GET", "HEAD")
272                 rest.Handle(`/{locator:[0-9a-f]{32}}`, GetBlockHandler{kc, t}).Methods("GET", "HEAD")
273
274                 // List all blocks
275                 rest.Handle(`/index`, IndexHandler{kc, t}).Methods("GET")
276
277                 // List blocks whose hash has the given prefix
278                 rest.Handle(`/index/{prefix:[0-9a-f]{0,32}}`, IndexHandler{kc, t}).Methods("GET")
279         }
280
281         if enable_put {
282                 rest.Handle(`/{locator:[0-9a-f]{32}\+.*}`, PutBlockHandler{kc, t}).Methods("PUT")
283                 rest.Handle(`/{locator:[0-9a-f]{32}}`, PutBlockHandler{kc, t}).Methods("PUT")
284                 rest.Handle(`/`, PutBlockHandler{kc, t}).Methods("POST")
285                 rest.Handle(`/{any}`, OptionsHandler{}).Methods("OPTIONS")
286                 rest.Handle(`/`, OptionsHandler{}).Methods("OPTIONS")
287         }
288
289         rest.NotFoundHandler = InvalidPathHandler{}
290
291         return rest
292 }
293
294 func SetCorsHeaders(resp http.ResponseWriter) {
295         resp.Header().Set("Access-Control-Allow-Methods", "GET, HEAD, POST, PUT, OPTIONS")
296         resp.Header().Set("Access-Control-Allow-Origin", "*")
297         resp.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Length, Content-Type, X-Keep-Desired-Replicas")
298         resp.Header().Set("Access-Control-Max-Age", "86486400")
299 }
300
301 func (this InvalidPathHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
302         log.Printf("%s: %s %s unroutable", GetRemoteAddress(req), req.Method, req.URL.Path)
303         http.Error(resp, "Bad request", http.StatusBadRequest)
304 }
305
306 func (this OptionsHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
307         log.Printf("%s: %s %s", GetRemoteAddress(req), req.Method, req.URL.Path)
308         SetCorsHeaders(resp)
309 }
310
311 var BadAuthorizationHeader = errors.New("Missing or invalid Authorization header")
312 var ContentLengthMismatch = errors.New("Actual length != expected content length")
313 var MethodNotSupported = errors.New("Method not supported")
314
315 var removeHint, _ = regexp.Compile("\\+K@[a-z0-9]{5}(\\+|$)")
316
317 func (this GetBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
318         SetCorsHeaders(resp)
319
320         locator := mux.Vars(req)["locator"]
321         var err error
322         var status int
323         var expectLength, responseLength int64
324         var proxiedURI = "-"
325
326         defer func() {
327                 log.Println(GetRemoteAddress(req), req.Method, req.URL.Path, status, expectLength, responseLength, proxiedURI, err)
328                 if status != http.StatusOK {
329                         http.Error(resp, err.Error(), status)
330                 }
331         }()
332
333         kc := *this.KeepClient
334
335         var pass bool
336         var tok string
337         if pass, tok = CheckAuthorizationHeader(&kc, this.ApiTokenCache, req); !pass {
338                 status, err = http.StatusForbidden, BadAuthorizationHeader
339                 return
340         }
341
342         // Copy ArvadosClient struct and use the client's API token
343         arvclient := *kc.Arvados
344         arvclient.ApiToken = tok
345         kc.Arvados = &arvclient
346
347         var reader io.ReadCloser
348
349         locator = removeHint.ReplaceAllString(locator, "$1")
350
351         switch req.Method {
352         case "HEAD":
353                 expectLength, proxiedURI, err = kc.Ask(locator)
354         case "GET":
355                 reader, expectLength, proxiedURI, err = kc.Get(locator)
356                 if reader != nil {
357                         defer reader.Close()
358                 }
359         default:
360                 status, err = http.StatusNotImplemented, MethodNotSupported
361                 return
362         }
363
364         if expectLength == -1 {
365                 log.Println("Warning:", GetRemoteAddress(req), req.Method, proxiedURI, "Content-Length not provided")
366         }
367
368         switch respErr := err.(type) {
369         case nil:
370                 status = http.StatusOK
371                 resp.Header().Set("Content-Length", fmt.Sprint(expectLength))
372                 switch req.Method {
373                 case "HEAD":
374                         responseLength = 0
375                 case "GET":
376                         responseLength, err = io.Copy(resp, reader)
377                         if err == nil && expectLength > -1 && responseLength != expectLength {
378                                 err = ContentLengthMismatch
379                         }
380                 }
381         case keepclient.Error:
382                 if respErr == keepclient.BlockNotFound {
383                         status = http.StatusNotFound
384                 } else if respErr.Temporary() {
385                         status = http.StatusBadGateway
386                 } else {
387                         status = 422
388                 }
389         default:
390                 status = http.StatusInternalServerError
391         }
392 }
393
394 var LengthRequiredError = errors.New(http.StatusText(http.StatusLengthRequired))
395 var LengthMismatchError = errors.New("Locator size hint does not match Content-Length header")
396
397 func (this PutBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
398         SetCorsHeaders(resp)
399
400         kc := *this.KeepClient
401         var err error
402         var expectLength int64 = -1
403         var status = http.StatusInternalServerError
404         var wroteReplicas int
405         var locatorOut string = "-"
406
407         defer func() {
408                 log.Println(GetRemoteAddress(req), req.Method, req.URL.Path, status, expectLength, kc.Want_replicas, wroteReplicas, locatorOut, err)
409                 if status != http.StatusOK {
410                         http.Error(resp, err.Error(), status)
411                 }
412         }()
413
414         locatorIn := mux.Vars(req)["locator"]
415
416         if req.Header.Get("Content-Length") != "" {
417                 _, err := fmt.Sscanf(req.Header.Get("Content-Length"), "%d", &expectLength)
418                 if err != nil {
419                         resp.Header().Set("Content-Length", fmt.Sprintf("%d", expectLength))
420                 }
421
422         }
423
424         if expectLength < 0 {
425                 err = LengthRequiredError
426                 status = http.StatusLengthRequired
427                 return
428         }
429
430         if locatorIn != "" {
431                 var loc *keepclient.Locator
432                 if loc, err = keepclient.MakeLocator(locatorIn); err != nil {
433                         status = http.StatusBadRequest
434                         return
435                 } else if loc.Size > 0 && int64(loc.Size) != expectLength {
436                         err = LengthMismatchError
437                         status = http.StatusBadRequest
438                         return
439                 }
440         }
441
442         var pass bool
443         var tok string
444         if pass, tok = CheckAuthorizationHeader(&kc, this.ApiTokenCache, req); !pass {
445                 err = BadAuthorizationHeader
446                 status = http.StatusForbidden
447                 return
448         }
449
450         // Copy ArvadosClient struct and use the client's API token
451         arvclient := *kc.Arvados
452         arvclient.ApiToken = tok
453         kc.Arvados = &arvclient
454
455         // Check if the client specified the number of replicas
456         if req.Header.Get("X-Keep-Desired-Replicas") != "" {
457                 var r int
458                 _, err := fmt.Sscanf(req.Header.Get(keepclient.X_Keep_Desired_Replicas), "%d", &r)
459                 if err != nil {
460                         kc.Want_replicas = r
461                 }
462         }
463
464         // Now try to put the block through
465         if locatorIn == "" {
466                 if bytes, err := ioutil.ReadAll(req.Body); err != nil {
467                         err = errors.New(fmt.Sprintf("Error reading request body: %s", err))
468                         status = http.StatusInternalServerError
469                         return
470                 } else {
471                         locatorOut, wroteReplicas, err = kc.PutB(bytes)
472                 }
473         } else {
474                 locatorOut, wroteReplicas, err = kc.PutHR(locatorIn, req.Body, expectLength)
475         }
476
477         // Tell the client how many successful PUTs we accomplished
478         resp.Header().Set(keepclient.X_Keep_Replicas_Stored, fmt.Sprintf("%d", wroteReplicas))
479
480         switch err {
481         case nil:
482                 status = http.StatusOK
483                 _, err = io.WriteString(resp, locatorOut)
484
485         case keepclient.OversizeBlockError:
486                 // Too much data
487                 status = http.StatusRequestEntityTooLarge
488
489         case keepclient.InsufficientReplicasError:
490                 if wroteReplicas > 0 {
491                         // At least one write is considered success.  The
492                         // client can decide if getting less than the number of
493                         // replications it asked for is a fatal error.
494                         status = http.StatusOK
495                         _, err = io.WriteString(resp, locatorOut)
496                 } else {
497                         status = http.StatusServiceUnavailable
498                 }
499
500         default:
501                 status = http.StatusBadGateway
502         }
503 }
504
505 // ServeHTTP implementation for IndexHandler
506 // Supports only GET requests for /index/{prefix:[0-9a-f]{0,32}}
507 // For each keep server found in LocalRoots:
508 //   Invokes GetIndex using keepclient
509 //   Expects "complete" response (terminating with blank new line)
510 //   Aborts on any errors
511 // Concatenates responses from all those keep servers and returns
512 func (handler IndexHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
513         SetCorsHeaders(resp)
514
515         prefix := mux.Vars(req)["prefix"]
516         var err error
517         var status int
518
519         defer func() {
520                 if status != http.StatusOK {
521                         http.Error(resp, err.Error(), status)
522                 }
523         }()
524
525         kc := *handler.KeepClient
526
527         ok, token := CheckAuthorizationHeader(&kc, handler.ApiTokenCache, req)
528         if !ok {
529                 status, err = http.StatusForbidden, BadAuthorizationHeader
530                 return
531         }
532
533         // Copy ArvadosClient struct and use the client's API token
534         arvclient := *kc.Arvados
535         arvclient.ApiToken = token
536         kc.Arvados = &arvclient
537
538         // Only GET method is supported
539         if req.Method != "GET" {
540                 status, err = http.StatusNotImplemented, MethodNotSupported
541                 return
542         }
543
544         // Get index from all LocalRoots and write to resp
545         var reader io.Reader
546         for uuid := range kc.LocalRoots() {
547                 reader, err = kc.GetIndex(uuid, prefix)
548                 if err != nil {
549                         status = http.StatusBadGateway
550                         return
551                 }
552
553                 _, err = io.Copy(resp, reader)
554                 if err != nil {
555                         status = http.StatusBadGateway
556                         return
557                 }
558         }
559
560         // Got index from all the keep servers and wrote to resp
561         status = http.StatusOK
562         resp.Write([]byte("\n"))
563 }