Merge branch 'patch-1' of https://github.com/mr-c/arvados into mr-c-patch-1
[arvados.git] / tools / keep-exercise / keep-exercise.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 // Testing tool for Keep services.
6 //
7 // keepexercise helps measure throughput and test reliability under
8 // various usage patterns.
9 //
10 // By default, it reads and writes blocks containing 2^26 NUL
11 // bytes. This generates network traffic without consuming much disk
12 // space.
13 //
14 // For a more realistic test, enable -vary-request. Warning: this will
15 // fill your storage volumes with random data if you leave it running,
16 // which can cost you money or leave you with too little room for
17 // useful data.
18 //
19 package main
20
21 import (
22         "crypto/rand"
23         "encoding/binary"
24         "flag"
25         "fmt"
26         "io"
27         "io/ioutil"
28         "log"
29         "net/http"
30         "os"
31         "os/signal"
32         "syscall"
33         "time"
34
35         "git.arvados.org/arvados.git/sdk/go/arvadosclient"
36         "git.arvados.org/arvados.git/sdk/go/keepclient"
37 )
38
39 var version = "dev"
40
41 // Command line config knobs
42 var (
43         BlockSize     = flag.Int("block-size", keepclient.BLOCKSIZE, "bytes per read/write op")
44         ReadThreads   = flag.Int("rthreads", 1, "number of concurrent readers")
45         WriteThreads  = flag.Int("wthreads", 1, "number of concurrent writers")
46         VaryRequest   = flag.Bool("vary-request", false, "vary the data for each request: consumes disk space, exercises write behavior")
47         VaryThread    = flag.Bool("vary-thread", false, "use -wthreads different data blocks")
48         Replicas      = flag.Int("replicas", 1, "replication level for writing")
49         StatsInterval = flag.Duration("stats-interval", time.Second, "time interval between IO stats reports, or 0 to disable")
50         ServiceURL    = flag.String("url", "", "specify scheme://host of a single keep service to exercise (instead of using all advertised services like normal clients)")
51         ServiceUUID   = flag.String("uuid", "", "specify UUID of a single advertised keep service to exercise")
52         getVersion    = flag.Bool("version", false, "Print version information and exit.")
53         RunTime       = flag.Duration("run-time", 0, "time to run (e.g. 60s), or 0 to run indefinitely (default)")
54 )
55
56 func main() {
57         flag.Parse()
58
59         // Print version information if requested
60         if *getVersion {
61                 fmt.Printf("keep-exercise %s\n", version)
62                 os.Exit(0)
63         }
64
65         stderr := log.New(os.Stderr, "", log.LstdFlags)
66
67         arv, err := arvadosclient.MakeArvadosClient()
68         if err != nil {
69                 stderr.Fatal(err)
70         }
71         kc, err := keepclient.MakeKeepClient(arv)
72         if err != nil {
73                 stderr.Fatal(err)
74         }
75         kc.Want_replicas = *Replicas
76
77         transport := *(http.DefaultTransport.(*http.Transport))
78         transport.TLSClientConfig = arvadosclient.MakeTLSConfig(arv.ApiInsecure)
79         kc.HTTPClient = &http.Client{
80                 Timeout:   10 * time.Minute,
81                 Transport: &transport,
82         }
83
84         overrideServices(kc, stderr)
85
86         nextLocator := make(chan string, *ReadThreads+*WriteThreads)
87
88         go countBeans(nextLocator, stderr)
89         for i := 0; i < *WriteThreads; i++ {
90                 nextBuf := make(chan []byte, 1)
91                 go makeBufs(nextBuf, i, stderr)
92                 go doWrites(kc, nextBuf, nextLocator, stderr)
93         }
94         for i := 0; i < *ReadThreads; i++ {
95                 go doReads(kc, nextLocator, stderr)
96         }
97         <-make(chan struct{})
98 }
99
100 // Send 1234 to bytesInChan when we receive 1234 bytes from keepstore.
101 var bytesInChan = make(chan uint64)
102 var bytesOutChan = make(chan uint64)
103
104 // Send struct{}{} to errorsChan when an error happens.
105 var errorsChan = make(chan struct{})
106
107 func countBeans(nextLocator chan string, stderr *log.Logger) {
108         t0 := time.Now()
109         var tickChan <-chan time.Time
110         var endChan <-chan time.Time
111         c := make(chan os.Signal)
112         signal.Notify(c, os.Interrupt, syscall.SIGTERM)
113         if *StatsInterval > 0 {
114                 tickChan = time.NewTicker(*StatsInterval).C
115         }
116         if *RunTime > 0 {
117                 endChan = time.NewTicker(*RunTime).C
118         }
119         var bytesIn uint64
120         var bytesOut uint64
121         var errors uint64
122         var rateIn, rateOut float64
123         var maxRateIn, maxRateOut float64
124         var abort, printCsv bool
125         csv := log.New(os.Stdout, "", 0)
126         csv.Println("Timestamp,Elapsed,Read (bytes),Avg Read Speed (MiB/s),Peak Read Speed (MiB/s),Written (bytes),Avg Write Speed (MiB/s),Peak Write Speed (MiB/s),Errors,ReadThreads,WriteThreads,VaryRequest,VaryThread,BlockSize,Replicas,StatsInterval,ServiceURL,ServiceUUID,RunTime")
127         for {
128                 select {
129                 case <-tickChan:
130                         printCsv = true
131                 case <-endChan:
132                         printCsv = true
133                         abort = true
134                 case <-c:
135                         printCsv = true
136                         abort = true
137                         fmt.Print("\r") // Suppress the ^C print
138                 case i := <-bytesInChan:
139                         bytesIn += i
140                 case o := <-bytesOutChan:
141                         bytesOut += o
142                 case <-errorsChan:
143                         errors++
144                 }
145                 if printCsv {
146                         elapsed := time.Since(t0)
147                         rateIn = float64(bytesIn) / elapsed.Seconds() / 1048576
148                         if rateIn > maxRateIn {
149                                 maxRateIn = rateIn
150                         }
151                         rateOut = float64(bytesOut) / elapsed.Seconds() / 1048576
152                         if rateOut > maxRateOut {
153                                 maxRateOut = rateOut
154                         }
155                         csv.Printf("%v,%v,%v,%.1f,%.1f,%v,%.1f,%.1f,%d,%d,%d,%t,%t,%d,%d,%s,%s,%s,%s",
156                                 time.Now().Format("2006-01-02 15:04:05"),
157                                 elapsed,
158                                 bytesIn, rateIn, maxRateIn,
159                                 bytesOut, rateOut, maxRateOut,
160                                 errors,
161                                 *ReadThreads,
162                                 *WriteThreads,
163                                 *VaryRequest,
164                                 *VaryThread,
165                                 *BlockSize,
166                                 *Replicas,
167                                 *StatsInterval,
168                                 *ServiceURL,
169                                 *ServiceUUID,
170                                 *RunTime,
171                         )
172                         printCsv = false
173                 }
174                 if abort {
175                         os.Exit(0)
176                 }
177         }
178 }
179
180 func makeBufs(nextBuf chan<- []byte, threadID int, stderr *log.Logger) {
181         buf := make([]byte, *BlockSize)
182         if *VaryThread {
183                 binary.PutVarint(buf, int64(threadID))
184         }
185         randSize := 524288
186         if randSize > *BlockSize {
187                 randSize = *BlockSize
188         }
189         for {
190                 if *VaryRequest {
191                         rnd := make([]byte, randSize)
192                         if _, err := io.ReadFull(rand.Reader, rnd); err != nil {
193                                 stderr.Fatal(err)
194                         }
195                         buf = append(rnd, buf[randSize:]...)
196                 }
197                 nextBuf <- buf
198         }
199 }
200
201 func doWrites(kc *keepclient.KeepClient, nextBuf <-chan []byte, nextLocator chan<- string, stderr *log.Logger) {
202         for buf := range nextBuf {
203                 locator, _, err := kc.PutB(buf)
204                 if err != nil {
205                         stderr.Print(err)
206                         errorsChan <- struct{}{}
207                         continue
208                 }
209                 bytesOutChan <- uint64(len(buf))
210                 for cap(nextLocator) > len(nextLocator)+*WriteThreads {
211                         // Give the readers something to do, unless
212                         // they have lots queued up already.
213                         nextLocator <- locator
214                 }
215         }
216 }
217
218 func doReads(kc *keepclient.KeepClient, nextLocator <-chan string, stderr *log.Logger) {
219         for locator := range nextLocator {
220                 rdr, size, url, err := kc.Get(locator)
221                 if err != nil {
222                         stderr.Print(err)
223                         errorsChan <- struct{}{}
224                         continue
225                 }
226                 n, err := io.Copy(ioutil.Discard, rdr)
227                 rdr.Close()
228                 if n != size || err != nil {
229                         stderr.Printf("Got %d bytes (expected %d) from %s: %v", n, size, url, err)
230                         errorsChan <- struct{}{}
231                         continue
232                         // Note we don't count the bytes received in
233                         // partial/corrupt responses: we are measuring
234                         // throughput, not resource consumption.
235                 }
236                 bytesInChan <- uint64(n)
237         }
238 }
239
240 func overrideServices(kc *keepclient.KeepClient, stderr *log.Logger) {
241         roots := make(map[string]string)
242         if *ServiceURL != "" {
243                 roots["zzzzz-bi6l4-000000000000000"] = *ServiceURL
244         } else if *ServiceUUID != "" {
245                 for uuid, url := range kc.GatewayRoots() {
246                         if uuid == *ServiceUUID {
247                                 roots[uuid] = url
248                                 break
249                         }
250                 }
251                 if len(roots) == 0 {
252                         stderr.Fatalf("Service %q was not in list advertised by API %+q", *ServiceUUID, kc.GatewayRoots())
253                 }
254         } else {
255                 return
256         }
257         kc.SetServiceRoots(roots, roots, roots)
258 }