Merge branch 'master' into 2903-remove-pi-active-and-success
[arvados.git] / services / keep / src / arvados.org / keepproxy / keepproxy.go
1 package main
2
3 import (
4         "arvados.org/keepclient"
5         "flag"
6         "fmt"
7         "github.com/gorilla/mux"
8         "io"
9         "log"
10         "net"
11         "net/http"
12         "os"
13         "os/signal"
14         "sync"
15         "syscall"
16         "time"
17 )
18
19 // Default TCP address on which to listen for requests.
20 // Initialized by the -listen flag.
21 const DEFAULT_ADDR = ":25107"
22
23 var listener net.Listener
24
25 func main() {
26         var (
27                 listen           string
28                 no_get           bool
29                 no_put           bool
30                 default_replicas int
31                 pidfile          string
32         )
33
34         flagset := flag.NewFlagSet("default", flag.ExitOnError)
35
36         flagset.StringVar(
37                 &listen,
38                 "listen",
39                 DEFAULT_ADDR,
40                 "Interface on which to listen for requests, in the format "+
41                         "ipaddr:port. e.g. -listen=10.0.1.24:8000. Use -listen=:port "+
42                         "to listen on all network interfaces.")
43
44         flagset.BoolVar(
45                 &no_get,
46                 "no-get",
47                 false,
48                 "If set, disable GET operations")
49
50         flagset.BoolVar(
51                 &no_put,
52                 "no-put",
53                 false,
54                 "If set, disable PUT operations")
55
56         flagset.IntVar(
57                 &default_replicas,
58                 "default-replicas",
59                 2,
60                 "Default number of replicas to write if not specified by the client.")
61
62         flagset.StringVar(
63                 &pidfile,
64                 "pid",
65                 "",
66                 "Path to write pid file")
67
68         flagset.Parse(os.Args[1:])
69
70         kc, err := keepclient.MakeKeepClient()
71         if err != nil {
72                 log.Fatalf("Error setting up keep client %s", err.Error())
73         }
74
75         if pidfile != "" {
76                 f, err := os.Create(pidfile)
77                 if err == nil {
78                         fmt.Fprint(f, os.Getpid())
79                         f.Close()
80                 } else {
81                         log.Printf("Error writing pid file (%s): %s", pidfile, err.Error())
82                 }
83         }
84
85         kc.Want_replicas = default_replicas
86
87         listener, err = net.Listen("tcp", listen)
88         if err != nil {
89                 log.Fatalf("Could not listen on %v", listen)
90         }
91
92         go RefreshServicesList(&kc)
93
94         // Shut down the server gracefully (by closing the listener)
95         // if SIGTERM is received.
96         term := make(chan os.Signal, 1)
97         go func(sig <-chan os.Signal) {
98                 s := <-sig
99                 log.Println("caught signal:", s)
100                 listener.Close()
101         }(term)
102         signal.Notify(term, syscall.SIGTERM)
103         signal.Notify(term, syscall.SIGINT)
104
105         if pidfile != "" {
106                 f, err := os.Create(pidfile)
107                 if err == nil {
108                         fmt.Fprint(f, os.Getpid())
109                         f.Close()
110                 } else {
111                         log.Printf("Error writing pid file (%s): %s", pidfile, err.Error())
112                 }
113         }
114
115         log.Printf("Arvados Keep proxy started listening on %v with server list %v", listener.Addr(), kc.ServiceRoots())
116
117         // Start listening for requests.
118         http.Serve(listener, MakeRESTRouter(!no_get, !no_put, &kc))
119
120         log.Println("shutting down")
121
122         if pidfile != "" {
123                 os.Remove(pidfile)
124         }
125 }
126
127 type ApiTokenCache struct {
128         tokens     map[string]int64
129         lock       sync.Mutex
130         expireTime int64
131 }
132
133 // Refresh the keep service list every five minutes.
134 func RefreshServicesList(kc *keepclient.KeepClient) {
135         for {
136                 time.Sleep(300 * time.Second)
137                 oldservices := kc.ServiceRoots()
138                 kc.DiscoverKeepServers()
139                 newservices := kc.ServiceRoots()
140                 s1 := fmt.Sprint(oldservices)
141                 s2 := fmt.Sprint(newservices)
142                 if s1 != s2 {
143                         log.Printf("Updated server list to %v", s2)
144                 }
145         }
146 }
147
148 // Cache the token and set an expire time.  If we already have an expire time
149 // on the token, it is not updated.
150 func (this *ApiTokenCache) RememberToken(token string) {
151         this.lock.Lock()
152         defer this.lock.Unlock()
153
154         now := time.Now().Unix()
155         if this.tokens[token] == 0 {
156                 this.tokens[token] = now + this.expireTime
157         }
158 }
159
160 // Check if the cached token is known and still believed to be valid.
161 func (this *ApiTokenCache) RecallToken(token string) bool {
162         this.lock.Lock()
163         defer this.lock.Unlock()
164
165         now := time.Now().Unix()
166         if this.tokens[token] == 0 {
167                 // Unknown token
168                 return false
169         } else if now < this.tokens[token] {
170                 // Token is known and still valid
171                 return true
172         } else {
173                 // Token is expired
174                 this.tokens[token] = 0
175                 return false
176         }
177 }
178
179 func GetRemoteAddress(req *http.Request) string {
180         if realip := req.Header.Get("X-Real-IP"); realip != "" {
181                 if forwarded := req.Header.Get("X-Forwarded-For"); forwarded != realip {
182                         return fmt.Sprintf("%s (X-Forwarded-For %s)", realip, forwarded)
183                 } else {
184                         return realip
185                 }
186         }
187         return req.RemoteAddr
188 }
189
190 func CheckAuthorizationHeader(kc keepclient.KeepClient, cache *ApiTokenCache, req *http.Request) bool {
191         var auth string
192         if auth = req.Header.Get("Authorization"); auth == "" {
193                 return false
194         }
195
196         var tok string
197         _, err := fmt.Sscanf(auth, "OAuth2 %s", &tok)
198         if err != nil {
199                 // Scanning error
200                 return false
201         }
202
203         if cache.RecallToken(tok) {
204                 // Valid in the cache, short circut
205                 return true
206         }
207
208         var usersreq *http.Request
209
210         if usersreq, err = http.NewRequest("HEAD", fmt.Sprintf("https://%s/arvados/v1/users/current", kc.ApiServer), nil); err != nil {
211                 // Can't construct the request
212                 log.Printf("%s: CheckAuthorizationHeader error: %v", GetRemoteAddress(req), err)
213                 return false
214         }
215
216         // Add api token header
217         usersreq.Header.Add("Authorization", fmt.Sprintf("OAuth2 %s", tok))
218
219         // Actually make the request
220         var resp *http.Response
221         if resp, err = kc.Client.Do(usersreq); err != nil {
222                 // Something else failed
223                 log.Printf("%s: CheckAuthorizationHeader error connecting to API server: %v", GetRemoteAddress(req), err.Error())
224                 return false
225         }
226
227         if resp.StatusCode != http.StatusOK {
228                 // Bad status
229                 log.Printf("%s: CheckAuthorizationHeader API server responded: %v", GetRemoteAddress(req), resp.Status)
230                 return false
231         }
232
233         // Success!  Update cache
234         cache.RememberToken(tok)
235
236         return true
237 }
238
239 type GetBlockHandler struct {
240         *keepclient.KeepClient
241         *ApiTokenCache
242 }
243
244 type PutBlockHandler struct {
245         *keepclient.KeepClient
246         *ApiTokenCache
247 }
248
249 type InvalidPathHandler struct{}
250
251 // MakeRESTRouter
252 //     Returns a mux.Router that passes GET and PUT requests to the
253 //     appropriate handlers.
254 //
255 func MakeRESTRouter(
256         enable_get bool,
257         enable_put bool,
258         kc *keepclient.KeepClient) *mux.Router {
259
260         t := &ApiTokenCache{tokens: make(map[string]int64), expireTime: 300}
261
262         rest := mux.NewRouter()
263
264         if enable_get {
265                 rest.Handle(`/{hash:[0-9a-f]{32}}+{hints}`,
266                         GetBlockHandler{kc, t}).Methods("GET", "HEAD")
267                 rest.Handle(`/{hash:[0-9a-f]{32}}`, GetBlockHandler{kc, t}).Methods("GET", "HEAD")
268         }
269
270         if enable_put {
271                 rest.Handle(`/{hash:[0-9a-f]{32}}+{hints}`, PutBlockHandler{kc, t}).Methods("PUT")
272                 rest.Handle(`/{hash:[0-9a-f]{32}}`, PutBlockHandler{kc, t}).Methods("PUT")
273         }
274
275         rest.NotFoundHandler = InvalidPathHandler{}
276
277         return rest
278 }
279
280 func (this InvalidPathHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
281         log.Printf("%s: %s %s unroutable", GetRemoteAddress(req), req.Method, req.URL.Path)
282         http.Error(resp, "Bad request", http.StatusBadRequest)
283 }
284
285 func (this GetBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
286
287         kc := *this.KeepClient
288
289         hash := mux.Vars(req)["hash"]
290         hints := mux.Vars(req)["hints"]
291
292         locator := keepclient.MakeLocator2(hash, hints)
293
294         log.Printf("%s: %s %s", GetRemoteAddress(req), req.Method, hash)
295
296         if !CheckAuthorizationHeader(kc, this.ApiTokenCache, req) {
297                 http.Error(resp, "Missing or invalid Authorization header", http.StatusForbidden)
298                 return
299         }
300
301         var reader io.ReadCloser
302         var err error
303         var blocklen int64
304
305         if req.Method == "GET" {
306                 reader, blocklen, _, err = kc.AuthorizedGet(hash, locator.Signature, locator.Timestamp)
307                 defer reader.Close()
308         } else if req.Method == "HEAD" {
309                 blocklen, _, err = kc.AuthorizedAsk(hash, locator.Signature, locator.Timestamp)
310         }
311
312         resp.Header().Set("Content-Length", fmt.Sprint(blocklen))
313
314         switch err {
315         case nil:
316                 if reader != nil {
317                         n, err2 := io.Copy(resp, reader)
318                         if n != blocklen {
319                                 log.Printf("%s: %s %s mismatched return %v with Content-Length %v error", GetRemoteAddress(req), req.Method, hash, n, blocklen, err.Error())
320                         } else if err2 == nil {
321                                 log.Printf("%s: %s %s success returned %v bytes", GetRemoteAddress(req), req.Method, hash, n)
322                         } else {
323                                 log.Printf("%s: %s %s returned %v bytes error %v", GetRemoteAddress(req), req.Method, hash, n, err.Error())
324                         }
325                 } else {
326                         log.Printf("%s: %s %s success", GetRemoteAddress(req), req.Method, hash)
327                 }
328         case keepclient.BlockNotFound:
329                 http.Error(resp, "Not found", http.StatusNotFound)
330         default:
331                 http.Error(resp, err.Error(), http.StatusBadGateway)
332         }
333
334         if err != nil {
335                 log.Printf("%s: %s %s error %s", GetRemoteAddress(req), req.Method, hash, err.Error())
336         }
337 }
338
339 func (this PutBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
340
341         kc := *this.KeepClient
342
343         hash := mux.Vars(req)["hash"]
344         hints := mux.Vars(req)["hints"]
345
346         locator := keepclient.MakeLocator2(hash, hints)
347
348         var contentLength int64 = -1
349         if req.Header.Get("Content-Length") != "" {
350                 _, err := fmt.Sscanf(req.Header.Get("Content-Length"), "%d", &contentLength)
351                 if err != nil {
352                         resp.Header().Set("Content-Length", fmt.Sprintf("%d", contentLength))
353                 }
354
355         }
356
357         log.Printf("%s: %s %s Content-Length %v", GetRemoteAddress(req), req.Method, hash, contentLength)
358
359         if contentLength < 1 {
360                 http.Error(resp, "Must include Content-Length header", http.StatusLengthRequired)
361                 return
362         }
363
364         if locator.Size > 0 && int64(locator.Size) != contentLength {
365                 http.Error(resp, "Locator size hint does not match Content-Length header", http.StatusBadRequest)
366                 return
367         }
368
369         if !CheckAuthorizationHeader(kc, this.ApiTokenCache, req) {
370                 http.Error(resp, "Missing or invalid Authorization header", http.StatusForbidden)
371                 return
372         }
373
374         // Check if the client specified the number of replicas
375         if req.Header.Get("X-Keep-Desired-Replicas") != "" {
376                 var r int
377                 _, err := fmt.Sscanf(req.Header.Get(keepclient.X_Keep_Desired_Replicas), "%d", &r)
378                 if err != nil {
379                         kc.Want_replicas = r
380                 }
381         }
382
383         // Now try to put the block through
384         hash, replicas, err := kc.PutHR(hash, req.Body, contentLength)
385
386         // Tell the client how many successful PUTs we accomplished
387         resp.Header().Set(keepclient.X_Keep_Replicas_Stored, fmt.Sprintf("%d", replicas))
388
389         switch err {
390         case nil:
391                 // Default will return http.StatusOK
392                 log.Printf("%s: %s %s finished, stored %v replicas (desired %v)", GetRemoteAddress(req), req.Method, hash, replicas, kc.Want_replicas)
393                 n, err2 := io.WriteString(resp, hash)
394                 if err2 != nil {
395                         log.Printf("%s: wrote %v bytes to response body and got error %v", n, err2.Error())
396                 }
397
398         case keepclient.OversizeBlockError:
399                 // Too much data
400                 http.Error(resp, fmt.Sprintf("Exceeded maximum blocksize %d", keepclient.BLOCKSIZE), http.StatusRequestEntityTooLarge)
401
402         case keepclient.InsufficientReplicasError:
403                 if replicas > 0 {
404                         // At least one write is considered success.  The
405                         // client can decide if getting less than the number of
406                         // replications it asked for is a fatal error.
407                         // Default will return http.StatusOK
408                         n, err2 := io.WriteString(resp, hash)
409                         if err2 != nil {
410                                 log.Printf("%s: wrote %v bytes to response body and got error %v", n, err2.Error())
411                         }
412                 } else {
413                         http.Error(resp, "", http.StatusServiceUnavailable)
414                 }
415
416         default:
417                 http.Error(resp, err.Error(), http.StatusBadGateway)
418         }
419
420         if err != nil {
421                 log.Printf("%s: %s %s stored %v replicas (desired %v) got error %v", GetRemoteAddress(req), req.Method, hash, replicas, kc.Want_replicas, err.Error())
422         }
423
424 }