Cleaned up unit tests. (refs #2620)
[arvados.git] / services / keep / src / keep / volume_test.go
1 package main
2
3 import (
4         "bytes"
5         "fmt"
6         "io/ioutil"
7         "os"
8         "testing"
9         "time"
10 )
11
12 func TempUnixVolume(t *testing.T, queue chan *IORequest) UnixVolume {
13         d, err := ioutil.TempDir("", "volume_test")
14         if err != nil {
15                 t.Fatal(err)
16         }
17         return MakeUnixVolume(d, queue)
18 }
19
20 func _teardown(v UnixVolume) {
21         if v.queue != nil {
22                 close(v.queue)
23         }
24         os.RemoveAll(v.root)
25 }
26
27 // store writes a Keep block directly into a UnixVolume, for testing
28 // UnixVolume methods.
29 //
30 func _store(t *testing.T, vol UnixVolume, filename string, block []byte) {
31         blockdir := fmt.Sprintf("%s/%s", vol.root, filename[:3])
32         if err := os.MkdirAll(blockdir, 0755); err != nil {
33                 t.Fatal(err)
34         }
35
36         blockpath := fmt.Sprintf("%s/%s", blockdir, filename)
37         if f, err := os.Create(blockpath); err == nil {
38                 f.Write(block)
39                 f.Close()
40         } else {
41                 t.Fatal(err)
42         }
43 }
44
45 func TestGet(t *testing.T) {
46         v := TempUnixVolume(t, nil)
47         defer _teardown(v)
48         _store(t, v, TEST_HASH, TEST_BLOCK)
49
50         buf, err := v.Get(TEST_HASH)
51         if err != nil {
52                 t.Error(err)
53         }
54         if bytes.Compare(buf, TEST_BLOCK) != 0 {
55                 t.Errorf("expected %s, got %s", string(TEST_BLOCK), string(buf))
56         }
57 }
58
59 func TestGetNotFound(t *testing.T) {
60         v := TempUnixVolume(t, nil)
61         defer _teardown(v)
62         _store(t, v, TEST_HASH, TEST_BLOCK)
63
64         buf, err := v.Get(TEST_HASH_2)
65         switch {
66         case os.IsNotExist(err):
67                 break
68         case err == nil:
69                 t.Errorf("Read should have failed, returned %s", string(buf))
70         default:
71                 t.Errorf("Read expected ErrNotExist, got: %s", err)
72         }
73 }
74
75 func TestPut(t *testing.T) {
76         v := TempUnixVolume(t, nil)
77         defer _teardown(v)
78
79         err := v.Put(TEST_HASH, TEST_BLOCK)
80         if err != nil {
81                 t.Error(err)
82         }
83         p := fmt.Sprintf("%s/%s/%s", v.root, TEST_HASH[:3], TEST_HASH)
84         if buf, err := ioutil.ReadFile(p); err != nil {
85                 t.Error(err)
86         } else if bytes.Compare(buf, TEST_BLOCK) != 0 {
87                 t.Errorf("Write should have stored %s, did store %s",
88                         string(TEST_BLOCK), string(buf))
89         }
90 }
91
92 func TestPutBadVolume(t *testing.T) {
93         v := TempUnixVolume(t, nil)
94         defer _teardown(v)
95
96         os.Chmod(v.root, 000)
97         err := v.Put(TEST_HASH, TEST_BLOCK)
98         if err == nil {
99                 t.Error("Write should have failed")
100         }
101 }
102
103 // Serialization tests: launch a bunch of concurrent
104 //
105 // TODO(twp): show that the underlying Read/Write operations executed
106 // serially and not concurrently. The easiest way to do this is
107 // probably to activate verbose or debug logging, capture log output
108 // and examine it to confirm that Reads and Writes did not overlap.
109 //
110 // TODO(twp): a proper test of I/O serialization requires that a
111 // second request start while the first one is still underway.
112 // Guaranteeing that the test behaves this way requires some tricky
113 // synchronization and mocking.  For now we'll just launch a bunch of
114 // requests simultaenously in goroutines and demonstrate that they
115 // return accurate results.
116 //
117 func TestGetSerialized(t *testing.T) {
118         v := TempUnixVolume(t, make(chan *IORequest))
119         defer _teardown(v)
120
121         _store(t, v, TEST_HASH, TEST_BLOCK)
122         _store(t, v, TEST_HASH_2, TEST_BLOCK_2)
123         _store(t, v, TEST_HASH_3, TEST_BLOCK_3)
124
125         sem := make(chan int)
126         go func(sem chan int) {
127                 buf, err := v.Get(TEST_HASH)
128                 if err != nil {
129                         t.Errorf("err1: %v", err)
130                 }
131                 if bytes.Compare(buf, TEST_BLOCK) != 0 {
132                         t.Errorf("buf should be %s, is %s", string(TEST_BLOCK), string(buf))
133                 }
134                 sem <- 1
135         }(sem)
136
137         go func(sem chan int) {
138                 buf, err := v.Get(TEST_HASH_2)
139                 if err != nil {
140                         t.Errorf("err2: %v", err)
141                 }
142                 if bytes.Compare(buf, TEST_BLOCK_2) != 0 {
143                         t.Errorf("buf should be %s, is %s", string(TEST_BLOCK_2), string(buf))
144                 }
145                 sem <- 1
146         }(sem)
147
148         go func(sem chan int) {
149                 buf, err := v.Get(TEST_HASH_3)
150                 if err != nil {
151                         t.Errorf("err3: %v", err)
152                 }
153                 if bytes.Compare(buf, TEST_BLOCK_3) != 0 {
154                         t.Errorf("buf should be %s, is %s", string(TEST_BLOCK_3), string(buf))
155                 }
156                 sem <- 1
157         }(sem)
158
159         // Wait for all goroutines to finish
160         for done := 0; done < 2; {
161                 done += <-sem
162         }
163 }
164
165 func TestPutSerialized(t *testing.T) {
166         v := TempUnixVolume(t, make(chan *IORequest))
167         defer _teardown(v)
168
169         sem := make(chan int)
170         go func(sem chan int) {
171                 err := v.Put(TEST_HASH, TEST_BLOCK)
172                 if err != nil {
173                         t.Errorf("err1: %v", err)
174                 }
175                 sem <- 1
176         }(sem)
177
178         go func(sem chan int) {
179                 err := v.Put(TEST_HASH_2, TEST_BLOCK_2)
180                 if err != nil {
181                         t.Errorf("err2: %v", err)
182                 }
183                 sem <- 1
184         }(sem)
185
186         go func(sem chan int) {
187                 err := v.Put(TEST_HASH_3, TEST_BLOCK_3)
188                 if err != nil {
189                         t.Errorf("err3: %v", err)
190                 }
191                 sem <- 1
192         }(sem)
193
194         // Wait for all goroutines to finish
195         for done := 0; done < 2; {
196                 done += <-sem
197         }
198
199         // Double check that we actually wrote the blocks we expected to write.
200         buf, err := v.Get(TEST_HASH)
201         if err != nil {
202                 t.Errorf("Get #1: %v", err)
203         }
204         if bytes.Compare(buf, TEST_BLOCK) != 0 {
205                 t.Errorf("Get #1: expected %s, got %s", string(TEST_BLOCK), string(buf))
206         }
207
208         buf, err = v.Get(TEST_HASH_2)
209         if err != nil {
210                 t.Errorf("Get #2: %v", err)
211         }
212         if bytes.Compare(buf, TEST_BLOCK_2) != 0 {
213                 t.Errorf("Get #2: expected %s, got %s", string(TEST_BLOCK_2), string(buf))
214         }
215
216         buf, err = v.Get(TEST_HASH_3)
217         if err != nil {
218                 t.Errorf("Get #3: %v", err)
219         }
220         if bytes.Compare(buf, TEST_BLOCK_3) != 0 {
221                 t.Errorf("Get #3: expected %s, got %s", string(TEST_BLOCK_3), string(buf))
222         }
223 }
224
225 func TestIsFull(t *testing.T) {
226         v := TempUnixVolume(t, nil)
227         defer _teardown(v)
228
229         full_path := v.root + "/full"
230         now := fmt.Sprintf("%d", time.Now().Unix())
231         os.Symlink(now, full_path)
232         if !v.IsFull() {
233                 t.Errorf("%s: claims not to be full", v)
234         }
235         os.Remove(full_path)
236
237         // Test with an expired /full link.
238         expired := fmt.Sprintf("%d", time.Now().Unix()-3605)
239         os.Symlink(expired, full_path)
240         if v.IsFull() {
241                 t.Errorf("%s: should no longer be full", v)
242         }
243 }