Merge branch '13593-async-perm-graph-update'
[arvados.git] / services / keepstore / azure_blob_volume_test.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package main
6
7 import (
8         "bytes"
9         "context"
10         "crypto/md5"
11         "encoding/base64"
12         "encoding/json"
13         "encoding/xml"
14         "flag"
15         "fmt"
16         "io/ioutil"
17         "math/rand"
18         "net"
19         "net/http"
20         "net/http/httptest"
21         "os"
22         "regexp"
23         "sort"
24         "strconv"
25         "strings"
26         "sync"
27         "testing"
28         "time"
29
30         "github.com/Azure/azure-sdk-for-go/storage"
31         "github.com/ghodss/yaml"
32         "github.com/prometheus/client_golang/prometheus"
33         check "gopkg.in/check.v1"
34 )
35
36 const (
37         // This cannot be the fake account name "devstoreaccount1"
38         // used by Microsoft's Azure emulator: the Azure SDK
39         // recognizes that magic string and changes its behavior to
40         // cater to the Azure SDK's own test suite.
41         fakeAccountName = "fakeaccountname"
42         fakeAccountKey  = "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw=="
43 )
44
45 var (
46         azureTestContainer string
47         azureTestDebug     = os.Getenv("ARVADOS_DEBUG") != ""
48 )
49
50 func init() {
51         flag.StringVar(
52                 &azureTestContainer,
53                 "test.azure-storage-container-volume",
54                 "",
55                 "Name of Azure container to use for testing. Do not use a container with real data! Use -azure-storage-account-name and -azure-storage-key-file arguments to supply credentials.")
56 }
57
58 type azBlob struct {
59         Data        []byte
60         Etag        string
61         Metadata    map[string]string
62         Mtime       time.Time
63         Uncommitted map[string][]byte
64 }
65
66 type azStubHandler struct {
67         sync.Mutex
68         blobs map[string]*azBlob
69         race  chan chan struct{}
70 }
71
72 func newAzStubHandler() *azStubHandler {
73         return &azStubHandler{
74                 blobs: make(map[string]*azBlob),
75         }
76 }
77
78 func (h *azStubHandler) TouchWithDate(container, hash string, t time.Time) {
79         blob, ok := h.blobs[container+"|"+hash]
80         if !ok {
81                 return
82         }
83         blob.Mtime = t
84 }
85
86 func (h *azStubHandler) PutRaw(container, hash string, data []byte) {
87         h.Lock()
88         defer h.Unlock()
89         h.blobs[container+"|"+hash] = &azBlob{
90                 Data:        data,
91                 Mtime:       time.Now(),
92                 Metadata:    make(map[string]string),
93                 Uncommitted: make(map[string][]byte),
94         }
95 }
96
97 func (h *azStubHandler) unlockAndRace() {
98         if h.race == nil {
99                 return
100         }
101         h.Unlock()
102         // Signal caller that race is starting by reading from
103         // h.race. If we get a channel, block until that channel is
104         // ready to receive. If we get nil (or h.race is closed) just
105         // proceed.
106         if c := <-h.race; c != nil {
107                 c <- struct{}{}
108         }
109         h.Lock()
110 }
111
112 var rangeRegexp = regexp.MustCompile(`^bytes=(\d+)-(\d+)$`)
113
114 func (h *azStubHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
115         h.Lock()
116         defer h.Unlock()
117         if azureTestDebug {
118                 defer log.Printf("azStubHandler: %+v", r)
119         }
120
121         path := strings.Split(r.URL.Path, "/")
122         container := path[1]
123         hash := ""
124         if len(path) > 2 {
125                 hash = path[2]
126         }
127
128         if err := r.ParseForm(); err != nil {
129                 log.Printf("azStubHandler(%+v): %s", r, err)
130                 rw.WriteHeader(http.StatusBadRequest)
131                 return
132         }
133
134         if (r.Method == "PUT" || r.Method == "POST") && r.Header.Get("Content-Length") == "" {
135                 rw.WriteHeader(http.StatusLengthRequired)
136                 return
137         }
138
139         body, err := ioutil.ReadAll(r.Body)
140         if err != nil {
141                 return
142         }
143
144         type blockListRequestBody struct {
145                 XMLName     xml.Name `xml:"BlockList"`
146                 Uncommitted []string
147         }
148
149         blob, blobExists := h.blobs[container+"|"+hash]
150
151         switch {
152         case r.Method == "PUT" && r.Form.Get("comp") == "":
153                 // "Put Blob" API
154                 if _, ok := h.blobs[container+"|"+hash]; !ok {
155                         // Like the real Azure service, we offer a
156                         // race window during which other clients can
157                         // list/get the new blob before any data is
158                         // committed.
159                         h.blobs[container+"|"+hash] = &azBlob{
160                                 Mtime:       time.Now(),
161                                 Uncommitted: make(map[string][]byte),
162                                 Metadata:    make(map[string]string),
163                                 Etag:        makeEtag(),
164                         }
165                         h.unlockAndRace()
166                 }
167                 metadata := make(map[string]string)
168                 for k, v := range r.Header {
169                         if strings.HasPrefix(strings.ToLower(k), "x-ms-meta-") {
170                                 name := k[len("x-ms-meta-"):]
171                                 metadata[strings.ToLower(name)] = v[0]
172                         }
173                 }
174                 h.blobs[container+"|"+hash] = &azBlob{
175                         Data:        body,
176                         Mtime:       time.Now(),
177                         Uncommitted: make(map[string][]byte),
178                         Metadata:    metadata,
179                         Etag:        makeEtag(),
180                 }
181                 rw.WriteHeader(http.StatusCreated)
182         case r.Method == "PUT" && r.Form.Get("comp") == "block":
183                 // "Put Block" API
184                 if !blobExists {
185                         log.Printf("Got block for nonexistent blob: %+v", r)
186                         rw.WriteHeader(http.StatusBadRequest)
187                         return
188                 }
189                 blockID, err := base64.StdEncoding.DecodeString(r.Form.Get("blockid"))
190                 if err != nil || len(blockID) == 0 {
191                         log.Printf("Invalid blockid: %+q", r.Form.Get("blockid"))
192                         rw.WriteHeader(http.StatusBadRequest)
193                         return
194                 }
195                 blob.Uncommitted[string(blockID)] = body
196                 rw.WriteHeader(http.StatusCreated)
197         case r.Method == "PUT" && r.Form.Get("comp") == "blocklist":
198                 // "Put Block List" API
199                 bl := &blockListRequestBody{}
200                 if err := xml.Unmarshal(body, bl); err != nil {
201                         log.Printf("xml Unmarshal: %s", err)
202                         rw.WriteHeader(http.StatusBadRequest)
203                         return
204                 }
205                 for _, encBlockID := range bl.Uncommitted {
206                         blockID, err := base64.StdEncoding.DecodeString(encBlockID)
207                         if err != nil || len(blockID) == 0 || blob.Uncommitted[string(blockID)] == nil {
208                                 log.Printf("Invalid blockid: %+q", encBlockID)
209                                 rw.WriteHeader(http.StatusBadRequest)
210                                 return
211                         }
212                         blob.Data = blob.Uncommitted[string(blockID)]
213                         blob.Etag = makeEtag()
214                         blob.Mtime = time.Now()
215                         delete(blob.Uncommitted, string(blockID))
216                 }
217                 rw.WriteHeader(http.StatusCreated)
218         case r.Method == "PUT" && r.Form.Get("comp") == "metadata":
219                 // "Set Metadata Headers" API. We don't bother
220                 // stubbing "Get Metadata Headers": AzureBlobVolume
221                 // sets metadata headers only as a way to bump Etag
222                 // and Last-Modified.
223                 if !blobExists {
224                         log.Printf("Got metadata for nonexistent blob: %+v", r)
225                         rw.WriteHeader(http.StatusBadRequest)
226                         return
227                 }
228                 blob.Metadata = make(map[string]string)
229                 for k, v := range r.Header {
230                         if strings.HasPrefix(strings.ToLower(k), "x-ms-meta-") {
231                                 name := k[len("x-ms-meta-"):]
232                                 blob.Metadata[strings.ToLower(name)] = v[0]
233                         }
234                 }
235                 blob.Mtime = time.Now()
236                 blob.Etag = makeEtag()
237         case (r.Method == "GET" || r.Method == "HEAD") && r.Form.Get("comp") == "metadata" && hash != "":
238                 // "Get Blob Metadata" API
239                 if !blobExists {
240                         rw.WriteHeader(http.StatusNotFound)
241                         return
242                 }
243                 for k, v := range blob.Metadata {
244                         rw.Header().Set(fmt.Sprintf("x-ms-meta-%s", k), v)
245                 }
246                 return
247         case (r.Method == "GET" || r.Method == "HEAD") && hash != "":
248                 // "Get Blob" API
249                 if !blobExists {
250                         rw.WriteHeader(http.StatusNotFound)
251                         return
252                 }
253                 data := blob.Data
254                 if rangeSpec := rangeRegexp.FindStringSubmatch(r.Header.Get("Range")); rangeSpec != nil {
255                         b0, err0 := strconv.Atoi(rangeSpec[1])
256                         b1, err1 := strconv.Atoi(rangeSpec[2])
257                         if err0 != nil || err1 != nil || b0 >= len(data) || b1 >= len(data) || b0 > b1 {
258                                 rw.Header().Set("Content-Range", fmt.Sprintf("bytes */%d", len(data)))
259                                 rw.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
260                                 return
261                         }
262                         rw.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", b0, b1, len(data)))
263                         rw.WriteHeader(http.StatusPartialContent)
264                         data = data[b0 : b1+1]
265                 }
266                 rw.Header().Set("Last-Modified", blob.Mtime.Format(time.RFC1123))
267                 rw.Header().Set("Content-Length", strconv.Itoa(len(data)))
268                 if r.Method == "GET" {
269                         if _, err := rw.Write(data); err != nil {
270                                 log.Printf("write %+q: %s", data, err)
271                         }
272                 }
273                 h.unlockAndRace()
274         case r.Method == "DELETE" && hash != "":
275                 // "Delete Blob" API
276                 if !blobExists {
277                         rw.WriteHeader(http.StatusNotFound)
278                         return
279                 }
280                 delete(h.blobs, container+"|"+hash)
281                 rw.WriteHeader(http.StatusAccepted)
282         case r.Method == "GET" && r.Form.Get("comp") == "list" && r.Form.Get("restype") == "container":
283                 // "List Blobs" API
284                 prefix := container + "|" + r.Form.Get("prefix")
285                 marker := r.Form.Get("marker")
286
287                 maxResults := 2
288                 if n, err := strconv.Atoi(r.Form.Get("maxresults")); err == nil && n >= 1 && n <= 5000 {
289                         maxResults = n
290                 }
291
292                 resp := storage.BlobListResponse{
293                         Marker:     marker,
294                         NextMarker: "",
295                         MaxResults: int64(maxResults),
296                 }
297                 var hashes sort.StringSlice
298                 for k := range h.blobs {
299                         if strings.HasPrefix(k, prefix) {
300                                 hashes = append(hashes, k[len(container)+1:])
301                         }
302                 }
303                 hashes.Sort()
304                 for _, hash := range hashes {
305                         if len(resp.Blobs) == maxResults {
306                                 resp.NextMarker = hash
307                                 break
308                         }
309                         if len(resp.Blobs) > 0 || marker == "" || marker == hash {
310                                 blob := h.blobs[container+"|"+hash]
311                                 bmeta := map[string]string(nil)
312                                 if r.Form.Get("include") == "metadata" {
313                                         bmeta = blob.Metadata
314                                 }
315                                 b := storage.Blob{
316                                         Name: hash,
317                                         Properties: storage.BlobProperties{
318                                                 LastModified:  storage.TimeRFC1123(blob.Mtime),
319                                                 ContentLength: int64(len(blob.Data)),
320                                                 Etag:          blob.Etag,
321                                         },
322                                         Metadata: bmeta,
323                                 }
324                                 resp.Blobs = append(resp.Blobs, b)
325                         }
326                 }
327                 buf, err := xml.Marshal(resp)
328                 if err != nil {
329                         log.Print(err)
330                         rw.WriteHeader(http.StatusInternalServerError)
331                 }
332                 rw.Write(buf)
333         default:
334                 log.Printf("azStubHandler: not implemented: %+v Body:%+q", r, body)
335                 rw.WriteHeader(http.StatusNotImplemented)
336         }
337 }
338
339 // azStubDialer is a net.Dialer that notices when the Azure driver
340 // tries to connect to "devstoreaccount1.blob.127.0.0.1:46067", and
341 // in such cases transparently dials "127.0.0.1:46067" instead.
342 type azStubDialer struct {
343         net.Dialer
344 }
345
346 var localHostPortRe = regexp.MustCompile(`(127\.0\.0\.1|localhost|\[::1\]):\d+`)
347
348 func (d *azStubDialer) Dial(network, address string) (net.Conn, error) {
349         if hp := localHostPortRe.FindString(address); hp != "" {
350                 if azureTestDebug {
351                         log.Println("azStubDialer: dial", hp, "instead of", address)
352                 }
353                 address = hp
354         }
355         return d.Dialer.Dial(network, address)
356 }
357
358 type TestableAzureBlobVolume struct {
359         *AzureBlobVolume
360         azHandler *azStubHandler
361         azStub    *httptest.Server
362         t         TB
363 }
364
365 func NewTestableAzureBlobVolume(t TB, readonly bool, replication int) *TestableAzureBlobVolume {
366         azHandler := newAzStubHandler()
367         azStub := httptest.NewServer(azHandler)
368
369         var azClient storage.Client
370
371         container := azureTestContainer
372         if container == "" {
373                 // Connect to stub instead of real Azure storage service
374                 stubURLBase := strings.Split(azStub.URL, "://")[1]
375                 var err error
376                 if azClient, err = storage.NewClient(fakeAccountName, fakeAccountKey, stubURLBase, storage.DefaultAPIVersion, false); err != nil {
377                         t.Fatal(err)
378                 }
379                 container = "fakecontainername"
380         } else {
381                 // Connect to real Azure storage service
382                 accountKey, err := readKeyFromFile(azureStorageAccountKeyFile)
383                 if err != nil {
384                         t.Fatal(err)
385                 }
386                 azClient, err = storage.NewBasicClient(azureStorageAccountName, accountKey)
387                 if err != nil {
388                         t.Fatal(err)
389                 }
390         }
391
392         bs := azClient.GetBlobService()
393         v := &AzureBlobVolume{
394                 ContainerName:    container,
395                 ReadOnly:         readonly,
396                 AzureReplication: replication,
397                 azClient:         azClient,
398                 container:        &azureContainer{ctr: bs.GetContainerReference(container)},
399         }
400
401         return &TestableAzureBlobVolume{
402                 AzureBlobVolume: v,
403                 azHandler:       azHandler,
404                 azStub:          azStub,
405                 t:               t,
406         }
407 }
408
409 var _ = check.Suite(&StubbedAzureBlobSuite{})
410
411 type StubbedAzureBlobSuite struct {
412         volume            *TestableAzureBlobVolume
413         origHTTPTransport http.RoundTripper
414 }
415
416 func (s *StubbedAzureBlobSuite) SetUpTest(c *check.C) {
417         s.origHTTPTransport = http.DefaultTransport
418         http.DefaultTransport = &http.Transport{
419                 Dial: (&azStubDialer{}).Dial,
420         }
421         azureWriteRaceInterval = time.Millisecond
422         azureWriteRacePollTime = time.Nanosecond
423
424         s.volume = NewTestableAzureBlobVolume(c, false, 3)
425 }
426
427 func (s *StubbedAzureBlobSuite) TearDownTest(c *check.C) {
428         s.volume.Teardown()
429         http.DefaultTransport = s.origHTTPTransport
430 }
431
432 func TestAzureBlobVolumeWithGeneric(t *testing.T) {
433         defer func(t http.RoundTripper) {
434                 http.DefaultTransport = t
435         }(http.DefaultTransport)
436         http.DefaultTransport = &http.Transport{
437                 Dial: (&azStubDialer{}).Dial,
438         }
439         azureWriteRaceInterval = time.Millisecond
440         azureWriteRacePollTime = time.Nanosecond
441         DoGenericVolumeTests(t, func(t TB) TestableVolume {
442                 return NewTestableAzureBlobVolume(t, false, azureStorageReplication)
443         })
444 }
445
446 func TestAzureBlobVolumeConcurrentRanges(t *testing.T) {
447         defer func(b int) {
448                 azureMaxGetBytes = b
449         }(azureMaxGetBytes)
450
451         defer func(t http.RoundTripper) {
452                 http.DefaultTransport = t
453         }(http.DefaultTransport)
454         http.DefaultTransport = &http.Transport{
455                 Dial: (&azStubDialer{}).Dial,
456         }
457         azureWriteRaceInterval = time.Millisecond
458         azureWriteRacePollTime = time.Nanosecond
459         // Test (BlockSize mod azureMaxGetBytes)==0 and !=0 cases
460         for _, azureMaxGetBytes = range []int{2 << 22, 2<<22 - 1} {
461                 DoGenericVolumeTests(t, func(t TB) TestableVolume {
462                         return NewTestableAzureBlobVolume(t, false, azureStorageReplication)
463                 })
464         }
465 }
466
467 func TestReadonlyAzureBlobVolumeWithGeneric(t *testing.T) {
468         defer func(t http.RoundTripper) {
469                 http.DefaultTransport = t
470         }(http.DefaultTransport)
471         http.DefaultTransport = &http.Transport{
472                 Dial: (&azStubDialer{}).Dial,
473         }
474         azureWriteRaceInterval = time.Millisecond
475         azureWriteRacePollTime = time.Nanosecond
476         DoGenericVolumeTests(t, func(t TB) TestableVolume {
477                 return NewTestableAzureBlobVolume(t, true, azureStorageReplication)
478         })
479 }
480
481 func TestAzureBlobVolumeRangeFenceposts(t *testing.T) {
482         defer func(t http.RoundTripper) {
483                 http.DefaultTransport = t
484         }(http.DefaultTransport)
485         http.DefaultTransport = &http.Transport{
486                 Dial: (&azStubDialer{}).Dial,
487         }
488
489         v := NewTestableAzureBlobVolume(t, false, 3)
490         defer v.Teardown()
491
492         for _, size := range []int{
493                 2<<22 - 1, // one <max read
494                 2 << 22,   // one =max read
495                 2<<22 + 1, // one =max read, one <max
496                 2 << 23,   // two =max reads
497                 BlockSize - 1,
498                 BlockSize,
499         } {
500                 data := make([]byte, size)
501                 for i := range data {
502                         data[i] = byte((i + 7) & 0xff)
503                 }
504                 hash := fmt.Sprintf("%x", md5.Sum(data))
505                 err := v.Put(context.Background(), hash, data)
506                 if err != nil {
507                         t.Error(err)
508                 }
509                 gotData := make([]byte, len(data))
510                 gotLen, err := v.Get(context.Background(), hash, gotData)
511                 if err != nil {
512                         t.Error(err)
513                 }
514                 gotHash := fmt.Sprintf("%x", md5.Sum(gotData))
515                 if gotLen != size {
516                         t.Errorf("length mismatch: got %d != %d", gotLen, size)
517                 }
518                 if gotHash != hash {
519                         t.Errorf("hash mismatch: got %s != %s", gotHash, hash)
520                 }
521         }
522 }
523
524 func TestAzureBlobVolumeReplication(t *testing.T) {
525         for r := 1; r <= 4; r++ {
526                 v := NewTestableAzureBlobVolume(t, false, r)
527                 defer v.Teardown()
528                 if n := v.Replication(); n != r {
529                         t.Errorf("Got replication %d, expected %d", n, r)
530                 }
531         }
532 }
533
534 func TestAzureBlobVolumeCreateBlobRace(t *testing.T) {
535         defer func(t http.RoundTripper) {
536                 http.DefaultTransport = t
537         }(http.DefaultTransport)
538         http.DefaultTransport = &http.Transport{
539                 Dial: (&azStubDialer{}).Dial,
540         }
541
542         v := NewTestableAzureBlobVolume(t, false, 3)
543         defer v.Teardown()
544
545         azureWriteRaceInterval = time.Second
546         azureWriteRacePollTime = time.Millisecond
547
548         var wg sync.WaitGroup
549
550         v.azHandler.race = make(chan chan struct{})
551
552         wg.Add(1)
553         go func() {
554                 defer wg.Done()
555                 err := v.Put(context.Background(), TestHash, TestBlock)
556                 if err != nil {
557                         t.Error(err)
558                 }
559         }()
560         continuePut := make(chan struct{})
561         // Wait for the stub's Put to create the empty blob
562         v.azHandler.race <- continuePut
563         wg.Add(1)
564         go func() {
565                 defer wg.Done()
566                 buf := make([]byte, len(TestBlock))
567                 _, err := v.Get(context.Background(), TestHash, buf)
568                 if err != nil {
569                         t.Error(err)
570                 }
571         }()
572         // Wait for the stub's Get to get the empty blob
573         close(v.azHandler.race)
574         // Allow stub's Put to continue, so the real data is ready
575         // when the volume's Get retries
576         <-continuePut
577         // Wait for Get() and Put() to finish
578         wg.Wait()
579 }
580
581 func TestAzureBlobVolumeCreateBlobRaceDeadline(t *testing.T) {
582         defer func(t http.RoundTripper) {
583                 http.DefaultTransport = t
584         }(http.DefaultTransport)
585         http.DefaultTransport = &http.Transport{
586                 Dial: (&azStubDialer{}).Dial,
587         }
588
589         v := NewTestableAzureBlobVolume(t, false, 3)
590         defer v.Teardown()
591
592         azureWriteRaceInterval = 2 * time.Second
593         azureWriteRacePollTime = 5 * time.Millisecond
594
595         v.PutRaw(TestHash, nil)
596
597         buf := new(bytes.Buffer)
598         v.IndexTo("", buf)
599         if buf.Len() != 0 {
600                 t.Errorf("Index %+q should be empty", buf.Bytes())
601         }
602
603         v.TouchWithDate(TestHash, time.Now().Add(-1982*time.Millisecond))
604
605         allDone := make(chan struct{})
606         go func() {
607                 defer close(allDone)
608                 buf := make([]byte, BlockSize)
609                 n, err := v.Get(context.Background(), TestHash, buf)
610                 if err != nil {
611                         t.Error(err)
612                         return
613                 }
614                 if n != 0 {
615                         t.Errorf("Got %+q, expected empty buf", buf[:n])
616                 }
617         }()
618         select {
619         case <-allDone:
620         case <-time.After(time.Second):
621                 t.Error("Get should have stopped waiting for race when block was 2s old")
622         }
623
624         buf.Reset()
625         v.IndexTo("", buf)
626         if !bytes.HasPrefix(buf.Bytes(), []byte(TestHash+"+0")) {
627                 t.Errorf("Index %+q should have %+q", buf.Bytes(), TestHash+"+0")
628         }
629 }
630
631 func TestAzureBlobVolumeContextCancelGet(t *testing.T) {
632         testAzureBlobVolumeContextCancel(t, func(ctx context.Context, v *TestableAzureBlobVolume) error {
633                 v.PutRaw(TestHash, TestBlock)
634                 _, err := v.Get(ctx, TestHash, make([]byte, BlockSize))
635                 return err
636         })
637 }
638
639 func TestAzureBlobVolumeContextCancelPut(t *testing.T) {
640         testAzureBlobVolumeContextCancel(t, func(ctx context.Context, v *TestableAzureBlobVolume) error {
641                 return v.Put(ctx, TestHash, make([]byte, BlockSize))
642         })
643 }
644
645 func TestAzureBlobVolumeContextCancelCompare(t *testing.T) {
646         testAzureBlobVolumeContextCancel(t, func(ctx context.Context, v *TestableAzureBlobVolume) error {
647                 v.PutRaw(TestHash, TestBlock)
648                 return v.Compare(ctx, TestHash, TestBlock2)
649         })
650 }
651
652 func testAzureBlobVolumeContextCancel(t *testing.T, testFunc func(context.Context, *TestableAzureBlobVolume) error) {
653         defer func(t http.RoundTripper) {
654                 http.DefaultTransport = t
655         }(http.DefaultTransport)
656         http.DefaultTransport = &http.Transport{
657                 Dial: (&azStubDialer{}).Dial,
658         }
659
660         v := NewTestableAzureBlobVolume(t, false, 3)
661         defer v.Teardown()
662         v.azHandler.race = make(chan chan struct{})
663
664         ctx, cancel := context.WithCancel(context.Background())
665         allDone := make(chan struct{})
666         go func() {
667                 defer close(allDone)
668                 err := testFunc(ctx, v)
669                 if err != context.Canceled {
670                         t.Errorf("got %T %q, expected %q", err, err, context.Canceled)
671                 }
672         }()
673         releaseHandler := make(chan struct{})
674         select {
675         case <-allDone:
676                 t.Error("testFunc finished without waiting for v.azHandler.race")
677         case <-time.After(10 * time.Second):
678                 t.Error("timed out waiting to enter handler")
679         case v.azHandler.race <- releaseHandler:
680         }
681
682         cancel()
683
684         select {
685         case <-time.After(10 * time.Second):
686                 t.Error("timed out waiting to cancel")
687         case <-allDone:
688         }
689
690         go func() {
691                 <-releaseHandler
692         }()
693 }
694
695 func (s *StubbedAzureBlobSuite) TestStats(c *check.C) {
696         stats := func() string {
697                 buf, err := json.Marshal(s.volume.InternalStats())
698                 c.Check(err, check.IsNil)
699                 return string(buf)
700         }
701
702         c.Check(stats(), check.Matches, `.*"Ops":0,.*`)
703         c.Check(stats(), check.Matches, `.*"Errors":0,.*`)
704
705         loc := "acbd18db4cc2f85cedef654fccc4a4d8"
706         _, err := s.volume.Get(context.Background(), loc, make([]byte, 3))
707         c.Check(err, check.NotNil)
708         c.Check(stats(), check.Matches, `.*"Ops":[^0],.*`)
709         c.Check(stats(), check.Matches, `.*"Errors":[^0],.*`)
710         c.Check(stats(), check.Matches, `.*"storage\.AzureStorageServiceError 404 \(404 Not Found\)":[^0].*`)
711         c.Check(stats(), check.Matches, `.*"InBytes":0,.*`)
712
713         err = s.volume.Put(context.Background(), loc, []byte("foo"))
714         c.Check(err, check.IsNil)
715         c.Check(stats(), check.Matches, `.*"OutBytes":3,.*`)
716         c.Check(stats(), check.Matches, `.*"CreateOps":1,.*`)
717
718         _, err = s.volume.Get(context.Background(), loc, make([]byte, 3))
719         c.Check(err, check.IsNil)
720         _, err = s.volume.Get(context.Background(), loc, make([]byte, 3))
721         c.Check(err, check.IsNil)
722         c.Check(stats(), check.Matches, `.*"InBytes":6,.*`)
723 }
724
725 func (s *StubbedAzureBlobSuite) TestConfig(c *check.C) {
726         var cfg Config
727         err := yaml.Unmarshal([]byte(`
728 Volumes:
729   - Type: Azure
730     StorageClasses: ["class_a", "class_b"]
731 `), &cfg)
732
733         c.Check(err, check.IsNil)
734         c.Check(cfg.Volumes[0].GetStorageClasses(), check.DeepEquals, []string{"class_a", "class_b"})
735 }
736
737 func (v *TestableAzureBlobVolume) PutRaw(locator string, data []byte) {
738         v.azHandler.PutRaw(v.ContainerName, locator, data)
739 }
740
741 func (v *TestableAzureBlobVolume) TouchWithDate(locator string, lastPut time.Time) {
742         v.azHandler.TouchWithDate(v.ContainerName, locator, lastPut)
743 }
744
745 func (v *TestableAzureBlobVolume) Teardown() {
746         v.azStub.Close()
747 }
748
749 func (v *TestableAzureBlobVolume) ReadWriteOperationLabelValues() (r, w string) {
750         return "get", "create"
751 }
752
753 func (v *TestableAzureBlobVolume) DeviceID() string {
754         // Dummy device id for testing purposes
755         return "azure://azure_blob_volume_test"
756 }
757
758 func (v *TestableAzureBlobVolume) Start(vm *volumeMetricsVecs) error {
759         // Override original Start() to be able to assign CounterVecs with a dummy DeviceID
760         v.container.stats.opsCounters, v.container.stats.errCounters, v.container.stats.ioBytes = vm.getCounterVecsFor(prometheus.Labels{"device_id": v.DeviceID()})
761         return nil
762 }
763
764 func makeEtag() string {
765         return fmt.Sprintf("0x%x", rand.Int63())
766 }