2798: Added AuthorizedGet(), Ask() and AuthorizedAsk(). Added BLOCKSIZE
[arvados.git] / sdk / go / src / arvados.org / keepclient / keepclient.go
index 93fcf4b3b95542cd6a337b7f9fd98673cf8cdc0c..2738cefa7c3144ec613aa8b6f8d6922e33debc44 100644 (file)
@@ -15,13 +15,19 @@ import (
        "strconv"
 )
 
+// A Keep "block" is 64MB.
+const BLOCKSIZE = 64 * 1024 * 1024
+
+var BlockNotFound = errors.New("Block not found")
+var InsufficientReplicasError = errors.New("Could not write sufficient replicas")
+
 type KeepClient struct {
        ApiServer     string
        ApiToken      string
        ApiInsecure   bool
        Service_roots []string
        Want_replicas int
-       client        *http.Client
+       Client        *http.Client
 }
 
 type KeepDisk struct {
@@ -30,24 +36,24 @@ type KeepDisk struct {
        SSL      bool   `json:"service_ssl_flag"`
 }
 
-func MakeKeepClient() (kc *KeepClient, err error) {
-       kc = &KeepClient{
-               ApiServer:   os.Getenv("ARVADOS_API_HOST"),
-               ApiToken:    os.Getenv("ARVADOS_API_TOKEN"),
-               ApiInsecure: (os.Getenv("ARVADOS_API_HOST_INSECURE") != "")}
-
+func MakeKeepClient() (kc KeepClient, err error) {
        tr := &http.Transport{
                TLSClientConfig: &tls.Config{InsecureSkipVerify: kc.ApiInsecure},
        }
 
-       kc.client = &http.Client{Transport: tr}
+       kc = KeepClient{
+               ApiServer:     os.Getenv("ARVADOS_API_HOST"),
+               ApiToken:      os.Getenv("ARVADOS_API_TOKEN"),
+               ApiInsecure:   (os.Getenv("ARVADOS_API_HOST_INSECURE") != ""),
+               Want_replicas: 2,
+               Client:        &http.Client{Transport: tr}}
 
-       err = kc.DiscoverKeepDisks()
+       err = (&kc).DiscoverKeepServers()
 
        return kc, err
 }
 
-func (this *KeepClient) DiscoverKeepDisks() error {
+func (this *KeepClient) DiscoverKeepServers() error {
        // Construct request of keep disk list
        var req *http.Request
        var err error
@@ -60,7 +66,7 @@ func (this *KeepClient) DiscoverKeepDisks() error {
 
        // Make the request
        var resp *http.Response
-       if resp, err = this.client.Do(req); err != nil {
+       if resp, err = this.Client.Do(req); err != nil {
                return err
        }
 
@@ -163,7 +169,6 @@ func ReadIntoBuffer(buffer []byte, r io.Reader, slices chan<- ReaderSlice) {
        // Initially use entire buffer as scratch space
        ptr := buffer[:]
        for {
-               log.Printf("ReadIntoBuffer doing read")
                var n int
                var err error
                if len(ptr) > 0 {
@@ -187,18 +192,13 @@ func ReadIntoBuffer(buffer []byte, r io.Reader, slices chan<- ReaderSlice) {
 
                // End on error (includes EOF)
                if err != nil {
-                       log.Printf("ReadIntoBuffer sending error %d %s", n, err.Error())
                        slices <- ReaderSlice{nil, err}
                        return
                }
 
-               log.Printf("ReadIntoBuffer got %d", n)
-
                if n > 0 {
-                       log.Printf("ReadIntoBuffer sending readerslice")
                        // Make a slice with the contents of the read
                        slices <- ReaderSlice{ptr[:n], nil}
-                       log.Printf("ReadIntoBuffer sent readerslice")
 
                        // Adjust the scratch space slice
                        ptr = ptr[n:]
@@ -208,15 +208,15 @@ func ReadIntoBuffer(buffer []byte, r io.Reader, slices chan<- ReaderSlice) {
 
 // A read request to the Transfer() function
 type ReadRequest struct {
-       offset int
-       p      []byte
-       result chan<- ReadResult
+       offset  int
+       maxsize int
+       result  chan<- ReadResult
 }
 
 // A read result from the Transfer() function
 type ReadResult struct {
-       n   int
-       err error
+       slice []byte
+       err   error
 }
 
 // Reads from the buffer managed by the Transfer()
@@ -232,16 +232,41 @@ func MakeBufferReader(requests chan<- ReadRequest) BufferReader {
 
 // Reads from the buffer managed by the Transfer()
 func (this BufferReader) Read(p []byte) (n int, err error) {
-       this.requests <- ReadRequest{*this.offset, p, this.responses}
+       this.requests <- ReadRequest{*this.offset, len(p), this.responses}
        rr, valid := <-this.responses
        if valid {
-               *this.offset += rr.n
-               return rr.n, rr.err
+               *this.offset += len(rr.slice)
+               return copy(p, rr.slice), rr.err
        } else {
                return 0, io.ErrUnexpectedEOF
        }
 }
 
+func (this BufferReader) WriteTo(dest io.Writer) (written int64, err error) {
+       // Record starting offset in order to correctly report the number of bytes sent
+       starting_offset := *this.offset
+       for {
+               this.requests <- ReadRequest{*this.offset, 32 * 1024, this.responses}
+               rr, valid := <-this.responses
+               if valid {
+                       log.Printf("WriteTo slice %v %d %v", *this.offset, len(rr.slice), rr.err)
+                       *this.offset += len(rr.slice)
+                       if rr.err != nil {
+                               if rr.err == io.EOF {
+                                       // EOF is not an error.
+                                       return int64(*this.offset - starting_offset), nil
+                               } else {
+                                       return int64(*this.offset - starting_offset), rr.err
+                               }
+                       } else {
+                               dest.Write(rr.slice)
+                       }
+               } else {
+                       return int64(*this.offset), io.ErrUnexpectedEOF
+               }
+       }
+}
+
 // Close the responses channel
 func (this BufferReader) Close() error {
        close(this.responses)
@@ -251,12 +276,18 @@ func (this BufferReader) Close() error {
 // Handle a read request.  Returns true if a response was sent, and false if
 // the request should be queued.
 func HandleReadRequest(req ReadRequest, body []byte, complete bool) bool {
-       log.Printf("HandleReadRequest %d %d %t", req.offset, len(body), complete)
+       log.Printf("HandleReadRequest %d %d %d", req.offset, req.maxsize, len(body))
        if req.offset < len(body) {
-               req.result <- ReadResult{copy(req.p, body[req.offset:]), nil}
+               var end int
+               if req.offset+req.maxsize < len(body) {
+                       end = req.offset + req.maxsize
+               } else {
+                       end = len(body)
+               }
+               req.result <- ReadResult{body[req.offset:end], nil}
                return true
        } else if complete && req.offset >= len(body) {
-               req.result <- ReadResult{0, io.EOF}
+               req.result <- ReadResult{nil, io.EOF}
                return true
        } else {
                return false
@@ -282,6 +313,7 @@ func Transfer(source_buffer []byte, source_reader io.Reader, requests <-chan Rea
                body = source_buffer[:0]
 
                // used to communicate slices of the buffer as they are
+               // ReadIntoBuffer will close 'slices' when it is done with it
                slices = make(chan ReaderSlice)
 
                // Spin it off
@@ -297,14 +329,11 @@ func Transfer(source_buffer []byte, source_reader io.Reader, requests <-chan Rea
        pending_requests := make([]ReadRequest, 0)
 
        for {
-               log.Printf("Doing select")
                select {
                case req, valid := <-requests:
-                       log.Printf("Got read request")
                        // Handle a buffer read request
                        if valid {
                                if !HandleReadRequest(req, body, complete) {
-                                       log.Printf("Queued")
                                        pending_requests = append(pending_requests, req)
                                }
                        } else {
@@ -315,8 +344,6 @@ func Transfer(source_buffer []byte, source_reader io.Reader, requests <-chan Rea
                case bk, valid := <-slices:
                        // Got a new slice from the reader
                        if valid {
-                               log.Printf("Got readerslice %d", len(bk.slice))
-
                                if bk.reader_error != nil {
                                        reader_error <- bk.reader_error
                                        if bk.reader_error == io.EOF {
@@ -338,7 +365,6 @@ func Transfer(source_buffer []byte, source_reader io.Reader, requests <-chan Rea
                                n := 0
                                for n < len(pending_requests) {
                                        if HandleReadRequest(pending_requests[n], body, complete) {
-                                               log.Printf("ReadRequest handled")
 
                                                // move the element from the
                                                // back of the slice to
@@ -347,7 +373,6 @@ func Transfer(source_buffer []byte, source_reader io.Reader, requests <-chan Rea
                                                pending_requests[n] = pending_requests[len(pending_requests)-1]
                                                pending_requests = pending_requests[0 : len(pending_requests)-1]
                                        } else {
-                                               log.Printf("ReadRequest re-queued")
 
                                                // Request wasn't handled, so keep it in the request slice
                                                n += 1
@@ -367,39 +392,51 @@ func Transfer(source_buffer []byte, source_reader io.Reader, requests <-chan Rea
        }
 }
 
-type UploadError struct {
-       err error
-       url string
+type UploadStatus struct {
+       Err        error
+       Url        string
+       StatusCode int
 }
 
-func (this KeepClient) uploadToKeepServer(host string, hash string, body io.ReadCloser, upload_status chan<- UploadError) {
+func (this KeepClient) uploadToKeepServer(host string, hash string, body io.ReadCloser,
+       upload_status chan<- UploadStatus, expectedLength int64) {
+
+       log.Printf("Uploading to %s", host)
+
        var req *http.Request
        var err error
        var url = fmt.Sprintf("%s/%s", host, hash)
        if req, err = http.NewRequest("PUT", url, nil); err != nil {
-               upload_status <- UploadError{err, url}
+               upload_status <- UploadStatus{err, url, 0}
                return
        }
 
+       if expectedLength > 0 {
+               req.ContentLength = expectedLength
+       }
+
        req.Header.Add("Authorization", fmt.Sprintf("OAuth2 %s", this.ApiToken))
+       req.Header.Add("Content-Type", "application/octet-stream")
        req.Body = body
 
        var resp *http.Response
-       if resp, err = this.client.Do(req); err != nil {
-               upload_status <- UploadError{err, url}
+       if resp, err = this.Client.Do(req); err != nil {
+               upload_status <- UploadStatus{err, url, 0}
+               return
        }
 
        if resp.StatusCode == http.StatusOK {
-               upload_status <- UploadError{io.EOF, url}
+               upload_status <- UploadStatus{nil, url, resp.StatusCode}
+       } else {
+               upload_status <- UploadStatus{errors.New(resp.Status), url, resp.StatusCode}
        }
 }
 
-var KeepWriteError = errors.New("Could not write sufficient replicas")
-
 func (this KeepClient) putReplicas(
        hash string,
        requests chan ReadRequest,
-       reader_status chan error) error {
+       reader_status chan error,
+       expectedLength int64) (replicas int, err error) {
 
        // Calculate the ordering for uploading to servers
        sv := this.ShuffledServiceRoots(hash)
@@ -411,21 +448,21 @@ func (this KeepClient) putReplicas(
        active := 0
 
        // Used to communicate status from the upload goroutines
-       upload_status := make(chan UploadError)
+       upload_status := make(chan UploadStatus)
        defer close(upload_status)
 
        // Desired number of replicas
-       want_replicas := this.Want_replicas
+       remaining_replicas := this.Want_replicas
 
-       for want_replicas > 0 {
-               for active < want_replicas {
+       for remaining_replicas > 0 {
+               for active < remaining_replicas {
                        // Start some upload requests
                        if next_server < len(sv) {
-                               go this.uploadToKeepServer(sv[next_server], hash, MakeBufferReader(requests), upload_status)
+                               go this.uploadToKeepServer(sv[next_server], hash, MakeBufferReader(requests), upload_status, expectedLength)
                                next_server += 1
                                active += 1
                        } else {
-                               return KeepWriteError
+                               return (this.Want_replicas - remaining_replicas), InsufficientReplicasError
                        }
                }
 
@@ -436,27 +473,39 @@ func (this KeepClient) putReplicas(
                                // good news!
                        } else {
                                // bad news
-                               return status
+                               return (this.Want_replicas - remaining_replicas), status
                        }
                case status := <-upload_status:
-                       if status.err == io.EOF {
+                       if status.StatusCode == 200 {
                                // good news!
-                               want_replicas -= 1
+                               remaining_replicas -= 1
                        } else {
                                // writing to keep server failed for some reason
-                               log.Printf("Got error %s uploading to %s", status.err, status.url)
+                               log.Printf("Keep server put to %v failed with '%v'",
+                                       status.Url, status.Err)
                        }
                        active -= 1
+                       log.Printf("Upload status %v %v %v", status.StatusCode, remaining_replicas, active)
                }
        }
 
-       return nil
+       return (this.Want_replicas - remaining_replicas), nil
 }
 
-func (this KeepClient) PutHR(hash string, r io.Reader) error {
+var OversizeBlockError = errors.New("Block too big")
+
+func (this KeepClient) PutHR(hash string, r io.Reader, expectedLength int64) (replicas int, err error) {
 
        // Buffer for reads from 'r'
-       buffer := make([]byte, 64*1024*1024)
+       var buffer []byte
+       if expectedLength > 0 {
+               if expectedLength > BLOCKSIZE {
+                       return 0, OversizeBlockError
+               }
+               buffer = make([]byte, expectedLength)
+       } else {
+               buffer = make([]byte, BLOCKSIZE)
+       }
 
        // Read requests on Transfer() buffer
        requests := make(chan ReadRequest)
@@ -464,14 +513,15 @@ func (this KeepClient) PutHR(hash string, r io.Reader) error {
 
        // Reporting reader error states
        reader_status := make(chan error)
+       defer close(reader_status)
 
        // Start the transfer goroutine
        go Transfer(buffer, r, requests, reader_status)
 
-       return this.putReplicas(hash, requests, reader_status)
+       return this.putReplicas(hash, requests, reader_status, expectedLength)
 }
 
-func (this KeepClient) PutHB(hash string, buffer []byte) error {
+func (this KeepClient) PutHB(hash string, buffer []byte) (replicas int, err error) {
        // Read requests on Transfer() buffer
        requests := make(chan ReadRequest)
        defer close(requests)
@@ -479,17 +529,100 @@ func (this KeepClient) PutHB(hash string, buffer []byte) error {
        // Start the transfer goroutine
        go Transfer(buffer, nil, requests, nil)
 
-       return this.putReplicas(hash, requests, nil)
+       return this.putReplicas(hash, requests, nil, int64(len(buffer)))
 }
 
-func (this KeepClient) PutB(buffer []byte) error {
-       return this.PutHB(fmt.Sprintf("%x", md5.Sum(buffer)), buffer)
+func (this KeepClient) PutB(buffer []byte) (hash string, replicas int, err error) {
+       hash = fmt.Sprintf("%x", md5.Sum(buffer))
+       replicas, err = this.PutHB(hash, buffer)
+       return hash, replicas, err
 }
 
-func (this KeepClient) PutR(r io.Reader) error {
+func (this KeepClient) PutR(r io.Reader) (hash string, replicas int, err error) {
        if buffer, err := ioutil.ReadAll(r); err != nil {
-               return err
+               return "", 0, err
        } else {
                return this.PutB(buffer)
        }
 }
+
+func (this KeepClient) Get(hash string) (reader io.ReadCloser,
+       contentLength int64, url string, err error) {
+       return this.AuthorizedGet(hash, "", "")
+}
+
+func (this KeepClient) AuthorizedGet(hash string,
+       signature string,
+       timestamp string) (reader io.ReadCloser,
+       contentLength int64, url string, err error) {
+
+       // Calculate the ordering for asking servers
+       sv := this.ShuffledServiceRoots(hash)
+
+       for _, host := range sv {
+               var req *http.Request
+               var err error
+               var url string
+               if signature != "" {
+                       url = fmt.Sprintf("%s/%s+A%s@%s", host, hash,
+                               signature, timestamp)
+               } else {
+                       url = fmt.Sprintf("%s/%s", host, hash)
+               }
+               if req, err = http.NewRequest("GET", url, nil); err != nil {
+                       continue
+               }
+
+               req.Header.Add("Authorization", fmt.Sprintf("OAuth2 %s", this.ApiToken))
+
+               var resp *http.Response
+               if resp, err = this.Client.Do(req); err != nil {
+                       continue
+               }
+
+               if resp.StatusCode == http.StatusOK {
+                       return resp.Body, resp.ContentLength, url, nil
+               }
+       }
+
+       return nil, 0, "", BlockNotFound
+}
+
+func (this KeepClient) Ask(hash string) (contentLength int64, url string, err error) {
+       return this.AuthorizedAsk(hash, "", "")
+}
+
+func (this KeepClient) AuthorizedAsk(hash string, signature string,
+       timestamp string) (contentLength int64, url string, err error) {
+       // Calculate the ordering for asking servers
+       sv := this.ShuffledServiceRoots(hash)
+
+       for _, host := range sv {
+               var req *http.Request
+               var err error
+               if signature != "" {
+                       url = fmt.Sprintf("%s/%s+A%s@%s", host, hash,
+                               signature, timestamp)
+               } else {
+                       url = fmt.Sprintf("%s/%s", host, hash)
+               }
+
+               if req, err = http.NewRequest("HEAD", url, nil); err != nil {
+                       continue
+               }
+
+               req.Header.Add("Authorization", fmt.Sprintf("OAuth2 %s", this.ApiToken))
+
+               var resp *http.Response
+               if resp, err = this.Client.Do(req); err != nil {
+                       continue
+               }
+
+               if resp.StatusCode == http.StatusOK {
+                       return resp.ContentLength, url, nil
+               }
+       }
+
+       return 0, "", BlockNotFound
+
+}