7200: Use if statement instead of switch to check http method in keepclient; strip...
[arvados.git] / services / keepproxy / keepproxy.go
1 package main
2
3 import (
4         "errors"
5         "flag"
6         "fmt"
7         "git.curoverse.com/arvados.git/sdk/go/arvadosclient"
8         "git.curoverse.com/arvados.git/sdk/go/keepclient"
9         "github.com/gorilla/mux"
10         "io"
11         "io/ioutil"
12         "log"
13         "net"
14         "net/http"
15         "os"
16         "os/signal"
17         "reflect"
18         "regexp"
19         "sync"
20         "syscall"
21         "time"
22 )
23
24 // Default TCP address on which to listen for requests.
25 // Initialized by the -listen flag.
26 const DEFAULT_ADDR = ":25107"
27
28 var listener net.Listener
29
30 func main() {
31         var (
32                 listen           string
33                 no_get           bool
34                 no_put           bool
35                 default_replicas int
36                 timeout          int64
37                 pidfile          string
38         )
39
40         flagset := flag.NewFlagSet("default", flag.ExitOnError)
41
42         flagset.StringVar(
43                 &listen,
44                 "listen",
45                 DEFAULT_ADDR,
46                 "Interface on which to listen for requests, in the format "+
47                         "ipaddr:port. e.g. -listen=10.0.1.24:8000. Use -listen=:port "+
48                         "to listen on all network interfaces.")
49
50         flagset.BoolVar(
51                 &no_get,
52                 "no-get",
53                 false,
54                 "If set, disable GET operations")
55
56         flagset.BoolVar(
57                 &no_put,
58                 "no-put",
59                 false,
60                 "If set, disable PUT operations")
61
62         flagset.IntVar(
63                 &default_replicas,
64                 "default-replicas",
65                 2,
66                 "Default number of replicas to write if not specified by the client.")
67
68         flagset.Int64Var(
69                 &timeout,
70                 "timeout",
71                 15,
72                 "Timeout on requests to internal Keep services (default 15 seconds)")
73
74         flagset.StringVar(
75                 &pidfile,
76                 "pid",
77                 "",
78                 "Path to write pid file")
79
80         flagset.Parse(os.Args[1:])
81
82         arv, err := arvadosclient.MakeArvadosClient()
83         if err != nil {
84                 log.Fatalf("Error setting up arvados client %s", err.Error())
85         }
86
87         kc, err := keepclient.MakeKeepClient(&arv)
88         if err != nil {
89                 log.Fatalf("Error setting up keep client %s", err.Error())
90         }
91
92         if pidfile != "" {
93                 f, err := os.Create(pidfile)
94                 if err != nil {
95                         log.Fatalf("Error writing pid file (%s): %s", pidfile, err.Error())
96                 }
97                 fmt.Fprint(f, os.Getpid())
98                 f.Close()
99                 defer os.Remove(pidfile)
100         }
101
102         kc.Want_replicas = default_replicas
103
104         kc.Client.Timeout = time.Duration(timeout) * time.Second
105
106         listener, err = net.Listen("tcp", listen)
107         if err != nil {
108                 log.Fatalf("Could not listen on %v", listen)
109         }
110
111         go RefreshServicesList(kc)
112
113         // Shut down the server gracefully (by closing the listener)
114         // if SIGTERM is received.
115         term := make(chan os.Signal, 1)
116         go func(sig <-chan os.Signal) {
117                 s := <-sig
118                 log.Println("caught signal:", s)
119                 listener.Close()
120         }(term)
121         signal.Notify(term, syscall.SIGTERM)
122         signal.Notify(term, syscall.SIGINT)
123
124         log.Printf("Arvados Keep proxy started listening on %v", listener.Addr())
125
126         // Start listening for requests.
127         http.Serve(listener, MakeRESTRouter(!no_get, !no_put, kc))
128
129         log.Println("shutting down")
130 }
131
132 type ApiTokenCache struct {
133         tokens     map[string]int64
134         lock       sync.Mutex
135         expireTime int64
136 }
137
138 // Refresh the keep service list every five minutes.
139 func RefreshServicesList(kc *keepclient.KeepClient) {
140         var previousRoots = []map[string]string{}
141         var delay time.Duration = 0
142         for {
143                 time.Sleep(delay * time.Second)
144                 delay = 300
145                 if err := kc.DiscoverKeepServers(); err != nil {
146                         log.Println("Error retrieving services list:", err)
147                         delay = 3
148                         continue
149                 }
150                 newRoots := []map[string]string{kc.LocalRoots(), kc.GatewayRoots()}
151                 if !reflect.DeepEqual(previousRoots, newRoots) {
152                         log.Printf("Updated services list: locals %v gateways %v", newRoots[0], newRoots[1])
153                 }
154                 if len(newRoots[0]) == 0 {
155                         log.Print("WARNING: No local services. Retrying in 3 seconds.")
156                         delay = 3
157                 }
158                 previousRoots = newRoots
159         }
160 }
161
162 // Cache the token and set an expire time.  If we already have an expire time
163 // on the token, it is not updated.
164 func (this *ApiTokenCache) RememberToken(token string) {
165         this.lock.Lock()
166         defer this.lock.Unlock()
167
168         now := time.Now().Unix()
169         if this.tokens[token] == 0 {
170                 this.tokens[token] = now + this.expireTime
171         }
172 }
173
174 // Check if the cached token is known and still believed to be valid.
175 func (this *ApiTokenCache) RecallToken(token string) bool {
176         this.lock.Lock()
177         defer this.lock.Unlock()
178
179         now := time.Now().Unix()
180         if this.tokens[token] == 0 {
181                 // Unknown token
182                 return false
183         } else if now < this.tokens[token] {
184                 // Token is known and still valid
185                 return true
186         } else {
187                 // Token is expired
188                 this.tokens[token] = 0
189                 return false
190         }
191 }
192
193 func GetRemoteAddress(req *http.Request) string {
194         if realip := req.Header.Get("X-Real-IP"); realip != "" {
195                 if forwarded := req.Header.Get("X-Forwarded-For"); forwarded != realip {
196                         return fmt.Sprintf("%s (X-Forwarded-For %s)", realip, forwarded)
197                 } else {
198                         return realip
199                 }
200         }
201         return req.RemoteAddr
202 }
203
204 func CheckAuthorizationHeader(kc keepclient.KeepClient, cache *ApiTokenCache, req *http.Request) (pass bool, tok string) {
205         var auth string
206         if auth = req.Header.Get("Authorization"); auth == "" {
207                 return false, ""
208         }
209
210         _, err := fmt.Sscanf(auth, "OAuth2 %s", &tok)
211         if err != nil {
212                 // Scanning error
213                 return false, ""
214         }
215
216         if cache.RecallToken(tok) {
217                 // Valid in the cache, short circut
218                 return true, tok
219         }
220
221         arv := *kc.Arvados
222         arv.ApiToken = tok
223         if err := arv.Call("HEAD", "users", "", "current", nil, nil); err != nil {
224                 log.Printf("%s: CheckAuthorizationHeader error: %v", GetRemoteAddress(req), err)
225                 return false, ""
226         }
227
228         // Success!  Update cache
229         cache.RememberToken(tok)
230
231         return true, tok
232 }
233
234 type GetBlockHandler struct {
235         *keepclient.KeepClient
236         *ApiTokenCache
237 }
238
239 type PutBlockHandler struct {
240         *keepclient.KeepClient
241         *ApiTokenCache
242 }
243
244 type IndexHandler struct {
245         *keepclient.KeepClient
246         *ApiTokenCache
247 }
248
249 type InvalidPathHandler struct{}
250
251 type OptionsHandler struct{}
252
253 // MakeRESTRouter
254 //     Returns a mux.Router that passes GET and PUT requests to the
255 //     appropriate handlers.
256 //
257 func MakeRESTRouter(
258         enable_get bool,
259         enable_put bool,
260         kc *keepclient.KeepClient) *mux.Router {
261
262         t := &ApiTokenCache{tokens: make(map[string]int64), expireTime: 300}
263
264         rest := mux.NewRouter()
265
266         if enable_get {
267                 rest.Handle(`/{locator:[0-9a-f]{32}\+.*}`,
268                         GetBlockHandler{kc, t}).Methods("GET", "HEAD")
269                 rest.Handle(`/{locator:[0-9a-f]{32}}`, GetBlockHandler{kc, t}).Methods("GET", "HEAD")
270
271                 // List all blocks
272                 rest.Handle(`/index`, IndexHandler{kc, t}).Methods("GET")
273
274                 // List blocks whose hash has the given prefix
275                 rest.Handle(`/index/{prefix:[0-9a-f]{0,32}}`, IndexHandler{kc, t}).Methods("GET")
276         }
277
278         if enable_put {
279                 rest.Handle(`/{locator:[0-9a-f]{32}\+.*}`, PutBlockHandler{kc, t}).Methods("PUT")
280                 rest.Handle(`/{locator:[0-9a-f]{32}}`, PutBlockHandler{kc, t}).Methods("PUT")
281                 rest.Handle(`/`, PutBlockHandler{kc, t}).Methods("POST")
282                 rest.Handle(`/{any}`, OptionsHandler{}).Methods("OPTIONS")
283                 rest.Handle(`/`, OptionsHandler{}).Methods("OPTIONS")
284         }
285
286         rest.NotFoundHandler = InvalidPathHandler{}
287
288         return rest
289 }
290
291 func SetCorsHeaders(resp http.ResponseWriter) {
292         resp.Header().Set("Access-Control-Allow-Methods", "GET, HEAD, POST, PUT, OPTIONS")
293         resp.Header().Set("Access-Control-Allow-Origin", "*")
294         resp.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Length, Content-Type, X-Keep-Desired-Replicas")
295         resp.Header().Set("Access-Control-Max-Age", "86486400")
296 }
297
298 func (this InvalidPathHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
299         log.Printf("%s: %s %s unroutable", GetRemoteAddress(req), req.Method, req.URL.Path)
300         http.Error(resp, "Bad request", http.StatusBadRequest)
301 }
302
303 func (this OptionsHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
304         log.Printf("%s: %s %s", GetRemoteAddress(req), req.Method, req.URL.Path)
305         SetCorsHeaders(resp)
306 }
307
308 var BadAuthorizationHeader = errors.New("Missing or invalid Authorization header")
309 var ContentLengthMismatch = errors.New("Actual length != expected content length")
310 var MethodNotSupported = errors.New("Method not supported")
311
312 var removeHint, _ = regexp.Compile("\\+K@[a-z0-9]{5}(\\+|$)")
313
314 func (this GetBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
315         SetCorsHeaders(resp)
316
317         locator := mux.Vars(req)["locator"]
318         var err error
319         var status int
320         var expectLength, responseLength int64
321         var proxiedURI = "-"
322
323         defer func() {
324                 log.Println(GetRemoteAddress(req), req.Method, req.URL.Path, status, expectLength, responseLength, proxiedURI, err)
325                 if status != http.StatusOK {
326                         http.Error(resp, err.Error(), status)
327                 }
328         }()
329
330         kc := *this.KeepClient
331
332         var pass bool
333         var tok string
334         if pass, tok = CheckAuthorizationHeader(kc, this.ApiTokenCache, req); !pass {
335                 status, err = http.StatusForbidden, BadAuthorizationHeader
336                 return
337         }
338
339         // Copy ArvadosClient struct and use the client's API token
340         arvclient := *kc.Arvados
341         arvclient.ApiToken = tok
342         kc.Arvados = &arvclient
343
344         var reader io.ReadCloser
345
346         locator = removeHint.ReplaceAllString(locator, "$1")
347
348         switch req.Method {
349         case "HEAD":
350                 expectLength, proxiedURI, err = kc.Ask(locator)
351         case "GET":
352                 reader, expectLength, proxiedURI, err = kc.Get(locator)
353                 if reader != nil {
354                         defer reader.Close()
355                 }
356         default:
357                 status, err = http.StatusNotImplemented, MethodNotSupported
358                 return
359         }
360
361         if expectLength == -1 {
362                 log.Println("Warning:", GetRemoteAddress(req), req.Method, proxiedURI, "Content-Length not provided")
363         }
364
365         switch err {
366         case nil:
367                 status = http.StatusOK
368                 resp.Header().Set("Content-Length", fmt.Sprint(expectLength))
369                 switch req.Method {
370                 case "HEAD":
371                         responseLength = 0
372                 case "GET":
373                         responseLength, err = io.Copy(resp, reader)
374                         if err == nil && expectLength > -1 && responseLength != expectLength {
375                                 err = ContentLengthMismatch
376                         }
377                 }
378         case keepclient.BlockNotFound:
379                 status = http.StatusNotFound
380         default:
381                 status = http.StatusBadGateway
382         }
383 }
384
385 var LengthRequiredError = errors.New(http.StatusText(http.StatusLengthRequired))
386 var LengthMismatchError = errors.New("Locator size hint does not match Content-Length header")
387
388 func (this PutBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
389         SetCorsHeaders(resp)
390
391         kc := *this.KeepClient
392         var err error
393         var expectLength int64 = -1
394         var status = http.StatusInternalServerError
395         var wroteReplicas int
396         var locatorOut string = "-"
397
398         defer func() {
399                 log.Println(GetRemoteAddress(req), req.Method, req.URL.Path, status, expectLength, kc.Want_replicas, wroteReplicas, locatorOut, err)
400                 if status != http.StatusOK {
401                         http.Error(resp, err.Error(), status)
402                 }
403         }()
404
405         locatorIn := mux.Vars(req)["locator"]
406
407         if req.Header.Get("Content-Length") != "" {
408                 _, err := fmt.Sscanf(req.Header.Get("Content-Length"), "%d", &expectLength)
409                 if err != nil {
410                         resp.Header().Set("Content-Length", fmt.Sprintf("%d", expectLength))
411                 }
412
413         }
414
415         if expectLength < 0 {
416                 err = LengthRequiredError
417                 status = http.StatusLengthRequired
418                 return
419         }
420
421         if locatorIn != "" {
422                 var loc *keepclient.Locator
423                 if loc, err = keepclient.MakeLocator(locatorIn); err != nil {
424                         status = http.StatusBadRequest
425                         return
426                 } else if loc.Size > 0 && int64(loc.Size) != expectLength {
427                         err = LengthMismatchError
428                         status = http.StatusBadRequest
429                         return
430                 }
431         }
432
433         var pass bool
434         var tok string
435         if pass, tok = CheckAuthorizationHeader(kc, this.ApiTokenCache, req); !pass {
436                 err = BadAuthorizationHeader
437                 status = http.StatusForbidden
438                 return
439         }
440
441         // Copy ArvadosClient struct and use the client's API token
442         arvclient := *kc.Arvados
443         arvclient.ApiToken = tok
444         kc.Arvados = &arvclient
445
446         // Check if the client specified the number of replicas
447         if req.Header.Get("X-Keep-Desired-Replicas") != "" {
448                 var r int
449                 _, err := fmt.Sscanf(req.Header.Get(keepclient.X_Keep_Desired_Replicas), "%d", &r)
450                 if err != nil {
451                         kc.Want_replicas = r
452                 }
453         }
454
455         // Now try to put the block through
456         if locatorIn == "" {
457                 if bytes, err := ioutil.ReadAll(req.Body); err != nil {
458                         err = errors.New(fmt.Sprintf("Error reading request body: %s", err))
459                         status = http.StatusInternalServerError
460                         return
461                 } else {
462                         locatorOut, wroteReplicas, err = kc.PutB(bytes)
463                 }
464         } else {
465                 locatorOut, wroteReplicas, err = kc.PutHR(locatorIn, req.Body, expectLength)
466         }
467
468         // Tell the client how many successful PUTs we accomplished
469         resp.Header().Set(keepclient.X_Keep_Replicas_Stored, fmt.Sprintf("%d", wroteReplicas))
470
471         switch err {
472         case nil:
473                 status = http.StatusOK
474                 _, err = io.WriteString(resp, locatorOut)
475
476         case keepclient.OversizeBlockError:
477                 // Too much data
478                 status = http.StatusRequestEntityTooLarge
479
480         case keepclient.InsufficientReplicasError:
481                 if wroteReplicas > 0 {
482                         // At least one write is considered success.  The
483                         // client can decide if getting less than the number of
484                         // replications it asked for is a fatal error.
485                         status = http.StatusOK
486                         _, err = io.WriteString(resp, locatorOut)
487                 } else {
488                         status = http.StatusServiceUnavailable
489                 }
490
491         default:
492                 status = http.StatusBadGateway
493         }
494 }
495
496 // ServeHTTP implementation for IndexHandler
497 // Supports only GET requests for /index/{prefix:[0-9a-f]{0,32}}
498 // For each keep server found in LocalRoots:
499 //   Invokes GetIndex using keepclient
500 //   Expects "complete" response (terminating with blank new line)
501 //   Aborts on any errors
502 // Concatenates responses from all those keep servers and returns
503 func (handler IndexHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
504         SetCorsHeaders(resp)
505
506         prefix := mux.Vars(req)["prefix"]
507         var err error
508         var status int
509
510         defer func() {
511                 if status != http.StatusOK {
512                         http.Error(resp, err.Error(), status)
513                 }
514         }()
515
516         kc := *handler.KeepClient
517
518         var pass bool
519         var tok string
520         if pass, tok = CheckAuthorizationHeader(kc, handler.ApiTokenCache, req); !pass {
521                 status, err = http.StatusForbidden, BadAuthorizationHeader
522                 return
523         }
524
525         // Copy ArvadosClient struct and use the client's API token
526         arvclient := *kc.Arvados
527         arvclient.ApiToken = tok
528         kc.Arvados = &arvclient
529
530         // Only GET method is supported
531         if req.Method != "GET" {
532                 status, err = http.StatusNotImplemented, MethodNotSupported
533                 return
534         }
535
536         contentLen := 0
537         var reader io.Reader
538         for uuid := range kc.LocalRoots() {
539                 reader, err = kc.GetIndex(uuid, prefix)
540                 if err != nil {
541                         status = http.StatusBadGateway
542                         return
543                 }
544
545                 var readBytes []byte
546                 readBytes, err = ioutil.ReadAll(reader)
547                 if err != nil {
548                         status = http.StatusBadGateway
549                         return
550                 }
551
552                 // Got index for this server; write to resp
553                 n, err := resp.Write(readBytes)
554                 if err != nil {
555                         status = http.StatusBadGateway
556                         return
557                 }
558                 contentLen += n
559         }
560
561         // Got index from all the keep servers and wrote to resp
562         status = http.StatusOK
563         resp.Header().Set("Content-Length", fmt.Sprint(contentLen+1))
564         resp.Write([]byte("\n"))
565 }