16585: simplify the code: use a context instead of a channel for the
authorWard Vandewege <ward@curii.com>
Tue, 7 Jul 2020 20:56:21 +0000 (16:56 -0400)
committerWard Vandewege <ward@curii.com>
Tue, 7 Jul 2020 21:37:35 +0000 (17:37 -0400)
       timer and catching signals.

Arvados-DCO-1.1-Signed-off-by: Ward Vandewege <ward@curii.com>

tools/keep-exercise/keep-exercise.go

index 7641465aa329056db3a559db3f032181246e4a32..19d46efbd8771c447ea0c9f6adce2631f57f21ad 100644 (file)
@@ -19,6 +19,7 @@
 package main
 
 import (
 package main
 
 import (
+       "context"
        "crypto/rand"
        "encoding/binary"
        "flag"
        "crypto/rand"
        "encoding/binary"
        "flag"
@@ -29,7 +30,6 @@ import (
        "net/http"
        "os"
        "os/signal"
        "net/http"
        "os"
        "os/signal"
-       "sync"
        "sync/atomic"
        "syscall"
        "time"
        "sync/atomic"
        "syscall"
        "time"
@@ -56,8 +56,12 @@ var (
        Repeat        = flag.Int("repeat", 1, "number of times to repeat the experiment (default 1)")
 )
 
        Repeat        = flag.Int("repeat", 1, "number of times to repeat the experiment (default 1)")
 )
 
-var summary string
-var csvHeader string
+// Send 1234 to bytesInChan when we receive 1234 bytes from keepstore.
+var bytesInChan = make(chan uint64)
+var bytesOutChan = make(chan uint64)
+
+// Send struct{}{} to errorsChan when an error happens.
+var errorsChan = make(chan struct{})
 
 func main() {
        flag.Parse()
 
 func main() {
        flag.Parse()
@@ -100,91 +104,104 @@ func main() {
                },
        }
 
                },
        }
 
+       ctx, cancel := context.WithCancel(context.Background())
+       defer cancel()
+       sigChan := make(chan os.Signal, 1)
+       signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)
+       go func() {
+               <-sigChan
+               fmt.Print("\r") // Suppress the ^C print
+               cancel()
+       }()
+
        overrideServices(kc, stderr)
        overrideServices(kc, stderr)
-       csvHeader = "Timestamp,Elapsed,Read (bytes),Avg Read Speed (MiB/s),Peak Read Speed (MiB/s),Written (bytes),Avg Write Speed (MiB/s),Peak Write Speed (MiB/s),Errors,ReadThreads,WriteThreads,VaryRequest,VaryThread,BlockSize,Replicas,StatsInterval,ServiceURL,ServiceUUID,RunTime,Repeat"
+       csvHeader := "Timestamp,Elapsed,Read (bytes),Avg Read Speed (MiB/s),Peak Read Speed (MiB/s),Written (bytes),Avg Write Speed (MiB/s),Peak Write Speed (MiB/s),Errors,ReadThreads,WriteThreads,VaryRequest,VaryThread,BlockSize,Replicas,StatsInterval,ServiceURL,ServiceUUID,RunTime,Repeat"
+       var summary string
 
        for i := 0; i < *Repeat; i++ {
 
        for i := 0; i < *Repeat; i++ {
-               runExperiment(kc, stderr)
-               stderr.Printf("*************************** experiment %d complete ******************************\n", i)
-               summary += fmt.Sprintf(",%d\n", i)
+               if ctx.Err() == nil {
+                       summary = runExperiment(ctx, kc, summary, csvHeader, stderr)
+                       stderr.Printf("*************************** experiment %d complete ******************************\n", i)
+                       summary += fmt.Sprintf(",%d\n", i)
+               }
        }
        stderr.Println("Summary:")
        stderr.Println()
        }
        stderr.Println("Summary:")
        stderr.Println()
+       fmt.Println()
        fmt.Println(csvHeader + ",Experiment")
        fmt.Println(summary)
 }
 
        fmt.Println(csvHeader + ",Experiment")
        fmt.Println(summary)
 }
 
-func runExperiment(kc *keepclient.KeepClient, stderr *log.Logger) {
-       var wg sync.WaitGroup
+func runExperiment(ctx context.Context, kc *keepclient.KeepClient, summary string, csvHeader string, stderr *log.Logger) (newSummary string) {
+       newSummary = summary
        var nextLocator atomic.Value
 
        var nextLocator atomic.Value
 
-       wg.Add(1)
-       stopCh := make(chan struct{})
+       // Start warmup
+       ready := make(chan struct{})
+       var warmup bool
        if *ReadThreads > 0 {
        if *ReadThreads > 0 {
+               warmup = true
                stderr.Printf("Start warmup phase, waiting for 1 available block before reading starts\n")
        }
                stderr.Printf("Start warmup phase, waiting for 1 available block before reading starts\n")
        }
-       for i := 0; i < *WriteThreads; i++ {
-               nextBuf := make(chan []byte, 1)
-               wg.Add(1)
-               go makeBufs(&wg, nextBuf, i, stopCh, stderr)
-               wg.Add(1)
-               go doWrites(&wg, kc, nextBuf, &nextLocator, stopCh, stderr)
-       }
-       if *ReadThreads > 0 {
-               for nextLocator.Load() == nil {
-                       select {
-                       case _ = <-bytesOutChan:
+       nextBuf := make(chan []byte, 1)
+       go makeBufs(nextBuf, 0, stderr)
+       if warmup {
+               go func() {
+                       locator, _, err := kc.PutB(<-nextBuf)
+                       if err != nil {
+                               stderr.Print(err)
+                               errorsChan <- struct{}{}
                        }
                        }
-               }
-               stderr.Printf("Warmup complete")
+                       nextLocator.Store(locator)
+                       stderr.Println("Warmup complete!")
+                       close(ready)
+               }()
+       } else {
+               close(ready)
        }
        }
-       go countBeans(&wg, stopCh, stderr)
-       for i := 0; i < *ReadThreads; i++ {
-               wg.Add(1)
-               go doReads(&wg, kc, &nextLocator, stopCh, stderr)
+       select {
+       case <-ctx.Done():
+               return
+       case <-ready:
        }
        }
-       wg.Wait()
-}
 
 
-// Send 1234 to bytesInChan when we receive 1234 bytes from keepstore.
-var bytesInChan = make(chan uint64)
-var bytesOutChan = make(chan uint64)
+       // Warmup complete
+       ctx, cancel := context.WithDeadline(ctx, time.Now().Add(*RunTime))
+       defer cancel()
 
 
-// Send struct{}{} to errorsChan when an error happens.
-var errorsChan = make(chan struct{})
+       for i := 0; i < *WriteThreads; i++ {
+               if i > 0 {
+                       // the makeBufs goroutine with index 0 was already started for the warmup phase, above
+                       nextBuf := make(chan []byte, 1)
+                       go makeBufs(nextBuf, i, stderr)
+               }
+               go doWrites(ctx, kc, nextBuf, &nextLocator, stderr)
+       }
+       for i := 0; i < *ReadThreads; i++ {
+               go doReads(ctx, kc, &nextLocator, stderr)
+       }
 
 
-func countBeans(wg *sync.WaitGroup, stopCh chan struct{}, stderr *log.Logger) {
-       defer wg.Done()
        t0 := time.Now()
        var tickChan <-chan time.Time
        t0 := time.Now()
        var tickChan <-chan time.Time
-       var endChan <-chan time.Time
-       c := make(chan os.Signal, 1)
-       signal.Notify(c, os.Interrupt, syscall.SIGTERM)
        if *StatsInterval > 0 {
                tickChan = time.NewTicker(*StatsInterval).C
        }
        if *StatsInterval > 0 {
                tickChan = time.NewTicker(*StatsInterval).C
        }
-       if *RunTime > 0 {
-               endChan = time.NewTicker(*RunTime).C
-       }
        var bytesIn uint64
        var bytesOut uint64
        var errors uint64
        var rateIn, rateOut float64
        var maxRateIn, maxRateOut float64
        var bytesIn uint64
        var bytesOut uint64
        var errors uint64
        var rateIn, rateOut float64
        var maxRateIn, maxRateOut float64
-       var exit, abort, printCsv bool
+       var exit, printCsv bool
        csv := log.New(os.Stdout, "", 0)
        csv := log.New(os.Stdout, "", 0)
+       csv.Println()
        csv.Println(csvHeader)
        for {
                select {
        csv.Println(csvHeader)
        for {
                select {
-               case <-tickChan:
-                       printCsv = true
-               case <-endChan:
+               case <-ctx.Done():
                        printCsv = true
                        exit = true
                        printCsv = true
                        exit = true
-               case <-c:
+               case <-tickChan:
                        printCsv = true
                        printCsv = true
-                       abort = true
-                       fmt.Print("\r") // Suppress the ^C print
                case i := <-bytesInChan:
                        bytesIn += i
                case o := <-bytesOutChan:
                case i := <-bytesInChan:
                        bytesIn += i
                case o := <-bytesOutChan:
@@ -203,7 +220,7 @@ func countBeans(wg *sync.WaitGroup, stopCh chan struct{}, stderr *log.Logger) {
                                maxRateOut = rateOut
                        }
                        line := fmt.Sprintf("%v,%v,%v,%.1f,%.1f,%v,%.1f,%.1f,%d,%d,%d,%t,%t,%d,%d,%s,%s,%s,%s,%d",
                                maxRateOut = rateOut
                        }
                        line := fmt.Sprintf("%v,%v,%v,%.1f,%.1f,%v,%.1f,%.1f,%d,%d,%d,%t,%t,%d,%d,%s,%s,%s,%s,%d",
-                               time.Now().Format("2006-01-02 15:04:05"),
+                               time.Now().Format("2006/01/02 15:04:05"),
                                elapsed,
                                bytesIn, rateIn, maxRateIn,
                                bytesOut, rateOut, maxRateOut,
                                elapsed,
                                bytesIn, rateIn, maxRateIn,
                                bytesOut, rateOut, maxRateOut,
@@ -222,22 +239,16 @@ func countBeans(wg *sync.WaitGroup, stopCh chan struct{}, stderr *log.Logger) {
                        )
                        csv.Println(line)
                        if exit {
                        )
                        csv.Println(line)
                        if exit {
-                               summary += line
+                               newSummary += line
+                               return
                        }
                        printCsv = false
                }
                        }
                        printCsv = false
                }
-               if abort {
-                       os.Exit(0)
-               }
-               if exit {
-                       close(stopCh)
-                       break
-               }
        }
        }
+       return
 }
 
 }
 
-func makeBufs(wg *sync.WaitGroup, nextBuf chan<- []byte, threadID int, stopCh <-chan struct{}, stderr *log.Logger) {
-       defer wg.Done()
+func makeBufs(nextBuf chan<- []byte, threadID int, stderr *log.Logger) {
        buf := make([]byte, *BlockSize)
        if *VaryThread {
                binary.PutVarint(buf, int64(threadID))
        buf := make([]byte, *BlockSize)
        if *VaryThread {
                binary.PutVarint(buf, int64(threadID))
@@ -254,44 +265,27 @@ func makeBufs(wg *sync.WaitGroup, nextBuf chan<- []byte, threadID int, stopCh <-
                        }
                        buf = append(rnd, buf[randSize:]...)
                }
                        }
                        buf = append(rnd, buf[randSize:]...)
                }
-               select {
-               case <-stopCh:
-                       close(nextBuf)
-                       return
-               case nextBuf <- buf:
-               }
+               nextBuf <- buf
        }
 }
 
        }
 }
 
-func doWrites(wg *sync.WaitGroup, kc *keepclient.KeepClient, nextBuf <-chan []byte, nextLocator *atomic.Value, stopCh <-chan struct{}, stderr *log.Logger) {
-       defer wg.Done()
-
-       for {
-               select {
-               case <-stopCh:
-                       return
-               case buf := <-nextBuf:
-                       locator, _, err := kc.PutB(buf)
-                       if err != nil {
-                               stderr.Print(err)
-                               errorsChan <- struct{}{}
-                               continue
-                       }
-                       select {
-                       case <-stopCh:
-                               return
-                       case bytesOutChan <- uint64(len(buf)):
-                       }
-                       nextLocator.Store(locator)
+func doWrites(ctx context.Context, kc *keepclient.KeepClient, nextBuf <-chan []byte, nextLocator *atomic.Value, stderr *log.Logger) {
+       for ctx.Err() == nil {
+               buf := <-nextBuf
+               locator, _, err := kc.PutB(buf)
+               if err != nil {
+                       stderr.Print(err)
+                       errorsChan <- struct{}{}
+                       continue
                }
                }
+               bytesOutChan <- uint64(len(buf))
+               nextLocator.Store(locator)
        }
 }
 
        }
 }
 
-func doReads(wg *sync.WaitGroup, kc *keepclient.KeepClient, nextLocator *atomic.Value, stopCh <-chan struct{}, stderr *log.Logger) {
-       defer wg.Done()
-
+func doReads(ctx context.Context, kc *keepclient.KeepClient, nextLocator *atomic.Value, stderr *log.Logger) {
        var locator string
        var locator string
-       for {
+       for ctx.Err() == nil {
                locator = nextLocator.Load().(string)
                rdr, size, url, err := kc.Get(locator)
                if err != nil {
                locator = nextLocator.Load().(string)
                rdr, size, url, err := kc.Get(locator)
                if err != nil {
@@ -309,11 +303,7 @@ func doReads(wg *sync.WaitGroup, kc *keepclient.KeepClient, nextLocator *atomic.
                        // partial/corrupt responses: we are measuring
                        // throughput, not resource consumption.
                }
                        // partial/corrupt responses: we are measuring
                        // throughput, not resource consumption.
                }
-               select {
-               case <-stopCh:
-                       return
-               case bytesInChan <- uint64(n):
-               }
+               bytesInChan <- uint64(n)
        }
 }
 
        }
 }