24 "github.com/curoverse/azure-sdk-for-go/storage"
28 // The same fake credentials used by Microsoft's Azure emulator
29 emulatorAccountName = "devstoreaccount1"
30 emulatorAccountKey = "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw=="
33 var azureTestContainer string
38 "test.azure-storage-container-volume",
40 "Name of Azure container to use for testing. Do not use a container with real data! Use -azure-storage-account-name and -azure-storage-key-file arguments to supply credentials.")
46 Metadata map[string]string
48 Uncommitted map[string][]byte
51 type azStubHandler struct {
53 blobs map[string]*azBlob
54 race chan chan struct{}
57 func newAzStubHandler() *azStubHandler {
58 return &azStubHandler{
59 blobs: make(map[string]*azBlob),
63 func (h *azStubHandler) TouchWithDate(container, hash string, t time.Time) {
64 blob, ok := h.blobs[container+"|"+hash]
71 func (h *azStubHandler) PutRaw(container, hash string, data []byte) {
74 h.blobs[container+"|"+hash] = &azBlob{
77 Metadata: make(map[string]string),
78 Uncommitted: make(map[string][]byte),
82 func (h *azStubHandler) unlockAndRace() {
87 // Signal caller that race is starting by reading from
88 // h.race. If we get a channel, block until that channel is
89 // ready to receive. If we get nil (or h.race is closed) just
91 if c := <-h.race; c != nil {
97 var rangeRegexp = regexp.MustCompile(`^bytes=(\d+)-(\d+)$`)
99 func (h *azStubHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
102 // defer log.Printf("azStubHandler: %+v", r)
104 path := strings.Split(r.URL.Path, "/")
111 if err := r.ParseForm(); err != nil {
112 log.Printf("azStubHandler(%+v): %s", r, err)
113 rw.WriteHeader(http.StatusBadRequest)
117 body, err := ioutil.ReadAll(r.Body)
122 type blockListRequestBody struct {
123 XMLName xml.Name `xml:"BlockList"`
127 blob, blobExists := h.blobs[container+"|"+hash]
130 case r.Method == "PUT" && r.Form.Get("comp") == "":
132 if _, ok := h.blobs[container+"|"+hash]; !ok {
133 // Like the real Azure service, we offer a
134 // race window during which other clients can
135 // list/get the new blob before any data is
137 h.blobs[container+"|"+hash] = &azBlob{
139 Uncommitted: make(map[string][]byte),
140 Metadata: make(map[string]string),
145 metadata := make(map[string]string)
146 for k, v := range r.Header {
147 if strings.HasPrefix(strings.ToLower(k), "x-ms-meta-") {
148 name := k[len("x-ms-meta-"):]
149 metadata[strings.ToLower(name)] = v[0]
152 h.blobs[container+"|"+hash] = &azBlob{
155 Uncommitted: make(map[string][]byte),
159 rw.WriteHeader(http.StatusCreated)
160 case r.Method == "PUT" && r.Form.Get("comp") == "block":
163 log.Printf("Got block for nonexistent blob: %+v", r)
164 rw.WriteHeader(http.StatusBadRequest)
167 blockID, err := base64.StdEncoding.DecodeString(r.Form.Get("blockid"))
168 if err != nil || len(blockID) == 0 {
169 log.Printf("Invalid blockid: %+q", r.Form.Get("blockid"))
170 rw.WriteHeader(http.StatusBadRequest)
173 blob.Uncommitted[string(blockID)] = body
174 rw.WriteHeader(http.StatusCreated)
175 case r.Method == "PUT" && r.Form.Get("comp") == "blocklist":
176 // "Put Block List" API
177 bl := &blockListRequestBody{}
178 if err := xml.Unmarshal(body, bl); err != nil {
179 log.Printf("xml Unmarshal: %s", err)
180 rw.WriteHeader(http.StatusBadRequest)
183 for _, encBlockID := range bl.Uncommitted {
184 blockID, err := base64.StdEncoding.DecodeString(encBlockID)
185 if err != nil || len(blockID) == 0 || blob.Uncommitted[string(blockID)] == nil {
186 log.Printf("Invalid blockid: %+q", encBlockID)
187 rw.WriteHeader(http.StatusBadRequest)
190 blob.Data = blob.Uncommitted[string(blockID)]
191 blob.Etag = makeEtag()
192 blob.Mtime = time.Now()
193 delete(blob.Uncommitted, string(blockID))
195 rw.WriteHeader(http.StatusCreated)
196 case r.Method == "PUT" && r.Form.Get("comp") == "metadata":
197 // "Set Metadata Headers" API. We don't bother
198 // stubbing "Get Metadata Headers": AzureBlobVolume
199 // sets metadata headers only as a way to bump Etag
200 // and Last-Modified.
202 log.Printf("Got metadata for nonexistent blob: %+v", r)
203 rw.WriteHeader(http.StatusBadRequest)
206 blob.Metadata = make(map[string]string)
207 for k, v := range r.Header {
208 if strings.HasPrefix(strings.ToLower(k), "x-ms-meta-") {
209 name := k[len("x-ms-meta-"):]
210 blob.Metadata[strings.ToLower(name)] = v[0]
213 blob.Mtime = time.Now()
214 blob.Etag = makeEtag()
215 case (r.Method == "GET" || r.Method == "HEAD") && r.Form.Get("comp") == "metadata" && hash != "":
216 // "Get Blob Metadata" API
218 rw.WriteHeader(http.StatusNotFound)
221 for k, v := range blob.Metadata {
222 rw.Header().Set(fmt.Sprintf("x-ms-meta-%s", k), v)
225 case (r.Method == "GET" || r.Method == "HEAD") && hash != "":
228 rw.WriteHeader(http.StatusNotFound)
232 if rangeSpec := rangeRegexp.FindStringSubmatch(r.Header.Get("Range")); rangeSpec != nil {
233 b0, err0 := strconv.Atoi(rangeSpec[1])
234 b1, err1 := strconv.Atoi(rangeSpec[2])
235 if err0 != nil || err1 != nil || b0 >= len(data) || b1 >= len(data) || b0 > b1 {
236 rw.Header().Set("Content-Range", fmt.Sprintf("bytes */%d", len(data)))
237 rw.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
240 rw.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", b0, b1, len(data)))
241 rw.WriteHeader(http.StatusPartialContent)
242 data = data[b0 : b1+1]
244 rw.Header().Set("Last-Modified", blob.Mtime.Format(time.RFC1123))
245 rw.Header().Set("Content-Length", strconv.Itoa(len(data)))
246 if r.Method == "GET" {
247 if _, err := rw.Write(data); err != nil {
248 log.Printf("write %+q: %s", data, err)
252 case r.Method == "DELETE" && hash != "":
255 rw.WriteHeader(http.StatusNotFound)
258 delete(h.blobs, container+"|"+hash)
259 rw.WriteHeader(http.StatusAccepted)
260 case r.Method == "GET" && r.Form.Get("comp") == "list" && r.Form.Get("restype") == "container":
262 prefix := container + "|" + r.Form.Get("prefix")
263 marker := r.Form.Get("marker")
266 if n, err := strconv.Atoi(r.Form.Get("maxresults")); err == nil && n >= 1 && n <= 5000 {
270 resp := storage.BlobListResponse{
273 MaxResults: int64(maxResults),
275 var hashes sort.StringSlice
276 for k := range h.blobs {
277 if strings.HasPrefix(k, prefix) {
278 hashes = append(hashes, k[len(container)+1:])
282 for _, hash := range hashes {
283 if len(resp.Blobs) == maxResults {
284 resp.NextMarker = hash
287 if len(resp.Blobs) > 0 || marker == "" || marker == hash {
288 blob := h.blobs[container+"|"+hash]
289 bmeta := map[string]string(nil)
290 if r.Form.Get("include") == "metadata" {
291 bmeta = blob.Metadata
295 Properties: storage.BlobProperties{
296 LastModified: blob.Mtime.Format(time.RFC1123),
297 ContentLength: int64(len(blob.Data)),
302 resp.Blobs = append(resp.Blobs, b)
305 buf, err := xml.Marshal(resp)
308 rw.WriteHeader(http.StatusInternalServerError)
312 log.Printf("azStubHandler: not implemented: %+v Body:%+q", r, body)
313 rw.WriteHeader(http.StatusNotImplemented)
317 // azStubDialer is a net.Dialer that notices when the Azure driver
318 // tries to connect to "devstoreaccount1.blob.127.0.0.1:46067", and
319 // in such cases transparently dials "127.0.0.1:46067" instead.
320 type azStubDialer struct {
324 var localHostPortRe = regexp.MustCompile(`(127\.0\.0\.1|localhost|\[::1\]):\d+`)
326 func (d *azStubDialer) Dial(network, address string) (net.Conn, error) {
327 if hp := localHostPortRe.FindString(address); hp != "" {
328 log.Println("azStubDialer: dial", hp, "instead of", address)
331 return d.Dialer.Dial(network, address)
334 type TestableAzureBlobVolume struct {
336 azHandler *azStubHandler
337 azStub *httptest.Server
341 func NewTestableAzureBlobVolume(t TB, readonly bool, replication int) *TestableAzureBlobVolume {
342 azHandler := newAzStubHandler()
343 azStub := httptest.NewServer(azHandler)
345 var azClient storage.Client
347 container := azureTestContainer
349 // Connect to stub instead of real Azure storage service
350 stubURLBase := strings.Split(azStub.URL, "://")[1]
352 if azClient, err = storage.NewClient(emulatorAccountName, emulatorAccountKey, stubURLBase, storage.DefaultAPIVersion, false); err != nil {
355 container = "fakecontainername"
357 // Connect to real Azure storage service
358 accountKey, err := readKeyFromFile(azureStorageAccountKeyFile)
362 azClient, err = storage.NewBasicClient(azureStorageAccountName, accountKey)
368 v := &AzureBlobVolume{
369 ContainerName: container,
371 AzureReplication: replication,
373 bsClient: azClient.GetBlobService(),
376 return &TestableAzureBlobVolume{
378 azHandler: azHandler,
384 func TestAzureBlobVolumeWithGeneric(t *testing.T) {
385 defer func(t http.RoundTripper) {
386 http.DefaultTransport = t
387 }(http.DefaultTransport)
388 http.DefaultTransport = &http.Transport{
389 Dial: (&azStubDialer{}).Dial,
391 azureWriteRaceInterval = time.Millisecond
392 azureWriteRacePollTime = time.Nanosecond
393 DoGenericVolumeTests(t, func(t TB) TestableVolume {
394 return NewTestableAzureBlobVolume(t, false, azureStorageReplication)
398 func TestAzureBlobVolumeConcurrentRanges(t *testing.T) {
403 defer func(t http.RoundTripper) {
404 http.DefaultTransport = t
405 }(http.DefaultTransport)
406 http.DefaultTransport = &http.Transport{
407 Dial: (&azStubDialer{}).Dial,
409 azureWriteRaceInterval = time.Millisecond
410 azureWriteRacePollTime = time.Nanosecond
411 // Test (BlockSize mod azureMaxGetBytes)==0 and !=0 cases
412 for _, azureMaxGetBytes = range []int{2 << 22, 2<<22 - 1} {
413 DoGenericVolumeTests(t, func(t TB) TestableVolume {
414 return NewTestableAzureBlobVolume(t, false, azureStorageReplication)
419 func TestReadonlyAzureBlobVolumeWithGeneric(t *testing.T) {
420 defer func(t http.RoundTripper) {
421 http.DefaultTransport = t
422 }(http.DefaultTransport)
423 http.DefaultTransport = &http.Transport{
424 Dial: (&azStubDialer{}).Dial,
426 azureWriteRaceInterval = time.Millisecond
427 azureWriteRacePollTime = time.Nanosecond
428 DoGenericVolumeTests(t, func(t TB) TestableVolume {
429 return NewTestableAzureBlobVolume(t, true, azureStorageReplication)
433 func TestAzureBlobVolumeRangeFenceposts(t *testing.T) {
434 defer func(t http.RoundTripper) {
435 http.DefaultTransport = t
436 }(http.DefaultTransport)
437 http.DefaultTransport = &http.Transport{
438 Dial: (&azStubDialer{}).Dial,
441 v := NewTestableAzureBlobVolume(t, false, 3)
444 for _, size := range []int{
445 2<<22 - 1, // one <max read
446 2 << 22, // one =max read
447 2<<22 + 1, // one =max read, one <max
448 2 << 23, // two =max reads
452 data := make([]byte, size)
453 for i := range data {
454 data[i] = byte((i + 7) & 0xff)
456 hash := fmt.Sprintf("%x", md5.Sum(data))
457 err := v.Put(hash, data)
461 gotData := make([]byte, len(data))
462 gotLen, err := v.Get(hash, gotData)
466 gotHash := fmt.Sprintf("%x", md5.Sum(gotData))
468 t.Error("length mismatch: got %d != %d", gotLen, size)
471 t.Error("hash mismatch: got %s != %s", gotHash, hash)
476 func TestAzureBlobVolumeReplication(t *testing.T) {
477 for r := 1; r <= 4; r++ {
478 v := NewTestableAzureBlobVolume(t, false, r)
480 if n := v.Replication(); n != r {
481 t.Errorf("Got replication %d, expected %d", n, r)
486 func TestAzureBlobVolumeCreateBlobRace(t *testing.T) {
487 defer func(t http.RoundTripper) {
488 http.DefaultTransport = t
489 }(http.DefaultTransport)
490 http.DefaultTransport = &http.Transport{
491 Dial: (&azStubDialer{}).Dial,
494 v := NewTestableAzureBlobVolume(t, false, 3)
497 azureWriteRaceInterval = time.Second
498 azureWriteRacePollTime = time.Millisecond
500 allDone := make(chan struct{})
501 v.azHandler.race = make(chan chan struct{})
503 err := v.Put(TestHash, TestBlock)
508 continuePut := make(chan struct{})
509 // Wait for the stub's Put to create the empty blob
510 v.azHandler.race <- continuePut
512 buf := make([]byte, len(TestBlock))
513 _, err := v.Get(TestHash, buf)
519 // Wait for the stub's Get to get the empty blob
520 close(v.azHandler.race)
521 // Allow stub's Put to continue, so the real data is ready
522 // when the volume's Get retries
524 // Wait for volume's Get to return the real data
528 func TestAzureBlobVolumeCreateBlobRaceDeadline(t *testing.T) {
529 defer func(t http.RoundTripper) {
530 http.DefaultTransport = t
531 }(http.DefaultTransport)
532 http.DefaultTransport = &http.Transport{
533 Dial: (&azStubDialer{}).Dial,
536 v := NewTestableAzureBlobVolume(t, false, 3)
539 azureWriteRaceInterval = 2 * time.Second
540 azureWriteRacePollTime = 5 * time.Millisecond
542 v.PutRaw(TestHash, nil)
544 buf := new(bytes.Buffer)
547 t.Errorf("Index %+q should be empty", buf.Bytes())
550 v.TouchWithDate(TestHash, time.Now().Add(-1982*time.Millisecond))
552 allDone := make(chan struct{})
555 buf := make([]byte, BlockSize)
556 n, err := v.Get(TestHash, buf)
562 t.Errorf("Got %+q, expected empty buf", buf[:n])
567 case <-time.After(time.Second):
568 t.Error("Get should have stopped waiting for race when block was 2s old")
573 if !bytes.HasPrefix(buf.Bytes(), []byte(TestHash+"+0")) {
574 t.Errorf("Index %+q should have %+q", buf.Bytes(), TestHash+"+0")
578 func (v *TestableAzureBlobVolume) PutRaw(locator string, data []byte) {
579 v.azHandler.PutRaw(v.ContainerName, locator, data)
582 func (v *TestableAzureBlobVolume) TouchWithDate(locator string, lastPut time.Time) {
583 v.azHandler.TouchWithDate(v.ContainerName, locator, lastPut)
586 func (v *TestableAzureBlobVolume) Teardown() {
590 func makeEtag() string {
591 return fmt.Sprintf("0x%x", rand.Int63())