Merge branch 'master' into 9161-node-state-fixes
[arvados.git] / services / crunch-dispatch-slurm / crunch-dispatch-slurm.go
1 package main
2
3 import (
4         "bufio"
5         "flag"
6         "fmt"
7         "git.curoverse.com/arvados.git/sdk/go/arvadosclient"
8         "io/ioutil"
9         "log"
10         "math"
11         "os"
12         "os/exec"
13         "os/signal"
14         "strconv"
15         "sync"
16         "syscall"
17         "time"
18 )
19
20 func main() {
21         err := doMain()
22         if err != nil {
23                 log.Fatalf("%q", err)
24         }
25 }
26
27 var (
28         arv              arvadosclient.ArvadosClient
29         runningCmds      map[string]*exec.Cmd
30         runningCmdsMutex sync.Mutex
31         waitGroup        sync.WaitGroup
32         doneProcessing   chan bool
33         sigChan          chan os.Signal
34 )
35
36 func doMain() error {
37         flags := flag.NewFlagSet("crunch-dispatch-slurm", flag.ExitOnError)
38
39         pollInterval := flags.Int(
40                 "poll-interval",
41                 10,
42                 "Interval in seconds to poll for queued containers")
43
44         priorityPollInterval := flags.Int(
45                 "container-priority-poll-interval",
46                 60,
47                 "Interval in seconds to check priority of a dispatched container")
48
49         crunchRunCommand := flags.String(
50                 "crunch-run-command",
51                 "/usr/bin/crunch-run",
52                 "Crunch command to run container")
53
54         finishCommand := flags.String(
55                 "finish-command",
56                 "/usr/bin/crunch-finish-slurm.sh",
57                 "Command to run from strigger when job is finished")
58
59         // Parse args; omit the first arg which is the command name
60         flags.Parse(os.Args[1:])
61
62         var err error
63         arv, err = arvadosclient.MakeArvadosClient()
64         if err != nil {
65                 return err
66         }
67
68         // Channel to terminate
69         doneProcessing = make(chan bool)
70
71         // Graceful shutdown
72         sigChan = make(chan os.Signal, 1)
73         signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT)
74         go func(sig <-chan os.Signal) {
75                 for sig := range sig {
76                         log.Printf("Caught signal: %v", sig)
77                         doneProcessing <- true
78                 }
79         }(sigChan)
80
81         // Run all queued containers
82         runQueuedContainers(*pollInterval, *priorityPollInterval, *crunchRunCommand, *finishCommand)
83
84         // Wait for all running crunch jobs to complete / terminate
85         waitGroup.Wait()
86
87         return nil
88 }
89
90 type apiClientAuthorization struct {
91         UUID     string `json:"uuid"`
92         APIToken string `json:"api_token"`
93 }
94
95 type apiClientAuthorizationList struct {
96         Items []apiClientAuthorization `json:"items"`
97 }
98
99 // Poll for queued containers using pollInterval.
100 // Invoke dispatchSlurm for each ticker cycle, which will run all the queued containers.
101 //
102 // Any errors encountered are logged but the program would continue to run (not exit).
103 // This is because, once one or more crunch jobs are running,
104 // we would need to wait for them complete.
105 func runQueuedContainers(pollInterval, priorityPollInterval int, crunchRunCommand, finishCommand string) {
106         var auth apiClientAuthorization
107         err := arv.Call("GET", "api_client_authorizations", "", "current", nil, &auth)
108         if err != nil {
109                 log.Printf("Error getting my token UUID: %v", err)
110                 return
111         }
112
113         ticker := time.NewTicker(time.Duration(pollInterval) * time.Second)
114         for {
115                 select {
116                 case <-ticker.C:
117                         dispatchSlurm(auth, time.Duration(priorityPollInterval)*time.Second, crunchRunCommand, finishCommand)
118                 case <-doneProcessing:
119                         ticker.Stop()
120                         return
121                 }
122         }
123 }
124
125 // Container data
126 type Container struct {
127         UUID               string           `json:"uuid"`
128         State              string           `json:"state"`
129         Priority           int              `json:"priority"`
130         RuntimeConstraints map[string]int64 `json:"runtime_constraints"`
131         LockedByUUID       string           `json:"locked_by_uuid"`
132 }
133
134 // ContainerList is a list of the containers from api
135 type ContainerList struct {
136         Items []Container `json:"items"`
137 }
138
139 // Get the list of queued containers from API server and invoke run
140 // for each container.
141 func dispatchSlurm(auth apiClientAuthorization, pollInterval time.Duration, crunchRunCommand, finishCommand string) {
142         params := arvadosclient.Dict{
143                 "filters": [][]interface{}{{"state", "in", []string{"Queued", "Locked"}}},
144         }
145
146         var containers ContainerList
147         err := arv.List("containers", params, &containers)
148         if err != nil {
149                 log.Printf("Error getting list of queued containers: %q", err)
150                 return
151         }
152
153         for _, container := range containers.Items {
154                 if container.State == "Locked" {
155                         if container.LockedByUUID != auth.UUID {
156                                 // Locked by a different dispatcher
157                                 continue
158                         } else if checkMine(container.UUID) {
159                                 // I already have a goroutine running
160                                 // for this container: it just hasn't
161                                 // gotten past Locked state yet.
162                                 continue
163                         }
164                         log.Printf("WARNING: found container %s already locked by my token %s, but I didn't submit it. "+
165                                 "Assuming it was left behind by a previous dispatch process, and waiting for it to finish.",
166                                 container.UUID, auth.UUID)
167                         setMine(container.UUID, true)
168                         go func() {
169                                 waitContainer(container, pollInterval)
170                                 setMine(container.UUID, false)
171                         }()
172                 }
173                 go run(container, crunchRunCommand, finishCommand, pollInterval)
174         }
175 }
176
177 // sbatchCmd
178 func sbatchFunc(container Container) *exec.Cmd {
179         memPerCPU := math.Ceil((float64(container.RuntimeConstraints["ram"])) / (float64(container.RuntimeConstraints["vcpus"] * 1048576)))
180         return exec.Command("sbatch", "--share", "--parsable",
181                 "--job-name="+container.UUID,
182                 "--mem-per-cpu="+strconv.Itoa(int(memPerCPU)),
183                 "--cpus-per-task="+strconv.Itoa(int(container.RuntimeConstraints["vcpus"])))
184 }
185
186 var sbatchCmd = sbatchFunc
187
188 // striggerCmd
189 func striggerFunc(jobid, containerUUID, finishCommand, apiHost, apiToken, apiInsecure string) *exec.Cmd {
190         return exec.Command("strigger", "--set", "--jobid="+jobid, "--fini",
191                 fmt.Sprintf("--program=%s %s %s %s %s", finishCommand, apiHost, apiToken, apiInsecure, containerUUID))
192 }
193
194 var striggerCmd = striggerFunc
195
196 // Submit job to slurm using sbatch.
197 func submit(container Container, crunchRunCommand string) (jobid string, submitErr error) {
198         submitErr = nil
199
200         defer func() {
201                 // If we didn't get as far as submitting a slurm job,
202                 // unlock the container and return it to the queue.
203                 if submitErr == nil {
204                         // OK, no cleanup needed
205                         return
206                 }
207                 err := arv.Update("containers", container.UUID,
208                         arvadosclient.Dict{
209                                 "container": arvadosclient.Dict{"state": "Queued"}},
210                         nil)
211                 if err != nil {
212                         log.Printf("Error unlocking container %s: %v", container.UUID, err)
213                 }
214         }()
215
216         // Create the command and attach to stdin/stdout
217         cmd := sbatchCmd(container)
218         stdinWriter, stdinerr := cmd.StdinPipe()
219         if stdinerr != nil {
220                 submitErr = fmt.Errorf("Error creating stdin pipe %v: %q", container.UUID, stdinerr)
221                 return
222         }
223
224         stdoutReader, stdoutErr := cmd.StdoutPipe()
225         if stdoutErr != nil {
226                 submitErr = fmt.Errorf("Error creating stdout pipe %v: %q", container.UUID, stdoutErr)
227                 return
228         }
229
230         stderrReader, stderrErr := cmd.StderrPipe()
231         if stderrErr != nil {
232                 submitErr = fmt.Errorf("Error creating stderr pipe %v: %q", container.UUID, stderrErr)
233                 return
234         }
235
236         err := cmd.Start()
237         if err != nil {
238                 submitErr = fmt.Errorf("Error starting %v: %v", cmd.Args, err)
239                 return
240         }
241
242         stdoutChan := make(chan []byte)
243         go func() {
244                 b, _ := ioutil.ReadAll(stdoutReader)
245                 stdoutReader.Close()
246                 stdoutChan <- b
247                 close(stdoutChan)
248         }()
249
250         stderrChan := make(chan []byte)
251         go func() {
252                 b, _ := ioutil.ReadAll(stderrReader)
253                 stderrReader.Close()
254                 stderrChan <- b
255                 close(stderrChan)
256         }()
257
258         // Send a tiny script on stdin to execute the crunch-run command
259         // slurm actually enforces that this must be a #! script
260         fmt.Fprintf(stdinWriter, "#!/bin/sh\nexec '%s' '%s'\n", crunchRunCommand, container.UUID)
261         stdinWriter.Close()
262
263         err = cmd.Wait()
264
265         stdoutMsg := <-stdoutChan
266         stderrmsg := <-stderrChan
267
268         if err != nil {
269                 submitErr = fmt.Errorf("Container submission failed %v: %v %v", cmd.Args, err, stderrmsg)
270                 return
271         }
272
273         // If everything worked out, got the jobid on stdout
274         jobid = string(stdoutMsg)
275
276         return
277 }
278
279 // finalizeRecordOnFinish uses 'strigger' command to register a script that will run on
280 // the slurm controller when the job finishes.
281 func finalizeRecordOnFinish(jobid, containerUUID, finishCommand, apiHost, apiToken, apiInsecure string) {
282         cmd := striggerCmd(jobid, containerUUID, finishCommand, apiHost, apiToken, apiInsecure)
283         cmd.Stdout = os.Stdout
284         cmd.Stderr = os.Stderr
285         err := cmd.Run()
286         if err != nil {
287                 log.Printf("While setting up strigger: %v", err)
288                 // BUG: we drop the error here and forget about it. A
289                 // human has to notice the container is stuck in
290                 // Running state, and fix it manually.
291         }
292 }
293
294 // Run a queued container: [1] Set container state to locked. [2]
295 // Execute crunch-run as a slurm batch job. [3] waitContainer().
296 func run(container Container, crunchRunCommand, finishCommand string, pollInterval time.Duration) {
297         setMine(container.UUID, true)
298         defer setMine(container.UUID, false)
299
300         // Update container status to Locked. This will fail if
301         // another dispatcher (token) has already locked it. It will
302         // succeed if *this* dispatcher has already locked it.
303         err := arv.Update("containers", container.UUID,
304                 arvadosclient.Dict{
305                         "container": arvadosclient.Dict{"state": "Locked"}},
306                 nil)
307         if err != nil {
308                 log.Printf("Error updating container state to 'Locked' for %v: %q", container.UUID, err)
309                 return
310         }
311
312         log.Printf("About to submit queued container %v", container.UUID)
313
314         jobid, err := submit(container, crunchRunCommand)
315         if err != nil {
316                 log.Printf("Error submitting container %s to slurm: %v", container.UUID, err)
317                 return
318         }
319
320         insecure := "0"
321         if arv.ApiInsecure {
322                 insecure = "1"
323         }
324         finalizeRecordOnFinish(jobid, container.UUID, finishCommand, arv.ApiServer, arv.ApiToken, insecure)
325
326         // Update container status to Running. This will fail if
327         // another dispatcher (token) has already locked it. It will
328         // succeed if *this* dispatcher has already locked it.
329         err = arv.Update("containers", container.UUID,
330                 arvadosclient.Dict{
331                         "container": arvadosclient.Dict{"state": "Running"}},
332                 nil)
333         if err != nil {
334                 log.Printf("Error updating container state to 'Running' for %v: %q", container.UUID, err)
335         }
336         log.Printf("Submitted container %v to slurm", container.UUID)
337         waitContainer(container, pollInterval)
338 }
339
340 // Wait for a container to finish. Cancel the slurm job if the
341 // container priority changes to zero before it ends.
342 func waitContainer(container Container, pollInterval time.Duration) {
343         log.Printf("Monitoring container %v started", container.UUID)
344         defer log.Printf("Monitoring container %v finished", container.UUID)
345
346         pollTicker := time.NewTicker(pollInterval)
347         defer pollTicker.Stop()
348         for _ = range pollTicker.C {
349                 var updated Container
350                 err := arv.Get("containers", container.UUID, nil, &updated)
351                 if err != nil {
352                         log.Printf("Error getting container %s: %q", container.UUID, err)
353                         continue
354                 }
355                 if updated.State == "Complete" || updated.State == "Cancelled" {
356                         return
357                 }
358                 if updated.Priority != 0 {
359                         continue
360                 }
361
362                 // Priority is zero, but state is Running or Locked
363                 log.Printf("Canceling container %s", container.UUID)
364
365                 err = exec.Command("scancel", "--name="+container.UUID).Run()
366                 if err != nil {
367                         log.Printf("Error stopping container %s with scancel: %v", container.UUID, err)
368                         if inQ, err := checkSqueue(container.UUID); err != nil {
369                                 log.Printf("Error running squeue: %v", err)
370                                 continue
371                         } else if inQ {
372                                 log.Printf("Container %s is still in squeue; will retry", container.UUID)
373                                 continue
374                         }
375                 }
376
377                 err = arv.Update("containers", container.UUID,
378                         arvadosclient.Dict{
379                                 "container": arvadosclient.Dict{"state": "Cancelled"}},
380                         nil)
381                 if err != nil {
382                         log.Printf("Error updating state for container %s: %s", container.UUID, err)
383                         continue
384                 }
385
386                 return
387         }
388 }
389
390 func checkSqueue(uuid string) (bool, error) {
391         cmd := exec.Command("squeue", "--format=%j")
392         sq, err := cmd.StdoutPipe()
393         if err != nil {
394                 return false, err
395         }
396         cmd.Start()
397         defer cmd.Wait()
398         scanner := bufio.NewScanner(sq)
399         found := false
400         for scanner.Scan() {
401                 if scanner.Text() == uuid {
402                         found = true
403                 }
404         }
405         if err := scanner.Err(); err != nil {
406                 return false, err
407         }
408         return found, nil
409 }
410
411 var mineMutex sync.RWMutex
412 var mineMap = make(map[string]bool)
413
414 // Goroutine-safely add/remove uuid to the set of "my" containers,
415 // i.e., ones for which this process has a goroutine running.
416 func setMine(uuid string, t bool) {
417         mineMutex.Lock()
418         if t {
419                 mineMap[uuid] = true
420         } else {
421                 delete(mineMap, uuid)
422         }
423         mineMutex.Unlock()
424 }
425
426 // Check whether there is already a goroutine running for this
427 // container.
428 func checkMine(uuid string) bool {
429         mineMutex.RLocker().Lock()
430         defer mineMutex.RLocker().Unlock()
431         return mineMap[uuid]
432 }