1 // Copyright (C) The Arvados Authors. All rights reserved.
3 // SPDX-License-Identifier: AGPL-3.0
30 "github.com/Azure/azure-sdk-for-go/storage"
31 "github.com/ghodss/yaml"
32 check "gopkg.in/check.v1"
36 // This cannot be the fake account name "devstoreaccount1"
37 // used by Microsoft's Azure emulator: the Azure SDK
38 // recognizes that magic string and changes its behavior to
39 // cater to the Azure SDK's own test suite.
40 fakeAccountName = "fakeaccountname"
41 fakeAccountKey = "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw=="
45 azureTestContainer string
46 azureTestDebug = os.Getenv("ARVADOS_DEBUG") != ""
52 "test.azure-storage-container-volume",
54 "Name of Azure container to use for testing. Do not use a container with real data! Use -azure-storage-account-name and -azure-storage-key-file arguments to supply credentials.")
60 Metadata map[string]string
62 Uncommitted map[string][]byte
65 type azStubHandler struct {
67 blobs map[string]*azBlob
68 race chan chan struct{}
71 func newAzStubHandler() *azStubHandler {
72 return &azStubHandler{
73 blobs: make(map[string]*azBlob),
77 func (h *azStubHandler) TouchWithDate(container, hash string, t time.Time) {
78 blob, ok := h.blobs[container+"|"+hash]
85 func (h *azStubHandler) PutRaw(container, hash string, data []byte) {
88 h.blobs[container+"|"+hash] = &azBlob{
91 Metadata: make(map[string]string),
92 Uncommitted: make(map[string][]byte),
96 func (h *azStubHandler) unlockAndRace() {
101 // Signal caller that race is starting by reading from
102 // h.race. If we get a channel, block until that channel is
103 // ready to receive. If we get nil (or h.race is closed) just
105 if c := <-h.race; c != nil {
111 var rangeRegexp = regexp.MustCompile(`^bytes=(\d+)-(\d+)$`)
113 func (h *azStubHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
117 defer log.Printf("azStubHandler: %+v", r)
120 path := strings.Split(r.URL.Path, "/")
127 if err := r.ParseForm(); err != nil {
128 log.Printf("azStubHandler(%+v): %s", r, err)
129 rw.WriteHeader(http.StatusBadRequest)
133 if (r.Method == "PUT" || r.Method == "POST") && r.Header.Get("Content-Length") == "" {
134 rw.WriteHeader(http.StatusLengthRequired)
138 body, err := ioutil.ReadAll(r.Body)
143 type blockListRequestBody struct {
144 XMLName xml.Name `xml:"BlockList"`
148 blob, blobExists := h.blobs[container+"|"+hash]
151 case r.Method == "PUT" && r.Form.Get("comp") == "":
153 if _, ok := h.blobs[container+"|"+hash]; !ok {
154 // Like the real Azure service, we offer a
155 // race window during which other clients can
156 // list/get the new blob before any data is
158 h.blobs[container+"|"+hash] = &azBlob{
160 Uncommitted: make(map[string][]byte),
161 Metadata: make(map[string]string),
166 metadata := make(map[string]string)
167 for k, v := range r.Header {
168 if strings.HasPrefix(strings.ToLower(k), "x-ms-meta-") {
169 name := k[len("x-ms-meta-"):]
170 metadata[strings.ToLower(name)] = v[0]
173 h.blobs[container+"|"+hash] = &azBlob{
176 Uncommitted: make(map[string][]byte),
180 rw.WriteHeader(http.StatusCreated)
181 case r.Method == "PUT" && r.Form.Get("comp") == "block":
184 log.Printf("Got block for nonexistent blob: %+v", r)
185 rw.WriteHeader(http.StatusBadRequest)
188 blockID, err := base64.StdEncoding.DecodeString(r.Form.Get("blockid"))
189 if err != nil || len(blockID) == 0 {
190 log.Printf("Invalid blockid: %+q", r.Form.Get("blockid"))
191 rw.WriteHeader(http.StatusBadRequest)
194 blob.Uncommitted[string(blockID)] = body
195 rw.WriteHeader(http.StatusCreated)
196 case r.Method == "PUT" && r.Form.Get("comp") == "blocklist":
197 // "Put Block List" API
198 bl := &blockListRequestBody{}
199 if err := xml.Unmarshal(body, bl); err != nil {
200 log.Printf("xml Unmarshal: %s", err)
201 rw.WriteHeader(http.StatusBadRequest)
204 for _, encBlockID := range bl.Uncommitted {
205 blockID, err := base64.StdEncoding.DecodeString(encBlockID)
206 if err != nil || len(blockID) == 0 || blob.Uncommitted[string(blockID)] == nil {
207 log.Printf("Invalid blockid: %+q", encBlockID)
208 rw.WriteHeader(http.StatusBadRequest)
211 blob.Data = blob.Uncommitted[string(blockID)]
212 blob.Etag = makeEtag()
213 blob.Mtime = time.Now()
214 delete(blob.Uncommitted, string(blockID))
216 rw.WriteHeader(http.StatusCreated)
217 case r.Method == "PUT" && r.Form.Get("comp") == "metadata":
218 // "Set Metadata Headers" API. We don't bother
219 // stubbing "Get Metadata Headers": AzureBlobVolume
220 // sets metadata headers only as a way to bump Etag
221 // and Last-Modified.
223 log.Printf("Got metadata for nonexistent blob: %+v", r)
224 rw.WriteHeader(http.StatusBadRequest)
227 blob.Metadata = make(map[string]string)
228 for k, v := range r.Header {
229 if strings.HasPrefix(strings.ToLower(k), "x-ms-meta-") {
230 name := k[len("x-ms-meta-"):]
231 blob.Metadata[strings.ToLower(name)] = v[0]
234 blob.Mtime = time.Now()
235 blob.Etag = makeEtag()
236 case (r.Method == "GET" || r.Method == "HEAD") && r.Form.Get("comp") == "metadata" && hash != "":
237 // "Get Blob Metadata" API
239 rw.WriteHeader(http.StatusNotFound)
242 for k, v := range blob.Metadata {
243 rw.Header().Set(fmt.Sprintf("x-ms-meta-%s", k), v)
246 case (r.Method == "GET" || r.Method == "HEAD") && hash != "":
249 rw.WriteHeader(http.StatusNotFound)
253 if rangeSpec := rangeRegexp.FindStringSubmatch(r.Header.Get("Range")); rangeSpec != nil {
254 b0, err0 := strconv.Atoi(rangeSpec[1])
255 b1, err1 := strconv.Atoi(rangeSpec[2])
256 if err0 != nil || err1 != nil || b0 >= len(data) || b1 >= len(data) || b0 > b1 {
257 rw.Header().Set("Content-Range", fmt.Sprintf("bytes */%d", len(data)))
258 rw.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
261 rw.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", b0, b1, len(data)))
262 rw.WriteHeader(http.StatusPartialContent)
263 data = data[b0 : b1+1]
265 rw.Header().Set("Last-Modified", blob.Mtime.Format(time.RFC1123))
266 rw.Header().Set("Content-Length", strconv.Itoa(len(data)))
267 if r.Method == "GET" {
268 if _, err := rw.Write(data); err != nil {
269 log.Printf("write %+q: %s", data, err)
273 case r.Method == "DELETE" && hash != "":
276 rw.WriteHeader(http.StatusNotFound)
279 delete(h.blobs, container+"|"+hash)
280 rw.WriteHeader(http.StatusAccepted)
281 case r.Method == "GET" && r.Form.Get("comp") == "list" && r.Form.Get("restype") == "container":
283 prefix := container + "|" + r.Form.Get("prefix")
284 marker := r.Form.Get("marker")
287 if n, err := strconv.Atoi(r.Form.Get("maxresults")); err == nil && n >= 1 && n <= 5000 {
291 resp := storage.BlobListResponse{
294 MaxResults: int64(maxResults),
296 var hashes sort.StringSlice
297 for k := range h.blobs {
298 if strings.HasPrefix(k, prefix) {
299 hashes = append(hashes, k[len(container)+1:])
303 for _, hash := range hashes {
304 if len(resp.Blobs) == maxResults {
305 resp.NextMarker = hash
308 if len(resp.Blobs) > 0 || marker == "" || marker == hash {
309 blob := h.blobs[container+"|"+hash]
310 bmeta := map[string]string(nil)
311 if r.Form.Get("include") == "metadata" {
312 bmeta = blob.Metadata
316 Properties: storage.BlobProperties{
317 LastModified: storage.TimeRFC1123(blob.Mtime),
318 ContentLength: int64(len(blob.Data)),
323 resp.Blobs = append(resp.Blobs, b)
326 buf, err := xml.Marshal(resp)
329 rw.WriteHeader(http.StatusInternalServerError)
333 log.Printf("azStubHandler: not implemented: %+v Body:%+q", r, body)
334 rw.WriteHeader(http.StatusNotImplemented)
338 // azStubDialer is a net.Dialer that notices when the Azure driver
339 // tries to connect to "devstoreaccount1.blob.127.0.0.1:46067", and
340 // in such cases transparently dials "127.0.0.1:46067" instead.
341 type azStubDialer struct {
345 var localHostPortRe = regexp.MustCompile(`(127\.0\.0\.1|localhost|\[::1\]):\d+`)
347 func (d *azStubDialer) Dial(network, address string) (net.Conn, error) {
348 if hp := localHostPortRe.FindString(address); hp != "" {
350 log.Println("azStubDialer: dial", hp, "instead of", address)
354 return d.Dialer.Dial(network, address)
357 type TestableAzureBlobVolume struct {
359 azHandler *azStubHandler
360 azStub *httptest.Server
364 func NewTestableAzureBlobVolume(t TB, readonly bool, replication int) *TestableAzureBlobVolume {
365 azHandler := newAzStubHandler()
366 azStub := httptest.NewServer(azHandler)
368 var azClient storage.Client
370 container := azureTestContainer
372 // Connect to stub instead of real Azure storage service
373 stubURLBase := strings.Split(azStub.URL, "://")[1]
375 if azClient, err = storage.NewClient(fakeAccountName, fakeAccountKey, stubURLBase, storage.DefaultAPIVersion, false); err != nil {
378 container = "fakecontainername"
380 // Connect to real Azure storage service
381 accountKey, err := readKeyFromFile(azureStorageAccountKeyFile)
385 azClient, err = storage.NewBasicClient(azureStorageAccountName, accountKey)
391 bs := azClient.GetBlobService()
392 v := &AzureBlobVolume{
393 ContainerName: container,
395 AzureReplication: replication,
397 container: &azureContainer{ctr: bs.GetContainerReference(container)},
400 return &TestableAzureBlobVolume{
402 azHandler: azHandler,
408 var _ = check.Suite(&StubbedAzureBlobSuite{})
410 type StubbedAzureBlobSuite struct {
411 volume *TestableAzureBlobVolume
412 origHTTPTransport http.RoundTripper
415 func (s *StubbedAzureBlobSuite) SetUpTest(c *check.C) {
416 s.origHTTPTransport = http.DefaultTransport
417 http.DefaultTransport = &http.Transport{
418 Dial: (&azStubDialer{}).Dial,
420 azureWriteRaceInterval = time.Millisecond
421 azureWriteRacePollTime = time.Nanosecond
423 s.volume = NewTestableAzureBlobVolume(c, false, 3)
426 func (s *StubbedAzureBlobSuite) TearDownTest(c *check.C) {
428 http.DefaultTransport = s.origHTTPTransport
431 func TestAzureBlobVolumeWithGeneric(t *testing.T) {
432 defer func(t http.RoundTripper) {
433 http.DefaultTransport = t
434 }(http.DefaultTransport)
435 http.DefaultTransport = &http.Transport{
436 Dial: (&azStubDialer{}).Dial,
438 azureWriteRaceInterval = time.Millisecond
439 azureWriteRacePollTime = time.Nanosecond
440 DoGenericVolumeTests(t, func(t TB) TestableVolume {
441 return NewTestableAzureBlobVolume(t, false, azureStorageReplication)
445 func TestAzureBlobVolumeConcurrentRanges(t *testing.T) {
450 defer func(t http.RoundTripper) {
451 http.DefaultTransport = t
452 }(http.DefaultTransport)
453 http.DefaultTransport = &http.Transport{
454 Dial: (&azStubDialer{}).Dial,
456 azureWriteRaceInterval = time.Millisecond
457 azureWriteRacePollTime = time.Nanosecond
458 // Test (BlockSize mod azureMaxGetBytes)==0 and !=0 cases
459 for _, azureMaxGetBytes = range []int{2 << 22, 2<<22 - 1} {
460 DoGenericVolumeTests(t, func(t TB) TestableVolume {
461 return NewTestableAzureBlobVolume(t, false, azureStorageReplication)
466 func TestReadonlyAzureBlobVolumeWithGeneric(t *testing.T) {
467 defer func(t http.RoundTripper) {
468 http.DefaultTransport = t
469 }(http.DefaultTransport)
470 http.DefaultTransport = &http.Transport{
471 Dial: (&azStubDialer{}).Dial,
473 azureWriteRaceInterval = time.Millisecond
474 azureWriteRacePollTime = time.Nanosecond
475 DoGenericVolumeTests(t, func(t TB) TestableVolume {
476 return NewTestableAzureBlobVolume(t, true, azureStorageReplication)
480 func TestAzureBlobVolumeRangeFenceposts(t *testing.T) {
481 defer func(t http.RoundTripper) {
482 http.DefaultTransport = t
483 }(http.DefaultTransport)
484 http.DefaultTransport = &http.Transport{
485 Dial: (&azStubDialer{}).Dial,
488 v := NewTestableAzureBlobVolume(t, false, 3)
491 for _, size := range []int{
492 2<<22 - 1, // one <max read
493 2 << 22, // one =max read
494 2<<22 + 1, // one =max read, one <max
495 2 << 23, // two =max reads
499 data := make([]byte, size)
500 for i := range data {
501 data[i] = byte((i + 7) & 0xff)
503 hash := fmt.Sprintf("%x", md5.Sum(data))
504 err := v.Put(context.Background(), hash, data)
508 gotData := make([]byte, len(data))
509 gotLen, err := v.Get(context.Background(), hash, gotData)
513 gotHash := fmt.Sprintf("%x", md5.Sum(gotData))
515 t.Errorf("length mismatch: got %d != %d", gotLen, size)
518 t.Errorf("hash mismatch: got %s != %s", gotHash, hash)
523 func TestAzureBlobVolumeReplication(t *testing.T) {
524 for r := 1; r <= 4; r++ {
525 v := NewTestableAzureBlobVolume(t, false, r)
527 if n := v.Replication(); n != r {
528 t.Errorf("Got replication %d, expected %d", n, r)
533 func TestAzureBlobVolumeCreateBlobRace(t *testing.T) {
534 defer func(t http.RoundTripper) {
535 http.DefaultTransport = t
536 }(http.DefaultTransport)
537 http.DefaultTransport = &http.Transport{
538 Dial: (&azStubDialer{}).Dial,
541 v := NewTestableAzureBlobVolume(t, false, 3)
544 azureWriteRaceInterval = time.Second
545 azureWriteRacePollTime = time.Millisecond
547 var wg sync.WaitGroup
549 v.azHandler.race = make(chan chan struct{})
554 err := v.Put(context.Background(), TestHash, TestBlock)
559 continuePut := make(chan struct{})
560 // Wait for the stub's Put to create the empty blob
561 v.azHandler.race <- continuePut
565 buf := make([]byte, len(TestBlock))
566 _, err := v.Get(context.Background(), TestHash, buf)
571 // Wait for the stub's Get to get the empty blob
572 close(v.azHandler.race)
573 // Allow stub's Put to continue, so the real data is ready
574 // when the volume's Get retries
576 // Wait for Get() and Put() to finish
580 func TestAzureBlobVolumeCreateBlobRaceDeadline(t *testing.T) {
581 defer func(t http.RoundTripper) {
582 http.DefaultTransport = t
583 }(http.DefaultTransport)
584 http.DefaultTransport = &http.Transport{
585 Dial: (&azStubDialer{}).Dial,
588 v := NewTestableAzureBlobVolume(t, false, 3)
591 azureWriteRaceInterval = 2 * time.Second
592 azureWriteRacePollTime = 5 * time.Millisecond
594 v.PutRaw(TestHash, nil)
596 buf := new(bytes.Buffer)
599 t.Errorf("Index %+q should be empty", buf.Bytes())
602 v.TouchWithDate(TestHash, time.Now().Add(-1982*time.Millisecond))
604 allDone := make(chan struct{})
607 buf := make([]byte, BlockSize)
608 n, err := v.Get(context.Background(), TestHash, buf)
614 t.Errorf("Got %+q, expected empty buf", buf[:n])
619 case <-time.After(time.Second):
620 t.Error("Get should have stopped waiting for race when block was 2s old")
625 if !bytes.HasPrefix(buf.Bytes(), []byte(TestHash+"+0")) {
626 t.Errorf("Index %+q should have %+q", buf.Bytes(), TestHash+"+0")
630 func TestAzureBlobVolumeContextCancelGet(t *testing.T) {
631 testAzureBlobVolumeContextCancel(t, func(ctx context.Context, v *TestableAzureBlobVolume) error {
632 v.PutRaw(TestHash, TestBlock)
633 _, err := v.Get(ctx, TestHash, make([]byte, BlockSize))
638 func TestAzureBlobVolumeContextCancelPut(t *testing.T) {
639 testAzureBlobVolumeContextCancel(t, func(ctx context.Context, v *TestableAzureBlobVolume) error {
640 return v.Put(ctx, TestHash, make([]byte, BlockSize))
644 func TestAzureBlobVolumeContextCancelCompare(t *testing.T) {
645 testAzureBlobVolumeContextCancel(t, func(ctx context.Context, v *TestableAzureBlobVolume) error {
646 v.PutRaw(TestHash, TestBlock)
647 return v.Compare(ctx, TestHash, TestBlock2)
651 func testAzureBlobVolumeContextCancel(t *testing.T, testFunc func(context.Context, *TestableAzureBlobVolume) error) {
652 defer func(t http.RoundTripper) {
653 http.DefaultTransport = t
654 }(http.DefaultTransport)
655 http.DefaultTransport = &http.Transport{
656 Dial: (&azStubDialer{}).Dial,
659 v := NewTestableAzureBlobVolume(t, false, 3)
661 v.azHandler.race = make(chan chan struct{})
663 ctx, cancel := context.WithCancel(context.Background())
664 allDone := make(chan struct{})
667 err := testFunc(ctx, v)
668 if err != context.Canceled {
669 t.Errorf("got %T %q, expected %q", err, err, context.Canceled)
672 releaseHandler := make(chan struct{})
675 t.Error("testFunc finished without waiting for v.azHandler.race")
676 case <-time.After(10 * time.Second):
677 t.Error("timed out waiting to enter handler")
678 case v.azHandler.race <- releaseHandler:
684 case <-time.After(10 * time.Second):
685 t.Error("timed out waiting to cancel")
694 func (s *StubbedAzureBlobSuite) TestStats(c *check.C) {
695 stats := func() string {
696 buf, err := json.Marshal(s.volume.InternalStats())
697 c.Check(err, check.IsNil)
701 c.Check(stats(), check.Matches, `.*"Ops":0,.*`)
702 c.Check(stats(), check.Matches, `.*"Errors":0,.*`)
704 loc := "acbd18db4cc2f85cedef654fccc4a4d8"
705 _, err := s.volume.Get(context.Background(), loc, make([]byte, 3))
706 c.Check(err, check.NotNil)
707 c.Check(stats(), check.Matches, `.*"Ops":[^0],.*`)
708 c.Check(stats(), check.Matches, `.*"Errors":[^0],.*`)
709 c.Check(stats(), check.Matches, `.*"storage\.AzureStorageServiceError 404 \(404 Not Found\)":[^0].*`)
710 c.Check(stats(), check.Matches, `.*"InBytes":0,.*`)
712 err = s.volume.Put(context.Background(), loc, []byte("foo"))
713 c.Check(err, check.IsNil)
714 c.Check(stats(), check.Matches, `.*"OutBytes":3,.*`)
715 c.Check(stats(), check.Matches, `.*"CreateOps":1,.*`)
717 _, err = s.volume.Get(context.Background(), loc, make([]byte, 3))
718 c.Check(err, check.IsNil)
719 _, err = s.volume.Get(context.Background(), loc, make([]byte, 3))
720 c.Check(err, check.IsNil)
721 c.Check(stats(), check.Matches, `.*"InBytes":6,.*`)
724 func (s *StubbedAzureBlobSuite) TestConfig(c *check.C) {
726 err := yaml.Unmarshal([]byte(`
729 StorageClasses: ["class_a", "class_b"]
732 c.Check(err, check.IsNil)
733 c.Check(cfg.Volumes[0].GetStorageClasses(), check.DeepEquals, []string{"class_a", "class_b"})
736 func (v *TestableAzureBlobVolume) PutRaw(locator string, data []byte) {
737 v.azHandler.PutRaw(v.ContainerName, locator, data)
740 func (v *TestableAzureBlobVolume) TouchWithDate(locator string, lastPut time.Time) {
741 v.azHandler.TouchWithDate(v.ContainerName, locator, lastPut)
744 func (v *TestableAzureBlobVolume) Teardown() {
748 func makeEtag() string {
749 return fmt.Sprintf("0x%x", rand.Int63())