7200: The incomplete response when no such prefix exists will be "\n". Update keepcli...
[arvados.git] / services / keepproxy / keepproxy.go
1 package main
2
3 import (
4         "errors"
5         "flag"
6         "fmt"
7         "git.curoverse.com/arvados.git/sdk/go/arvadosclient"
8         "git.curoverse.com/arvados.git/sdk/go/keepclient"
9         "github.com/gorilla/mux"
10         "io"
11         "io/ioutil"
12         "log"
13         "net"
14         "net/http"
15         "os"
16         "os/signal"
17         "reflect"
18         "regexp"
19         "strings"
20         "sync"
21         "syscall"
22         "time"
23 )
24
25 // Default TCP address on which to listen for requests.
26 // Initialized by the -listen flag.
27 const DEFAULT_ADDR = ":25107"
28
29 var listener net.Listener
30
31 func main() {
32         var (
33                 listen           string
34                 no_get           bool
35                 no_put           bool
36                 default_replicas int
37                 timeout          int64
38                 pidfile          string
39         )
40
41         flagset := flag.NewFlagSet("default", flag.ExitOnError)
42
43         flagset.StringVar(
44                 &listen,
45                 "listen",
46                 DEFAULT_ADDR,
47                 "Interface on which to listen for requests, in the format "+
48                         "ipaddr:port. e.g. -listen=10.0.1.24:8000. Use -listen=:port "+
49                         "to listen on all network interfaces.")
50
51         flagset.BoolVar(
52                 &no_get,
53                 "no-get",
54                 false,
55                 "If set, disable GET operations")
56
57         flagset.BoolVar(
58                 &no_put,
59                 "no-put",
60                 false,
61                 "If set, disable PUT operations")
62
63         flagset.IntVar(
64                 &default_replicas,
65                 "default-replicas",
66                 2,
67                 "Default number of replicas to write if not specified by the client.")
68
69         flagset.Int64Var(
70                 &timeout,
71                 "timeout",
72                 15,
73                 "Timeout on requests to internal Keep services (default 15 seconds)")
74
75         flagset.StringVar(
76                 &pidfile,
77                 "pid",
78                 "",
79                 "Path to write pid file")
80
81         flagset.Parse(os.Args[1:])
82
83         arv, err := arvadosclient.MakeArvadosClient()
84         if err != nil {
85                 log.Fatalf("Error setting up arvados client %s", err.Error())
86         }
87
88         kc, err := keepclient.MakeKeepClient(&arv)
89         if err != nil {
90                 log.Fatalf("Error setting up keep client %s", err.Error())
91         }
92
93         if pidfile != "" {
94                 f, err := os.Create(pidfile)
95                 if err != nil {
96                         log.Fatalf("Error writing pid file (%s): %s", pidfile, err.Error())
97                 }
98                 fmt.Fprint(f, os.Getpid())
99                 f.Close()
100                 defer os.Remove(pidfile)
101         }
102
103         kc.Want_replicas = default_replicas
104
105         kc.Client.Timeout = time.Duration(timeout) * time.Second
106
107         listener, err = net.Listen("tcp", listen)
108         if err != nil {
109                 log.Fatalf("Could not listen on %v", listen)
110         }
111
112         go RefreshServicesList(kc)
113
114         // Shut down the server gracefully (by closing the listener)
115         // if SIGTERM is received.
116         term := make(chan os.Signal, 1)
117         go func(sig <-chan os.Signal) {
118                 s := <-sig
119                 log.Println("caught signal:", s)
120                 listener.Close()
121         }(term)
122         signal.Notify(term, syscall.SIGTERM)
123         signal.Notify(term, syscall.SIGINT)
124
125         log.Printf("Arvados Keep proxy started listening on %v", listener.Addr())
126
127         // Start listening for requests.
128         http.Serve(listener, MakeRESTRouter(!no_get, !no_put, kc))
129
130         log.Println("shutting down")
131 }
132
133 type ApiTokenCache struct {
134         tokens     map[string]int64
135         lock       sync.Mutex
136         expireTime int64
137 }
138
139 // Refresh the keep service list every five minutes.
140 func RefreshServicesList(kc *keepclient.KeepClient) {
141         var previousRoots = []map[string]string{}
142         var delay time.Duration = 0
143         for {
144                 time.Sleep(delay * time.Second)
145                 delay = 300
146                 if err := kc.DiscoverKeepServers(); err != nil {
147                         log.Println("Error retrieving services list:", err)
148                         delay = 3
149                         continue
150                 }
151                 newRoots := []map[string]string{kc.LocalRoots(), kc.GatewayRoots()}
152                 if !reflect.DeepEqual(previousRoots, newRoots) {
153                         log.Printf("Updated services list: locals %v gateways %v", newRoots[0], newRoots[1])
154                 }
155                 if len(newRoots[0]) == 0 {
156                         log.Print("WARNING: No local services. Retrying in 3 seconds.")
157                         delay = 3
158                 }
159                 previousRoots = newRoots
160         }
161 }
162
163 // Cache the token and set an expire time.  If we already have an expire time
164 // on the token, it is not updated.
165 func (this *ApiTokenCache) RememberToken(token string) {
166         this.lock.Lock()
167         defer this.lock.Unlock()
168
169         now := time.Now().Unix()
170         if this.tokens[token] == 0 {
171                 this.tokens[token] = now + this.expireTime
172         }
173 }
174
175 // Check if the cached token is known and still believed to be valid.
176 func (this *ApiTokenCache) RecallToken(token string) bool {
177         this.lock.Lock()
178         defer this.lock.Unlock()
179
180         now := time.Now().Unix()
181         if this.tokens[token] == 0 {
182                 // Unknown token
183                 return false
184         } else if now < this.tokens[token] {
185                 // Token is known and still valid
186                 return true
187         } else {
188                 // Token is expired
189                 this.tokens[token] = 0
190                 return false
191         }
192 }
193
194 func GetRemoteAddress(req *http.Request) string {
195         if realip := req.Header.Get("X-Real-IP"); realip != "" {
196                 if forwarded := req.Header.Get("X-Forwarded-For"); forwarded != realip {
197                         return fmt.Sprintf("%s (X-Forwarded-For %s)", realip, forwarded)
198                 } else {
199                         return realip
200                 }
201         }
202         return req.RemoteAddr
203 }
204
205 func CheckAuthorizationHeader(kc keepclient.KeepClient, cache *ApiTokenCache, req *http.Request) (pass bool, tok string) {
206         var auth string
207         if auth = req.Header.Get("Authorization"); auth == "" {
208                 return false, ""
209         }
210
211         _, err := fmt.Sscanf(auth, "OAuth2 %s", &tok)
212         if err != nil {
213                 // Scanning error
214                 return false, ""
215         }
216
217         if cache.RecallToken(tok) {
218                 // Valid in the cache, short circut
219                 return true, tok
220         }
221
222         arv := *kc.Arvados
223         arv.ApiToken = tok
224         if err := arv.Call("HEAD", "users", "", "current", nil, nil); err != nil {
225                 log.Printf("%s: CheckAuthorizationHeader error: %v", GetRemoteAddress(req), err)
226                 return false, ""
227         }
228
229         // Success!  Update cache
230         cache.RememberToken(tok)
231
232         return true, tok
233 }
234
235 type GetBlockHandler struct {
236         *keepclient.KeepClient
237         *ApiTokenCache
238 }
239
240 type PutBlockHandler struct {
241         *keepclient.KeepClient
242         *ApiTokenCache
243 }
244
245 type IndexHandler struct {
246         *keepclient.KeepClient
247         *ApiTokenCache
248 }
249
250 type InvalidPathHandler struct{}
251
252 type OptionsHandler struct{}
253
254 // MakeRESTRouter
255 //     Returns a mux.Router that passes GET and PUT requests to the
256 //     appropriate handlers.
257 //
258 func MakeRESTRouter(
259         enable_get bool,
260         enable_put bool,
261         kc *keepclient.KeepClient) *mux.Router {
262
263         t := &ApiTokenCache{tokens: make(map[string]int64), expireTime: 300}
264
265         rest := mux.NewRouter()
266
267         if enable_get {
268                 rest.Handle(`/{locator:[0-9a-f]{32}\+.*}`,
269                         GetBlockHandler{kc, t}).Methods("GET", "HEAD")
270                 rest.Handle(`/{locator:[0-9a-f]{32}}`, GetBlockHandler{kc, t}).Methods("GET", "HEAD")
271
272                 // List all blocks
273                 rest.Handle(`/index`, IndexHandler{kc, t}).Methods("GET")
274
275                 // List blocks whose hash has the given prefix
276                 rest.Handle(`/index/{prefix:[0-9a-f]{0,32}}`, IndexHandler{kc, t}).Methods("GET")
277         }
278
279         if enable_put {
280                 rest.Handle(`/{locator:[0-9a-f]{32}\+.*}`, PutBlockHandler{kc, t}).Methods("PUT")
281                 rest.Handle(`/{locator:[0-9a-f]{32}}`, PutBlockHandler{kc, t}).Methods("PUT")
282                 rest.Handle(`/`, PutBlockHandler{kc, t}).Methods("POST")
283                 rest.Handle(`/{any}`, OptionsHandler{}).Methods("OPTIONS")
284                 rest.Handle(`/`, OptionsHandler{}).Methods("OPTIONS")
285         }
286
287         rest.NotFoundHandler = InvalidPathHandler{}
288
289         return rest
290 }
291
292 func SetCorsHeaders(resp http.ResponseWriter) {
293         resp.Header().Set("Access-Control-Allow-Methods", "GET, HEAD, POST, PUT, OPTIONS")
294         resp.Header().Set("Access-Control-Allow-Origin", "*")
295         resp.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Length, Content-Type, X-Keep-Desired-Replicas")
296         resp.Header().Set("Access-Control-Max-Age", "86486400")
297 }
298
299 func (this InvalidPathHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
300         log.Printf("%s: %s %s unroutable", GetRemoteAddress(req), req.Method, req.URL.Path)
301         http.Error(resp, "Bad request", http.StatusBadRequest)
302 }
303
304 func (this OptionsHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
305         log.Printf("%s: %s %s", GetRemoteAddress(req), req.Method, req.URL.Path)
306         SetCorsHeaders(resp)
307 }
308
309 var BadAuthorizationHeader = errors.New("Missing or invalid Authorization header")
310 var ContentLengthMismatch = errors.New("Actual length != expected content length")
311 var MethodNotSupported = errors.New("Method not supported")
312
313 var removeHint, _ = regexp.Compile("\\+K@[a-z0-9]{5}(\\+|$)")
314
315 func (this GetBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
316         SetCorsHeaders(resp)
317
318         locator := mux.Vars(req)["locator"]
319         var err error
320         var status int
321         var expectLength, responseLength int64
322         var proxiedURI = "-"
323
324         defer func() {
325                 log.Println(GetRemoteAddress(req), req.Method, req.URL.Path, status, expectLength, responseLength, proxiedURI, err)
326                 if status != http.StatusOK {
327                         http.Error(resp, err.Error(), status)
328                 }
329         }()
330
331         kc := *this.KeepClient
332
333         var pass bool
334         var tok string
335         if pass, tok = CheckAuthorizationHeader(kc, this.ApiTokenCache, req); !pass {
336                 status, err = http.StatusForbidden, BadAuthorizationHeader
337                 return
338         }
339
340         // Copy ArvadosClient struct and use the client's API token
341         arvclient := *kc.Arvados
342         arvclient.ApiToken = tok
343         kc.Arvados = &arvclient
344
345         var reader io.ReadCloser
346
347         locator = removeHint.ReplaceAllString(locator, "$1")
348
349         switch req.Method {
350         case "HEAD":
351                 expectLength, proxiedURI, err = kc.Ask(locator)
352         case "GET":
353                 reader, expectLength, proxiedURI, err = kc.Get(locator)
354                 if reader != nil {
355                         defer reader.Close()
356                 }
357         default:
358                 status, err = http.StatusNotImplemented, MethodNotSupported
359                 return
360         }
361
362         if expectLength == -1 {
363                 log.Println("Warning:", GetRemoteAddress(req), req.Method, proxiedURI, "Content-Length not provided")
364         }
365
366         switch err {
367         case nil:
368                 status = http.StatusOK
369                 resp.Header().Set("Content-Length", fmt.Sprint(expectLength))
370                 switch req.Method {
371                 case "HEAD":
372                         responseLength = 0
373                 case "GET":
374                         responseLength, err = io.Copy(resp, reader)
375                         if err == nil && expectLength > -1 && responseLength != expectLength {
376                                 err = ContentLengthMismatch
377                         }
378                 }
379         case keepclient.BlockNotFound:
380                 status = http.StatusNotFound
381         default:
382                 status = http.StatusBadGateway
383         }
384 }
385
386 var LengthRequiredError = errors.New(http.StatusText(http.StatusLengthRequired))
387 var LengthMismatchError = errors.New("Locator size hint does not match Content-Length header")
388
389 func (this PutBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
390         SetCorsHeaders(resp)
391
392         kc := *this.KeepClient
393         var err error
394         var expectLength int64 = -1
395         var status = http.StatusInternalServerError
396         var wroteReplicas int
397         var locatorOut string = "-"
398
399         defer func() {
400                 log.Println(GetRemoteAddress(req), req.Method, req.URL.Path, status, expectLength, kc.Want_replicas, wroteReplicas, locatorOut, err)
401                 if status != http.StatusOK {
402                         http.Error(resp, err.Error(), status)
403                 }
404         }()
405
406         locatorIn := mux.Vars(req)["locator"]
407
408         if req.Header.Get("Content-Length") != "" {
409                 _, err := fmt.Sscanf(req.Header.Get("Content-Length"), "%d", &expectLength)
410                 if err != nil {
411                         resp.Header().Set("Content-Length", fmt.Sprintf("%d", expectLength))
412                 }
413
414         }
415
416         if expectLength < 0 {
417                 err = LengthRequiredError
418                 status = http.StatusLengthRequired
419                 return
420         }
421
422         if locatorIn != "" {
423                 var loc *keepclient.Locator
424                 if loc, err = keepclient.MakeLocator(locatorIn); err != nil {
425                         status = http.StatusBadRequest
426                         return
427                 } else if loc.Size > 0 && int64(loc.Size) != expectLength {
428                         err = LengthMismatchError
429                         status = http.StatusBadRequest
430                         return
431                 }
432         }
433
434         var pass bool
435         var tok string
436         if pass, tok = CheckAuthorizationHeader(kc, this.ApiTokenCache, req); !pass {
437                 err = BadAuthorizationHeader
438                 status = http.StatusForbidden
439                 return
440         }
441
442         // Copy ArvadosClient struct and use the client's API token
443         arvclient := *kc.Arvados
444         arvclient.ApiToken = tok
445         kc.Arvados = &arvclient
446
447         // Check if the client specified the number of replicas
448         if req.Header.Get("X-Keep-Desired-Replicas") != "" {
449                 var r int
450                 _, err := fmt.Sscanf(req.Header.Get(keepclient.X_Keep_Desired_Replicas), "%d", &r)
451                 if err != nil {
452                         kc.Want_replicas = r
453                 }
454         }
455
456         // Now try to put the block through
457         if locatorIn == "" {
458                 if bytes, err := ioutil.ReadAll(req.Body); err != nil {
459                         err = errors.New(fmt.Sprintf("Error reading request body: %s", err))
460                         status = http.StatusInternalServerError
461                         return
462                 } else {
463                         locatorOut, wroteReplicas, err = kc.PutB(bytes)
464                 }
465         } else {
466                 locatorOut, wroteReplicas, err = kc.PutHR(locatorIn, req.Body, expectLength)
467         }
468
469         // Tell the client how many successful PUTs we accomplished
470         resp.Header().Set(keepclient.X_Keep_Replicas_Stored, fmt.Sprintf("%d", wroteReplicas))
471
472         switch err {
473         case nil:
474                 status = http.StatusOK
475                 _, err = io.WriteString(resp, locatorOut)
476
477         case keepclient.OversizeBlockError:
478                 // Too much data
479                 status = http.StatusRequestEntityTooLarge
480
481         case keepclient.InsufficientReplicasError:
482                 if wroteReplicas > 0 {
483                         // At least one write is considered success.  The
484                         // client can decide if getting less than the number of
485                         // replications it asked for is a fatal error.
486                         status = http.StatusOK
487                         _, err = io.WriteString(resp, locatorOut)
488                 } else {
489                         status = http.StatusServiceUnavailable
490                 }
491
492         default:
493                 status = http.StatusBadGateway
494         }
495 }
496
497 // ServeHTTP implemenation for IndexHandler
498 // Supports only GET requests for /index/{prefix:[0-9a-f]{0,32}}
499 // For each keep server found in LocalRoots:
500 //   Invokes GetIndex using keepclient
501 //   Expects "complete" response (terminating with blank new line)
502 //   Aborts on any errors
503 // Concatenates responses from all those keep servers and returns
504 func (handler IndexHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
505         SetCorsHeaders(resp)
506
507         prefix := mux.Vars(req)["prefix"]
508         var err error
509         var status int
510
511         defer func() {
512                 if status != http.StatusOK {
513                         http.Error(resp, err.Error(), status)
514                 }
515         }()
516
517         kc := *handler.KeepClient
518
519         var pass bool
520         var tok string
521         if pass, tok = CheckAuthorizationHeader(kc, handler.ApiTokenCache, req); !pass {
522                 status, err = http.StatusForbidden, BadAuthorizationHeader
523                 return
524         }
525
526         // Copy ArvadosClient struct and use the client's API token
527         arvclient := *kc.Arvados
528         arvclient.ApiToken = tok
529         kc.Arvados = &arvclient
530
531         var indexResp []byte
532         var reader io.Reader
533
534         switch req.Method {
535         case "GET":
536                 for uuid := range kc.LocalRoots() {
537                         reader, err = kc.GetIndex(uuid, prefix)
538                         if err != nil {
539                                 break
540                         }
541
542                         var readBytes []byte
543                         readBytes, err = ioutil.ReadAll(reader)
544                         if err != nil {
545                                 break
546                         }
547
548                         // Got index; verify that it is complete
549                         // The response should be "\n" if no locators matched the prefix
550                         // Else, it should be a list of locators followed by a blank line
551                         if (!strings.HasSuffix(string(readBytes), "\n\n")) && (string(readBytes) != "\n") {
552                                 err = errors.New("Got incomplete index")
553                         }
554
555                         // Trim the extra empty new line found in response from each server
556                         indexResp = append(indexResp, (readBytes[0 : len(readBytes)-1])...)
557                 }
558
559                 // Append empty line at the end of concatenation of all server responses
560                 indexResp = append(indexResp, ([]byte("\n"))...)
561         default:
562                 status, err = http.StatusNotImplemented, MethodNotSupported
563                 return
564         }
565
566         switch err {
567         case nil:
568                 status = http.StatusOK
569                 resp.Header().Set("Content-Length", fmt.Sprint(len(indexResp)))
570                 _, err = resp.Write(indexResp)
571         default:
572                 status = http.StatusBadGateway
573         }
574 }