7167: move perms code from keepstore into keepclient go SDK.
[arvados.git] / services / keepstore / azure_blob_volume_test.go
1 package main
2
3 import (
4         "encoding/base64"
5         "encoding/xml"
6         "flag"
7         "fmt"
8         "io/ioutil"
9         "log"
10         "math/rand"
11         "net"
12         "net/http"
13         "net/http/httptest"
14         "regexp"
15         "sort"
16         "strconv"
17         "strings"
18         "sync"
19         "testing"
20         "time"
21
22         "github.com/curoverse/azure-sdk-for-go/storage"
23 )
24
25 const (
26         // The same fake credentials used by Microsoft's Azure emulator
27         emulatorAccountName = "devstoreaccount1"
28         emulatorAccountKey  = "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw=="
29 )
30
31 var azureTestContainer string
32
33 func init() {
34         flag.StringVar(
35                 &azureTestContainer,
36                 "test.azure-storage-container-volume",
37                 "",
38                 "Name of Azure container to use for testing. Do not use a container with real data! Use -azure-storage-account-name and -azure-storage-key-file arguments to supply credentials.")
39 }
40
41 type azBlob struct {
42         Data        []byte
43         Etag        string
44         Metadata    map[string]string
45         Mtime       time.Time
46         Uncommitted map[string][]byte
47 }
48
49 type azStubHandler struct {
50         sync.Mutex
51         blobs map[string]*azBlob
52 }
53
54 func newAzStubHandler() *azStubHandler {
55         return &azStubHandler{
56                 blobs: make(map[string]*azBlob),
57         }
58 }
59
60 func (h *azStubHandler) TouchWithDate(container, hash string, t time.Time) {
61         if blob, ok := h.blobs[container+"|"+hash]; !ok {
62                 return
63         } else {
64                 blob.Mtime = t
65         }
66 }
67
68 func (h *azStubHandler) PutRaw(container, hash string, data []byte) {
69         h.Lock()
70         defer h.Unlock()
71         h.blobs[container+"|"+hash] = &azBlob{
72                 Data:        data,
73                 Mtime:       time.Now(),
74                 Uncommitted: make(map[string][]byte),
75         }
76 }
77
78 func (h *azStubHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
79         h.Lock()
80         defer h.Unlock()
81         // defer log.Printf("azStubHandler: %+v", r)
82
83         path := strings.Split(r.URL.Path, "/")
84         container := path[1]
85         hash := ""
86         if len(path) > 2 {
87                 hash = path[2]
88         }
89
90         if err := r.ParseForm(); err != nil {
91                 log.Printf("azStubHandler(%+v): %s", r, err)
92                 rw.WriteHeader(http.StatusBadRequest)
93                 return
94         }
95
96         body, err := ioutil.ReadAll(r.Body)
97         if err != nil {
98                 return
99         }
100
101         type blockListRequestBody struct {
102                 XMLName     xml.Name `xml:"BlockList"`
103                 Uncommitted []string
104         }
105
106         blob, blobExists := h.blobs[container+"|"+hash]
107
108         switch {
109         case r.Method == "PUT" && r.Form.Get("comp") == "":
110                 // "Put Blob" API
111                 h.blobs[container+"|"+hash] = &azBlob{
112                         Data:        body,
113                         Mtime:       time.Now(),
114                         Uncommitted: make(map[string][]byte),
115                         Etag:        makeEtag(),
116                 }
117                 rw.WriteHeader(http.StatusCreated)
118         case r.Method == "PUT" && r.Form.Get("comp") == "block":
119                 // "Put Block" API
120                 if !blobExists {
121                         log.Printf("Got block for nonexistent blob: %+v", r)
122                         rw.WriteHeader(http.StatusBadRequest)
123                         return
124                 }
125                 blockID, err := base64.StdEncoding.DecodeString(r.Form.Get("blockid"))
126                 if err != nil || len(blockID) == 0 {
127                         log.Printf("Invalid blockid: %+q", r.Form.Get("blockid"))
128                         rw.WriteHeader(http.StatusBadRequest)
129                         return
130                 }
131                 blob.Uncommitted[string(blockID)] = body
132                 rw.WriteHeader(http.StatusCreated)
133         case r.Method == "PUT" && r.Form.Get("comp") == "blocklist":
134                 // "Put Block List" API
135                 bl := &blockListRequestBody{}
136                 if err := xml.Unmarshal(body, bl); err != nil {
137                         log.Printf("xml Unmarshal: %s", err)
138                         rw.WriteHeader(http.StatusBadRequest)
139                         return
140                 }
141                 for _, encBlockID := range bl.Uncommitted {
142                         blockID, err := base64.StdEncoding.DecodeString(encBlockID)
143                         if err != nil || len(blockID) == 0 || blob.Uncommitted[string(blockID)] == nil {
144                                 log.Printf("Invalid blockid: %+q", encBlockID)
145                                 rw.WriteHeader(http.StatusBadRequest)
146                                 return
147                         }
148                         blob.Data = blob.Uncommitted[string(blockID)]
149                         blob.Etag = makeEtag()
150                         blob.Mtime = time.Now()
151                         delete(blob.Uncommitted, string(blockID))
152                 }
153                 rw.WriteHeader(http.StatusCreated)
154         case r.Method == "PUT" && r.Form.Get("comp") == "metadata":
155                 // "Set Metadata Headers" API. We don't bother
156                 // stubbing "Get Metadata Headers": AzureBlobVolume
157                 // sets metadata headers only as a way to bump Etag
158                 // and Last-Modified.
159                 if !blobExists {
160                         log.Printf("Got metadata for nonexistent blob: %+v", r)
161                         rw.WriteHeader(http.StatusBadRequest)
162                         return
163                 }
164                 blob.Metadata = make(map[string]string)
165                 for k, v := range r.Header {
166                         if strings.HasPrefix(strings.ToLower(k), "x-ms-meta-") {
167                                 blob.Metadata[k] = v[0]
168                         }
169                 }
170                 blob.Mtime = time.Now()
171                 blob.Etag = makeEtag()
172         case (r.Method == "GET" || r.Method == "HEAD") && hash != "":
173                 // "Get Blob" API
174                 if !blobExists {
175                         rw.WriteHeader(http.StatusNotFound)
176                         return
177                 }
178                 rw.Header().Set("Last-Modified", blob.Mtime.Format(time.RFC1123))
179                 rw.Header().Set("Content-Length", strconv.Itoa(len(blob.Data)))
180                 if r.Method == "GET" {
181                         if _, err := rw.Write(blob.Data); err != nil {
182                                 log.Printf("write %+q: %s", blob.Data, err)
183                         }
184                 }
185         case r.Method == "DELETE" && hash != "":
186                 // "Delete Blob" API
187                 if !blobExists {
188                         rw.WriteHeader(http.StatusNotFound)
189                         return
190                 }
191                 delete(h.blobs, container+"|"+hash)
192                 rw.WriteHeader(http.StatusAccepted)
193         case r.Method == "GET" && r.Form.Get("comp") == "list" && r.Form.Get("restype") == "container":
194                 // "List Blobs" API
195                 prefix := container + "|" + r.Form.Get("prefix")
196                 marker := r.Form.Get("marker")
197
198                 maxResults := 2
199                 if n, err := strconv.Atoi(r.Form.Get("maxresults")); err == nil && n >= 1 && n <= 5000 {
200                         maxResults = n
201                 }
202
203                 resp := storage.BlobListResponse{
204                         Marker:     marker,
205                         NextMarker: "",
206                         MaxResults: int64(maxResults),
207                 }
208                 var hashes sort.StringSlice
209                 for k := range h.blobs {
210                         if strings.HasPrefix(k, prefix) {
211                                 hashes = append(hashes, k[len(container)+1:])
212                         }
213                 }
214                 hashes.Sort()
215                 for _, hash := range hashes {
216                         if len(resp.Blobs) == maxResults {
217                                 resp.NextMarker = hash
218                                 break
219                         }
220                         if len(resp.Blobs) > 0 || marker == "" || marker == hash {
221                                 blob := h.blobs[container+"|"+hash]
222                                 resp.Blobs = append(resp.Blobs, storage.Blob{
223                                         Name: hash,
224                                         Properties: storage.BlobProperties{
225                                                 LastModified:  blob.Mtime.Format(time.RFC1123),
226                                                 ContentLength: int64(len(blob.Data)),
227                                                 Etag:          blob.Etag,
228                                         },
229                                 })
230                         }
231                 }
232                 buf, err := xml.Marshal(resp)
233                 if err != nil {
234                         log.Print(err)
235                         rw.WriteHeader(http.StatusInternalServerError)
236                 }
237                 rw.Write(buf)
238         default:
239                 log.Printf("azStubHandler: not implemented: %+v Body:%+q", r, body)
240                 rw.WriteHeader(http.StatusNotImplemented)
241         }
242 }
243
244 // azStubDialer is a net.Dialer that notices when the Azure driver
245 // tries to connect to "devstoreaccount1.blob.127.0.0.1:46067", and
246 // in such cases transparently dials "127.0.0.1:46067" instead.
247 type azStubDialer struct {
248         net.Dialer
249 }
250
251 var localHostPortRe = regexp.MustCompile(`(127\.0\.0\.1|localhost|\[::1\]):\d+`)
252
253 func (d *azStubDialer) Dial(network, address string) (net.Conn, error) {
254         if hp := localHostPortRe.FindString(address); hp != "" {
255                 log.Println("azStubDialer: dial", hp, "instead of", address)
256                 address = hp
257         }
258         return d.Dialer.Dial(network, address)
259 }
260
261 type TestableAzureBlobVolume struct {
262         *AzureBlobVolume
263         azHandler *azStubHandler
264         azStub    *httptest.Server
265         t         *testing.T
266 }
267
268 func NewTestableAzureBlobVolume(t *testing.T, readonly bool, replication int) TestableVolume {
269         azHandler := newAzStubHandler()
270         azStub := httptest.NewServer(azHandler)
271
272         var azClient storage.Client
273
274         container := azureTestContainer
275         if container == "" {
276                 // Connect to stub instead of real Azure storage service
277                 stubURLBase := strings.Split(azStub.URL, "://")[1]
278                 var err error
279                 if azClient, err = storage.NewClient(emulatorAccountName, emulatorAccountKey, stubURLBase, storage.DefaultAPIVersion, false); err != nil {
280                         t.Fatal(err)
281                 }
282                 container = "fakecontainername"
283         } else {
284                 // Connect to real Azure storage service
285                 accountKey, err := readKeyFromFile(azureStorageAccountKeyFile)
286                 if err != nil {
287                         t.Fatal(err)
288                 }
289                 azClient, err = storage.NewBasicClient(azureStorageAccountName, accountKey)
290                 if err != nil {
291                         t.Fatal(err)
292                 }
293         }
294
295         v := NewAzureBlobVolume(azClient, container, readonly, replication)
296
297         return &TestableAzureBlobVolume{
298                 AzureBlobVolume: v,
299                 azHandler:       azHandler,
300                 azStub:          azStub,
301                 t:               t,
302         }
303 }
304
305 func TestAzureBlobVolumeWithGeneric(t *testing.T) {
306         defer func(t http.RoundTripper) {
307                 http.DefaultTransport = t
308         }(http.DefaultTransport)
309         http.DefaultTransport = &http.Transport{
310                 Dial: (&azStubDialer{}).Dial,
311         }
312         DoGenericVolumeTests(t, func(t *testing.T) TestableVolume {
313                 return NewTestableAzureBlobVolume(t, false, azureStorageReplication)
314         })
315 }
316
317 func TestReadonlyAzureBlobVolumeWithGeneric(t *testing.T) {
318         defer func(t http.RoundTripper) {
319                 http.DefaultTransport = t
320         }(http.DefaultTransport)
321         http.DefaultTransport = &http.Transport{
322                 Dial: (&azStubDialer{}).Dial,
323         }
324         DoGenericVolumeTests(t, func(t *testing.T) TestableVolume {
325                 return NewTestableAzureBlobVolume(t, true, azureStorageReplication)
326         })
327 }
328
329 func TestAzureBlobVolumeReplication(t *testing.T) {
330         for r := 1; r <= 4; r++ {
331                 v := NewTestableAzureBlobVolume(t, false, r)
332                 defer v.Teardown()
333                 if n := v.Replication(); n != r {
334                         t.Errorf("Got replication %d, expected %d", n, r)
335                 }
336         }
337 }
338
339 func (v *TestableAzureBlobVolume) PutRaw(locator string, data []byte) {
340         v.azHandler.PutRaw(v.containerName, locator, data)
341 }
342
343 func (v *TestableAzureBlobVolume) TouchWithDate(locator string, lastPut time.Time) {
344         v.azHandler.TouchWithDate(v.containerName, locator, lastPut)
345 }
346
347 func (v *TestableAzureBlobVolume) Teardown() {
348         v.azStub.Close()
349 }
350
351 func makeEtag() string {
352         return fmt.Sprintf("0x%x", rand.Int63())
353 }