1 // Copyright (C) The Arvados Authors. All rights reserved.
3 // SPDX-License-Identifier: AGPL-3.0
29 "github.com/Azure/azure-sdk-for-go/storage"
30 check "gopkg.in/check.v1"
34 // This cannot be the fake account name "devstoreaccount1"
35 // used by Microsoft's Azure emulator: the Azure SDK
36 // recognizes that magic string and changes its behavior to
37 // cater to the Azure SDK's own test suite.
38 fakeAccountName = "fakeaccountname"
39 fakeAccountKey = "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw=="
42 var azureTestContainer string
47 "test.azure-storage-container-volume",
49 "Name of Azure container to use for testing. Do not use a container with real data! Use -azure-storage-account-name and -azure-storage-key-file arguments to supply credentials.")
55 Metadata map[string]string
57 Uncommitted map[string][]byte
60 type azStubHandler struct {
62 blobs map[string]*azBlob
63 race chan chan struct{}
66 func newAzStubHandler() *azStubHandler {
67 return &azStubHandler{
68 blobs: make(map[string]*azBlob),
72 func (h *azStubHandler) TouchWithDate(container, hash string, t time.Time) {
73 blob, ok := h.blobs[container+"|"+hash]
80 func (h *azStubHandler) PutRaw(container, hash string, data []byte) {
83 h.blobs[container+"|"+hash] = &azBlob{
86 Metadata: make(map[string]string),
87 Uncommitted: make(map[string][]byte),
91 func (h *azStubHandler) unlockAndRace() {
96 // Signal caller that race is starting by reading from
97 // h.race. If we get a channel, block until that channel is
98 // ready to receive. If we get nil (or h.race is closed) just
100 if c := <-h.race; c != nil {
106 var rangeRegexp = regexp.MustCompile(`^bytes=(\d+)-(\d+)$`)
108 func (h *azStubHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
111 // defer log.Printf("azStubHandler: %+v", r)
113 path := strings.Split(r.URL.Path, "/")
120 if err := r.ParseForm(); err != nil {
121 log.Printf("azStubHandler(%+v): %s", r, err)
122 rw.WriteHeader(http.StatusBadRequest)
126 if (r.Method == "PUT" || r.Method == "POST") && r.Header.Get("Content-Length") == "" {
127 rw.WriteHeader(http.StatusLengthRequired)
131 body, err := ioutil.ReadAll(r.Body)
136 type blockListRequestBody struct {
137 XMLName xml.Name `xml:"BlockList"`
141 blob, blobExists := h.blobs[container+"|"+hash]
144 case r.Method == "PUT" && r.Form.Get("comp") == "":
146 if _, ok := h.blobs[container+"|"+hash]; !ok {
147 // Like the real Azure service, we offer a
148 // race window during which other clients can
149 // list/get the new blob before any data is
151 h.blobs[container+"|"+hash] = &azBlob{
153 Uncommitted: make(map[string][]byte),
154 Metadata: make(map[string]string),
159 metadata := make(map[string]string)
160 for k, v := range r.Header {
161 if strings.HasPrefix(strings.ToLower(k), "x-ms-meta-") {
162 name := k[len("x-ms-meta-"):]
163 metadata[strings.ToLower(name)] = v[0]
166 h.blobs[container+"|"+hash] = &azBlob{
169 Uncommitted: make(map[string][]byte),
173 rw.WriteHeader(http.StatusCreated)
174 case r.Method == "PUT" && r.Form.Get("comp") == "block":
177 log.Printf("Got block for nonexistent blob: %+v", r)
178 rw.WriteHeader(http.StatusBadRequest)
181 blockID, err := base64.StdEncoding.DecodeString(r.Form.Get("blockid"))
182 if err != nil || len(blockID) == 0 {
183 log.Printf("Invalid blockid: %+q", r.Form.Get("blockid"))
184 rw.WriteHeader(http.StatusBadRequest)
187 blob.Uncommitted[string(blockID)] = body
188 rw.WriteHeader(http.StatusCreated)
189 case r.Method == "PUT" && r.Form.Get("comp") == "blocklist":
190 // "Put Block List" API
191 bl := &blockListRequestBody{}
192 if err := xml.Unmarshal(body, bl); err != nil {
193 log.Printf("xml Unmarshal: %s", err)
194 rw.WriteHeader(http.StatusBadRequest)
197 for _, encBlockID := range bl.Uncommitted {
198 blockID, err := base64.StdEncoding.DecodeString(encBlockID)
199 if err != nil || len(blockID) == 0 || blob.Uncommitted[string(blockID)] == nil {
200 log.Printf("Invalid blockid: %+q", encBlockID)
201 rw.WriteHeader(http.StatusBadRequest)
204 blob.Data = blob.Uncommitted[string(blockID)]
205 blob.Etag = makeEtag()
206 blob.Mtime = time.Now()
207 delete(blob.Uncommitted, string(blockID))
209 rw.WriteHeader(http.StatusCreated)
210 case r.Method == "PUT" && r.Form.Get("comp") == "metadata":
211 // "Set Metadata Headers" API. We don't bother
212 // stubbing "Get Metadata Headers": AzureBlobVolume
213 // sets metadata headers only as a way to bump Etag
214 // and Last-Modified.
216 log.Printf("Got metadata for nonexistent blob: %+v", r)
217 rw.WriteHeader(http.StatusBadRequest)
220 blob.Metadata = make(map[string]string)
221 for k, v := range r.Header {
222 if strings.HasPrefix(strings.ToLower(k), "x-ms-meta-") {
223 name := k[len("x-ms-meta-"):]
224 blob.Metadata[strings.ToLower(name)] = v[0]
227 blob.Mtime = time.Now()
228 blob.Etag = makeEtag()
229 case (r.Method == "GET" || r.Method == "HEAD") && r.Form.Get("comp") == "metadata" && hash != "":
230 // "Get Blob Metadata" API
232 rw.WriteHeader(http.StatusNotFound)
235 for k, v := range blob.Metadata {
236 rw.Header().Set(fmt.Sprintf("x-ms-meta-%s", k), v)
239 case (r.Method == "GET" || r.Method == "HEAD") && hash != "":
242 rw.WriteHeader(http.StatusNotFound)
246 if rangeSpec := rangeRegexp.FindStringSubmatch(r.Header.Get("Range")); rangeSpec != nil {
247 b0, err0 := strconv.Atoi(rangeSpec[1])
248 b1, err1 := strconv.Atoi(rangeSpec[2])
249 if err0 != nil || err1 != nil || b0 >= len(data) || b1 >= len(data) || b0 > b1 {
250 rw.Header().Set("Content-Range", fmt.Sprintf("bytes */%d", len(data)))
251 rw.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
254 rw.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", b0, b1, len(data)))
255 rw.WriteHeader(http.StatusPartialContent)
256 data = data[b0 : b1+1]
258 rw.Header().Set("Last-Modified", blob.Mtime.Format(time.RFC1123))
259 rw.Header().Set("Content-Length", strconv.Itoa(len(data)))
260 if r.Method == "GET" {
261 if _, err := rw.Write(data); err != nil {
262 log.Printf("write %+q: %s", data, err)
266 case r.Method == "DELETE" && hash != "":
269 rw.WriteHeader(http.StatusNotFound)
272 delete(h.blobs, container+"|"+hash)
273 rw.WriteHeader(http.StatusAccepted)
274 case r.Method == "GET" && r.Form.Get("comp") == "list" && r.Form.Get("restype") == "container":
276 prefix := container + "|" + r.Form.Get("prefix")
277 marker := r.Form.Get("marker")
280 if n, err := strconv.Atoi(r.Form.Get("maxresults")); err == nil && n >= 1 && n <= 5000 {
284 resp := storage.BlobListResponse{
287 MaxResults: int64(maxResults),
289 var hashes sort.StringSlice
290 for k := range h.blobs {
291 if strings.HasPrefix(k, prefix) {
292 hashes = append(hashes, k[len(container)+1:])
296 for _, hash := range hashes {
297 if len(resp.Blobs) == maxResults {
298 resp.NextMarker = hash
301 if len(resp.Blobs) > 0 || marker == "" || marker == hash {
302 blob := h.blobs[container+"|"+hash]
303 bmeta := map[string]string(nil)
304 if r.Form.Get("include") == "metadata" {
305 bmeta = blob.Metadata
309 Properties: storage.BlobProperties{
310 LastModified: storage.TimeRFC1123(blob.Mtime),
311 ContentLength: int64(len(blob.Data)),
316 resp.Blobs = append(resp.Blobs, b)
319 buf, err := xml.Marshal(resp)
322 rw.WriteHeader(http.StatusInternalServerError)
326 log.Printf("azStubHandler: not implemented: %+v Body:%+q", r, body)
327 rw.WriteHeader(http.StatusNotImplemented)
331 // azStubDialer is a net.Dialer that notices when the Azure driver
332 // tries to connect to "devstoreaccount1.blob.127.0.0.1:46067", and
333 // in such cases transparently dials "127.0.0.1:46067" instead.
334 type azStubDialer struct {
338 var localHostPortRe = regexp.MustCompile(`(127\.0\.0\.1|localhost|\[::1\]):\d+`)
340 func (d *azStubDialer) Dial(network, address string) (net.Conn, error) {
341 if hp := localHostPortRe.FindString(address); hp != "" {
342 log.Println("azStubDialer: dial", hp, "instead of", address)
345 return d.Dialer.Dial(network, address)
348 type TestableAzureBlobVolume struct {
350 azHandler *azStubHandler
351 azStub *httptest.Server
355 func NewTestableAzureBlobVolume(t TB, readonly bool, replication int) *TestableAzureBlobVolume {
356 azHandler := newAzStubHandler()
357 azStub := httptest.NewServer(azHandler)
359 var azClient storage.Client
361 container := azureTestContainer
363 // Connect to stub instead of real Azure storage service
364 stubURLBase := strings.Split(azStub.URL, "://")[1]
366 if azClient, err = storage.NewClient(fakeAccountName, fakeAccountKey, stubURLBase, storage.DefaultAPIVersion, false); err != nil {
369 container = "fakecontainername"
371 // Connect to real Azure storage service
372 accountKey, err := readKeyFromFile(azureStorageAccountKeyFile)
376 azClient, err = storage.NewBasicClient(azureStorageAccountName, accountKey)
382 bs := azClient.GetBlobService()
383 v := &AzureBlobVolume{
384 ContainerName: container,
386 AzureReplication: replication,
388 container: &azureContainer{ctr: bs.GetContainerReference(container)},
391 return &TestableAzureBlobVolume{
393 azHandler: azHandler,
399 var _ = check.Suite(&StubbedAzureBlobSuite{})
401 type StubbedAzureBlobSuite struct {
402 volume *TestableAzureBlobVolume
403 origHTTPTransport http.RoundTripper
406 func (s *StubbedAzureBlobSuite) SetUpTest(c *check.C) {
407 s.origHTTPTransport = http.DefaultTransport
408 http.DefaultTransport = &http.Transport{
409 Dial: (&azStubDialer{}).Dial,
411 azureWriteRaceInterval = time.Millisecond
412 azureWriteRacePollTime = time.Nanosecond
414 s.volume = NewTestableAzureBlobVolume(c, false, 3)
417 func (s *StubbedAzureBlobSuite) TearDownTest(c *check.C) {
419 http.DefaultTransport = s.origHTTPTransport
422 func TestAzureBlobVolumeWithGeneric(t *testing.T) {
423 defer func(t http.RoundTripper) {
424 http.DefaultTransport = t
425 }(http.DefaultTransport)
426 http.DefaultTransport = &http.Transport{
427 Dial: (&azStubDialer{}).Dial,
429 azureWriteRaceInterval = time.Millisecond
430 azureWriteRacePollTime = time.Nanosecond
431 DoGenericVolumeTests(t, func(t TB) TestableVolume {
432 return NewTestableAzureBlobVolume(t, false, azureStorageReplication)
436 func TestAzureBlobVolumeConcurrentRanges(t *testing.T) {
441 defer func(t http.RoundTripper) {
442 http.DefaultTransport = t
443 }(http.DefaultTransport)
444 http.DefaultTransport = &http.Transport{
445 Dial: (&azStubDialer{}).Dial,
447 azureWriteRaceInterval = time.Millisecond
448 azureWriteRacePollTime = time.Nanosecond
449 // Test (BlockSize mod azureMaxGetBytes)==0 and !=0 cases
450 for _, azureMaxGetBytes = range []int{2 << 22, 2<<22 - 1} {
451 DoGenericVolumeTests(t, func(t TB) TestableVolume {
452 return NewTestableAzureBlobVolume(t, false, azureStorageReplication)
457 func TestReadonlyAzureBlobVolumeWithGeneric(t *testing.T) {
458 defer func(t http.RoundTripper) {
459 http.DefaultTransport = t
460 }(http.DefaultTransport)
461 http.DefaultTransport = &http.Transport{
462 Dial: (&azStubDialer{}).Dial,
464 azureWriteRaceInterval = time.Millisecond
465 azureWriteRacePollTime = time.Nanosecond
466 DoGenericVolumeTests(t, func(t TB) TestableVolume {
467 return NewTestableAzureBlobVolume(t, true, azureStorageReplication)
471 func TestAzureBlobVolumeRangeFenceposts(t *testing.T) {
472 defer func(t http.RoundTripper) {
473 http.DefaultTransport = t
474 }(http.DefaultTransport)
475 http.DefaultTransport = &http.Transport{
476 Dial: (&azStubDialer{}).Dial,
479 v := NewTestableAzureBlobVolume(t, false, 3)
482 for _, size := range []int{
483 2<<22 - 1, // one <max read
484 2 << 22, // one =max read
485 2<<22 + 1, // one =max read, one <max
486 2 << 23, // two =max reads
490 data := make([]byte, size)
491 for i := range data {
492 data[i] = byte((i + 7) & 0xff)
494 hash := fmt.Sprintf("%x", md5.Sum(data))
495 err := v.Put(context.Background(), hash, data)
499 gotData := make([]byte, len(data))
500 gotLen, err := v.Get(context.Background(), hash, gotData)
504 gotHash := fmt.Sprintf("%x", md5.Sum(gotData))
506 t.Errorf("length mismatch: got %d != %d", gotLen, size)
509 t.Errorf("hash mismatch: got %s != %s", gotHash, hash)
514 func TestAzureBlobVolumeReplication(t *testing.T) {
515 for r := 1; r <= 4; r++ {
516 v := NewTestableAzureBlobVolume(t, false, r)
518 if n := v.Replication(); n != r {
519 t.Errorf("Got replication %d, expected %d", n, r)
524 func TestAzureBlobVolumeCreateBlobRace(t *testing.T) {
525 defer func(t http.RoundTripper) {
526 http.DefaultTransport = t
527 }(http.DefaultTransport)
528 http.DefaultTransport = &http.Transport{
529 Dial: (&azStubDialer{}).Dial,
532 v := NewTestableAzureBlobVolume(t, false, 3)
535 azureWriteRaceInterval = time.Second
536 azureWriteRacePollTime = time.Millisecond
538 allDone := make(chan struct{})
539 v.azHandler.race = make(chan chan struct{})
541 err := v.Put(context.Background(), TestHash, TestBlock)
546 continuePut := make(chan struct{})
547 // Wait for the stub's Put to create the empty blob
548 v.azHandler.race <- continuePut
550 buf := make([]byte, len(TestBlock))
551 _, err := v.Get(context.Background(), TestHash, buf)
557 // Wait for the stub's Get to get the empty blob
558 close(v.azHandler.race)
559 // Allow stub's Put to continue, so the real data is ready
560 // when the volume's Get retries
562 // Wait for volume's Get to return the real data
566 func TestAzureBlobVolumeCreateBlobRaceDeadline(t *testing.T) {
567 defer func(t http.RoundTripper) {
568 http.DefaultTransport = t
569 }(http.DefaultTransport)
570 http.DefaultTransport = &http.Transport{
571 Dial: (&azStubDialer{}).Dial,
574 v := NewTestableAzureBlobVolume(t, false, 3)
577 azureWriteRaceInterval = 2 * time.Second
578 azureWriteRacePollTime = 5 * time.Millisecond
580 v.PutRaw(TestHash, nil)
582 buf := new(bytes.Buffer)
585 t.Errorf("Index %+q should be empty", buf.Bytes())
588 v.TouchWithDate(TestHash, time.Now().Add(-1982*time.Millisecond))
590 allDone := make(chan struct{})
593 buf := make([]byte, BlockSize)
594 n, err := v.Get(context.Background(), TestHash, buf)
600 t.Errorf("Got %+q, expected empty buf", buf[:n])
605 case <-time.After(time.Second):
606 t.Error("Get should have stopped waiting for race when block was 2s old")
611 if !bytes.HasPrefix(buf.Bytes(), []byte(TestHash+"+0")) {
612 t.Errorf("Index %+q should have %+q", buf.Bytes(), TestHash+"+0")
616 func TestAzureBlobVolumeContextCancelGet(t *testing.T) {
617 testAzureBlobVolumeContextCancel(t, func(ctx context.Context, v *TestableAzureBlobVolume) error {
618 v.PutRaw(TestHash, TestBlock)
619 _, err := v.Get(ctx, TestHash, make([]byte, BlockSize))
624 func TestAzureBlobVolumeContextCancelPut(t *testing.T) {
625 testAzureBlobVolumeContextCancel(t, func(ctx context.Context, v *TestableAzureBlobVolume) error {
626 return v.Put(ctx, TestHash, make([]byte, BlockSize))
630 func TestAzureBlobVolumeContextCancelCompare(t *testing.T) {
631 testAzureBlobVolumeContextCancel(t, func(ctx context.Context, v *TestableAzureBlobVolume) error {
632 v.PutRaw(TestHash, TestBlock)
633 return v.Compare(ctx, TestHash, TestBlock2)
637 func testAzureBlobVolumeContextCancel(t *testing.T, testFunc func(context.Context, *TestableAzureBlobVolume) error) {
638 defer func(t http.RoundTripper) {
639 http.DefaultTransport = t
640 }(http.DefaultTransport)
641 http.DefaultTransport = &http.Transport{
642 Dial: (&azStubDialer{}).Dial,
645 v := NewTestableAzureBlobVolume(t, false, 3)
647 v.azHandler.race = make(chan chan struct{})
649 ctx, cancel := context.WithCancel(context.Background())
650 allDone := make(chan struct{})
653 err := testFunc(ctx, v)
654 if err != context.Canceled {
655 t.Errorf("got %T %q, expected %q", err, err, context.Canceled)
658 releaseHandler := make(chan struct{})
661 t.Error("testFunc finished without waiting for v.azHandler.race")
662 case <-time.After(10 * time.Second):
663 t.Error("timed out waiting to enter handler")
664 case v.azHandler.race <- releaseHandler:
670 case <-time.After(10 * time.Second):
671 t.Error("timed out waiting to cancel")
680 func (s *StubbedAzureBlobSuite) TestStats(c *check.C) {
681 stats := func() string {
682 buf, err := json.Marshal(s.volume.InternalStats())
683 c.Check(err, check.IsNil)
687 c.Check(stats(), check.Matches, `.*"Ops":0,.*`)
688 c.Check(stats(), check.Matches, `.*"Errors":0,.*`)
690 loc := "acbd18db4cc2f85cedef654fccc4a4d8"
691 _, err := s.volume.Get(context.Background(), loc, make([]byte, 3))
692 c.Check(err, check.NotNil)
693 c.Check(stats(), check.Matches, `.*"Ops":[^0],.*`)
694 c.Check(stats(), check.Matches, `.*"Errors":[^0],.*`)
695 c.Check(stats(), check.Matches, `.*"storage\.AzureStorageServiceError 404 \(404 Not Found\)":[^0].*`)
696 c.Check(stats(), check.Matches, `.*"InBytes":0,.*`)
698 err = s.volume.Put(context.Background(), loc, []byte("foo"))
699 c.Check(err, check.IsNil)
700 c.Check(stats(), check.Matches, `.*"OutBytes":3,.*`)
701 c.Check(stats(), check.Matches, `.*"CreateOps":1,.*`)
703 _, err = s.volume.Get(context.Background(), loc, make([]byte, 3))
704 c.Check(err, check.IsNil)
705 _, err = s.volume.Get(context.Background(), loc, make([]byte, 3))
706 c.Check(err, check.IsNil)
707 c.Check(stats(), check.Matches, `.*"InBytes":6,.*`)
710 func (v *TestableAzureBlobVolume) PutRaw(locator string, data []byte) {
711 v.azHandler.PutRaw(v.ContainerName, locator, data)
714 func (v *TestableAzureBlobVolume) TouchWithDate(locator string, lastPut time.Time) {
715 v.azHandler.TouchWithDate(v.ContainerName, locator, lastPut)
718 func (v *TestableAzureBlobVolume) Teardown() {
722 func makeEtag() string {
723 return fmt.Sprintf("0x%x", rand.Int63())