Merge branch '8288-arv-mount-deadlock' refs #8288
[arvados.git] / services / keepstore / azure_blob_volume_test.go
1 package main
2
3 import (
4         "bytes"
5         "crypto/md5"
6         "encoding/base64"
7         "encoding/xml"
8         "flag"
9         "fmt"
10         "io/ioutil"
11         "log"
12         "math/rand"
13         "net"
14         "net/http"
15         "net/http/httptest"
16         "regexp"
17         "sort"
18         "strconv"
19         "strings"
20         "sync"
21         "testing"
22         "time"
23
24         "github.com/curoverse/azure-sdk-for-go/storage"
25 )
26
27 const (
28         // The same fake credentials used by Microsoft's Azure emulator
29         emulatorAccountName = "devstoreaccount1"
30         emulatorAccountKey  = "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw=="
31 )
32
33 var azureTestContainer string
34
35 func init() {
36         flag.StringVar(
37                 &azureTestContainer,
38                 "test.azure-storage-container-volume",
39                 "",
40                 "Name of Azure container to use for testing. Do not use a container with real data! Use -azure-storage-account-name and -azure-storage-key-file arguments to supply credentials.")
41 }
42
43 type azBlob struct {
44         Data        []byte
45         Etag        string
46         Metadata    map[string]string
47         Mtime       time.Time
48         Uncommitted map[string][]byte
49 }
50
51 type azStubHandler struct {
52         sync.Mutex
53         blobs map[string]*azBlob
54         race  chan chan struct{}
55 }
56
57 func newAzStubHandler() *azStubHandler {
58         return &azStubHandler{
59                 blobs: make(map[string]*azBlob),
60         }
61 }
62
63 func (h *azStubHandler) TouchWithDate(container, hash string, t time.Time) {
64         blob, ok := h.blobs[container+"|"+hash]
65         if !ok {
66                 return
67         }
68         blob.Mtime = t
69 }
70
71 func (h *azStubHandler) PutRaw(container, hash string, data []byte) {
72         h.Lock()
73         defer h.Unlock()
74         h.blobs[container+"|"+hash] = &azBlob{
75                 Data:        data,
76                 Mtime:       time.Now(),
77                 Uncommitted: make(map[string][]byte),
78         }
79 }
80
81 func (h *azStubHandler) unlockAndRace() {
82         if h.race == nil {
83                 return
84         }
85         h.Unlock()
86         // Signal caller that race is starting by reading from
87         // h.race. If we get a channel, block until that channel is
88         // ready to receive. If we get nil (or h.race is closed) just
89         // proceed.
90         if c := <-h.race; c != nil {
91                 c <- struct{}{}
92         }
93         h.Lock()
94 }
95
96 var rangeRegexp = regexp.MustCompile(`^bytes=(\d+)-(\d+)$`)
97
98 func (h *azStubHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
99         h.Lock()
100         defer h.Unlock()
101         // defer log.Printf("azStubHandler: %+v", r)
102
103         path := strings.Split(r.URL.Path, "/")
104         container := path[1]
105         hash := ""
106         if len(path) > 2 {
107                 hash = path[2]
108         }
109
110         if err := r.ParseForm(); err != nil {
111                 log.Printf("azStubHandler(%+v): %s", r, err)
112                 rw.WriteHeader(http.StatusBadRequest)
113                 return
114         }
115
116         body, err := ioutil.ReadAll(r.Body)
117         if err != nil {
118                 return
119         }
120
121         type blockListRequestBody struct {
122                 XMLName     xml.Name `xml:"BlockList"`
123                 Uncommitted []string
124         }
125
126         blob, blobExists := h.blobs[container+"|"+hash]
127
128         switch {
129         case r.Method == "PUT" && r.Form.Get("comp") == "":
130                 // "Put Blob" API
131                 if _, ok := h.blobs[container+"|"+hash]; !ok {
132                         // Like the real Azure service, we offer a
133                         // race window during which other clients can
134                         // list/get the new blob before any data is
135                         // committed.
136                         h.blobs[container+"|"+hash] = &azBlob{
137                                 Mtime:       time.Now(),
138                                 Uncommitted: make(map[string][]byte),
139                                 Etag:        makeEtag(),
140                         }
141                         h.unlockAndRace()
142                 }
143                 h.blobs[container+"|"+hash] = &azBlob{
144                         Data:        body,
145                         Mtime:       time.Now(),
146                         Uncommitted: make(map[string][]byte),
147                         Etag:        makeEtag(),
148                 }
149                 rw.WriteHeader(http.StatusCreated)
150         case r.Method == "PUT" && r.Form.Get("comp") == "block":
151                 // "Put Block" API
152                 if !blobExists {
153                         log.Printf("Got block for nonexistent blob: %+v", r)
154                         rw.WriteHeader(http.StatusBadRequest)
155                         return
156                 }
157                 blockID, err := base64.StdEncoding.DecodeString(r.Form.Get("blockid"))
158                 if err != nil || len(blockID) == 0 {
159                         log.Printf("Invalid blockid: %+q", r.Form.Get("blockid"))
160                         rw.WriteHeader(http.StatusBadRequest)
161                         return
162                 }
163                 blob.Uncommitted[string(blockID)] = body
164                 rw.WriteHeader(http.StatusCreated)
165         case r.Method == "PUT" && r.Form.Get("comp") == "blocklist":
166                 // "Put Block List" API
167                 bl := &blockListRequestBody{}
168                 if err := xml.Unmarshal(body, bl); err != nil {
169                         log.Printf("xml Unmarshal: %s", err)
170                         rw.WriteHeader(http.StatusBadRequest)
171                         return
172                 }
173                 for _, encBlockID := range bl.Uncommitted {
174                         blockID, err := base64.StdEncoding.DecodeString(encBlockID)
175                         if err != nil || len(blockID) == 0 || blob.Uncommitted[string(blockID)] == nil {
176                                 log.Printf("Invalid blockid: %+q", encBlockID)
177                                 rw.WriteHeader(http.StatusBadRequest)
178                                 return
179                         }
180                         blob.Data = blob.Uncommitted[string(blockID)]
181                         blob.Etag = makeEtag()
182                         blob.Mtime = time.Now()
183                         delete(blob.Uncommitted, string(blockID))
184                 }
185                 rw.WriteHeader(http.StatusCreated)
186         case r.Method == "PUT" && r.Form.Get("comp") == "metadata":
187                 // "Set Metadata Headers" API. We don't bother
188                 // stubbing "Get Metadata Headers": AzureBlobVolume
189                 // sets metadata headers only as a way to bump Etag
190                 // and Last-Modified.
191                 if !blobExists {
192                         log.Printf("Got metadata for nonexistent blob: %+v", r)
193                         rw.WriteHeader(http.StatusBadRequest)
194                         return
195                 }
196                 blob.Metadata = make(map[string]string)
197                 for k, v := range r.Header {
198                         if strings.HasPrefix(strings.ToLower(k), "x-ms-meta-") {
199                                 blob.Metadata[k] = v[0]
200                         }
201                 }
202                 blob.Mtime = time.Now()
203                 blob.Etag = makeEtag()
204         case (r.Method == "GET" || r.Method == "HEAD") && hash != "":
205                 // "Get Blob" API
206                 if !blobExists {
207                         rw.WriteHeader(http.StatusNotFound)
208                         return
209                 }
210                 data := blob.Data
211                 if rangeSpec := rangeRegexp.FindStringSubmatch(r.Header.Get("Range")); rangeSpec != nil {
212                         b0, err0 := strconv.Atoi(rangeSpec[1])
213                         b1, err1 := strconv.Atoi(rangeSpec[2])
214                         if err0 != nil || err1 != nil || b0 >= len(data) || b1 >= len(data) || b0 > b1 {
215                                 rw.Header().Set("Content-Range", fmt.Sprintf("bytes */%d", len(data)))
216                                 rw.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
217                                 return
218                         }
219                         rw.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", b0, b1, len(data)))
220                         rw.WriteHeader(http.StatusPartialContent)
221                         data = data[b0 : b1+1]
222                 }
223                 rw.Header().Set("Last-Modified", blob.Mtime.Format(time.RFC1123))
224                 rw.Header().Set("Content-Length", strconv.Itoa(len(data)))
225                 if r.Method == "GET" {
226                         if _, err := rw.Write(data); err != nil {
227                                 log.Printf("write %+q: %s", data, err)
228                         }
229                 }
230                 h.unlockAndRace()
231         case r.Method == "DELETE" && hash != "":
232                 // "Delete Blob" API
233                 if !blobExists {
234                         rw.WriteHeader(http.StatusNotFound)
235                         return
236                 }
237                 delete(h.blobs, container+"|"+hash)
238                 rw.WriteHeader(http.StatusAccepted)
239         case r.Method == "GET" && r.Form.Get("comp") == "list" && r.Form.Get("restype") == "container":
240                 // "List Blobs" API
241                 prefix := container + "|" + r.Form.Get("prefix")
242                 marker := r.Form.Get("marker")
243
244                 maxResults := 2
245                 if n, err := strconv.Atoi(r.Form.Get("maxresults")); err == nil && n >= 1 && n <= 5000 {
246                         maxResults = n
247                 }
248
249                 resp := storage.BlobListResponse{
250                         Marker:     marker,
251                         NextMarker: "",
252                         MaxResults: int64(maxResults),
253                 }
254                 var hashes sort.StringSlice
255                 for k := range h.blobs {
256                         if strings.HasPrefix(k, prefix) {
257                                 hashes = append(hashes, k[len(container)+1:])
258                         }
259                 }
260                 hashes.Sort()
261                 for _, hash := range hashes {
262                         if len(resp.Blobs) == maxResults {
263                                 resp.NextMarker = hash
264                                 break
265                         }
266                         if len(resp.Blobs) > 0 || marker == "" || marker == hash {
267                                 blob := h.blobs[container+"|"+hash]
268                                 resp.Blobs = append(resp.Blobs, storage.Blob{
269                                         Name: hash,
270                                         Properties: storage.BlobProperties{
271                                                 LastModified:  blob.Mtime.Format(time.RFC1123),
272                                                 ContentLength: int64(len(blob.Data)),
273                                                 Etag:          blob.Etag,
274                                         },
275                                 })
276                         }
277                 }
278                 buf, err := xml.Marshal(resp)
279                 if err != nil {
280                         log.Print(err)
281                         rw.WriteHeader(http.StatusInternalServerError)
282                 }
283                 rw.Write(buf)
284         default:
285                 log.Printf("azStubHandler: not implemented: %+v Body:%+q", r, body)
286                 rw.WriteHeader(http.StatusNotImplemented)
287         }
288 }
289
290 // azStubDialer is a net.Dialer that notices when the Azure driver
291 // tries to connect to "devstoreaccount1.blob.127.0.0.1:46067", and
292 // in such cases transparently dials "127.0.0.1:46067" instead.
293 type azStubDialer struct {
294         net.Dialer
295 }
296
297 var localHostPortRe = regexp.MustCompile(`(127\.0\.0\.1|localhost|\[::1\]):\d+`)
298
299 func (d *azStubDialer) Dial(network, address string) (net.Conn, error) {
300         if hp := localHostPortRe.FindString(address); hp != "" {
301                 log.Println("azStubDialer: dial", hp, "instead of", address)
302                 address = hp
303         }
304         return d.Dialer.Dial(network, address)
305 }
306
307 type TestableAzureBlobVolume struct {
308         *AzureBlobVolume
309         azHandler *azStubHandler
310         azStub    *httptest.Server
311         t         TB
312 }
313
314 func NewTestableAzureBlobVolume(t TB, readonly bool, replication int) *TestableAzureBlobVolume {
315         azHandler := newAzStubHandler()
316         azStub := httptest.NewServer(azHandler)
317
318         var azClient storage.Client
319
320         container := azureTestContainer
321         if container == "" {
322                 // Connect to stub instead of real Azure storage service
323                 stubURLBase := strings.Split(azStub.URL, "://")[1]
324                 var err error
325                 if azClient, err = storage.NewClient(emulatorAccountName, emulatorAccountKey, stubURLBase, storage.DefaultAPIVersion, false); err != nil {
326                         t.Fatal(err)
327                 }
328                 container = "fakecontainername"
329         } else {
330                 // Connect to real Azure storage service
331                 accountKey, err := readKeyFromFile(azureStorageAccountKeyFile)
332                 if err != nil {
333                         t.Fatal(err)
334                 }
335                 azClient, err = storage.NewBasicClient(azureStorageAccountName, accountKey)
336                 if err != nil {
337                         t.Fatal(err)
338                 }
339         }
340
341         v := NewAzureBlobVolume(azClient, container, readonly, replication)
342
343         return &TestableAzureBlobVolume{
344                 AzureBlobVolume: v,
345                 azHandler:       azHandler,
346                 azStub:          azStub,
347                 t:               t,
348         }
349 }
350
351 func TestAzureBlobVolumeWithGeneric(t *testing.T) {
352         defer func(t http.RoundTripper) {
353                 http.DefaultTransport = t
354         }(http.DefaultTransport)
355         http.DefaultTransport = &http.Transport{
356                 Dial: (&azStubDialer{}).Dial,
357         }
358         azureWriteRaceInterval = time.Millisecond
359         azureWriteRacePollTime = time.Nanosecond
360         DoGenericVolumeTests(t, func(t TB) TestableVolume {
361                 return NewTestableAzureBlobVolume(t, false, azureStorageReplication)
362         })
363 }
364
365 func TestAzureBlobVolumeConcurrentRanges(t *testing.T) {
366         defer func(b int) {
367                 azureMaxGetBytes = b
368         }(azureMaxGetBytes)
369
370         defer func(t http.RoundTripper) {
371                 http.DefaultTransport = t
372         }(http.DefaultTransport)
373         http.DefaultTransport = &http.Transport{
374                 Dial: (&azStubDialer{}).Dial,
375         }
376         azureWriteRaceInterval = time.Millisecond
377         azureWriteRacePollTime = time.Nanosecond
378         // Test (BlockSize mod azureMaxGetBytes)==0 and !=0 cases
379         for _, azureMaxGetBytes = range []int{2 << 22, 2<<22 - 1} {
380                 DoGenericVolumeTests(t, func(t TB) TestableVolume {
381                         return NewTestableAzureBlobVolume(t, false, azureStorageReplication)
382                 })
383         }
384 }
385
386 func TestReadonlyAzureBlobVolumeWithGeneric(t *testing.T) {
387         defer func(t http.RoundTripper) {
388                 http.DefaultTransport = t
389         }(http.DefaultTransport)
390         http.DefaultTransport = &http.Transport{
391                 Dial: (&azStubDialer{}).Dial,
392         }
393         azureWriteRaceInterval = time.Millisecond
394         azureWriteRacePollTime = time.Nanosecond
395         DoGenericVolumeTests(t, func(t TB) TestableVolume {
396                 return NewTestableAzureBlobVolume(t, true, azureStorageReplication)
397         })
398 }
399
400 func TestAzureBlobVolumeRangeFenceposts(t *testing.T) {
401         defer func(t http.RoundTripper) {
402                 http.DefaultTransport = t
403         }(http.DefaultTransport)
404         http.DefaultTransport = &http.Transport{
405                 Dial: (&azStubDialer{}).Dial,
406         }
407
408         v := NewTestableAzureBlobVolume(t, false, 3)
409         defer v.Teardown()
410
411         for _, size := range []int{
412                 2<<22 - 1, // one <max read
413                 2 << 22,   // one =max read
414                 2<<22 + 1, // one =max read, one <max
415                 2 << 23,   // two =max reads
416                 BlockSize - 1,
417                 BlockSize,
418         } {
419                 data := make([]byte, size)
420                 for i := range data {
421                         data[i] = byte((i + 7) & 0xff)
422                 }
423                 hash := fmt.Sprintf("%x", md5.Sum(data))
424                 err := v.Put(hash, data)
425                 if err != nil {
426                         t.Error(err)
427                 }
428                 gotData, err := v.Get(hash)
429                 if err != nil {
430                         t.Error(err)
431                 }
432                 gotHash := fmt.Sprintf("%x", md5.Sum(gotData))
433                 gotLen := len(gotData)
434                 bufs.Put(gotData)
435                 if gotLen != size {
436                         t.Error("length mismatch: got %d != %d", gotLen, size)
437                 }
438                 if gotHash != hash {
439                         t.Error("hash mismatch: got %s != %s", gotHash, hash)
440                 }
441         }
442 }
443
444 func TestAzureBlobVolumeReplication(t *testing.T) {
445         for r := 1; r <= 4; r++ {
446                 v := NewTestableAzureBlobVolume(t, false, r)
447                 defer v.Teardown()
448                 if n := v.Replication(); n != r {
449                         t.Errorf("Got replication %d, expected %d", n, r)
450                 }
451         }
452 }
453
454 func TestAzureBlobVolumeCreateBlobRace(t *testing.T) {
455         defer func(t http.RoundTripper) {
456                 http.DefaultTransport = t
457         }(http.DefaultTransport)
458         http.DefaultTransport = &http.Transport{
459                 Dial: (&azStubDialer{}).Dial,
460         }
461
462         v := NewTestableAzureBlobVolume(t, false, 3)
463         defer v.Teardown()
464
465         azureWriteRaceInterval = time.Second
466         azureWriteRacePollTime = time.Millisecond
467
468         allDone := make(chan struct{})
469         v.azHandler.race = make(chan chan struct{})
470         go func() {
471                 err := v.Put(TestHash, TestBlock)
472                 if err != nil {
473                         t.Error(err)
474                 }
475         }()
476         continuePut := make(chan struct{})
477         // Wait for the stub's Put to create the empty blob
478         v.azHandler.race <- continuePut
479         go func() {
480                 buf, err := v.Get(TestHash)
481                 if err != nil {
482                         t.Error(err)
483                 } else {
484                         bufs.Put(buf)
485                 }
486                 close(allDone)
487         }()
488         // Wait for the stub's Get to get the empty blob
489         close(v.azHandler.race)
490         // Allow stub's Put to continue, so the real data is ready
491         // when the volume's Get retries
492         <-continuePut
493         // Wait for volume's Get to return the real data
494         <-allDone
495 }
496
497 func TestAzureBlobVolumeCreateBlobRaceDeadline(t *testing.T) {
498         defer func(t http.RoundTripper) {
499                 http.DefaultTransport = t
500         }(http.DefaultTransport)
501         http.DefaultTransport = &http.Transport{
502                 Dial: (&azStubDialer{}).Dial,
503         }
504
505         v := NewTestableAzureBlobVolume(t, false, 3)
506         defer v.Teardown()
507
508         azureWriteRaceInterval = 2 * time.Second
509         azureWriteRacePollTime = 5 * time.Millisecond
510
511         v.PutRaw(TestHash, nil)
512
513         buf := new(bytes.Buffer)
514         v.IndexTo("", buf)
515         if buf.Len() != 0 {
516                 t.Errorf("Index %+q should be empty", buf.Bytes())
517         }
518
519         v.TouchWithDate(TestHash, time.Now().Add(-1982*time.Millisecond))
520
521         allDone := make(chan struct{})
522         go func() {
523                 defer close(allDone)
524                 buf, err := v.Get(TestHash)
525                 if err != nil {
526                         t.Error(err)
527                         return
528                 }
529                 if len(buf) != 0 {
530                         t.Errorf("Got %+q, expected empty buf", buf)
531                 }
532                 bufs.Put(buf)
533         }()
534         select {
535         case <-allDone:
536         case <-time.After(time.Second):
537                 t.Error("Get should have stopped waiting for race when block was 2s old")
538         }
539
540         buf.Reset()
541         v.IndexTo("", buf)
542         if !bytes.HasPrefix(buf.Bytes(), []byte(TestHash+"+0")) {
543                 t.Errorf("Index %+q should have %+q", buf.Bytes(), TestHash+"+0")
544         }
545 }
546
547 func (v *TestableAzureBlobVolume) PutRaw(locator string, data []byte) {
548         v.azHandler.PutRaw(v.containerName, locator, data)
549 }
550
551 func (v *TestableAzureBlobVolume) TouchWithDate(locator string, lastPut time.Time) {
552         v.azHandler.TouchWithDate(v.containerName, locator, lastPut)
553 }
554
555 func (v *TestableAzureBlobVolume) Teardown() {
556         v.azStub.Close()
557 }
558
559 func makeEtag() string {
560         return fmt.Sprintf("0x%x", rand.Int63())
561 }