Fixup test_node_undrained_when_shutdown_cancelled and test_alloc_node_undrained_when_...
[arvados.git] / services / keepstore / azure_blob_volume_test.go
index c3fea9a80d8be19fc7ab3bf3b01c42196c6147b6..439b40221465ada53c805c7b7afb47ba974652a9 100644 (file)
@@ -2,6 +2,7 @@ package main
 
 import (
        "bytes"
+       "crypto/md5"
        "encoding/base64"
        "encoding/xml"
        "flag"
@@ -60,11 +61,11 @@ func newAzStubHandler() *azStubHandler {
 }
 
 func (h *azStubHandler) TouchWithDate(container, hash string, t time.Time) {
-       if blob, ok := h.blobs[container+"|"+hash]; !ok {
+       blob, ok := h.blobs[container+"|"+hash]
+       if !ok {
                return
-       } else {
-               blob.Mtime = t
        }
+       blob.Mtime = t
 }
 
 func (h *azStubHandler) PutRaw(container, hash string, data []byte) {
@@ -92,6 +93,8 @@ func (h *azStubHandler) unlockAndRace() {
        h.Lock()
 }
 
+var rangeRegexp = regexp.MustCompile(`^bytes=(\d+)-(\d+)$`)
+
 func (h *azStubHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
        h.Lock()
        defer h.Unlock()
@@ -204,11 +207,24 @@ func (h *azStubHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
                        rw.WriteHeader(http.StatusNotFound)
                        return
                }
+               data := blob.Data
+               if rangeSpec := rangeRegexp.FindStringSubmatch(r.Header.Get("Range")); rangeSpec != nil {
+                       b0, err0 := strconv.Atoi(rangeSpec[1])
+                       b1, err1 := strconv.Atoi(rangeSpec[2])
+                       if err0 != nil || err1 != nil || b0 >= len(data) || b1 >= len(data) || b0 > b1 {
+                               rw.Header().Set("Content-Range", fmt.Sprintf("bytes */%d", len(data)))
+                               rw.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
+                               return
+                       }
+                       rw.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", b0, b1, len(data)))
+                       rw.WriteHeader(http.StatusPartialContent)
+                       data = data[b0 : b1+1]
+               }
                rw.Header().Set("Last-Modified", blob.Mtime.Format(time.RFC1123))
-               rw.Header().Set("Content-Length", strconv.Itoa(len(blob.Data)))
+               rw.Header().Set("Content-Length", strconv.Itoa(len(data)))
                if r.Method == "GET" {
-                       if _, err := rw.Write(blob.Data); err != nil {
-                               log.Printf("write %+q: %s", blob.Data, err)
+                       if _, err := rw.Write(data); err != nil {
+                               log.Printf("write %+q: %s", data, err)
                        }
                }
                h.unlockAndRace()
@@ -292,10 +308,10 @@ type TestableAzureBlobVolume struct {
        *AzureBlobVolume
        azHandler *azStubHandler
        azStub    *httptest.Server
-       t         *testing.T
+       t         TB
 }
 
-func NewTestableAzureBlobVolume(t *testing.T, readonly bool, replication int) *TestableAzureBlobVolume {
+func NewTestableAzureBlobVolume(t TB, readonly bool, replication int) *TestableAzureBlobVolume {
        azHandler := newAzStubHandler()
        azStub := httptest.NewServer(azHandler)
 
@@ -339,11 +355,34 @@ func TestAzureBlobVolumeWithGeneric(t *testing.T) {
        http.DefaultTransport = &http.Transport{
                Dial: (&azStubDialer{}).Dial,
        }
-       DoGenericVolumeTests(t, func(t *testing.T) TestableVolume {
+       azureWriteRaceInterval = time.Millisecond
+       azureWriteRacePollTime = time.Nanosecond
+       DoGenericVolumeTests(t, func(t TB) TestableVolume {
                return NewTestableAzureBlobVolume(t, false, azureStorageReplication)
        })
 }
 
+func TestAzureBlobVolumeConcurrentRanges(t *testing.T) {
+       defer func(b int) {
+               azureMaxGetBytes = b
+       }(azureMaxGetBytes)
+
+       defer func(t http.RoundTripper) {
+               http.DefaultTransport = t
+       }(http.DefaultTransport)
+       http.DefaultTransport = &http.Transport{
+               Dial: (&azStubDialer{}).Dial,
+       }
+       azureWriteRaceInterval = time.Millisecond
+       azureWriteRacePollTime = time.Nanosecond
+       // Test (BlockSize mod azureMaxGetBytes)==0 and !=0 cases
+       for _, azureMaxGetBytes = range []int{2 << 22, 2<<22 - 1} {
+               DoGenericVolumeTests(t, func(t TB) TestableVolume {
+                       return NewTestableAzureBlobVolume(t, false, azureStorageReplication)
+               })
+       }
+}
+
 func TestReadonlyAzureBlobVolumeWithGeneric(t *testing.T) {
        defer func(t http.RoundTripper) {
                http.DefaultTransport = t
@@ -351,11 +390,57 @@ func TestReadonlyAzureBlobVolumeWithGeneric(t *testing.T) {
        http.DefaultTransport = &http.Transport{
                Dial: (&azStubDialer{}).Dial,
        }
-       DoGenericVolumeTests(t, func(t *testing.T) TestableVolume {
+       azureWriteRaceInterval = time.Millisecond
+       azureWriteRacePollTime = time.Nanosecond
+       DoGenericVolumeTests(t, func(t TB) TestableVolume {
                return NewTestableAzureBlobVolume(t, true, azureStorageReplication)
        })
 }
 
+func TestAzureBlobVolumeRangeFenceposts(t *testing.T) {
+       defer func(t http.RoundTripper) {
+               http.DefaultTransport = t
+       }(http.DefaultTransport)
+       http.DefaultTransport = &http.Transport{
+               Dial: (&azStubDialer{}).Dial,
+       }
+
+       v := NewTestableAzureBlobVolume(t, false, 3)
+       defer v.Teardown()
+
+       for _, size := range []int{
+               2<<22 - 1, // one <max read
+               2 << 22,   // one =max read
+               2<<22 + 1, // one =max read, one <max
+               2 << 23,   // two =max reads
+               BlockSize - 1,
+               BlockSize,
+       } {
+               data := make([]byte, size)
+               for i := range data {
+                       data[i] = byte((i + 7) & 0xff)
+               }
+               hash := fmt.Sprintf("%x", md5.Sum(data))
+               err := v.Put(hash, data)
+               if err != nil {
+                       t.Error(err)
+               }
+               gotData, err := v.Get(hash)
+               if err != nil {
+                       t.Error(err)
+               }
+               gotHash := fmt.Sprintf("%x", md5.Sum(gotData))
+               gotLen := len(gotData)
+               bufs.Put(gotData)
+               if gotLen != size {
+                       t.Error("length mismatch: got %d != %d", gotLen, size)
+               }
+               if gotHash != hash {
+                       t.Error("hash mismatch: got %s != %s", gotHash, hash)
+               }
+       }
+}
+
 func TestAzureBlobVolumeReplication(t *testing.T) {
        for r := 1; r <= 4; r++ {
                v := NewTestableAzureBlobVolume(t, false, r)
@@ -423,7 +508,7 @@ func TestAzureBlobVolumeCreateBlobRaceDeadline(t *testing.T) {
        azureWriteRaceInterval = 2 * time.Second
        azureWriteRacePollTime = 5 * time.Millisecond
 
-       v.PutRaw(TestHash, []byte{})
+       v.PutRaw(TestHash, nil)
 
        buf := new(bytes.Buffer)
        v.IndexTo("", buf)
@@ -431,7 +516,7 @@ func TestAzureBlobVolumeCreateBlobRaceDeadline(t *testing.T) {
                t.Errorf("Index %+q should be empty", buf.Bytes())
        }
 
-       v.TouchWithDate(TestHash, time.Now().Add(-1982 * time.Millisecond))
+       v.TouchWithDate(TestHash, time.Now().Add(-1982*time.Millisecond))
 
        allDone := make(chan struct{})
        go func() {