1 // Copyright (C) The Arvados Authors. All rights reserved.
3 // SPDX-License-Identifier: AGPL-3.0
30 "github.com/Azure/azure-sdk-for-go/storage"
31 "github.com/ghodss/yaml"
32 "github.com/prometheus/client_golang/prometheus"
33 check "gopkg.in/check.v1"
37 // This cannot be the fake account name "devstoreaccount1"
38 // used by Microsoft's Azure emulator: the Azure SDK
39 // recognizes that magic string and changes its behavior to
40 // cater to the Azure SDK's own test suite.
41 fakeAccountName = "fakeaccountname"
42 fakeAccountKey = "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw=="
46 azureTestContainer string
47 azureTestDebug = os.Getenv("ARVADOS_DEBUG") != ""
53 "test.azure-storage-container-volume",
55 "Name of Azure container to use for testing. Do not use a container with real data! Use -azure-storage-account-name and -azure-storage-key-file arguments to supply credentials.")
61 Metadata map[string]string
63 Uncommitted map[string][]byte
66 type azStubHandler struct {
68 blobs map[string]*azBlob
69 race chan chan struct{}
72 func newAzStubHandler() *azStubHandler {
73 return &azStubHandler{
74 blobs: make(map[string]*azBlob),
78 func (h *azStubHandler) TouchWithDate(container, hash string, t time.Time) {
79 blob, ok := h.blobs[container+"|"+hash]
86 func (h *azStubHandler) PutRaw(container, hash string, data []byte) {
89 h.blobs[container+"|"+hash] = &azBlob{
92 Metadata: make(map[string]string),
93 Uncommitted: make(map[string][]byte),
97 func (h *azStubHandler) unlockAndRace() {
102 // Signal caller that race is starting by reading from
103 // h.race. If we get a channel, block until that channel is
104 // ready to receive. If we get nil (or h.race is closed) just
106 if c := <-h.race; c != nil {
112 var rangeRegexp = regexp.MustCompile(`^bytes=(\d+)-(\d+)$`)
114 func (h *azStubHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
118 defer log.Printf("azStubHandler: %+v", r)
121 path := strings.Split(r.URL.Path, "/")
128 if err := r.ParseForm(); err != nil {
129 log.Printf("azStubHandler(%+v): %s", r, err)
130 rw.WriteHeader(http.StatusBadRequest)
134 if (r.Method == "PUT" || r.Method == "POST") && r.Header.Get("Content-Length") == "" {
135 rw.WriteHeader(http.StatusLengthRequired)
139 body, err := ioutil.ReadAll(r.Body)
144 type blockListRequestBody struct {
145 XMLName xml.Name `xml:"BlockList"`
149 blob, blobExists := h.blobs[container+"|"+hash]
152 case r.Method == "PUT" && r.Form.Get("comp") == "":
154 if _, ok := h.blobs[container+"|"+hash]; !ok {
155 // Like the real Azure service, we offer a
156 // race window during which other clients can
157 // list/get the new blob before any data is
159 h.blobs[container+"|"+hash] = &azBlob{
161 Uncommitted: make(map[string][]byte),
162 Metadata: make(map[string]string),
167 metadata := make(map[string]string)
168 for k, v := range r.Header {
169 if strings.HasPrefix(strings.ToLower(k), "x-ms-meta-") {
170 name := k[len("x-ms-meta-"):]
171 metadata[strings.ToLower(name)] = v[0]
174 h.blobs[container+"|"+hash] = &azBlob{
177 Uncommitted: make(map[string][]byte),
181 rw.WriteHeader(http.StatusCreated)
182 case r.Method == "PUT" && r.Form.Get("comp") == "block":
185 log.Printf("Got block for nonexistent blob: %+v", r)
186 rw.WriteHeader(http.StatusBadRequest)
189 blockID, err := base64.StdEncoding.DecodeString(r.Form.Get("blockid"))
190 if err != nil || len(blockID) == 0 {
191 log.Printf("Invalid blockid: %+q", r.Form.Get("blockid"))
192 rw.WriteHeader(http.StatusBadRequest)
195 blob.Uncommitted[string(blockID)] = body
196 rw.WriteHeader(http.StatusCreated)
197 case r.Method == "PUT" && r.Form.Get("comp") == "blocklist":
198 // "Put Block List" API
199 bl := &blockListRequestBody{}
200 if err := xml.Unmarshal(body, bl); err != nil {
201 log.Printf("xml Unmarshal: %s", err)
202 rw.WriteHeader(http.StatusBadRequest)
205 for _, encBlockID := range bl.Uncommitted {
206 blockID, err := base64.StdEncoding.DecodeString(encBlockID)
207 if err != nil || len(blockID) == 0 || blob.Uncommitted[string(blockID)] == nil {
208 log.Printf("Invalid blockid: %+q", encBlockID)
209 rw.WriteHeader(http.StatusBadRequest)
212 blob.Data = blob.Uncommitted[string(blockID)]
213 blob.Etag = makeEtag()
214 blob.Mtime = time.Now()
215 delete(blob.Uncommitted, string(blockID))
217 rw.WriteHeader(http.StatusCreated)
218 case r.Method == "PUT" && r.Form.Get("comp") == "metadata":
219 // "Set Metadata Headers" API. We don't bother
220 // stubbing "Get Metadata Headers": AzureBlobVolume
221 // sets metadata headers only as a way to bump Etag
222 // and Last-Modified.
224 log.Printf("Got metadata for nonexistent blob: %+v", r)
225 rw.WriteHeader(http.StatusBadRequest)
228 blob.Metadata = make(map[string]string)
229 for k, v := range r.Header {
230 if strings.HasPrefix(strings.ToLower(k), "x-ms-meta-") {
231 name := k[len("x-ms-meta-"):]
232 blob.Metadata[strings.ToLower(name)] = v[0]
235 blob.Mtime = time.Now()
236 blob.Etag = makeEtag()
237 case (r.Method == "GET" || r.Method == "HEAD") && r.Form.Get("comp") == "metadata" && hash != "":
238 // "Get Blob Metadata" API
240 rw.WriteHeader(http.StatusNotFound)
243 for k, v := range blob.Metadata {
244 rw.Header().Set(fmt.Sprintf("x-ms-meta-%s", k), v)
247 case (r.Method == "GET" || r.Method == "HEAD") && hash != "":
250 rw.WriteHeader(http.StatusNotFound)
254 if rangeSpec := rangeRegexp.FindStringSubmatch(r.Header.Get("Range")); rangeSpec != nil {
255 b0, err0 := strconv.Atoi(rangeSpec[1])
256 b1, err1 := strconv.Atoi(rangeSpec[2])
257 if err0 != nil || err1 != nil || b0 >= len(data) || b1 >= len(data) || b0 > b1 {
258 rw.Header().Set("Content-Range", fmt.Sprintf("bytes */%d", len(data)))
259 rw.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
262 rw.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", b0, b1, len(data)))
263 rw.WriteHeader(http.StatusPartialContent)
264 data = data[b0 : b1+1]
266 rw.Header().Set("Last-Modified", blob.Mtime.Format(time.RFC1123))
267 rw.Header().Set("Content-Length", strconv.Itoa(len(data)))
268 if r.Method == "GET" {
269 if _, err := rw.Write(data); err != nil {
270 log.Printf("write %+q: %s", data, err)
274 case r.Method == "DELETE" && hash != "":
277 rw.WriteHeader(http.StatusNotFound)
280 delete(h.blobs, container+"|"+hash)
281 rw.WriteHeader(http.StatusAccepted)
282 case r.Method == "GET" && r.Form.Get("comp") == "list" && r.Form.Get("restype") == "container":
284 prefix := container + "|" + r.Form.Get("prefix")
285 marker := r.Form.Get("marker")
288 if n, err := strconv.Atoi(r.Form.Get("maxresults")); err == nil && n >= 1 && n <= 5000 {
292 resp := storage.BlobListResponse{
295 MaxResults: int64(maxResults),
297 var hashes sort.StringSlice
298 for k := range h.blobs {
299 if strings.HasPrefix(k, prefix) {
300 hashes = append(hashes, k[len(container)+1:])
304 for _, hash := range hashes {
305 if len(resp.Blobs) == maxResults {
306 resp.NextMarker = hash
309 if len(resp.Blobs) > 0 || marker == "" || marker == hash {
310 blob := h.blobs[container+"|"+hash]
311 bmeta := map[string]string(nil)
312 if r.Form.Get("include") == "metadata" {
313 bmeta = blob.Metadata
317 Properties: storage.BlobProperties{
318 LastModified: storage.TimeRFC1123(blob.Mtime),
319 ContentLength: int64(len(blob.Data)),
324 resp.Blobs = append(resp.Blobs, b)
327 buf, err := xml.Marshal(resp)
330 rw.WriteHeader(http.StatusInternalServerError)
334 log.Printf("azStubHandler: not implemented: %+v Body:%+q", r, body)
335 rw.WriteHeader(http.StatusNotImplemented)
339 // azStubDialer is a net.Dialer that notices when the Azure driver
340 // tries to connect to "devstoreaccount1.blob.127.0.0.1:46067", and
341 // in such cases transparently dials "127.0.0.1:46067" instead.
342 type azStubDialer struct {
346 var localHostPortRe = regexp.MustCompile(`(127\.0\.0\.1|localhost|\[::1\]):\d+`)
348 func (d *azStubDialer) Dial(network, address string) (net.Conn, error) {
349 if hp := localHostPortRe.FindString(address); hp != "" {
351 log.Println("azStubDialer: dial", hp, "instead of", address)
355 return d.Dialer.Dial(network, address)
358 type TestableAzureBlobVolume struct {
360 azHandler *azStubHandler
361 azStub *httptest.Server
365 func NewTestableAzureBlobVolume(t TB, readonly bool, replication int) *TestableAzureBlobVolume {
366 azHandler := newAzStubHandler()
367 azStub := httptest.NewServer(azHandler)
369 var azClient storage.Client
371 container := azureTestContainer
373 // Connect to stub instead of real Azure storage service
374 stubURLBase := strings.Split(azStub.URL, "://")[1]
376 if azClient, err = storage.NewClient(fakeAccountName, fakeAccountKey, stubURLBase, storage.DefaultAPIVersion, false); err != nil {
379 container = "fakecontainername"
381 // Connect to real Azure storage service
382 accountKey, err := readKeyFromFile(azureStorageAccountKeyFile)
386 azClient, err = storage.NewBasicClient(azureStorageAccountName, accountKey)
392 bs := azClient.GetBlobService()
393 v := &AzureBlobVolume{
394 ContainerName: container,
396 AzureReplication: replication,
398 container: &azureContainer{ctr: bs.GetContainerReference(container)},
401 return &TestableAzureBlobVolume{
403 azHandler: azHandler,
409 var _ = check.Suite(&StubbedAzureBlobSuite{})
411 type StubbedAzureBlobSuite struct {
412 volume *TestableAzureBlobVolume
413 origHTTPTransport http.RoundTripper
416 func (s *StubbedAzureBlobSuite) SetUpTest(c *check.C) {
417 s.origHTTPTransport = http.DefaultTransport
418 http.DefaultTransport = &http.Transport{
419 Dial: (&azStubDialer{}).Dial,
421 azureWriteRaceInterval = time.Millisecond
422 azureWriteRacePollTime = time.Nanosecond
424 s.volume = NewTestableAzureBlobVolume(c, false, 3)
427 func (s *StubbedAzureBlobSuite) TearDownTest(c *check.C) {
429 http.DefaultTransport = s.origHTTPTransport
432 func TestAzureBlobVolumeWithGeneric(t *testing.T) {
433 defer func(t http.RoundTripper) {
434 http.DefaultTransport = t
435 }(http.DefaultTransport)
436 http.DefaultTransport = &http.Transport{
437 Dial: (&azStubDialer{}).Dial,
439 azureWriteRaceInterval = time.Millisecond
440 azureWriteRacePollTime = time.Nanosecond
441 DoGenericVolumeTests(t, func(t TB) TestableVolume {
442 return NewTestableAzureBlobVolume(t, false, azureStorageReplication)
446 func TestAzureBlobVolumeConcurrentRanges(t *testing.T) {
451 defer func(t http.RoundTripper) {
452 http.DefaultTransport = t
453 }(http.DefaultTransport)
454 http.DefaultTransport = &http.Transport{
455 Dial: (&azStubDialer{}).Dial,
457 azureWriteRaceInterval = time.Millisecond
458 azureWriteRacePollTime = time.Nanosecond
459 // Test (BlockSize mod azureMaxGetBytes)==0 and !=0 cases
460 for _, azureMaxGetBytes = range []int{2 << 22, 2<<22 - 1} {
461 DoGenericVolumeTests(t, func(t TB) TestableVolume {
462 return NewTestableAzureBlobVolume(t, false, azureStorageReplication)
467 func TestReadonlyAzureBlobVolumeWithGeneric(t *testing.T) {
468 defer func(t http.RoundTripper) {
469 http.DefaultTransport = t
470 }(http.DefaultTransport)
471 http.DefaultTransport = &http.Transport{
472 Dial: (&azStubDialer{}).Dial,
474 azureWriteRaceInterval = time.Millisecond
475 azureWriteRacePollTime = time.Nanosecond
476 DoGenericVolumeTests(t, func(t TB) TestableVolume {
477 return NewTestableAzureBlobVolume(t, true, azureStorageReplication)
481 func TestAzureBlobVolumeRangeFenceposts(t *testing.T) {
482 defer func(t http.RoundTripper) {
483 http.DefaultTransport = t
484 }(http.DefaultTransport)
485 http.DefaultTransport = &http.Transport{
486 Dial: (&azStubDialer{}).Dial,
489 v := NewTestableAzureBlobVolume(t, false, 3)
492 for _, size := range []int{
493 2<<22 - 1, // one <max read
494 2 << 22, // one =max read
495 2<<22 + 1, // one =max read, one <max
496 2 << 23, // two =max reads
500 data := make([]byte, size)
501 for i := range data {
502 data[i] = byte((i + 7) & 0xff)
504 hash := fmt.Sprintf("%x", md5.Sum(data))
505 err := v.Put(context.Background(), hash, data)
509 gotData := make([]byte, len(data))
510 gotLen, err := v.Get(context.Background(), hash, gotData)
514 gotHash := fmt.Sprintf("%x", md5.Sum(gotData))
516 t.Errorf("length mismatch: got %d != %d", gotLen, size)
519 t.Errorf("hash mismatch: got %s != %s", gotHash, hash)
524 func TestAzureBlobVolumeReplication(t *testing.T) {
525 for r := 1; r <= 4; r++ {
526 v := NewTestableAzureBlobVolume(t, false, r)
528 if n := v.Replication(); n != r {
529 t.Errorf("Got replication %d, expected %d", n, r)
534 func TestAzureBlobVolumeCreateBlobRace(t *testing.T) {
535 defer func(t http.RoundTripper) {
536 http.DefaultTransport = t
537 }(http.DefaultTransport)
538 http.DefaultTransport = &http.Transport{
539 Dial: (&azStubDialer{}).Dial,
542 v := NewTestableAzureBlobVolume(t, false, 3)
545 azureWriteRaceInterval = time.Second
546 azureWriteRacePollTime = time.Millisecond
548 var wg sync.WaitGroup
550 v.azHandler.race = make(chan chan struct{})
555 err := v.Put(context.Background(), TestHash, TestBlock)
560 continuePut := make(chan struct{})
561 // Wait for the stub's Put to create the empty blob
562 v.azHandler.race <- continuePut
566 buf := make([]byte, len(TestBlock))
567 _, err := v.Get(context.Background(), TestHash, buf)
572 // Wait for the stub's Get to get the empty blob
573 close(v.azHandler.race)
574 // Allow stub's Put to continue, so the real data is ready
575 // when the volume's Get retries
577 // Wait for Get() and Put() to finish
581 func TestAzureBlobVolumeCreateBlobRaceDeadline(t *testing.T) {
582 defer func(t http.RoundTripper) {
583 http.DefaultTransport = t
584 }(http.DefaultTransport)
585 http.DefaultTransport = &http.Transport{
586 Dial: (&azStubDialer{}).Dial,
589 v := NewTestableAzureBlobVolume(t, false, 3)
592 azureWriteRaceInterval = 2 * time.Second
593 azureWriteRacePollTime = 5 * time.Millisecond
595 v.PutRaw(TestHash, nil)
597 buf := new(bytes.Buffer)
600 t.Errorf("Index %+q should be empty", buf.Bytes())
603 v.TouchWithDate(TestHash, time.Now().Add(-1982*time.Millisecond))
605 allDone := make(chan struct{})
608 buf := make([]byte, BlockSize)
609 n, err := v.Get(context.Background(), TestHash, buf)
615 t.Errorf("Got %+q, expected empty buf", buf[:n])
620 case <-time.After(time.Second):
621 t.Error("Get should have stopped waiting for race when block was 2s old")
626 if !bytes.HasPrefix(buf.Bytes(), []byte(TestHash+"+0")) {
627 t.Errorf("Index %+q should have %+q", buf.Bytes(), TestHash+"+0")
631 func TestAzureBlobVolumeContextCancelGet(t *testing.T) {
632 testAzureBlobVolumeContextCancel(t, func(ctx context.Context, v *TestableAzureBlobVolume) error {
633 v.PutRaw(TestHash, TestBlock)
634 _, err := v.Get(ctx, TestHash, make([]byte, BlockSize))
639 func TestAzureBlobVolumeContextCancelPut(t *testing.T) {
640 testAzureBlobVolumeContextCancel(t, func(ctx context.Context, v *TestableAzureBlobVolume) error {
641 return v.Put(ctx, TestHash, make([]byte, BlockSize))
645 func TestAzureBlobVolumeContextCancelCompare(t *testing.T) {
646 testAzureBlobVolumeContextCancel(t, func(ctx context.Context, v *TestableAzureBlobVolume) error {
647 v.PutRaw(TestHash, TestBlock)
648 return v.Compare(ctx, TestHash, TestBlock2)
652 func testAzureBlobVolumeContextCancel(t *testing.T, testFunc func(context.Context, *TestableAzureBlobVolume) error) {
653 defer func(t http.RoundTripper) {
654 http.DefaultTransport = t
655 }(http.DefaultTransport)
656 http.DefaultTransport = &http.Transport{
657 Dial: (&azStubDialer{}).Dial,
660 v := NewTestableAzureBlobVolume(t, false, 3)
662 v.azHandler.race = make(chan chan struct{})
664 ctx, cancel := context.WithCancel(context.Background())
665 allDone := make(chan struct{})
668 err := testFunc(ctx, v)
669 if err != context.Canceled {
670 t.Errorf("got %T %q, expected %q", err, err, context.Canceled)
673 releaseHandler := make(chan struct{})
676 t.Error("testFunc finished without waiting for v.azHandler.race")
677 case <-time.After(10 * time.Second):
678 t.Error("timed out waiting to enter handler")
679 case v.azHandler.race <- releaseHandler:
685 case <-time.After(10 * time.Second):
686 t.Error("timed out waiting to cancel")
695 func (s *StubbedAzureBlobSuite) TestStats(c *check.C) {
696 stats := func() string {
697 buf, err := json.Marshal(s.volume.InternalStats())
698 c.Check(err, check.IsNil)
702 c.Check(stats(), check.Matches, `.*"Ops":0,.*`)
703 c.Check(stats(), check.Matches, `.*"Errors":0,.*`)
705 loc := "acbd18db4cc2f85cedef654fccc4a4d8"
706 _, err := s.volume.Get(context.Background(), loc, make([]byte, 3))
707 c.Check(err, check.NotNil)
708 c.Check(stats(), check.Matches, `.*"Ops":[^0],.*`)
709 c.Check(stats(), check.Matches, `.*"Errors":[^0],.*`)
710 c.Check(stats(), check.Matches, `.*"storage\.AzureStorageServiceError 404 \(404 Not Found\)":[^0].*`)
711 c.Check(stats(), check.Matches, `.*"InBytes":0,.*`)
713 err = s.volume.Put(context.Background(), loc, []byte("foo"))
714 c.Check(err, check.IsNil)
715 c.Check(stats(), check.Matches, `.*"OutBytes":3,.*`)
716 c.Check(stats(), check.Matches, `.*"CreateOps":1,.*`)
718 _, err = s.volume.Get(context.Background(), loc, make([]byte, 3))
719 c.Check(err, check.IsNil)
720 _, err = s.volume.Get(context.Background(), loc, make([]byte, 3))
721 c.Check(err, check.IsNil)
722 c.Check(stats(), check.Matches, `.*"InBytes":6,.*`)
725 func (s *StubbedAzureBlobSuite) TestConfig(c *check.C) {
727 err := yaml.Unmarshal([]byte(`
730 StorageClasses: ["class_a", "class_b"]
733 c.Check(err, check.IsNil)
734 c.Check(cfg.Volumes[0].GetStorageClasses(), check.DeepEquals, []string{"class_a", "class_b"})
737 func (v *TestableAzureBlobVolume) PutRaw(locator string, data []byte) {
738 v.azHandler.PutRaw(v.ContainerName, locator, data)
741 func (v *TestableAzureBlobVolume) TouchWithDate(locator string, lastPut time.Time) {
742 v.azHandler.TouchWithDate(v.ContainerName, locator, lastPut)
745 func (v *TestableAzureBlobVolume) Teardown() {
749 func (v *TestableAzureBlobVolume) ReadWriteOperationLabelValues() (r, w string) {
750 return "get", "create"
753 func (v *TestableAzureBlobVolume) DeviceID() string {
754 // Dummy device id for testing purposes
755 return "azure://azure_blob_volume_test"
758 func (v *TestableAzureBlobVolume) Start(vm *volumeMetricsVecs) error {
759 // Override original Start() to be able to assign CounterVecs with a dummy DeviceID
760 v.container.stats.opsCounters, v.container.stats.errCounters, v.container.stats.ioBytes = vm.getCounterVecsFor(prometheus.Labels{"device_id": v.DeviceID()})
764 func makeEtag() string {
765 return fmt.Sprintf("0x%x", rand.Int63())