Merge branch 'master' into 2751-python-sdk-keep-proxy-support refs #2751
[arvados.git] / services / keep / src / arvados.org / keepproxy / keepproxy.go
1 package main
2
3 import (
4         "arvados.org/keepclient"
5         "flag"
6         "fmt"
7         "github.com/gorilla/mux"
8         "io"
9         "log"
10         "net"
11         "net/http"
12         "os"
13         "os/signal"
14         "sync"
15         "syscall"
16         "time"
17 )
18
19 // Default TCP address on which to listen for requests.
20 // Initialized by the -listen flag.
21 const DEFAULT_ADDR = ":25107"
22
23 var listener net.Listener
24
25 func main() {
26         var (
27                 listen           string
28                 no_get           bool
29                 no_put           bool
30                 default_replicas int
31                 pidfile          string
32         )
33
34         flag.StringVar(
35                 &listen,
36                 "listen",
37                 DEFAULT_ADDR,
38                 "Interface on which to listen for requests, in the format "+
39                         "ipaddr:port. e.g. -listen=10.0.1.24:8000. Use -listen=:port "+
40                         "to listen on all network interfaces.")
41
42         flag.BoolVar(
43                 &no_get,
44                 "no-get",
45                 false,
46                 "If set, disable GET operations")
47
48         flag.BoolVar(
49                 &no_get,
50                 "no-put",
51                 false,
52                 "If set, disable PUT operations")
53
54         flag.IntVar(
55                 &default_replicas,
56                 "default-replicas",
57                 2,
58                 "Default number of replicas to write if not specified by the client.")
59
60         flag.StringVar(
61                 &pidfile,
62                 "pid",
63                 "",
64                 "Path to write pid file")
65
66         flag.Parse()
67
68         kc, err := keepclient.MakeKeepClient()
69         if err != nil {
70                 log.Fatalf("Error setting up keep client %s", err.Error())
71         }
72
73         if pidfile != "" {
74                 f, err := os.Create(pidfile)
75                 if err == nil {
76                         fmt.Fprint(f, os.Getpid())
77                         f.Close()
78                 } else {
79                         log.Printf("Error writing pid file (%s): %s", pidfile, err.Error())
80                 }
81         }
82
83         kc.Want_replicas = default_replicas
84
85         listener, err = net.Listen("tcp", listen)
86         if err != nil {
87                 log.Fatalf("Could not listen on %v", listen)
88         }
89
90         go RefreshServicesList(&kc)
91
92         // Shut down the server gracefully (by closing the listener)
93         // if SIGTERM is received.
94         term := make(chan os.Signal, 1)
95         go func(sig <-chan os.Signal) {
96                 s := <-sig
97                 log.Println("caught signal:", s)
98                 listener.Close()
99         }(term)
100         signal.Notify(term, syscall.SIGTERM)
101
102         if pidfile != "" {
103                 f, err := os.Create(pidfile)
104                 if err == nil {
105                         fmt.Fprint(f, os.Getpid())
106                         f.Close()
107                 } else {
108                         log.Printf("Error writing pid file (%s): %s", pidfile, err.Error())
109                 }
110         }
111
112         log.Printf("Arvados Keep proxy started listening on %v with server list %v", listener.Addr(), kc.ServiceRoots())
113
114         // Start listening for requests.
115         http.Serve(listener, MakeRESTRouter(!no_get, !no_put, &kc))
116
117         log.Println("shutting down")
118
119         if pidfile != "" {
120                 os.Remove(pidfile)
121         }
122 }
123
124 type ApiTokenCache struct {
125         tokens     map[string]int64
126         lock       sync.Mutex
127         expireTime int64
128 }
129
130 // Refresh the keep service list every five minutes.
131 func RefreshServicesList(kc *keepclient.KeepClient) {
132         for {
133                 time.Sleep(300 * time.Second)
134                 oldservices := kc.ServiceRoots()
135                 kc.DiscoverKeepServers()
136                 newservices := kc.ServiceRoots()
137                 s1 := fmt.Sprint(oldservices)
138                 s2 := fmt.Sprint(newservices)
139                 if s1 != s2 {
140                         log.Printf("Updated server list to %v", s2)
141                 }
142         }
143 }
144
145 // Cache the token and set an expire time.  If we already have an expire time
146 // on the token, it is not updated.
147 func (this *ApiTokenCache) RememberToken(token string) {
148         this.lock.Lock()
149         defer this.lock.Unlock()
150
151         now := time.Now().Unix()
152         if this.tokens[token] == 0 {
153                 this.tokens[token] = now + this.expireTime
154         }
155 }
156
157 // Check if the cached token is known and still believed to be valid.
158 func (this *ApiTokenCache) RecallToken(token string) bool {
159         this.lock.Lock()
160         defer this.lock.Unlock()
161
162         now := time.Now().Unix()
163         if this.tokens[token] == 0 {
164                 // Unknown token
165                 return false
166         } else if now < this.tokens[token] {
167                 // Token is known and still valid
168                 return true
169         } else {
170                 // Token is expired
171                 this.tokens[token] = 0
172                 return false
173         }
174 }
175
176 func GetRemoteAddress(req *http.Request) string {
177         if realip := req.Header.Get("X-Real-IP"); realip != "" {
178                 if forwarded := req.Header.Get("X-Forwarded-For"); forwarded != realip {
179                         return fmt.Sprintf("%s (X-Forwarded-For %s)", realip, forwarded)
180                 } else {
181                         return realip
182                 }
183         }
184         return req.RemoteAddr
185 }
186
187 func CheckAuthorizationHeader(kc keepclient.KeepClient, cache *ApiTokenCache, req *http.Request) bool {
188         var auth string
189         if auth = req.Header.Get("Authorization"); auth == "" {
190                 return false
191         }
192
193         var tok string
194         _, err := fmt.Sscanf(auth, "OAuth2 %s", &tok)
195         if err != nil {
196                 // Scanning error
197                 return false
198         }
199
200         if cache.RecallToken(tok) {
201                 // Valid in the cache, short circut
202                 return true
203         }
204
205         var usersreq *http.Request
206
207         if usersreq, err = http.NewRequest("HEAD", fmt.Sprintf("https://%s/arvados/v1/users/current", kc.ApiServer), nil); err != nil {
208                 // Can't construct the request
209                 log.Printf("%s: CheckAuthorizationHeader error: %v", GetRemoteAddress(req), err)
210                 return false
211         }
212
213         // Add api token header
214         usersreq.Header.Add("Authorization", fmt.Sprintf("OAuth2 %s", tok))
215
216         // Actually make the request
217         var resp *http.Response
218         if resp, err = kc.Client.Do(usersreq); err != nil {
219                 // Something else failed
220                 log.Printf("%s: CheckAuthorizationHeader error connecting to API server: %v", GetRemoteAddress(req), err.Error())
221                 return false
222         }
223
224         if resp.StatusCode != http.StatusOK {
225                 // Bad status
226                 log.Printf("%s: CheckAuthorizationHeader API server responded: %v", GetRemoteAddress(req), resp.Status)
227                 return false
228         }
229
230         // Success!  Update cache
231         cache.RememberToken(tok)
232
233         return true
234 }
235
236 type GetBlockHandler struct {
237         *keepclient.KeepClient
238         *ApiTokenCache
239 }
240
241 type PutBlockHandler struct {
242         *keepclient.KeepClient
243         *ApiTokenCache
244 }
245
246 type InvalidPathHandler struct{}
247
248 // MakeRESTRouter
249 //     Returns a mux.Router that passes GET and PUT requests to the
250 //     appropriate handlers.
251 //
252 func MakeRESTRouter(
253         enable_get bool,
254         enable_put bool,
255         kc *keepclient.KeepClient) *mux.Router {
256
257         t := &ApiTokenCache{tokens: make(map[string]int64), expireTime: 300}
258
259         rest := mux.NewRouter()
260         gh := rest.Handle(`/{hash:[0-9a-f]{32}}`, GetBlockHandler{kc, t})
261         ghsig := rest.Handle(
262                 `/{hash:[0-9a-f]{32}}+A{signature:[0-9a-f]+}@{timestamp:[0-9a-f]+}`,
263                 GetBlockHandler{kc, t})
264         ph := rest.Handle(`/{hash:[0-9a-f]{32}}`, PutBlockHandler{kc, t})
265
266         if enable_get {
267                 gh.Methods("GET", "HEAD")
268                 ghsig.Methods("GET", "HEAD")
269         }
270
271         if enable_put {
272                 ph.Methods("PUT")
273         }
274
275         rest.NotFoundHandler = InvalidPathHandler{}
276
277         return rest
278 }
279
280 func (this InvalidPathHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
281         log.Printf("%s: %s %s unroutable", GetRemoteAddress(req), req.Method, req.URL.Path)
282         http.Error(resp, "Bad request", http.StatusBadRequest)
283 }
284
285 func (this GetBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
286
287         kc := *this.KeepClient
288
289         hash := mux.Vars(req)["hash"]
290         signature := mux.Vars(req)["signature"]
291         timestamp := mux.Vars(req)["timestamp"]
292
293         log.Printf("%s: %s %s", GetRemoteAddress(req), req.Method, hash)
294
295         if !CheckAuthorizationHeader(kc, this.ApiTokenCache, req) {
296                 http.Error(resp, "Missing or invalid Authorization header", http.StatusForbidden)
297                 return
298         }
299
300         var reader io.ReadCloser
301         var err error
302         var blocklen int64
303
304         if req.Method == "GET" {
305                 reader, blocklen, _, err = kc.AuthorizedGet(hash, signature, timestamp)
306                 defer reader.Close()
307         } else if req.Method == "HEAD" {
308                 blocklen, _, err = kc.AuthorizedAsk(hash, signature, timestamp)
309         }
310
311         resp.Header().Set("Content-Length", fmt.Sprint(blocklen))
312
313         switch err {
314         case nil:
315                 if reader != nil {
316                         n, err2 := io.Copy(resp, reader)
317                         if n != blocklen {
318                                 log.Printf("%s: %s %s mismatched return %v with Content-Length %v error", GetRemoteAddress(req), req.Method, hash, n, blocklen, err.Error())
319                         } else if err2 == nil {
320                                 log.Printf("%s: %s %s success returned %v bytes", GetRemoteAddress(req), req.Method, hash, n)
321                         } else {
322                                 log.Printf("%s: %s %s returned %v bytes error %v", GetRemoteAddress(req), req.Method, hash, n, err.Error())
323                         }
324                 } else {
325                         log.Printf("%s: %s %s success", GetRemoteAddress(req), req.Method, hash)
326                 }
327         case keepclient.BlockNotFound:
328                 http.Error(resp, "Not found", http.StatusNotFound)
329         default:
330                 http.Error(resp, err.Error(), http.StatusBadGateway)
331         }
332
333         if err != nil {
334                 log.Printf("%s: %s %s error %s", GetRemoteAddress(req), req.Method, hash, err.Error())
335         }
336 }
337
338 func (this PutBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
339
340         kc := *this.KeepClient
341
342         hash := mux.Vars(req)["hash"]
343
344         var contentLength int64 = -1
345         if req.Header.Get("Content-Length") != "" {
346                 _, err := fmt.Sscanf(req.Header.Get("Content-Length"), "%d", &contentLength)
347                 if err != nil {
348                         resp.Header().Set("Content-Length", fmt.Sprintf("%d", contentLength))
349                 }
350
351         }
352
353         log.Printf("%s: %s %s Content-Length %v", GetRemoteAddress(req), req.Method, hash, contentLength)
354
355         if contentLength < 1 {
356                 http.Error(resp, "Must include Content-Length header", http.StatusLengthRequired)
357                 return
358         }
359
360         if !CheckAuthorizationHeader(kc, this.ApiTokenCache, req) {
361                 http.Error(resp, "Missing or invalid Authorization header", http.StatusForbidden)
362                 return
363         }
364
365         // Check if the client specified the number of replicas
366         if req.Header.Get("X-Keep-Desired-Replicas") != "" {
367                 var r int
368                 _, err := fmt.Sscanf(req.Header.Get(keepclient.X_Keep_Desired_Replicas), "%d", &r)
369                 if err != nil {
370                         kc.Want_replicas = r
371                 }
372         }
373
374         // Now try to put the block through
375         replicas, err := kc.PutHR(hash, req.Body, contentLength)
376
377         // Tell the client how many successful PUTs we accomplished
378         resp.Header().Set(keepclient.X_Keep_Replicas_Stored, fmt.Sprintf("%d", replicas))
379
380         switch err {
381         case nil:
382                 // Default will return http.StatusOK
383                 log.Printf("%s: %s %s finished, stored %v replicas (desired %v)", GetRemoteAddress(req), req.Method, hash, replicas, kc.Want_replicas)
384
385         case keepclient.OversizeBlockError:
386                 // Too much data
387                 http.Error(resp, fmt.Sprintf("Exceeded maximum blocksize %d", keepclient.BLOCKSIZE), http.StatusRequestEntityTooLarge)
388
389         case keepclient.InsufficientReplicasError:
390                 if replicas > 0 {
391                         // At least one write is considered success.  The
392                         // client can decide if getting less than the number of
393                         // replications it asked for is a fatal error.
394                         // Default will return http.StatusOK
395                 } else {
396                         http.Error(resp, "", http.StatusServiceUnavailable)
397                 }
398
399         default:
400                 http.Error(resp, err.Error(), http.StatusBadGateway)
401         }
402
403         if err != nil {
404                 log.Printf("%s: %s %s stored %v replicas (desired %v) got error %v", GetRemoteAddress(req), req.Method, hash, replicas, kc.Want_replicas, err.Error())
405         }
406
407 }