Merge branch '8784-dir-listings'
[arvados.git] / services / crunch-dispatch-slurm / squeue.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package main
6
7 import (
8         "bytes"
9         "log"
10         "os/exec"
11         "strings"
12         "sync"
13         "time"
14 )
15
16 // Squeue implements asynchronous polling monitor of the SLURM queue using the
17 // command 'squeue'.
18 type SqueueChecker struct {
19         Period    time.Duration
20         uuids     map[string]bool
21         startOnce sync.Once
22         done      chan struct{}
23         sync.Cond
24 }
25
26 func squeueFunc() *exec.Cmd {
27         return exec.Command("squeue", "--all", "--format=%j")
28 }
29
30 var squeueCmd = squeueFunc
31
32 // HasUUID checks if a given container UUID is in the slurm queue.
33 // This does not run squeue directly, but instead blocks until woken
34 // up by next successful update of squeue.
35 func (sqc *SqueueChecker) HasUUID(uuid string) bool {
36         sqc.startOnce.Do(sqc.start)
37
38         sqc.L.Lock()
39         defer sqc.L.Unlock()
40
41         // block until next squeue broadcast signaling an update.
42         sqc.Wait()
43         return sqc.uuids[uuid]
44 }
45
46 // Stop stops the squeue monitoring goroutine. Do not call HasUUID
47 // after calling Stop.
48 func (sqc *SqueueChecker) Stop() {
49         if sqc.done != nil {
50                 close(sqc.done)
51         }
52 }
53
54 // check gets the names of jobs in the SLURM queue (running and
55 // queued). If it succeeds, it updates squeue.uuids and wakes up any
56 // goroutines that are waiting in HasUUID() or All().
57 func (sqc *SqueueChecker) check() {
58         // Mutex between squeue sync and running sbatch or scancel.  This
59         // establishes a sequence so that squeue doesn't run concurrently with
60         // sbatch or scancel; the next update of squeue will occur only after
61         // sbatch or scancel has completed.
62         sqc.L.Lock()
63         defer sqc.L.Unlock()
64
65         cmd := squeueCmd()
66         stdout, stderr := &bytes.Buffer{}, &bytes.Buffer{}
67         cmd.Stdout, cmd.Stderr = stdout, stderr
68         if err := cmd.Run(); err != nil {
69                 log.Printf("Error running %q %q: %s %q", cmd.Path, cmd.Args, err, stderr.String())
70                 return
71         }
72
73         uuids := strings.Split(stdout.String(), "\n")
74         sqc.uuids = make(map[string]bool, len(uuids))
75         for _, uuid := range uuids {
76                 sqc.uuids[uuid] = true
77         }
78         sqc.Broadcast()
79 }
80
81 // Initialize, and start a goroutine to call check() once per
82 // squeue.Period until terminated by calling Stop().
83 func (sqc *SqueueChecker) start() {
84         sqc.L = &sync.Mutex{}
85         sqc.done = make(chan struct{})
86         go func() {
87                 ticker := time.NewTicker(sqc.Period)
88                 for {
89                         select {
90                         case <-sqc.done:
91                                 ticker.Stop()
92                                 return
93                         case <-ticker.C:
94                                 sqc.check()
95                         }
96                 }
97         }()
98 }
99
100 // All waits for the next squeue invocation, and returns all job
101 // names reported by squeue.
102 func (sqc *SqueueChecker) All() []string {
103         sqc.startOnce.Do(sqc.start)
104         sqc.L.Lock()
105         defer sqc.L.Unlock()
106         sqc.Wait()
107         var uuids []string
108         for uuid := range sqc.uuids {
109                 uuids = append(uuids, uuid)
110         }
111         return uuids
112 }