8460: Merge branch 'master' into 8460-websocket-go
[arvados.git] / services / keepstore / azure_blob_volume_test.go
1 package main
2
3 import (
4         "bytes"
5         "context"
6         "crypto/md5"
7         "encoding/base64"
8         "encoding/xml"
9         "flag"
10         "fmt"
11         "io/ioutil"
12         "math/rand"
13         "net"
14         "net/http"
15         "net/http/httptest"
16         "regexp"
17         "sort"
18         "strconv"
19         "strings"
20         "sync"
21         "testing"
22         "time"
23
24         log "github.com/Sirupsen/logrus"
25         "github.com/curoverse/azure-sdk-for-go/storage"
26 )
27
28 const (
29         // The same fake credentials used by Microsoft's Azure emulator
30         emulatorAccountName = "devstoreaccount1"
31         emulatorAccountKey  = "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw=="
32 )
33
34 var azureTestContainer string
35
36 func init() {
37         flag.StringVar(
38                 &azureTestContainer,
39                 "test.azure-storage-container-volume",
40                 "",
41                 "Name of Azure container to use for testing. Do not use a container with real data! Use -azure-storage-account-name and -azure-storage-key-file arguments to supply credentials.")
42 }
43
44 type azBlob struct {
45         Data        []byte
46         Etag        string
47         Metadata    map[string]string
48         Mtime       time.Time
49         Uncommitted map[string][]byte
50 }
51
52 type azStubHandler struct {
53         sync.Mutex
54         blobs map[string]*azBlob
55         race  chan chan struct{}
56 }
57
58 func newAzStubHandler() *azStubHandler {
59         return &azStubHandler{
60                 blobs: make(map[string]*azBlob),
61         }
62 }
63
64 func (h *azStubHandler) TouchWithDate(container, hash string, t time.Time) {
65         blob, ok := h.blobs[container+"|"+hash]
66         if !ok {
67                 return
68         }
69         blob.Mtime = t
70 }
71
72 func (h *azStubHandler) PutRaw(container, hash string, data []byte) {
73         h.Lock()
74         defer h.Unlock()
75         h.blobs[container+"|"+hash] = &azBlob{
76                 Data:        data,
77                 Mtime:       time.Now(),
78                 Metadata:    make(map[string]string),
79                 Uncommitted: make(map[string][]byte),
80         }
81 }
82
83 func (h *azStubHandler) unlockAndRace() {
84         if h.race == nil {
85                 return
86         }
87         h.Unlock()
88         // Signal caller that race is starting by reading from
89         // h.race. If we get a channel, block until that channel is
90         // ready to receive. If we get nil (or h.race is closed) just
91         // proceed.
92         if c := <-h.race; c != nil {
93                 c <- struct{}{}
94         }
95         h.Lock()
96 }
97
98 var rangeRegexp = regexp.MustCompile(`^bytes=(\d+)-(\d+)$`)
99
100 func (h *azStubHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
101         h.Lock()
102         defer h.Unlock()
103         // defer log.Printf("azStubHandler: %+v", r)
104
105         path := strings.Split(r.URL.Path, "/")
106         container := path[1]
107         hash := ""
108         if len(path) > 2 {
109                 hash = path[2]
110         }
111
112         if err := r.ParseForm(); err != nil {
113                 log.Printf("azStubHandler(%+v): %s", r, err)
114                 rw.WriteHeader(http.StatusBadRequest)
115                 return
116         }
117
118         body, err := ioutil.ReadAll(r.Body)
119         if err != nil {
120                 return
121         }
122
123         type blockListRequestBody struct {
124                 XMLName     xml.Name `xml:"BlockList"`
125                 Uncommitted []string
126         }
127
128         blob, blobExists := h.blobs[container+"|"+hash]
129
130         switch {
131         case r.Method == "PUT" && r.Form.Get("comp") == "":
132                 // "Put Blob" API
133                 if _, ok := h.blobs[container+"|"+hash]; !ok {
134                         // Like the real Azure service, we offer a
135                         // race window during which other clients can
136                         // list/get the new blob before any data is
137                         // committed.
138                         h.blobs[container+"|"+hash] = &azBlob{
139                                 Mtime:       time.Now(),
140                                 Uncommitted: make(map[string][]byte),
141                                 Metadata:    make(map[string]string),
142                                 Etag:        makeEtag(),
143                         }
144                         h.unlockAndRace()
145                 }
146                 metadata := make(map[string]string)
147                 for k, v := range r.Header {
148                         if strings.HasPrefix(strings.ToLower(k), "x-ms-meta-") {
149                                 name := k[len("x-ms-meta-"):]
150                                 metadata[strings.ToLower(name)] = v[0]
151                         }
152                 }
153                 h.blobs[container+"|"+hash] = &azBlob{
154                         Data:        body,
155                         Mtime:       time.Now(),
156                         Uncommitted: make(map[string][]byte),
157                         Metadata:    metadata,
158                         Etag:        makeEtag(),
159                 }
160                 rw.WriteHeader(http.StatusCreated)
161         case r.Method == "PUT" && r.Form.Get("comp") == "block":
162                 // "Put Block" API
163                 if !blobExists {
164                         log.Printf("Got block for nonexistent blob: %+v", r)
165                         rw.WriteHeader(http.StatusBadRequest)
166                         return
167                 }
168                 blockID, err := base64.StdEncoding.DecodeString(r.Form.Get("blockid"))
169                 if err != nil || len(blockID) == 0 {
170                         log.Printf("Invalid blockid: %+q", r.Form.Get("blockid"))
171                         rw.WriteHeader(http.StatusBadRequest)
172                         return
173                 }
174                 blob.Uncommitted[string(blockID)] = body
175                 rw.WriteHeader(http.StatusCreated)
176         case r.Method == "PUT" && r.Form.Get("comp") == "blocklist":
177                 // "Put Block List" API
178                 bl := &blockListRequestBody{}
179                 if err := xml.Unmarshal(body, bl); err != nil {
180                         log.Printf("xml Unmarshal: %s", err)
181                         rw.WriteHeader(http.StatusBadRequest)
182                         return
183                 }
184                 for _, encBlockID := range bl.Uncommitted {
185                         blockID, err := base64.StdEncoding.DecodeString(encBlockID)
186                         if err != nil || len(blockID) == 0 || blob.Uncommitted[string(blockID)] == nil {
187                                 log.Printf("Invalid blockid: %+q", encBlockID)
188                                 rw.WriteHeader(http.StatusBadRequest)
189                                 return
190                         }
191                         blob.Data = blob.Uncommitted[string(blockID)]
192                         blob.Etag = makeEtag()
193                         blob.Mtime = time.Now()
194                         delete(blob.Uncommitted, string(blockID))
195                 }
196                 rw.WriteHeader(http.StatusCreated)
197         case r.Method == "PUT" && r.Form.Get("comp") == "metadata":
198                 // "Set Metadata Headers" API. We don't bother
199                 // stubbing "Get Metadata Headers": AzureBlobVolume
200                 // sets metadata headers only as a way to bump Etag
201                 // and Last-Modified.
202                 if !blobExists {
203                         log.Printf("Got metadata for nonexistent blob: %+v", r)
204                         rw.WriteHeader(http.StatusBadRequest)
205                         return
206                 }
207                 blob.Metadata = make(map[string]string)
208                 for k, v := range r.Header {
209                         if strings.HasPrefix(strings.ToLower(k), "x-ms-meta-") {
210                                 name := k[len("x-ms-meta-"):]
211                                 blob.Metadata[strings.ToLower(name)] = v[0]
212                         }
213                 }
214                 blob.Mtime = time.Now()
215                 blob.Etag = makeEtag()
216         case (r.Method == "GET" || r.Method == "HEAD") && r.Form.Get("comp") == "metadata" && hash != "":
217                 // "Get Blob Metadata" API
218                 if !blobExists {
219                         rw.WriteHeader(http.StatusNotFound)
220                         return
221                 }
222                 for k, v := range blob.Metadata {
223                         rw.Header().Set(fmt.Sprintf("x-ms-meta-%s", k), v)
224                 }
225                 return
226         case (r.Method == "GET" || r.Method == "HEAD") && hash != "":
227                 // "Get Blob" API
228                 if !blobExists {
229                         rw.WriteHeader(http.StatusNotFound)
230                         return
231                 }
232                 data := blob.Data
233                 if rangeSpec := rangeRegexp.FindStringSubmatch(r.Header.Get("Range")); rangeSpec != nil {
234                         b0, err0 := strconv.Atoi(rangeSpec[1])
235                         b1, err1 := strconv.Atoi(rangeSpec[2])
236                         if err0 != nil || err1 != nil || b0 >= len(data) || b1 >= len(data) || b0 > b1 {
237                                 rw.Header().Set("Content-Range", fmt.Sprintf("bytes */%d", len(data)))
238                                 rw.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
239                                 return
240                         }
241                         rw.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", b0, b1, len(data)))
242                         rw.WriteHeader(http.StatusPartialContent)
243                         data = data[b0 : b1+1]
244                 }
245                 rw.Header().Set("Last-Modified", blob.Mtime.Format(time.RFC1123))
246                 rw.Header().Set("Content-Length", strconv.Itoa(len(data)))
247                 if r.Method == "GET" {
248                         if _, err := rw.Write(data); err != nil {
249                                 log.Printf("write %+q: %s", data, err)
250                         }
251                 }
252                 h.unlockAndRace()
253         case r.Method == "DELETE" && hash != "":
254                 // "Delete Blob" API
255                 if !blobExists {
256                         rw.WriteHeader(http.StatusNotFound)
257                         return
258                 }
259                 delete(h.blobs, container+"|"+hash)
260                 rw.WriteHeader(http.StatusAccepted)
261         case r.Method == "GET" && r.Form.Get("comp") == "list" && r.Form.Get("restype") == "container":
262                 // "List Blobs" API
263                 prefix := container + "|" + r.Form.Get("prefix")
264                 marker := r.Form.Get("marker")
265
266                 maxResults := 2
267                 if n, err := strconv.Atoi(r.Form.Get("maxresults")); err == nil && n >= 1 && n <= 5000 {
268                         maxResults = n
269                 }
270
271                 resp := storage.BlobListResponse{
272                         Marker:     marker,
273                         NextMarker: "",
274                         MaxResults: int64(maxResults),
275                 }
276                 var hashes sort.StringSlice
277                 for k := range h.blobs {
278                         if strings.HasPrefix(k, prefix) {
279                                 hashes = append(hashes, k[len(container)+1:])
280                         }
281                 }
282                 hashes.Sort()
283                 for _, hash := range hashes {
284                         if len(resp.Blobs) == maxResults {
285                                 resp.NextMarker = hash
286                                 break
287                         }
288                         if len(resp.Blobs) > 0 || marker == "" || marker == hash {
289                                 blob := h.blobs[container+"|"+hash]
290                                 bmeta := map[string]string(nil)
291                                 if r.Form.Get("include") == "metadata" {
292                                         bmeta = blob.Metadata
293                                 }
294                                 b := storage.Blob{
295                                         Name: hash,
296                                         Properties: storage.BlobProperties{
297                                                 LastModified:  blob.Mtime.Format(time.RFC1123),
298                                                 ContentLength: int64(len(blob.Data)),
299                                                 Etag:          blob.Etag,
300                                         },
301                                         Metadata: bmeta,
302                                 }
303                                 resp.Blobs = append(resp.Blobs, b)
304                         }
305                 }
306                 buf, err := xml.Marshal(resp)
307                 if err != nil {
308                         log.Print(err)
309                         rw.WriteHeader(http.StatusInternalServerError)
310                 }
311                 rw.Write(buf)
312         default:
313                 log.Printf("azStubHandler: not implemented: %+v Body:%+q", r, body)
314                 rw.WriteHeader(http.StatusNotImplemented)
315         }
316 }
317
318 // azStubDialer is a net.Dialer that notices when the Azure driver
319 // tries to connect to "devstoreaccount1.blob.127.0.0.1:46067", and
320 // in such cases transparently dials "127.0.0.1:46067" instead.
321 type azStubDialer struct {
322         net.Dialer
323 }
324
325 var localHostPortRe = regexp.MustCompile(`(127\.0\.0\.1|localhost|\[::1\]):\d+`)
326
327 func (d *azStubDialer) Dial(network, address string) (net.Conn, error) {
328         if hp := localHostPortRe.FindString(address); hp != "" {
329                 log.Println("azStubDialer: dial", hp, "instead of", address)
330                 address = hp
331         }
332         return d.Dialer.Dial(network, address)
333 }
334
335 type TestableAzureBlobVolume struct {
336         *AzureBlobVolume
337         azHandler *azStubHandler
338         azStub    *httptest.Server
339         t         TB
340 }
341
342 func NewTestableAzureBlobVolume(t TB, readonly bool, replication int) *TestableAzureBlobVolume {
343         azHandler := newAzStubHandler()
344         azStub := httptest.NewServer(azHandler)
345
346         var azClient storage.Client
347
348         container := azureTestContainer
349         if container == "" {
350                 // Connect to stub instead of real Azure storage service
351                 stubURLBase := strings.Split(azStub.URL, "://")[1]
352                 var err error
353                 if azClient, err = storage.NewClient(emulatorAccountName, emulatorAccountKey, stubURLBase, storage.DefaultAPIVersion, false); err != nil {
354                         t.Fatal(err)
355                 }
356                 container = "fakecontainername"
357         } else {
358                 // Connect to real Azure storage service
359                 accountKey, err := readKeyFromFile(azureStorageAccountKeyFile)
360                 if err != nil {
361                         t.Fatal(err)
362                 }
363                 azClient, err = storage.NewBasicClient(azureStorageAccountName, accountKey)
364                 if err != nil {
365                         t.Fatal(err)
366                 }
367         }
368
369         v := &AzureBlobVolume{
370                 ContainerName:    container,
371                 ReadOnly:         readonly,
372                 AzureReplication: replication,
373                 azClient:         azClient,
374                 bsClient:         azClient.GetBlobService(),
375         }
376
377         return &TestableAzureBlobVolume{
378                 AzureBlobVolume: v,
379                 azHandler:       azHandler,
380                 azStub:          azStub,
381                 t:               t,
382         }
383 }
384
385 func TestAzureBlobVolumeWithGeneric(t *testing.T) {
386         defer func(t http.RoundTripper) {
387                 http.DefaultTransport = t
388         }(http.DefaultTransport)
389         http.DefaultTransport = &http.Transport{
390                 Dial: (&azStubDialer{}).Dial,
391         }
392         azureWriteRaceInterval = time.Millisecond
393         azureWriteRacePollTime = time.Nanosecond
394         DoGenericVolumeTests(t, func(t TB) TestableVolume {
395                 return NewTestableAzureBlobVolume(t, false, azureStorageReplication)
396         })
397 }
398
399 func TestAzureBlobVolumeConcurrentRanges(t *testing.T) {
400         defer func(b int) {
401                 azureMaxGetBytes = b
402         }(azureMaxGetBytes)
403
404         defer func(t http.RoundTripper) {
405                 http.DefaultTransport = t
406         }(http.DefaultTransport)
407         http.DefaultTransport = &http.Transport{
408                 Dial: (&azStubDialer{}).Dial,
409         }
410         azureWriteRaceInterval = time.Millisecond
411         azureWriteRacePollTime = time.Nanosecond
412         // Test (BlockSize mod azureMaxGetBytes)==0 and !=0 cases
413         for _, azureMaxGetBytes = range []int{2 << 22, 2<<22 - 1} {
414                 DoGenericVolumeTests(t, func(t TB) TestableVolume {
415                         return NewTestableAzureBlobVolume(t, false, azureStorageReplication)
416                 })
417         }
418 }
419
420 func TestReadonlyAzureBlobVolumeWithGeneric(t *testing.T) {
421         defer func(t http.RoundTripper) {
422                 http.DefaultTransport = t
423         }(http.DefaultTransport)
424         http.DefaultTransport = &http.Transport{
425                 Dial: (&azStubDialer{}).Dial,
426         }
427         azureWriteRaceInterval = time.Millisecond
428         azureWriteRacePollTime = time.Nanosecond
429         DoGenericVolumeTests(t, func(t TB) TestableVolume {
430                 return NewTestableAzureBlobVolume(t, true, azureStorageReplication)
431         })
432 }
433
434 func TestAzureBlobVolumeRangeFenceposts(t *testing.T) {
435         defer func(t http.RoundTripper) {
436                 http.DefaultTransport = t
437         }(http.DefaultTransport)
438         http.DefaultTransport = &http.Transport{
439                 Dial: (&azStubDialer{}).Dial,
440         }
441
442         v := NewTestableAzureBlobVolume(t, false, 3)
443         defer v.Teardown()
444
445         for _, size := range []int{
446                 2<<22 - 1, // one <max read
447                 2 << 22,   // one =max read
448                 2<<22 + 1, // one =max read, one <max
449                 2 << 23,   // two =max reads
450                 BlockSize - 1,
451                 BlockSize,
452         } {
453                 data := make([]byte, size)
454                 for i := range data {
455                         data[i] = byte((i + 7) & 0xff)
456                 }
457                 hash := fmt.Sprintf("%x", md5.Sum(data))
458                 err := v.Put(context.Background(), hash, data)
459                 if err != nil {
460                         t.Error(err)
461                 }
462                 gotData := make([]byte, len(data))
463                 gotLen, err := v.Get(context.Background(), hash, gotData)
464                 if err != nil {
465                         t.Error(err)
466                 }
467                 gotHash := fmt.Sprintf("%x", md5.Sum(gotData))
468                 if gotLen != size {
469                         t.Error("length mismatch: got %d != %d", gotLen, size)
470                 }
471                 if gotHash != hash {
472                         t.Error("hash mismatch: got %s != %s", gotHash, hash)
473                 }
474         }
475 }
476
477 func TestAzureBlobVolumeReplication(t *testing.T) {
478         for r := 1; r <= 4; r++ {
479                 v := NewTestableAzureBlobVolume(t, false, r)
480                 defer v.Teardown()
481                 if n := v.Replication(); n != r {
482                         t.Errorf("Got replication %d, expected %d", n, r)
483                 }
484         }
485 }
486
487 func TestAzureBlobVolumeCreateBlobRace(t *testing.T) {
488         defer func(t http.RoundTripper) {
489                 http.DefaultTransport = t
490         }(http.DefaultTransport)
491         http.DefaultTransport = &http.Transport{
492                 Dial: (&azStubDialer{}).Dial,
493         }
494
495         v := NewTestableAzureBlobVolume(t, false, 3)
496         defer v.Teardown()
497
498         azureWriteRaceInterval = time.Second
499         azureWriteRacePollTime = time.Millisecond
500
501         allDone := make(chan struct{})
502         v.azHandler.race = make(chan chan struct{})
503         go func() {
504                 err := v.Put(context.Background(), TestHash, TestBlock)
505                 if err != nil {
506                         t.Error(err)
507                 }
508         }()
509         continuePut := make(chan struct{})
510         // Wait for the stub's Put to create the empty blob
511         v.azHandler.race <- continuePut
512         go func() {
513                 buf := make([]byte, len(TestBlock))
514                 _, err := v.Get(context.Background(), TestHash, buf)
515                 if err != nil {
516                         t.Error(err)
517                 }
518                 close(allDone)
519         }()
520         // Wait for the stub's Get to get the empty blob
521         close(v.azHandler.race)
522         // Allow stub's Put to continue, so the real data is ready
523         // when the volume's Get retries
524         <-continuePut
525         // Wait for volume's Get to return the real data
526         <-allDone
527 }
528
529 func TestAzureBlobVolumeCreateBlobRaceDeadline(t *testing.T) {
530         defer func(t http.RoundTripper) {
531                 http.DefaultTransport = t
532         }(http.DefaultTransport)
533         http.DefaultTransport = &http.Transport{
534                 Dial: (&azStubDialer{}).Dial,
535         }
536
537         v := NewTestableAzureBlobVolume(t, false, 3)
538         defer v.Teardown()
539
540         azureWriteRaceInterval = 2 * time.Second
541         azureWriteRacePollTime = 5 * time.Millisecond
542
543         v.PutRaw(TestHash, nil)
544
545         buf := new(bytes.Buffer)
546         v.IndexTo("", buf)
547         if buf.Len() != 0 {
548                 t.Errorf("Index %+q should be empty", buf.Bytes())
549         }
550
551         v.TouchWithDate(TestHash, time.Now().Add(-1982*time.Millisecond))
552
553         allDone := make(chan struct{})
554         go func() {
555                 defer close(allDone)
556                 buf := make([]byte, BlockSize)
557                 n, err := v.Get(context.Background(), TestHash, buf)
558                 if err != nil {
559                         t.Error(err)
560                         return
561                 }
562                 if n != 0 {
563                         t.Errorf("Got %+q, expected empty buf", buf[:n])
564                 }
565         }()
566         select {
567         case <-allDone:
568         case <-time.After(time.Second):
569                 t.Error("Get should have stopped waiting for race when block was 2s old")
570         }
571
572         buf.Reset()
573         v.IndexTo("", buf)
574         if !bytes.HasPrefix(buf.Bytes(), []byte(TestHash+"+0")) {
575                 t.Errorf("Index %+q should have %+q", buf.Bytes(), TestHash+"+0")
576         }
577 }
578
579 func (v *TestableAzureBlobVolume) PutRaw(locator string, data []byte) {
580         v.azHandler.PutRaw(v.ContainerName, locator, data)
581 }
582
583 func (v *TestableAzureBlobVolume) TouchWithDate(locator string, lastPut time.Time) {
584         v.azHandler.TouchWithDate(v.ContainerName, locator, lastPut)
585 }
586
587 func (v *TestableAzureBlobVolume) Teardown() {
588         v.azStub.Close()
589 }
590
591 func makeEtag() string {
592         return fmt.Sprintf("0x%x", rand.Int63())
593 }