8d02def1445c3f0d7f6ed5806c4c226b75e41644
[arvados.git] / services / keepstore / azure_blob_volume_test.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package main
6
7 import (
8         "bytes"
9         "context"
10         "crypto/md5"
11         "encoding/base64"
12         "encoding/json"
13         "encoding/xml"
14         "flag"
15         "fmt"
16         "io/ioutil"
17         "math/rand"
18         "net"
19         "net/http"
20         "net/http/httptest"
21         "os"
22         "regexp"
23         "sort"
24         "strconv"
25         "strings"
26         "sync"
27         "testing"
28         "time"
29
30         "git.curoverse.com/arvados.git/sdk/go/arvados"
31         "github.com/Azure/azure-sdk-for-go/storage"
32         "github.com/ghodss/yaml"
33         "github.com/prometheus/client_golang/prometheus"
34         check "gopkg.in/check.v1"
35 )
36
37 const (
38         // This cannot be the fake account name "devstoreaccount1"
39         // used by Microsoft's Azure emulator: the Azure SDK
40         // recognizes that magic string and changes its behavior to
41         // cater to the Azure SDK's own test suite.
42         fakeAccountName = "fakeaccountname"
43         fakeAccountKey  = "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw=="
44 )
45
46 var (
47         azureTestContainer string
48         azureTestDebug     = os.Getenv("ARVADOS_DEBUG") != ""
49 )
50
51 func init() {
52         flag.StringVar(
53                 &azureTestContainer,
54                 "test.azure-storage-container-volume",
55                 "",
56                 "Name of Azure container to use for testing. Do not use a container with real data! Use -azure-storage-account-name and -azure-storage-key-file arguments to supply credentials.")
57 }
58
59 type azBlob struct {
60         Data        []byte
61         Etag        string
62         Metadata    map[string]string
63         Mtime       time.Time
64         Uncommitted map[string][]byte
65 }
66
67 type azStubHandler struct {
68         sync.Mutex
69         blobs      map[string]*azBlob
70         race       chan chan struct{}
71         didlist503 bool
72 }
73
74 func newAzStubHandler() *azStubHandler {
75         return &azStubHandler{
76                 blobs: make(map[string]*azBlob),
77         }
78 }
79
80 func (h *azStubHandler) TouchWithDate(container, hash string, t time.Time) {
81         blob, ok := h.blobs[container+"|"+hash]
82         if !ok {
83                 return
84         }
85         blob.Mtime = t
86 }
87
88 func (h *azStubHandler) PutRaw(container, hash string, data []byte) {
89         h.Lock()
90         defer h.Unlock()
91         h.blobs[container+"|"+hash] = &azBlob{
92                 Data:        data,
93                 Mtime:       time.Now(),
94                 Metadata:    make(map[string]string),
95                 Uncommitted: make(map[string][]byte),
96         }
97 }
98
99 func (h *azStubHandler) unlockAndRace() {
100         if h.race == nil {
101                 return
102         }
103         h.Unlock()
104         // Signal caller that race is starting by reading from
105         // h.race. If we get a channel, block until that channel is
106         // ready to receive. If we get nil (or h.race is closed) just
107         // proceed.
108         if c := <-h.race; c != nil {
109                 c <- struct{}{}
110         }
111         h.Lock()
112 }
113
114 var rangeRegexp = regexp.MustCompile(`^bytes=(\d+)-(\d+)$`)
115
116 func (h *azStubHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
117         h.Lock()
118         defer h.Unlock()
119         if azureTestDebug {
120                 defer log.Printf("azStubHandler: %+v", r)
121         }
122
123         path := strings.Split(r.URL.Path, "/")
124         container := path[1]
125         hash := ""
126         if len(path) > 2 {
127                 hash = path[2]
128         }
129
130         if err := r.ParseForm(); err != nil {
131                 log.Printf("azStubHandler(%+v): %s", r, err)
132                 rw.WriteHeader(http.StatusBadRequest)
133                 return
134         }
135
136         if (r.Method == "PUT" || r.Method == "POST") && r.Header.Get("Content-Length") == "" {
137                 rw.WriteHeader(http.StatusLengthRequired)
138                 return
139         }
140
141         body, err := ioutil.ReadAll(r.Body)
142         if err != nil {
143                 return
144         }
145
146         type blockListRequestBody struct {
147                 XMLName     xml.Name `xml:"BlockList"`
148                 Uncommitted []string
149         }
150
151         blob, blobExists := h.blobs[container+"|"+hash]
152
153         switch {
154         case r.Method == "PUT" && r.Form.Get("comp") == "":
155                 // "Put Blob" API
156                 if _, ok := h.blobs[container+"|"+hash]; !ok {
157                         // Like the real Azure service, we offer a
158                         // race window during which other clients can
159                         // list/get the new blob before any data is
160                         // committed.
161                         h.blobs[container+"|"+hash] = &azBlob{
162                                 Mtime:       time.Now(),
163                                 Uncommitted: make(map[string][]byte),
164                                 Metadata:    make(map[string]string),
165                                 Etag:        makeEtag(),
166                         }
167                         h.unlockAndRace()
168                 }
169                 metadata := make(map[string]string)
170                 for k, v := range r.Header {
171                         if strings.HasPrefix(strings.ToLower(k), "x-ms-meta-") {
172                                 name := k[len("x-ms-meta-"):]
173                                 metadata[strings.ToLower(name)] = v[0]
174                         }
175                 }
176                 h.blobs[container+"|"+hash] = &azBlob{
177                         Data:        body,
178                         Mtime:       time.Now(),
179                         Uncommitted: make(map[string][]byte),
180                         Metadata:    metadata,
181                         Etag:        makeEtag(),
182                 }
183                 rw.WriteHeader(http.StatusCreated)
184         case r.Method == "PUT" && r.Form.Get("comp") == "block":
185                 // "Put Block" API
186                 if !blobExists {
187                         log.Printf("Got block for nonexistent blob: %+v", r)
188                         rw.WriteHeader(http.StatusBadRequest)
189                         return
190                 }
191                 blockID, err := base64.StdEncoding.DecodeString(r.Form.Get("blockid"))
192                 if err != nil || len(blockID) == 0 {
193                         log.Printf("Invalid blockid: %+q", r.Form.Get("blockid"))
194                         rw.WriteHeader(http.StatusBadRequest)
195                         return
196                 }
197                 blob.Uncommitted[string(blockID)] = body
198                 rw.WriteHeader(http.StatusCreated)
199         case r.Method == "PUT" && r.Form.Get("comp") == "blocklist":
200                 // "Put Block List" API
201                 bl := &blockListRequestBody{}
202                 if err := xml.Unmarshal(body, bl); err != nil {
203                         log.Printf("xml Unmarshal: %s", err)
204                         rw.WriteHeader(http.StatusBadRequest)
205                         return
206                 }
207                 for _, encBlockID := range bl.Uncommitted {
208                         blockID, err := base64.StdEncoding.DecodeString(encBlockID)
209                         if err != nil || len(blockID) == 0 || blob.Uncommitted[string(blockID)] == nil {
210                                 log.Printf("Invalid blockid: %+q", encBlockID)
211                                 rw.WriteHeader(http.StatusBadRequest)
212                                 return
213                         }
214                         blob.Data = blob.Uncommitted[string(blockID)]
215                         blob.Etag = makeEtag()
216                         blob.Mtime = time.Now()
217                         delete(blob.Uncommitted, string(blockID))
218                 }
219                 rw.WriteHeader(http.StatusCreated)
220         case r.Method == "PUT" && r.Form.Get("comp") == "metadata":
221                 // "Set Metadata Headers" API. We don't bother
222                 // stubbing "Get Metadata Headers": AzureBlobVolume
223                 // sets metadata headers only as a way to bump Etag
224                 // and Last-Modified.
225                 if !blobExists {
226                         log.Printf("Got metadata for nonexistent blob: %+v", r)
227                         rw.WriteHeader(http.StatusBadRequest)
228                         return
229                 }
230                 blob.Metadata = make(map[string]string)
231                 for k, v := range r.Header {
232                         if strings.HasPrefix(strings.ToLower(k), "x-ms-meta-") {
233                                 name := k[len("x-ms-meta-"):]
234                                 blob.Metadata[strings.ToLower(name)] = v[0]
235                         }
236                 }
237                 blob.Mtime = time.Now()
238                 blob.Etag = makeEtag()
239         case (r.Method == "GET" || r.Method == "HEAD") && r.Form.Get("comp") == "metadata" && hash != "":
240                 // "Get Blob Metadata" API
241                 if !blobExists {
242                         rw.WriteHeader(http.StatusNotFound)
243                         return
244                 }
245                 for k, v := range blob.Metadata {
246                         rw.Header().Set(fmt.Sprintf("x-ms-meta-%s", k), v)
247                 }
248                 return
249         case (r.Method == "GET" || r.Method == "HEAD") && hash != "":
250                 // "Get Blob" API
251                 if !blobExists {
252                         rw.WriteHeader(http.StatusNotFound)
253                         return
254                 }
255                 data := blob.Data
256                 if rangeSpec := rangeRegexp.FindStringSubmatch(r.Header.Get("Range")); rangeSpec != nil {
257                         b0, err0 := strconv.Atoi(rangeSpec[1])
258                         b1, err1 := strconv.Atoi(rangeSpec[2])
259                         if err0 != nil || err1 != nil || b0 >= len(data) || b1 >= len(data) || b0 > b1 {
260                                 rw.Header().Set("Content-Range", fmt.Sprintf("bytes */%d", len(data)))
261                                 rw.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
262                                 return
263                         }
264                         rw.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", b0, b1, len(data)))
265                         rw.WriteHeader(http.StatusPartialContent)
266                         data = data[b0 : b1+1]
267                 }
268                 rw.Header().Set("Last-Modified", blob.Mtime.Format(time.RFC1123))
269                 rw.Header().Set("Content-Length", strconv.Itoa(len(data)))
270                 if r.Method == "GET" {
271                         if _, err := rw.Write(data); err != nil {
272                                 log.Printf("write %+q: %s", data, err)
273                         }
274                 }
275                 h.unlockAndRace()
276         case r.Method == "DELETE" && hash != "":
277                 // "Delete Blob" API
278                 if !blobExists {
279                         rw.WriteHeader(http.StatusNotFound)
280                         return
281                 }
282                 delete(h.blobs, container+"|"+hash)
283                 rw.WriteHeader(http.StatusAccepted)
284         case r.Method == "GET" && r.Form.Get("comp") == "list" && r.Form.Get("restype") == "container":
285                 // "List Blobs" API
286                 if !h.didlist503 {
287                         h.didlist503 = true
288                         rw.WriteHeader(http.StatusServiceUnavailable)
289                         return
290                 }
291                 prefix := container + "|" + r.Form.Get("prefix")
292                 marker := r.Form.Get("marker")
293
294                 maxResults := 2
295                 if n, err := strconv.Atoi(r.Form.Get("maxresults")); err == nil && n >= 1 && n <= 5000 {
296                         maxResults = n
297                 }
298
299                 resp := storage.BlobListResponse{
300                         Marker:     marker,
301                         NextMarker: "",
302                         MaxResults: int64(maxResults),
303                 }
304                 var hashes sort.StringSlice
305                 for k := range h.blobs {
306                         if strings.HasPrefix(k, prefix) {
307                                 hashes = append(hashes, k[len(container)+1:])
308                         }
309                 }
310                 hashes.Sort()
311                 for _, hash := range hashes {
312                         if len(resp.Blobs) == maxResults {
313                                 resp.NextMarker = hash
314                                 break
315                         }
316                         if len(resp.Blobs) > 0 || marker == "" || marker == hash {
317                                 blob := h.blobs[container+"|"+hash]
318                                 bmeta := map[string]string(nil)
319                                 if r.Form.Get("include") == "metadata" {
320                                         bmeta = blob.Metadata
321                                 }
322                                 b := storage.Blob{
323                                         Name: hash,
324                                         Properties: storage.BlobProperties{
325                                                 LastModified:  storage.TimeRFC1123(blob.Mtime),
326                                                 ContentLength: int64(len(blob.Data)),
327                                                 Etag:          blob.Etag,
328                                         },
329                                         Metadata: bmeta,
330                                 }
331                                 resp.Blobs = append(resp.Blobs, b)
332                         }
333                 }
334                 buf, err := xml.Marshal(resp)
335                 if err != nil {
336                         log.Print(err)
337                         rw.WriteHeader(http.StatusInternalServerError)
338                 }
339                 rw.Write(buf)
340         default:
341                 log.Printf("azStubHandler: not implemented: %+v Body:%+q", r, body)
342                 rw.WriteHeader(http.StatusNotImplemented)
343         }
344 }
345
346 // azStubDialer is a net.Dialer that notices when the Azure driver
347 // tries to connect to "devstoreaccount1.blob.127.0.0.1:46067", and
348 // in such cases transparently dials "127.0.0.1:46067" instead.
349 type azStubDialer struct {
350         net.Dialer
351 }
352
353 var localHostPortRe = regexp.MustCompile(`(127\.0\.0\.1|localhost|\[::1\]):\d+`)
354
355 func (d *azStubDialer) Dial(network, address string) (net.Conn, error) {
356         if hp := localHostPortRe.FindString(address); hp != "" {
357                 if azureTestDebug {
358                         log.Println("azStubDialer: dial", hp, "instead of", address)
359                 }
360                 address = hp
361         }
362         return d.Dialer.Dial(network, address)
363 }
364
365 type TestableAzureBlobVolume struct {
366         *AzureBlobVolume
367         azHandler *azStubHandler
368         azStub    *httptest.Server
369         t         TB
370 }
371
372 func NewTestableAzureBlobVolume(t TB, readonly bool, replication int) *TestableAzureBlobVolume {
373         azHandler := newAzStubHandler()
374         azStub := httptest.NewServer(azHandler)
375
376         var azClient storage.Client
377
378         container := azureTestContainer
379         if container == "" {
380                 // Connect to stub instead of real Azure storage service
381                 stubURLBase := strings.Split(azStub.URL, "://")[1]
382                 var err error
383                 if azClient, err = storage.NewClient(fakeAccountName, fakeAccountKey, stubURLBase, storage.DefaultAPIVersion, false); err != nil {
384                         t.Fatal(err)
385                 }
386                 container = "fakecontainername"
387         } else {
388                 // Connect to real Azure storage service
389                 accountKey, err := readKeyFromFile(azureStorageAccountKeyFile)
390                 if err != nil {
391                         t.Fatal(err)
392                 }
393                 azClient, err = storage.NewBasicClient(azureStorageAccountName, accountKey)
394                 if err != nil {
395                         t.Fatal(err)
396                 }
397         }
398         azClient.Sender = &singleSender{}
399
400         bs := azClient.GetBlobService()
401         v := &AzureBlobVolume{
402                 ContainerName:        container,
403                 ReadOnly:             readonly,
404                 AzureReplication:     replication,
405                 ListBlobsMaxAttempts: 2,
406                 ListBlobsRetryDelay:  arvados.Duration(time.Millisecond),
407                 azClient:             azClient,
408                 container:            &azureContainer{ctr: bs.GetContainerReference(container)},
409         }
410
411         return &TestableAzureBlobVolume{
412                 AzureBlobVolume: v,
413                 azHandler:       azHandler,
414                 azStub:          azStub,
415                 t:               t,
416         }
417 }
418
419 var _ = check.Suite(&StubbedAzureBlobSuite{})
420
421 type StubbedAzureBlobSuite struct {
422         volume            *TestableAzureBlobVolume
423         origHTTPTransport http.RoundTripper
424 }
425
426 func (s *StubbedAzureBlobSuite) SetUpTest(c *check.C) {
427         s.origHTTPTransport = http.DefaultTransport
428         http.DefaultTransport = &http.Transport{
429                 Dial: (&azStubDialer{}).Dial,
430         }
431         azureWriteRaceInterval = time.Millisecond
432         azureWriteRacePollTime = time.Nanosecond
433
434         s.volume = NewTestableAzureBlobVolume(c, false, 3)
435 }
436
437 func (s *StubbedAzureBlobSuite) TearDownTest(c *check.C) {
438         s.volume.Teardown()
439         http.DefaultTransport = s.origHTTPTransport
440 }
441
442 func TestAzureBlobVolumeWithGeneric(t *testing.T) {
443         defer func(t http.RoundTripper) {
444                 http.DefaultTransport = t
445         }(http.DefaultTransport)
446         http.DefaultTransport = &http.Transport{
447                 Dial: (&azStubDialer{}).Dial,
448         }
449         azureWriteRaceInterval = time.Millisecond
450         azureWriteRacePollTime = time.Nanosecond
451         DoGenericVolumeTests(t, func(t TB) TestableVolume {
452                 return NewTestableAzureBlobVolume(t, false, azureStorageReplication)
453         })
454 }
455
456 func TestAzureBlobVolumeConcurrentRanges(t *testing.T) {
457         defer func(b int) {
458                 azureMaxGetBytes = b
459         }(azureMaxGetBytes)
460
461         defer func(t http.RoundTripper) {
462                 http.DefaultTransport = t
463         }(http.DefaultTransport)
464         http.DefaultTransport = &http.Transport{
465                 Dial: (&azStubDialer{}).Dial,
466         }
467         azureWriteRaceInterval = time.Millisecond
468         azureWriteRacePollTime = time.Nanosecond
469         // Test (BlockSize mod azureMaxGetBytes)==0 and !=0 cases
470         for _, azureMaxGetBytes = range []int{2 << 22, 2<<22 - 1} {
471                 DoGenericVolumeTests(t, func(t TB) TestableVolume {
472                         return NewTestableAzureBlobVolume(t, false, azureStorageReplication)
473                 })
474         }
475 }
476
477 func TestReadonlyAzureBlobVolumeWithGeneric(t *testing.T) {
478         defer func(t http.RoundTripper) {
479                 http.DefaultTransport = t
480         }(http.DefaultTransport)
481         http.DefaultTransport = &http.Transport{
482                 Dial: (&azStubDialer{}).Dial,
483         }
484         azureWriteRaceInterval = time.Millisecond
485         azureWriteRacePollTime = time.Nanosecond
486         DoGenericVolumeTests(t, func(t TB) TestableVolume {
487                 return NewTestableAzureBlobVolume(t, true, azureStorageReplication)
488         })
489 }
490
491 func TestAzureBlobVolumeRangeFenceposts(t *testing.T) {
492         defer func(t http.RoundTripper) {
493                 http.DefaultTransport = t
494         }(http.DefaultTransport)
495         http.DefaultTransport = &http.Transport{
496                 Dial: (&azStubDialer{}).Dial,
497         }
498
499         v := NewTestableAzureBlobVolume(t, false, 3)
500         defer v.Teardown()
501
502         for _, size := range []int{
503                 2<<22 - 1, // one <max read
504                 2 << 22,   // one =max read
505                 2<<22 + 1, // one =max read, one <max
506                 2 << 23,   // two =max reads
507                 BlockSize - 1,
508                 BlockSize,
509         } {
510                 data := make([]byte, size)
511                 for i := range data {
512                         data[i] = byte((i + 7) & 0xff)
513                 }
514                 hash := fmt.Sprintf("%x", md5.Sum(data))
515                 err := v.Put(context.Background(), hash, data)
516                 if err != nil {
517                         t.Error(err)
518                 }
519                 gotData := make([]byte, len(data))
520                 gotLen, err := v.Get(context.Background(), hash, gotData)
521                 if err != nil {
522                         t.Error(err)
523                 }
524                 gotHash := fmt.Sprintf("%x", md5.Sum(gotData))
525                 if gotLen != size {
526                         t.Errorf("length mismatch: got %d != %d", gotLen, size)
527                 }
528                 if gotHash != hash {
529                         t.Errorf("hash mismatch: got %s != %s", gotHash, hash)
530                 }
531         }
532 }
533
534 func TestAzureBlobVolumeReplication(t *testing.T) {
535         for r := 1; r <= 4; r++ {
536                 v := NewTestableAzureBlobVolume(t, false, r)
537                 defer v.Teardown()
538                 if n := v.Replication(); n != r {
539                         t.Errorf("Got replication %d, expected %d", n, r)
540                 }
541         }
542 }
543
544 func TestAzureBlobVolumeCreateBlobRace(t *testing.T) {
545         defer func(t http.RoundTripper) {
546                 http.DefaultTransport = t
547         }(http.DefaultTransport)
548         http.DefaultTransport = &http.Transport{
549                 Dial: (&azStubDialer{}).Dial,
550         }
551
552         v := NewTestableAzureBlobVolume(t, false, 3)
553         defer v.Teardown()
554
555         azureWriteRaceInterval = time.Second
556         azureWriteRacePollTime = time.Millisecond
557
558         var wg sync.WaitGroup
559
560         v.azHandler.race = make(chan chan struct{})
561
562         wg.Add(1)
563         go func() {
564                 defer wg.Done()
565                 err := v.Put(context.Background(), TestHash, TestBlock)
566                 if err != nil {
567                         t.Error(err)
568                 }
569         }()
570         continuePut := make(chan struct{})
571         // Wait for the stub's Put to create the empty blob
572         v.azHandler.race <- continuePut
573         wg.Add(1)
574         go func() {
575                 defer wg.Done()
576                 buf := make([]byte, len(TestBlock))
577                 _, err := v.Get(context.Background(), TestHash, buf)
578                 if err != nil {
579                         t.Error(err)
580                 }
581         }()
582         // Wait for the stub's Get to get the empty blob
583         close(v.azHandler.race)
584         // Allow stub's Put to continue, so the real data is ready
585         // when the volume's Get retries
586         <-continuePut
587         // Wait for Get() and Put() to finish
588         wg.Wait()
589 }
590
591 func TestAzureBlobVolumeCreateBlobRaceDeadline(t *testing.T) {
592         defer func(t http.RoundTripper) {
593                 http.DefaultTransport = t
594         }(http.DefaultTransport)
595         http.DefaultTransport = &http.Transport{
596                 Dial: (&azStubDialer{}).Dial,
597         }
598
599         v := NewTestableAzureBlobVolume(t, false, 3)
600         defer v.Teardown()
601
602         azureWriteRaceInterval = 2 * time.Second
603         azureWriteRacePollTime = 5 * time.Millisecond
604
605         v.PutRaw(TestHash, nil)
606
607         buf := new(bytes.Buffer)
608         v.IndexTo("", buf)
609         if buf.Len() != 0 {
610                 t.Errorf("Index %+q should be empty", buf.Bytes())
611         }
612
613         v.TouchWithDate(TestHash, time.Now().Add(-1982*time.Millisecond))
614
615         allDone := make(chan struct{})
616         go func() {
617                 defer close(allDone)
618                 buf := make([]byte, BlockSize)
619                 n, err := v.Get(context.Background(), TestHash, buf)
620                 if err != nil {
621                         t.Error(err)
622                         return
623                 }
624                 if n != 0 {
625                         t.Errorf("Got %+q, expected empty buf", buf[:n])
626                 }
627         }()
628         select {
629         case <-allDone:
630         case <-time.After(time.Second):
631                 t.Error("Get should have stopped waiting for race when block was 2s old")
632         }
633
634         buf.Reset()
635         v.IndexTo("", buf)
636         if !bytes.HasPrefix(buf.Bytes(), []byte(TestHash+"+0")) {
637                 t.Errorf("Index %+q should have %+q", buf.Bytes(), TestHash+"+0")
638         }
639 }
640
641 func TestAzureBlobVolumeContextCancelGet(t *testing.T) {
642         testAzureBlobVolumeContextCancel(t, func(ctx context.Context, v *TestableAzureBlobVolume) error {
643                 v.PutRaw(TestHash, TestBlock)
644                 _, err := v.Get(ctx, TestHash, make([]byte, BlockSize))
645                 return err
646         })
647 }
648
649 func TestAzureBlobVolumeContextCancelPut(t *testing.T) {
650         testAzureBlobVolumeContextCancel(t, func(ctx context.Context, v *TestableAzureBlobVolume) error {
651                 return v.Put(ctx, TestHash, make([]byte, BlockSize))
652         })
653 }
654
655 func TestAzureBlobVolumeContextCancelCompare(t *testing.T) {
656         testAzureBlobVolumeContextCancel(t, func(ctx context.Context, v *TestableAzureBlobVolume) error {
657                 v.PutRaw(TestHash, TestBlock)
658                 return v.Compare(ctx, TestHash, TestBlock2)
659         })
660 }
661
662 func testAzureBlobVolumeContextCancel(t *testing.T, testFunc func(context.Context, *TestableAzureBlobVolume) error) {
663         defer func(t http.RoundTripper) {
664                 http.DefaultTransport = t
665         }(http.DefaultTransport)
666         http.DefaultTransport = &http.Transport{
667                 Dial: (&azStubDialer{}).Dial,
668         }
669
670         v := NewTestableAzureBlobVolume(t, false, 3)
671         defer v.Teardown()
672         v.azHandler.race = make(chan chan struct{})
673
674         ctx, cancel := context.WithCancel(context.Background())
675         allDone := make(chan struct{})
676         go func() {
677                 defer close(allDone)
678                 err := testFunc(ctx, v)
679                 if err != context.Canceled {
680                         t.Errorf("got %T %q, expected %q", err, err, context.Canceled)
681                 }
682         }()
683         releaseHandler := make(chan struct{})
684         select {
685         case <-allDone:
686                 t.Error("testFunc finished without waiting for v.azHandler.race")
687         case <-time.After(10 * time.Second):
688                 t.Error("timed out waiting to enter handler")
689         case v.azHandler.race <- releaseHandler:
690         }
691
692         cancel()
693
694         select {
695         case <-time.After(10 * time.Second):
696                 t.Error("timed out waiting to cancel")
697         case <-allDone:
698         }
699
700         go func() {
701                 <-releaseHandler
702         }()
703 }
704
705 func (s *StubbedAzureBlobSuite) TestStats(c *check.C) {
706         stats := func() string {
707                 buf, err := json.Marshal(s.volume.InternalStats())
708                 c.Check(err, check.IsNil)
709                 return string(buf)
710         }
711
712         c.Check(stats(), check.Matches, `.*"Ops":0,.*`)
713         c.Check(stats(), check.Matches, `.*"Errors":0,.*`)
714
715         loc := "acbd18db4cc2f85cedef654fccc4a4d8"
716         _, err := s.volume.Get(context.Background(), loc, make([]byte, 3))
717         c.Check(err, check.NotNil)
718         c.Check(stats(), check.Matches, `.*"Ops":[^0],.*`)
719         c.Check(stats(), check.Matches, `.*"Errors":[^0],.*`)
720         c.Check(stats(), check.Matches, `.*"storage\.AzureStorageServiceError 404 \(404 Not Found\)":[^0].*`)
721         c.Check(stats(), check.Matches, `.*"InBytes":0,.*`)
722
723         err = s.volume.Put(context.Background(), loc, []byte("foo"))
724         c.Check(err, check.IsNil)
725         c.Check(stats(), check.Matches, `.*"OutBytes":3,.*`)
726         c.Check(stats(), check.Matches, `.*"CreateOps":1,.*`)
727
728         _, err = s.volume.Get(context.Background(), loc, make([]byte, 3))
729         c.Check(err, check.IsNil)
730         _, err = s.volume.Get(context.Background(), loc, make([]byte, 3))
731         c.Check(err, check.IsNil)
732         c.Check(stats(), check.Matches, `.*"InBytes":6,.*`)
733 }
734
735 func (s *StubbedAzureBlobSuite) TestConfig(c *check.C) {
736         var cfg Config
737         err := yaml.Unmarshal([]byte(`
738 Volumes:
739   - Type: Azure
740     StorageClasses: ["class_a", "class_b"]
741 `), &cfg)
742
743         c.Check(err, check.IsNil)
744         c.Check(cfg.Volumes[0].GetStorageClasses(), check.DeepEquals, []string{"class_a", "class_b"})
745 }
746
747 func (v *TestableAzureBlobVolume) PutRaw(locator string, data []byte) {
748         v.azHandler.PutRaw(v.ContainerName, locator, data)
749 }
750
751 func (v *TestableAzureBlobVolume) TouchWithDate(locator string, lastPut time.Time) {
752         v.azHandler.TouchWithDate(v.ContainerName, locator, lastPut)
753 }
754
755 func (v *TestableAzureBlobVolume) Teardown() {
756         v.azStub.Close()
757 }
758
759 func (v *TestableAzureBlobVolume) ReadWriteOperationLabelValues() (r, w string) {
760         return "get", "create"
761 }
762
763 func (v *TestableAzureBlobVolume) DeviceID() string {
764         // Dummy device id for testing purposes
765         return "azure://azure_blob_volume_test"
766 }
767
768 func (v *TestableAzureBlobVolume) Start(vm *volumeMetricsVecs) error {
769         // Override original Start() to be able to assign CounterVecs with a dummy DeviceID
770         v.container.stats.opsCounters, v.container.stats.errCounters, v.container.stats.ioBytes = vm.getCounterVecsFor(prometheus.Labels{"device_id": v.DeviceID()})
771         return nil
772 }
773
774 func makeEtag() string {
775         return fmt.Sprintf("0x%x", rand.Int63())
776 }