Merge branch '7167-propagate-error' refs #7167
[arvados.git] / services / keepstore / azure_blob_volume_test.go
1 package main
2
3 import (
4         "bytes"
5         "encoding/base64"
6         "encoding/xml"
7         "flag"
8         "fmt"
9         "io/ioutil"
10         "log"
11         "math/rand"
12         "net"
13         "net/http"
14         "net/http/httptest"
15         "regexp"
16         "sort"
17         "strconv"
18         "strings"
19         "sync"
20         "testing"
21         "time"
22
23         "github.com/curoverse/azure-sdk-for-go/storage"
24 )
25
26 const (
27         // The same fake credentials used by Microsoft's Azure emulator
28         emulatorAccountName = "devstoreaccount1"
29         emulatorAccountKey  = "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw=="
30 )
31
32 var azureTestContainer string
33
34 func init() {
35         flag.StringVar(
36                 &azureTestContainer,
37                 "test.azure-storage-container-volume",
38                 "",
39                 "Name of Azure container to use for testing. Do not use a container with real data! Use -azure-storage-account-name and -azure-storage-key-file arguments to supply credentials.")
40 }
41
42 type azBlob struct {
43         Data        []byte
44         Etag        string
45         Metadata    map[string]string
46         Mtime       time.Time
47         Uncommitted map[string][]byte
48 }
49
50 type azStubHandler struct {
51         sync.Mutex
52         blobs map[string]*azBlob
53         race  chan chan struct{}
54 }
55
56 func newAzStubHandler() *azStubHandler {
57         return &azStubHandler{
58                 blobs: make(map[string]*azBlob),
59         }
60 }
61
62 func (h *azStubHandler) TouchWithDate(container, hash string, t time.Time) {
63         blob, ok := h.blobs[container+"|"+hash]
64         if !ok {
65                 return
66         }
67         blob.Mtime = t
68 }
69
70 func (h *azStubHandler) PutRaw(container, hash string, data []byte) {
71         h.Lock()
72         defer h.Unlock()
73         h.blobs[container+"|"+hash] = &azBlob{
74                 Data:        data,
75                 Mtime:       time.Now(),
76                 Uncommitted: make(map[string][]byte),
77         }
78 }
79
80 func (h *azStubHandler) unlockAndRace() {
81         if h.race == nil {
82                 return
83         }
84         h.Unlock()
85         // Signal caller that race is starting by reading from
86         // h.race. If we get a channel, block until that channel is
87         // ready to receive. If we get nil (or h.race is closed) just
88         // proceed.
89         if c := <-h.race; c != nil {
90                 c <- struct{}{}
91         }
92         h.Lock()
93 }
94
95 func (h *azStubHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
96         h.Lock()
97         defer h.Unlock()
98         // defer log.Printf("azStubHandler: %+v", r)
99
100         path := strings.Split(r.URL.Path, "/")
101         container := path[1]
102         hash := ""
103         if len(path) > 2 {
104                 hash = path[2]
105         }
106
107         if err := r.ParseForm(); err != nil {
108                 log.Printf("azStubHandler(%+v): %s", r, err)
109                 rw.WriteHeader(http.StatusBadRequest)
110                 return
111         }
112
113         body, err := ioutil.ReadAll(r.Body)
114         if err != nil {
115                 return
116         }
117
118         type blockListRequestBody struct {
119                 XMLName     xml.Name `xml:"BlockList"`
120                 Uncommitted []string
121         }
122
123         blob, blobExists := h.blobs[container+"|"+hash]
124
125         switch {
126         case r.Method == "PUT" && r.Form.Get("comp") == "":
127                 // "Put Blob" API
128                 if _, ok := h.blobs[container+"|"+hash]; !ok {
129                         // Like the real Azure service, we offer a
130                         // race window during which other clients can
131                         // list/get the new blob before any data is
132                         // committed.
133                         h.blobs[container+"|"+hash] = &azBlob{
134                                 Mtime:       time.Now(),
135                                 Uncommitted: make(map[string][]byte),
136                                 Etag:        makeEtag(),
137                         }
138                         h.unlockAndRace()
139                 }
140                 h.blobs[container+"|"+hash] = &azBlob{
141                         Data:        body,
142                         Mtime:       time.Now(),
143                         Uncommitted: make(map[string][]byte),
144                         Etag:        makeEtag(),
145                 }
146                 rw.WriteHeader(http.StatusCreated)
147         case r.Method == "PUT" && r.Form.Get("comp") == "block":
148                 // "Put Block" API
149                 if !blobExists {
150                         log.Printf("Got block for nonexistent blob: %+v", r)
151                         rw.WriteHeader(http.StatusBadRequest)
152                         return
153                 }
154                 blockID, err := base64.StdEncoding.DecodeString(r.Form.Get("blockid"))
155                 if err != nil || len(blockID) == 0 {
156                         log.Printf("Invalid blockid: %+q", r.Form.Get("blockid"))
157                         rw.WriteHeader(http.StatusBadRequest)
158                         return
159                 }
160                 blob.Uncommitted[string(blockID)] = body
161                 rw.WriteHeader(http.StatusCreated)
162         case r.Method == "PUT" && r.Form.Get("comp") == "blocklist":
163                 // "Put Block List" API
164                 bl := &blockListRequestBody{}
165                 if err := xml.Unmarshal(body, bl); err != nil {
166                         log.Printf("xml Unmarshal: %s", err)
167                         rw.WriteHeader(http.StatusBadRequest)
168                         return
169                 }
170                 for _, encBlockID := range bl.Uncommitted {
171                         blockID, err := base64.StdEncoding.DecodeString(encBlockID)
172                         if err != nil || len(blockID) == 0 || blob.Uncommitted[string(blockID)] == nil {
173                                 log.Printf("Invalid blockid: %+q", encBlockID)
174                                 rw.WriteHeader(http.StatusBadRequest)
175                                 return
176                         }
177                         blob.Data = blob.Uncommitted[string(blockID)]
178                         blob.Etag = makeEtag()
179                         blob.Mtime = time.Now()
180                         delete(blob.Uncommitted, string(blockID))
181                 }
182                 rw.WriteHeader(http.StatusCreated)
183         case r.Method == "PUT" && r.Form.Get("comp") == "metadata":
184                 // "Set Metadata Headers" API. We don't bother
185                 // stubbing "Get Metadata Headers": AzureBlobVolume
186                 // sets metadata headers only as a way to bump Etag
187                 // and Last-Modified.
188                 if !blobExists {
189                         log.Printf("Got metadata for nonexistent blob: %+v", r)
190                         rw.WriteHeader(http.StatusBadRequest)
191                         return
192                 }
193                 blob.Metadata = make(map[string]string)
194                 for k, v := range r.Header {
195                         if strings.HasPrefix(strings.ToLower(k), "x-ms-meta-") {
196                                 blob.Metadata[k] = v[0]
197                         }
198                 }
199                 blob.Mtime = time.Now()
200                 blob.Etag = makeEtag()
201         case (r.Method == "GET" || r.Method == "HEAD") && hash != "":
202                 // "Get Blob" API
203                 if !blobExists {
204                         rw.WriteHeader(http.StatusNotFound)
205                         return
206                 }
207                 rw.Header().Set("Last-Modified", blob.Mtime.Format(time.RFC1123))
208                 rw.Header().Set("Content-Length", strconv.Itoa(len(blob.Data)))
209                 if r.Method == "GET" {
210                         if _, err := rw.Write(blob.Data); err != nil {
211                                 log.Printf("write %+q: %s", blob.Data, err)
212                         }
213                 }
214                 h.unlockAndRace()
215         case r.Method == "DELETE" && hash != "":
216                 // "Delete Blob" API
217                 if !blobExists {
218                         rw.WriteHeader(http.StatusNotFound)
219                         return
220                 }
221                 delete(h.blobs, container+"|"+hash)
222                 rw.WriteHeader(http.StatusAccepted)
223         case r.Method == "GET" && r.Form.Get("comp") == "list" && r.Form.Get("restype") == "container":
224                 // "List Blobs" API
225                 prefix := container + "|" + r.Form.Get("prefix")
226                 marker := r.Form.Get("marker")
227
228                 maxResults := 2
229                 if n, err := strconv.Atoi(r.Form.Get("maxresults")); err == nil && n >= 1 && n <= 5000 {
230                         maxResults = n
231                 }
232
233                 resp := storage.BlobListResponse{
234                         Marker:     marker,
235                         NextMarker: "",
236                         MaxResults: int64(maxResults),
237                 }
238                 var hashes sort.StringSlice
239                 for k := range h.blobs {
240                         if strings.HasPrefix(k, prefix) {
241                                 hashes = append(hashes, k[len(container)+1:])
242                         }
243                 }
244                 hashes.Sort()
245                 for _, hash := range hashes {
246                         if len(resp.Blobs) == maxResults {
247                                 resp.NextMarker = hash
248                                 break
249                         }
250                         if len(resp.Blobs) > 0 || marker == "" || marker == hash {
251                                 blob := h.blobs[container+"|"+hash]
252                                 resp.Blobs = append(resp.Blobs, storage.Blob{
253                                         Name: hash,
254                                         Properties: storage.BlobProperties{
255                                                 LastModified:  blob.Mtime.Format(time.RFC1123),
256                                                 ContentLength: int64(len(blob.Data)),
257                                                 Etag:          blob.Etag,
258                                         },
259                                 })
260                         }
261                 }
262                 buf, err := xml.Marshal(resp)
263                 if err != nil {
264                         log.Print(err)
265                         rw.WriteHeader(http.StatusInternalServerError)
266                 }
267                 rw.Write(buf)
268         default:
269                 log.Printf("azStubHandler: not implemented: %+v Body:%+q", r, body)
270                 rw.WriteHeader(http.StatusNotImplemented)
271         }
272 }
273
274 // azStubDialer is a net.Dialer that notices when the Azure driver
275 // tries to connect to "devstoreaccount1.blob.127.0.0.1:46067", and
276 // in such cases transparently dials "127.0.0.1:46067" instead.
277 type azStubDialer struct {
278         net.Dialer
279 }
280
281 var localHostPortRe = regexp.MustCompile(`(127\.0\.0\.1|localhost|\[::1\]):\d+`)
282
283 func (d *azStubDialer) Dial(network, address string) (net.Conn, error) {
284         if hp := localHostPortRe.FindString(address); hp != "" {
285                 log.Println("azStubDialer: dial", hp, "instead of", address)
286                 address = hp
287         }
288         return d.Dialer.Dial(network, address)
289 }
290
291 type TestableAzureBlobVolume struct {
292         *AzureBlobVolume
293         azHandler *azStubHandler
294         azStub    *httptest.Server
295         t         *testing.T
296 }
297
298 func NewTestableAzureBlobVolume(t *testing.T, readonly bool, replication int) *TestableAzureBlobVolume {
299         azHandler := newAzStubHandler()
300         azStub := httptest.NewServer(azHandler)
301
302         var azClient storage.Client
303
304         container := azureTestContainer
305         if container == "" {
306                 // Connect to stub instead of real Azure storage service
307                 stubURLBase := strings.Split(azStub.URL, "://")[1]
308                 var err error
309                 if azClient, err = storage.NewClient(emulatorAccountName, emulatorAccountKey, stubURLBase, storage.DefaultAPIVersion, false); err != nil {
310                         t.Fatal(err)
311                 }
312                 container = "fakecontainername"
313         } else {
314                 // Connect to real Azure storage service
315                 accountKey, err := readKeyFromFile(azureStorageAccountKeyFile)
316                 if err != nil {
317                         t.Fatal(err)
318                 }
319                 azClient, err = storage.NewBasicClient(azureStorageAccountName, accountKey)
320                 if err != nil {
321                         t.Fatal(err)
322                 }
323         }
324
325         v := NewAzureBlobVolume(azClient, container, readonly, replication)
326
327         return &TestableAzureBlobVolume{
328                 AzureBlobVolume: v,
329                 azHandler:       azHandler,
330                 azStub:          azStub,
331                 t:               t,
332         }
333 }
334
335 func TestAzureBlobVolumeWithGeneric(t *testing.T) {
336         defer func(t http.RoundTripper) {
337                 http.DefaultTransport = t
338         }(http.DefaultTransport)
339         http.DefaultTransport = &http.Transport{
340                 Dial: (&azStubDialer{}).Dial,
341         }
342         azureWriteRaceInterval = time.Millisecond
343         azureWriteRacePollTime = time.Nanosecond
344         DoGenericVolumeTests(t, func(t *testing.T) TestableVolume {
345                 return NewTestableAzureBlobVolume(t, false, azureStorageReplication)
346         })
347 }
348
349 func TestReadonlyAzureBlobVolumeWithGeneric(t *testing.T) {
350         defer func(t http.RoundTripper) {
351                 http.DefaultTransport = t
352         }(http.DefaultTransport)
353         http.DefaultTransport = &http.Transport{
354                 Dial: (&azStubDialer{}).Dial,
355         }
356         azureWriteRaceInterval = time.Millisecond
357         azureWriteRacePollTime = time.Nanosecond
358         DoGenericVolumeTests(t, func(t *testing.T) TestableVolume {
359                 return NewTestableAzureBlobVolume(t, true, azureStorageReplication)
360         })
361 }
362
363 func TestAzureBlobVolumeReplication(t *testing.T) {
364         for r := 1; r <= 4; r++ {
365                 v := NewTestableAzureBlobVolume(t, false, r)
366                 defer v.Teardown()
367                 if n := v.Replication(); n != r {
368                         t.Errorf("Got replication %d, expected %d", n, r)
369                 }
370         }
371 }
372
373 func TestAzureBlobVolumeCreateBlobRace(t *testing.T) {
374         defer func(t http.RoundTripper) {
375                 http.DefaultTransport = t
376         }(http.DefaultTransport)
377         http.DefaultTransport = &http.Transport{
378                 Dial: (&azStubDialer{}).Dial,
379         }
380
381         v := NewTestableAzureBlobVolume(t, false, 3)
382         defer v.Teardown()
383
384         azureWriteRaceInterval = time.Second
385         azureWriteRacePollTime = time.Millisecond
386
387         allDone := make(chan struct{})
388         v.azHandler.race = make(chan chan struct{})
389         go func() {
390                 err := v.Put(TestHash, TestBlock)
391                 if err != nil {
392                         t.Error(err)
393                 }
394         }()
395         continuePut := make(chan struct{})
396         // Wait for the stub's Put to create the empty blob
397         v.azHandler.race <- continuePut
398         go func() {
399                 buf, err := v.Get(TestHash)
400                 if err != nil {
401                         t.Error(err)
402                 } else {
403                         bufs.Put(buf)
404                 }
405                 close(allDone)
406         }()
407         // Wait for the stub's Get to get the empty blob
408         close(v.azHandler.race)
409         // Allow stub's Put to continue, so the real data is ready
410         // when the volume's Get retries
411         <-continuePut
412         // Wait for volume's Get to return the real data
413         <-allDone
414 }
415
416 func TestAzureBlobVolumeCreateBlobRaceDeadline(t *testing.T) {
417         defer func(t http.RoundTripper) {
418                 http.DefaultTransport = t
419         }(http.DefaultTransport)
420         http.DefaultTransport = &http.Transport{
421                 Dial: (&azStubDialer{}).Dial,
422         }
423
424         v := NewTestableAzureBlobVolume(t, false, 3)
425         defer v.Teardown()
426
427         azureWriteRaceInterval = 2 * time.Second
428         azureWriteRacePollTime = 5 * time.Millisecond
429
430         v.PutRaw(TestHash, nil)
431
432         buf := new(bytes.Buffer)
433         v.IndexTo("", buf)
434         if buf.Len() != 0 {
435                 t.Errorf("Index %+q should be empty", buf.Bytes())
436         }
437
438         v.TouchWithDate(TestHash, time.Now().Add(-1982 * time.Millisecond))
439
440         allDone := make(chan struct{})
441         go func() {
442                 defer close(allDone)
443                 buf, err := v.Get(TestHash)
444                 if err != nil {
445                         t.Error(err)
446                         return
447                 }
448                 if len(buf) != 0 {
449                         t.Errorf("Got %+q, expected empty buf", buf)
450                 }
451                 bufs.Put(buf)
452         }()
453         select {
454         case <-allDone:
455         case <-time.After(time.Second):
456                 t.Error("Get should have stopped waiting for race when block was 2s old")
457         }
458
459         buf.Reset()
460         v.IndexTo("", buf)
461         if !bytes.HasPrefix(buf.Bytes(), []byte(TestHash+"+0")) {
462                 t.Errorf("Index %+q should have %+q", buf.Bytes(), TestHash+"+0")
463         }
464 }
465
466 func (v *TestableAzureBlobVolume) PutRaw(locator string, data []byte) {
467         v.azHandler.PutRaw(v.containerName, locator, data)
468 }
469
470 func (v *TestableAzureBlobVolume) TouchWithDate(locator string, lastPut time.Time) {
471         v.azHandler.TouchWithDate(v.containerName, locator, lastPut)
472 }
473
474 func (v *TestableAzureBlobVolume) Teardown() {
475         v.azStub.Close()
476 }
477
478 func makeEtag() string {
479         return fmt.Sprintf("0x%x", rand.Int63())
480 }