2751: Teaching run_test_server how to run the proxy server for testing. Added -pid...
[arvados.git] / services / keep / src / keep / keep.go
index fa27b665b4778ea794fef2d2addfaf2acbd11f9b..3f1d5dec5a8128a5864cedfbd827f7d3d2ebdcdf 100644 (file)
@@ -12,8 +12,10 @@ import (
        "io"
        "io/ioutil"
        "log"
+       "net"
        "net/http"
        "os"
+       "os/signal"
        "regexp"
        "strconv"
        "strings"
@@ -28,6 +30,7 @@ import (
 // and/or configuration file settings.
 
 // Default TCP address on which to listen for requests.
+// Initialized by the --listen flag.
 const DEFAULT_ADDR = ":25107"
 
 // A Keep "block" is 64MB.
@@ -40,19 +43,22 @@ const MIN_FREE_KILOBYTES = BLOCKSIZE / 1024
 var PROC_MOUNTS = "/proc/mounts"
 
 // The Keep VolumeManager maintains a list of available volumes.
+// Initialized by the --volumes flag (or by FindKeepVolumes).
 var KeepVM VolumeManager
 
 // enforce_permissions controls whether permission signatures
-// should be enforced (affecting GET and DELETE requests)
+// should be enforced (affecting GET and DELETE requests).
+// Initialized by the --enforce-permissions flag.
 var enforce_permissions bool
 
-// permission_ttl is the time duration (in seconds) for which
-// new permission signatures (returned by PUT requests) will be
-// valid.
-var permission_ttl int
+// permission_ttl is the time duration for which new permission
+// signatures (returned by PUT requests) will be valid.
+// Initialized by the --permission-ttl flag.
+var permission_ttl time.Duration
 
 // data_manager_token represents the API token used by the
 // Data Manager, and is required on certain privileged operations.
+// Initialized by the --data-manager-token-file flag.
 var data_manager_token string
 
 // ==========
@@ -83,7 +89,14 @@ func (e *KeepError) Error() string {
 // data exceeds BLOCKSIZE bytes.
 var ReadErrorTooLong = errors.New("Too long")
 
+// TODO(twp): continue moving as much code as possible out of main
+// so it can be effectively tested. Esp. handling and postprocessing
+// of command line flags (identifying Keep volumes and initializing
+// permission arguments).
+
 func main() {
+       log.Println("Keep started: pid", os.Getpid())
+
        // Parse command-line flags:
        //
        // -listen=ipaddr:port
@@ -103,13 +116,21 @@ func main() {
        //    by looking at currently mounted filesystems for /keep top-level
        //    directories.
 
-       var data_manager_token_file, listen, permission_key_file, volumearg string
-       var serialize_io bool
+       var (
+               data_manager_token_file string
+               listen                  string
+               permission_key_file     string
+               permission_ttl_sec      int
+               serialize_io            bool
+               volumearg               string
+               pidfile                 string
+       )
        flag.StringVar(
                &data_manager_token_file,
                "data-manager-token-file",
                "",
-               "File with the API token used by the Data Manager. All DELETE requests or unqualified GET /index requests must carry this token.")
+               "File with the API token used by the Data Manager. All DELETE "+
+                       "requests or GET /index requests must carry this token.")
        flag.BoolVar(
                &enforce_permissions,
                "enforce-permissions",
@@ -119,27 +140,42 @@ func main() {
                &listen,
                "listen",
                DEFAULT_ADDR,
-               "interface on which to listen for requests, in the format ipaddr:port. e.g. -listen=10.0.1.24:8000. Use -listen=:port to listen on all network interfaces.")
+               "Interface on which to listen for requests, in the format "+
+                       "ipaddr:port. e.g. -listen=10.0.1.24:8000. Use -listen=:port "+
+                       "to listen on all network interfaces.")
        flag.StringVar(
                &permission_key_file,
                "permission-key-file",
                "",
-               "File containing the secret key for generating and verifying permission signatures.")
+               "File containing the secret key for generating and verifying "+
+                       "permission signatures.")
        flag.IntVar(
-               &permission_ttl,
+               &permission_ttl_sec,
                "permission-ttl",
                300,
-               "Expiration time (in seconds) for newly generated permission signatures.")
+               "Expiration time (in seconds) for newly generated permission "+
+                       "signatures.")
        flag.BoolVar(
                &serialize_io,
                "serialize",
                false,
-               "If set, all read and write operations on local Keep volumes will be serialized.")
+               "If set, all read and write operations on local Keep volumes will "+
+                       "be serialized.")
        flag.StringVar(
                &volumearg,
                "volumes",
                "",
-               "Comma-separated list of directories to use for Keep volumes, e.g. -volumes=/var/keep1,/var/keep2. If empty or not supplied, Keep will scan mounted filesystems for volumes with a /keep top-level directory.")
+               "Comma-separated list of directories to use for Keep volumes, "+
+                       "e.g. -volumes=/var/keep1,/var/keep2. If empty or not "+
+                       "supplied, Keep will scan mounted filesystems for volumes "+
+                       "with a /keep top-level directory.")
+
+       flag.StringVar(
+               &pidfile,
+               "pid",
+               "",
+               "Path to write pid file")
+
        flag.Parse()
 
        // Look for local keep volumes.
@@ -170,24 +206,38 @@ func main() {
        }
 
        // Initialize data manager token and permission key.
+       // If these tokens are specified but cannot be read,
+       // raise a fatal error.
        if data_manager_token_file != "" {
                if buf, err := ioutil.ReadFile(data_manager_token_file); err == nil {
                        data_manager_token = strings.TrimSpace(string(buf))
                } else {
-                       log.Printf("reading data_manager_token: %s\n", err)
+                       log.Fatalf("reading data manager token: %s\n", err)
                }
        }
        if permission_key_file != "" {
                if buf, err := ioutil.ReadFile(permission_key_file); err == nil {
                        PermissionSecret = bytes.TrimSpace(buf)
                } else {
-                       log.Printf("reading data_manager_token: %s\n", err)
+                       log.Fatalf("reading permission key: %s\n", err)
                }
        }
 
-       // If --enforce-permissions is true, we must have a permission key to continue.
-       if enforce_permissions && PermissionSecret == nil {
-               log.Fatal("--enforce-permissions requires a permission key")
+       // Initialize permission TTL
+       permission_ttl = time.Duration(permission_ttl_sec) * time.Second
+
+       // If --enforce-permissions is true, we must have a permission key
+       // to continue.
+       if PermissionSecret == nil {
+               if enforce_permissions {
+                       log.Fatal("--enforce-permissions requires a permission key")
+               } else {
+                       log.Println("Running without a PermissionSecret. Block locators " +
+                               "returned by this server will not be signed, and will be rejected " +
+                               "by a server that enforces permissions.")
+                       log.Println("To fix this, run Keep with --permission-key-file=<path> " +
+                               "to define the location of a file containing the permission key.")
+               }
        }
 
        // Start a round-robin VolumeManager with the volumes we have found.
@@ -195,23 +245,73 @@ func main() {
 
        // Tell the built-in HTTP server to direct all requests to the REST
        // router.
-       http.Handle("/", NewRESTRouter())
+       http.Handle("/", MakeRESTRouter())
+
+       // Set up a TCP listener.
+       listener, err := net.Listen("tcp", listen)
+       if err != nil {
+               log.Fatal(err)
+       }
+
+       // Shut down the server gracefully (by closing the listener)
+       // if SIGTERM is received.
+       term := make(chan os.Signal, 1)
+       go func(sig <-chan os.Signal) {
+               s := <-sig
+               log.Println("caught signal:", s)
+               listener.Close()
+       }(term)
+       signal.Notify(term, syscall.SIGTERM)
+
+       if pidfile != "" {
+               f, err := os.Create(pidfile)
+               if err == nil {
+                       fmt.Fprint(f, os.Getpid())
+                       f.Close()
+               } else {
+                       log.Printf("Error writing pid file (%s): %s", pidfile, err.Error())
+               }
+       }
 
        // Start listening for requests.
-       http.ListenAndServe(listen, nil)
+       srv := &http.Server{Addr: listen}
+       srv.Serve(listener)
+
+       log.Println("shutting down")
+
+       if pidfile != "" {
+               os.Remove(pidfile)
+       }
 }
 
-// NewRESTRouter
+// MakeRESTRouter
 //     Returns a mux.Router that passes GET and PUT requests to the
 //     appropriate handlers.
 //
-func NewRESTRouter() *mux.Router {
+func MakeRESTRouter() *mux.Router {
        rest := mux.NewRouter()
-       rest.HandleFunc(`/{hash:[0-9a-f]{32}}`, GetBlockHandler).Methods("GET", "HEAD")
-       rest.HandleFunc(`/{hash:[0-9a-f]{32}}+A{signature:[0-9a-f]+}@{timestamp:[0-9a-f]+}`, GetBlockHandler).Methods("GET", "HEAD")
+       rest.HandleFunc(
+               `/{hash:[0-9a-f]{32}}`, GetBlockHandler).Methods("GET", "HEAD")
+       rest.HandleFunc(
+               `/{hash:[0-9a-f]{32}}+A{signature:[0-9a-f]+}@{timestamp:[0-9a-f]+}`,
+               GetBlockHandler).Methods("GET", "HEAD")
        rest.HandleFunc(`/{hash:[0-9a-f]{32}}`, PutBlockHandler).Methods("PUT")
+
+       // For IndexHandler we support:
+       //   /index           - returns all locators
+       //   /index/{prefix}  - returns all locators that begin with {prefix}
+       //      {prefix} is a string of hexadecimal digits between 0 and 32 digits.
+       //      If {prefix} is the empty string, return an index of all locators
+       //      (so /index and /index/ behave identically)
+       //      A client may supply a full 32-digit locator string, in which
+       //      case the server will return an index with either zero or one
+       //      entries. This usage allows a client to check whether a block is
+       //      present, and its size and upload time, without retrieving the
+       //      entire block.
+       //
        rest.HandleFunc(`/index`, IndexHandler).Methods("GET", "HEAD")
-       rest.HandleFunc(`/index/{prefix:[0-9a-f]{0,32}}`, IndexHandler).Methods("GET", "HEAD")
+       rest.HandleFunc(
+               `/index/{prefix:[0-9a-f]{0,32}}`, IndexHandler).Methods("GET", "HEAD")
        rest.HandleFunc(`/status.json`, StatusHandler).Methods("GET", "HEAD")
        return rest
 }
@@ -232,7 +332,8 @@ func FindKeepVolumes() []string {
                for scanner.Scan() {
                        args := strings.Fields(scanner.Text())
                        dev, mount := args[0], args[1]
-                       if (dev == "tmpfs" || strings.HasPrefix(dev, "/dev/")) && mount != "/" {
+                       if mount != "/" &&
+                               (dev == "tmpfs" || strings.HasPrefix(dev, "/dev/")) {
                                keep := mount + "/keep"
                                if st, err := os.Stat(keep); err == nil && st.IsDir() {
                                        vols = append(vols, keep)
@@ -246,8 +347,11 @@ func FindKeepVolumes() []string {
        return vols
 }
 
-func GetBlockHandler(w http.ResponseWriter, req *http.Request) {
+func GetBlockHandler(resp http.ResponseWriter, req *http.Request) {
        hash := mux.Vars(req)["hash"]
+
+       log.Printf("%s %s", req.Method, hash)
+
        signature := mux.Vars(req)["signature"]
        timestamp := mux.Vars(req)["timestamp"]
 
@@ -255,24 +359,29 @@ func GetBlockHandler(w http.ResponseWriter, req *http.Request) {
        // request's permission signature.
        if enforce_permissions {
                if signature == "" || timestamp == "" {
-                       http.Error(w, PermissionError.Error(), PermissionError.HTTPCode)
+                       http.Error(resp, PermissionError.Error(), PermissionError.HTTPCode)
                        return
                } else if IsExpired(timestamp) {
-                       http.Error(w, ExpiredError.Error(), ExpiredError.HTTPCode)
-                       return
-               } else if signature != MakePermSignature(hash, GetApiToken(req), timestamp) {
-                       http.Error(w, PermissionError.Error(), PermissionError.HTTPCode)
+                       http.Error(resp, ExpiredError.Error(), ExpiredError.HTTPCode)
                        return
+               } else {
+                       validsig := MakePermSignature(hash, GetApiToken(req), timestamp)
+                       if signature != validsig {
+                               http.Error(resp, PermissionError.Error(), PermissionError.HTTPCode)
+                               return
+                       }
                }
        }
 
        block, err := GetBlock(hash)
        if err != nil {
-               http.Error(w, err.Error(), err.(*KeepError).HTTPCode)
+               // This type assertion is safe because the only errors
+               // GetBlock can return are CorruptError or NotFoundError.
+               http.Error(resp, err.Error(), err.(*KeepError).HTTPCode)
                return
        }
 
-       _, err = w.Write(block)
+       _, err = resp.Write(block)
        if err != nil {
                log.Printf("GetBlockHandler: writing response: %s", err)
        }
@@ -280,9 +389,11 @@ func GetBlockHandler(w http.ResponseWriter, req *http.Request) {
        return
 }
 
-func PutBlockHandler(w http.ResponseWriter, req *http.Request) {
+func PutBlockHandler(resp http.ResponseWriter, req *http.Request) {
        hash := mux.Vars(req)["hash"]
 
+       log.Printf("%s %s", req.Method, hash)
+
        // Read the block data to be stored.
        // If the request exceeds BLOCKSIZE bytes, issue a HTTP 500 error.
        //
@@ -297,13 +408,12 @@ func PutBlockHandler(w http.ResponseWriter, req *http.Request) {
                if err := PutBlock(buf, hash); err == nil {
                        // Success; sign the locator and return it to the client.
                        api_token := GetApiToken(req)
-                       expiry := time.Now().Add( // convert permission_ttl to time.Duration
-                               time.Duration(permission_ttl) * time.Second)
+                       expiry := time.Now().Add(permission_ttl)
                        signed_loc := SignLocator(hash, api_token, expiry)
-                       w.Write([]byte(signed_loc))
+                       resp.Write([]byte(signed_loc))
                } else {
                        ke := err.(*KeepError)
-                       http.Error(w, ke.Error(), ke.HTTPCode)
+                       http.Error(resp, ke.Error(), ke.HTTPCode)
                }
        } else {
                log.Println("error reading request: ", err)
@@ -313,34 +423,31 @@ func PutBlockHandler(w http.ResponseWriter, req *http.Request) {
                        // the maximum request size.
                        errmsg = fmt.Sprintf("Max request size %d bytes", BLOCKSIZE)
                }
-               http.Error(w, errmsg, 500)
+               http.Error(resp, errmsg, 500)
        }
 }
 
 // IndexHandler
 //     A HandleFunc to address /index and /index/{prefix} requests.
 //
-func IndexHandler(w http.ResponseWriter, req *http.Request) {
+func IndexHandler(resp http.ResponseWriter, req *http.Request) {
        prefix := mux.Vars(req)["prefix"]
 
-       // Only the data manager may issue unqualified "GET /index" requests,
+       // Only the data manager may issue /index requests,
        // and only if enforce_permissions is enabled.
-       // If the request is unauthenticated, or does not match the data manager's
-       // API token, return 403 Permission denied.
-       if prefix == "" {
-               api_token := GetApiToken(req)
-               if !enforce_permissions ||
-                       api_token == "" ||
-                       data_manager_token != GetApiToken(req) {
-                       http.Error(w, PermissionError.Error(), PermissionError.HTTPCode)
-                       return
-               }
+       // All other requests return 403 Permission denied.
+       api_token := GetApiToken(req)
+       if !enforce_permissions ||
+               api_token == "" ||
+               data_manager_token != api_token {
+               http.Error(resp, PermissionError.Error(), PermissionError.HTTPCode)
+               return
        }
        var index string
        for _, vol := range KeepVM.Volumes() {
                index = index + vol.Index(prefix)
        }
-       w.Write([]byte(index))
+       resp.Write([]byte(index))
 }
 
 // StatusHandler
@@ -366,14 +473,14 @@ type NodeStatus struct {
        Volumes []*VolumeStatus `json:"volumes"`
 }
 
-func StatusHandler(w http.ResponseWriter, req *http.Request) {
+func StatusHandler(resp http.ResponseWriter, req *http.Request) {
        st := GetNodeStatus()
        if jstat, err := json.Marshal(st); err == nil {
-               w.Write(jstat)
+               resp.Write(jstat)
        } else {
                log.Printf("json.Marshal: %s\n", err)
                log.Printf("NodeStatus = %v\n", st)
-               http.Error(w, err.Error(), 500)
+               http.Error(resp, err.Error(), 500)
        }
 }
 
@@ -443,7 +550,7 @@ func GetBlock(hash string) ([]byte, error) {
                                // they should be sent directly to an event manager at high
                                // priority or logged as urgent problems.
                                //
-                               log.Printf("%s: checksum mismatch for request %s (actual hash %s)\n",
+                               log.Printf("%s: checksum mismatch for request %s (actual %s)\n",
                                        vol, hash, filehash)
                                return buf, CorruptError
                        }
@@ -493,8 +600,8 @@ func PutBlock(block []byte, hash string) error {
        // If we already have a block on disk under this identifier, return
        // success (but check for MD5 collisions).
        // The only errors that GetBlock can return are ErrCorrupt and ErrNotFound.
-       // In either case, we want to write our new (good) block to disk, so there is
-       // nothing special to do if err != nil.
+       // In either case, we want to write our new (good) block to disk,
+       // so there is nothing special to do if err != nil.
        if oldblock, err := GetBlock(hash); err == nil {
                if bytes.Compare(block, oldblock) == 0 {
                        return nil
@@ -578,7 +685,8 @@ func GetApiToken(req *http.Request) string {
 }
 
 // IsExpired returns true if the given Unix timestamp (expressed as a
-// hexadecimal string) is in the past.
+// hexadecimal string) is in the past, or if timestamp_hex cannot be
+// parsed as a hexadecimal string.
 func IsExpired(timestamp_hex string) bool {
        ts, err := strconv.ParseInt(timestamp_hex, 16, 0)
        if err != nil {