24 log "github.com/Sirupsen/logrus"
25 "github.com/curoverse/azure-sdk-for-go/storage"
29 // The same fake credentials used by Microsoft's Azure emulator
30 emulatorAccountName = "devstoreaccount1"
31 emulatorAccountKey = "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw=="
34 var azureTestContainer string
39 "test.azure-storage-container-volume",
41 "Name of Azure container to use for testing. Do not use a container with real data! Use -azure-storage-account-name and -azure-storage-key-file arguments to supply credentials.")
47 Metadata map[string]string
49 Uncommitted map[string][]byte
52 type azStubHandler struct {
54 blobs map[string]*azBlob
55 race chan chan struct{}
58 func newAzStubHandler() *azStubHandler {
59 return &azStubHandler{
60 blobs: make(map[string]*azBlob),
64 func (h *azStubHandler) TouchWithDate(container, hash string, t time.Time) {
65 blob, ok := h.blobs[container+"|"+hash]
72 func (h *azStubHandler) PutRaw(container, hash string, data []byte) {
75 h.blobs[container+"|"+hash] = &azBlob{
78 Metadata: make(map[string]string),
79 Uncommitted: make(map[string][]byte),
83 func (h *azStubHandler) unlockAndRace() {
88 // Signal caller that race is starting by reading from
89 // h.race. If we get a channel, block until that channel is
90 // ready to receive. If we get nil (or h.race is closed) just
92 if c := <-h.race; c != nil {
98 var rangeRegexp = regexp.MustCompile(`^bytes=(\d+)-(\d+)$`)
100 func (h *azStubHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
103 // defer log.Printf("azStubHandler: %+v", r)
105 path := strings.Split(r.URL.Path, "/")
112 if err := r.ParseForm(); err != nil {
113 log.Printf("azStubHandler(%+v): %s", r, err)
114 rw.WriteHeader(http.StatusBadRequest)
118 body, err := ioutil.ReadAll(r.Body)
123 type blockListRequestBody struct {
124 XMLName xml.Name `xml:"BlockList"`
128 blob, blobExists := h.blobs[container+"|"+hash]
131 case r.Method == "PUT" && r.Form.Get("comp") == "":
133 if _, ok := h.blobs[container+"|"+hash]; !ok {
134 // Like the real Azure service, we offer a
135 // race window during which other clients can
136 // list/get the new blob before any data is
138 h.blobs[container+"|"+hash] = &azBlob{
140 Uncommitted: make(map[string][]byte),
141 Metadata: make(map[string]string),
146 metadata := make(map[string]string)
147 for k, v := range r.Header {
148 if strings.HasPrefix(strings.ToLower(k), "x-ms-meta-") {
149 name := k[len("x-ms-meta-"):]
150 metadata[strings.ToLower(name)] = v[0]
153 h.blobs[container+"|"+hash] = &azBlob{
156 Uncommitted: make(map[string][]byte),
160 rw.WriteHeader(http.StatusCreated)
161 case r.Method == "PUT" && r.Form.Get("comp") == "block":
164 log.Printf("Got block for nonexistent blob: %+v", r)
165 rw.WriteHeader(http.StatusBadRequest)
168 blockID, err := base64.StdEncoding.DecodeString(r.Form.Get("blockid"))
169 if err != nil || len(blockID) == 0 {
170 log.Printf("Invalid blockid: %+q", r.Form.Get("blockid"))
171 rw.WriteHeader(http.StatusBadRequest)
174 blob.Uncommitted[string(blockID)] = body
175 rw.WriteHeader(http.StatusCreated)
176 case r.Method == "PUT" && r.Form.Get("comp") == "blocklist":
177 // "Put Block List" API
178 bl := &blockListRequestBody{}
179 if err := xml.Unmarshal(body, bl); err != nil {
180 log.Printf("xml Unmarshal: %s", err)
181 rw.WriteHeader(http.StatusBadRequest)
184 for _, encBlockID := range bl.Uncommitted {
185 blockID, err := base64.StdEncoding.DecodeString(encBlockID)
186 if err != nil || len(blockID) == 0 || blob.Uncommitted[string(blockID)] == nil {
187 log.Printf("Invalid blockid: %+q", encBlockID)
188 rw.WriteHeader(http.StatusBadRequest)
191 blob.Data = blob.Uncommitted[string(blockID)]
192 blob.Etag = makeEtag()
193 blob.Mtime = time.Now()
194 delete(blob.Uncommitted, string(blockID))
196 rw.WriteHeader(http.StatusCreated)
197 case r.Method == "PUT" && r.Form.Get("comp") == "metadata":
198 // "Set Metadata Headers" API. We don't bother
199 // stubbing "Get Metadata Headers": AzureBlobVolume
200 // sets metadata headers only as a way to bump Etag
201 // and Last-Modified.
203 log.Printf("Got metadata for nonexistent blob: %+v", r)
204 rw.WriteHeader(http.StatusBadRequest)
207 blob.Metadata = make(map[string]string)
208 for k, v := range r.Header {
209 if strings.HasPrefix(strings.ToLower(k), "x-ms-meta-") {
210 name := k[len("x-ms-meta-"):]
211 blob.Metadata[strings.ToLower(name)] = v[0]
214 blob.Mtime = time.Now()
215 blob.Etag = makeEtag()
216 case (r.Method == "GET" || r.Method == "HEAD") && r.Form.Get("comp") == "metadata" && hash != "":
217 // "Get Blob Metadata" API
219 rw.WriteHeader(http.StatusNotFound)
222 for k, v := range blob.Metadata {
223 rw.Header().Set(fmt.Sprintf("x-ms-meta-%s", k), v)
226 case (r.Method == "GET" || r.Method == "HEAD") && hash != "":
229 rw.WriteHeader(http.StatusNotFound)
233 if rangeSpec := rangeRegexp.FindStringSubmatch(r.Header.Get("Range")); rangeSpec != nil {
234 b0, err0 := strconv.Atoi(rangeSpec[1])
235 b1, err1 := strconv.Atoi(rangeSpec[2])
236 if err0 != nil || err1 != nil || b0 >= len(data) || b1 >= len(data) || b0 > b1 {
237 rw.Header().Set("Content-Range", fmt.Sprintf("bytes */%d", len(data)))
238 rw.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
241 rw.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", b0, b1, len(data)))
242 rw.WriteHeader(http.StatusPartialContent)
243 data = data[b0 : b1+1]
245 rw.Header().Set("Last-Modified", blob.Mtime.Format(time.RFC1123))
246 rw.Header().Set("Content-Length", strconv.Itoa(len(data)))
247 if r.Method == "GET" {
248 if _, err := rw.Write(data); err != nil {
249 log.Printf("write %+q: %s", data, err)
253 case r.Method == "DELETE" && hash != "":
256 rw.WriteHeader(http.StatusNotFound)
259 delete(h.blobs, container+"|"+hash)
260 rw.WriteHeader(http.StatusAccepted)
261 case r.Method == "GET" && r.Form.Get("comp") == "list" && r.Form.Get("restype") == "container":
263 prefix := container + "|" + r.Form.Get("prefix")
264 marker := r.Form.Get("marker")
267 if n, err := strconv.Atoi(r.Form.Get("maxresults")); err == nil && n >= 1 && n <= 5000 {
271 resp := storage.BlobListResponse{
274 MaxResults: int64(maxResults),
276 var hashes sort.StringSlice
277 for k := range h.blobs {
278 if strings.HasPrefix(k, prefix) {
279 hashes = append(hashes, k[len(container)+1:])
283 for _, hash := range hashes {
284 if len(resp.Blobs) == maxResults {
285 resp.NextMarker = hash
288 if len(resp.Blobs) > 0 || marker == "" || marker == hash {
289 blob := h.blobs[container+"|"+hash]
290 bmeta := map[string]string(nil)
291 if r.Form.Get("include") == "metadata" {
292 bmeta = blob.Metadata
296 Properties: storage.BlobProperties{
297 LastModified: blob.Mtime.Format(time.RFC1123),
298 ContentLength: int64(len(blob.Data)),
303 resp.Blobs = append(resp.Blobs, b)
306 buf, err := xml.Marshal(resp)
309 rw.WriteHeader(http.StatusInternalServerError)
313 log.Printf("azStubHandler: not implemented: %+v Body:%+q", r, body)
314 rw.WriteHeader(http.StatusNotImplemented)
318 // azStubDialer is a net.Dialer that notices when the Azure driver
319 // tries to connect to "devstoreaccount1.blob.127.0.0.1:46067", and
320 // in such cases transparently dials "127.0.0.1:46067" instead.
321 type azStubDialer struct {
325 var localHostPortRe = regexp.MustCompile(`(127\.0\.0\.1|localhost|\[::1\]):\d+`)
327 func (d *azStubDialer) Dial(network, address string) (net.Conn, error) {
328 if hp := localHostPortRe.FindString(address); hp != "" {
329 log.Println("azStubDialer: dial", hp, "instead of", address)
332 return d.Dialer.Dial(network, address)
335 type TestableAzureBlobVolume struct {
337 azHandler *azStubHandler
338 azStub *httptest.Server
342 func NewTestableAzureBlobVolume(t TB, readonly bool, replication int) *TestableAzureBlobVolume {
343 azHandler := newAzStubHandler()
344 azStub := httptest.NewServer(azHandler)
346 var azClient storage.Client
348 container := azureTestContainer
350 // Connect to stub instead of real Azure storage service
351 stubURLBase := strings.Split(azStub.URL, "://")[1]
353 if azClient, err = storage.NewClient(emulatorAccountName, emulatorAccountKey, stubURLBase, storage.DefaultAPIVersion, false); err != nil {
356 container = "fakecontainername"
358 // Connect to real Azure storage service
359 accountKey, err := readKeyFromFile(azureStorageAccountKeyFile)
363 azClient, err = storage.NewBasicClient(azureStorageAccountName, accountKey)
369 v := &AzureBlobVolume{
370 ContainerName: container,
372 AzureReplication: replication,
374 bsClient: azClient.GetBlobService(),
377 return &TestableAzureBlobVolume{
379 azHandler: azHandler,
385 func TestAzureBlobVolumeWithGeneric(t *testing.T) {
386 defer func(t http.RoundTripper) {
387 http.DefaultTransport = t
388 }(http.DefaultTransport)
389 http.DefaultTransport = &http.Transport{
390 Dial: (&azStubDialer{}).Dial,
392 azureWriteRaceInterval = time.Millisecond
393 azureWriteRacePollTime = time.Nanosecond
394 DoGenericVolumeTests(t, func(t TB) TestableVolume {
395 return NewTestableAzureBlobVolume(t, false, azureStorageReplication)
399 func TestAzureBlobVolumeConcurrentRanges(t *testing.T) {
404 defer func(t http.RoundTripper) {
405 http.DefaultTransport = t
406 }(http.DefaultTransport)
407 http.DefaultTransport = &http.Transport{
408 Dial: (&azStubDialer{}).Dial,
410 azureWriteRaceInterval = time.Millisecond
411 azureWriteRacePollTime = time.Nanosecond
412 // Test (BlockSize mod azureMaxGetBytes)==0 and !=0 cases
413 for _, azureMaxGetBytes = range []int{2 << 22, 2<<22 - 1} {
414 DoGenericVolumeTests(t, func(t TB) TestableVolume {
415 return NewTestableAzureBlobVolume(t, false, azureStorageReplication)
420 func TestReadonlyAzureBlobVolumeWithGeneric(t *testing.T) {
421 defer func(t http.RoundTripper) {
422 http.DefaultTransport = t
423 }(http.DefaultTransport)
424 http.DefaultTransport = &http.Transport{
425 Dial: (&azStubDialer{}).Dial,
427 azureWriteRaceInterval = time.Millisecond
428 azureWriteRacePollTime = time.Nanosecond
429 DoGenericVolumeTests(t, func(t TB) TestableVolume {
430 return NewTestableAzureBlobVolume(t, true, azureStorageReplication)
434 func TestAzureBlobVolumeRangeFenceposts(t *testing.T) {
435 defer func(t http.RoundTripper) {
436 http.DefaultTransport = t
437 }(http.DefaultTransport)
438 http.DefaultTransport = &http.Transport{
439 Dial: (&azStubDialer{}).Dial,
442 v := NewTestableAzureBlobVolume(t, false, 3)
445 for _, size := range []int{
446 2<<22 - 1, // one <max read
447 2 << 22, // one =max read
448 2<<22 + 1, // one =max read, one <max
449 2 << 23, // two =max reads
453 data := make([]byte, size)
454 for i := range data {
455 data[i] = byte((i + 7) & 0xff)
457 hash := fmt.Sprintf("%x", md5.Sum(data))
458 err := v.Put(context.Background(), hash, data)
462 gotData := make([]byte, len(data))
463 gotLen, err := v.Get(context.Background(), hash, gotData)
467 gotHash := fmt.Sprintf("%x", md5.Sum(gotData))
469 t.Error("length mismatch: got %d != %d", gotLen, size)
472 t.Error("hash mismatch: got %s != %s", gotHash, hash)
477 func TestAzureBlobVolumeReplication(t *testing.T) {
478 for r := 1; r <= 4; r++ {
479 v := NewTestableAzureBlobVolume(t, false, r)
481 if n := v.Replication(); n != r {
482 t.Errorf("Got replication %d, expected %d", n, r)
487 func TestAzureBlobVolumeCreateBlobRace(t *testing.T) {
488 defer func(t http.RoundTripper) {
489 http.DefaultTransport = t
490 }(http.DefaultTransport)
491 http.DefaultTransport = &http.Transport{
492 Dial: (&azStubDialer{}).Dial,
495 v := NewTestableAzureBlobVolume(t, false, 3)
498 azureWriteRaceInterval = time.Second
499 azureWriteRacePollTime = time.Millisecond
501 allDone := make(chan struct{})
502 v.azHandler.race = make(chan chan struct{})
504 err := v.Put(context.Background(), TestHash, TestBlock)
509 continuePut := make(chan struct{})
510 // Wait for the stub's Put to create the empty blob
511 v.azHandler.race <- continuePut
513 buf := make([]byte, len(TestBlock))
514 _, err := v.Get(context.Background(), TestHash, buf)
520 // Wait for the stub's Get to get the empty blob
521 close(v.azHandler.race)
522 // Allow stub's Put to continue, so the real data is ready
523 // when the volume's Get retries
525 // Wait for volume's Get to return the real data
529 func TestAzureBlobVolumeCreateBlobRaceDeadline(t *testing.T) {
530 defer func(t http.RoundTripper) {
531 http.DefaultTransport = t
532 }(http.DefaultTransport)
533 http.DefaultTransport = &http.Transport{
534 Dial: (&azStubDialer{}).Dial,
537 v := NewTestableAzureBlobVolume(t, false, 3)
540 azureWriteRaceInterval = 2 * time.Second
541 azureWriteRacePollTime = 5 * time.Millisecond
543 v.PutRaw(TestHash, nil)
545 buf := new(bytes.Buffer)
548 t.Errorf("Index %+q should be empty", buf.Bytes())
551 v.TouchWithDate(TestHash, time.Now().Add(-1982*time.Millisecond))
553 allDone := make(chan struct{})
556 buf := make([]byte, BlockSize)
557 n, err := v.Get(context.Background(), TestHash, buf)
563 t.Errorf("Got %+q, expected empty buf", buf[:n])
568 case <-time.After(time.Second):
569 t.Error("Get should have stopped waiting for race when block was 2s old")
574 if !bytes.HasPrefix(buf.Bytes(), []byte(TestHash+"+0")) {
575 t.Errorf("Index %+q should have %+q", buf.Bytes(), TestHash+"+0")
579 func (v *TestableAzureBlobVolume) PutRaw(locator string, data []byte) {
580 v.azHandler.PutRaw(v.ContainerName, locator, data)
583 func (v *TestableAzureBlobVolume) TouchWithDate(locator string, lastPut time.Time) {
584 v.azHandler.TouchWithDate(v.ContainerName, locator, lastPut)
587 func (v *TestableAzureBlobVolume) Teardown() {
591 func makeEtag() string {
592 return fmt.Sprintf("0x%x", rand.Int63())