1 // Copyright (C) The Arvados Authors. All rights reserved.
3 // SPDX-License-Identifier: AGPL-3.0
29 log "github.com/Sirupsen/logrus"
30 "github.com/curoverse/azure-sdk-for-go/storage"
31 check "gopkg.in/check.v1"
35 // This cannot be the fake account name "devstoreaccount1"
36 // used by Microsoft's Azure emulator: the Azure SDK
37 // recognizes that magic string and changes its behavior to
38 // cater to the Azure SDK's own test suite.
39 fakeAccountName = "fakeAccountName"
40 fakeAccountKey = "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw=="
43 var azureTestContainer string
48 "test.azure-storage-container-volume",
50 "Name of Azure container to use for testing. Do not use a container with real data! Use -azure-storage-account-name and -azure-storage-key-file arguments to supply credentials.")
56 Metadata map[string]string
58 Uncommitted map[string][]byte
61 type azStubHandler struct {
63 blobs map[string]*azBlob
64 race chan chan struct{}
67 func newAzStubHandler() *azStubHandler {
68 return &azStubHandler{
69 blobs: make(map[string]*azBlob),
73 func (h *azStubHandler) TouchWithDate(container, hash string, t time.Time) {
74 blob, ok := h.blobs[container+"|"+hash]
81 func (h *azStubHandler) PutRaw(container, hash string, data []byte) {
84 h.blobs[container+"|"+hash] = &azBlob{
87 Metadata: make(map[string]string),
88 Uncommitted: make(map[string][]byte),
92 func (h *azStubHandler) unlockAndRace() {
97 // Signal caller that race is starting by reading from
98 // h.race. If we get a channel, block until that channel is
99 // ready to receive. If we get nil (or h.race is closed) just
101 if c := <-h.race; c != nil {
107 var rangeRegexp = regexp.MustCompile(`^bytes=(\d+)-(\d+)$`)
109 func (h *azStubHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
112 // defer log.Printf("azStubHandler: %+v", r)
114 path := strings.Split(r.URL.Path, "/")
121 if err := r.ParseForm(); err != nil {
122 log.Printf("azStubHandler(%+v): %s", r, err)
123 rw.WriteHeader(http.StatusBadRequest)
127 body, err := ioutil.ReadAll(r.Body)
132 type blockListRequestBody struct {
133 XMLName xml.Name `xml:"BlockList"`
137 blob, blobExists := h.blobs[container+"|"+hash]
140 case r.Method == "PUT" && r.Form.Get("comp") == "":
142 if _, ok := h.blobs[container+"|"+hash]; !ok {
143 // Like the real Azure service, we offer a
144 // race window during which other clients can
145 // list/get the new blob before any data is
147 h.blobs[container+"|"+hash] = &azBlob{
149 Uncommitted: make(map[string][]byte),
150 Metadata: make(map[string]string),
155 metadata := make(map[string]string)
156 for k, v := range r.Header {
157 if strings.HasPrefix(strings.ToLower(k), "x-ms-meta-") {
158 name := k[len("x-ms-meta-"):]
159 metadata[strings.ToLower(name)] = v[0]
162 h.blobs[container+"|"+hash] = &azBlob{
165 Uncommitted: make(map[string][]byte),
169 rw.WriteHeader(http.StatusCreated)
170 case r.Method == "PUT" && r.Form.Get("comp") == "block":
173 log.Printf("Got block for nonexistent blob: %+v", r)
174 rw.WriteHeader(http.StatusBadRequest)
177 blockID, err := base64.StdEncoding.DecodeString(r.Form.Get("blockid"))
178 if err != nil || len(blockID) == 0 {
179 log.Printf("Invalid blockid: %+q", r.Form.Get("blockid"))
180 rw.WriteHeader(http.StatusBadRequest)
183 blob.Uncommitted[string(blockID)] = body
184 rw.WriteHeader(http.StatusCreated)
185 case r.Method == "PUT" && r.Form.Get("comp") == "blocklist":
186 // "Put Block List" API
187 bl := &blockListRequestBody{}
188 if err := xml.Unmarshal(body, bl); err != nil {
189 log.Printf("xml Unmarshal: %s", err)
190 rw.WriteHeader(http.StatusBadRequest)
193 for _, encBlockID := range bl.Uncommitted {
194 blockID, err := base64.StdEncoding.DecodeString(encBlockID)
195 if err != nil || len(blockID) == 0 || blob.Uncommitted[string(blockID)] == nil {
196 log.Printf("Invalid blockid: %+q", encBlockID)
197 rw.WriteHeader(http.StatusBadRequest)
200 blob.Data = blob.Uncommitted[string(blockID)]
201 blob.Etag = makeEtag()
202 blob.Mtime = time.Now()
203 delete(blob.Uncommitted, string(blockID))
205 rw.WriteHeader(http.StatusCreated)
206 case r.Method == "PUT" && r.Form.Get("comp") == "metadata":
207 // "Set Metadata Headers" API. We don't bother
208 // stubbing "Get Metadata Headers": AzureBlobVolume
209 // sets metadata headers only as a way to bump Etag
210 // and Last-Modified.
212 log.Printf("Got metadata for nonexistent blob: %+v", r)
213 rw.WriteHeader(http.StatusBadRequest)
216 blob.Metadata = make(map[string]string)
217 for k, v := range r.Header {
218 if strings.HasPrefix(strings.ToLower(k), "x-ms-meta-") {
219 name := k[len("x-ms-meta-"):]
220 blob.Metadata[strings.ToLower(name)] = v[0]
223 blob.Mtime = time.Now()
224 blob.Etag = makeEtag()
225 case (r.Method == "GET" || r.Method == "HEAD") && r.Form.Get("comp") == "metadata" && hash != "":
226 // "Get Blob Metadata" API
228 rw.WriteHeader(http.StatusNotFound)
231 for k, v := range blob.Metadata {
232 rw.Header().Set(fmt.Sprintf("x-ms-meta-%s", k), v)
235 case (r.Method == "GET" || r.Method == "HEAD") && hash != "":
238 rw.WriteHeader(http.StatusNotFound)
242 if rangeSpec := rangeRegexp.FindStringSubmatch(r.Header.Get("Range")); rangeSpec != nil {
243 b0, err0 := strconv.Atoi(rangeSpec[1])
244 b1, err1 := strconv.Atoi(rangeSpec[2])
245 if err0 != nil || err1 != nil || b0 >= len(data) || b1 >= len(data) || b0 > b1 {
246 rw.Header().Set("Content-Range", fmt.Sprintf("bytes */%d", len(data)))
247 rw.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
250 rw.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", b0, b1, len(data)))
251 rw.WriteHeader(http.StatusPartialContent)
252 data = data[b0 : b1+1]
254 rw.Header().Set("Last-Modified", blob.Mtime.Format(time.RFC1123))
255 rw.Header().Set("Content-Length", strconv.Itoa(len(data)))
256 if r.Method == "GET" {
257 if _, err := rw.Write(data); err != nil {
258 log.Printf("write %+q: %s", data, err)
262 case r.Method == "DELETE" && hash != "":
265 rw.WriteHeader(http.StatusNotFound)
268 delete(h.blobs, container+"|"+hash)
269 rw.WriteHeader(http.StatusAccepted)
270 case r.Method == "GET" && r.Form.Get("comp") == "list" && r.Form.Get("restype") == "container":
272 prefix := container + "|" + r.Form.Get("prefix")
273 marker := r.Form.Get("marker")
276 if n, err := strconv.Atoi(r.Form.Get("maxresults")); err == nil && n >= 1 && n <= 5000 {
280 resp := storage.BlobListResponse{
283 MaxResults: int64(maxResults),
285 var hashes sort.StringSlice
286 for k := range h.blobs {
287 if strings.HasPrefix(k, prefix) {
288 hashes = append(hashes, k[len(container)+1:])
292 for _, hash := range hashes {
293 if len(resp.Blobs) == maxResults {
294 resp.NextMarker = hash
297 if len(resp.Blobs) > 0 || marker == "" || marker == hash {
298 blob := h.blobs[container+"|"+hash]
299 bmeta := map[string]string(nil)
300 if r.Form.Get("include") == "metadata" {
301 bmeta = blob.Metadata
305 Properties: storage.BlobProperties{
306 LastModified: blob.Mtime.Format(time.RFC1123),
307 ContentLength: int64(len(blob.Data)),
312 resp.Blobs = append(resp.Blobs, b)
315 buf, err := xml.Marshal(resp)
318 rw.WriteHeader(http.StatusInternalServerError)
322 log.Printf("azStubHandler: not implemented: %+v Body:%+q", r, body)
323 rw.WriteHeader(http.StatusNotImplemented)
327 // azStubDialer is a net.Dialer that notices when the Azure driver
328 // tries to connect to "devstoreaccount1.blob.127.0.0.1:46067", and
329 // in such cases transparently dials "127.0.0.1:46067" instead.
330 type azStubDialer struct {
334 var localHostPortRe = regexp.MustCompile(`(127\.0\.0\.1|localhost|\[::1\]):\d+`)
336 func (d *azStubDialer) Dial(network, address string) (net.Conn, error) {
337 if hp := localHostPortRe.FindString(address); hp != "" {
338 log.Println("azStubDialer: dial", hp, "instead of", address)
341 return d.Dialer.Dial(network, address)
344 type TestableAzureBlobVolume struct {
346 azHandler *azStubHandler
347 azStub *httptest.Server
351 func NewTestableAzureBlobVolume(t TB, readonly bool, replication int) *TestableAzureBlobVolume {
352 azHandler := newAzStubHandler()
353 azStub := httptest.NewServer(azHandler)
355 var azClient storage.Client
357 container := azureTestContainer
359 // Connect to stub instead of real Azure storage service
360 stubURLBase := strings.Split(azStub.URL, "://")[1]
362 if azClient, err = storage.NewClient(fakeAccountName, fakeAccountKey, stubURLBase, storage.DefaultAPIVersion, false); err != nil {
365 container = "fakecontainername"
367 // Connect to real Azure storage service
368 accountKey, err := readKeyFromFile(azureStorageAccountKeyFile)
372 azClient, err = storage.NewBasicClient(azureStorageAccountName, accountKey)
378 bs := azClient.GetBlobService()
379 v := &AzureBlobVolume{
380 ContainerName: container,
382 AzureReplication: replication,
384 bsClient: &azureBlobClient{client: &bs},
387 return &TestableAzureBlobVolume{
389 azHandler: azHandler,
395 var _ = check.Suite(&StubbedAzureBlobSuite{})
397 type StubbedAzureBlobSuite struct {
398 volume *TestableAzureBlobVolume
399 origHTTPTransport http.RoundTripper
402 func (s *StubbedAzureBlobSuite) SetUpTest(c *check.C) {
403 s.origHTTPTransport = http.DefaultTransport
404 http.DefaultTransport = &http.Transport{
405 Dial: (&azStubDialer{}).Dial,
407 azureWriteRaceInterval = time.Millisecond
408 azureWriteRacePollTime = time.Nanosecond
410 s.volume = NewTestableAzureBlobVolume(c, false, 3)
413 func (s *StubbedAzureBlobSuite) TearDownTest(c *check.C) {
415 http.DefaultTransport = s.origHTTPTransport
418 func TestAzureBlobVolumeWithGeneric(t *testing.T) {
419 defer func(t http.RoundTripper) {
420 http.DefaultTransport = t
421 }(http.DefaultTransport)
422 http.DefaultTransport = &http.Transport{
423 Dial: (&azStubDialer{}).Dial,
425 azureWriteRaceInterval = time.Millisecond
426 azureWriteRacePollTime = time.Nanosecond
427 DoGenericVolumeTests(t, func(t TB) TestableVolume {
428 return NewTestableAzureBlobVolume(t, false, azureStorageReplication)
432 func TestAzureBlobVolumeConcurrentRanges(t *testing.T) {
437 defer func(t http.RoundTripper) {
438 http.DefaultTransport = t
439 }(http.DefaultTransport)
440 http.DefaultTransport = &http.Transport{
441 Dial: (&azStubDialer{}).Dial,
443 azureWriteRaceInterval = time.Millisecond
444 azureWriteRacePollTime = time.Nanosecond
445 // Test (BlockSize mod azureMaxGetBytes)==0 and !=0 cases
446 for _, azureMaxGetBytes = range []int{2 << 22, 2<<22 - 1} {
447 DoGenericVolumeTests(t, func(t TB) TestableVolume {
448 return NewTestableAzureBlobVolume(t, false, azureStorageReplication)
453 func TestReadonlyAzureBlobVolumeWithGeneric(t *testing.T) {
454 defer func(t http.RoundTripper) {
455 http.DefaultTransport = t
456 }(http.DefaultTransport)
457 http.DefaultTransport = &http.Transport{
458 Dial: (&azStubDialer{}).Dial,
460 azureWriteRaceInterval = time.Millisecond
461 azureWriteRacePollTime = time.Nanosecond
462 DoGenericVolumeTests(t, func(t TB) TestableVolume {
463 return NewTestableAzureBlobVolume(t, true, azureStorageReplication)
467 func TestAzureBlobVolumeRangeFenceposts(t *testing.T) {
468 defer func(t http.RoundTripper) {
469 http.DefaultTransport = t
470 }(http.DefaultTransport)
471 http.DefaultTransport = &http.Transport{
472 Dial: (&azStubDialer{}).Dial,
475 v := NewTestableAzureBlobVolume(t, false, 3)
478 for _, size := range []int{
479 2<<22 - 1, // one <max read
480 2 << 22, // one =max read
481 2<<22 + 1, // one =max read, one <max
482 2 << 23, // two =max reads
486 data := make([]byte, size)
487 for i := range data {
488 data[i] = byte((i + 7) & 0xff)
490 hash := fmt.Sprintf("%x", md5.Sum(data))
491 err := v.Put(context.Background(), hash, data)
495 gotData := make([]byte, len(data))
496 gotLen, err := v.Get(context.Background(), hash, gotData)
500 gotHash := fmt.Sprintf("%x", md5.Sum(gotData))
502 t.Errorf("length mismatch: got %d != %d", gotLen, size)
505 t.Errorf("hash mismatch: got %s != %s", gotHash, hash)
510 func TestAzureBlobVolumeReplication(t *testing.T) {
511 for r := 1; r <= 4; r++ {
512 v := NewTestableAzureBlobVolume(t, false, r)
514 if n := v.Replication(); n != r {
515 t.Errorf("Got replication %d, expected %d", n, r)
520 func TestAzureBlobVolumeCreateBlobRace(t *testing.T) {
521 defer func(t http.RoundTripper) {
522 http.DefaultTransport = t
523 }(http.DefaultTransport)
524 http.DefaultTransport = &http.Transport{
525 Dial: (&azStubDialer{}).Dial,
528 v := NewTestableAzureBlobVolume(t, false, 3)
531 azureWriteRaceInterval = time.Second
532 azureWriteRacePollTime = time.Millisecond
534 allDone := make(chan struct{})
535 v.azHandler.race = make(chan chan struct{})
537 err := v.Put(context.Background(), TestHash, TestBlock)
542 continuePut := make(chan struct{})
543 // Wait for the stub's Put to create the empty blob
544 v.azHandler.race <- continuePut
546 buf := make([]byte, len(TestBlock))
547 _, err := v.Get(context.Background(), TestHash, buf)
553 // Wait for the stub's Get to get the empty blob
554 close(v.azHandler.race)
555 // Allow stub's Put to continue, so the real data is ready
556 // when the volume's Get retries
558 // Wait for volume's Get to return the real data
562 func TestAzureBlobVolumeCreateBlobRaceDeadline(t *testing.T) {
563 defer func(t http.RoundTripper) {
564 http.DefaultTransport = t
565 }(http.DefaultTransport)
566 http.DefaultTransport = &http.Transport{
567 Dial: (&azStubDialer{}).Dial,
570 v := NewTestableAzureBlobVolume(t, false, 3)
573 azureWriteRaceInterval = 2 * time.Second
574 azureWriteRacePollTime = 5 * time.Millisecond
576 v.PutRaw(TestHash, nil)
578 buf := new(bytes.Buffer)
581 t.Errorf("Index %+q should be empty", buf.Bytes())
584 v.TouchWithDate(TestHash, time.Now().Add(-1982*time.Millisecond))
586 allDone := make(chan struct{})
589 buf := make([]byte, BlockSize)
590 n, err := v.Get(context.Background(), TestHash, buf)
596 t.Errorf("Got %+q, expected empty buf", buf[:n])
601 case <-time.After(time.Second):
602 t.Error("Get should have stopped waiting for race when block was 2s old")
607 if !bytes.HasPrefix(buf.Bytes(), []byte(TestHash+"+0")) {
608 t.Errorf("Index %+q should have %+q", buf.Bytes(), TestHash+"+0")
612 func TestAzureBlobVolumeContextCancelGet(t *testing.T) {
613 testAzureBlobVolumeContextCancel(t, func(ctx context.Context, v *TestableAzureBlobVolume) error {
614 v.PutRaw(TestHash, TestBlock)
615 _, err := v.Get(ctx, TestHash, make([]byte, BlockSize))
620 func TestAzureBlobVolumeContextCancelPut(t *testing.T) {
621 testAzureBlobVolumeContextCancel(t, func(ctx context.Context, v *TestableAzureBlobVolume) error {
622 return v.Put(ctx, TestHash, make([]byte, BlockSize))
626 func TestAzureBlobVolumeContextCancelCompare(t *testing.T) {
627 testAzureBlobVolumeContextCancel(t, func(ctx context.Context, v *TestableAzureBlobVolume) error {
628 v.PutRaw(TestHash, TestBlock)
629 return v.Compare(ctx, TestHash, TestBlock2)
633 func testAzureBlobVolumeContextCancel(t *testing.T, testFunc func(context.Context, *TestableAzureBlobVolume) error) {
634 defer func(t http.RoundTripper) {
635 http.DefaultTransport = t
636 }(http.DefaultTransport)
637 http.DefaultTransport = &http.Transport{
638 Dial: (&azStubDialer{}).Dial,
641 v := NewTestableAzureBlobVolume(t, false, 3)
643 v.azHandler.race = make(chan chan struct{})
645 ctx, cancel := context.WithCancel(context.Background())
646 allDone := make(chan struct{})
649 err := testFunc(ctx, v)
650 if err != context.Canceled {
651 t.Errorf("got %T %q, expected %q", err, err, context.Canceled)
654 releaseHandler := make(chan struct{})
657 t.Error("testFunc finished without waiting for v.azHandler.race")
658 case <-time.After(10 * time.Second):
659 t.Error("timed out waiting to enter handler")
660 case v.azHandler.race <- releaseHandler:
666 case <-time.After(10 * time.Second):
667 t.Error("timed out waiting to cancel")
676 func (s *StubbedAzureBlobSuite) TestStats(c *check.C) {
677 stats := func() string {
678 buf, err := json.Marshal(s.volume.InternalStats())
679 c.Check(err, check.IsNil)
683 c.Check(stats(), check.Matches, `.*"Ops":0,.*`)
684 c.Check(stats(), check.Matches, `.*"Errors":0,.*`)
686 loc := "acbd18db4cc2f85cedef654fccc4a4d8"
687 _, err := s.volume.Get(context.Background(), loc, make([]byte, 3))
688 c.Check(err, check.NotNil)
689 c.Check(stats(), check.Matches, `.*"Ops":[^0],.*`)
690 c.Check(stats(), check.Matches, `.*"Errors":[^0],.*`)
691 c.Check(stats(), check.Matches, `.*"storage\.AzureStorageServiceError 404 \(404 Not Found\)":[^0].*`)
692 c.Check(stats(), check.Matches, `.*"InBytes":0,.*`)
694 err = s.volume.Put(context.Background(), loc, []byte("foo"))
695 c.Check(err, check.IsNil)
696 c.Check(stats(), check.Matches, `.*"OutBytes":3,.*`)
697 c.Check(stats(), check.Matches, `.*"CreateOps":1,.*`)
699 _, err = s.volume.Get(context.Background(), loc, make([]byte, 3))
700 c.Check(err, check.IsNil)
701 _, err = s.volume.Get(context.Background(), loc, make([]byte, 3))
702 c.Check(err, check.IsNil)
703 c.Check(stats(), check.Matches, `.*"InBytes":6,.*`)
706 func (v *TestableAzureBlobVolume) PutRaw(locator string, data []byte) {
707 v.azHandler.PutRaw(v.ContainerName, locator, data)
710 func (v *TestableAzureBlobVolume) TouchWithDate(locator string, lastPut time.Time) {
711 v.azHandler.TouchWithDate(v.ContainerName, locator, lastPut)
714 func (v *TestableAzureBlobVolume) Teardown() {
718 func makeEtag() string {
719 return fmt.Sprintf("0x%x", rand.Int63())