1 // Copyright (C) The Arvados Authors. All rights reserved.
3 // SPDX-License-Identifier: AGPL-3.0
29 "github.com/Azure/azure-sdk-for-go/storage"
30 "github.com/ghodss/yaml"
31 check "gopkg.in/check.v1"
35 // This cannot be the fake account name "devstoreaccount1"
36 // used by Microsoft's Azure emulator: the Azure SDK
37 // recognizes that magic string and changes its behavior to
38 // cater to the Azure SDK's own test suite.
39 fakeAccountName = "fakeaccountname"
40 fakeAccountKey = "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw=="
43 var azureTestContainer string
48 "test.azure-storage-container-volume",
50 "Name of Azure container to use for testing. Do not use a container with real data! Use -azure-storage-account-name and -azure-storage-key-file arguments to supply credentials.")
56 Metadata map[string]string
58 Uncommitted map[string][]byte
61 type azStubHandler struct {
63 blobs map[string]*azBlob
64 race chan chan struct{}
67 func newAzStubHandler() *azStubHandler {
68 return &azStubHandler{
69 blobs: make(map[string]*azBlob),
73 func (h *azStubHandler) TouchWithDate(container, hash string, t time.Time) {
74 blob, ok := h.blobs[container+"|"+hash]
81 func (h *azStubHandler) PutRaw(container, hash string, data []byte) {
84 h.blobs[container+"|"+hash] = &azBlob{
87 Metadata: make(map[string]string),
88 Uncommitted: make(map[string][]byte),
92 func (h *azStubHandler) unlockAndRace() {
97 // Signal caller that race is starting by reading from
98 // h.race. If we get a channel, block until that channel is
99 // ready to receive. If we get nil (or h.race is closed) just
101 if c := <-h.race; c != nil {
107 var rangeRegexp = regexp.MustCompile(`^bytes=(\d+)-(\d+)$`)
109 func (h *azStubHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
112 // defer log.Printf("azStubHandler: %+v", r)
114 path := strings.Split(r.URL.Path, "/")
121 if err := r.ParseForm(); err != nil {
122 log.Printf("azStubHandler(%+v): %s", r, err)
123 rw.WriteHeader(http.StatusBadRequest)
127 if (r.Method == "PUT" || r.Method == "POST") && r.Header.Get("Content-Length") == "" {
128 rw.WriteHeader(http.StatusLengthRequired)
132 body, err := ioutil.ReadAll(r.Body)
137 type blockListRequestBody struct {
138 XMLName xml.Name `xml:"BlockList"`
142 blob, blobExists := h.blobs[container+"|"+hash]
145 case r.Method == "PUT" && r.Form.Get("comp") == "":
147 if _, ok := h.blobs[container+"|"+hash]; !ok {
148 // Like the real Azure service, we offer a
149 // race window during which other clients can
150 // list/get the new blob before any data is
152 h.blobs[container+"|"+hash] = &azBlob{
154 Uncommitted: make(map[string][]byte),
155 Metadata: make(map[string]string),
160 metadata := make(map[string]string)
161 for k, v := range r.Header {
162 if strings.HasPrefix(strings.ToLower(k), "x-ms-meta-") {
163 name := k[len("x-ms-meta-"):]
164 metadata[strings.ToLower(name)] = v[0]
167 h.blobs[container+"|"+hash] = &azBlob{
170 Uncommitted: make(map[string][]byte),
174 rw.WriteHeader(http.StatusCreated)
175 case r.Method == "PUT" && r.Form.Get("comp") == "block":
178 log.Printf("Got block for nonexistent blob: %+v", r)
179 rw.WriteHeader(http.StatusBadRequest)
182 blockID, err := base64.StdEncoding.DecodeString(r.Form.Get("blockid"))
183 if err != nil || len(blockID) == 0 {
184 log.Printf("Invalid blockid: %+q", r.Form.Get("blockid"))
185 rw.WriteHeader(http.StatusBadRequest)
188 blob.Uncommitted[string(blockID)] = body
189 rw.WriteHeader(http.StatusCreated)
190 case r.Method == "PUT" && r.Form.Get("comp") == "blocklist":
191 // "Put Block List" API
192 bl := &blockListRequestBody{}
193 if err := xml.Unmarshal(body, bl); err != nil {
194 log.Printf("xml Unmarshal: %s", err)
195 rw.WriteHeader(http.StatusBadRequest)
198 for _, encBlockID := range bl.Uncommitted {
199 blockID, err := base64.StdEncoding.DecodeString(encBlockID)
200 if err != nil || len(blockID) == 0 || blob.Uncommitted[string(blockID)] == nil {
201 log.Printf("Invalid blockid: %+q", encBlockID)
202 rw.WriteHeader(http.StatusBadRequest)
205 blob.Data = blob.Uncommitted[string(blockID)]
206 blob.Etag = makeEtag()
207 blob.Mtime = time.Now()
208 delete(blob.Uncommitted, string(blockID))
210 rw.WriteHeader(http.StatusCreated)
211 case r.Method == "PUT" && r.Form.Get("comp") == "metadata":
212 // "Set Metadata Headers" API. We don't bother
213 // stubbing "Get Metadata Headers": AzureBlobVolume
214 // sets metadata headers only as a way to bump Etag
215 // and Last-Modified.
217 log.Printf("Got metadata for nonexistent blob: %+v", r)
218 rw.WriteHeader(http.StatusBadRequest)
221 blob.Metadata = make(map[string]string)
222 for k, v := range r.Header {
223 if strings.HasPrefix(strings.ToLower(k), "x-ms-meta-") {
224 name := k[len("x-ms-meta-"):]
225 blob.Metadata[strings.ToLower(name)] = v[0]
228 blob.Mtime = time.Now()
229 blob.Etag = makeEtag()
230 case (r.Method == "GET" || r.Method == "HEAD") && r.Form.Get("comp") == "metadata" && hash != "":
231 // "Get Blob Metadata" API
233 rw.WriteHeader(http.StatusNotFound)
236 for k, v := range blob.Metadata {
237 rw.Header().Set(fmt.Sprintf("x-ms-meta-%s", k), v)
240 case (r.Method == "GET" || r.Method == "HEAD") && hash != "":
243 rw.WriteHeader(http.StatusNotFound)
247 if rangeSpec := rangeRegexp.FindStringSubmatch(r.Header.Get("Range")); rangeSpec != nil {
248 b0, err0 := strconv.Atoi(rangeSpec[1])
249 b1, err1 := strconv.Atoi(rangeSpec[2])
250 if err0 != nil || err1 != nil || b0 >= len(data) || b1 >= len(data) || b0 > b1 {
251 rw.Header().Set("Content-Range", fmt.Sprintf("bytes */%d", len(data)))
252 rw.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
255 rw.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", b0, b1, len(data)))
256 rw.WriteHeader(http.StatusPartialContent)
257 data = data[b0 : b1+1]
259 rw.Header().Set("Last-Modified", blob.Mtime.Format(time.RFC1123))
260 rw.Header().Set("Content-Length", strconv.Itoa(len(data)))
261 if r.Method == "GET" {
262 if _, err := rw.Write(data); err != nil {
263 log.Printf("write %+q: %s", data, err)
267 case r.Method == "DELETE" && hash != "":
270 rw.WriteHeader(http.StatusNotFound)
273 delete(h.blobs, container+"|"+hash)
274 rw.WriteHeader(http.StatusAccepted)
275 case r.Method == "GET" && r.Form.Get("comp") == "list" && r.Form.Get("restype") == "container":
277 prefix := container + "|" + r.Form.Get("prefix")
278 marker := r.Form.Get("marker")
281 if n, err := strconv.Atoi(r.Form.Get("maxresults")); err == nil && n >= 1 && n <= 5000 {
285 resp := storage.BlobListResponse{
288 MaxResults: int64(maxResults),
290 var hashes sort.StringSlice
291 for k := range h.blobs {
292 if strings.HasPrefix(k, prefix) {
293 hashes = append(hashes, k[len(container)+1:])
297 for _, hash := range hashes {
298 if len(resp.Blobs) == maxResults {
299 resp.NextMarker = hash
302 if len(resp.Blobs) > 0 || marker == "" || marker == hash {
303 blob := h.blobs[container+"|"+hash]
304 bmeta := map[string]string(nil)
305 if r.Form.Get("include") == "metadata" {
306 bmeta = blob.Metadata
310 Properties: storage.BlobProperties{
311 LastModified: storage.TimeRFC1123(blob.Mtime),
312 ContentLength: int64(len(blob.Data)),
317 resp.Blobs = append(resp.Blobs, b)
320 buf, err := xml.Marshal(resp)
323 rw.WriteHeader(http.StatusInternalServerError)
327 log.Printf("azStubHandler: not implemented: %+v Body:%+q", r, body)
328 rw.WriteHeader(http.StatusNotImplemented)
332 // azStubDialer is a net.Dialer that notices when the Azure driver
333 // tries to connect to "devstoreaccount1.blob.127.0.0.1:46067", and
334 // in such cases transparently dials "127.0.0.1:46067" instead.
335 type azStubDialer struct {
339 var localHostPortRe = regexp.MustCompile(`(127\.0\.0\.1|localhost|\[::1\]):\d+`)
341 func (d *azStubDialer) Dial(network, address string) (net.Conn, error) {
342 if hp := localHostPortRe.FindString(address); hp != "" {
343 log.Println("azStubDialer: dial", hp, "instead of", address)
346 return d.Dialer.Dial(network, address)
349 type TestableAzureBlobVolume struct {
351 azHandler *azStubHandler
352 azStub *httptest.Server
356 func NewTestableAzureBlobVolume(t TB, readonly bool, replication int) *TestableAzureBlobVolume {
357 azHandler := newAzStubHandler()
358 azStub := httptest.NewServer(azHandler)
360 var azClient storage.Client
362 container := azureTestContainer
364 // Connect to stub instead of real Azure storage service
365 stubURLBase := strings.Split(azStub.URL, "://")[1]
367 if azClient, err = storage.NewClient(fakeAccountName, fakeAccountKey, stubURLBase, storage.DefaultAPIVersion, false); err != nil {
370 container = "fakecontainername"
372 // Connect to real Azure storage service
373 accountKey, err := readKeyFromFile(azureStorageAccountKeyFile)
377 azClient, err = storage.NewBasicClient(azureStorageAccountName, accountKey)
383 bs := azClient.GetBlobService()
384 v := &AzureBlobVolume{
385 ContainerName: container,
387 AzureReplication: replication,
389 container: &azureContainer{ctr: bs.GetContainerReference(container)},
392 return &TestableAzureBlobVolume{
394 azHandler: azHandler,
400 var _ = check.Suite(&StubbedAzureBlobSuite{})
402 type StubbedAzureBlobSuite struct {
403 volume *TestableAzureBlobVolume
404 origHTTPTransport http.RoundTripper
407 func (s *StubbedAzureBlobSuite) SetUpTest(c *check.C) {
408 s.origHTTPTransport = http.DefaultTransport
409 http.DefaultTransport = &http.Transport{
410 Dial: (&azStubDialer{}).Dial,
412 azureWriteRaceInterval = time.Millisecond
413 azureWriteRacePollTime = time.Nanosecond
415 s.volume = NewTestableAzureBlobVolume(c, false, 3)
418 func (s *StubbedAzureBlobSuite) TearDownTest(c *check.C) {
420 http.DefaultTransport = s.origHTTPTransport
423 func TestAzureBlobVolumeWithGeneric(t *testing.T) {
424 defer func(t http.RoundTripper) {
425 http.DefaultTransport = t
426 }(http.DefaultTransport)
427 http.DefaultTransport = &http.Transport{
428 Dial: (&azStubDialer{}).Dial,
430 azureWriteRaceInterval = time.Millisecond
431 azureWriteRacePollTime = time.Nanosecond
432 DoGenericVolumeTests(t, func(t TB) TestableVolume {
433 return NewTestableAzureBlobVolume(t, false, azureStorageReplication)
437 func TestAzureBlobVolumeConcurrentRanges(t *testing.T) {
442 defer func(t http.RoundTripper) {
443 http.DefaultTransport = t
444 }(http.DefaultTransport)
445 http.DefaultTransport = &http.Transport{
446 Dial: (&azStubDialer{}).Dial,
448 azureWriteRaceInterval = time.Millisecond
449 azureWriteRacePollTime = time.Nanosecond
450 // Test (BlockSize mod azureMaxGetBytes)==0 and !=0 cases
451 for _, azureMaxGetBytes = range []int{2 << 22, 2<<22 - 1} {
452 DoGenericVolumeTests(t, func(t TB) TestableVolume {
453 return NewTestableAzureBlobVolume(t, false, azureStorageReplication)
458 func TestReadonlyAzureBlobVolumeWithGeneric(t *testing.T) {
459 defer func(t http.RoundTripper) {
460 http.DefaultTransport = t
461 }(http.DefaultTransport)
462 http.DefaultTransport = &http.Transport{
463 Dial: (&azStubDialer{}).Dial,
465 azureWriteRaceInterval = time.Millisecond
466 azureWriteRacePollTime = time.Nanosecond
467 DoGenericVolumeTests(t, func(t TB) TestableVolume {
468 return NewTestableAzureBlobVolume(t, true, azureStorageReplication)
472 func TestAzureBlobVolumeRangeFenceposts(t *testing.T) {
473 defer func(t http.RoundTripper) {
474 http.DefaultTransport = t
475 }(http.DefaultTransport)
476 http.DefaultTransport = &http.Transport{
477 Dial: (&azStubDialer{}).Dial,
480 v := NewTestableAzureBlobVolume(t, false, 3)
483 for _, size := range []int{
484 2<<22 - 1, // one <max read
485 2 << 22, // one =max read
486 2<<22 + 1, // one =max read, one <max
487 2 << 23, // two =max reads
491 data := make([]byte, size)
492 for i := range data {
493 data[i] = byte((i + 7) & 0xff)
495 hash := fmt.Sprintf("%x", md5.Sum(data))
496 err := v.Put(context.Background(), hash, data)
500 gotData := make([]byte, len(data))
501 gotLen, err := v.Get(context.Background(), hash, gotData)
505 gotHash := fmt.Sprintf("%x", md5.Sum(gotData))
507 t.Errorf("length mismatch: got %d != %d", gotLen, size)
510 t.Errorf("hash mismatch: got %s != %s", gotHash, hash)
515 func TestAzureBlobVolumeReplication(t *testing.T) {
516 for r := 1; r <= 4; r++ {
517 v := NewTestableAzureBlobVolume(t, false, r)
519 if n := v.Replication(); n != r {
520 t.Errorf("Got replication %d, expected %d", n, r)
525 func TestAzureBlobVolumeCreateBlobRace(t *testing.T) {
526 defer func(t http.RoundTripper) {
527 http.DefaultTransport = t
528 }(http.DefaultTransport)
529 http.DefaultTransport = &http.Transport{
530 Dial: (&azStubDialer{}).Dial,
533 v := NewTestableAzureBlobVolume(t, false, 3)
536 azureWriteRaceInterval = time.Second
537 azureWriteRacePollTime = time.Millisecond
539 allDone := make(chan struct{})
540 v.azHandler.race = make(chan chan struct{})
542 err := v.Put(context.Background(), TestHash, TestBlock)
547 continuePut := make(chan struct{})
548 // Wait for the stub's Put to create the empty blob
549 v.azHandler.race <- continuePut
551 buf := make([]byte, len(TestBlock))
552 _, err := v.Get(context.Background(), TestHash, buf)
558 // Wait for the stub's Get to get the empty blob
559 close(v.azHandler.race)
560 // Allow stub's Put to continue, so the real data is ready
561 // when the volume's Get retries
563 // Wait for volume's Get to return the real data
567 func TestAzureBlobVolumeCreateBlobRaceDeadline(t *testing.T) {
568 defer func(t http.RoundTripper) {
569 http.DefaultTransport = t
570 }(http.DefaultTransport)
571 http.DefaultTransport = &http.Transport{
572 Dial: (&azStubDialer{}).Dial,
575 v := NewTestableAzureBlobVolume(t, false, 3)
578 azureWriteRaceInterval = 2 * time.Second
579 azureWriteRacePollTime = 5 * time.Millisecond
581 v.PutRaw(TestHash, nil)
583 buf := new(bytes.Buffer)
586 t.Errorf("Index %+q should be empty", buf.Bytes())
589 v.TouchWithDate(TestHash, time.Now().Add(-1982*time.Millisecond))
591 allDone := make(chan struct{})
594 buf := make([]byte, BlockSize)
595 n, err := v.Get(context.Background(), TestHash, buf)
601 t.Errorf("Got %+q, expected empty buf", buf[:n])
606 case <-time.After(time.Second):
607 t.Error("Get should have stopped waiting for race when block was 2s old")
612 if !bytes.HasPrefix(buf.Bytes(), []byte(TestHash+"+0")) {
613 t.Errorf("Index %+q should have %+q", buf.Bytes(), TestHash+"+0")
617 func TestAzureBlobVolumeContextCancelGet(t *testing.T) {
618 testAzureBlobVolumeContextCancel(t, func(ctx context.Context, v *TestableAzureBlobVolume) error {
619 v.PutRaw(TestHash, TestBlock)
620 _, err := v.Get(ctx, TestHash, make([]byte, BlockSize))
625 func TestAzureBlobVolumeContextCancelPut(t *testing.T) {
626 testAzureBlobVolumeContextCancel(t, func(ctx context.Context, v *TestableAzureBlobVolume) error {
627 return v.Put(ctx, TestHash, make([]byte, BlockSize))
631 func TestAzureBlobVolumeContextCancelCompare(t *testing.T) {
632 testAzureBlobVolumeContextCancel(t, func(ctx context.Context, v *TestableAzureBlobVolume) error {
633 v.PutRaw(TestHash, TestBlock)
634 return v.Compare(ctx, TestHash, TestBlock2)
638 func testAzureBlobVolumeContextCancel(t *testing.T, testFunc func(context.Context, *TestableAzureBlobVolume) error) {
639 defer func(t http.RoundTripper) {
640 http.DefaultTransport = t
641 }(http.DefaultTransport)
642 http.DefaultTransport = &http.Transport{
643 Dial: (&azStubDialer{}).Dial,
646 v := NewTestableAzureBlobVolume(t, false, 3)
648 v.azHandler.race = make(chan chan struct{})
650 ctx, cancel := context.WithCancel(context.Background())
651 allDone := make(chan struct{})
654 err := testFunc(ctx, v)
655 if err != context.Canceled {
656 t.Errorf("got %T %q, expected %q", err, err, context.Canceled)
659 releaseHandler := make(chan struct{})
662 t.Error("testFunc finished without waiting for v.azHandler.race")
663 case <-time.After(10 * time.Second):
664 t.Error("timed out waiting to enter handler")
665 case v.azHandler.race <- releaseHandler:
671 case <-time.After(10 * time.Second):
672 t.Error("timed out waiting to cancel")
681 func (s *StubbedAzureBlobSuite) TestStats(c *check.C) {
682 stats := func() string {
683 buf, err := json.Marshal(s.volume.InternalStats())
684 c.Check(err, check.IsNil)
688 c.Check(stats(), check.Matches, `.*"Ops":0,.*`)
689 c.Check(stats(), check.Matches, `.*"Errors":0,.*`)
691 loc := "acbd18db4cc2f85cedef654fccc4a4d8"
692 _, err := s.volume.Get(context.Background(), loc, make([]byte, 3))
693 c.Check(err, check.NotNil)
694 c.Check(stats(), check.Matches, `.*"Ops":[^0],.*`)
695 c.Check(stats(), check.Matches, `.*"Errors":[^0],.*`)
696 c.Check(stats(), check.Matches, `.*"storage\.AzureStorageServiceError 404 \(404 Not Found\)":[^0].*`)
697 c.Check(stats(), check.Matches, `.*"InBytes":0,.*`)
699 err = s.volume.Put(context.Background(), loc, []byte("foo"))
700 c.Check(err, check.IsNil)
701 c.Check(stats(), check.Matches, `.*"OutBytes":3,.*`)
702 c.Check(stats(), check.Matches, `.*"CreateOps":1,.*`)
704 _, err = s.volume.Get(context.Background(), loc, make([]byte, 3))
705 c.Check(err, check.IsNil)
706 _, err = s.volume.Get(context.Background(), loc, make([]byte, 3))
707 c.Check(err, check.IsNil)
708 c.Check(stats(), check.Matches, `.*"InBytes":6,.*`)
711 func (s *StubbedAzureBlobSuite) TestConfig(c *check.C) {
713 err := yaml.Unmarshal([]byte(`
716 StorageClasses: ["class_a", "class_b"]
719 c.Check(err, check.IsNil)
720 c.Check(cfg.Volumes[0].GetStorageClasses(), check.DeepEquals, []string{"class_a", "class_b"})
723 func (v *TestableAzureBlobVolume) PutRaw(locator string, data []byte) {
724 v.azHandler.PutRaw(v.ContainerName, locator, data)
727 func (v *TestableAzureBlobVolume) TouchWithDate(locator string, lastPut time.Time) {
728 v.azHandler.TouchWithDate(v.ContainerName, locator, lastPut)
731 func (v *TestableAzureBlobVolume) Teardown() {
735 func makeEtag() string {
736 return fmt.Sprintf("0x%x", rand.Int63())