X-Git-Url: https://git.arvados.org/arvados.git/blobdiff_plain/863570108a2c901a8eff22dc8a9bc72635ba7b95..380a54a7d97b34119cbaa3bee05d6b6cd241eee5:/services/keepstore/azure_blob_volume_test.go diff --git a/services/keepstore/azure_blob_volume_test.go b/services/keepstore/azure_blob_volume_test.go index bb57dcd266..8d02def144 100644 --- a/services/keepstore/azure_blob_volume_test.go +++ b/services/keepstore/azure_blob_volume_test.go @@ -1,3 +1,7 @@ +// Copyright (C) The Arvados Authors. All rights reserved. +// +// SPDX-License-Identifier: AGPL-3.0 + package main import ( @@ -5,15 +9,16 @@ import ( "context" "crypto/md5" "encoding/base64" + "encoding/json" "encoding/xml" "flag" "fmt" "io/ioutil" - "log" "math/rand" "net" "net/http" "net/http/httptest" + "os" "regexp" "sort" "strconv" @@ -22,16 +27,26 @@ import ( "testing" "time" - "github.com/curoverse/azure-sdk-for-go/storage" + "git.curoverse.com/arvados.git/sdk/go/arvados" + "github.com/Azure/azure-sdk-for-go/storage" + "github.com/ghodss/yaml" + "github.com/prometheus/client_golang/prometheus" + check "gopkg.in/check.v1" ) const ( - // The same fake credentials used by Microsoft's Azure emulator - emulatorAccountName = "devstoreaccount1" - emulatorAccountKey = "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==" + // This cannot be the fake account name "devstoreaccount1" + // used by Microsoft's Azure emulator: the Azure SDK + // recognizes that magic string and changes its behavior to + // cater to the Azure SDK's own test suite. + fakeAccountName = "fakeaccountname" + fakeAccountKey = "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==" ) -var azureTestContainer string +var ( + azureTestContainer string + azureTestDebug = os.Getenv("ARVADOS_DEBUG") != "" +) func init() { flag.StringVar( @@ -51,8 +66,9 @@ type azBlob struct { type azStubHandler struct { sync.Mutex - blobs map[string]*azBlob - race chan chan struct{} + blobs map[string]*azBlob + race chan chan struct{} + didlist503 bool } func newAzStubHandler() *azStubHandler { @@ -100,7 +116,9 @@ var rangeRegexp = regexp.MustCompile(`^bytes=(\d+)-(\d+)$`) func (h *azStubHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { h.Lock() defer h.Unlock() - // defer log.Printf("azStubHandler: %+v", r) + if azureTestDebug { + defer log.Printf("azStubHandler: %+v", r) + } path := strings.Split(r.URL.Path, "/") container := path[1] @@ -115,6 +133,11 @@ func (h *azStubHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { return } + if (r.Method == "PUT" || r.Method == "POST") && r.Header.Get("Content-Length") == "" { + rw.WriteHeader(http.StatusLengthRequired) + return + } + body, err := ioutil.ReadAll(r.Body) if err != nil { return @@ -260,6 +283,11 @@ func (h *azStubHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { rw.WriteHeader(http.StatusAccepted) case r.Method == "GET" && r.Form.Get("comp") == "list" && r.Form.Get("restype") == "container": // "List Blobs" API + if !h.didlist503 { + h.didlist503 = true + rw.WriteHeader(http.StatusServiceUnavailable) + return + } prefix := container + "|" + r.Form.Get("prefix") marker := r.Form.Get("marker") @@ -294,7 +322,7 @@ func (h *azStubHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { b := storage.Blob{ Name: hash, Properties: storage.BlobProperties{ - LastModified: blob.Mtime.Format(time.RFC1123), + LastModified: storage.TimeRFC1123(blob.Mtime), ContentLength: int64(len(blob.Data)), Etag: blob.Etag, }, @@ -326,7 +354,9 @@ var localHostPortRe = regexp.MustCompile(`(127\.0\.0\.1|localhost|\[::1\]):\d+`) func (d *azStubDialer) Dial(network, address string) (net.Conn, error) { if hp := localHostPortRe.FindString(address); hp != "" { - log.Println("azStubDialer: dial", hp, "instead of", address) + if azureTestDebug { + log.Println("azStubDialer: dial", hp, "instead of", address) + } address = hp } return d.Dialer.Dial(network, address) @@ -350,7 +380,7 @@ func NewTestableAzureBlobVolume(t TB, readonly bool, replication int) *TestableA // Connect to stub instead of real Azure storage service stubURLBase := strings.Split(azStub.URL, "://")[1] var err error - if azClient, err = storage.NewClient(emulatorAccountName, emulatorAccountKey, stubURLBase, storage.DefaultAPIVersion, false); err != nil { + if azClient, err = storage.NewClient(fakeAccountName, fakeAccountKey, stubURLBase, storage.DefaultAPIVersion, false); err != nil { t.Fatal(err) } container = "fakecontainername" @@ -365,13 +395,17 @@ func NewTestableAzureBlobVolume(t TB, readonly bool, replication int) *TestableA t.Fatal(err) } } + azClient.Sender = &singleSender{} + bs := azClient.GetBlobService() v := &AzureBlobVolume{ - ContainerName: container, - ReadOnly: readonly, - AzureReplication: replication, - azClient: azClient, - bsClient: azClient.GetBlobService(), + ContainerName: container, + ReadOnly: readonly, + AzureReplication: replication, + ListBlobsMaxAttempts: 2, + ListBlobsRetryDelay: arvados.Duration(time.Millisecond), + azClient: azClient, + container: &azureContainer{ctr: bs.GetContainerReference(container)}, } return &TestableAzureBlobVolume{ @@ -382,6 +416,29 @@ func NewTestableAzureBlobVolume(t TB, readonly bool, replication int) *TestableA } } +var _ = check.Suite(&StubbedAzureBlobSuite{}) + +type StubbedAzureBlobSuite struct { + volume *TestableAzureBlobVolume + origHTTPTransport http.RoundTripper +} + +func (s *StubbedAzureBlobSuite) SetUpTest(c *check.C) { + s.origHTTPTransport = http.DefaultTransport + http.DefaultTransport = &http.Transport{ + Dial: (&azStubDialer{}).Dial, + } + azureWriteRaceInterval = time.Millisecond + azureWriteRacePollTime = time.Nanosecond + + s.volume = NewTestableAzureBlobVolume(c, false, 3) +} + +func (s *StubbedAzureBlobSuite) TearDownTest(c *check.C) { + s.volume.Teardown() + http.DefaultTransport = s.origHTTPTransport +} + func TestAzureBlobVolumeWithGeneric(t *testing.T) { defer func(t http.RoundTripper) { http.DefaultTransport = t @@ -455,21 +512,21 @@ func TestAzureBlobVolumeRangeFenceposts(t *testing.T) { data[i] = byte((i + 7) & 0xff) } hash := fmt.Sprintf("%x", md5.Sum(data)) - err := v.Put(hash, data) + err := v.Put(context.Background(), hash, data) if err != nil { t.Error(err) } gotData := make([]byte, len(data)) - gotLen, err := v.Get(context.TODO(), hash, gotData) + gotLen, err := v.Get(context.Background(), hash, gotData) if err != nil { t.Error(err) } gotHash := fmt.Sprintf("%x", md5.Sum(gotData)) if gotLen != size { - t.Error("length mismatch: got %d != %d", gotLen, size) + t.Errorf("length mismatch: got %d != %d", gotLen, size) } if gotHash != hash { - t.Error("hash mismatch: got %s != %s", gotHash, hash) + t.Errorf("hash mismatch: got %s != %s", gotHash, hash) } } } @@ -498,10 +555,14 @@ func TestAzureBlobVolumeCreateBlobRace(t *testing.T) { azureWriteRaceInterval = time.Second azureWriteRacePollTime = time.Millisecond - allDone := make(chan struct{}) + var wg sync.WaitGroup + v.azHandler.race = make(chan chan struct{}) + + wg.Add(1) go func() { - err := v.Put(TestHash, TestBlock) + defer wg.Done() + err := v.Put(context.Background(), TestHash, TestBlock) if err != nil { t.Error(err) } @@ -509,21 +570,22 @@ func TestAzureBlobVolumeCreateBlobRace(t *testing.T) { continuePut := make(chan struct{}) // Wait for the stub's Put to create the empty blob v.azHandler.race <- continuePut + wg.Add(1) go func() { + defer wg.Done() buf := make([]byte, len(TestBlock)) - _, err := v.Get(context.TODO(), TestHash, buf) + _, err := v.Get(context.Background(), TestHash, buf) if err != nil { t.Error(err) } - close(allDone) }() // Wait for the stub's Get to get the empty blob close(v.azHandler.race) // Allow stub's Put to continue, so the real data is ready // when the volume's Get retries <-continuePut - // Wait for volume's Get to return the real data - <-allDone + // Wait for Get() and Put() to finish + wg.Wait() } func TestAzureBlobVolumeCreateBlobRaceDeadline(t *testing.T) { @@ -554,7 +616,7 @@ func TestAzureBlobVolumeCreateBlobRaceDeadline(t *testing.T) { go func() { defer close(allDone) buf := make([]byte, BlockSize) - n, err := v.Get(context.TODO(), TestHash, buf) + n, err := v.Get(context.Background(), TestHash, buf) if err != nil { t.Error(err) return @@ -576,6 +638,112 @@ func TestAzureBlobVolumeCreateBlobRaceDeadline(t *testing.T) { } } +func TestAzureBlobVolumeContextCancelGet(t *testing.T) { + testAzureBlobVolumeContextCancel(t, func(ctx context.Context, v *TestableAzureBlobVolume) error { + v.PutRaw(TestHash, TestBlock) + _, err := v.Get(ctx, TestHash, make([]byte, BlockSize)) + return err + }) +} + +func TestAzureBlobVolumeContextCancelPut(t *testing.T) { + testAzureBlobVolumeContextCancel(t, func(ctx context.Context, v *TestableAzureBlobVolume) error { + return v.Put(ctx, TestHash, make([]byte, BlockSize)) + }) +} + +func TestAzureBlobVolumeContextCancelCompare(t *testing.T) { + testAzureBlobVolumeContextCancel(t, func(ctx context.Context, v *TestableAzureBlobVolume) error { + v.PutRaw(TestHash, TestBlock) + return v.Compare(ctx, TestHash, TestBlock2) + }) +} + +func testAzureBlobVolumeContextCancel(t *testing.T, testFunc func(context.Context, *TestableAzureBlobVolume) error) { + defer func(t http.RoundTripper) { + http.DefaultTransport = t + }(http.DefaultTransport) + http.DefaultTransport = &http.Transport{ + Dial: (&azStubDialer{}).Dial, + } + + v := NewTestableAzureBlobVolume(t, false, 3) + defer v.Teardown() + v.azHandler.race = make(chan chan struct{}) + + ctx, cancel := context.WithCancel(context.Background()) + allDone := make(chan struct{}) + go func() { + defer close(allDone) + err := testFunc(ctx, v) + if err != context.Canceled { + t.Errorf("got %T %q, expected %q", err, err, context.Canceled) + } + }() + releaseHandler := make(chan struct{}) + select { + case <-allDone: + t.Error("testFunc finished without waiting for v.azHandler.race") + case <-time.After(10 * time.Second): + t.Error("timed out waiting to enter handler") + case v.azHandler.race <- releaseHandler: + } + + cancel() + + select { + case <-time.After(10 * time.Second): + t.Error("timed out waiting to cancel") + case <-allDone: + } + + go func() { + <-releaseHandler + }() +} + +func (s *StubbedAzureBlobSuite) TestStats(c *check.C) { + stats := func() string { + buf, err := json.Marshal(s.volume.InternalStats()) + c.Check(err, check.IsNil) + return string(buf) + } + + c.Check(stats(), check.Matches, `.*"Ops":0,.*`) + c.Check(stats(), check.Matches, `.*"Errors":0,.*`) + + loc := "acbd18db4cc2f85cedef654fccc4a4d8" + _, err := s.volume.Get(context.Background(), loc, make([]byte, 3)) + c.Check(err, check.NotNil) + c.Check(stats(), check.Matches, `.*"Ops":[^0],.*`) + c.Check(stats(), check.Matches, `.*"Errors":[^0],.*`) + c.Check(stats(), check.Matches, `.*"storage\.AzureStorageServiceError 404 \(404 Not Found\)":[^0].*`) + c.Check(stats(), check.Matches, `.*"InBytes":0,.*`) + + err = s.volume.Put(context.Background(), loc, []byte("foo")) + c.Check(err, check.IsNil) + c.Check(stats(), check.Matches, `.*"OutBytes":3,.*`) + c.Check(stats(), check.Matches, `.*"CreateOps":1,.*`) + + _, err = s.volume.Get(context.Background(), loc, make([]byte, 3)) + c.Check(err, check.IsNil) + _, err = s.volume.Get(context.Background(), loc, make([]byte, 3)) + c.Check(err, check.IsNil) + c.Check(stats(), check.Matches, `.*"InBytes":6,.*`) +} + +func (s *StubbedAzureBlobSuite) TestConfig(c *check.C) { + var cfg Config + err := yaml.Unmarshal([]byte(` +Volumes: + - Type: Azure + StorageClasses: ["class_a", "class_b"] +`), &cfg) + + c.Check(err, check.IsNil) + c.Check(cfg.Volumes[0].GetStorageClasses(), check.DeepEquals, []string{"class_a", "class_b"}) +} + func (v *TestableAzureBlobVolume) PutRaw(locator string, data []byte) { v.azHandler.PutRaw(v.ContainerName, locator, data) } @@ -588,6 +756,21 @@ func (v *TestableAzureBlobVolume) Teardown() { v.azStub.Close() } +func (v *TestableAzureBlobVolume) ReadWriteOperationLabelValues() (r, w string) { + return "get", "create" +} + +func (v *TestableAzureBlobVolume) DeviceID() string { + // Dummy device id for testing purposes + return "azure://azure_blob_volume_test" +} + +func (v *TestableAzureBlobVolume) Start(vm *volumeMetricsVecs) error { + // Override original Start() to be able to assign CounterVecs with a dummy DeviceID + v.container.stats.opsCounters, v.container.stats.errCounters, v.container.stats.ioBytes = vm.getCounterVecsFor(prometheus.Labels{"device_id": v.DeviceID()}) + return nil +} + func makeEtag() string { return fmt.Sprintf("0x%x", rand.Int63()) }