3220: fix tests
[arvados.git] / services / keep / src / keep / keep.go
index 001e66eba2075bb1b6fde9c3ca5f36c8889d09ba..15b73ff798577bfb1cfcd71cf78ad3d356d03a0b 100644 (file)
@@ -5,7 +5,6 @@ import (
        "bytes"
        "crypto/md5"
        "encoding/json"
-       "errors"
        "flag"
        "fmt"
        "github.com/gorilla/mux"
@@ -17,6 +16,7 @@ import (
        "os"
        "os/signal"
        "regexp"
+       "runtime"
        "strconv"
        "strings"
        "syscall"
@@ -70,25 +70,22 @@ type KeepError struct {
 }
 
 var (
-       CollisionError  = &KeepError{400, "Collision"}
-       MD5Error        = &KeepError{401, "MD5 Failure"}
-       PermissionError = &KeepError{401, "Permission denied"}
-       CorruptError    = &KeepError{402, "Corruption"}
-       ExpiredError    = &KeepError{403, "Expired permission signature"}
+       BadRequestError = &KeepError{400, "Bad Request"}
+       CollisionError  = &KeepError{500, "Collision"}
+       RequestHashError= &KeepError{422, "Hash mismatch in request"}
+       PermissionError = &KeepError{403, "Forbidden"}
+       DiskHashError   = &KeepError{500, "Hash mismatch in stored data"}
+       ExpiredError    = &KeepError{401, "Expired permission signature"}
        NotFoundError   = &KeepError{404, "Not Found"}
        GenericError    = &KeepError{500, "Fail"}
        FullError       = &KeepError{503, "Full"}
-       TooLongError    = &KeepError{504, "Too Long"}
+       TooLongError    = &KeepError{504, "Timeout"}
 )
 
 func (e *KeepError) Error() string {
        return e.ErrMsg
 }
 
-// This error is returned by ReadAtMost if the available
-// data exceeds BLOCKSIZE bytes.
-var ReadErrorTooLong = errors.New("Too long")
-
 // TODO(twp): continue moving as much code as possible out of main
 // so it can be effectively tested. Esp. handling and postprocessing
 // of command line flags (identifying Keep volumes and initializing
@@ -123,6 +120,7 @@ func main() {
                permission_ttl_sec      int
                serialize_io            bool
                volumearg               string
+               pidfile                 string
        )
        flag.StringVar(
                &data_manager_token_file,
@@ -151,7 +149,7 @@ func main() {
        flag.IntVar(
                &permission_ttl_sec,
                "permission-ttl",
-               300,
+               1209600,
                "Expiration time (in seconds) for newly generated permission "+
                        "signatures.")
        flag.BoolVar(
@@ -168,6 +166,13 @@ func main() {
                        "e.g. -volumes=/var/keep1,/var/keep2. If empty or not "+
                        "supplied, Keep will scan mounted filesystems for volumes "+
                        "with a /keep top-level directory.")
+
+       flag.StringVar(
+               &pidfile,
+               "pid",
+               "",
+               "Path to write pid file")
+
        flag.Parse()
 
        // Look for local keep volumes.
@@ -255,11 +260,25 @@ func main() {
        }(term)
        signal.Notify(term, syscall.SIGTERM)
 
+       if pidfile != "" {
+               f, err := os.Create(pidfile)
+               if err == nil {
+                       fmt.Fprint(f, os.Getpid())
+                       f.Close()
+               } else {
+                       log.Printf("Error writing pid file (%s): %s", pidfile, err.Error())
+               }
+       }
+
        // Start listening for requests.
        srv := &http.Server{Addr: listen}
        srv.Serve(listener)
 
        log.Println("shutting down")
+
+       if pidfile != "" {
+               os.Remove(pidfile)
+       }
 }
 
 // MakeRESTRouter
@@ -268,11 +287,13 @@ func main() {
 //
 func MakeRESTRouter() *mux.Router {
        rest := mux.NewRouter()
+
        rest.HandleFunc(
                `/{hash:[0-9a-f]{32}}`, GetBlockHandler).Methods("GET", "HEAD")
        rest.HandleFunc(
-               `/{hash:[0-9a-f]{32}}+A{signature:[0-9a-f]+}@{timestamp:[0-9a-f]+}`,
+               `/{hash:[0-9a-f]{32}}+{hints}`,
                GetBlockHandler).Methods("GET", "HEAD")
+
        rest.HandleFunc(`/{hash:[0-9a-f]{32}}`, PutBlockHandler).Methods("PUT")
 
        // For IndexHandler we support:
@@ -291,9 +312,18 @@ func MakeRESTRouter() *mux.Router {
        rest.HandleFunc(
                `/index/{prefix:[0-9a-f]{0,32}}`, IndexHandler).Methods("GET", "HEAD")
        rest.HandleFunc(`/status.json`, StatusHandler).Methods("GET", "HEAD")
+
+       // Any request which does not match any of these routes gets
+       // 400 Bad Request.
+       rest.NotFoundHandler = http.HandlerFunc(BadRequestHandler)
+
        return rest
 }
 
+func BadRequestHandler(w http.ResponseWriter, r *http.Request) {
+       http.Error(w, BadRequestError.Error(), BadRequestError.HTTPCode)
+}
+
 // FindKeepVolumes
 //     Returns a list of Keep volumes mounted on this system.
 //
@@ -330,8 +360,29 @@ func GetBlockHandler(resp http.ResponseWriter, req *http.Request) {
 
        log.Printf("%s %s", req.Method, hash)
 
-       signature := mux.Vars(req)["signature"]
-       timestamp := mux.Vars(req)["timestamp"]
+       hints := mux.Vars(req)["hints"]
+
+       // Parse the locator string and hints from the request.
+       // TODO(twp): implement a Locator type.
+       var signature, timestamp string
+       if hints != "" {
+               signature_pat, _ := regexp.Compile("^A([[:xdigit:]]+)@([[:xdigit:]]{8})$")
+               for _, hint := range strings.Split(hints, "+") {
+                       if match, _ := regexp.MatchString("^[[:digit:]]+$", hint); match {
+                               // Server ignores size hints
+                       } else if m := signature_pat.FindStringSubmatch(hint); m != nil {
+                               signature = m[1]
+                               timestamp = m[2]
+                       } else if match, _ := regexp.MatchString("^[[:upper:]]", hint); match {
+                               // Any unknown hint that starts with an uppercase letter is
+                               // presumed to be valid and ignored, to permit forward compatibility.
+                       } else {
+                               // Unknown format; not a valid locator.
+                               http.Error(resp, BadRequestError.Error(), BadRequestError.HTTPCode)
+                               return
+                       }
+               }
+       }
 
        // If permission checking is in effect, verify this
        // request's permission signature.
@@ -343,8 +394,8 @@ func GetBlockHandler(resp http.ResponseWriter, req *http.Request) {
                        http.Error(resp, ExpiredError.Error(), ExpiredError.HTTPCode)
                        return
                } else {
-                       validsig := MakePermSignature(hash, GetApiToken(req), timestamp)
-                       if signature != validsig {
+                       req_locator := req.URL.Path[1:] // strip leading slash
+                       if !VerifySignature(req_locator, GetApiToken(req)) {
                                http.Error(resp, PermissionError.Error(), PermissionError.HTTPCode)
                                return
                        }
@@ -352,9 +403,19 @@ func GetBlockHandler(resp http.ResponseWriter, req *http.Request) {
        }
 
        block, err := GetBlock(hash)
+
+       // Garbage collect after each GET. Fixes #2865.
+       // TODO(twp): review Keep memory usage and see if there's
+       // a better way to do this than blindly garbage collecting
+       // after every block.
+       defer runtime.GC()
+
        if err != nil {
                // This type assertion is safe because the only errors
-               // GetBlock can return are CorruptError or NotFoundError.
+               // GetBlock can return are DiskHashError or NotFoundError.
+               if err == NotFoundError {
+                       log.Printf("%s: not found, giving up\n", hash)
+               }
                http.Error(resp, err.Error(), err.(*KeepError).HTTPCode)
                return
        }
@@ -370,6 +431,10 @@ func GetBlockHandler(resp http.ResponseWriter, req *http.Request) {
 }
 
 func PutBlockHandler(resp http.ResponseWriter, req *http.Request) {
+       // Garbage collect after each PUT. Fixes #2865.
+       // See also GetBlockHandler.
+       defer runtime.GC()
+
        hash := mux.Vars(req)["hash"]
 
        log.Printf("%s %s", req.Method, hash)
@@ -377,34 +442,34 @@ func PutBlockHandler(resp http.ResponseWriter, req *http.Request) {
        // Read the block data to be stored.
        // If the request exceeds BLOCKSIZE bytes, issue a HTTP 500 error.
        //
-       // Note: because req.Body is a buffered Reader, each Read() call will
-       // collect only the data in the network buffer (typically 16384 bytes),
-       // even if it is passed a much larger slice.
-       //
-       // Instead, call ReadAtMost to read data from the socket
-       // repeatedly until either EOF or BLOCKSIZE bytes have been read.
-       //
-       if buf, err := ReadAtMost(req.Body, BLOCKSIZE); err == nil {
+       if req.ContentLength > BLOCKSIZE {
+               http.Error(resp, TooLongError.Error(), TooLongError.HTTPCode)
+               return
+       }
+
+       buf := make([]byte, req.ContentLength)
+       nread, err := io.ReadFull(req.Body, buf)
+       if err != nil {
+               http.Error(resp, err.Error(), 500)
+       } else if int64(nread) < req.ContentLength {
+               http.Error(resp, "request truncated", 500)
+       } else {
                if err := PutBlock(buf, hash); err == nil {
-                       // Success; sign the locator and return it to the client.
+                       // Success; add a size hint, sign the locator if
+                       // possible, and return it to the client.
+                       return_hash := fmt.Sprintf("%s+%d", hash, len(buf))
                        api_token := GetApiToken(req)
-                       expiry := time.Now().Add(permission_ttl)
-                       signed_loc := SignLocator(hash, api_token, expiry)
-                       resp.Write([]byte(signed_loc))
+                       if PermissionSecret != nil && api_token != "" {
+                               expiry := time.Now().Add(permission_ttl)
+                               return_hash = SignLocator(return_hash, api_token, expiry)
+                       }
+                       resp.Write([]byte(return_hash + "\n"))
                } else {
                        ke := err.(*KeepError)
                        http.Error(resp, ke.Error(), ke.HTTPCode)
                }
-       } else {
-               log.Println("error reading request: ", err)
-               errmsg := err.Error()
-               if err == ReadErrorTooLong {
-                       // Use a more descriptive error message that includes
-                       // the maximum request size.
-                       errmsg = fmt.Sprintf("Max request size %d bytes", BLOCKSIZE)
-               }
-               http.Error(resp, errmsg, 500)
        }
+       return
 }
 
 // IndexHandler
@@ -415,7 +480,7 @@ func IndexHandler(resp http.ResponseWriter, req *http.Request) {
 
        // Only the data manager may issue /index requests,
        // and only if enforce_permissions is enabled.
-       // All other requests return 403 Permission denied.
+       // All other requests return 403 Forbidden.
        api_token := GetApiToken(req)
        if !enforce_permissions ||
                api_token == "" ||
@@ -507,12 +572,13 @@ func GetVolumeStatus(volume string) *VolumeStatus {
 
 func GetBlock(hash string) ([]byte, error) {
        // Attempt to read the requested hash from a keep volume.
+       error_to_caller := NotFoundError
+
        for _, vol := range KeepVM.Volumes() {
                if buf, err := vol.Get(hash); err != nil {
                        // IsNotExist is an expected error and may be ignored.
                        // (If all volumes report IsNotExist, we return a NotFoundError)
-                       // A CorruptError should be returned immediately.
-                       // Any other errors should be logged but we continue trying to
+                       // All other errors should be logged but we continue trying to
                        // read.
                        switch {
                        case os.IsNotExist(err):
@@ -532,15 +598,22 @@ func GetBlock(hash string) ([]byte, error) {
                                //
                                log.Printf("%s: checksum mismatch for request %s (actual %s)\n",
                                        vol, hash, filehash)
-                               return buf, CorruptError
+                               error_to_caller = DiskHashError
+                       } else {
+                               // Success!
+                               if error_to_caller != NotFoundError {
+                                               log.Printf("%s: checksum mismatch for request %s but a good copy was found on another volume and returned\n",
+                                                       vol, hash)
+                               }
+                               return buf, nil
                        }
-                       // Success!
-                       return buf, nil
                }
        }
 
-       log.Printf("%s: not found on any volumes, giving up\n", hash)
-       return nil, NotFoundError
+  if error_to_caller != NotFoundError {
+    log.Printf("%s: checksum mismatch, no good copy found\n", hash)
+  }
+       return nil, error_to_caller
 }
 
 /* PutBlock(block, hash)
@@ -555,10 +628,10 @@ func GetBlock(hash string) ([]byte, error) {
    On success, PutBlock returns nil.
    On failure, it returns a KeepError with one of the following codes:
 
-   400 Collision
+   500 Collision
           A different block with the same hash already exists on this
           Keep server.
-   401 MD5Fail
+   422 MD5Fail
           The MD5 hash of the BLOCK does not match the argument HASH.
    503 Full
           There was not enough space left in any Keep volume to store
@@ -574,12 +647,12 @@ func PutBlock(block []byte, hash string) error {
        blockhash := fmt.Sprintf("%x", md5.Sum(block))
        if blockhash != hash {
                log.Printf("%s: MD5 checksum %s did not match request", hash, blockhash)
-               return MD5Error
+               return RequestHashError
        }
 
        // If we already have a block on disk under this identifier, return
        // success (but check for MD5 collisions).
-       // The only errors that GetBlock can return are ErrCorrupt and ErrNotFound.
+       // The only errors that GetBlock can return are DiskHashError and NotFoundError.
        // In either case, we want to write our new (good) block to disk,
        // so there is nothing special to do if err != nil.
        if oldblock, err := GetBlock(hash); err == nil {
@@ -620,24 +693,6 @@ func PutBlock(block []byte, hash string) error {
        }
 }
 
-// ReadAtMost
-//     Reads bytes repeatedly from an io.Reader until either
-//     encountering EOF, or the maxbytes byte limit has been reached.
-//     Returns a byte slice of the bytes that were read.
-//
-//     If the reader contains more than maxbytes, returns a nil slice
-//     and an error.
-//
-func ReadAtMost(r io.Reader, maxbytes int) ([]byte, error) {
-       // Attempt to read one more byte than maxbytes.
-       lr := io.LimitReader(r, int64(maxbytes+1))
-       buf, err := ioutil.ReadAll(lr)
-       if len(buf) > maxbytes {
-               return nil, ReadErrorTooLong
-       }
-       return buf, err
-}
-
 // IsValidLocator
 //     Return true if the specified string is a valid Keep locator.
 //     When Keep is extended to support hash types other than MD5,
@@ -652,13 +707,15 @@ func IsValidLocator(loc string) bool {
        return false
 }
 
-// GetApiToken returns the OAuth token from the Authorization
+// GetApiToken returns the OAuth2 token from the Authorization
 // header of a HTTP request, or an empty string if no matching
 // token is found.
 func GetApiToken(req *http.Request) string {
        if auth, ok := req.Header["Authorization"]; ok {
-               if strings.HasPrefix(auth[0], "OAuth ") {
-                       return auth[0][6:]
+               if pat, err := regexp.Compile(`^OAuth2\s+(.*)`); err != nil {
+                       log.Println(err)
+               } else if match := pat.FindStringSubmatch(auth[0]); match != nil {
+                       return match[1]
                }
        }
        return ""