8465: arvados-cwl-runner support for stdin and stderr redirection with containers.
[arvados.git] / services / crunch-dispatch-slurm / squeue.go
1 package main
2
3 import (
4         "bytes"
5         "log"
6         "os/exec"
7         "strings"
8         "sync"
9         "time"
10 )
11
12 // Squeue implements asynchronous polling monitor of the SLURM queue using the
13 // command 'squeue'.
14 type SqueueChecker struct {
15         Period    time.Duration
16         uuids     map[string]bool
17         startOnce sync.Once
18         done      chan struct{}
19         sync.Cond
20 }
21
22 func squeueFunc() *exec.Cmd {
23         return exec.Command("squeue", "--all", "--format=%j")
24 }
25
26 var squeueCmd = squeueFunc
27
28 // HasUUID checks if a given container UUID is in the slurm queue.
29 // This does not run squeue directly, but instead blocks until woken
30 // up by next successful update of squeue.
31 func (sqc *SqueueChecker) HasUUID(uuid string) bool {
32         sqc.startOnce.Do(sqc.start)
33
34         sqc.L.Lock()
35         defer sqc.L.Unlock()
36
37         // block until next squeue broadcast signaling an update.
38         sqc.Wait()
39         return sqc.uuids[uuid]
40 }
41
42 // Stop stops the squeue monitoring goroutine. Do not call HasUUID
43 // after calling Stop.
44 func (sqc *SqueueChecker) Stop() {
45         if sqc.done != nil {
46                 close(sqc.done)
47         }
48 }
49
50 // check gets the names of jobs in the SLURM queue (running and
51 // queued). If it succeeds, it updates squeue.uuids and wakes up any
52 // goroutines that are waiting in HasUUID() or All().
53 func (sqc *SqueueChecker) check() {
54         // Mutex between squeue sync and running sbatch or scancel.  This
55         // establishes a sequence so that squeue doesn't run concurrently with
56         // sbatch or scancel; the next update of squeue will occur only after
57         // sbatch or scancel has completed.
58         sqc.L.Lock()
59         defer sqc.L.Unlock()
60
61         cmd := squeueCmd()
62         stdout, stderr := &bytes.Buffer{}, &bytes.Buffer{}
63         cmd.Stdout, cmd.Stderr = stdout, stderr
64         if err := cmd.Run(); err != nil {
65                 log.Printf("Error running %q %q: %s %q", cmd.Path, cmd.Args, err, stderr.String())
66                 return
67         }
68
69         uuids := strings.Split(stdout.String(), "\n")
70         sqc.uuids = make(map[string]bool, len(uuids))
71         for _, uuid := range uuids {
72                 sqc.uuids[uuid] = true
73         }
74         sqc.Broadcast()
75 }
76
77 // Initialize, and start a goroutine to call check() once per
78 // squeue.Period until terminated by calling Stop().
79 func (sqc *SqueueChecker) start() {
80         sqc.L = &sync.Mutex{}
81         sqc.done = make(chan struct{})
82         go func() {
83                 ticker := time.NewTicker(sqc.Period)
84                 for {
85                         select {
86                         case <-sqc.done:
87                                 ticker.Stop()
88                                 return
89                         case <-ticker.C:
90                                 sqc.check()
91                         }
92                 }
93         }()
94 }
95
96 // All waits for the next squeue invocation, and returns all job
97 // names reported by squeue.
98 func (sqc *SqueueChecker) All() []string {
99         sqc.startOnce.Do(sqc.start)
100         sqc.L.Lock()
101         defer sqc.L.Unlock()
102         sqc.Wait()
103         var uuids []string
104         for uuid := range sqc.uuids {
105                 uuids = append(uuids, uuid)
106         }
107         return uuids
108 }