closes #4717
[arvados.git] / services / keepproxy / keepproxy.go
1 package main
2
3 import (
4         "errors"
5         "flag"
6         "fmt"
7         "git.curoverse.com/arvados.git/sdk/go/arvadosclient"
8         "git.curoverse.com/arvados.git/sdk/go/keepclient"
9         "github.com/gorilla/mux"
10         "io"
11         "io/ioutil"
12         "log"
13         "net"
14         "net/http"
15         "os"
16         "os/signal"
17         "reflect"
18         "regexp"
19         "sync"
20         "syscall"
21         "time"
22 )
23
24 // Default TCP address on which to listen for requests.
25 // Initialized by the -listen flag.
26 const DEFAULT_ADDR = ":25107"
27
28 var listener net.Listener
29
30 func main() {
31         var (
32                 listen           string
33                 no_get           bool
34                 no_put           bool
35                 default_replicas int
36                 timeout          int64
37                 pidfile          string
38         )
39
40         flagset := flag.NewFlagSet("default", flag.ExitOnError)
41
42         flagset.StringVar(
43                 &listen,
44                 "listen",
45                 DEFAULT_ADDR,
46                 "Interface on which to listen for requests, in the format "+
47                         "ipaddr:port. e.g. -listen=10.0.1.24:8000. Use -listen=:port "+
48                         "to listen on all network interfaces.")
49
50         flagset.BoolVar(
51                 &no_get,
52                 "no-get",
53                 false,
54                 "If set, disable GET operations")
55
56         flagset.BoolVar(
57                 &no_put,
58                 "no-put",
59                 false,
60                 "If set, disable PUT operations")
61
62         flagset.IntVar(
63                 &default_replicas,
64                 "default-replicas",
65                 2,
66                 "Default number of replicas to write if not specified by the client.")
67
68         flagset.Int64Var(
69                 &timeout,
70                 "timeout",
71                 15,
72                 "Timeout on requests to internal Keep services (default 15 seconds)")
73
74         flagset.StringVar(
75                 &pidfile,
76                 "pid",
77                 "",
78                 "Path to write pid file")
79
80         flagset.Parse(os.Args[1:])
81
82         arv, err := arvadosclient.MakeArvadosClient()
83         if err != nil {
84                 log.Fatalf("Error setting up arvados client %s", err.Error())
85         }
86
87         kc, err := keepclient.MakeKeepClient(&arv)
88         if err != nil {
89                 log.Fatalf("Error setting up keep client %s", err.Error())
90         }
91
92         if pidfile != "" {
93                 f, err := os.Create(pidfile)
94                 if err != nil {
95                         log.Fatalf("Error writing pid file (%s): %s", pidfile, err.Error())
96                 }
97                 fmt.Fprint(f, os.Getpid())
98                 f.Close()
99                 defer os.Remove(pidfile)
100         }
101
102         kc.Want_replicas = default_replicas
103
104         kc.Client.Timeout = time.Duration(timeout) * time.Second
105
106         listener, err = net.Listen("tcp", listen)
107         if err != nil {
108                 log.Fatalf("Could not listen on %v", listen)
109         }
110
111         go RefreshServicesList(kc)
112
113         // Shut down the server gracefully (by closing the listener)
114         // if SIGTERM is received.
115         term := make(chan os.Signal, 1)
116         go func(sig <-chan os.Signal) {
117                 s := <-sig
118                 log.Println("caught signal:", s)
119                 listener.Close()
120         }(term)
121         signal.Notify(term, syscall.SIGTERM)
122         signal.Notify(term, syscall.SIGINT)
123
124         log.Printf("Arvados Keep proxy started listening on %v", listener.Addr())
125
126         // Start listening for requests.
127         http.Serve(listener, MakeRESTRouter(!no_get, !no_put, kc))
128
129         log.Println("shutting down")
130 }
131
132 type ApiTokenCache struct {
133         tokens     map[string]int64
134         lock       sync.Mutex
135         expireTime int64
136 }
137
138 // Refresh the keep service list every five minutes.
139 func RefreshServicesList(kc *keepclient.KeepClient) {
140         var previousRoots = []map[string]string{}
141         var delay time.Duration = 0
142         for {
143                 time.Sleep(delay * time.Second)
144                 delay = 300
145                 if err := kc.DiscoverKeepServers(); err != nil {
146                         log.Println("Error retrieving services list:", err)
147                         delay = 3
148                         continue
149                 }
150                 newRoots := []map[string]string{kc.LocalRoots(), kc.GatewayRoots()}
151                 if !reflect.DeepEqual(previousRoots, newRoots) {
152                         log.Printf("Updated services list: locals %v gateways %v", newRoots[0], newRoots[1])
153                 }
154                 if len(newRoots[0]) == 0 {
155                         log.Print("WARNING: No local services. Retrying in 3 seconds.")
156                         delay = 3
157                 }
158                 previousRoots = newRoots
159         }
160 }
161
162 // Cache the token and set an expire time.  If we already have an expire time
163 // on the token, it is not updated.
164 func (this *ApiTokenCache) RememberToken(token string) {
165         this.lock.Lock()
166         defer this.lock.Unlock()
167
168         now := time.Now().Unix()
169         if this.tokens[token] == 0 {
170                 this.tokens[token] = now + this.expireTime
171         }
172 }
173
174 // Check if the cached token is known and still believed to be valid.
175 func (this *ApiTokenCache) RecallToken(token string) bool {
176         this.lock.Lock()
177         defer this.lock.Unlock()
178
179         now := time.Now().Unix()
180         if this.tokens[token] == 0 {
181                 // Unknown token
182                 return false
183         } else if now < this.tokens[token] {
184                 // Token is known and still valid
185                 return true
186         } else {
187                 // Token is expired
188                 this.tokens[token] = 0
189                 return false
190         }
191 }
192
193 func GetRemoteAddress(req *http.Request) string {
194         if realip := req.Header.Get("X-Real-IP"); realip != "" {
195                 if forwarded := req.Header.Get("X-Forwarded-For"); forwarded != realip {
196                         return fmt.Sprintf("%s (X-Forwarded-For %s)", realip, forwarded)
197                 } else {
198                         return realip
199                 }
200         }
201         return req.RemoteAddr
202 }
203
204 func CheckAuthorizationHeader(kc keepclient.KeepClient, cache *ApiTokenCache, req *http.Request) (pass bool, tok string) {
205         var auth string
206         if auth = req.Header.Get("Authorization"); auth == "" {
207                 return false, ""
208         }
209
210         _, err := fmt.Sscanf(auth, "OAuth2 %s", &tok)
211         if err != nil {
212                 // Scanning error
213                 return false, ""
214         }
215
216         if cache.RecallToken(tok) {
217                 // Valid in the cache, short circut
218                 return true, tok
219         }
220
221         arv := *kc.Arvados
222         arv.ApiToken = tok
223         if err := arv.Call("HEAD", "users", "", "current", nil, nil); err != nil {
224                 log.Printf("%s: CheckAuthorizationHeader error: %v", GetRemoteAddress(req), err)
225                 return false, ""
226         }
227
228         // Success!  Update cache
229         cache.RememberToken(tok)
230
231         return true, tok
232 }
233
234 type GetBlockHandler struct {
235         *keepclient.KeepClient
236         *ApiTokenCache
237 }
238
239 type PutBlockHandler struct {
240         *keepclient.KeepClient
241         *ApiTokenCache
242 }
243
244 type InvalidPathHandler struct{}
245
246 type OptionsHandler struct{}
247
248 // MakeRESTRouter
249 //     Returns a mux.Router that passes GET and PUT requests to the
250 //     appropriate handlers.
251 //
252 func MakeRESTRouter(
253         enable_get bool,
254         enable_put bool,
255         kc *keepclient.KeepClient) *mux.Router {
256
257         t := &ApiTokenCache{tokens: make(map[string]int64), expireTime: 300}
258
259         rest := mux.NewRouter()
260
261         if enable_get {
262                 rest.Handle(`/{locator:[0-9a-f]{32}\+.*}`,
263                         GetBlockHandler{kc, t}).Methods("GET", "HEAD")
264                 rest.Handle(`/{locator:[0-9a-f]{32}}`, GetBlockHandler{kc, t}).Methods("GET", "HEAD")
265         }
266
267         if enable_put {
268                 rest.Handle(`/{locator:[0-9a-f]{32}\+.*}`, PutBlockHandler{kc, t}).Methods("PUT")
269                 rest.Handle(`/{locator:[0-9a-f]{32}}`, PutBlockHandler{kc, t}).Methods("PUT")
270                 rest.Handle(`/`, PutBlockHandler{kc, t}).Methods("POST")
271                 rest.Handle(`/{any}`, OptionsHandler{}).Methods("OPTIONS")
272                 rest.Handle(`/`, OptionsHandler{}).Methods("OPTIONS")
273         }
274
275         rest.NotFoundHandler = InvalidPathHandler{}
276
277         return rest
278 }
279
280 func SetCorsHeaders(resp http.ResponseWriter) {
281         resp.Header().Set("Access-Control-Allow-Methods", "GET, HEAD, POST, PUT, OPTIONS")
282         resp.Header().Set("Access-Control-Allow-Origin", "*")
283         resp.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Length, Content-Type, X-Keep-Desired-Replicas")
284         resp.Header().Set("Access-Control-Max-Age", "86486400")
285 }
286
287 func (this InvalidPathHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
288         log.Printf("%s: %s %s unroutable", GetRemoteAddress(req), req.Method, req.URL.Path)
289         http.Error(resp, "Bad request", http.StatusBadRequest)
290 }
291
292 func (this OptionsHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
293         log.Printf("%s: %s %s", GetRemoteAddress(req), req.Method, req.URL.Path)
294         SetCorsHeaders(resp)
295 }
296
297 var BadAuthorizationHeader = errors.New("Missing or invalid Authorization header")
298 var ContentLengthMismatch = errors.New("Actual length != expected content length")
299 var MethodNotSupported = errors.New("Method not supported")
300
301 var removeHint, _ = regexp.Compile("\\+K@[a-z0-9]{5}(\\+|$)")
302
303 func (this GetBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
304         SetCorsHeaders(resp)
305
306         locator := mux.Vars(req)["locator"]
307         var err error
308         var status int
309         var expectLength, responseLength int64
310         var proxiedURI = "-"
311
312         defer func() {
313                 log.Println(GetRemoteAddress(req), req.Method, req.URL.Path, status, expectLength, responseLength, proxiedURI, err)
314                 if status != http.StatusOK {
315                         http.Error(resp, err.Error(), status)
316                 }
317         }()
318
319         kc := *this.KeepClient
320
321         var pass bool
322         var tok string
323         if pass, tok = CheckAuthorizationHeader(kc, this.ApiTokenCache, req); !pass {
324                 status, err = http.StatusForbidden, BadAuthorizationHeader
325                 return
326         }
327
328         // Copy ArvadosClient struct and use the client's API token
329         arvclient := *kc.Arvados
330         arvclient.ApiToken = tok
331         kc.Arvados = &arvclient
332
333         var reader io.ReadCloser
334
335         locator = removeHint.ReplaceAllString(locator, "$1")
336
337         switch req.Method {
338         case "HEAD":
339                 expectLength, proxiedURI, err = kc.Ask(locator)
340         case "GET":
341                 reader, expectLength, proxiedURI, err = kc.Get(locator)
342                 if reader != nil {
343                         defer reader.Close()
344                 }
345         default:
346                 status, err = http.StatusNotImplemented, MethodNotSupported
347                 return
348         }
349
350         if expectLength == -1 {
351                 log.Println("Warning:", GetRemoteAddress(req), req.Method, proxiedURI, "Content-Length not provided")
352         }
353
354         switch err {
355         case nil:
356                 status = http.StatusOK
357                 resp.Header().Set("Content-Length", fmt.Sprint(expectLength))
358                 switch req.Method {
359                 case "HEAD":
360                         responseLength = 0
361                 case "GET":
362                         responseLength, err = io.Copy(resp, reader)
363                         if err == nil && expectLength > -1 && responseLength != expectLength {
364                                 err = ContentLengthMismatch
365                         }
366                 }
367         case keepclient.BlockNotFound:
368                 status = http.StatusNotFound
369         default:
370                 status = http.StatusBadGateway
371         }
372 }
373
374 var LengthRequiredError = errors.New(http.StatusText(http.StatusLengthRequired))
375 var LengthMismatchError = errors.New("Locator size hint does not match Content-Length header")
376
377 func (this PutBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
378         SetCorsHeaders(resp)
379
380         kc := *this.KeepClient
381         var err error
382         var expectLength int64 = -1
383         var status = http.StatusInternalServerError
384         var wroteReplicas int
385         var locatorOut string = "-"
386
387         defer func() {
388                 log.Println(GetRemoteAddress(req), req.Method, req.URL.Path, status, expectLength, kc.Want_replicas, wroteReplicas, locatorOut, err)
389                 if status != http.StatusOK {
390                         http.Error(resp, err.Error(), status)
391                 }
392         }()
393
394         locatorIn := mux.Vars(req)["locator"]
395
396         if req.Header.Get("Content-Length") != "" {
397                 _, err := fmt.Sscanf(req.Header.Get("Content-Length"), "%d", &expectLength)
398                 if err != nil {
399                         resp.Header().Set("Content-Length", fmt.Sprintf("%d", expectLength))
400                 }
401
402         }
403
404         if expectLength < 0 {
405                 err = LengthRequiredError
406                 status = http.StatusLengthRequired
407                 return
408         }
409
410         if locatorIn != "" {
411                 var loc *keepclient.Locator
412                 if loc, err = keepclient.MakeLocator(locatorIn); err != nil {
413                         status = http.StatusBadRequest
414                         return
415                 } else if loc.Size > 0 && int64(loc.Size) != expectLength {
416                         err = LengthMismatchError
417                         status = http.StatusBadRequest
418                         return
419                 }
420         }
421
422         var pass bool
423         var tok string
424         if pass, tok = CheckAuthorizationHeader(kc, this.ApiTokenCache, req); !pass {
425                 err = BadAuthorizationHeader
426                 status = http.StatusForbidden
427                 return
428         }
429
430         // Copy ArvadosClient struct and use the client's API token
431         arvclient := *kc.Arvados
432         arvclient.ApiToken = tok
433         kc.Arvados = &arvclient
434
435         // Check if the client specified the number of replicas
436         if req.Header.Get("X-Keep-Desired-Replicas") != "" {
437                 var r int
438                 _, err := fmt.Sscanf(req.Header.Get(keepclient.X_Keep_Desired_Replicas), "%d", &r)
439                 if err != nil {
440                         kc.Want_replicas = r
441                 }
442         }
443
444         // Now try to put the block through
445         if locatorIn == "" {
446                 if bytes, err := ioutil.ReadAll(req.Body); err != nil {
447                         err = errors.New(fmt.Sprintf("Error reading request body: %s", err))
448                         status = http.StatusInternalServerError
449                         return
450                 } else {
451                         locatorOut, wroteReplicas, err = kc.PutB(bytes)
452                 }
453         } else {
454                 locatorOut, wroteReplicas, err = kc.PutHR(locatorIn, req.Body, expectLength)
455         }
456
457         // Tell the client how many successful PUTs we accomplished
458         resp.Header().Set(keepclient.X_Keep_Replicas_Stored, fmt.Sprintf("%d", wroteReplicas))
459
460         switch err {
461         case nil:
462                 status = http.StatusOK
463                 _, err = io.WriteString(resp, locatorOut)
464
465         case keepclient.OversizeBlockError:
466                 // Too much data
467                 status = http.StatusRequestEntityTooLarge
468
469         case keepclient.InsufficientReplicasError:
470                 if wroteReplicas > 0 {
471                         // At least one write is considered success.  The
472                         // client can decide if getting less than the number of
473                         // replications it asked for is a fatal error.
474                         status = http.StatusOK
475                         _, err = io.WriteString(resp, locatorOut)
476                 } else {
477                         status = http.StatusServiceUnavailable
478                 }
479
480         default:
481                 status = http.StatusBadGateway
482         }
483 }