10682: Avoid invoking special "test mode" behavior in Azure SDK.
[arvados.git] / services / keepstore / azure_blob_volume_test.go
1 package main
2
3 import (
4         "bytes"
5         "context"
6         "crypto/md5"
7         "encoding/base64"
8         "encoding/xml"
9         "flag"
10         "fmt"
11         "io/ioutil"
12         "math/rand"
13         "net"
14         "net/http"
15         "net/http/httptest"
16         "regexp"
17         "sort"
18         "strconv"
19         "strings"
20         "sync"
21         "testing"
22         "time"
23
24         log "github.com/Sirupsen/logrus"
25         "github.com/curoverse/azure-sdk-for-go/storage"
26 )
27
28 const (
29         // This cannot be the fake account name "devstoreaccount1"
30         // used by Microsoft's Azure emulator: the Azure SDK
31         // recognizes that magic string and changes its behavior to
32         // cater to the Azure SDK's own test suite.
33         fakeAccountName = "fakeAccountName"
34         fakeAccountKey  = "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw=="
35 )
36
37 var azureTestContainer string
38
39 func init() {
40         flag.StringVar(
41                 &azureTestContainer,
42                 "test.azure-storage-container-volume",
43                 "",
44                 "Name of Azure container to use for testing. Do not use a container with real data! Use -azure-storage-account-name and -azure-storage-key-file arguments to supply credentials.")
45 }
46
47 type azBlob struct {
48         Data        []byte
49         Etag        string
50         Metadata    map[string]string
51         Mtime       time.Time
52         Uncommitted map[string][]byte
53 }
54
55 type azStubHandler struct {
56         sync.Mutex
57         blobs map[string]*azBlob
58         race  chan chan struct{}
59 }
60
61 func newAzStubHandler() *azStubHandler {
62         return &azStubHandler{
63                 blobs: make(map[string]*azBlob),
64         }
65 }
66
67 func (h *azStubHandler) TouchWithDate(container, hash string, t time.Time) {
68         blob, ok := h.blobs[container+"|"+hash]
69         if !ok {
70                 return
71         }
72         blob.Mtime = t
73 }
74
75 func (h *azStubHandler) PutRaw(container, hash string, data []byte) {
76         h.Lock()
77         defer h.Unlock()
78         h.blobs[container+"|"+hash] = &azBlob{
79                 Data:        data,
80                 Mtime:       time.Now(),
81                 Metadata:    make(map[string]string),
82                 Uncommitted: make(map[string][]byte),
83         }
84 }
85
86 func (h *azStubHandler) unlockAndRace() {
87         if h.race == nil {
88                 return
89         }
90         h.Unlock()
91         // Signal caller that race is starting by reading from
92         // h.race. If we get a channel, block until that channel is
93         // ready to receive. If we get nil (or h.race is closed) just
94         // proceed.
95         if c := <-h.race; c != nil {
96                 c <- struct{}{}
97         }
98         h.Lock()
99 }
100
101 var rangeRegexp = regexp.MustCompile(`^bytes=(\d+)-(\d+)$`)
102
103 func (h *azStubHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
104         h.Lock()
105         defer h.Unlock()
106         // defer log.Printf("azStubHandler: %+v", r)
107
108         path := strings.Split(r.URL.Path, "/")
109         container := path[1]
110         hash := ""
111         if len(path) > 2 {
112                 hash = path[2]
113         }
114
115         if err := r.ParseForm(); err != nil {
116                 log.Printf("azStubHandler(%+v): %s", r, err)
117                 rw.WriteHeader(http.StatusBadRequest)
118                 return
119         }
120
121         body, err := ioutil.ReadAll(r.Body)
122         if err != nil {
123                 return
124         }
125
126         type blockListRequestBody struct {
127                 XMLName     xml.Name `xml:"BlockList"`
128                 Uncommitted []string
129         }
130
131         blob, blobExists := h.blobs[container+"|"+hash]
132
133         switch {
134         case r.Method == "PUT" && r.Form.Get("comp") == "":
135                 // "Put Blob" API
136                 if _, ok := h.blobs[container+"|"+hash]; !ok {
137                         // Like the real Azure service, we offer a
138                         // race window during which other clients can
139                         // list/get the new blob before any data is
140                         // committed.
141                         h.blobs[container+"|"+hash] = &azBlob{
142                                 Mtime:       time.Now(),
143                                 Uncommitted: make(map[string][]byte),
144                                 Metadata:    make(map[string]string),
145                                 Etag:        makeEtag(),
146                         }
147                         h.unlockAndRace()
148                 }
149                 metadata := make(map[string]string)
150                 for k, v := range r.Header {
151                         if strings.HasPrefix(strings.ToLower(k), "x-ms-meta-") {
152                                 name := k[len("x-ms-meta-"):]
153                                 metadata[strings.ToLower(name)] = v[0]
154                         }
155                 }
156                 h.blobs[container+"|"+hash] = &azBlob{
157                         Data:        body,
158                         Mtime:       time.Now(),
159                         Uncommitted: make(map[string][]byte),
160                         Metadata:    metadata,
161                         Etag:        makeEtag(),
162                 }
163                 rw.WriteHeader(http.StatusCreated)
164         case r.Method == "PUT" && r.Form.Get("comp") == "block":
165                 // "Put Block" API
166                 if !blobExists {
167                         log.Printf("Got block for nonexistent blob: %+v", r)
168                         rw.WriteHeader(http.StatusBadRequest)
169                         return
170                 }
171                 blockID, err := base64.StdEncoding.DecodeString(r.Form.Get("blockid"))
172                 if err != nil || len(blockID) == 0 {
173                         log.Printf("Invalid blockid: %+q", r.Form.Get("blockid"))
174                         rw.WriteHeader(http.StatusBadRequest)
175                         return
176                 }
177                 blob.Uncommitted[string(blockID)] = body
178                 rw.WriteHeader(http.StatusCreated)
179         case r.Method == "PUT" && r.Form.Get("comp") == "blocklist":
180                 // "Put Block List" API
181                 bl := &blockListRequestBody{}
182                 if err := xml.Unmarshal(body, bl); err != nil {
183                         log.Printf("xml Unmarshal: %s", err)
184                         rw.WriteHeader(http.StatusBadRequest)
185                         return
186                 }
187                 for _, encBlockID := range bl.Uncommitted {
188                         blockID, err := base64.StdEncoding.DecodeString(encBlockID)
189                         if err != nil || len(blockID) == 0 || blob.Uncommitted[string(blockID)] == nil {
190                                 log.Printf("Invalid blockid: %+q", encBlockID)
191                                 rw.WriteHeader(http.StatusBadRequest)
192                                 return
193                         }
194                         blob.Data = blob.Uncommitted[string(blockID)]
195                         blob.Etag = makeEtag()
196                         blob.Mtime = time.Now()
197                         delete(blob.Uncommitted, string(blockID))
198                 }
199                 rw.WriteHeader(http.StatusCreated)
200         case r.Method == "PUT" && r.Form.Get("comp") == "metadata":
201                 // "Set Metadata Headers" API. We don't bother
202                 // stubbing "Get Metadata Headers": AzureBlobVolume
203                 // sets metadata headers only as a way to bump Etag
204                 // and Last-Modified.
205                 if !blobExists {
206                         log.Printf("Got metadata for nonexistent blob: %+v", r)
207                         rw.WriteHeader(http.StatusBadRequest)
208                         return
209                 }
210                 blob.Metadata = make(map[string]string)
211                 for k, v := range r.Header {
212                         if strings.HasPrefix(strings.ToLower(k), "x-ms-meta-") {
213                                 name := k[len("x-ms-meta-"):]
214                                 blob.Metadata[strings.ToLower(name)] = v[0]
215                         }
216                 }
217                 blob.Mtime = time.Now()
218                 blob.Etag = makeEtag()
219         case (r.Method == "GET" || r.Method == "HEAD") && r.Form.Get("comp") == "metadata" && hash != "":
220                 // "Get Blob Metadata" API
221                 if !blobExists {
222                         rw.WriteHeader(http.StatusNotFound)
223                         return
224                 }
225                 for k, v := range blob.Metadata {
226                         rw.Header().Set(fmt.Sprintf("x-ms-meta-%s", k), v)
227                 }
228                 return
229         case (r.Method == "GET" || r.Method == "HEAD") && hash != "":
230                 // "Get Blob" API
231                 if !blobExists {
232                         rw.WriteHeader(http.StatusNotFound)
233                         return
234                 }
235                 data := blob.Data
236                 if rangeSpec := rangeRegexp.FindStringSubmatch(r.Header.Get("Range")); rangeSpec != nil {
237                         b0, err0 := strconv.Atoi(rangeSpec[1])
238                         b1, err1 := strconv.Atoi(rangeSpec[2])
239                         if err0 != nil || err1 != nil || b0 >= len(data) || b1 >= len(data) || b0 > b1 {
240                                 rw.Header().Set("Content-Range", fmt.Sprintf("bytes */%d", len(data)))
241                                 rw.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
242                                 return
243                         }
244                         rw.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", b0, b1, len(data)))
245                         rw.WriteHeader(http.StatusPartialContent)
246                         data = data[b0 : b1+1]
247                 }
248                 rw.Header().Set("Last-Modified", blob.Mtime.Format(time.RFC1123))
249                 rw.Header().Set("Content-Length", strconv.Itoa(len(data)))
250                 if r.Method == "GET" {
251                         if _, err := rw.Write(data); err != nil {
252                                 log.Printf("write %+q: %s", data, err)
253                         }
254                 }
255                 h.unlockAndRace()
256         case r.Method == "DELETE" && hash != "":
257                 // "Delete Blob" API
258                 if !blobExists {
259                         rw.WriteHeader(http.StatusNotFound)
260                         return
261                 }
262                 delete(h.blobs, container+"|"+hash)
263                 rw.WriteHeader(http.StatusAccepted)
264         case r.Method == "GET" && r.Form.Get("comp") == "list" && r.Form.Get("restype") == "container":
265                 // "List Blobs" API
266                 prefix := container + "|" + r.Form.Get("prefix")
267                 marker := r.Form.Get("marker")
268
269                 maxResults := 2
270                 if n, err := strconv.Atoi(r.Form.Get("maxresults")); err == nil && n >= 1 && n <= 5000 {
271                         maxResults = n
272                 }
273
274                 resp := storage.BlobListResponse{
275                         Marker:     marker,
276                         NextMarker: "",
277                         MaxResults: int64(maxResults),
278                 }
279                 var hashes sort.StringSlice
280                 for k := range h.blobs {
281                         if strings.HasPrefix(k, prefix) {
282                                 hashes = append(hashes, k[len(container)+1:])
283                         }
284                 }
285                 hashes.Sort()
286                 for _, hash := range hashes {
287                         if len(resp.Blobs) == maxResults {
288                                 resp.NextMarker = hash
289                                 break
290                         }
291                         if len(resp.Blobs) > 0 || marker == "" || marker == hash {
292                                 blob := h.blobs[container+"|"+hash]
293                                 bmeta := map[string]string(nil)
294                                 if r.Form.Get("include") == "metadata" {
295                                         bmeta = blob.Metadata
296                                 }
297                                 b := storage.Blob{
298                                         Name: hash,
299                                         Properties: storage.BlobProperties{
300                                                 LastModified:  blob.Mtime.Format(time.RFC1123),
301                                                 ContentLength: int64(len(blob.Data)),
302                                                 Etag:          blob.Etag,
303                                         },
304                                         Metadata: bmeta,
305                                 }
306                                 resp.Blobs = append(resp.Blobs, b)
307                         }
308                 }
309                 buf, err := xml.Marshal(resp)
310                 if err != nil {
311                         log.Print(err)
312                         rw.WriteHeader(http.StatusInternalServerError)
313                 }
314                 rw.Write(buf)
315         default:
316                 log.Printf("azStubHandler: not implemented: %+v Body:%+q", r, body)
317                 rw.WriteHeader(http.StatusNotImplemented)
318         }
319 }
320
321 // azStubDialer is a net.Dialer that notices when the Azure driver
322 // tries to connect to "devstoreaccount1.blob.127.0.0.1:46067", and
323 // in such cases transparently dials "127.0.0.1:46067" instead.
324 type azStubDialer struct {
325         net.Dialer
326 }
327
328 var localHostPortRe = regexp.MustCompile(`(127\.0\.0\.1|localhost|\[::1\]):\d+`)
329
330 func (d *azStubDialer) Dial(network, address string) (net.Conn, error) {
331         if hp := localHostPortRe.FindString(address); hp != "" {
332                 log.Println("azStubDialer: dial", hp, "instead of", address)
333                 address = hp
334         }
335         return d.Dialer.Dial(network, address)
336 }
337
338 type TestableAzureBlobVolume struct {
339         *AzureBlobVolume
340         azHandler *azStubHandler
341         azStub    *httptest.Server
342         t         TB
343 }
344
345 func NewTestableAzureBlobVolume(t TB, readonly bool, replication int) *TestableAzureBlobVolume {
346         azHandler := newAzStubHandler()
347         azStub := httptest.NewServer(azHandler)
348
349         var azClient storage.Client
350
351         container := azureTestContainer
352         if container == "" {
353                 // Connect to stub instead of real Azure storage service
354                 stubURLBase := strings.Split(azStub.URL, "://")[1]
355                 var err error
356                 if azClient, err = storage.NewClient(fakeAccountName, fakeAccountKey, stubURLBase, storage.DefaultAPIVersion, false); err != nil {
357                         t.Fatal(err)
358                 }
359                 container = "fakecontainername"
360         } else {
361                 // Connect to real Azure storage service
362                 accountKey, err := readKeyFromFile(azureStorageAccountKeyFile)
363                 if err != nil {
364                         t.Fatal(err)
365                 }
366                 azClient, err = storage.NewBasicClient(azureStorageAccountName, accountKey)
367                 if err != nil {
368                         t.Fatal(err)
369                 }
370         }
371
372         v := &AzureBlobVolume{
373                 ContainerName:    container,
374                 ReadOnly:         readonly,
375                 AzureReplication: replication,
376                 azClient:         azClient,
377                 bsClient:         azClient.GetBlobService(),
378         }
379
380         return &TestableAzureBlobVolume{
381                 AzureBlobVolume: v,
382                 azHandler:       azHandler,
383                 azStub:          azStub,
384                 t:               t,
385         }
386 }
387
388 func TestAzureBlobVolumeWithGeneric(t *testing.T) {
389         defer func(t http.RoundTripper) {
390                 http.DefaultTransport = t
391         }(http.DefaultTransport)
392         http.DefaultTransport = &http.Transport{
393                 Dial: (&azStubDialer{}).Dial,
394         }
395         azureWriteRaceInterval = time.Millisecond
396         azureWriteRacePollTime = time.Nanosecond
397         DoGenericVolumeTests(t, func(t TB) TestableVolume {
398                 return NewTestableAzureBlobVolume(t, false, azureStorageReplication)
399         })
400 }
401
402 func TestAzureBlobVolumeConcurrentRanges(t *testing.T) {
403         defer func(b int) {
404                 azureMaxGetBytes = b
405         }(azureMaxGetBytes)
406
407         defer func(t http.RoundTripper) {
408                 http.DefaultTransport = t
409         }(http.DefaultTransport)
410         http.DefaultTransport = &http.Transport{
411                 Dial: (&azStubDialer{}).Dial,
412         }
413         azureWriteRaceInterval = time.Millisecond
414         azureWriteRacePollTime = time.Nanosecond
415         // Test (BlockSize mod azureMaxGetBytes)==0 and !=0 cases
416         for _, azureMaxGetBytes = range []int{2 << 22, 2<<22 - 1} {
417                 DoGenericVolumeTests(t, func(t TB) TestableVolume {
418                         return NewTestableAzureBlobVolume(t, false, azureStorageReplication)
419                 })
420         }
421 }
422
423 func TestReadonlyAzureBlobVolumeWithGeneric(t *testing.T) {
424         defer func(t http.RoundTripper) {
425                 http.DefaultTransport = t
426         }(http.DefaultTransport)
427         http.DefaultTransport = &http.Transport{
428                 Dial: (&azStubDialer{}).Dial,
429         }
430         azureWriteRaceInterval = time.Millisecond
431         azureWriteRacePollTime = time.Nanosecond
432         DoGenericVolumeTests(t, func(t TB) TestableVolume {
433                 return NewTestableAzureBlobVolume(t, true, azureStorageReplication)
434         })
435 }
436
437 func TestAzureBlobVolumeRangeFenceposts(t *testing.T) {
438         defer func(t http.RoundTripper) {
439                 http.DefaultTransport = t
440         }(http.DefaultTransport)
441         http.DefaultTransport = &http.Transport{
442                 Dial: (&azStubDialer{}).Dial,
443         }
444
445         v := NewTestableAzureBlobVolume(t, false, 3)
446         defer v.Teardown()
447
448         for _, size := range []int{
449                 2<<22 - 1, // one <max read
450                 2 << 22,   // one =max read
451                 2<<22 + 1, // one =max read, one <max
452                 2 << 23,   // two =max reads
453                 BlockSize - 1,
454                 BlockSize,
455         } {
456                 data := make([]byte, size)
457                 for i := range data {
458                         data[i] = byte((i + 7) & 0xff)
459                 }
460                 hash := fmt.Sprintf("%x", md5.Sum(data))
461                 err := v.Put(context.Background(), hash, data)
462                 if err != nil {
463                         t.Error(err)
464                 }
465                 gotData := make([]byte, len(data))
466                 gotLen, err := v.Get(context.Background(), hash, gotData)
467                 if err != nil {
468                         t.Error(err)
469                 }
470                 gotHash := fmt.Sprintf("%x", md5.Sum(gotData))
471                 if gotLen != size {
472                         t.Error("length mismatch: got %d != %d", gotLen, size)
473                 }
474                 if gotHash != hash {
475                         t.Error("hash mismatch: got %s != %s", gotHash, hash)
476                 }
477         }
478 }
479
480 func TestAzureBlobVolumeReplication(t *testing.T) {
481         for r := 1; r <= 4; r++ {
482                 v := NewTestableAzureBlobVolume(t, false, r)
483                 defer v.Teardown()
484                 if n := v.Replication(); n != r {
485                         t.Errorf("Got replication %d, expected %d", n, r)
486                 }
487         }
488 }
489
490 func TestAzureBlobVolumeCreateBlobRace(t *testing.T) {
491         defer func(t http.RoundTripper) {
492                 http.DefaultTransport = t
493         }(http.DefaultTransport)
494         http.DefaultTransport = &http.Transport{
495                 Dial: (&azStubDialer{}).Dial,
496         }
497
498         v := NewTestableAzureBlobVolume(t, false, 3)
499         defer v.Teardown()
500
501         azureWriteRaceInterval = time.Second
502         azureWriteRacePollTime = time.Millisecond
503
504         allDone := make(chan struct{})
505         v.azHandler.race = make(chan chan struct{})
506         go func() {
507                 err := v.Put(context.Background(), TestHash, TestBlock)
508                 if err != nil {
509                         t.Error(err)
510                 }
511         }()
512         continuePut := make(chan struct{})
513         // Wait for the stub's Put to create the empty blob
514         v.azHandler.race <- continuePut
515         go func() {
516                 buf := make([]byte, len(TestBlock))
517                 _, err := v.Get(context.Background(), TestHash, buf)
518                 if err != nil {
519                         t.Error(err)
520                 }
521                 close(allDone)
522         }()
523         // Wait for the stub's Get to get the empty blob
524         close(v.azHandler.race)
525         // Allow stub's Put to continue, so the real data is ready
526         // when the volume's Get retries
527         <-continuePut
528         // Wait for volume's Get to return the real data
529         <-allDone
530 }
531
532 func TestAzureBlobVolumeCreateBlobRaceDeadline(t *testing.T) {
533         defer func(t http.RoundTripper) {
534                 http.DefaultTransport = t
535         }(http.DefaultTransport)
536         http.DefaultTransport = &http.Transport{
537                 Dial: (&azStubDialer{}).Dial,
538         }
539
540         v := NewTestableAzureBlobVolume(t, false, 3)
541         defer v.Teardown()
542
543         azureWriteRaceInterval = 2 * time.Second
544         azureWriteRacePollTime = 5 * time.Millisecond
545
546         v.PutRaw(TestHash, nil)
547
548         buf := new(bytes.Buffer)
549         v.IndexTo("", buf)
550         if buf.Len() != 0 {
551                 t.Errorf("Index %+q should be empty", buf.Bytes())
552         }
553
554         v.TouchWithDate(TestHash, time.Now().Add(-1982*time.Millisecond))
555
556         allDone := make(chan struct{})
557         go func() {
558                 defer close(allDone)
559                 buf := make([]byte, BlockSize)
560                 n, err := v.Get(context.Background(), TestHash, buf)
561                 if err != nil {
562                         t.Error(err)
563                         return
564                 }
565                 if n != 0 {
566                         t.Errorf("Got %+q, expected empty buf", buf[:n])
567                 }
568         }()
569         select {
570         case <-allDone:
571         case <-time.After(time.Second):
572                 t.Error("Get should have stopped waiting for race when block was 2s old")
573         }
574
575         buf.Reset()
576         v.IndexTo("", buf)
577         if !bytes.HasPrefix(buf.Bytes(), []byte(TestHash+"+0")) {
578                 t.Errorf("Index %+q should have %+q", buf.Bytes(), TestHash+"+0")
579         }
580 }
581
582 func TestAzureBlobVolumeContextCancelGet(t *testing.T) {
583         testAzureBlobVolumeContextCancel(t, func(ctx context.Context, v *TestableAzureBlobVolume) error {
584                 v.PutRaw(TestHash, TestBlock)
585                 _, err := v.Get(ctx, TestHash, make([]byte, BlockSize))
586                 return err
587         })
588 }
589
590 func TestAzureBlobVolumeContextCancelPut(t *testing.T) {
591         testAzureBlobVolumeContextCancel(t, func(ctx context.Context, v *TestableAzureBlobVolume) error {
592                 return v.Put(ctx, TestHash, make([]byte, BlockSize))
593         })
594 }
595
596 func TestAzureBlobVolumeContextCancelCompare(t *testing.T) {
597         testAzureBlobVolumeContextCancel(t, func(ctx context.Context, v *TestableAzureBlobVolume) error {
598                 v.PutRaw(TestHash, TestBlock)
599                 return v.Compare(ctx, TestHash, TestBlock2)
600         })
601 }
602
603 func testAzureBlobVolumeContextCancel(t *testing.T, testFunc func(context.Context, *TestableAzureBlobVolume) error) {
604         defer func(t http.RoundTripper) {
605                 http.DefaultTransport = t
606         }(http.DefaultTransport)
607         http.DefaultTransport = &http.Transport{
608                 Dial: (&azStubDialer{}).Dial,
609         }
610
611         v := NewTestableAzureBlobVolume(t, false, 3)
612         defer v.Teardown()
613         v.azHandler.race = make(chan chan struct{})
614
615         ctx, cancel := context.WithCancel(context.Background())
616         allDone := make(chan struct{})
617         go func() {
618                 defer close(allDone)
619                 err := testFunc(ctx, v)
620                 if err != context.Canceled {
621                         t.Errorf("got %T %q, expected %q", err, err, context.Canceled)
622                 }
623         }()
624         releaseHandler := make(chan struct{})
625         select {
626         case <-allDone:
627                 t.Error("testFunc finished without waiting for v.azHandler.race")
628         case <-time.After(10 * time.Second):
629                 t.Error("timed out waiting to enter handler")
630         case v.azHandler.race <- releaseHandler:
631         }
632
633         cancel()
634
635         select {
636         case <-time.After(10 * time.Second):
637                 t.Error("timed out waiting to cancel")
638         case <-allDone:
639         }
640
641         go func() {
642                 <-releaseHandler
643         }()
644 }
645
646 func (v *TestableAzureBlobVolume) PutRaw(locator string, data []byte) {
647         v.azHandler.PutRaw(v.ContainerName, locator, data)
648 }
649
650 func (v *TestableAzureBlobVolume) TouchWithDate(locator string, lastPut time.Time) {
651         v.azHandler.TouchWithDate(v.ContainerName, locator, lastPut)
652 }
653
654 func (v *TestableAzureBlobVolume) Teardown() {
655         v.azStub.Close()
656 }
657
658 func makeEtag() string {
659         return fmt.Sprintf("0x%x", rand.Int63())
660 }