1 // Copyright (C) The Arvados Authors. All rights reserved.
3 // SPDX-License-Identifier: AGPL-3.0
30 "git.curoverse.com/arvados.git/sdk/go/arvados"
31 "github.com/Azure/azure-sdk-for-go/storage"
32 "github.com/ghodss/yaml"
33 "github.com/prometheus/client_golang/prometheus"
34 check "gopkg.in/check.v1"
38 // This cannot be the fake account name "devstoreaccount1"
39 // used by Microsoft's Azure emulator: the Azure SDK
40 // recognizes that magic string and changes its behavior to
41 // cater to the Azure SDK's own test suite.
42 fakeAccountName = "fakeaccountname"
43 fakeAccountKey = "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw=="
47 azureTestContainer string
48 azureTestDebug = os.Getenv("ARVADOS_DEBUG") != ""
54 "test.azure-storage-container-volume",
56 "Name of Azure container to use for testing. Do not use a container with real data! Use -azure-storage-account-name and -azure-storage-key-file arguments to supply credentials.")
62 Metadata map[string]string
64 Uncommitted map[string][]byte
67 type azStubHandler struct {
69 blobs map[string]*azBlob
70 race chan chan struct{}
74 func newAzStubHandler() *azStubHandler {
75 return &azStubHandler{
76 blobs: make(map[string]*azBlob),
80 func (h *azStubHandler) TouchWithDate(container, hash string, t time.Time) {
81 blob, ok := h.blobs[container+"|"+hash]
88 func (h *azStubHandler) PutRaw(container, hash string, data []byte) {
91 h.blobs[container+"|"+hash] = &azBlob{
94 Metadata: make(map[string]string),
95 Uncommitted: make(map[string][]byte),
99 func (h *azStubHandler) unlockAndRace() {
104 // Signal caller that race is starting by reading from
105 // h.race. If we get a channel, block until that channel is
106 // ready to receive. If we get nil (or h.race is closed) just
108 if c := <-h.race; c != nil {
114 var rangeRegexp = regexp.MustCompile(`^bytes=(\d+)-(\d+)$`)
116 func (h *azStubHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
120 defer log.Printf("azStubHandler: %+v", r)
123 path := strings.Split(r.URL.Path, "/")
130 if err := r.ParseForm(); err != nil {
131 log.Printf("azStubHandler(%+v): %s", r, err)
132 rw.WriteHeader(http.StatusBadRequest)
136 if (r.Method == "PUT" || r.Method == "POST") && r.Header.Get("Content-Length") == "" {
137 rw.WriteHeader(http.StatusLengthRequired)
141 body, err := ioutil.ReadAll(r.Body)
146 type blockListRequestBody struct {
147 XMLName xml.Name `xml:"BlockList"`
151 blob, blobExists := h.blobs[container+"|"+hash]
154 case r.Method == "PUT" && r.Form.Get("comp") == "":
156 if _, ok := h.blobs[container+"|"+hash]; !ok {
157 // Like the real Azure service, we offer a
158 // race window during which other clients can
159 // list/get the new blob before any data is
161 h.blobs[container+"|"+hash] = &azBlob{
163 Uncommitted: make(map[string][]byte),
164 Metadata: make(map[string]string),
169 metadata := make(map[string]string)
170 for k, v := range r.Header {
171 if strings.HasPrefix(strings.ToLower(k), "x-ms-meta-") {
172 name := k[len("x-ms-meta-"):]
173 metadata[strings.ToLower(name)] = v[0]
176 h.blobs[container+"|"+hash] = &azBlob{
179 Uncommitted: make(map[string][]byte),
183 rw.WriteHeader(http.StatusCreated)
184 case r.Method == "PUT" && r.Form.Get("comp") == "block":
187 log.Printf("Got block for nonexistent blob: %+v", r)
188 rw.WriteHeader(http.StatusBadRequest)
191 blockID, err := base64.StdEncoding.DecodeString(r.Form.Get("blockid"))
192 if err != nil || len(blockID) == 0 {
193 log.Printf("Invalid blockid: %+q", r.Form.Get("blockid"))
194 rw.WriteHeader(http.StatusBadRequest)
197 blob.Uncommitted[string(blockID)] = body
198 rw.WriteHeader(http.StatusCreated)
199 case r.Method == "PUT" && r.Form.Get("comp") == "blocklist":
200 // "Put Block List" API
201 bl := &blockListRequestBody{}
202 if err := xml.Unmarshal(body, bl); err != nil {
203 log.Printf("xml Unmarshal: %s", err)
204 rw.WriteHeader(http.StatusBadRequest)
207 for _, encBlockID := range bl.Uncommitted {
208 blockID, err := base64.StdEncoding.DecodeString(encBlockID)
209 if err != nil || len(blockID) == 0 || blob.Uncommitted[string(blockID)] == nil {
210 log.Printf("Invalid blockid: %+q", encBlockID)
211 rw.WriteHeader(http.StatusBadRequest)
214 blob.Data = blob.Uncommitted[string(blockID)]
215 blob.Etag = makeEtag()
216 blob.Mtime = time.Now()
217 delete(blob.Uncommitted, string(blockID))
219 rw.WriteHeader(http.StatusCreated)
220 case r.Method == "PUT" && r.Form.Get("comp") == "metadata":
221 // "Set Metadata Headers" API. We don't bother
222 // stubbing "Get Metadata Headers": AzureBlobVolume
223 // sets metadata headers only as a way to bump Etag
224 // and Last-Modified.
226 log.Printf("Got metadata for nonexistent blob: %+v", r)
227 rw.WriteHeader(http.StatusBadRequest)
230 blob.Metadata = make(map[string]string)
231 for k, v := range r.Header {
232 if strings.HasPrefix(strings.ToLower(k), "x-ms-meta-") {
233 name := k[len("x-ms-meta-"):]
234 blob.Metadata[strings.ToLower(name)] = v[0]
237 blob.Mtime = time.Now()
238 blob.Etag = makeEtag()
239 case (r.Method == "GET" || r.Method == "HEAD") && r.Form.Get("comp") == "metadata" && hash != "":
240 // "Get Blob Metadata" API
242 rw.WriteHeader(http.StatusNotFound)
245 for k, v := range blob.Metadata {
246 rw.Header().Set(fmt.Sprintf("x-ms-meta-%s", k), v)
249 case (r.Method == "GET" || r.Method == "HEAD") && hash != "":
252 rw.WriteHeader(http.StatusNotFound)
256 if rangeSpec := rangeRegexp.FindStringSubmatch(r.Header.Get("Range")); rangeSpec != nil {
257 b0, err0 := strconv.Atoi(rangeSpec[1])
258 b1, err1 := strconv.Atoi(rangeSpec[2])
259 if err0 != nil || err1 != nil || b0 >= len(data) || b1 >= len(data) || b0 > b1 {
260 rw.Header().Set("Content-Range", fmt.Sprintf("bytes */%d", len(data)))
261 rw.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
264 rw.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", b0, b1, len(data)))
265 rw.WriteHeader(http.StatusPartialContent)
266 data = data[b0 : b1+1]
268 rw.Header().Set("Last-Modified", blob.Mtime.Format(time.RFC1123))
269 rw.Header().Set("Content-Length", strconv.Itoa(len(data)))
270 if r.Method == "GET" {
271 if _, err := rw.Write(data); err != nil {
272 log.Printf("write %+q: %s", data, err)
276 case r.Method == "DELETE" && hash != "":
279 rw.WriteHeader(http.StatusNotFound)
282 delete(h.blobs, container+"|"+hash)
283 rw.WriteHeader(http.StatusAccepted)
284 case r.Method == "GET" && r.Form.Get("comp") == "list" && r.Form.Get("restype") == "container":
288 rw.WriteHeader(http.StatusServiceUnavailable)
291 prefix := container + "|" + r.Form.Get("prefix")
292 marker := r.Form.Get("marker")
295 if n, err := strconv.Atoi(r.Form.Get("maxresults")); err == nil && n >= 1 && n <= 5000 {
299 resp := storage.BlobListResponse{
302 MaxResults: int64(maxResults),
304 var hashes sort.StringSlice
305 for k := range h.blobs {
306 if strings.HasPrefix(k, prefix) {
307 hashes = append(hashes, k[len(container)+1:])
311 for _, hash := range hashes {
312 if len(resp.Blobs) == maxResults {
313 resp.NextMarker = hash
316 if len(resp.Blobs) > 0 || marker == "" || marker == hash {
317 blob := h.blobs[container+"|"+hash]
318 bmeta := map[string]string(nil)
319 if r.Form.Get("include") == "metadata" {
320 bmeta = blob.Metadata
324 Properties: storage.BlobProperties{
325 LastModified: storage.TimeRFC1123(blob.Mtime),
326 ContentLength: int64(len(blob.Data)),
331 resp.Blobs = append(resp.Blobs, b)
334 buf, err := xml.Marshal(resp)
337 rw.WriteHeader(http.StatusInternalServerError)
341 log.Printf("azStubHandler: not implemented: %+v Body:%+q", r, body)
342 rw.WriteHeader(http.StatusNotImplemented)
346 // azStubDialer is a net.Dialer that notices when the Azure driver
347 // tries to connect to "devstoreaccount1.blob.127.0.0.1:46067", and
348 // in such cases transparently dials "127.0.0.1:46067" instead.
349 type azStubDialer struct {
353 var localHostPortRe = regexp.MustCompile(`(127\.0\.0\.1|localhost|\[::1\]):\d+`)
355 func (d *azStubDialer) Dial(network, address string) (net.Conn, error) {
356 if hp := localHostPortRe.FindString(address); hp != "" {
358 log.Println("azStubDialer: dial", hp, "instead of", address)
362 return d.Dialer.Dial(network, address)
365 type TestableAzureBlobVolume struct {
367 azHandler *azStubHandler
368 azStub *httptest.Server
372 func NewTestableAzureBlobVolume(t TB, readonly bool, replication int) *TestableAzureBlobVolume {
373 azHandler := newAzStubHandler()
374 azStub := httptest.NewServer(azHandler)
376 var azClient storage.Client
378 container := azureTestContainer
380 // Connect to stub instead of real Azure storage service
381 stubURLBase := strings.Split(azStub.URL, "://")[1]
383 if azClient, err = storage.NewClient(fakeAccountName, fakeAccountKey, stubURLBase, storage.DefaultAPIVersion, false); err != nil {
386 container = "fakecontainername"
388 // Connect to real Azure storage service
389 accountKey, err := readKeyFromFile(azureStorageAccountKeyFile)
393 azClient, err = storage.NewBasicClient(azureStorageAccountName, accountKey)
398 azClient.Sender = &singleSender{}
400 bs := azClient.GetBlobService()
401 v := &AzureBlobVolume{
402 ContainerName: container,
404 AzureReplication: replication,
405 ListBlobsMaxAttempts: 2,
406 ListBlobsRetryDelay: arvados.Duration(time.Millisecond),
408 container: &azureContainer{ctr: bs.GetContainerReference(container)},
411 return &TestableAzureBlobVolume{
413 azHandler: azHandler,
419 var _ = check.Suite(&StubbedAzureBlobSuite{})
421 type StubbedAzureBlobSuite struct {
422 volume *TestableAzureBlobVolume
423 origHTTPTransport http.RoundTripper
426 func (s *StubbedAzureBlobSuite) SetUpTest(c *check.C) {
427 s.origHTTPTransport = http.DefaultTransport
428 http.DefaultTransport = &http.Transport{
429 Dial: (&azStubDialer{}).Dial,
431 azureWriteRaceInterval = time.Millisecond
432 azureWriteRacePollTime = time.Nanosecond
434 s.volume = NewTestableAzureBlobVolume(c, false, 3)
437 func (s *StubbedAzureBlobSuite) TearDownTest(c *check.C) {
439 http.DefaultTransport = s.origHTTPTransport
442 func TestAzureBlobVolumeWithGeneric(t *testing.T) {
443 defer func(t http.RoundTripper) {
444 http.DefaultTransport = t
445 }(http.DefaultTransport)
446 http.DefaultTransport = &http.Transport{
447 Dial: (&azStubDialer{}).Dial,
449 azureWriteRaceInterval = time.Millisecond
450 azureWriteRacePollTime = time.Nanosecond
451 DoGenericVolumeTests(t, func(t TB) TestableVolume {
452 return NewTestableAzureBlobVolume(t, false, azureStorageReplication)
456 func TestAzureBlobVolumeConcurrentRanges(t *testing.T) {
461 defer func(t http.RoundTripper) {
462 http.DefaultTransport = t
463 }(http.DefaultTransport)
464 http.DefaultTransport = &http.Transport{
465 Dial: (&azStubDialer{}).Dial,
467 azureWriteRaceInterval = time.Millisecond
468 azureWriteRacePollTime = time.Nanosecond
469 // Test (BlockSize mod azureMaxGetBytes)==0 and !=0 cases
470 for _, azureMaxGetBytes = range []int{2 << 22, 2<<22 - 1} {
471 DoGenericVolumeTests(t, func(t TB) TestableVolume {
472 return NewTestableAzureBlobVolume(t, false, azureStorageReplication)
477 func TestReadonlyAzureBlobVolumeWithGeneric(t *testing.T) {
478 defer func(t http.RoundTripper) {
479 http.DefaultTransport = t
480 }(http.DefaultTransport)
481 http.DefaultTransport = &http.Transport{
482 Dial: (&azStubDialer{}).Dial,
484 azureWriteRaceInterval = time.Millisecond
485 azureWriteRacePollTime = time.Nanosecond
486 DoGenericVolumeTests(t, func(t TB) TestableVolume {
487 return NewTestableAzureBlobVolume(t, true, azureStorageReplication)
491 func TestAzureBlobVolumeRangeFenceposts(t *testing.T) {
492 defer func(t http.RoundTripper) {
493 http.DefaultTransport = t
494 }(http.DefaultTransport)
495 http.DefaultTransport = &http.Transport{
496 Dial: (&azStubDialer{}).Dial,
499 v := NewTestableAzureBlobVolume(t, false, 3)
502 for _, size := range []int{
503 2<<22 - 1, // one <max read
504 2 << 22, // one =max read
505 2<<22 + 1, // one =max read, one <max
506 2 << 23, // two =max reads
510 data := make([]byte, size)
511 for i := range data {
512 data[i] = byte((i + 7) & 0xff)
514 hash := fmt.Sprintf("%x", md5.Sum(data))
515 err := v.Put(context.Background(), hash, data)
519 gotData := make([]byte, len(data))
520 gotLen, err := v.Get(context.Background(), hash, gotData)
524 gotHash := fmt.Sprintf("%x", md5.Sum(gotData))
526 t.Errorf("length mismatch: got %d != %d", gotLen, size)
529 t.Errorf("hash mismatch: got %s != %s", gotHash, hash)
534 func TestAzureBlobVolumeReplication(t *testing.T) {
535 for r := 1; r <= 4; r++ {
536 v := NewTestableAzureBlobVolume(t, false, r)
538 if n := v.Replication(); n != r {
539 t.Errorf("Got replication %d, expected %d", n, r)
544 func TestAzureBlobVolumeCreateBlobRace(t *testing.T) {
545 defer func(t http.RoundTripper) {
546 http.DefaultTransport = t
547 }(http.DefaultTransport)
548 http.DefaultTransport = &http.Transport{
549 Dial: (&azStubDialer{}).Dial,
552 v := NewTestableAzureBlobVolume(t, false, 3)
555 azureWriteRaceInterval = time.Second
556 azureWriteRacePollTime = time.Millisecond
558 var wg sync.WaitGroup
560 v.azHandler.race = make(chan chan struct{})
565 err := v.Put(context.Background(), TestHash, TestBlock)
570 continuePut := make(chan struct{})
571 // Wait for the stub's Put to create the empty blob
572 v.azHandler.race <- continuePut
576 buf := make([]byte, len(TestBlock))
577 _, err := v.Get(context.Background(), TestHash, buf)
582 // Wait for the stub's Get to get the empty blob
583 close(v.azHandler.race)
584 // Allow stub's Put to continue, so the real data is ready
585 // when the volume's Get retries
587 // Wait for Get() and Put() to finish
591 func TestAzureBlobVolumeCreateBlobRaceDeadline(t *testing.T) {
592 defer func(t http.RoundTripper) {
593 http.DefaultTransport = t
594 }(http.DefaultTransport)
595 http.DefaultTransport = &http.Transport{
596 Dial: (&azStubDialer{}).Dial,
599 v := NewTestableAzureBlobVolume(t, false, 3)
602 azureWriteRaceInterval = 2 * time.Second
603 azureWriteRacePollTime = 5 * time.Millisecond
605 v.PutRaw(TestHash, nil)
607 buf := new(bytes.Buffer)
610 t.Errorf("Index %+q should be empty", buf.Bytes())
613 v.TouchWithDate(TestHash, time.Now().Add(-1982*time.Millisecond))
615 allDone := make(chan struct{})
618 buf := make([]byte, BlockSize)
619 n, err := v.Get(context.Background(), TestHash, buf)
625 t.Errorf("Got %+q, expected empty buf", buf[:n])
630 case <-time.After(time.Second):
631 t.Error("Get should have stopped waiting for race when block was 2s old")
636 if !bytes.HasPrefix(buf.Bytes(), []byte(TestHash+"+0")) {
637 t.Errorf("Index %+q should have %+q", buf.Bytes(), TestHash+"+0")
641 func TestAzureBlobVolumeContextCancelGet(t *testing.T) {
642 testAzureBlobVolumeContextCancel(t, func(ctx context.Context, v *TestableAzureBlobVolume) error {
643 v.PutRaw(TestHash, TestBlock)
644 _, err := v.Get(ctx, TestHash, make([]byte, BlockSize))
649 func TestAzureBlobVolumeContextCancelPut(t *testing.T) {
650 testAzureBlobVolumeContextCancel(t, func(ctx context.Context, v *TestableAzureBlobVolume) error {
651 return v.Put(ctx, TestHash, make([]byte, BlockSize))
655 func TestAzureBlobVolumeContextCancelCompare(t *testing.T) {
656 testAzureBlobVolumeContextCancel(t, func(ctx context.Context, v *TestableAzureBlobVolume) error {
657 v.PutRaw(TestHash, TestBlock)
658 return v.Compare(ctx, TestHash, TestBlock2)
662 func testAzureBlobVolumeContextCancel(t *testing.T, testFunc func(context.Context, *TestableAzureBlobVolume) error) {
663 defer func(t http.RoundTripper) {
664 http.DefaultTransport = t
665 }(http.DefaultTransport)
666 http.DefaultTransport = &http.Transport{
667 Dial: (&azStubDialer{}).Dial,
670 v := NewTestableAzureBlobVolume(t, false, 3)
672 v.azHandler.race = make(chan chan struct{})
674 ctx, cancel := context.WithCancel(context.Background())
675 allDone := make(chan struct{})
678 err := testFunc(ctx, v)
679 if err != context.Canceled {
680 t.Errorf("got %T %q, expected %q", err, err, context.Canceled)
683 releaseHandler := make(chan struct{})
686 t.Error("testFunc finished without waiting for v.azHandler.race")
687 case <-time.After(10 * time.Second):
688 t.Error("timed out waiting to enter handler")
689 case v.azHandler.race <- releaseHandler:
695 case <-time.After(10 * time.Second):
696 t.Error("timed out waiting to cancel")
705 func (s *StubbedAzureBlobSuite) TestStats(c *check.C) {
706 stats := func() string {
707 buf, err := json.Marshal(s.volume.InternalStats())
708 c.Check(err, check.IsNil)
712 c.Check(stats(), check.Matches, `.*"Ops":0,.*`)
713 c.Check(stats(), check.Matches, `.*"Errors":0,.*`)
715 loc := "acbd18db4cc2f85cedef654fccc4a4d8"
716 _, err := s.volume.Get(context.Background(), loc, make([]byte, 3))
717 c.Check(err, check.NotNil)
718 c.Check(stats(), check.Matches, `.*"Ops":[^0],.*`)
719 c.Check(stats(), check.Matches, `.*"Errors":[^0],.*`)
720 c.Check(stats(), check.Matches, `.*"storage\.AzureStorageServiceError 404 \(404 Not Found\)":[^0].*`)
721 c.Check(stats(), check.Matches, `.*"InBytes":0,.*`)
723 err = s.volume.Put(context.Background(), loc, []byte("foo"))
724 c.Check(err, check.IsNil)
725 c.Check(stats(), check.Matches, `.*"OutBytes":3,.*`)
726 c.Check(stats(), check.Matches, `.*"CreateOps":1,.*`)
728 _, err = s.volume.Get(context.Background(), loc, make([]byte, 3))
729 c.Check(err, check.IsNil)
730 _, err = s.volume.Get(context.Background(), loc, make([]byte, 3))
731 c.Check(err, check.IsNil)
732 c.Check(stats(), check.Matches, `.*"InBytes":6,.*`)
735 func (s *StubbedAzureBlobSuite) TestConfig(c *check.C) {
737 err := yaml.Unmarshal([]byte(`
740 StorageClasses: ["class_a", "class_b"]
743 c.Check(err, check.IsNil)
744 c.Check(cfg.Volumes[0].GetStorageClasses(), check.DeepEquals, []string{"class_a", "class_b"})
747 func (v *TestableAzureBlobVolume) PutRaw(locator string, data []byte) {
748 v.azHandler.PutRaw(v.ContainerName, locator, data)
751 func (v *TestableAzureBlobVolume) TouchWithDate(locator string, lastPut time.Time) {
752 v.azHandler.TouchWithDate(v.ContainerName, locator, lastPut)
755 func (v *TestableAzureBlobVolume) Teardown() {
759 func (v *TestableAzureBlobVolume) ReadWriteOperationLabelValues() (r, w string) {
760 return "get", "create"
763 func (v *TestableAzureBlobVolume) DeviceID() string {
764 // Dummy device id for testing purposes
765 return "azure://azure_blob_volume_test"
768 func (v *TestableAzureBlobVolume) Start(vm *volumeMetricsVecs) error {
769 // Override original Start() to be able to assign CounterVecs with a dummy DeviceID
770 v.container.stats.opsCounters, v.container.stats.errCounters, v.container.stats.ioBytes = vm.getCounterVecsFor(prometheus.Labels{"device_id": v.DeviceID()})
774 func makeEtag() string {
775 return fmt.Sprintf("0x%x", rand.Int63())