5824: Merge branch 'master' into 5824-keep-web
[arvados.git] / services / keepproxy / keepproxy.go
1 package main
2
3 import (
4         "errors"
5         "flag"
6         "fmt"
7         "git.curoverse.com/arvados.git/sdk/go/arvadosclient"
8         "git.curoverse.com/arvados.git/sdk/go/keepclient"
9         "github.com/gorilla/mux"
10         "io"
11         "io/ioutil"
12         "log"
13         "net"
14         "net/http"
15         "os"
16         "os/signal"
17         "reflect"
18         "regexp"
19         "sync"
20         "syscall"
21         "time"
22 )
23
24 // Default TCP address on which to listen for requests.
25 // Initialized by the -listen flag.
26 const DEFAULT_ADDR = ":25107"
27
28 var listener net.Listener
29
30 func main() {
31         var (
32                 listen           string
33                 no_get           bool
34                 no_put           bool
35                 default_replicas int
36                 timeout          int64
37                 pidfile          string
38         )
39
40         flagset := flag.NewFlagSet("keepproxy", flag.ExitOnError)
41
42         flagset.StringVar(
43                 &listen,
44                 "listen",
45                 DEFAULT_ADDR,
46                 "Interface on which to listen for requests, in the format "+
47                         "ipaddr:port. e.g. -listen=10.0.1.24:8000. Use -listen=:port "+
48                         "to listen on all network interfaces.")
49
50         flagset.BoolVar(
51                 &no_get,
52                 "no-get",
53                 false,
54                 "If set, disable GET operations")
55
56         flagset.BoolVar(
57                 &no_put,
58                 "no-put",
59                 false,
60                 "If set, disable PUT operations")
61
62         flagset.IntVar(
63                 &default_replicas,
64                 "default-replicas",
65                 2,
66                 "Default number of replicas to write if not specified by the client.")
67
68         flagset.Int64Var(
69                 &timeout,
70                 "timeout",
71                 15,
72                 "Timeout on requests to internal Keep services (default 15 seconds)")
73
74         flagset.StringVar(
75                 &pidfile,
76                 "pid",
77                 "",
78                 "Path to write pid file")
79
80         flagset.Parse(os.Args[1:])
81
82         arv, err := arvadosclient.MakeArvadosClient()
83         if err != nil {
84                 log.Fatalf("Error setting up arvados client %s", err.Error())
85         }
86
87         kc, err := keepclient.MakeKeepClient(&arv)
88         if err != nil {
89                 log.Fatalf("Error setting up keep client %s", err.Error())
90         }
91
92         if pidfile != "" {
93                 f, err := os.Create(pidfile)
94                 if err != nil {
95                         log.Fatalf("Error writing pid file (%s): %s", pidfile, err.Error())
96                 }
97                 fmt.Fprint(f, os.Getpid())
98                 f.Close()
99                 defer os.Remove(pidfile)
100         }
101
102         kc.Want_replicas = default_replicas
103
104         kc.Client.Timeout = time.Duration(timeout) * time.Second
105
106         listener, err = net.Listen("tcp", listen)
107         if err != nil {
108                 log.Fatalf("Could not listen on %v", listen)
109         }
110
111         go RefreshServicesList(kc)
112
113         // Shut down the server gracefully (by closing the listener)
114         // if SIGTERM is received.
115         term := make(chan os.Signal, 1)
116         go func(sig <-chan os.Signal) {
117                 s := <-sig
118                 log.Println("caught signal:", s)
119                 listener.Close()
120         }(term)
121         signal.Notify(term, syscall.SIGTERM)
122         signal.Notify(term, syscall.SIGINT)
123
124         log.Printf("Arvados Keep proxy started listening on %v", listener.Addr())
125
126         // Start listening for requests.
127         http.Serve(listener, MakeRESTRouter(!no_get, !no_put, kc))
128
129         log.Println("shutting down")
130 }
131
132 type ApiTokenCache struct {
133         tokens     map[string]int64
134         lock       sync.Mutex
135         expireTime int64
136 }
137
138 // Refresh the keep service list every five minutes.
139 func RefreshServicesList(kc *keepclient.KeepClient) {
140         var previousRoots = []map[string]string{}
141         var delay time.Duration = 0
142         for {
143                 time.Sleep(delay * time.Second)
144                 delay = 300
145                 if err := kc.DiscoverKeepServers(); err != nil {
146                         log.Println("Error retrieving services list:", err)
147                         delay = 3
148                         continue
149                 }
150                 newRoots := []map[string]string{kc.LocalRoots(), kc.GatewayRoots()}
151                 if !reflect.DeepEqual(previousRoots, newRoots) {
152                         log.Printf("Updated services list: locals %v gateways %v", newRoots[0], newRoots[1])
153                 }
154                 if len(newRoots[0]) == 0 {
155                         log.Print("WARNING: No local services. Retrying in 3 seconds.")
156                         delay = 3
157                 }
158                 previousRoots = newRoots
159         }
160 }
161
162 // Cache the token and set an expire time.  If we already have an expire time
163 // on the token, it is not updated.
164 func (this *ApiTokenCache) RememberToken(token string) {
165         this.lock.Lock()
166         defer this.lock.Unlock()
167
168         now := time.Now().Unix()
169         if this.tokens[token] == 0 {
170                 this.tokens[token] = now + this.expireTime
171         }
172 }
173
174 // Check if the cached token is known and still believed to be valid.
175 func (this *ApiTokenCache) RecallToken(token string) bool {
176         this.lock.Lock()
177         defer this.lock.Unlock()
178
179         now := time.Now().Unix()
180         if this.tokens[token] == 0 {
181                 // Unknown token
182                 return false
183         } else if now < this.tokens[token] {
184                 // Token is known and still valid
185                 return true
186         } else {
187                 // Token is expired
188                 this.tokens[token] = 0
189                 return false
190         }
191 }
192
193 func GetRemoteAddress(req *http.Request) string {
194         if realip := req.Header.Get("X-Real-IP"); realip != "" {
195                 if forwarded := req.Header.Get("X-Forwarded-For"); forwarded != realip {
196                         return fmt.Sprintf("%s (X-Forwarded-For %s)", realip, forwarded)
197                 } else {
198                         return realip
199                 }
200         }
201         return req.RemoteAddr
202 }
203
204 func CheckAuthorizationHeader(kc *keepclient.KeepClient, cache *ApiTokenCache, req *http.Request) (pass bool, tok string) {
205         var auth string
206         if auth = req.Header.Get("Authorization"); auth == "" {
207                 return false, ""
208         }
209
210         _, err := fmt.Sscanf(auth, "OAuth2 %s", &tok)
211         if err != nil {
212                 // Scanning error
213                 return false, ""
214         }
215
216         if cache.RecallToken(tok) {
217                 // Valid in the cache, short circut
218                 return true, tok
219         }
220
221         arv := *kc.Arvados
222         arv.ApiToken = tok
223         if err := arv.Call("HEAD", "users", "", "current", nil, nil); err != nil {
224                 log.Printf("%s: CheckAuthorizationHeader error: %v", GetRemoteAddress(req), err)
225                 return false, ""
226         }
227
228         // Success!  Update cache
229         cache.RememberToken(tok)
230
231         return true, tok
232 }
233
234 type GetBlockHandler struct {
235         *keepclient.KeepClient
236         *ApiTokenCache
237 }
238
239 type PutBlockHandler struct {
240         *keepclient.KeepClient
241         *ApiTokenCache
242 }
243
244 type IndexHandler struct {
245         *keepclient.KeepClient
246         *ApiTokenCache
247 }
248
249 type InvalidPathHandler struct{}
250
251 type OptionsHandler struct{}
252
253 // MakeRESTRouter
254 //     Returns a mux.Router that passes GET and PUT requests to the
255 //     appropriate handlers.
256 //
257 func MakeRESTRouter(
258         enable_get bool,
259         enable_put bool,
260         kc *keepclient.KeepClient) *mux.Router {
261
262         t := &ApiTokenCache{tokens: make(map[string]int64), expireTime: 300}
263
264         rest := mux.NewRouter()
265
266         if enable_get {
267                 rest.Handle(`/{locator:[0-9a-f]{32}\+.*}`,
268                         GetBlockHandler{kc, t}).Methods("GET", "HEAD")
269                 rest.Handle(`/{locator:[0-9a-f]{32}}`, GetBlockHandler{kc, t}).Methods("GET", "HEAD")
270
271                 // List all blocks
272                 rest.Handle(`/index`, IndexHandler{kc, t}).Methods("GET")
273
274                 // List blocks whose hash has the given prefix
275                 rest.Handle(`/index/{prefix:[0-9a-f]{0,32}}`, IndexHandler{kc, t}).Methods("GET")
276         }
277
278         if enable_put {
279                 rest.Handle(`/{locator:[0-9a-f]{32}\+.*}`, PutBlockHandler{kc, t}).Methods("PUT")
280                 rest.Handle(`/{locator:[0-9a-f]{32}}`, PutBlockHandler{kc, t}).Methods("PUT")
281                 rest.Handle(`/`, PutBlockHandler{kc, t}).Methods("POST")
282                 rest.Handle(`/{any}`, OptionsHandler{}).Methods("OPTIONS")
283                 rest.Handle(`/`, OptionsHandler{}).Methods("OPTIONS")
284         }
285
286         rest.NotFoundHandler = InvalidPathHandler{}
287
288         return rest
289 }
290
291 func SetCorsHeaders(resp http.ResponseWriter) {
292         resp.Header().Set("Access-Control-Allow-Methods", "GET, HEAD, POST, PUT, OPTIONS")
293         resp.Header().Set("Access-Control-Allow-Origin", "*")
294         resp.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Length, Content-Type, X-Keep-Desired-Replicas")
295         resp.Header().Set("Access-Control-Max-Age", "86486400")
296 }
297
298 func (this InvalidPathHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
299         log.Printf("%s: %s %s unroutable", GetRemoteAddress(req), req.Method, req.URL.Path)
300         http.Error(resp, "Bad request", http.StatusBadRequest)
301 }
302
303 func (this OptionsHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
304         log.Printf("%s: %s %s", GetRemoteAddress(req), req.Method, req.URL.Path)
305         SetCorsHeaders(resp)
306 }
307
308 var BadAuthorizationHeader = errors.New("Missing or invalid Authorization header")
309 var ContentLengthMismatch = errors.New("Actual length != expected content length")
310 var MethodNotSupported = errors.New("Method not supported")
311
312 var removeHint, _ = regexp.Compile("\\+K@[a-z0-9]{5}(\\+|$)")
313
314 func (this GetBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
315         SetCorsHeaders(resp)
316
317         locator := mux.Vars(req)["locator"]
318         var err error
319         var status int
320         var expectLength, responseLength int64
321         var proxiedURI = "-"
322
323         defer func() {
324                 log.Println(GetRemoteAddress(req), req.Method, req.URL.Path, status, expectLength, responseLength, proxiedURI, err)
325                 if status != http.StatusOK {
326                         http.Error(resp, err.Error(), status)
327                 }
328         }()
329
330         kc := *this.KeepClient
331
332         var pass bool
333         var tok string
334         if pass, tok = CheckAuthorizationHeader(&kc, this.ApiTokenCache, req); !pass {
335                 status, err = http.StatusForbidden, BadAuthorizationHeader
336                 return
337         }
338
339         // Copy ArvadosClient struct and use the client's API token
340         arvclient := *kc.Arvados
341         arvclient.ApiToken = tok
342         kc.Arvados = &arvclient
343
344         var reader io.ReadCloser
345
346         locator = removeHint.ReplaceAllString(locator, "$1")
347
348         switch req.Method {
349         case "HEAD":
350                 expectLength, proxiedURI, err = kc.Ask(locator)
351         case "GET":
352                 reader, expectLength, proxiedURI, err = kc.Get(locator)
353                 if reader != nil {
354                         defer reader.Close()
355                 }
356         default:
357                 status, err = http.StatusNotImplemented, MethodNotSupported
358                 return
359         }
360
361         if expectLength == -1 {
362                 log.Println("Warning:", GetRemoteAddress(req), req.Method, proxiedURI, "Content-Length not provided")
363         }
364
365         switch respErr := err.(type) {
366         case nil:
367                 status = http.StatusOK
368                 resp.Header().Set("Content-Length", fmt.Sprint(expectLength))
369                 switch req.Method {
370                 case "HEAD":
371                         responseLength = 0
372                 case "GET":
373                         responseLength, err = io.Copy(resp, reader)
374                         if err == nil && expectLength > -1 && responseLength != expectLength {
375                                 err = ContentLengthMismatch
376                         }
377                 }
378         case keepclient.Error:
379                 if respErr == keepclient.BlockNotFound {
380                         status = http.StatusNotFound
381                 } else if respErr.Temporary() {
382                         status = http.StatusBadGateway
383                 } else {
384                         status = 422
385                 }
386         default:
387                 status = http.StatusInternalServerError
388         }
389 }
390
391 var LengthRequiredError = errors.New(http.StatusText(http.StatusLengthRequired))
392 var LengthMismatchError = errors.New("Locator size hint does not match Content-Length header")
393
394 func (this PutBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
395         SetCorsHeaders(resp)
396
397         kc := *this.KeepClient
398         var err error
399         var expectLength int64 = -1
400         var status = http.StatusInternalServerError
401         var wroteReplicas int
402         var locatorOut string = "-"
403
404         defer func() {
405                 log.Println(GetRemoteAddress(req), req.Method, req.URL.Path, status, expectLength, kc.Want_replicas, wroteReplicas, locatorOut, err)
406                 if status != http.StatusOK {
407                         http.Error(resp, err.Error(), status)
408                 }
409         }()
410
411         locatorIn := mux.Vars(req)["locator"]
412
413         if req.Header.Get("Content-Length") != "" {
414                 _, err := fmt.Sscanf(req.Header.Get("Content-Length"), "%d", &expectLength)
415                 if err != nil {
416                         resp.Header().Set("Content-Length", fmt.Sprintf("%d", expectLength))
417                 }
418
419         }
420
421         if expectLength < 0 {
422                 err = LengthRequiredError
423                 status = http.StatusLengthRequired
424                 return
425         }
426
427         if locatorIn != "" {
428                 var loc *keepclient.Locator
429                 if loc, err = keepclient.MakeLocator(locatorIn); err != nil {
430                         status = http.StatusBadRequest
431                         return
432                 } else if loc.Size > 0 && int64(loc.Size) != expectLength {
433                         err = LengthMismatchError
434                         status = http.StatusBadRequest
435                         return
436                 }
437         }
438
439         var pass bool
440         var tok string
441         if pass, tok = CheckAuthorizationHeader(&kc, this.ApiTokenCache, req); !pass {
442                 err = BadAuthorizationHeader
443                 status = http.StatusForbidden
444                 return
445         }
446
447         // Copy ArvadosClient struct and use the client's API token
448         arvclient := *kc.Arvados
449         arvclient.ApiToken = tok
450         kc.Arvados = &arvclient
451
452         // Check if the client specified the number of replicas
453         if req.Header.Get("X-Keep-Desired-Replicas") != "" {
454                 var r int
455                 _, err := fmt.Sscanf(req.Header.Get(keepclient.X_Keep_Desired_Replicas), "%d", &r)
456                 if err != nil {
457                         kc.Want_replicas = r
458                 }
459         }
460
461         // Now try to put the block through
462         if locatorIn == "" {
463                 if bytes, err := ioutil.ReadAll(req.Body); err != nil {
464                         err = errors.New(fmt.Sprintf("Error reading request body: %s", err))
465                         status = http.StatusInternalServerError
466                         return
467                 } else {
468                         locatorOut, wroteReplicas, err = kc.PutB(bytes)
469                 }
470         } else {
471                 locatorOut, wroteReplicas, err = kc.PutHR(locatorIn, req.Body, expectLength)
472         }
473
474         // Tell the client how many successful PUTs we accomplished
475         resp.Header().Set(keepclient.X_Keep_Replicas_Stored, fmt.Sprintf("%d", wroteReplicas))
476
477         switch err {
478         case nil:
479                 status = http.StatusOK
480                 _, err = io.WriteString(resp, locatorOut)
481
482         case keepclient.OversizeBlockError:
483                 // Too much data
484                 status = http.StatusRequestEntityTooLarge
485
486         case keepclient.InsufficientReplicasError:
487                 if wroteReplicas > 0 {
488                         // At least one write is considered success.  The
489                         // client can decide if getting less than the number of
490                         // replications it asked for is a fatal error.
491                         status = http.StatusOK
492                         _, err = io.WriteString(resp, locatorOut)
493                 } else {
494                         status = http.StatusServiceUnavailable
495                 }
496
497         default:
498                 status = http.StatusBadGateway
499         }
500 }
501
502 // ServeHTTP implementation for IndexHandler
503 // Supports only GET requests for /index/{prefix:[0-9a-f]{0,32}}
504 // For each keep server found in LocalRoots:
505 //   Invokes GetIndex using keepclient
506 //   Expects "complete" response (terminating with blank new line)
507 //   Aborts on any errors
508 // Concatenates responses from all those keep servers and returns
509 func (handler IndexHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
510         SetCorsHeaders(resp)
511
512         prefix := mux.Vars(req)["prefix"]
513         var err error
514         var status int
515
516         defer func() {
517                 if status != http.StatusOK {
518                         http.Error(resp, err.Error(), status)
519                 }
520         }()
521
522         kc := *handler.KeepClient
523
524         ok, token := CheckAuthorizationHeader(&kc, handler.ApiTokenCache, req)
525         if !ok {
526                 status, err = http.StatusForbidden, BadAuthorizationHeader
527                 return
528         }
529
530         // Copy ArvadosClient struct and use the client's API token
531         arvclient := *kc.Arvados
532         arvclient.ApiToken = token
533         kc.Arvados = &arvclient
534
535         // Only GET method is supported
536         if req.Method != "GET" {
537                 status, err = http.StatusNotImplemented, MethodNotSupported
538                 return
539         }
540
541         // Get index from all LocalRoots and write to resp
542         var reader io.Reader
543         for uuid := range kc.LocalRoots() {
544                 reader, err = kc.GetIndex(uuid, prefix)
545                 if err != nil {
546                         status = http.StatusBadGateway
547                         return
548                 }
549
550                 _, err = io.Copy(resp, reader)
551                 if err != nil {
552                         status = http.StatusBadGateway
553                         return
554                 }
555         }
556
557         // Got index from all the keep servers and wrote to resp
558         status = http.StatusOK
559         resp.Write([]byte("\n"))
560 }