2754: Merge branch '2754-pipeline-template-description' refs #2754
[arvados.git] / services / keep / src / arvados.org / keepproxy / keepproxy.go
1 package main
2
3 import (
4         "arvados.org/keepclient"
5         "flag"
6         "fmt"
7         "github.com/gorilla/mux"
8         "io"
9         "log"
10         "net"
11         "net/http"
12         "os"
13         "os/signal"
14         "sync"
15         "syscall"
16         "time"
17 )
18
19 // Default TCP address on which to listen for requests.
20 // Initialized by the -listen flag.
21 const DEFAULT_ADDR = ":25107"
22
23 var listener net.Listener
24
25 func main() {
26         var (
27                 listen           string
28                 no_get           bool
29                 no_put           bool
30                 default_replicas int
31                 pidfile          string
32         )
33
34         flagset := flag.NewFlagSet("default", flag.ExitOnError)
35
36         flagset.StringVar(
37                 &listen,
38                 "listen",
39                 DEFAULT_ADDR,
40                 "Interface on which to listen for requests, in the format "+
41                         "ipaddr:port. e.g. -listen=10.0.1.24:8000. Use -listen=:port "+
42                         "to listen on all network interfaces.")
43
44         flagset.BoolVar(
45                 &no_get,
46                 "no-get",
47                 false,
48                 "If set, disable GET operations")
49
50         flagset.BoolVar(
51                 &no_put,
52                 "no-put",
53                 false,
54                 "If set, disable PUT operations")
55
56         flagset.IntVar(
57                 &default_replicas,
58                 "default-replicas",
59                 2,
60                 "Default number of replicas to write if not specified by the client.")
61
62         flagset.StringVar(
63                 &pidfile,
64                 "pid",
65                 "",
66                 "Path to write pid file")
67
68         flagset.Parse(os.Args[1:])
69
70         kc, err := keepclient.MakeKeepClient()
71         if err != nil {
72                 log.Fatalf("Error setting up keep client %s", err.Error())
73         }
74
75         if pidfile != "" {
76                 f, err := os.Create(pidfile)
77                 if err == nil {
78                         fmt.Fprint(f, os.Getpid())
79                         f.Close()
80                 } else {
81                         log.Printf("Error writing pid file (%s): %s", pidfile, err.Error())
82                 }
83         }
84
85         kc.Want_replicas = default_replicas
86
87         listener, err = net.Listen("tcp", listen)
88         if err != nil {
89                 log.Fatalf("Could not listen on %v", listen)
90         }
91
92         go RefreshServicesList(&kc)
93
94         // Shut down the server gracefully (by closing the listener)
95         // if SIGTERM is received.
96         term := make(chan os.Signal, 1)
97         go func(sig <-chan os.Signal) {
98                 s := <-sig
99                 log.Println("caught signal:", s)
100                 listener.Close()
101         }(term)
102         signal.Notify(term, syscall.SIGTERM)
103
104         if pidfile != "" {
105                 f, err := os.Create(pidfile)
106                 if err == nil {
107                         fmt.Fprint(f, os.Getpid())
108                         f.Close()
109                 } else {
110                         log.Printf("Error writing pid file (%s): %s", pidfile, err.Error())
111                 }
112         }
113
114         log.Printf("Arvados Keep proxy started listening on %v with server list %v", listener.Addr(), kc.ServiceRoots())
115
116         // Start listening for requests.
117         http.Serve(listener, MakeRESTRouter(!no_get, !no_put, &kc))
118
119         log.Println("shutting down")
120
121         if pidfile != "" {
122                 os.Remove(pidfile)
123         }
124 }
125
126 type ApiTokenCache struct {
127         tokens     map[string]int64
128         lock       sync.Mutex
129         expireTime int64
130 }
131
132 // Refresh the keep service list every five minutes.
133 func RefreshServicesList(kc *keepclient.KeepClient) {
134         for {
135                 time.Sleep(300 * time.Second)
136                 oldservices := kc.ServiceRoots()
137                 kc.DiscoverKeepServers()
138                 newservices := kc.ServiceRoots()
139                 s1 := fmt.Sprint(oldservices)
140                 s2 := fmt.Sprint(newservices)
141                 if s1 != s2 {
142                         log.Printf("Updated server list to %v", s2)
143                 }
144         }
145 }
146
147 // Cache the token and set an expire time.  If we already have an expire time
148 // on the token, it is not updated.
149 func (this *ApiTokenCache) RememberToken(token string) {
150         this.lock.Lock()
151         defer this.lock.Unlock()
152
153         now := time.Now().Unix()
154         if this.tokens[token] == 0 {
155                 this.tokens[token] = now + this.expireTime
156         }
157 }
158
159 // Check if the cached token is known and still believed to be valid.
160 func (this *ApiTokenCache) RecallToken(token string) bool {
161         this.lock.Lock()
162         defer this.lock.Unlock()
163
164         now := time.Now().Unix()
165         if this.tokens[token] == 0 {
166                 // Unknown token
167                 return false
168         } else if now < this.tokens[token] {
169                 // Token is known and still valid
170                 return true
171         } else {
172                 // Token is expired
173                 this.tokens[token] = 0
174                 return false
175         }
176 }
177
178 func GetRemoteAddress(req *http.Request) string {
179         if realip := req.Header.Get("X-Real-IP"); realip != "" {
180                 if forwarded := req.Header.Get("X-Forwarded-For"); forwarded != realip {
181                         return fmt.Sprintf("%s (X-Forwarded-For %s)", realip, forwarded)
182                 } else {
183                         return realip
184                 }
185         }
186         return req.RemoteAddr
187 }
188
189 func CheckAuthorizationHeader(kc keepclient.KeepClient, cache *ApiTokenCache, req *http.Request) bool {
190         var auth string
191         if auth = req.Header.Get("Authorization"); auth == "" {
192                 return false
193         }
194
195         var tok string
196         _, err := fmt.Sscanf(auth, "OAuth2 %s", &tok)
197         if err != nil {
198                 // Scanning error
199                 return false
200         }
201
202         if cache.RecallToken(tok) {
203                 // Valid in the cache, short circut
204                 return true
205         }
206
207         var usersreq *http.Request
208
209         if usersreq, err = http.NewRequest("HEAD", fmt.Sprintf("https://%s/arvados/v1/users/current", kc.ApiServer), nil); err != nil {
210                 // Can't construct the request
211                 log.Printf("%s: CheckAuthorizationHeader error: %v", GetRemoteAddress(req), err)
212                 return false
213         }
214
215         // Add api token header
216         usersreq.Header.Add("Authorization", fmt.Sprintf("OAuth2 %s", tok))
217
218         // Actually make the request
219         var resp *http.Response
220         if resp, err = kc.Client.Do(usersreq); err != nil {
221                 // Something else failed
222                 log.Printf("%s: CheckAuthorizationHeader error connecting to API server: %v", GetRemoteAddress(req), err.Error())
223                 return false
224         }
225
226         if resp.StatusCode != http.StatusOK {
227                 // Bad status
228                 log.Printf("%s: CheckAuthorizationHeader API server responded: %v", GetRemoteAddress(req), resp.Status)
229                 return false
230         }
231
232         // Success!  Update cache
233         cache.RememberToken(tok)
234
235         return true
236 }
237
238 type GetBlockHandler struct {
239         *keepclient.KeepClient
240         *ApiTokenCache
241 }
242
243 type PutBlockHandler struct {
244         *keepclient.KeepClient
245         *ApiTokenCache
246 }
247
248 type InvalidPathHandler struct{}
249
250 // MakeRESTRouter
251 //     Returns a mux.Router that passes GET and PUT requests to the
252 //     appropriate handlers.
253 //
254 func MakeRESTRouter(
255         enable_get bool,
256         enable_put bool,
257         kc *keepclient.KeepClient) *mux.Router {
258
259         t := &ApiTokenCache{tokens: make(map[string]int64), expireTime: 300}
260
261         rest := mux.NewRouter()
262
263         if enable_get {
264                 rest.Handle(`/{hash:[0-9a-f]{32}}+{hints}`,
265                         GetBlockHandler{kc, t}).Methods("GET", "HEAD")
266                 rest.Handle(`/{hash:[0-9a-f]{32}}`, GetBlockHandler{kc, t}).Methods("GET", "HEAD")
267         }
268
269         if enable_put {
270                 rest.Handle(`/{hash:[0-9a-f]{32}}+{hints}`, PutBlockHandler{kc, t}).Methods("PUT")
271                 rest.Handle(`/{hash:[0-9a-f]{32}}`, PutBlockHandler{kc, t}).Methods("PUT")
272         }
273
274         rest.NotFoundHandler = InvalidPathHandler{}
275
276         return rest
277 }
278
279 func (this InvalidPathHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
280         log.Printf("%s: %s %s unroutable", GetRemoteAddress(req), req.Method, req.URL.Path)
281         http.Error(resp, "Bad request", http.StatusBadRequest)
282 }
283
284 func (this GetBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
285
286         kc := *this.KeepClient
287
288         hash := mux.Vars(req)["hash"]
289         hints := mux.Vars(req)["hints"]
290
291         locator := keepclient.MakeLocator2(hash, hints)
292
293         log.Printf("%s: %s %s", GetRemoteAddress(req), req.Method, hash)
294
295         if !CheckAuthorizationHeader(kc, this.ApiTokenCache, req) {
296                 http.Error(resp, "Missing or invalid Authorization header", http.StatusForbidden)
297                 return
298         }
299
300         var reader io.ReadCloser
301         var err error
302         var blocklen int64
303
304         if req.Method == "GET" {
305                 reader, blocklen, _, err = kc.AuthorizedGet(hash, locator.Signature, locator.Timestamp)
306                 defer reader.Close()
307         } else if req.Method == "HEAD" {
308                 blocklen, _, err = kc.AuthorizedAsk(hash, locator.Signature, locator.Timestamp)
309         }
310
311         resp.Header().Set("Content-Length", fmt.Sprint(blocklen))
312
313         switch err {
314         case nil:
315                 if reader != nil {
316                         n, err2 := io.Copy(resp, reader)
317                         if n != blocklen {
318                                 log.Printf("%s: %s %s mismatched return %v with Content-Length %v error", GetRemoteAddress(req), req.Method, hash, n, blocklen, err.Error())
319                         } else if err2 == nil {
320                                 log.Printf("%s: %s %s success returned %v bytes", GetRemoteAddress(req), req.Method, hash, n)
321                         } else {
322                                 log.Printf("%s: %s %s returned %v bytes error %v", GetRemoteAddress(req), req.Method, hash, n, err.Error())
323                         }
324                 } else {
325                         log.Printf("%s: %s %s success", GetRemoteAddress(req), req.Method, hash)
326                 }
327         case keepclient.BlockNotFound:
328                 http.Error(resp, "Not found", http.StatusNotFound)
329         default:
330                 http.Error(resp, err.Error(), http.StatusBadGateway)
331         }
332
333         if err != nil {
334                 log.Printf("%s: %s %s error %s", GetRemoteAddress(req), req.Method, hash, err.Error())
335         }
336 }
337
338 func (this PutBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
339
340         kc := *this.KeepClient
341
342         hash := mux.Vars(req)["hash"]
343         hints := mux.Vars(req)["hints"]
344
345         locator := keepclient.MakeLocator2(hash, hints)
346
347         var contentLength int64 = -1
348         if req.Header.Get("Content-Length") != "" {
349                 _, err := fmt.Sscanf(req.Header.Get("Content-Length"), "%d", &contentLength)
350                 if err != nil {
351                         resp.Header().Set("Content-Length", fmt.Sprintf("%d", contentLength))
352                 }
353
354         }
355
356         log.Printf("%s: %s %s Content-Length %v", GetRemoteAddress(req), req.Method, hash, contentLength)
357
358         if contentLength < 1 {
359                 http.Error(resp, "Must include Content-Length header", http.StatusLengthRequired)
360                 return
361         }
362
363         if locator.Size > 0 && int64(locator.Size) != contentLength {
364                 http.Error(resp, "Locator size hint does not match Content-Length header", http.StatusBadRequest)
365                 return
366         }
367
368         if !CheckAuthorizationHeader(kc, this.ApiTokenCache, req) {
369                 http.Error(resp, "Missing or invalid Authorization header", http.StatusForbidden)
370                 return
371         }
372
373         // Check if the client specified the number of replicas
374         if req.Header.Get("X-Keep-Desired-Replicas") != "" {
375                 var r int
376                 _, err := fmt.Sscanf(req.Header.Get(keepclient.X_Keep_Desired_Replicas), "%d", &r)
377                 if err != nil {
378                         kc.Want_replicas = r
379                 }
380         }
381
382         // Now try to put the block through
383         hash, replicas, err := kc.PutHR(hash, req.Body, contentLength)
384
385         // Tell the client how many successful PUTs we accomplished
386         resp.Header().Set(keepclient.X_Keep_Replicas_Stored, fmt.Sprintf("%d", replicas))
387
388         switch err {
389         case nil:
390                 // Default will return http.StatusOK
391                 log.Printf("%s: %s %s finished, stored %v replicas (desired %v)", GetRemoteAddress(req), req.Method, hash, replicas, kc.Want_replicas)
392                 n, err2 := io.WriteString(resp, hash)
393                 if err2 != nil {
394                         log.Printf("%s: wrote %v bytes to response body and got error %v", n, err2.Error())
395                 }
396
397         case keepclient.OversizeBlockError:
398                 // Too much data
399                 http.Error(resp, fmt.Sprintf("Exceeded maximum blocksize %d", keepclient.BLOCKSIZE), http.StatusRequestEntityTooLarge)
400
401         case keepclient.InsufficientReplicasError:
402                 if replicas > 0 {
403                         // At least one write is considered success.  The
404                         // client can decide if getting less than the number of
405                         // replications it asked for is a fatal error.
406                         // Default will return http.StatusOK
407                         n, err2 := io.WriteString(resp, hash)
408                         if err2 != nil {
409                                 log.Printf("%s: wrote %v bytes to response body and got error %v", n, err2.Error())
410                         }
411                 } else {
412                         http.Error(resp, "", http.StatusServiceUnavailable)
413                 }
414
415         default:
416                 http.Error(resp, err.Error(), http.StatusBadGateway)
417         }
418
419         if err != nil {
420                 log.Printf("%s: %s %s stored %v replicas (desired %v) got error %v", GetRemoteAddress(req), req.Method, hash, replicas, kc.Want_replicas, err.Error())
421         }
422
423 }