7241: Stub Azure API calls
authorTom Clegg <tom@curoverse.com>
Thu, 24 Sep 2015 22:31:43 +0000 (18:31 -0400)
committerTom Clegg <tom@curoverse.com>
Fri, 25 Sep 2015 19:31:52 +0000 (15:31 -0400)
services/keepstore/azure_blob_volume.go
services/keepstore/azure_blob_volume_test.go
services/keepstore/collision.go
services/keepstore/volume_generic_test.go
services/keepstore/volume_unix.go

index 0d0e5462258b47d7bf45c4d91761d0512416faea..35b1dc79bfc28b3a2c336f26dd3732bee2f07afc 100644 (file)
@@ -7,6 +7,7 @@ import (
        "io"
        "io/ioutil"
        "log"
+       "os"
        "strings"
        "time"
 
@@ -18,6 +19,18 @@ var (
        azureStorageAccountKeyFile string
 )
 
+func readKeyFromFile(file string) (string, error) {
+       buf, err := ioutil.ReadFile(file)
+       if err != nil {
+               return "", errors.New("reading key from " + file + ": " + err.Error())
+       }
+       accountKey := strings.TrimSpace(string(buf))
+       if accountKey == "" {
+               return "", errors.New("empty account key in " + file)
+       }
+       return accountKey, nil
+}
+
 type azureVolumeAdder struct {
        *volumeSet
 }
@@ -26,13 +39,9 @@ func (s *azureVolumeAdder) Set(containerName string) error {
        if containerName == "" {
                return errors.New("no container name given")
        }
-       buf, err := ioutil.ReadFile(azureStorageAccountKeyFile)
+       accountKey, err := readKeyFromFile(azureStorageAccountKeyFile)
        if err != nil {
-               return errors.New("reading key from " + azureStorageAccountKeyFile + ": " + err.Error())
-       }
-       accountKey := strings.TrimSpace(string(buf))
-       if accountKey == "" {
-               return errors.New("empty account key in " + azureStorageAccountKeyFile)
+               return err
        }
        azClient, err := storage.NewBasicClient(azureStorageAccountName, accountKey)
        if err != nil {
@@ -98,6 +107,16 @@ func (v *AzureBlobVolume) Check() error {
 func (v *AzureBlobVolume) Get(loc string) ([]byte, error) {
        rdr, err := v.bsClient.GetBlob(v.containerName, loc)
        if err != nil {
+               if strings.Contains(err.Error(), "404 Not Found") {
+                       // "storage: service returned without a response body (404 Not Found)"
+                       return nil, os.ErrNotExist
+               }
+               return nil, err
+       }
+       switch err := err.(type) {
+       case nil:
+       default:
+               log.Printf("ERROR IN Get(): %T %#v", err, err)
                return nil, err
        }
        defer rdr.Close()
@@ -112,11 +131,19 @@ func (v *AzureBlobVolume) Get(loc string) ([]byte, error) {
        }
 }
 
-func (v *AzureBlobVolume) Compare(loc string, data []byte) error {
-       return NotFoundError
+func (v *AzureBlobVolume) Compare(loc string, expect []byte) error {
+       rdr, err := v.bsClient.GetBlob(v.containerName, loc)
+       if err != nil {
+               return err
+       }
+       defer rdr.Close()
+       return compareReaderWithBuf(rdr, expect, loc[:32])
 }
 
 func (v *AzureBlobVolume) Put(loc string, block []byte) error {
+       if v.readonly {
+               return MethodDisabledError
+       }
        if err := v.bsClient.CreateBlockBlob(v.containerName, loc); err != nil {
                return err
        }
@@ -128,6 +155,14 @@ func (v *AzureBlobVolume) Put(loc string, block []byte) error {
 }
 
 func (v *AzureBlobVolume) Touch(loc string) error {
+       if v.readonly {
+               return MethodDisabledError
+       }
+       if exists, err := v.bsClient.BlobExists(v.containerName, loc); err != nil {
+               return err
+       } else if !exists {
+               return os.ErrNotExist
+       }
        return v.bsClient.PutBlockList(v.containerName, loc, []storage.Block{{"MA==", storage.BlockStatusCommitted}})
 }
 
@@ -153,7 +188,7 @@ func (v *AzureBlobVolume) IndexTo(prefix string, writer io.Writer) error {
                        if err != nil {
                                return err
                        }
-                       fmt.Fprintf(writer, "%s+%d\n", b.Name, t.Unix())
+                       fmt.Fprintf(writer, "%s+%d %d\n", b.Name, b.Properties.ContentLength, t.Unix())
                }
                if resp.NextMarker == "" {
                        return nil
@@ -164,6 +199,9 @@ func (v *AzureBlobVolume) IndexTo(prefix string, writer io.Writer) error {
 
 func (v *AzureBlobVolume) Delete(loc string) error {
        // TODO: Use leases to handle races with Touch and Put.
+       if v.readonly {
+               return MethodDisabledError
+       }
        if t, err := v.Mtime(loc); err != nil {
                return err
        } else if time.Since(t) < blobSignatureTTL {
index 59021c0b3f2789ddfbebe063a9b695ea08979411..619c013a3164aada97147fbe348aa954edaf2457 100644 (file)
@@ -1,12 +1,19 @@
 package main
 
 import (
+       "encoding/base64"
+       "encoding/xml"
+       "flag"
+       "io/ioutil"
        "log"
        "net"
        "net/http"
        "net/http/httptest"
        "regexp"
+       "sort"
+       "strconv"
        "strings"
+       "sync"
        "testing"
        "time"
 
@@ -19,9 +26,187 @@ const (
        emulatorAccountKey = "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw=="
 )
 
-type azStubHandler struct {}
+var azureTestContainer string
 
-func (azStubHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
+func init() {
+       flag.StringVar(
+               &azureTestContainer,
+               "test.azure-storage-container-volume",
+               "",
+               "Name of Azure container to use for testing. Do not use a container with real data! Use -azure-storage-account-name and -azure-storage-key-file arguments to supply credentials.")
+}
+
+type azBlob struct{
+       Data        []byte
+       Mtime       time.Time
+       Uncommitted map[string][]byte
+}
+
+type azStubHandler struct {
+       sync.Mutex
+       blobs  map[string]*azBlob
+}
+
+func newAzStubHandler() *azStubHandler {
+       return &azStubHandler{
+               blobs: make(map[string]*azBlob),
+       }
+}
+
+func (h *azStubHandler) TouchWithDate(container, hash string, t time.Time) {
+       if blob, ok := h.blobs[container + "|" + hash]; !ok {
+               return
+       } else {
+               blob.Mtime = t
+       }
+}
+
+func (h *azStubHandler) PutRaw(container, hash string, data []byte) {
+       h.Lock()
+       defer h.Unlock()
+       h.blobs[container + "|" + hash] = &azBlob{
+               Data: data,
+               Mtime: time.Now(),
+               Uncommitted: make(map[string][]byte),
+       }
+}
+
+func (h *azStubHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
+       h.Lock()
+       defer h.Unlock()
+       // defer log.Printf("azStubHandler: %+v", r)
+
+       path := strings.Split(r.URL.Path, "/")
+       container := path[1]
+       hash := ""
+       if len(path) > 2 {
+               hash = path[2]
+       }
+
+       if err := r.ParseForm(); err != nil {
+               log.Printf("azStubHandler(%+v): %s", r, err)
+               rw.WriteHeader(http.StatusBadRequest)
+               return
+       }
+
+       body, err := ioutil.ReadAll(r.Body)
+       if err != nil {
+               return
+       }
+
+       type blockListRequestBody struct {
+               XMLName     xml.Name `xml:"BlockList"`
+               Uncommitted []string
+       }
+
+       blob, blobExists := h.blobs[container + "|" + hash]
+
+       switch {
+       case r.Method == "PUT" && r.Form.Get("comp") == "" && r.Header.Get("Content-Length") == "0":
+               rw.WriteHeader(http.StatusCreated)
+               h.blobs[container + "|" + hash] = &azBlob{
+                       Data:  body,
+                       Mtime: time.Now(),
+                       Uncommitted: make(map[string][]byte),
+               }
+       case r.Method == "PUT" && r.Form.Get("comp") == "block":
+               if !blobExists {
+                       log.Printf("Got block for nonexistent blob: %+v", r)
+                       rw.WriteHeader(http.StatusBadRequest)
+                       return
+               }
+               blockID, err := base64.StdEncoding.DecodeString(r.Form.Get("blockid"))
+               if err != nil || len(blockID) == 0 {
+                       log.Printf("Invalid blockid: %+q", r.Form.Get("blockid"))
+                       rw.WriteHeader(http.StatusBadRequest)
+                       return
+               }
+               blob.Uncommitted[string(blockID)] = body
+               rw.WriteHeader(http.StatusCreated)
+       case r.Method == "PUT" && r.Form.Get("comp") == "blocklist":
+               bl := &blockListRequestBody{}
+               if err := xml.Unmarshal(body, bl); err != nil {
+                       log.Printf("xml Unmarshal: %s", err)
+                       rw.WriteHeader(http.StatusBadRequest)
+                       return
+               }
+               for _, encBlockID := range bl.Uncommitted {
+                       blockID, err := base64.StdEncoding.DecodeString(encBlockID)
+                       if err != nil || len(blockID) == 0 || blob.Uncommitted[string(blockID)] == nil {
+                               log.Printf("Invalid blockid: %+q", encBlockID)
+                               rw.WriteHeader(http.StatusBadRequest)
+                               return
+                       }
+                       blob.Data = blob.Uncommitted[string(blockID)]
+                       log.Printf("body %+q, bl %+v, blockID %+q, data %+q", body, bl, blockID, blob.Data)
+               }
+               rw.WriteHeader(http.StatusCreated)
+       case (r.Method == "GET" || r.Method == "HEAD") && hash != "":
+               if !blobExists {
+                       rw.WriteHeader(http.StatusNotFound)
+                       return
+               }
+               rw.Header().Set("Last-Modified", blob.Mtime.Format(time.RFC1123))
+               rw.Header().Set("Content-Length", strconv.Itoa(len(blob.Data)))
+               if r.Method == "GET" {
+                       if _, err := rw.Write(blob.Data); err != nil {
+                               log.Printf("write %+q: %s", blob.Data, err)
+                       }
+               }
+       case r.Method == "DELETE" && hash != "":
+               if !blobExists {
+                       rw.WriteHeader(http.StatusNotFound)
+                       return
+               }
+               delete(h.blobs, container + "|" + hash)
+               rw.WriteHeader(http.StatusAccepted)
+       case r.Method == "GET" && r.Form.Get("comp") == "list" && r.Form.Get("restype") == "container":
+               prefix := container + "|" + r.Form.Get("prefix")
+               marker := r.Form.Get("marker")
+
+               maxResults := 2
+               if n, err := strconv.Atoi(r.Form.Get("maxresults")); err == nil && n >= 1 && n <= 5000 {
+                       maxResults = n
+               }
+
+               resp := storage.BlobListResponse{
+                       Marker: marker,
+                       NextMarker: "",
+                       MaxResults: int64(maxResults),
+               }
+               var hashes sort.StringSlice
+               for k := range h.blobs {
+                       if strings.HasPrefix(k, prefix) {
+                               hashes = append(hashes, k[len(container)+1:])
+                       }
+               }
+               hashes.Sort()
+               for _, hash := range hashes {
+                       if len(resp.Blobs) == maxResults {
+                               resp.NextMarker = hash
+                               break
+                       }
+                       if len(resp.Blobs) > 0 || marker == "" || marker == hash {
+                               blob := h.blobs[container + "|" + hash]
+                               resp.Blobs = append(resp.Blobs, storage.Blob{
+                                       Name: hash,
+                                       Properties: storage.BlobProperties{
+                                               LastModified: blob.Mtime.Format(time.RFC1123),
+                                               ContentLength: int64(len(blob.Data)),
+                                       },
+                               })
+                       }
+               }
+               buf, err := xml.Marshal(resp)
+               if err != nil {
+                       log.Print(err)
+                       rw.WriteHeader(http.StatusInternalServerError)
+               }
+               rw.Write(buf)
+       default:
+               log.Printf("azStubHandler: not implemented: %+v Body:%+q", r, body)
+               rw.WriteHeader(http.StatusNotImplemented)
+       }
 }
 
 // azStubDialer is a net.Dialer that notices when the Azure driver
@@ -34,7 +219,7 @@ type azStubDialer struct {
 var localHostPortRe = regexp.MustCompile(`(127\.0\.0\.1|localhost|\[::1\]):\d+`)
 func (d *azStubDialer) Dial(network, address string) (net.Conn, error) {
        if hp := localHostPortRe.FindString(address); hp != "" {
-               log.Println("custom dialer: dial", hp, "instead of", address)
+               log.Println("azStubDialer: dial", hp, "instead of", address)
                address = hp
        }
        return d.Dialer.Dial(network, address)
@@ -42,23 +227,43 @@ func (d *azStubDialer) Dial(network, address string) (net.Conn, error) {
 
 type TestableAzureBlobVolume struct {
        *AzureBlobVolume
-       azStub *httptest.Server
-       t      *testing.T
+       azHandler *azStubHandler
+       azStub    *httptest.Server
+       t         *testing.T
 }
 
-func NewTestableAzureBlobVolume(t *testing.T, readonly bool) *TestableAzureBlobVolume {
-       azStub := httptest.NewServer(azStubHandler{})
+func NewTestableAzureBlobVolume(t *testing.T, readonly bool) TestableVolume {
+       azHandler := newAzStubHandler()
+       azStub := httptest.NewServer(azHandler)
 
-       stubURLBase := strings.Split(azStub.URL, "://")[1]
-       azClient, err := storage.NewClient(emulatorAccountName, emulatorAccountKey, stubURLBase, storage.DefaultAPIVersion, false)
-       if err != nil {
-               t.Fatal(err)
+       var azClient storage.Client
+
+       container := azureTestContainer
+       if container == "" {
+               // Connect to stub instead of real Azure storage service
+               stubURLBase := strings.Split(azStub.URL, "://")[1]
+               var err error
+               if azClient, err = storage.NewClient(emulatorAccountName, emulatorAccountKey, stubURLBase, storage.DefaultAPIVersion, false); err != nil {
+                       t.Fatal(err)
+               }
+               container = "fakecontainername"
+       } else {
+               // Connect to real Azure storage service
+               accountKey, err := readKeyFromFile(azureStorageAccountKeyFile)
+               if err != nil {
+                       t.Fatal(err)
+               }
+               azClient, err = storage.NewBasicClient(azureStorageAccountName, accountKey)
+               if err != nil {
+                       t.Fatal(err)
+               }
        }
 
-       v := NewAzureBlobVolume(azClient, "fakecontainername", readonly)
+       v := NewAzureBlobVolume(azClient, container, readonly)
 
        return &TestableAzureBlobVolume{
                AzureBlobVolume: v,
+               azHandler: azHandler,
                azStub: azStub,
                t: t,
        }
@@ -89,10 +294,11 @@ func TestReadonlyAzureBlobVolumeWithGeneric(t *testing.T) {
 }
 
 func (v *TestableAzureBlobVolume) PutRaw(locator string, data []byte) {
-       v.Put(locator, data)
+       v.azHandler.PutRaw(v.containerName, locator, data)
 }
 
 func (v *TestableAzureBlobVolume) TouchWithDate(locator string, lastPut time.Time) {
+       v.azHandler.TouchWithDate(v.containerName, locator, lastPut)
 }
 
 func (v *TestableAzureBlobVolume) Teardown() {
index 210286ad75ab3869aaf6a9690f5ef341eb15b549..447ffa05d7b6b926374c1c961b9242f9e5ee5def 100644 (file)
@@ -1,6 +1,7 @@
 package main
 
 import (
+       "bytes"
        "crypto/md5"
        "fmt"
        "io"
@@ -47,3 +48,37 @@ func collisionOrCorrupt(expectMD5 string, buf1, buf2 []byte, rdr io.Reader) erro
        }
        return <-outcome
 }
+
+func compareReaderWithBuf(rdr io.Reader, expect []byte, hash string) error {
+       bufLen := 1 << 20
+       if bufLen > len(expect) && len(expect) > 0 {
+               // No need for bufLen to be longer than
+               // expect, except that len(buf)==0 would
+               // prevent us from handling empty readers the
+               // same way as non-empty readers: reading 0
+               // bytes at a time never reaches EOF.
+               bufLen = len(expect)
+       }
+       buf := make([]byte, bufLen)
+       cmp := expect
+
+       // Loop invariants: all data read so far matched what
+       // we expected, and the first N bytes of cmp are
+       // expected to equal the next N bytes read from
+       // rdr.
+       for {
+               n, err := rdr.Read(buf)
+               if n > len(cmp) || bytes.Compare(cmp[:n], buf[:n]) != 0 {
+                       return collisionOrCorrupt(hash, expect[:len(expect)-len(cmp)], buf[:n], rdr)
+               }
+               cmp = cmp[n:]
+               if err == io.EOF {
+                       if len(cmp) != 0 {
+                               return collisionOrCorrupt(hash, expect[:len(expect)-len(cmp)], nil, nil)
+                       }
+                       return nil
+               } else if err != nil {
+                       return err
+               }
+       }
+}
index b64f3453d1e534a81812794b153fd6b1c0924f6e..503e6b9a58e4e16c2fb47b71b80547a77e4d1d8d 100644 (file)
@@ -171,14 +171,14 @@ func testPutBlockWithDifferentContent(t *testing.T, factory TestableVolumeFactor
                // Put must not return a nil error unless it has
                // overwritten the existing data.
                if bytes.Compare(buf, TestBlock2) != 0 {
-                       t.Errorf("Put succeeded but Get returned %+v, expected %+v", buf, TestBlock2)
+                       t.Errorf("Put succeeded but Get returned %+q, expected %+q", buf, TestBlock2)
                }
        } else {
                // It is permissible for Put to fail, but it must
                // leave us with either the original data, the new
                // data, or nothing at all.
                if getErr == nil && bytes.Compare(buf, TestBlock) != 0 && bytes.Compare(buf, TestBlock2) != 0 {
-                       t.Errorf("Put failed but Get returned %+v, which is neither %+v nor %+v", buf, TestBlock, TestBlock2)
+                       t.Errorf("Put failed but Get returned %+q, which is neither %+q nor %+q", buf, TestBlock, TestBlock2)
                }
        }
        if getErr == nil {
@@ -214,26 +214,32 @@ func testPutMultipleBlocks(t *testing.T, factory TestableVolumeFactory) {
        data, err := v.Get(TestHash)
        if err != nil {
                t.Error(err)
-       } else if bytes.Compare(data, TestBlock) != 0 {
-               t.Errorf("Block present, but content is incorrect: Expected: %v  Found: %v", data, TestBlock)
+       } else {
+               if bytes.Compare(data, TestBlock) != 0 {
+                       t.Errorf("Block present, but got %+q, expected %+q", data, TestBlock)
+               }
+               bufs.Put(data)
        }
-       bufs.Put(data)
 
        data, err = v.Get(TestHash2)
        if err != nil {
                t.Error(err)
-       } else if bytes.Compare(data, TestBlock2) != 0 {
-               t.Errorf("Block present, but content is incorrect: Expected: %v  Found: %v", data, TestBlock2)
+       } else {
+               if bytes.Compare(data, TestBlock2) != 0 {
+                       t.Errorf("Block present, but got %+q, expected %+q", data, TestBlock2)
+               }
+               bufs.Put(data)
        }
-       bufs.Put(data)
 
        data, err = v.Get(TestHash3)
        if err != nil {
                t.Error(err)
-       } else if bytes.Compare(data, TestBlock3) != 0 {
-               t.Errorf("Block present, but content is incorrect: Expected: %v  Found: %v", data, TestBlock3)
+       } else {
+               if bytes.Compare(data, TestBlock3) != 0 {
+                       t.Errorf("Block present, but to %+q, expected %+q", data, TestBlock3)
+               }
+               bufs.Put(data)
        }
-       bufs.Put(data)
 }
 
 // testPutAndTouch
@@ -360,6 +366,7 @@ func testIndexTo(t *testing.T, factory TestableVolumeFactory) {
 func testDeleteNewBlock(t *testing.T, factory TestableVolumeFactory) {
        v := factory(t)
        defer v.Teardown()
+       blobSignatureTTL = 300 * time.Second
 
        if v.Writable() == false {
                return
@@ -373,10 +380,12 @@ func testDeleteNewBlock(t *testing.T, factory TestableVolumeFactory) {
        data, err := v.Get(TestHash)
        if err != nil {
                t.Error(err)
-       } else if bytes.Compare(data, TestBlock) != 0 {
-               t.Error("Block still present, but content is incorrect: %+v != %+v", data, TestBlock)
+       } else {
+               if bytes.Compare(data, TestBlock) != 0 {
+                       t.Errorf("Got data %+q, expected %+q", data, TestBlock)
+               }
+               bufs.Put(data)
        }
-       bufs.Put(data)
 }
 
 // Calling Delete() for a block with a timestamp older than
@@ -385,13 +394,14 @@ func testDeleteNewBlock(t *testing.T, factory TestableVolumeFactory) {
 func testDeleteOldBlock(t *testing.T, factory TestableVolumeFactory) {
        v := factory(t)
        defer v.Teardown()
+       blobSignatureTTL = 300 * time.Second
 
        if v.Writable() == false {
                return
        }
 
        v.Put(TestHash, TestBlock)
-       v.TouchWithDate(TestHash, time.Now().Add(-2*blobSignatureTTL*time.Second))
+       v.TouchWithDate(TestHash, time.Now().Add(-2*blobSignatureTTL))
 
        if err := v.Delete(TestHash); err != nil {
                t.Error(err)
index 5d09e84f525eb1c5ab928d5e8bf3ba01a3ab05f0..f498c3c32d3d1fad71609061d32a7ea8c9a222e8 100644 (file)
@@ -2,7 +2,6 @@ package main
 
 import (
        "bufio"
-       "bytes"
        "errors"
        "flag"
        "fmt"
@@ -209,43 +208,11 @@ func (v *UnixVolume) Get(loc string) ([]byte, error) {
 // bytes.Compare(), but uses less memory.
 func (v *UnixVolume) Compare(loc string, expect []byte) error {
        path := v.blockPath(loc)
-       stat, err := v.stat(path)
-       if err != nil {
+       if _, err := v.stat(path); err != nil {
                return err
        }
-       bufLen := 1 << 20
-       if int64(bufLen) > stat.Size() {
-               bufLen = int(stat.Size())
-               if bufLen < 1 {
-                       // len(buf)==0 would prevent us from handling
-                       // empty files the same way as non-empty
-                       // files, because reading 0 bytes at a time
-                       // never reaches EOF.
-                       bufLen = 1
-               }
-       }
-       cmp := expect
-       buf := make([]byte, bufLen)
        return v.getFunc(path, func(rdr io.Reader) error {
-               // Loop invariants: all data read so far matched what
-               // we expected, and the first N bytes of cmp are
-               // expected to equal the next N bytes read from
-               // reader.
-               for {
-                       n, err := rdr.Read(buf)
-                       if n > len(cmp) || bytes.Compare(cmp[:n], buf[:n]) != 0 {
-                               return collisionOrCorrupt(loc[:32], expect[:len(expect)-len(cmp)], buf[:n], rdr)
-                       }
-                       cmp = cmp[n:]
-                       if err == io.EOF {
-                               if len(cmp) != 0 {
-                                       return collisionOrCorrupt(loc[:32], expect[:len(expect)-len(cmp)], nil, nil)
-                               }
-                               return nil
-                       } else if err != nil {
-                               return err
-                       }
-               }
+               return compareReaderWithBuf(rdr, expect, loc[:32])
        })
 }