Remove obsolete comment. No issue #
[arvados.git] / services / keepstore / azure_blob_volume_test.go
1 package main
2
3 import (
4         "bytes"
5         "crypto/md5"
6         "encoding/base64"
7         "encoding/xml"
8         "flag"
9         "fmt"
10         "io/ioutil"
11         "log"
12         "math/rand"
13         "net"
14         "net/http"
15         "net/http/httptest"
16         "regexp"
17         "sort"
18         "strconv"
19         "strings"
20         "sync"
21         "testing"
22         "time"
23
24         "github.com/curoverse/azure-sdk-for-go/storage"
25 )
26
27 const (
28         // The same fake credentials used by Microsoft's Azure emulator
29         emulatorAccountName = "devstoreaccount1"
30         emulatorAccountKey  = "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw=="
31 )
32
33 var azureTestContainer string
34
35 func init() {
36         flag.StringVar(
37                 &azureTestContainer,
38                 "test.azure-storage-container-volume",
39                 "",
40                 "Name of Azure container to use for testing. Do not use a container with real data! Use -azure-storage-account-name and -azure-storage-key-file arguments to supply credentials.")
41 }
42
43 type azBlob struct {
44         Data        []byte
45         Etag        string
46         Metadata    map[string]string
47         Mtime       time.Time
48         Uncommitted map[string][]byte
49 }
50
51 type azStubHandler struct {
52         sync.Mutex
53         blobs map[string]*azBlob
54         race  chan chan struct{}
55 }
56
57 func newAzStubHandler() *azStubHandler {
58         return &azStubHandler{
59                 blobs: make(map[string]*azBlob),
60         }
61 }
62
63 func (h *azStubHandler) TouchWithDate(container, hash string, t time.Time) {
64         blob, ok := h.blobs[container+"|"+hash]
65         if !ok {
66                 return
67         }
68         blob.Mtime = t
69 }
70
71 func (h *azStubHandler) PutRaw(container, hash string, data []byte) {
72         h.Lock()
73         defer h.Unlock()
74         h.blobs[container+"|"+hash] = &azBlob{
75                 Data:        data,
76                 Mtime:       time.Now(),
77                 Metadata:    make(map[string]string),
78                 Uncommitted: make(map[string][]byte),
79         }
80 }
81
82 func (h *azStubHandler) unlockAndRace() {
83         if h.race == nil {
84                 return
85         }
86         h.Unlock()
87         // Signal caller that race is starting by reading from
88         // h.race. If we get a channel, block until that channel is
89         // ready to receive. If we get nil (or h.race is closed) just
90         // proceed.
91         if c := <-h.race; c != nil {
92                 c <- struct{}{}
93         }
94         h.Lock()
95 }
96
97 var rangeRegexp = regexp.MustCompile(`^bytes=(\d+)-(\d+)$`)
98
99 func (h *azStubHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
100         h.Lock()
101         defer h.Unlock()
102         // defer log.Printf("azStubHandler: %+v", r)
103
104         path := strings.Split(r.URL.Path, "/")
105         container := path[1]
106         hash := ""
107         if len(path) > 2 {
108                 hash = path[2]
109         }
110
111         if err := r.ParseForm(); err != nil {
112                 log.Printf("azStubHandler(%+v): %s", r, err)
113                 rw.WriteHeader(http.StatusBadRequest)
114                 return
115         }
116
117         body, err := ioutil.ReadAll(r.Body)
118         if err != nil {
119                 return
120         }
121
122         type blockListRequestBody struct {
123                 XMLName     xml.Name `xml:"BlockList"`
124                 Uncommitted []string
125         }
126
127         blob, blobExists := h.blobs[container+"|"+hash]
128
129         switch {
130         case r.Method == "PUT" && r.Form.Get("comp") == "":
131                 // "Put Blob" API
132                 if _, ok := h.blobs[container+"|"+hash]; !ok {
133                         // Like the real Azure service, we offer a
134                         // race window during which other clients can
135                         // list/get the new blob before any data is
136                         // committed.
137                         h.blobs[container+"|"+hash] = &azBlob{
138                                 Mtime:       time.Now(),
139                                 Uncommitted: make(map[string][]byte),
140                                 Metadata:    make(map[string]string),
141                                 Etag:        makeEtag(),
142                         }
143                         h.unlockAndRace()
144                 }
145                 metadata := make(map[string]string)
146                 for k, v := range r.Header {
147                         if strings.HasPrefix(strings.ToLower(k), "x-ms-meta-") {
148                                 name := k[len("x-ms-meta-"):]
149                                 metadata[strings.ToLower(name)] = v[0]
150                         }
151                 }
152                 h.blobs[container+"|"+hash] = &azBlob{
153                         Data:        body,
154                         Mtime:       time.Now(),
155                         Uncommitted: make(map[string][]byte),
156                         Metadata:    metadata,
157                         Etag:        makeEtag(),
158                 }
159                 rw.WriteHeader(http.StatusCreated)
160         case r.Method == "PUT" && r.Form.Get("comp") == "block":
161                 // "Put Block" API
162                 if !blobExists {
163                         log.Printf("Got block for nonexistent blob: %+v", r)
164                         rw.WriteHeader(http.StatusBadRequest)
165                         return
166                 }
167                 blockID, err := base64.StdEncoding.DecodeString(r.Form.Get("blockid"))
168                 if err != nil || len(blockID) == 0 {
169                         log.Printf("Invalid blockid: %+q", r.Form.Get("blockid"))
170                         rw.WriteHeader(http.StatusBadRequest)
171                         return
172                 }
173                 blob.Uncommitted[string(blockID)] = body
174                 rw.WriteHeader(http.StatusCreated)
175         case r.Method == "PUT" && r.Form.Get("comp") == "blocklist":
176                 // "Put Block List" API
177                 bl := &blockListRequestBody{}
178                 if err := xml.Unmarshal(body, bl); err != nil {
179                         log.Printf("xml Unmarshal: %s", err)
180                         rw.WriteHeader(http.StatusBadRequest)
181                         return
182                 }
183                 for _, encBlockID := range bl.Uncommitted {
184                         blockID, err := base64.StdEncoding.DecodeString(encBlockID)
185                         if err != nil || len(blockID) == 0 || blob.Uncommitted[string(blockID)] == nil {
186                                 log.Printf("Invalid blockid: %+q", encBlockID)
187                                 rw.WriteHeader(http.StatusBadRequest)
188                                 return
189                         }
190                         blob.Data = blob.Uncommitted[string(blockID)]
191                         blob.Etag = makeEtag()
192                         blob.Mtime = time.Now()
193                         delete(blob.Uncommitted, string(blockID))
194                 }
195                 rw.WriteHeader(http.StatusCreated)
196         case r.Method == "PUT" && r.Form.Get("comp") == "metadata":
197                 // "Set Metadata Headers" API. We don't bother
198                 // stubbing "Get Metadata Headers": AzureBlobVolume
199                 // sets metadata headers only as a way to bump Etag
200                 // and Last-Modified.
201                 if !blobExists {
202                         log.Printf("Got metadata for nonexistent blob: %+v", r)
203                         rw.WriteHeader(http.StatusBadRequest)
204                         return
205                 }
206                 blob.Metadata = make(map[string]string)
207                 for k, v := range r.Header {
208                         if strings.HasPrefix(strings.ToLower(k), "x-ms-meta-") {
209                                 name := k[len("x-ms-meta-"):]
210                                 blob.Metadata[strings.ToLower(name)] = v[0]
211                         }
212                 }
213                 blob.Mtime = time.Now()
214                 blob.Etag = makeEtag()
215         case (r.Method == "GET" || r.Method == "HEAD") && r.Form.Get("comp") == "metadata" && hash != "":
216                 // "Get Blob Metadata" API
217                 if !blobExists {
218                         rw.WriteHeader(http.StatusNotFound)
219                         return
220                 }
221                 for k, v := range blob.Metadata {
222                         rw.Header().Set(fmt.Sprintf("x-ms-meta-%s", k), v)
223                 }
224                 return
225         case (r.Method == "GET" || r.Method == "HEAD") && hash != "":
226                 // "Get Blob" API
227                 if !blobExists {
228                         rw.WriteHeader(http.StatusNotFound)
229                         return
230                 }
231                 data := blob.Data
232                 if rangeSpec := rangeRegexp.FindStringSubmatch(r.Header.Get("Range")); rangeSpec != nil {
233                         b0, err0 := strconv.Atoi(rangeSpec[1])
234                         b1, err1 := strconv.Atoi(rangeSpec[2])
235                         if err0 != nil || err1 != nil || b0 >= len(data) || b1 >= len(data) || b0 > b1 {
236                                 rw.Header().Set("Content-Range", fmt.Sprintf("bytes */%d", len(data)))
237                                 rw.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
238                                 return
239                         }
240                         rw.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", b0, b1, len(data)))
241                         rw.WriteHeader(http.StatusPartialContent)
242                         data = data[b0 : b1+1]
243                 }
244                 rw.Header().Set("Last-Modified", blob.Mtime.Format(time.RFC1123))
245                 rw.Header().Set("Content-Length", strconv.Itoa(len(data)))
246                 if r.Method == "GET" {
247                         if _, err := rw.Write(data); err != nil {
248                                 log.Printf("write %+q: %s", data, err)
249                         }
250                 }
251                 h.unlockAndRace()
252         case r.Method == "DELETE" && hash != "":
253                 // "Delete Blob" API
254                 if !blobExists {
255                         rw.WriteHeader(http.StatusNotFound)
256                         return
257                 }
258                 delete(h.blobs, container+"|"+hash)
259                 rw.WriteHeader(http.StatusAccepted)
260         case r.Method == "GET" && r.Form.Get("comp") == "list" && r.Form.Get("restype") == "container":
261                 // "List Blobs" API
262                 prefix := container + "|" + r.Form.Get("prefix")
263                 marker := r.Form.Get("marker")
264
265                 maxResults := 2
266                 if n, err := strconv.Atoi(r.Form.Get("maxresults")); err == nil && n >= 1 && n <= 5000 {
267                         maxResults = n
268                 }
269
270                 resp := storage.BlobListResponse{
271                         Marker:     marker,
272                         NextMarker: "",
273                         MaxResults: int64(maxResults),
274                 }
275                 var hashes sort.StringSlice
276                 for k := range h.blobs {
277                         if strings.HasPrefix(k, prefix) {
278                                 hashes = append(hashes, k[len(container)+1:])
279                         }
280                 }
281                 hashes.Sort()
282                 for _, hash := range hashes {
283                         if len(resp.Blobs) == maxResults {
284                                 resp.NextMarker = hash
285                                 break
286                         }
287                         if len(resp.Blobs) > 0 || marker == "" || marker == hash {
288                                 blob := h.blobs[container+"|"+hash]
289                                 bmeta := map[string]string(nil)
290                                 if r.Form.Get("include") == "metadata" {
291                                         bmeta = blob.Metadata
292                                 }
293                                 b := storage.Blob{
294                                         Name: hash,
295                                         Properties: storage.BlobProperties{
296                                                 LastModified:  blob.Mtime.Format(time.RFC1123),
297                                                 ContentLength: int64(len(blob.Data)),
298                                                 Etag:          blob.Etag,
299                                         },
300                                         Metadata: bmeta,
301                                 }
302                                 resp.Blobs = append(resp.Blobs, b)
303                         }
304                 }
305                 buf, err := xml.Marshal(resp)
306                 if err != nil {
307                         log.Print(err)
308                         rw.WriteHeader(http.StatusInternalServerError)
309                 }
310                 rw.Write(buf)
311         default:
312                 log.Printf("azStubHandler: not implemented: %+v Body:%+q", r, body)
313                 rw.WriteHeader(http.StatusNotImplemented)
314         }
315 }
316
317 // azStubDialer is a net.Dialer that notices when the Azure driver
318 // tries to connect to "devstoreaccount1.blob.127.0.0.1:46067", and
319 // in such cases transparently dials "127.0.0.1:46067" instead.
320 type azStubDialer struct {
321         net.Dialer
322 }
323
324 var localHostPortRe = regexp.MustCompile(`(127\.0\.0\.1|localhost|\[::1\]):\d+`)
325
326 func (d *azStubDialer) Dial(network, address string) (net.Conn, error) {
327         if hp := localHostPortRe.FindString(address); hp != "" {
328                 log.Println("azStubDialer: dial", hp, "instead of", address)
329                 address = hp
330         }
331         return d.Dialer.Dial(network, address)
332 }
333
334 type TestableAzureBlobVolume struct {
335         *AzureBlobVolume
336         azHandler *azStubHandler
337         azStub    *httptest.Server
338         t         TB
339 }
340
341 func NewTestableAzureBlobVolume(t TB, readonly bool, replication int) *TestableAzureBlobVolume {
342         azHandler := newAzStubHandler()
343         azStub := httptest.NewServer(azHandler)
344
345         var azClient storage.Client
346
347         container := azureTestContainer
348         if container == "" {
349                 // Connect to stub instead of real Azure storage service
350                 stubURLBase := strings.Split(azStub.URL, "://")[1]
351                 var err error
352                 if azClient, err = storage.NewClient(emulatorAccountName, emulatorAccountKey, stubURLBase, storage.DefaultAPIVersion, false); err != nil {
353                         t.Fatal(err)
354                 }
355                 container = "fakecontainername"
356         } else {
357                 // Connect to real Azure storage service
358                 accountKey, err := readKeyFromFile(azureStorageAccountKeyFile)
359                 if err != nil {
360                         t.Fatal(err)
361                 }
362                 azClient, err = storage.NewBasicClient(azureStorageAccountName, accountKey)
363                 if err != nil {
364                         t.Fatal(err)
365                 }
366         }
367
368         v := &AzureBlobVolume{
369                 ContainerName:    container,
370                 ReadOnly:         readonly,
371                 AzureReplication: replication,
372                 azClient:         azClient,
373                 bsClient:         azClient.GetBlobService(),
374         }
375
376         return &TestableAzureBlobVolume{
377                 AzureBlobVolume: v,
378                 azHandler:       azHandler,
379                 azStub:          azStub,
380                 t:               t,
381         }
382 }
383
384 func TestAzureBlobVolumeWithGeneric(t *testing.T) {
385         defer func(t http.RoundTripper) {
386                 http.DefaultTransport = t
387         }(http.DefaultTransport)
388         http.DefaultTransport = &http.Transport{
389                 Dial: (&azStubDialer{}).Dial,
390         }
391         azureWriteRaceInterval = time.Millisecond
392         azureWriteRacePollTime = time.Nanosecond
393         DoGenericVolumeTests(t, func(t TB) TestableVolume {
394                 return NewTestableAzureBlobVolume(t, false, azureStorageReplication)
395         })
396 }
397
398 func TestAzureBlobVolumeConcurrentRanges(t *testing.T) {
399         defer func(b int) {
400                 azureMaxGetBytes = b
401         }(azureMaxGetBytes)
402
403         defer func(t http.RoundTripper) {
404                 http.DefaultTransport = t
405         }(http.DefaultTransport)
406         http.DefaultTransport = &http.Transport{
407                 Dial: (&azStubDialer{}).Dial,
408         }
409         azureWriteRaceInterval = time.Millisecond
410         azureWriteRacePollTime = time.Nanosecond
411         // Test (BlockSize mod azureMaxGetBytes)==0 and !=0 cases
412         for _, azureMaxGetBytes = range []int{2 << 22, 2<<22 - 1} {
413                 DoGenericVolumeTests(t, func(t TB) TestableVolume {
414                         return NewTestableAzureBlobVolume(t, false, azureStorageReplication)
415                 })
416         }
417 }
418
419 func TestReadonlyAzureBlobVolumeWithGeneric(t *testing.T) {
420         defer func(t http.RoundTripper) {
421                 http.DefaultTransport = t
422         }(http.DefaultTransport)
423         http.DefaultTransport = &http.Transport{
424                 Dial: (&azStubDialer{}).Dial,
425         }
426         azureWriteRaceInterval = time.Millisecond
427         azureWriteRacePollTime = time.Nanosecond
428         DoGenericVolumeTests(t, func(t TB) TestableVolume {
429                 return NewTestableAzureBlobVolume(t, true, azureStorageReplication)
430         })
431 }
432
433 func TestAzureBlobVolumeRangeFenceposts(t *testing.T) {
434         defer func(t http.RoundTripper) {
435                 http.DefaultTransport = t
436         }(http.DefaultTransport)
437         http.DefaultTransport = &http.Transport{
438                 Dial: (&azStubDialer{}).Dial,
439         }
440
441         v := NewTestableAzureBlobVolume(t, false, 3)
442         defer v.Teardown()
443
444         for _, size := range []int{
445                 2<<22 - 1, // one <max read
446                 2 << 22,   // one =max read
447                 2<<22 + 1, // one =max read, one <max
448                 2 << 23,   // two =max reads
449                 BlockSize - 1,
450                 BlockSize,
451         } {
452                 data := make([]byte, size)
453                 for i := range data {
454                         data[i] = byte((i + 7) & 0xff)
455                 }
456                 hash := fmt.Sprintf("%x", md5.Sum(data))
457                 err := v.Put(hash, data)
458                 if err != nil {
459                         t.Error(err)
460                 }
461                 gotData := make([]byte, len(data))
462                 gotLen, err := v.Get(hash, gotData)
463                 if err != nil {
464                         t.Error(err)
465                 }
466                 gotHash := fmt.Sprintf("%x", md5.Sum(gotData))
467                 if gotLen != size {
468                         t.Error("length mismatch: got %d != %d", gotLen, size)
469                 }
470                 if gotHash != hash {
471                         t.Error("hash mismatch: got %s != %s", gotHash, hash)
472                 }
473         }
474 }
475
476 func TestAzureBlobVolumeReplication(t *testing.T) {
477         for r := 1; r <= 4; r++ {
478                 v := NewTestableAzureBlobVolume(t, false, r)
479                 defer v.Teardown()
480                 if n := v.Replication(); n != r {
481                         t.Errorf("Got replication %d, expected %d", n, r)
482                 }
483         }
484 }
485
486 func TestAzureBlobVolumeCreateBlobRace(t *testing.T) {
487         defer func(t http.RoundTripper) {
488                 http.DefaultTransport = t
489         }(http.DefaultTransport)
490         http.DefaultTransport = &http.Transport{
491                 Dial: (&azStubDialer{}).Dial,
492         }
493
494         v := NewTestableAzureBlobVolume(t, false, 3)
495         defer v.Teardown()
496
497         azureWriteRaceInterval = time.Second
498         azureWriteRacePollTime = time.Millisecond
499
500         allDone := make(chan struct{})
501         v.azHandler.race = make(chan chan struct{})
502         go func() {
503                 err := v.Put(TestHash, TestBlock)
504                 if err != nil {
505                         t.Error(err)
506                 }
507         }()
508         continuePut := make(chan struct{})
509         // Wait for the stub's Put to create the empty blob
510         v.azHandler.race <- continuePut
511         go func() {
512                 buf := make([]byte, len(TestBlock))
513                 _, err := v.Get(TestHash, buf)
514                 if err != nil {
515                         t.Error(err)
516                 }
517                 close(allDone)
518         }()
519         // Wait for the stub's Get to get the empty blob
520         close(v.azHandler.race)
521         // Allow stub's Put to continue, so the real data is ready
522         // when the volume's Get retries
523         <-continuePut
524         // Wait for volume's Get to return the real data
525         <-allDone
526 }
527
528 func TestAzureBlobVolumeCreateBlobRaceDeadline(t *testing.T) {
529         defer func(t http.RoundTripper) {
530                 http.DefaultTransport = t
531         }(http.DefaultTransport)
532         http.DefaultTransport = &http.Transport{
533                 Dial: (&azStubDialer{}).Dial,
534         }
535
536         v := NewTestableAzureBlobVolume(t, false, 3)
537         defer v.Teardown()
538
539         azureWriteRaceInterval = 2 * time.Second
540         azureWriteRacePollTime = 5 * time.Millisecond
541
542         v.PutRaw(TestHash, nil)
543
544         buf := new(bytes.Buffer)
545         v.IndexTo("", buf)
546         if buf.Len() != 0 {
547                 t.Errorf("Index %+q should be empty", buf.Bytes())
548         }
549
550         v.TouchWithDate(TestHash, time.Now().Add(-1982*time.Millisecond))
551
552         allDone := make(chan struct{})
553         go func() {
554                 defer close(allDone)
555                 buf := make([]byte, BlockSize)
556                 n, err := v.Get(TestHash, buf)
557                 if err != nil {
558                         t.Error(err)
559                         return
560                 }
561                 if n != 0 {
562                         t.Errorf("Got %+q, expected empty buf", buf[:n])
563                 }
564         }()
565         select {
566         case <-allDone:
567         case <-time.After(time.Second):
568                 t.Error("Get should have stopped waiting for race when block was 2s old")
569         }
570
571         buf.Reset()
572         v.IndexTo("", buf)
573         if !bytes.HasPrefix(buf.Bytes(), []byte(TestHash+"+0")) {
574                 t.Errorf("Index %+q should have %+q", buf.Bytes(), TestHash+"+0")
575         }
576 }
577
578 func (v *TestableAzureBlobVolume) PutRaw(locator string, data []byte) {
579         v.azHandler.PutRaw(v.ContainerName, locator, data)
580 }
581
582 func (v *TestableAzureBlobVolume) TouchWithDate(locator string, lastPut time.Time) {
583         v.azHandler.TouchWithDate(v.ContainerName, locator, lastPut)
584 }
585
586 func (v *TestableAzureBlobVolume) Teardown() {
587         v.azStub.Close()
588 }
589
590 func makeEtag() string {
591         return fmt.Sprintf("0x%x", rand.Int63())
592 }