2798: Completed move of Transfer() related code out to 'buffer' package.
[arvados.git] / sdk / go / src / arvados.org / keepclient / keepclient.go
index ba735235f12b524840f58a7cf5201e40c93e4e4d..829ab0efb043414c6e6c606175af784762c02c5c 100644 (file)
@@ -1,17 +1,34 @@
 package keepclient
 
 import (
+       "arvados.org/buffer"
+       "crypto/md5"
        "crypto/tls"
        "encoding/json"
+       "errors"
        "fmt"
        "io"
+       "io/ioutil"
+       "log"
        "net/http"
+       "os"
        "sort"
        "strconv"
 )
 
+// A Keep "block" is 64MB.
+const BLOCKSIZE = 64 * 1024 * 1024
+
+var BlockNotFound = errors.New("Block not found")
+var InsufficientReplicasError = errors.New("Could not write sufficient replicas")
+
 type KeepClient struct {
+       ApiServer     string
+       ApiToken      string
+       ApiInsecure   bool
        Service_roots []string
+       Want_replicas int
+       Client        *http.Client
 }
 
 type KeepDisk struct {
@@ -20,51 +37,75 @@ type KeepDisk struct {
        SSL      bool   `json:"service_ssl_flag"`
 }
 
-func KeepDisks() (service_roots []string, err error) {
+func MakeKeepClient() (kc KeepClient, err error) {
        tr := &http.Transport{
-               TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
+               TLSClientConfig: &tls.Config{InsecureSkipVerify: kc.ApiInsecure},
        }
-       client := &http.Client{Transport: tr}
 
+       kc = KeepClient{
+               ApiServer:     os.Getenv("ARVADOS_API_HOST"),
+               ApiToken:      os.Getenv("ARVADOS_API_TOKEN"),
+               ApiInsecure:   (os.Getenv("ARVADOS_API_HOST_INSECURE") != ""),
+               Want_replicas: 2,
+               Client:        &http.Client{Transport: tr}}
+
+       err = (&kc).DiscoverKeepServers()
+
+       return kc, err
+}
+
+func (this *KeepClient) DiscoverKeepServers() error {
+       // Construct request of keep disk list
        var req *http.Request
-       if req, err = http.NewRequest("GET", "https://localhost:3001/arvados/v1/keep_disks", nil); err != nil {
-               return nil, err
+       var err error
+       if req, err = http.NewRequest("GET", fmt.Sprintf("https://%s/arvados/v1/keep_disks", this.ApiServer), nil); err != nil {
+               return err
        }
 
+       // Add api token header
+       req.Header.Add("Authorization", fmt.Sprintf("OAuth2 %s", this.ApiToken))
+
+       // Make the request
        var resp *http.Response
-       req.Header.Add("Authorization", "OAuth2 4axaw8zxe0qm22wa6urpp5nskcne8z88cvbupv653y1njyi05h")
-       if resp, err = client.Do(req); err != nil {
-               return nil, err
+       if resp, err = this.Client.Do(req); err != nil {
+               return err
        }
 
        type SvcList struct {
                Items []KeepDisk `json:"items"`
        }
+
+       // Decode json reply
        dec := json.NewDecoder(resp.Body)
        var m SvcList
        if err := dec.Decode(&m); err != nil {
-               return nil, err
+               return err
        }
 
-       service_roots = make([]string, len(m.Items))
-       for index, element := range m.Items {
+       listed := make(map[string]bool)
+       this.Service_roots = make([]string, 0, len(m.Items))
+
+       for _, element := range m.Items {
                n := ""
                if element.SSL {
                        n = "s"
                }
-               service_roots[index] = fmt.Sprintf("http%s://%s:%d",
-                       n, element.Hostname, element.Port)
-       }
-       sort.Strings(service_roots)
-       return service_roots, nil
-}
 
-func MakeKeepClient() (kc *KeepClient, err error) {
-       sv, err := KeepDisks()
-       if err != nil {
-               return nil, err
+               // Construct server URL
+               url := fmt.Sprintf("http%s://%s:%d", n, element.Hostname, element.Port)
+
+               // Skip duplicates
+               if !listed[url] {
+                       listed[url] = true
+                       this.Service_roots = append(this.Service_roots, url)
+               }
        }
-       return &KeepClient{sv}, nil
+
+       // Must be sorted for ShuffledServiceRoots() to produce consistent
+       // results.
+       sort.Strings(this.Service_roots)
+
+       return nil
 }
 
 func (this KeepClient) ShuffledServiceRoots(hash string) (pseq []string) {
@@ -115,97 +156,237 @@ func (this KeepClient) ShuffledServiceRoots(hash string) (pseq []string) {
        return pseq
 }
 
-func ReadIntoBuffer(buffer []byte, r io.Reader, c chan []byte, reader_error chan error) {
-       ptr := buffer[:]
-       for {
-               n, err := r.Read(ptr)
-               if err != nil {
-                       reader_error <- err
-                       return
-               }
-               c <- ptr[:n]
-               ptr = ptr[n:]
+type UploadStatus struct {
+       Err        error
+       Url        string
+       StatusCode int
+}
+
+func (this KeepClient) uploadToKeepServer(host string, hash string, body io.ReadCloser,
+       upload_status chan<- UploadStatus, expectedLength int64) {
+
+       log.Printf("Uploading to %s", host)
+
+       var req *http.Request
+       var err error
+       var url = fmt.Sprintf("%s/%s", host, hash)
+       if req, err = http.NewRequest("PUT", url, nil); err != nil {
+               upload_status <- UploadStatus{err, url, 0}
+               return
+       }
+
+       if expectedLength > 0 {
+               req.ContentLength = expectedLength
+       }
+
+       req.Header.Add("Authorization", fmt.Sprintf("OAuth2 %s", this.ApiToken))
+       req.Header.Add("Content-Type", "application/octet-stream")
+       req.Body = body
+
+       var resp *http.Response
+       if resp, err = this.Client.Do(req); err != nil {
+               upload_status <- UploadStatus{err, url, 0}
+               return
+       }
+
+       if resp.StatusCode == http.StatusOK {
+               upload_status <- UploadStatus{nil, url, resp.StatusCode}
+       } else {
+               upload_status <- UploadStatus{errors.New(resp.Status), url, resp.StatusCode}
        }
 }
 
-type Sink struct {
-       out io.Writer
-       err chan error
+func (this KeepClient) putReplicas(
+       hash string,
+       requests chan buffer.ReadRequest,
+       reader_status chan error,
+       expectedLength int64) (replicas int, err error) {
+
+       // Calculate the ordering for uploading to servers
+       sv := this.ShuffledServiceRoots(hash)
+
+       // The next server to try contacting
+       next_server := 0
+
+       // The number of active writers
+       active := 0
+
+       // Used to communicate status from the upload goroutines
+       upload_status := make(chan UploadStatus)
+       defer close(upload_status)
+
+       // Desired number of replicas
+       remaining_replicas := this.Want_replicas
+
+       for remaining_replicas > 0 {
+               for active < remaining_replicas {
+                       // Start some upload requests
+                       if next_server < len(sv) {
+                               go this.uploadToKeepServer(sv[next_server], hash, buffer.MakeBufferReader(requests), upload_status, expectedLength)
+                               next_server += 1
+                               active += 1
+                       } else {
+                               return (this.Want_replicas - remaining_replicas), InsufficientReplicasError
+                       }
+               }
+
+               // Now wait for something to happen.
+               select {
+               case status := <-reader_status:
+                       if status == io.EOF {
+                               // good news!
+                       } else {
+                               // bad news
+                               return (this.Want_replicas - remaining_replicas), status
+                       }
+               case status := <-upload_status:
+                       if status.StatusCode == 200 {
+                               // good news!
+                               remaining_replicas -= 1
+                       } else {
+                               // writing to keep server failed for some reason
+                               log.Printf("Keep server put to %v failed with '%v'",
+                                       status.Url, status.Err)
+                       }
+                       active -= 1
+                       log.Printf("Upload status %v %v %v", status.StatusCode, remaining_replicas, active)
+               }
+       }
+
+       return (this.Want_replicas - remaining_replicas), nil
 }
 
-// Transfer data from a buffer into one or more 'sinks'.
-//
-// Forwards all data read to the writers in "Sinks", including any previous
-// reads into the buffer.  Either one of buffer or io.Reader must be valid, and
-// the other must be nil.  If 'source' is valid, it will read from the reader,
-// store the data in the buffer, and send the data to the sinks.  Otherwise
-// 'buffer' must be valid, and it will send the contents of the buffer to the
-// sinks.
-func Transfer(buffer []byte, source io.Reader, sinks chan Sink, errorchan chan error) {
-       // currently buffered data
-       var ptr []byte
+var OversizeBlockError = errors.New("Block too big")
 
-       // for receiving slices from ReadIntoBuffer
-       var slices chan []byte
+func (this KeepClient) PutHR(hash string, r io.Reader, expectedLength int64) (replicas int, err error) {
 
-       // indicates whether the buffered data is complete
-       var complete bool = false
+       // Buffer for reads from 'r'
+       var buf []byte
+       if expectedLength > 0 {
+               if expectedLength > BLOCKSIZE {
+                       return 0, OversizeBlockError
+               }
+               buf = make([]byte, expectedLength)
+       } else {
+               buf = make([]byte, BLOCKSIZE)
+       }
 
-       // for receiving errors from ReadIntoBuffer
-       var reader_error chan error = nil
+       // Read requests on Transfer() buffer
+       requests := make(chan buffer.ReadRequest)
+       defer close(requests)
 
-       if source != nil {
-               // allocate the scratch buffer at 64 MiB
-               buffer = make([]byte, 1024*1024*64)
+       // Reporting reader error states
+       reader_status := make(chan error)
+       defer close(reader_status)
 
-               // 'ptr' is a slice indicating the buffer slice that has been
-               // read so far
-               ptr = buffer[0:0]
+       // Start the transfer goroutine
+       go buffer.Transfer(buf, r, requests, reader_status)
 
-               // used to communicate slices of the buffer as read
-               slices := make(chan []byte)
+       return this.putReplicas(hash, requests, reader_status, expectedLength)
+}
 
-               // communicate read errors
-               reader_error = make(chan error)
+func (this KeepClient) PutHB(hash string, buf []byte) (replicas int, err error) {
+       // Read requests on Transfer() buffer
+       requests := make(chan buffer.ReadRequest)
+       defer close(requests)
 
-               // Spin it off
-               go ReadIntoBuffer(buffer, source, slices, reader_error)
-       } else {
-               // use the whole buffer
-               ptr = buffer[:]
+       // Start the transfer goroutine
+       go buffer.Transfer(buf, nil, requests, nil)
 
-               // that's it
-               complete = true
+       return this.putReplicas(hash, requests, nil, int64(len(buf)))
+}
+
+func (this KeepClient) PutB(buffer []byte) (hash string, replicas int, err error) {
+       hash = fmt.Sprintf("%x", md5.Sum(buffer))
+       replicas, err = this.PutHB(hash, buffer)
+       return hash, replicas, err
+}
+
+func (this KeepClient) PutR(r io.Reader) (hash string, replicas int, err error) {
+       if buffer, err := ioutil.ReadAll(r); err != nil {
+               return "", 0, err
+       } else {
+               return this.PutB(buffer)
        }
+}
 
-       // list of sinks to send to
-       sinks_slice := make([]io.Writer, 0)
+func (this KeepClient) Get(hash string) (reader io.ReadCloser,
+       contentLength int64, url string, err error) {
+       return this.AuthorizedGet(hash, "", "")
+}
 
-       select {
-       case e := <-reader_error:
-               // barf
-       case s, valid := <-sinks:
-               if !valid {
-                       // sinks channel closed
-                       return
+func (this KeepClient) AuthorizedGet(hash string,
+       signature string,
+       timestamp string) (reader io.ReadCloser,
+       contentLength int64, url string, err error) {
+
+       // Calculate the ordering for asking servers
+       sv := this.ShuffledServiceRoots(hash)
+
+       for _, host := range sv {
+               var req *http.Request
+               var err error
+               var url string
+               if signature != "" {
+                       url = fmt.Sprintf("%s/%s+A%s@%s", host, hash,
+                               signature, timestamp)
+               } else {
+                       url = fmt.Sprintf("%s/%s", host, hash)
                }
-               sinks_slice = append(sinks_slice, s)
-               go s.Write(ptr)
-       case bk := <-slices:
-               ptr = buffer[0 : len(ptr)+len(bk)]
-               for _, s := range sinks_slice {
-                       go s.Write(bk)
+               if req, err = http.NewRequest("GET", url, nil); err != nil {
+                       continue
+               }
+
+               req.Header.Add("Authorization", fmt.Sprintf("OAuth2 %s", this.ApiToken))
+
+               var resp *http.Response
+               if resp, err = this.Client.Do(req); err != nil {
+                       continue
+               }
+
+               if resp.StatusCode == http.StatusOK {
+                       return resp.Body, resp.ContentLength, url, nil
                }
        }
+
+       return nil, 0, "", BlockNotFound
+}
+
+func (this KeepClient) Ask(hash string) (contentLength int64, url string, err error) {
+       return this.AuthorizedAsk(hash, "", "")
 }
 
-func (this KeepClient) KeepPut(hash string, r io.Reader) {
-       //sv := this.ShuffledServiceRoots(hash)
-       //n := 0
+func (this KeepClient) AuthorizedAsk(hash string, signature string,
+       timestamp string) (contentLength int64, url string, err error) {
+       // Calculate the ordering for asking servers
+       sv := this.ShuffledServiceRoots(hash)
+
+       for _, host := range sv {
+               var req *http.Request
+               var err error
+               if signature != "" {
+                       url = fmt.Sprintf("%s/%s+A%s@%s", host, hash,
+                               signature, timestamp)
+               } else {
+                       url = fmt.Sprintf("%s/%s", host, hash)
+               }
+
+               if req, err = http.NewRequest("HEAD", url, nil); err != nil {
+                       continue
+               }
+
+               req.Header.Add("Authorization", fmt.Sprintf("OAuth2 %s", this.ApiToken))
+
+               var resp *http.Response
+               if resp, err = this.Client.Do(req); err != nil {
+                       continue
+               }
+
+               if resp.StatusCode == http.StatusOK {
+                       return resp.ContentLength, url, nil
+               }
+       }
 
-       //success := make(chan int)
-       sinks := make(chan []io.Writer)
-       errorchan := make(chan error)
+       return 0, "", BlockNotFound
 
-       go Transfer(nil, r, reads, errorchan)
 }