2760: Merge branch '2760-folder-hierarchy' refs #2760
[arvados.git] / services / keep / src / arvados.org / keepproxy / keepproxy.go
1 package main
2
3 import (
4         "arvados.org/keepclient"
5         "flag"
6         "fmt"
7         "github.com/gorilla/mux"
8         "io"
9         "log"
10         "net"
11         "net/http"
12         "os"
13         "sync"
14         "time"
15 )
16
17 // Default TCP address on which to listen for requests.
18 // Initialized by the -listen flag.
19 const DEFAULT_ADDR = ":25107"
20
21 var listener net.Listener
22
23 func main() {
24         var (
25                 listen           string
26                 no_get           bool
27                 no_put           bool
28                 default_replicas int
29                 pidfile          string
30         )
31
32         flagset := flag.NewFlagSet("default", flag.ExitOnError)
33
34         flagset.StringVar(
35                 &listen,
36                 "listen",
37                 DEFAULT_ADDR,
38                 "Interface on which to listen for requests, in the format "+
39                         "ipaddr:port. e.g. -listen=10.0.1.24:8000. Use -listen=:port "+
40                         "to listen on all network interfaces.")
41
42         flagset.BoolVar(
43                 &no_get,
44                 "no-get",
45                 false,
46                 "If set, disable GET operations")
47
48         flagset.BoolVar(
49                 &no_put,
50                 "no-put",
51                 false,
52                 "If set, disable PUT operations")
53
54         flagset.IntVar(
55                 &default_replicas,
56                 "default-replicas",
57                 2,
58                 "Default number of replicas to write if not specified by the client.")
59
60         flagset.StringVar(
61                 &pidfile,
62                 "pid",
63                 "",
64                 "Path to write pid file")
65
66         flagset.Parse(os.Args[1:])
67
68         kc, err := keepclient.MakeKeepClient()
69         if err != nil {
70                 log.Fatalf("Error setting up keep client %s", err.Error())
71         }
72
73         if pidfile != "" {
74                 f, err := os.Create(pidfile)
75                 if err == nil {
76                         fmt.Fprint(f, os.Getpid())
77                         f.Close()
78                 } else {
79                         log.Printf("Error writing pid file (%s): %s", pidfile, err.Error())
80                 }
81         }
82
83         kc.Want_replicas = default_replicas
84
85         listener, err = net.Listen("tcp", listen)
86         if err != nil {
87                 log.Fatalf("Could not listen on %v", listen)
88         }
89
90         go RefreshServicesList(&kc)
91
92         log.Printf("Arvados Keep proxy started listening on %v with server list %v", listener.Addr(), kc.ServiceRoots())
93
94         // Start listening for requests.
95         http.Serve(listener, MakeRESTRouter(!no_get, !no_put, &kc))
96 }
97
98 type ApiTokenCache struct {
99         tokens     map[string]int64
100         lock       sync.Mutex
101         expireTime int64
102 }
103
104 // Refresh the keep service list every five minutes.
105 func RefreshServicesList(kc *keepclient.KeepClient) {
106         for {
107                 time.Sleep(300 * time.Second)
108                 oldservices := kc.ServiceRoots()
109                 kc.DiscoverKeepServers()
110                 newservices := kc.ServiceRoots()
111                 s1 := fmt.Sprint(oldservices)
112                 s2 := fmt.Sprint(newservices)
113                 if s1 != s2 {
114                         log.Printf("Updated server list to %v", s2)
115                 }
116         }
117 }
118
119 // Cache the token and set an expire time.  If we already have an expire time
120 // on the token, it is not updated.
121 func (this *ApiTokenCache) RememberToken(token string) {
122         this.lock.Lock()
123         defer this.lock.Unlock()
124
125         now := time.Now().Unix()
126         if this.tokens[token] == 0 {
127                 this.tokens[token] = now + this.expireTime
128         }
129 }
130
131 // Check if the cached token is known and still believed to be valid.
132 func (this *ApiTokenCache) RecallToken(token string) bool {
133         this.lock.Lock()
134         defer this.lock.Unlock()
135
136         now := time.Now().Unix()
137         if this.tokens[token] == 0 {
138                 // Unknown token
139                 return false
140         } else if now < this.tokens[token] {
141                 // Token is known and still valid
142                 return true
143         } else {
144                 // Token is expired
145                 this.tokens[token] = 0
146                 return false
147         }
148 }
149
150 func GetRemoteAddress(req *http.Request) string {
151         if realip := req.Header.Get("X-Real-IP"); realip != "" {
152                 if forwarded := req.Header.Get("X-Forwarded-For"); forwarded != realip {
153                         return fmt.Sprintf("%s (X-Forwarded-For %s)", realip, forwarded)
154                 } else {
155                         return realip
156                 }
157         }
158         return req.RemoteAddr
159 }
160
161 func CheckAuthorizationHeader(kc keepclient.KeepClient, cache *ApiTokenCache, req *http.Request) bool {
162         var auth string
163         if auth = req.Header.Get("Authorization"); auth == "" {
164                 return false
165         }
166
167         var tok string
168         _, err := fmt.Sscanf(auth, "OAuth2 %s", &tok)
169         if err != nil {
170                 // Scanning error
171                 return false
172         }
173
174         if cache.RecallToken(tok) {
175                 // Valid in the cache, short circut
176                 return true
177         }
178
179         var usersreq *http.Request
180
181         if usersreq, err = http.NewRequest("HEAD", fmt.Sprintf("https://%s/arvados/v1/users/current", kc.ApiServer), nil); err != nil {
182                 // Can't construct the request
183                 log.Printf("%s: CheckAuthorizationHeader error: %v", GetRemoteAddress(req), err)
184                 return false
185         }
186
187         // Add api token header
188         usersreq.Header.Add("Authorization", fmt.Sprintf("OAuth2 %s", tok))
189
190         // Actually make the request
191         var resp *http.Response
192         if resp, err = kc.Client.Do(usersreq); err != nil {
193                 // Something else failed
194                 log.Printf("%s: CheckAuthorizationHeader error connecting to API server: %v", GetRemoteAddress(req), err.Error())
195                 return false
196         }
197
198         if resp.StatusCode != http.StatusOK {
199                 // Bad status
200                 log.Printf("%s: CheckAuthorizationHeader API server responded: %v", GetRemoteAddress(req), resp.Status)
201                 return false
202         }
203
204         // Success!  Update cache
205         cache.RememberToken(tok)
206
207         return true
208 }
209
210 type GetBlockHandler struct {
211         *keepclient.KeepClient
212         *ApiTokenCache
213 }
214
215 type PutBlockHandler struct {
216         *keepclient.KeepClient
217         *ApiTokenCache
218 }
219
220 type InvalidPathHandler struct{}
221
222 // MakeRESTRouter
223 //     Returns a mux.Router that passes GET and PUT requests to the
224 //     appropriate handlers.
225 //
226 func MakeRESTRouter(
227         enable_get bool,
228         enable_put bool,
229         kc *keepclient.KeepClient) *mux.Router {
230
231         t := &ApiTokenCache{tokens: make(map[string]int64), expireTime: 300}
232
233         rest := mux.NewRouter()
234
235         if enable_get {
236                 gh := rest.Handle(`/{hash:[0-9a-f]{32}}`, GetBlockHandler{kc, t})
237                 ghsig := rest.Handle(
238                         `/{hash:[0-9a-f]{32}}+A{signature:[0-9a-f]+}@{timestamp:[0-9a-f]+}`,
239                         GetBlockHandler{kc, t})
240
241                 gh.Methods("GET", "HEAD")
242                 ghsig.Methods("GET", "HEAD")
243         }
244
245         if enable_put {
246                 rest.Handle(`/{hash:[0-9a-f]{32}}`, PutBlockHandler{kc, t}).Methods("PUT")
247         }
248
249         rest.NotFoundHandler = InvalidPathHandler{}
250
251         return rest
252 }
253
254 func (this InvalidPathHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
255         log.Printf("%s: %s %s unroutable", GetRemoteAddress(req), req.Method, req.URL.Path)
256         http.Error(resp, "Bad request", http.StatusBadRequest)
257 }
258
259 func (this GetBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
260
261         kc := *this.KeepClient
262
263         hash := mux.Vars(req)["hash"]
264         signature := mux.Vars(req)["signature"]
265         timestamp := mux.Vars(req)["timestamp"]
266
267         log.Printf("%s: %s %s", GetRemoteAddress(req), req.Method, hash)
268
269         if !CheckAuthorizationHeader(kc, this.ApiTokenCache, req) {
270                 http.Error(resp, "Missing or invalid Authorization header", http.StatusForbidden)
271                 return
272         }
273
274         var reader io.ReadCloser
275         var err error
276         var blocklen int64
277
278         if req.Method == "GET" {
279                 reader, blocklen, _, err = kc.AuthorizedGet(hash, signature, timestamp)
280                 defer reader.Close()
281         } else if req.Method == "HEAD" {
282                 blocklen, _, err = kc.AuthorizedAsk(hash, signature, timestamp)
283         }
284
285         resp.Header().Set("Content-Length", fmt.Sprint(blocklen))
286
287         switch err {
288         case nil:
289                 if reader != nil {
290                         n, err2 := io.Copy(resp, reader)
291                         if n != blocklen {
292                                 log.Printf("%s: %s %s mismatched return %v with Content-Length %v error", GetRemoteAddress(req), req.Method, hash, n, blocklen, err.Error())
293                         } else if err2 == nil {
294                                 log.Printf("%s: %s %s success returned %v bytes", GetRemoteAddress(req), req.Method, hash, n)
295                         } else {
296                                 log.Printf("%s: %s %s returned %v bytes error %v", GetRemoteAddress(req), req.Method, hash, n, err.Error())
297                         }
298                 } else {
299                         log.Printf("%s: %s %s success", GetRemoteAddress(req), req.Method, hash)
300                 }
301         case keepclient.BlockNotFound:
302                 http.Error(resp, "Not found", http.StatusNotFound)
303         default:
304                 http.Error(resp, err.Error(), http.StatusBadGateway)
305         }
306
307         if err != nil {
308                 log.Printf("%s: %s %s error %s", GetRemoteAddress(req), req.Method, hash, err.Error())
309         }
310 }
311
312 func (this PutBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
313
314         kc := *this.KeepClient
315
316         hash := mux.Vars(req)["hash"]
317
318         var contentLength int64 = -1
319         if req.Header.Get("Content-Length") != "" {
320                 _, err := fmt.Sscanf(req.Header.Get("Content-Length"), "%d", &contentLength)
321                 if err != nil {
322                         resp.Header().Set("Content-Length", fmt.Sprintf("%d", contentLength))
323                 }
324
325         }
326
327         log.Printf("%s: %s %s Content-Length %v", GetRemoteAddress(req), req.Method, hash, contentLength)
328
329         if contentLength < 1 {
330                 http.Error(resp, "Must include Content-Length header", http.StatusLengthRequired)
331                 return
332         }
333
334         if !CheckAuthorizationHeader(kc, this.ApiTokenCache, req) {
335                 http.Error(resp, "Missing or invalid Authorization header", http.StatusForbidden)
336                 return
337         }
338
339         // Check if the client specified the number of replicas
340         if req.Header.Get("X-Keep-Desired-Replicas") != "" {
341                 var r int
342                 _, err := fmt.Sscanf(req.Header.Get(keepclient.X_Keep_Desired_Replicas), "%d", &r)
343                 if err != nil {
344                         kc.Want_replicas = r
345                 }
346         }
347
348         // Now try to put the block through
349         replicas, err := kc.PutHR(hash, req.Body, contentLength)
350
351         // Tell the client how many successful PUTs we accomplished
352         resp.Header().Set(keepclient.X_Keep_Replicas_Stored, fmt.Sprintf("%d", replicas))
353
354         switch err {
355         case nil:
356                 // Default will return http.StatusOK
357                 log.Printf("%s: %s %s finished, stored %v replicas (desired %v)", GetRemoteAddress(req), req.Method, hash, replicas, kc.Want_replicas)
358
359         case keepclient.OversizeBlockError:
360                 // Too much data
361                 http.Error(resp, fmt.Sprintf("Exceeded maximum blocksize %d", keepclient.BLOCKSIZE), http.StatusRequestEntityTooLarge)
362
363         case keepclient.InsufficientReplicasError:
364                 if replicas > 0 {
365                         // At least one write is considered success.  The
366                         // client can decide if getting less than the number of
367                         // replications it asked for is a fatal error.
368                         // Default will return http.StatusOK
369                 } else {
370                         http.Error(resp, "", http.StatusServiceUnavailable)
371                 }
372
373         default:
374                 http.Error(resp, err.Error(), http.StatusBadGateway)
375         }
376
377         if err != nil {
378                 log.Printf("%s: %s %s stored %v replicas (desired %v) got error %v", GetRemoteAddress(req), req.Method, hash, replicas, kc.Want_replicas, err.Error())
379         }
380
381 }