Merge branch '8784-dir-listings'
[arvados.git] / services / keepstore / azure_blob_volume_test.go
index a240c23e1622b525f62a6957a23c025651f94190..4256ec0d0cb599e259ff7cabcc6f3407fd2e6dce 100644 (file)
@@ -1,13 +1,19 @@
+// Copyright (C) The Arvados Authors. All rights reserved.
+//
+// SPDX-License-Identifier: AGPL-3.0
+
 package main
 
 import (
        "bytes"
+       "context"
+       "crypto/md5"
        "encoding/base64"
+       "encoding/json"
        "encoding/xml"
        "flag"
        "fmt"
        "io/ioutil"
-       "log"
        "math/rand"
        "net"
        "net/http"
@@ -20,13 +26,18 @@ import (
        "testing"
        "time"
 
+       log "github.com/Sirupsen/logrus"
        "github.com/curoverse/azure-sdk-for-go/storage"
+       check "gopkg.in/check.v1"
 )
 
 const (
-       // The same fake credentials used by Microsoft's Azure emulator
-       emulatorAccountName = "devstoreaccount1"
-       emulatorAccountKey  = "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw=="
+       // This cannot be the fake account name "devstoreaccount1"
+       // used by Microsoft's Azure emulator: the Azure SDK
+       // recognizes that magic string and changes its behavior to
+       // cater to the Azure SDK's own test suite.
+       fakeAccountName = "fakeAccountName"
+       fakeAccountKey  = "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw=="
 )
 
 var azureTestContainer string
@@ -73,6 +84,7 @@ func (h *azStubHandler) PutRaw(container, hash string, data []byte) {
        h.blobs[container+"|"+hash] = &azBlob{
                Data:        data,
                Mtime:       time.Now(),
+               Metadata:    make(map[string]string),
                Uncommitted: make(map[string][]byte),
        }
 }
@@ -92,6 +104,8 @@ func (h *azStubHandler) unlockAndRace() {
        h.Lock()
 }
 
+var rangeRegexp = regexp.MustCompile(`^bytes=(\d+)-(\d+)$`)
+
 func (h *azStubHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
        h.Lock()
        defer h.Unlock()
@@ -133,14 +147,23 @@ func (h *azStubHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
                        h.blobs[container+"|"+hash] = &azBlob{
                                Mtime:       time.Now(),
                                Uncommitted: make(map[string][]byte),
+                               Metadata:    make(map[string]string),
                                Etag:        makeEtag(),
                        }
                        h.unlockAndRace()
                }
+               metadata := make(map[string]string)
+               for k, v := range r.Header {
+                       if strings.HasPrefix(strings.ToLower(k), "x-ms-meta-") {
+                               name := k[len("x-ms-meta-"):]
+                               metadata[strings.ToLower(name)] = v[0]
+                       }
+               }
                h.blobs[container+"|"+hash] = &azBlob{
                        Data:        body,
                        Mtime:       time.Now(),
                        Uncommitted: make(map[string][]byte),
+                       Metadata:    metadata,
                        Etag:        makeEtag(),
                }
                rw.WriteHeader(http.StatusCreated)
@@ -193,22 +216,46 @@ func (h *azStubHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
                blob.Metadata = make(map[string]string)
                for k, v := range r.Header {
                        if strings.HasPrefix(strings.ToLower(k), "x-ms-meta-") {
-                               blob.Metadata[k] = v[0]
+                               name := k[len("x-ms-meta-"):]
+                               blob.Metadata[strings.ToLower(name)] = v[0]
                        }
                }
                blob.Mtime = time.Now()
                blob.Etag = makeEtag()
+       case (r.Method == "GET" || r.Method == "HEAD") && r.Form.Get("comp") == "metadata" && hash != "":
+               // "Get Blob Metadata" API
+               if !blobExists {
+                       rw.WriteHeader(http.StatusNotFound)
+                       return
+               }
+               for k, v := range blob.Metadata {
+                       rw.Header().Set(fmt.Sprintf("x-ms-meta-%s", k), v)
+               }
+               return
        case (r.Method == "GET" || r.Method == "HEAD") && hash != "":
                // "Get Blob" API
                if !blobExists {
                        rw.WriteHeader(http.StatusNotFound)
                        return
                }
+               data := blob.Data
+               if rangeSpec := rangeRegexp.FindStringSubmatch(r.Header.Get("Range")); rangeSpec != nil {
+                       b0, err0 := strconv.Atoi(rangeSpec[1])
+                       b1, err1 := strconv.Atoi(rangeSpec[2])
+                       if err0 != nil || err1 != nil || b0 >= len(data) || b1 >= len(data) || b0 > b1 {
+                               rw.Header().Set("Content-Range", fmt.Sprintf("bytes */%d", len(data)))
+                               rw.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
+                               return
+                       }
+                       rw.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", b0, b1, len(data)))
+                       rw.WriteHeader(http.StatusPartialContent)
+                       data = data[b0 : b1+1]
+               }
                rw.Header().Set("Last-Modified", blob.Mtime.Format(time.RFC1123))
-               rw.Header().Set("Content-Length", strconv.Itoa(len(blob.Data)))
+               rw.Header().Set("Content-Length", strconv.Itoa(len(data)))
                if r.Method == "GET" {
-                       if _, err := rw.Write(blob.Data); err != nil {
-                               log.Printf("write %+q: %s", blob.Data, err)
+                       if _, err := rw.Write(data); err != nil {
+                               log.Printf("write %+q: %s", data, err)
                        }
                }
                h.unlockAndRace()
@@ -249,14 +296,20 @@ func (h *azStubHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
                        }
                        if len(resp.Blobs) > 0 || marker == "" || marker == hash {
                                blob := h.blobs[container+"|"+hash]
-                               resp.Blobs = append(resp.Blobs, storage.Blob{
+                               bmeta := map[string]string(nil)
+                               if r.Form.Get("include") == "metadata" {
+                                       bmeta = blob.Metadata
+                               }
+                               b := storage.Blob{
                                        Name: hash,
                                        Properties: storage.BlobProperties{
                                                LastModified:  blob.Mtime.Format(time.RFC1123),
                                                ContentLength: int64(len(blob.Data)),
                                                Etag:          blob.Etag,
                                        },
-                               })
+                                       Metadata: bmeta,
+                               }
+                               resp.Blobs = append(resp.Blobs, b)
                        }
                }
                buf, err := xml.Marshal(resp)
@@ -292,10 +345,10 @@ type TestableAzureBlobVolume struct {
        *AzureBlobVolume
        azHandler *azStubHandler
        azStub    *httptest.Server
-       t         *testing.T
+       t         TB
 }
 
-func NewTestableAzureBlobVolume(t *testing.T, readonly bool, replication int) *TestableAzureBlobVolume {
+func NewTestableAzureBlobVolume(t TB, readonly bool, replication int) *TestableAzureBlobVolume {
        azHandler := newAzStubHandler()
        azStub := httptest.NewServer(azHandler)
 
@@ -306,7 +359,7 @@ func NewTestableAzureBlobVolume(t *testing.T, readonly bool, replication int) *T
                // Connect to stub instead of real Azure storage service
                stubURLBase := strings.Split(azStub.URL, "://")[1]
                var err error
-               if azClient, err = storage.NewClient(emulatorAccountName, emulatorAccountKey, stubURLBase, storage.DefaultAPIVersion, false); err != nil {
+               if azClient, err = storage.NewClient(fakeAccountName, fakeAccountKey, stubURLBase, storage.DefaultAPIVersion, false); err != nil {
                        t.Fatal(err)
                }
                container = "fakecontainername"
@@ -322,7 +375,14 @@ func NewTestableAzureBlobVolume(t *testing.T, readonly bool, replication int) *T
                }
        }
 
-       v := NewAzureBlobVolume(azClient, container, readonly, replication)
+       bs := azClient.GetBlobService()
+       v := &AzureBlobVolume{
+               ContainerName:    container,
+               ReadOnly:         readonly,
+               AzureReplication: replication,
+               azClient:         azClient,
+               bsClient:         &azureBlobClient{client: &bs},
+       }
 
        return &TestableAzureBlobVolume{
                AzureBlobVolume: v,
@@ -332,6 +392,29 @@ func NewTestableAzureBlobVolume(t *testing.T, readonly bool, replication int) *T
        }
 }
 
+var _ = check.Suite(&StubbedAzureBlobSuite{})
+
+type StubbedAzureBlobSuite struct {
+       volume            *TestableAzureBlobVolume
+       origHTTPTransport http.RoundTripper
+}
+
+func (s *StubbedAzureBlobSuite) SetUpTest(c *check.C) {
+       s.origHTTPTransport = http.DefaultTransport
+       http.DefaultTransport = &http.Transport{
+               Dial: (&azStubDialer{}).Dial,
+       }
+       azureWriteRaceInterval = time.Millisecond
+       azureWriteRacePollTime = time.Nanosecond
+
+       s.volume = NewTestableAzureBlobVolume(c, false, 3)
+}
+
+func (s *StubbedAzureBlobSuite) TearDownTest(c *check.C) {
+       s.volume.Teardown()
+       http.DefaultTransport = s.origHTTPTransport
+}
+
 func TestAzureBlobVolumeWithGeneric(t *testing.T) {
        defer func(t http.RoundTripper) {
                http.DefaultTransport = t
@@ -341,11 +424,32 @@ func TestAzureBlobVolumeWithGeneric(t *testing.T) {
        }
        azureWriteRaceInterval = time.Millisecond
        azureWriteRacePollTime = time.Nanosecond
-       DoGenericVolumeTests(t, func(t *testing.T) TestableVolume {
+       DoGenericVolumeTests(t, func(t TB) TestableVolume {
                return NewTestableAzureBlobVolume(t, false, azureStorageReplication)
        })
 }
 
+func TestAzureBlobVolumeConcurrentRanges(t *testing.T) {
+       defer func(b int) {
+               azureMaxGetBytes = b
+       }(azureMaxGetBytes)
+
+       defer func(t http.RoundTripper) {
+               http.DefaultTransport = t
+       }(http.DefaultTransport)
+       http.DefaultTransport = &http.Transport{
+               Dial: (&azStubDialer{}).Dial,
+       }
+       azureWriteRaceInterval = time.Millisecond
+       azureWriteRacePollTime = time.Nanosecond
+       // Test (BlockSize mod azureMaxGetBytes)==0 and !=0 cases
+       for _, azureMaxGetBytes = range []int{2 << 22, 2<<22 - 1} {
+               DoGenericVolumeTests(t, func(t TB) TestableVolume {
+                       return NewTestableAzureBlobVolume(t, false, azureStorageReplication)
+               })
+       }
+}
+
 func TestReadonlyAzureBlobVolumeWithGeneric(t *testing.T) {
        defer func(t http.RoundTripper) {
                http.DefaultTransport = t
@@ -355,11 +459,54 @@ func TestReadonlyAzureBlobVolumeWithGeneric(t *testing.T) {
        }
        azureWriteRaceInterval = time.Millisecond
        azureWriteRacePollTime = time.Nanosecond
-       DoGenericVolumeTests(t, func(t *testing.T) TestableVolume {
+       DoGenericVolumeTests(t, func(t TB) TestableVolume {
                return NewTestableAzureBlobVolume(t, true, azureStorageReplication)
        })
 }
 
+func TestAzureBlobVolumeRangeFenceposts(t *testing.T) {
+       defer func(t http.RoundTripper) {
+               http.DefaultTransport = t
+       }(http.DefaultTransport)
+       http.DefaultTransport = &http.Transport{
+               Dial: (&azStubDialer{}).Dial,
+       }
+
+       v := NewTestableAzureBlobVolume(t, false, 3)
+       defer v.Teardown()
+
+       for _, size := range []int{
+               2<<22 - 1, // one <max read
+               2 << 22,   // one =max read
+               2<<22 + 1, // one =max read, one <max
+               2 << 23,   // two =max reads
+               BlockSize - 1,
+               BlockSize,
+       } {
+               data := make([]byte, size)
+               for i := range data {
+                       data[i] = byte((i + 7) & 0xff)
+               }
+               hash := fmt.Sprintf("%x", md5.Sum(data))
+               err := v.Put(context.Background(), hash, data)
+               if err != nil {
+                       t.Error(err)
+               }
+               gotData := make([]byte, len(data))
+               gotLen, err := v.Get(context.Background(), hash, gotData)
+               if err != nil {
+                       t.Error(err)
+               }
+               gotHash := fmt.Sprintf("%x", md5.Sum(gotData))
+               if gotLen != size {
+                       t.Errorf("length mismatch: got %d != %d", gotLen, size)
+               }
+               if gotHash != hash {
+                       t.Errorf("hash mismatch: got %s != %s", gotHash, hash)
+               }
+       }
+}
+
 func TestAzureBlobVolumeReplication(t *testing.T) {
        for r := 1; r <= 4; r++ {
                v := NewTestableAzureBlobVolume(t, false, r)
@@ -387,7 +534,7 @@ func TestAzureBlobVolumeCreateBlobRace(t *testing.T) {
        allDone := make(chan struct{})
        v.azHandler.race = make(chan chan struct{})
        go func() {
-               err := v.Put(TestHash, TestBlock)
+               err := v.Put(context.Background(), TestHash, TestBlock)
                if err != nil {
                        t.Error(err)
                }
@@ -396,11 +543,10 @@ func TestAzureBlobVolumeCreateBlobRace(t *testing.T) {
        // Wait for the stub's Put to create the empty blob
        v.azHandler.race <- continuePut
        go func() {
-               buf, err := v.Get(TestHash)
+               buf := make([]byte, len(TestBlock))
+               _, err := v.Get(context.Background(), TestHash, buf)
                if err != nil {
                        t.Error(err)
-               } else {
-                       bufs.Put(buf)
                }
                close(allDone)
        }()
@@ -435,20 +581,20 @@ func TestAzureBlobVolumeCreateBlobRaceDeadline(t *testing.T) {
                t.Errorf("Index %+q should be empty", buf.Bytes())
        }
 
-       v.TouchWithDate(TestHash, time.Now().Add(-1982 * time.Millisecond))
+       v.TouchWithDate(TestHash, time.Now().Add(-1982*time.Millisecond))
 
        allDone := make(chan struct{})
        go func() {
                defer close(allDone)
-               buf, err := v.Get(TestHash)
+               buf := make([]byte, BlockSize)
+               n, err := v.Get(context.Background(), TestHash, buf)
                if err != nil {
                        t.Error(err)
                        return
                }
-               if len(buf) != 0 {
-                       t.Errorf("Got %+q, expected empty buf", buf)
+               if n != 0 {
+                       t.Errorf("Got %+q, expected empty buf", buf[:n])
                }
-               bufs.Put(buf)
        }()
        select {
        case <-allDone:
@@ -463,12 +609,106 @@ func TestAzureBlobVolumeCreateBlobRaceDeadline(t *testing.T) {
        }
 }
 
+func TestAzureBlobVolumeContextCancelGet(t *testing.T) {
+       testAzureBlobVolumeContextCancel(t, func(ctx context.Context, v *TestableAzureBlobVolume) error {
+               v.PutRaw(TestHash, TestBlock)
+               _, err := v.Get(ctx, TestHash, make([]byte, BlockSize))
+               return err
+       })
+}
+
+func TestAzureBlobVolumeContextCancelPut(t *testing.T) {
+       testAzureBlobVolumeContextCancel(t, func(ctx context.Context, v *TestableAzureBlobVolume) error {
+               return v.Put(ctx, TestHash, make([]byte, BlockSize))
+       })
+}
+
+func TestAzureBlobVolumeContextCancelCompare(t *testing.T) {
+       testAzureBlobVolumeContextCancel(t, func(ctx context.Context, v *TestableAzureBlobVolume) error {
+               v.PutRaw(TestHash, TestBlock)
+               return v.Compare(ctx, TestHash, TestBlock2)
+       })
+}
+
+func testAzureBlobVolumeContextCancel(t *testing.T, testFunc func(context.Context, *TestableAzureBlobVolume) error) {
+       defer func(t http.RoundTripper) {
+               http.DefaultTransport = t
+       }(http.DefaultTransport)
+       http.DefaultTransport = &http.Transport{
+               Dial: (&azStubDialer{}).Dial,
+       }
+
+       v := NewTestableAzureBlobVolume(t, false, 3)
+       defer v.Teardown()
+       v.azHandler.race = make(chan chan struct{})
+
+       ctx, cancel := context.WithCancel(context.Background())
+       allDone := make(chan struct{})
+       go func() {
+               defer close(allDone)
+               err := testFunc(ctx, v)
+               if err != context.Canceled {
+                       t.Errorf("got %T %q, expected %q", err, err, context.Canceled)
+               }
+       }()
+       releaseHandler := make(chan struct{})
+       select {
+       case <-allDone:
+               t.Error("testFunc finished without waiting for v.azHandler.race")
+       case <-time.After(10 * time.Second):
+               t.Error("timed out waiting to enter handler")
+       case v.azHandler.race <- releaseHandler:
+       }
+
+       cancel()
+
+       select {
+       case <-time.After(10 * time.Second):
+               t.Error("timed out waiting to cancel")
+       case <-allDone:
+       }
+
+       go func() {
+               <-releaseHandler
+       }()
+}
+
+func (s *StubbedAzureBlobSuite) TestStats(c *check.C) {
+       stats := func() string {
+               buf, err := json.Marshal(s.volume.InternalStats())
+               c.Check(err, check.IsNil)
+               return string(buf)
+       }
+
+       c.Check(stats(), check.Matches, `.*"Ops":0,.*`)
+       c.Check(stats(), check.Matches, `.*"Errors":0,.*`)
+
+       loc := "acbd18db4cc2f85cedef654fccc4a4d8"
+       _, err := s.volume.Get(context.Background(), loc, make([]byte, 3))
+       c.Check(err, check.NotNil)
+       c.Check(stats(), check.Matches, `.*"Ops":[^0],.*`)
+       c.Check(stats(), check.Matches, `.*"Errors":[^0],.*`)
+       c.Check(stats(), check.Matches, `.*"storage\.AzureStorageServiceError 404 \(404 Not Found\)":[^0].*`)
+       c.Check(stats(), check.Matches, `.*"InBytes":0,.*`)
+
+       err = s.volume.Put(context.Background(), loc, []byte("foo"))
+       c.Check(err, check.IsNil)
+       c.Check(stats(), check.Matches, `.*"OutBytes":3,.*`)
+       c.Check(stats(), check.Matches, `.*"CreateOps":1,.*`)
+
+       _, err = s.volume.Get(context.Background(), loc, make([]byte, 3))
+       c.Check(err, check.IsNil)
+       _, err = s.volume.Get(context.Background(), loc, make([]byte, 3))
+       c.Check(err, check.IsNil)
+       c.Check(stats(), check.Matches, `.*"InBytes":6,.*`)
+}
+
 func (v *TestableAzureBlobVolume) PutRaw(locator string, data []byte) {
-       v.azHandler.PutRaw(v.containerName, locator, data)
+       v.azHandler.PutRaw(v.ContainerName, locator, data)
 }
 
 func (v *TestableAzureBlobVolume) TouchWithDate(locator string, lastPut time.Time) {
-       v.azHandler.TouchWithDate(v.containerName, locator, lastPut)
+       v.azHandler.TouchWithDate(v.ContainerName, locator, lastPut)
 }
 
 func (v *TestableAzureBlobVolume) Teardown() {