Avoid http client race in test case.
[arvados.git] / services / keepstore / azure_blob_volume_test.go
index d636a5ee86887806372a14e2f291e5c4f2c11b33..1cb6dc380d0a24a002072b1a7465ef640882dd6c 100644 (file)
@@ -1,3 +1,7 @@
+// Copyright (C) The Arvados Authors. All rights reserved.
+//
+// SPDX-License-Identifier: AGPL-3.0
+
 package main
 
 import (
@@ -5,11 +9,11 @@ import (
        "context"
        "crypto/md5"
        "encoding/base64"
+       "encoding/json"
        "encoding/xml"
        "flag"
        "fmt"
        "io/ioutil"
-       "log"
        "math/rand"
        "net"
        "net/http"
@@ -22,13 +26,18 @@ import (
        "testing"
        "time"
 
-       "github.com/curoverse/azure-sdk-for-go/storage"
+       "github.com/Azure/azure-sdk-for-go/storage"
+       "github.com/ghodss/yaml"
+       check "gopkg.in/check.v1"
 )
 
 const (
-       // The same fake credentials used by Microsoft's Azure emulator
-       emulatorAccountName = "devstoreaccount1"
-       emulatorAccountKey  = "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw=="
+       // This cannot be the fake account name "devstoreaccount1"
+       // used by Microsoft's Azure emulator: the Azure SDK
+       // recognizes that magic string and changes its behavior to
+       // cater to the Azure SDK's own test suite.
+       fakeAccountName = "fakeaccountname"
+       fakeAccountKey  = "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw=="
 )
 
 var azureTestContainer string
@@ -115,6 +124,11 @@ func (h *azStubHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
                return
        }
 
+       if (r.Method == "PUT" || r.Method == "POST") && r.Header.Get("Content-Length") == "" {
+               rw.WriteHeader(http.StatusLengthRequired)
+               return
+       }
+
        body, err := ioutil.ReadAll(r.Body)
        if err != nil {
                return
@@ -294,7 +308,7 @@ func (h *azStubHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
                                b := storage.Blob{
                                        Name: hash,
                                        Properties: storage.BlobProperties{
-                                               LastModified:  blob.Mtime.Format(time.RFC1123),
+                                               LastModified:  storage.TimeRFC1123(blob.Mtime),
                                                ContentLength: int64(len(blob.Data)),
                                                Etag:          blob.Etag,
                                        },
@@ -350,7 +364,7 @@ func NewTestableAzureBlobVolume(t TB, readonly bool, replication int) *TestableA
                // Connect to stub instead of real Azure storage service
                stubURLBase := strings.Split(azStub.URL, "://")[1]
                var err error
-               if azClient, err = storage.NewClient(emulatorAccountName, emulatorAccountKey, stubURLBase, storage.DefaultAPIVersion, false); err != nil {
+               if azClient, err = storage.NewClient(fakeAccountName, fakeAccountKey, stubURLBase, storage.DefaultAPIVersion, false); err != nil {
                        t.Fatal(err)
                }
                container = "fakecontainername"
@@ -366,12 +380,13 @@ func NewTestableAzureBlobVolume(t TB, readonly bool, replication int) *TestableA
                }
        }
 
+       bs := azClient.GetBlobService()
        v := &AzureBlobVolume{
                ContainerName:    container,
                ReadOnly:         readonly,
                AzureReplication: replication,
                azClient:         azClient,
-               bsClient:         azClient.GetBlobService(),
+               container:        &azureContainer{ctr: bs.GetContainerReference(container)},
        }
 
        return &TestableAzureBlobVolume{
@@ -382,6 +397,29 @@ func NewTestableAzureBlobVolume(t TB, readonly bool, replication int) *TestableA
        }
 }
 
+var _ = check.Suite(&StubbedAzureBlobSuite{})
+
+type StubbedAzureBlobSuite struct {
+       volume            *TestableAzureBlobVolume
+       origHTTPTransport http.RoundTripper
+}
+
+func (s *StubbedAzureBlobSuite) SetUpTest(c *check.C) {
+       s.origHTTPTransport = http.DefaultTransport
+       http.DefaultTransport = &http.Transport{
+               Dial: (&azStubDialer{}).Dial,
+       }
+       azureWriteRaceInterval = time.Millisecond
+       azureWriteRacePollTime = time.Nanosecond
+
+       s.volume = NewTestableAzureBlobVolume(c, false, 3)
+}
+
+func (s *StubbedAzureBlobSuite) TearDownTest(c *check.C) {
+       s.volume.Teardown()
+       http.DefaultTransport = s.origHTTPTransport
+}
+
 func TestAzureBlobVolumeWithGeneric(t *testing.T) {
        defer func(t http.RoundTripper) {
                http.DefaultTransport = t
@@ -466,10 +504,10 @@ func TestAzureBlobVolumeRangeFenceposts(t *testing.T) {
                }
                gotHash := fmt.Sprintf("%x", md5.Sum(gotData))
                if gotLen != size {
-                       t.Error("length mismatch: got %d != %d", gotLen, size)
+                       t.Errorf("length mismatch: got %d != %d", gotLen, size)
                }
                if gotHash != hash {
-                       t.Error("hash mismatch: got %s != %s", gotHash, hash)
+                       t.Errorf("hash mismatch: got %s != %s", gotHash, hash)
                }
        }
 }
@@ -498,9 +536,13 @@ func TestAzureBlobVolumeCreateBlobRace(t *testing.T) {
        azureWriteRaceInterval = time.Second
        azureWriteRacePollTime = time.Millisecond
 
-       allDone := make(chan struct{})
+       var wg sync.WaitGroup
+
        v.azHandler.race = make(chan chan struct{})
+
+       wg.Add(1)
        go func() {
+               defer wg.Done()
                err := v.Put(context.Background(), TestHash, TestBlock)
                if err != nil {
                        t.Error(err)
@@ -509,21 +551,22 @@ func TestAzureBlobVolumeCreateBlobRace(t *testing.T) {
        continuePut := make(chan struct{})
        // Wait for the stub's Put to create the empty blob
        v.azHandler.race <- continuePut
+       wg.Add(1)
        go func() {
+               defer wg.Done()
                buf := make([]byte, len(TestBlock))
                _, err := v.Get(context.Background(), TestHash, buf)
                if err != nil {
                        t.Error(err)
                }
-               close(allDone)
        }()
        // Wait for the stub's Get to get the empty blob
        close(v.azHandler.race)
        // Allow stub's Put to continue, so the real data is ready
        // when the volume's Get retries
        <-continuePut
-       // Wait for volume's Get to return the real data
-       <-allDone
+       // Wait for Get() and Put() to finish
+       wg.Wait()
 }
 
 func TestAzureBlobVolumeCreateBlobRaceDeadline(t *testing.T) {
@@ -576,6 +619,112 @@ func TestAzureBlobVolumeCreateBlobRaceDeadline(t *testing.T) {
        }
 }
 
+func TestAzureBlobVolumeContextCancelGet(t *testing.T) {
+       testAzureBlobVolumeContextCancel(t, func(ctx context.Context, v *TestableAzureBlobVolume) error {
+               v.PutRaw(TestHash, TestBlock)
+               _, err := v.Get(ctx, TestHash, make([]byte, BlockSize))
+               return err
+       })
+}
+
+func TestAzureBlobVolumeContextCancelPut(t *testing.T) {
+       testAzureBlobVolumeContextCancel(t, func(ctx context.Context, v *TestableAzureBlobVolume) error {
+               return v.Put(ctx, TestHash, make([]byte, BlockSize))
+       })
+}
+
+func TestAzureBlobVolumeContextCancelCompare(t *testing.T) {
+       testAzureBlobVolumeContextCancel(t, func(ctx context.Context, v *TestableAzureBlobVolume) error {
+               v.PutRaw(TestHash, TestBlock)
+               return v.Compare(ctx, TestHash, TestBlock2)
+       })
+}
+
+func testAzureBlobVolumeContextCancel(t *testing.T, testFunc func(context.Context, *TestableAzureBlobVolume) error) {
+       defer func(t http.RoundTripper) {
+               http.DefaultTransport = t
+       }(http.DefaultTransport)
+       http.DefaultTransport = &http.Transport{
+               Dial: (&azStubDialer{}).Dial,
+       }
+
+       v := NewTestableAzureBlobVolume(t, false, 3)
+       defer v.Teardown()
+       v.azHandler.race = make(chan chan struct{})
+
+       ctx, cancel := context.WithCancel(context.Background())
+       allDone := make(chan struct{})
+       go func() {
+               defer close(allDone)
+               err := testFunc(ctx, v)
+               if err != context.Canceled {
+                       t.Errorf("got %T %q, expected %q", err, err, context.Canceled)
+               }
+       }()
+       releaseHandler := make(chan struct{})
+       select {
+       case <-allDone:
+               t.Error("testFunc finished without waiting for v.azHandler.race")
+       case <-time.After(10 * time.Second):
+               t.Error("timed out waiting to enter handler")
+       case v.azHandler.race <- releaseHandler:
+       }
+
+       cancel()
+
+       select {
+       case <-time.After(10 * time.Second):
+               t.Error("timed out waiting to cancel")
+       case <-allDone:
+       }
+
+       go func() {
+               <-releaseHandler
+       }()
+}
+
+func (s *StubbedAzureBlobSuite) TestStats(c *check.C) {
+       stats := func() string {
+               buf, err := json.Marshal(s.volume.InternalStats())
+               c.Check(err, check.IsNil)
+               return string(buf)
+       }
+
+       c.Check(stats(), check.Matches, `.*"Ops":0,.*`)
+       c.Check(stats(), check.Matches, `.*"Errors":0,.*`)
+
+       loc := "acbd18db4cc2f85cedef654fccc4a4d8"
+       _, err := s.volume.Get(context.Background(), loc, make([]byte, 3))
+       c.Check(err, check.NotNil)
+       c.Check(stats(), check.Matches, `.*"Ops":[^0],.*`)
+       c.Check(stats(), check.Matches, `.*"Errors":[^0],.*`)
+       c.Check(stats(), check.Matches, `.*"storage\.AzureStorageServiceError 404 \(404 Not Found\)":[^0].*`)
+       c.Check(stats(), check.Matches, `.*"InBytes":0,.*`)
+
+       err = s.volume.Put(context.Background(), loc, []byte("foo"))
+       c.Check(err, check.IsNil)
+       c.Check(stats(), check.Matches, `.*"OutBytes":3,.*`)
+       c.Check(stats(), check.Matches, `.*"CreateOps":1,.*`)
+
+       _, err = s.volume.Get(context.Background(), loc, make([]byte, 3))
+       c.Check(err, check.IsNil)
+       _, err = s.volume.Get(context.Background(), loc, make([]byte, 3))
+       c.Check(err, check.IsNil)
+       c.Check(stats(), check.Matches, `.*"InBytes":6,.*`)
+}
+
+func (s *StubbedAzureBlobSuite) TestConfig(c *check.C) {
+       var cfg Config
+       err := yaml.Unmarshal([]byte(`
+Volumes:
+  - Type: Azure
+    StorageClasses: ["class_a", "class_b"]
+`), &cfg)
+
+       c.Check(err, check.IsNil)
+       c.Check(cfg.Volumes[0].GetStorageClasses(), check.DeepEquals, []string{"class_a", "class_b"})
+}
+
 func (v *TestableAzureBlobVolume) PutRaw(locator string, data []byte) {
        v.azHandler.PutRaw(v.ContainerName, locator, data)
 }