1885: Full-stack integration test (api+keep+keepproxy+keepclient) works!
[arvados.git] / services / keep / src / arvados.org / keepproxy / keepproxy.go
1 package main
2
3 import (
4         "arvados.org/keepclient"
5         "flag"
6         "fmt"
7         "github.com/gorilla/mux"
8         "io"
9         "log"
10         "net"
11         "net/http"
12         "os"
13         "sync"
14         "time"
15 )
16
17 // Default TCP address on which to listen for requests.
18 // Initialized by the -listen flag.
19 const DEFAULT_ADDR = ":25107"
20
21 var listener net.Listener
22
23 func main() {
24         var (
25                 listen           string
26                 no_get           bool
27                 no_put           bool
28                 no_head          bool
29                 default_replicas int
30                 pidfile          string
31         )
32
33         flag.StringVar(
34                 &listen,
35                 "listen",
36                 DEFAULT_ADDR,
37                 "Interface on which to listen for requests, in the format "+
38                         "ipaddr:port. e.g. -listen=10.0.1.24:8000. Use -listen=:port "+
39                         "to listen on all network interfaces.")
40
41         flag.BoolVar(
42                 &no_get,
43                 "no-get",
44                 true,
45                 "If set, disable GET operations")
46
47         flag.BoolVar(
48                 &no_get,
49                 "no-put",
50                 false,
51                 "If set, disable PUT operations")
52
53         flag.BoolVar(
54                 &no_head,
55                 "no-head",
56                 true,
57                 "If set, disable HEAD operations")
58
59         flag.IntVar(
60                 &default_replicas,
61                 "default-replicas",
62                 2,
63                 "Default number of replicas to write if not specified by the client.")
64
65         flag.StringVar(
66                 &pidfile,
67                 "pid",
68                 "",
69                 "Path to write pid file")
70
71         flag.Parse()
72
73         /*if no_get == false || no_head == false {
74                 log.Print("Must specify -no-get and -no-head")
75                 return
76         }*/
77
78         kc, err := keepclient.MakeKeepClient()
79         if err != nil {
80                 log.Print(err)
81                 return
82         }
83
84         if pidfile != "" {
85                 f, err := os.Create(pidfile)
86                 if err == nil {
87                         fmt.Fprint(f, os.Getpid())
88                         f.Close()
89                 } else {
90                         log.Printf("Error writing pid file (%s): %s", pidfile, err.Error())
91                 }
92         }
93
94         kc.Want_replicas = default_replicas
95
96         listener, err = net.Listen("tcp", listen)
97         if err != nil {
98                 log.Printf("Could not listen on %v", listen)
99                 return
100         }
101
102         // Start listening for requests.
103         http.Serve(listener, MakeRESTRouter(!no_get, !no_put, !no_head, kc))
104 }
105
106 type ApiTokenCache struct {
107         tokens     map[string]int64
108         lock       sync.Mutex
109         expireTime int64
110 }
111
112 // Cache the token and set an expire time.  If we already have an expire time
113 // on the token, it is not updated.
114 func (this *ApiTokenCache) RememberToken(token string) {
115         this.lock.Lock()
116         defer this.lock.Unlock()
117
118         now := time.Now().Unix()
119         if this.tokens[token] == 0 {
120                 this.tokens[token] = now + this.expireTime
121         }
122 }
123
124 // Check if the cached token is known and still believed to be valid.
125 func (this *ApiTokenCache) RecallToken(token string) bool {
126         this.lock.Lock()
127         defer this.lock.Unlock()
128
129         now := time.Now().Unix()
130         if this.tokens[token] == 0 {
131                 // Unknown token
132                 return false
133         } else if now < this.tokens[token] {
134                 // Token is known and still valid
135                 return true
136         } else {
137                 // Token is expired
138                 this.tokens[token] = 0
139                 return false
140         }
141 }
142
143 func CheckAuthorizationHeader(kc keepclient.KeepClient, cache *ApiTokenCache, req *http.Request) bool {
144         if req.Header.Get("Authorization") == "" {
145                 return false
146         }
147
148         var tok string
149         _, err := fmt.Sscanf(req.Header.Get("Authorization"), "OAuth2 %s", &tok)
150         if err != nil {
151                 // Scanning error
152                 return false
153         }
154
155         if cache.RecallToken(tok) {
156                 // Valid in the cache, short circut
157                 return true
158         }
159
160         var usersreq *http.Request
161
162         if usersreq, err = http.NewRequest("GET", fmt.Sprintf("https://%s/arvados/v1/users/current", kc.ApiServer), nil); err != nil {
163                 // Can't construct the request
164                 log.Print("CheckAuthorizationHeader error: %v", err)
165                 return false
166         }
167
168         // Add api token header
169         usersreq.Header.Add("Authorization", fmt.Sprintf("OAuth2 %s", tok))
170
171         // Actually make the request
172         var resp *http.Response
173         if resp, err = kc.Client.Do(usersreq); err != nil {
174                 // Something else failed
175                 log.Print("CheckAuthorizationHeader error: %v", err)
176                 return false
177         }
178
179         if resp.StatusCode != http.StatusOK {
180                 // Bad status
181                 return false
182         }
183
184         // Success!  Update cache
185         cache.RememberToken(tok)
186
187         return true
188 }
189
190 type GetBlockHandler struct {
191         keepclient.KeepClient
192         *ApiTokenCache
193 }
194
195 type PutBlockHandler struct {
196         keepclient.KeepClient
197         *ApiTokenCache
198 }
199
200 // MakeRESTRouter
201 //     Returns a mux.Router that passes GET and PUT requests to the
202 //     appropriate handlers.
203 //
204 func MakeRESTRouter(
205         enable_get bool,
206         enable_put bool,
207         enable_head bool,
208         kc keepclient.KeepClient) *mux.Router {
209
210         t := &ApiTokenCache{tokens: make(map[string]int64), expireTime: 300}
211
212         rest := mux.NewRouter()
213         gh := rest.Handle(`/{hash:[0-9a-f]{32}}`, GetBlockHandler{kc, t})
214         ghsig := rest.Handle(
215                 `/{hash:[0-9a-f]{32}}+A{signature:[0-9a-f]+}@{timestamp:[0-9a-f]+}`,
216                 GetBlockHandler{kc, t})
217         ph := rest.Handle(`/{hash:[0-9a-f]{32}}`, PutBlockHandler{kc, t})
218
219         if enable_get {
220                 gh.Methods("GET")
221                 ghsig.Methods("GET")
222         }
223
224         if enable_put {
225                 ph.Methods("PUT")
226         }
227
228         if enable_head {
229                 gh.Methods("HEAD")
230                 ghsig.Methods("HEAD")
231         }
232
233         return rest
234 }
235
236 func (this GetBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
237         hash := mux.Vars(req)["hash"]
238         signature := mux.Vars(req)["signature"]
239         timestamp := mux.Vars(req)["timestamp"]
240
241         var reader io.ReadCloser
242         var err error
243         var blocklen int64
244
245         if req.Method == "GET" {
246                 reader, blocklen, _, err = this.KeepClient.AuthorizedGet(hash, signature, timestamp)
247                 defer reader.Close()
248         } else if req.Method == "HEAD" {
249                 blocklen, _, err = this.KeepClient.AuthorizedAsk(hash, signature, timestamp)
250         }
251
252         resp.Header().Set("Content-Length", fmt.Sprint(blocklen))
253
254         switch err {
255         case nil:
256                 if reader != nil {
257                         io.Copy(resp, reader)
258                 }
259         case keepclient.BlockNotFound:
260                 http.Error(resp, "Not found", http.StatusNotFound)
261         default:
262                 http.Error(resp, err.Error(), http.StatusBadGateway)
263         }
264 }
265
266 func (this PutBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
267         if !CheckAuthorizationHeader(this.KeepClient, this.ApiTokenCache, req) {
268                 http.Error(resp, "Missing or invalid Authorization header", http.StatusForbidden)
269         }
270
271         hash := mux.Vars(req)["hash"]
272
273         var contentLength int64 = -1
274         if req.Header.Get("Content-Length") != "" {
275                 _, err := fmt.Sscanf(req.Header.Get("Content-Length"), "%d", &contentLength)
276                 if err != nil {
277                         resp.Header().Set("Content-Length", fmt.Sprintf("%d", contentLength))
278                 }
279
280         }
281
282         if contentLength < 1 {
283                 http.Error(resp, "Must include Content-Length header", http.StatusLengthRequired)
284                 return
285         }
286
287         // Check if the client specified the number of replicas
288         if req.Header.Get("X-Keep-Desired-Replicas") != "" {
289                 var r int
290                 _, err := fmt.Sscanf(req.Header.Get("X-Keep-Desired-Replicas"), "%d", &r)
291                 if err != nil {
292                         this.KeepClient.Want_replicas = r
293                 }
294         }
295
296         // Now try to put the block through
297         replicas, err := this.KeepClient.PutHR(hash, req.Body, contentLength)
298
299         log.Printf("Replicas stored: %v err: %v", replicas, err)
300
301         // Tell the client how many successful PUTs we accomplished
302         resp.Header().Set("X-Keep-Replicas-Stored", fmt.Sprintf("%d", replicas))
303
304         switch err {
305         case nil:
306                 // Default will return http.StatusOK
307
308         case keepclient.OversizeBlockError:
309                 // Too much data
310                 http.Error(resp, fmt.Sprintf("Exceeded maximum blocksize %d", keepclient.BLOCKSIZE), http.StatusRequestEntityTooLarge)
311
312         case keepclient.InsufficientReplicasError:
313                 if replicas > 0 {
314                         // At least one write is considered success.  The
315                         // client can decide if getting less than the number of
316                         // replications it asked for is a fatal error.
317                         // Default will return http.StatusOK
318                 } else {
319                         http.Error(resp, "", http.StatusServiceUnavailable)
320                 }
321
322         default:
323                 http.Error(resp, err.Error(), http.StatusBadGateway)
324         }
325
326 }