7710: text updates around the -service-type argument.
[arvados.git] / services / keepproxy / keepproxy.go
index 865212d747691f6423e5b1701575f15206f2abd4..7b5cd2befb8f69bd25fa62674d01590214aec5ad 100644 (file)
@@ -14,17 +14,15 @@ import (
        "net/http"
        "os"
        "os/signal"
-       "reflect"
        "regexp"
-       "strings"
        "sync"
        "syscall"
        "time"
 )
 
 // Default TCP address on which to listen for requests.
-// Initialized by the -listen flag.
-const DEFAULT_ADDR = ":25107"
+// Override with -listen.
+const DefaultAddr = ":25107"
 
 var listener net.Listener
 
@@ -38,12 +36,12 @@ func main() {
                pidfile          string
        )
 
-       flagset := flag.NewFlagSet("default", flag.ExitOnError)
+       flagset := flag.NewFlagSet("keepproxy", flag.ExitOnError)
 
        flagset.StringVar(
                &listen,
                "listen",
-               DEFAULT_ADDR,
+               DefaultAddr,
                "Interface on which to listen for requests, in the format "+
                        "ipaddr:port. e.g. -listen=10.0.1.24:8000. Use -listen=:port "+
                        "to listen on all network interfaces.")
@@ -85,6 +83,9 @@ func main() {
                log.Fatalf("Error setting up arvados client %s", err.Error())
        }
 
+       if os.Getenv("ARVADOS_DEBUG") != "" {
+               keepclient.DebugPrintf = log.Printf
+       }
        kc, err := keepclient.MakeKeepClient(&arv)
        if err != nil {
                log.Fatalf("Error setting up keep client %s", err.Error())
@@ -101,15 +102,14 @@ func main() {
        }
 
        kc.Want_replicas = default_replicas
-
        kc.Client.Timeout = time.Duration(timeout) * time.Second
+       go kc.RefreshServices(5*time.Minute, 3*time.Second)
 
        listener, err = net.Listen("tcp", listen)
        if err != nil {
                log.Fatalf("Could not listen on %v", listen)
        }
-
-       go RefreshServicesList(kc)
+       log.Printf("Arvados Keep proxy started listening on %v", listener.Addr())
 
        // Shut down the server gracefully (by closing the listener)
        // if SIGTERM is received.
@@ -122,9 +122,7 @@ func main() {
        signal.Notify(term, syscall.SIGTERM)
        signal.Notify(term, syscall.SIGINT)
 
-       log.Printf("Arvados Keep proxy started listening on %v", listener.Addr())
-
-       // Start listening for requests.
+       // Start serving requests.
        http.Serve(listener, MakeRESTRouter(!no_get, !no_put, kc))
 
        log.Println("shutting down")
@@ -136,30 +134,6 @@ type ApiTokenCache struct {
        expireTime int64
 }
 
-// Refresh the keep service list every five minutes.
-func RefreshServicesList(kc *keepclient.KeepClient) {
-       var previousRoots = []map[string]string{}
-       var delay time.Duration = 0
-       for {
-               time.Sleep(delay * time.Second)
-               delay = 300
-               if err := kc.DiscoverKeepServers(); err != nil {
-                       log.Println("Error retrieving services list:", err)
-                       delay = 3
-                       continue
-               }
-               newRoots := []map[string]string{kc.LocalRoots(), kc.GatewayRoots()}
-               if !reflect.DeepEqual(previousRoots, newRoots) {
-                       log.Printf("Updated services list: locals %v gateways %v", newRoots[0], newRoots[1])
-               }
-               if len(newRoots[0]) == 0 {
-                       log.Print("WARNING: No local services. Retrying in 3 seconds.")
-                       delay = 3
-               }
-               previousRoots = newRoots
-       }
-}
-
 // Cache the token and set an expire time.  If we already have an expire time
 // on the token, it is not updated.
 func (this *ApiTokenCache) RememberToken(token string) {
@@ -192,17 +166,13 @@ func (this *ApiTokenCache) RecallToken(token string) bool {
 }
 
 func GetRemoteAddress(req *http.Request) string {
-       if realip := req.Header.Get("X-Real-IP"); realip != "" {
-               if forwarded := req.Header.Get("X-Forwarded-For"); forwarded != realip {
-                       return fmt.Sprintf("%s (X-Forwarded-For %s)", realip, forwarded)
-               } else {
-                       return realip
-               }
+       if xff := req.Header.Get("X-Forwarded-For"); xff != "" {
+               return xff + "," + req.RemoteAddr
        }
        return req.RemoteAddr
 }
 
-func CheckAuthorizationHeader(kc keepclient.KeepClient, cache *ApiTokenCache, req *http.Request) (pass bool, tok string) {
+func CheckAuthorizationHeader(kc *keepclient.KeepClient, cache *ApiTokenCache, req *http.Request) (pass bool, tok string) {
        var auth string
        if auth = req.Header.Get("Authorization"); auth == "" {
                return false, ""
@@ -332,7 +302,7 @@ func (this GetBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques
 
        var pass bool
        var tok string
-       if pass, tok = CheckAuthorizationHeader(kc, this.ApiTokenCache, req); !pass {
+       if pass, tok = CheckAuthorizationHeader(&kc, this.ApiTokenCache, req); !pass {
                status, err = http.StatusForbidden, BadAuthorizationHeader
                return
        }
@@ -363,7 +333,7 @@ func (this GetBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques
                log.Println("Warning:", GetRemoteAddress(req), req.Method, proxiedURI, "Content-Length not provided")
        }
 
-       switch err {
+       switch respErr := err.(type) {
        case nil:
                status = http.StatusOK
                resp.Header().Set("Content-Length", fmt.Sprint(expectLength))
@@ -376,10 +346,16 @@ func (this GetBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques
                                err = ContentLengthMismatch
                        }
                }
-       case keepclient.BlockNotFound:
-               status = http.StatusNotFound
+       case keepclient.Error:
+               if respErr == keepclient.BlockNotFound {
+                       status = http.StatusNotFound
+               } else if respErr.Temporary() {
+                       status = http.StatusBadGateway
+               } else {
+                       status = 422
+               }
        default:
-               status = http.StatusBadGateway
+               status = http.StatusInternalServerError
        }
 }
 
@@ -433,7 +409,7 @@ func (this PutBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques
 
        var pass bool
        var tok string
-       if pass, tok = CheckAuthorizationHeader(kc, this.ApiTokenCache, req); !pass {
+       if pass, tok = CheckAuthorizationHeader(&kc, this.ApiTokenCache, req); !pass {
                err = BadAuthorizationHeader
                status = http.StatusForbidden
                return
@@ -494,7 +470,7 @@ func (this PutBlockHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques
        }
 }
 
-// ServeHTTP implemenation for IndexHandler
+// ServeHTTP implementation for IndexHandler
 // Supports only GET requests for /index/{prefix:[0-9a-f]{0,32}}
 // For each keep server found in LocalRoots:
 //   Invokes GetIndex using keepclient
@@ -516,59 +492,40 @@ func (handler IndexHandler) ServeHTTP(resp http.ResponseWriter, req *http.Reques
 
        kc := *handler.KeepClient
 
-       var pass bool
-       var tok string
-       if pass, tok = CheckAuthorizationHeader(kc, handler.ApiTokenCache, req); !pass {
+       ok, token := CheckAuthorizationHeader(&kc, handler.ApiTokenCache, req)
+       if !ok {
                status, err = http.StatusForbidden, BadAuthorizationHeader
                return
        }
 
        // Copy ArvadosClient struct and use the client's API token
        arvclient := *kc.Arvados
-       arvclient.ApiToken = tok
+       arvclient.ApiToken = token
        kc.Arvados = &arvclient
 
-       var indexResp []byte
-       var reader io.Reader
-
-       switch req.Method {
-       case "GET":
-               for uuid := range kc.LocalRoots() {
-                       reader, err = kc.GetIndex(uuid, prefix)
-                       if err != nil {
-                               break
-                       }
-
-                       var readBytes []byte
-                       readBytes, err = ioutil.ReadAll(reader)
-                       if err != nil {
-                               break
-                       }
-
-                       // Got index; verify that it is complete
-                       // The response should be "\n" if no locators matched the prefix
-                       // Else, it should be a list of locators followed by a blank line
-                       if (!strings.HasSuffix(string(readBytes), "\n\n")) && (string(readBytes) != "\n") {
-                               err = errors.New("Got incomplete index")
-                       }
-
-                       // Trim the extra empty new line found in response from each server
-                       indexResp = append(indexResp, (readBytes[0 : len(readBytes)-1])...)
-               }
-
-               // Append empty line at the end of concatenation of all server responses
-               indexResp = append(indexResp, ([]byte("\n"))...)
-       default:
+       // Only GET method is supported
+       if req.Method != "GET" {
                status, err = http.StatusNotImplemented, MethodNotSupported
                return
        }
 
-       switch err {
-       case nil:
-               status = http.StatusOK
-               resp.Header().Set("Content-Length", fmt.Sprint(len(indexResp)))
-               _, err = resp.Write(indexResp)
-       default:
-               status = http.StatusBadGateway
+       // Get index from all LocalRoots and write to resp
+       var reader io.Reader
+       for uuid := range kc.LocalRoots() {
+               reader, err = kc.GetIndex(uuid, prefix)
+               if err != nil {
+                       status = http.StatusBadGateway
+                       return
+               }
+
+               _, err = io.Copy(resp, reader)
+               if err != nil {
+                       status = http.StatusBadGateway
+                       return
+               }
        }
+
+       // Got index from all the keep servers and wrote to resp
+       status = http.StatusOK
+       resp.Write([]byte("\n"))
 }