5824: Refresh keepproxy services list on SIGHUP. Update Workbench upload test to...
[arvados.git] / services / keepproxy / keepproxy.go
1 package main
2
3 import (
4         "errors"
5         "flag"
6         "fmt"
7         "git.curoverse.com/arvados.git/sdk/go/arvadosclient"
8         "git.curoverse.com/arvados.git/sdk/go/keepclient"
9         "github.com/gorilla/mux"
10         "io"
11         "io/ioutil"
12         "log"
13         "net"
14         "net/http"
15         "os"
16         "os/signal"
17         "reflect"
18         "regexp"
19         "sync"
20         "syscall"
21         "time"
22 )
23
24 // Default TCP address on which to listen for requests.
25 // Initialized by the -listen flag.
26 const DEFAULT_ADDR = ":25107"
27
28 var listener net.Listener
29
30 func main() {
31         var (
32                 listen           string
33                 no_get           bool
34                 no_put           bool
35                 default_replicas int
36                 timeout          int64
37                 pidfile          string
38         )
39
40         flagset := flag.NewFlagSet("keepproxy", flag.ExitOnError)
41
42         flagset.StringVar(
43                 &listen,
44                 "listen",
45                 DEFAULT_ADDR,
46                 "Interface on which to listen for requests, in the format "+
47                         "ipaddr:port. e.g. -listen=10.0.1.24:8000. Use -listen=:port "+
48                         "to listen on all network interfaces.")
49
50         flagset.BoolVar(
51                 &no_get,
52                 "no-get",
53                 false,
54                 "If set, disable GET operations")
55
56         flagset.BoolVar(
57                 &no_put,
58                 "no-put",
59                 false,
60                 "If set, disable PUT operations")
61
62         flagset.IntVar(
63                 &default_replicas,
64                 "default-replicas",
65                 2,
66                 "Default number of replicas to write if not specified by the client.")
67
68         flagset.Int64Var(
69                 &timeout,
70                 "timeout",
71                 15,
72                 "Timeout on requests to internal Keep services (default 15 seconds)")
73
74         flagset.StringVar(
75                 &pidfile,
76                 "pid",
77                 "",
78                 "Path to write pid file")
79
80         flagset.Parse(os.Args[1:])
81
82         arv, err := arvadosclient.MakeArvadosClient()
83         if err != nil {
84                 log.Fatalf("Error setting up arvados client %s", err.Error())
85         }
86
87         kc, err := keepclient.MakeKeepClient(&arv)
88         if err != nil {
89                 log.Fatalf("Error setting up keep client %s", err.Error())
90         }
91
92         if pidfile != "" {
93                 f, err := os.Create(pidfile)
94                 if err != nil {
95                         log.Fatalf("Error writing pid file (%s): %s", pidfile, err.Error())
96                 }
97                 fmt.Fprint(f, os.Getpid())
98                 f.Close()
99                 defer os.Remove(pidfile)
100         }
101
102         kc.Want_replicas = default_replicas
103
104         kc.Client.Timeout = time.Duration(timeout) * time.Second
105
106         listener, err = net.Listen("tcp", listen)
107         if err != nil {
108                 log.Fatalf("Could not listen on %v", listen)
109         }
110
111         go RefreshServicesList(kc, 5*time.Minute, 3*time.Second)
112
113         // Shut down the server gracefully (by closing the listener)
114         // if SIGTERM is received.
115         term := make(chan os.Signal, 1)
116         go func(sig <-chan os.Signal) {
117                 s := <-sig
118                 log.Println("caught signal:", s)
119                 listener.Close()
120         }(term)
121         signal.Notify(term, syscall.SIGTERM)
122         signal.Notify(term, syscall.SIGINT)
123
124         log.Printf("Arvados Keep proxy started listening on %v", listener.Addr())
125
126         // Start listening for requests.
127         http.Serve(listener, MakeRESTRouter(!no_get, !no_put, kc))
128
129         log.Println("shutting down")
130 }
131
132 type ApiTokenCache struct {
133         tokens     map[string]int64
134         lock       sync.Mutex
135         expireTime int64
136 }
137
138 // Refresh the keep service list on SIGHUP; when the given interval
139 // has elapsed since the last refresh; and (if the last refresh
140 // failed) the given errInterval has elapsed.
141 func RefreshServicesList(kc *keepclient.KeepClient, interval, errInterval time.Duration) {
142         var previousRoots = []map[string]string{}
143
144         timer := time.NewTimer(interval)
145         gotHUP := make(chan os.Signal, 1)
146         signal.Notify(gotHUP, syscall.SIGHUP)
147
148         for {
149                 select {
150                 case <-gotHUP:
151                 case <-timer.C:
152                 }
153                 timer.Reset(interval)
154
155                 if err := kc.DiscoverKeepServers(); err != nil {
156                         log.Println("Error retrieving services list: %v (retrying in %v)", err, errInterval)
157                         timer.Reset(errInterval)
158                         continue
159                 }
160                 newRoots := []map[string]string{kc.LocalRoots(), kc.GatewayRoots()}
161
162                 if !reflect.DeepEqual(previousRoots, newRoots) {
163                         log.Printf("Updated services list: locals %v gateways %v", newRoots[0], newRoots[1])
164                         previousRoots = newRoots
165                 }
166
167                 if len(newRoots[0]) == 0 {
168                         log.Printf("WARNING: No local services (retrying in %v)", errInterval)
169                         timer.Reset(errInterval)
170                 }
171         }
172 }
173
174 // Cache the token and set an expire time.  If we already have an expire time
175 // on the token, it is not updated.
176 func (this *ApiTokenCache) RememberToken(token string) {
177         this.lock.Lock()
178         defer this.lock.Unlock()
179
180         now := time.Now().Unix()
181         if this.tokens[token] == 0 {
182                 this.tokens[token] = now + this.expireTime
183         }
184 }
185
186 // Check if the cached token is known and still believed to be valid.
187 func (this *ApiTokenCache) RecallToken(token string) bool {
188         this.lock.Lock()
189         defer this.lock.Unlock()
190
191         now := time.Now().Unix()
192         if this.tokens[token] == 0 {
193                 // Unknown token
194                 return false
195         } else if now < this.tokens[token] {
196                 // Token is known and still valid
197                 return true
198         } else {
199                 // Token is expired
200                 this.tokens[token] = 0
201                 return false
202         }
203 }
204
205 func GetRemoteAddress(req *http.Request) string {
206         if realip := req.Header.Get("X-Real-IP"); realip != "" {
207                 if forwarded := req.Header.Get("X-Forwarded-For"); forwarded != realip {
208                         return fmt.Sprintf("%s (X-Forwarded-For %s)", realip, forwarded)
209                 } else {
210                         return realip
211                 }
212         }
213         return req.RemoteAddr
214 }
215
216 func CheckAuthorizationHeader(kc keepclient.KeepClient, cache *ApiTokenCache, req *http.Request) (pass bool, tok string) {
217         var auth string
218         if auth = req.Header.Get("Authorization"); auth == "" {
219                 return false, ""
220         }
221
222         _, err := fmt.Sscanf(auth, "OAuth2 %s", &tok)
223         if err != nil {
224                 // Scanning error
225                 return false, ""
226         }
227
228         if cache.RecallToken(tok) {
229                 // Valid in the cache, short circut
230                 return true, tok
231         }
232
233         arv := *kc.Arvados
234         arv.ApiToken = tok
235         if err := arv.Call("HEAD", "users", "", "current", nil, nil); err != nil {
236                 log.Printf("%s: CheckAuthorizationHeader error: %v", GetRemoteAddress(req), err)
237                 return false, ""
238         }
239
240         // Success!  Update cache
241         cache.RememberToken(tok)
242
243         return true, tok
244 }
245
246 type GetBlockHandler struct {
247         *keepclient.KeepClient
248         *ApiTokenCache
249 }
250
251 type PutBlockHandler struct {
252         *keepclient.KeepClient
253         *ApiTokenCache
254 }
255
256 type IndexHandler struct {
257         *keepclient.KeepClient
258         *ApiTokenCache
259 }
260
261 type InvalidPathHandler struct{}
262
263 type OptionsHandler struct{}
264
265 // MakeRESTRouter
266 //     Returns a mux.Router that passes GET and PUT requests to the
267 //     appropriate handlers.
268 //
269 func MakeRESTRouter(
270         enable_get bool,
271         enable_put bool,
272         kc *keepclient.KeepClient) *mux.Router {
273
274         t := &ApiTokenCache{tokens: make(map[string]int64), expireTime: 300}
275
276         rest := mux.NewRouter()
277
278         if enable_get {
279                 rest.Handle(`/{locator:[0-9a-f]{32}\+.*}`,
280                         GetBlockHandler{kc, t}).Methods("GET", "HEAD")
281                 rest.Handle(`/{locator:[0-9a-f]{32}}`, GetBlockHandler{kc, t}).Methods("GET", "HEAD")
282
283                 // List all blocks
284                 rest.Handle(`/index`, IndexHandler{kc, t}).Methods("GET")
285
286                 // List blocks whose hash has the given prefix
287                 rest.Handle(`/index/{prefix:[0-9a-f]{0,32}}`, IndexHandler{kc, t}).Methods("GET")
288         }
289
290         if enable_put {
291                 rest.Handle(`/{locator:[0-9a-f]{32}\+.*}`, PutBlockHandler{kc, t}).Methods("PUT")
292                 rest.Handle(`/{locator:[0-9a-f]{32}}`, PutBlockHandler{kc, t}).Methods("PUT")
293                 rest.Handle(`/`, PutBlockHandler{kc, t}).Methods("POST")
294                 rest.Handle(`/{any}`, OptionsHandler{}).Methods("OPTIONS")
295                 rest.Handle(`/`, OptionsHandler{}).Methods("OPTIONS")
296         }
297
298         rest.NotFoundHandler = InvalidPathHandler{}
299
300         return rest
301 }
302
303 func SetCorsHeaders(resp http.ResponseWriter) {
304         resp.Header().Set("Access-Control-Allow-Methods", "GET, HEAD, POST, PUT, OPTIONS")
305         resp.Header().Set("Access-Control-Allow-Origin", "*")
306         resp.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Length, Content-Type, X-Keep-Desired-Replicas")
307         resp.Header().Set("Access-Control-Max-Age", "86486400")
308 }
309
310 func (this InvalidPathHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
311         log.Printf("%s: %s %s unroutable", GetRemoteAddress(req), req.Method, req.URL.Path)
312         http.Error(resp, "Bad request", http.StatusBadRequest)
313 }
314
315 func (this OptionsHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
316         log.Printf("%s: %s %s", GetRemoteAddress(req), req.Method, req.URL.Path)
317         SetCorsHeaders(resp)
318 }
319
320 var BadAuthorizationHeader = errors.New("Missing or invalid Authorization header")
321 var ContentLengthMismatch = errors.New("Actual length != expected content length")
322 var MethodNotSupported = errors.New("Method not supported")
323
324 var removeHint, _ = regexp.Compile("\\+K@[a-z0-9]{5}(\\+|$)")
325
326 func (this GetBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
327         SetCorsHeaders(resp)
328
329         locator := mux.Vars(req)["locator"]
330         var err error
331         var status int
332         var expectLength, responseLength int64
333         var proxiedURI = "-"
334
335         defer func() {
336                 log.Println(GetRemoteAddress(req), req.Method, req.URL.Path, status, expectLength, responseLength, proxiedURI, err)
337                 if status != http.StatusOK {
338                         http.Error(resp, err.Error(), status)
339                 }
340         }()
341
342         kc := *this.KeepClient
343
344         var pass bool
345         var tok string
346         if pass, tok = CheckAuthorizationHeader(kc, this.ApiTokenCache, req); !pass {
347                 status, err = http.StatusForbidden, BadAuthorizationHeader
348                 return
349         }
350
351         // Copy ArvadosClient struct and use the client's API token
352         arvclient := *kc.Arvados
353         arvclient.ApiToken = tok
354         kc.Arvados = &arvclient
355
356         var reader io.ReadCloser
357
358         locator = removeHint.ReplaceAllString(locator, "$1")
359
360         switch req.Method {
361         case "HEAD":
362                 expectLength, proxiedURI, err = kc.Ask(locator)
363         case "GET":
364                 reader, expectLength, proxiedURI, err = kc.Get(locator)
365                 if reader != nil {
366                         defer reader.Close()
367                 }
368         default:
369                 status, err = http.StatusNotImplemented, MethodNotSupported
370                 return
371         }
372
373         if expectLength == -1 {
374                 log.Println("Warning:", GetRemoteAddress(req), req.Method, proxiedURI, "Content-Length not provided")
375         }
376
377         switch err {
378         case nil:
379                 status = http.StatusOK
380                 resp.Header().Set("Content-Length", fmt.Sprint(expectLength))
381                 switch req.Method {
382                 case "HEAD":
383                         responseLength = 0
384                 case "GET":
385                         responseLength, err = io.Copy(resp, reader)
386                         if err == nil && expectLength > -1 && responseLength != expectLength {
387                                 err = ContentLengthMismatch
388                         }
389                 }
390         case keepclient.BlockNotFound:
391                 status = http.StatusNotFound
392         default:
393                 status = http.StatusBadGateway
394         }
395 }
396
397 var LengthRequiredError = errors.New(http.StatusText(http.StatusLengthRequired))
398 var LengthMismatchError = errors.New("Locator size hint does not match Content-Length header")
399
400 func (this PutBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
401         SetCorsHeaders(resp)
402
403         kc := *this.KeepClient
404         var err error
405         var expectLength int64 = -1
406         var status = http.StatusInternalServerError
407         var wroteReplicas int
408         var locatorOut string = "-"
409
410         defer func() {
411                 log.Println(GetRemoteAddress(req), req.Method, req.URL.Path, status, expectLength, kc.Want_replicas, wroteReplicas, locatorOut, err)
412                 if status != http.StatusOK {
413                         http.Error(resp, err.Error(), status)
414                 }
415         }()
416
417         locatorIn := mux.Vars(req)["locator"]
418
419         if req.Header.Get("Content-Length") != "" {
420                 _, err := fmt.Sscanf(req.Header.Get("Content-Length"), "%d", &expectLength)
421                 if err != nil {
422                         resp.Header().Set("Content-Length", fmt.Sprintf("%d", expectLength))
423                 }
424
425         }
426
427         if expectLength < 0 {
428                 err = LengthRequiredError
429                 status = http.StatusLengthRequired
430                 return
431         }
432
433         if locatorIn != "" {
434                 var loc *keepclient.Locator
435                 if loc, err = keepclient.MakeLocator(locatorIn); err != nil {
436                         status = http.StatusBadRequest
437                         return
438                 } else if loc.Size > 0 && int64(loc.Size) != expectLength {
439                         err = LengthMismatchError
440                         status = http.StatusBadRequest
441                         return
442                 }
443         }
444
445         var pass bool
446         var tok string
447         if pass, tok = CheckAuthorizationHeader(kc, this.ApiTokenCache, req); !pass {
448                 err = BadAuthorizationHeader
449                 status = http.StatusForbidden
450                 return
451         }
452
453         // Copy ArvadosClient struct and use the client's API token
454         arvclient := *kc.Arvados
455         arvclient.ApiToken = tok
456         kc.Arvados = &arvclient
457
458         // Check if the client specified the number of replicas
459         if req.Header.Get("X-Keep-Desired-Replicas") != "" {
460                 var r int
461                 _, err := fmt.Sscanf(req.Header.Get(keepclient.X_Keep_Desired_Replicas), "%d", &r)
462                 if err != nil {
463                         kc.Want_replicas = r
464                 }
465         }
466
467         // Now try to put the block through
468         if locatorIn == "" {
469                 if bytes, err := ioutil.ReadAll(req.Body); err != nil {
470                         err = errors.New(fmt.Sprintf("Error reading request body: %s", err))
471                         status = http.StatusInternalServerError
472                         return
473                 } else {
474                         locatorOut, wroteReplicas, err = kc.PutB(bytes)
475                 }
476         } else {
477                 locatorOut, wroteReplicas, err = kc.PutHR(locatorIn, req.Body, expectLength)
478         }
479
480         // Tell the client how many successful PUTs we accomplished
481         resp.Header().Set(keepclient.X_Keep_Replicas_Stored, fmt.Sprintf("%d", wroteReplicas))
482
483         switch err {
484         case nil:
485                 status = http.StatusOK
486                 _, err = io.WriteString(resp, locatorOut)
487
488         case keepclient.OversizeBlockError:
489                 // Too much data
490                 status = http.StatusRequestEntityTooLarge
491
492         case keepclient.InsufficientReplicasError:
493                 if wroteReplicas > 0 {
494                         // At least one write is considered success.  The
495                         // client can decide if getting less than the number of
496                         // replications it asked for is a fatal error.
497                         status = http.StatusOK
498                         _, err = io.WriteString(resp, locatorOut)
499                 } else {
500                         status = http.StatusServiceUnavailable
501                 }
502
503         default:
504                 status = http.StatusBadGateway
505         }
506 }
507
508 // ServeHTTP implementation for IndexHandler
509 // Supports only GET requests for /index/{prefix:[0-9a-f]{0,32}}
510 // For each keep server found in LocalRoots:
511 //   Invokes GetIndex using keepclient
512 //   Expects "complete" response (terminating with blank new line)
513 //   Aborts on any errors
514 // Concatenates responses from all those keep servers and returns
515 func (handler IndexHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
516         SetCorsHeaders(resp)
517
518         prefix := mux.Vars(req)["prefix"]
519         var err error
520         var status int
521
522         defer func() {
523                 if status != http.StatusOK {
524                         http.Error(resp, err.Error(), status)
525                 }
526         }()
527
528         kc := *handler.KeepClient
529
530         ok, token := CheckAuthorizationHeader(kc, handler.ApiTokenCache, req)
531         if !ok {
532                 status, err = http.StatusForbidden, BadAuthorizationHeader
533                 return
534         }
535
536         // Copy ArvadosClient struct and use the client's API token
537         arvclient := *kc.Arvados
538         arvclient.ApiToken = token
539         kc.Arvados = &arvclient
540
541         // Only GET method is supported
542         if req.Method != "GET" {
543                 status, err = http.StatusNotImplemented, MethodNotSupported
544                 return
545         }
546
547         // Get index from all LocalRoots and write to resp
548         var reader io.Reader
549         for uuid := range kc.LocalRoots() {
550                 reader, err = kc.GetIndex(uuid, prefix)
551                 if err != nil {
552                         status = http.StatusBadGateway
553                         return
554                 }
555
556                 _, err = io.Copy(resp, reader)
557                 if err != nil {
558                         status = http.StatusBadGateway
559                         return
560                 }
561         }
562
563         // Got index from all the keep servers and wrote to resp
564         status = http.StatusOK
565         resp.Write([]byte("\n"))
566 }