1885: Made ServiceRoots atomically updatable, so that KeepProxy can poll for
[arvados.git] / services / keep / src / arvados.org / keepproxy / keepproxy.go
1 package main
2
3 import (
4         "arvados.org/keepclient"
5         "flag"
6         "fmt"
7         "github.com/gorilla/mux"
8         "io"
9         "log"
10         "net"
11         "net/http"
12         "os"
13         "sync"
14         "time"
15 )
16
17 // Default TCP address on which to listen for requests.
18 // Initialized by the -listen flag.
19 const DEFAULT_ADDR = ":25107"
20
21 var listener net.Listener
22
23 func main() {
24         var (
25                 listen           string
26                 no_get           bool
27                 no_put           bool
28                 default_replicas int
29                 pidfile          string
30         )
31
32         flag.StringVar(
33                 &listen,
34                 "listen",
35                 DEFAULT_ADDR,
36                 "Interface on which to listen for requests, in the format "+
37                         "ipaddr:port. e.g. -listen=10.0.1.24:8000. Use -listen=:port "+
38                         "to listen on all network interfaces.")
39
40         flag.BoolVar(
41                 &no_get,
42                 "no-get",
43                 false,
44                 "If set, disable GET operations")
45
46         flag.BoolVar(
47                 &no_get,
48                 "no-put",
49                 false,
50                 "If set, disable PUT operations")
51
52         flag.IntVar(
53                 &default_replicas,
54                 "default-replicas",
55                 2,
56                 "Default number of replicas to write if not specified by the client.")
57
58         flag.StringVar(
59                 &pidfile,
60                 "pid",
61                 "",
62                 "Path to write pid file")
63
64         flag.Parse()
65
66         /*if no_get == false {
67                 log.Print("Must specify -no-get")
68                 return
69         }*/
70
71         kc, err := keepclient.MakeKeepClient()
72         if err != nil {
73                 log.Print(err)
74                 return
75         }
76
77         if pidfile != "" {
78                 f, err := os.Create(pidfile)
79                 if err == nil {
80                         fmt.Fprint(f, os.Getpid())
81                         f.Close()
82                 } else {
83                         log.Printf("Error writing pid file (%s): %s", pidfile, err.Error())
84                 }
85         }
86
87         kc.Want_replicas = default_replicas
88
89         listener, err = net.Listen("tcp", listen)
90         if err != nil {
91                 log.Printf("Could not listen on %v", listen)
92                 return
93         }
94
95         go RefreshServicesList(&kc)
96
97         // Start listening for requests.
98         http.Serve(listener, MakeRESTRouter(!no_get, !no_put, &kc))
99 }
100
101 type ApiTokenCache struct {
102         tokens     map[string]int64
103         lock       sync.Mutex
104         expireTime int64
105 }
106
107 // Refresh the keep service list every five minutes.
108 func RefreshServicesList(kc *keepclient.KeepClient) {
109         for {
110                 time.Sleep(300 * time.Second)
111                 kc.DiscoverKeepServers()
112         }
113 }
114
115 // Cache the token and set an expire time.  If we already have an expire time
116 // on the token, it is not updated.
117 func (this *ApiTokenCache) RememberToken(token string) {
118         this.lock.Lock()
119         defer this.lock.Unlock()
120
121         now := time.Now().Unix()
122         if this.tokens[token] == 0 {
123                 this.tokens[token] = now + this.expireTime
124         }
125 }
126
127 // Check if the cached token is known and still believed to be valid.
128 func (this *ApiTokenCache) RecallToken(token string) bool {
129         this.lock.Lock()
130         defer this.lock.Unlock()
131
132         now := time.Now().Unix()
133         if this.tokens[token] == 0 {
134                 // Unknown token
135                 return false
136         } else if now < this.tokens[token] {
137                 // Token is known and still valid
138                 return true
139         } else {
140                 // Token is expired
141                 this.tokens[token] = 0
142                 return false
143         }
144 }
145
146 func CheckAuthorizationHeader(kc keepclient.KeepClient, cache *ApiTokenCache, req *http.Request) bool {
147         if req.Header.Get("Authorization") == "" {
148                 return false
149         }
150
151         var tok string
152         _, err := fmt.Sscanf(req.Header.Get("Authorization"), "OAuth2 %s", &tok)
153         if err != nil {
154                 // Scanning error
155                 return false
156         }
157
158         if cache.RecallToken(tok) {
159                 // Valid in the cache, short circut
160                 return true
161         }
162
163         var usersreq *http.Request
164
165         if usersreq, err = http.NewRequest("GET", fmt.Sprintf("https://%s/arvados/v1/users/current", kc.ApiServer), nil); err != nil {
166                 // Can't construct the request
167                 log.Print("CheckAuthorizationHeader error: %v", err)
168                 return false
169         }
170
171         // Add api token header
172         usersreq.Header.Add("Authorization", fmt.Sprintf("OAuth2 %s", tok))
173
174         // Actually make the request
175         var resp *http.Response
176         if resp, err = kc.Client.Do(usersreq); err != nil {
177                 // Something else failed
178                 log.Print("CheckAuthorizationHeader error: %v", err)
179                 return false
180         }
181
182         if resp.StatusCode != http.StatusOK {
183                 // Bad status
184                 return false
185         }
186
187         // Success!  Update cache
188         cache.RememberToken(tok)
189
190         return true
191 }
192
193 type GetBlockHandler struct {
194         *keepclient.KeepClient
195         *ApiTokenCache
196 }
197
198 type PutBlockHandler struct {
199         *keepclient.KeepClient
200         *ApiTokenCache
201 }
202
203 // MakeRESTRouter
204 //     Returns a mux.Router that passes GET and PUT requests to the
205 //     appropriate handlers.
206 //
207 func MakeRESTRouter(
208         enable_get bool,
209         enable_put bool,
210         kc *keepclient.KeepClient) *mux.Router {
211
212         t := &ApiTokenCache{tokens: make(map[string]int64), expireTime: 300}
213
214         rest := mux.NewRouter()
215         gh := rest.Handle(`/{hash:[0-9a-f]{32}}`, GetBlockHandler{kc, t})
216         ghsig := rest.Handle(
217                 `/{hash:[0-9a-f]{32}}+A{signature:[0-9a-f]+}@{timestamp:[0-9a-f]+}`,
218                 GetBlockHandler{kc, t})
219         ph := rest.Handle(`/{hash:[0-9a-f]{32}}`, PutBlockHandler{kc, t})
220
221         if enable_get {
222                 gh.Methods("GET", "HEAD")
223                 ghsig.Methods("GET", "HEAD")
224         }
225
226         if enable_put {
227                 ph.Methods("PUT")
228         }
229
230         return rest
231 }
232
233 func (this GetBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
234
235         kc := *this.KeepClient
236
237         if !CheckAuthorizationHeader(kc, this.ApiTokenCache, req) {
238                 http.Error(resp, "Missing or invalid Authorization header", http.StatusForbidden)
239         }
240
241         hash := mux.Vars(req)["hash"]
242         signature := mux.Vars(req)["signature"]
243         timestamp := mux.Vars(req)["timestamp"]
244
245         var reader io.ReadCloser
246         var err error
247         var blocklen int64
248
249         if req.Method == "GET" {
250                 reader, blocklen, _, err = kc.AuthorizedGet(hash, signature, timestamp)
251                 defer reader.Close()
252         } else if req.Method == "HEAD" {
253                 blocklen, _, err = kc.AuthorizedAsk(hash, signature, timestamp)
254         }
255
256         resp.Header().Set("Content-Length", fmt.Sprint(blocklen))
257
258         switch err {
259         case nil:
260                 if reader != nil {
261                         io.Copy(resp, reader)
262                 }
263         case keepclient.BlockNotFound:
264                 http.Error(resp, "Not found", http.StatusNotFound)
265         default:
266                 http.Error(resp, err.Error(), http.StatusBadGateway)
267         }
268 }
269
270 func (this PutBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
271
272         log.Print("PutBlockHandler start")
273
274         kc := *this.KeepClient
275
276         if !CheckAuthorizationHeader(kc, this.ApiTokenCache, req) {
277                 http.Error(resp, "Missing or invalid Authorization header", http.StatusForbidden)
278         }
279
280         hash := mux.Vars(req)["hash"]
281
282         var contentLength int64 = -1
283         if req.Header.Get("Content-Length") != "" {
284                 _, err := fmt.Sscanf(req.Header.Get("Content-Length"), "%d", &contentLength)
285                 if err != nil {
286                         resp.Header().Set("Content-Length", fmt.Sprintf("%d", contentLength))
287                 }
288
289         }
290
291         if contentLength < 1 {
292                 http.Error(resp, "Must include Content-Length header", http.StatusLengthRequired)
293                 return
294         }
295
296         // Check if the client specified the number of replicas
297         if req.Header.Get("X-Keep-Desired-Replicas") != "" {
298                 var r int
299                 _, err := fmt.Sscanf(req.Header.Get("X-Keep-Desired-Replicas"), "%d", &r)
300                 if err != nil {
301                         kc.Want_replicas = r
302                 }
303         }
304
305         // Now try to put the block through
306         replicas, err := kc.PutHR(hash, req.Body, contentLength)
307
308         log.Printf("Replicas stored: %v err: %v", replicas, err)
309
310         // Tell the client how many successful PUTs we accomplished
311         resp.Header().Set("X-Keep-Replicas-Stored", fmt.Sprintf("%d", replicas))
312
313         switch err {
314         case nil:
315                 // Default will return http.StatusOK
316
317         case keepclient.OversizeBlockError:
318                 // Too much data
319                 http.Error(resp, fmt.Sprintf("Exceeded maximum blocksize %d", keepclient.BLOCKSIZE), http.StatusRequestEntityTooLarge)
320
321         case keepclient.InsufficientReplicasError:
322                 if replicas > 0 {
323                         // At least one write is considered success.  The
324                         // client can decide if getting less than the number of
325                         // replications it asked for is a fatal error.
326                         // Default will return http.StatusOK
327                 } else {
328                         http.Error(resp, "", http.StatusServiceUnavailable)
329                 }
330
331         default:
332                 http.Error(resp, err.Error(), http.StatusBadGateway)
333         }
334
335 }