Merge branch 'master' into 8017-slurm-runtime-constraints
[arvados.git] / services / crunch-dispatch-slurm / crunch-dispatch-slurm.go
1 package main
2
3 import (
4         "flag"
5         "fmt"
6         "git.curoverse.com/arvados.git/sdk/go/arvadosclient"
7         "io/ioutil"
8         "log"
9         "os"
10         "os/exec"
11         "os/signal"
12         "strconv"
13         "sync"
14         "syscall"
15         "time"
16 )
17
18 func main() {
19         err := doMain()
20         if err != nil {
21                 log.Fatalf("%q", err)
22         }
23 }
24
25 var (
26         arv              arvadosclient.ArvadosClient
27         runningCmds      map[string]*exec.Cmd
28         runningCmdsMutex sync.Mutex
29         waitGroup        sync.WaitGroup
30         doneProcessing   chan bool
31         sigChan          chan os.Signal
32 )
33
34 func doMain() error {
35         flags := flag.NewFlagSet("crunch-dispatch-slurm", flag.ExitOnError)
36
37         pollInterval := flags.Int(
38                 "poll-interval",
39                 10,
40                 "Interval in seconds to poll for queued containers")
41
42         priorityPollInterval := flags.Int(
43                 "container-priority-poll-interval",
44                 60,
45                 "Interval in seconds to check priority of a dispatched container")
46
47         crunchRunCommand := flags.String(
48                 "crunch-run-command",
49                 "/usr/bin/crunch-run",
50                 "Crunch command to run container")
51
52         finishCommand := flags.String(
53                 "finish-command",
54                 "/usr/bin/crunch-finish-slurm.sh",
55                 "Command to run from strigger when job is finished")
56
57         // Parse args; omit the first arg which is the command name
58         flags.Parse(os.Args[1:])
59
60         var err error
61         arv, err = arvadosclient.MakeArvadosClient()
62         if err != nil {
63                 return err
64         }
65
66         // Channel to terminate
67         doneProcessing = make(chan bool)
68
69         // Graceful shutdown
70         sigChan = make(chan os.Signal, 1)
71         signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT)
72         go func(sig <-chan os.Signal) {
73                 for sig := range sig {
74                         log.Printf("Caught signal: %v", sig)
75                         doneProcessing <- true
76                 }
77         }(sigChan)
78
79         // Run all queued containers
80         runQueuedContainers(*pollInterval, *priorityPollInterval, *crunchRunCommand, *finishCommand)
81
82         // Wait for all running crunch jobs to complete / terminate
83         waitGroup.Wait()
84
85         return nil
86 }
87
88 // Poll for queued containers using pollInterval.
89 // Invoke dispatchSlurm for each ticker cycle, which will run all the queued containers.
90 //
91 // Any errors encountered are logged but the program would continue to run (not exit).
92 // This is because, once one or more crunch jobs are running,
93 // we would need to wait for them complete.
94 func runQueuedContainers(pollInterval, priorityPollInterval int, crunchRunCommand, finishCommand string) {
95         ticker := time.NewTicker(time.Duration(pollInterval) * time.Second)
96
97         for {
98                 select {
99                 case <-ticker.C:
100                         dispatchSlurm(priorityPollInterval, crunchRunCommand, finishCommand)
101                 case <-doneProcessing:
102                         ticker.Stop()
103                         return
104                 }
105         }
106 }
107
108 // Container data
109 type Container struct {
110         UUID               string         `json:"uuid"`
111         State              string         `json:"state"`
112         Priority           int            `json:"priority"`
113         RuntimeConstraints map[string]int `json:"runtime_constraints"`
114 }
115
116 // ContainerList is a list of the containers from api
117 type ContainerList struct {
118         Items []Container `json:"items"`
119 }
120
121 // Get the list of queued containers from API server and invoke run for each container.
122 func dispatchSlurm(priorityPollInterval int, crunchRunCommand, finishCommand string) {
123         params := arvadosclient.Dict{
124                 "filters": [][]string{[]string{"state", "=", "Queued"}},
125         }
126
127         var containers ContainerList
128         err := arv.List("containers", params, &containers)
129         if err != nil {
130                 log.Printf("Error getting list of queued containers: %q", err)
131                 return
132         }
133
134         for i := 0; i < len(containers.Items); i++ {
135                 log.Printf("About to submit queued container %v", containers.Items[i].UUID)
136                 // Run the container
137                 go run(containers.Items[i], crunchRunCommand, finishCommand, priorityPollInterval)
138         }
139 }
140
141 // sbatchCmd
142 func sbatchFunc(container Container) *exec.Cmd {
143         return exec.Command("sbatch", "--share", "--parsable",
144                 "--job-name="+container.UUID,
145                 "--mem="+strconv.Itoa(container.RuntimeConstraints["ram"]),
146                 "--cpus-per-task="+strconv.Itoa(container.RuntimeConstraints["vcpus"]))
147 }
148
149 var sbatchCmd = sbatchFunc
150
151 // striggerCmd
152 func striggerFunc(jobid, containerUUID, finishCommand, apiHost, apiToken, apiInsecure string) *exec.Cmd {
153         return exec.Command("strigger", "--set", "--jobid="+jobid, "--fini",
154                 fmt.Sprintf("--program=%s %s %s %s %s", finishCommand, apiHost, apiToken, apiInsecure, containerUUID))
155 }
156
157 var striggerCmd = striggerFunc
158
159 // Submit job to slurm using sbatch.
160 func submit(container Container, crunchRunCommand string) (jobid string, submitErr error) {
161         submitErr = nil
162
163         // Mark record as complete if anything errors out.
164         defer func() {
165                 if submitErr != nil {
166                         // This really should be an "Error" state, see #8018
167                         updateErr := arv.Update("containers", container.UUID,
168                                 arvadosclient.Dict{
169                                         "container": arvadosclient.Dict{"state": "Complete"}},
170                                 nil)
171                         if updateErr != nil {
172                                 log.Printf("Error updating container state to 'Complete' for %v: %q", container.UUID, updateErr)
173                         }
174                 }
175         }()
176
177         // Create the command and attach to stdin/stdout
178         cmd := sbatchCmd(container)
179         stdinWriter, stdinerr := cmd.StdinPipe()
180         if stdinerr != nil {
181                 submitErr = fmt.Errorf("Error creating stdin pipe %v: %q", container.UUID, stdinerr)
182                 return
183         }
184
185         stdoutReader, stdoutErr := cmd.StdoutPipe()
186         if stdoutErr != nil {
187                 submitErr = fmt.Errorf("Error creating stdout pipe %v: %q", container.UUID, stdoutErr)
188                 return
189         }
190
191         stderrReader, stderrErr := cmd.StderrPipe()
192         if stderrErr != nil {
193                 submitErr = fmt.Errorf("Error creating stderr pipe %v: %q", container.UUID, stderrErr)
194                 return
195         }
196
197         err := cmd.Start()
198         if err != nil {
199                 submitErr = fmt.Errorf("Error starting %v: %v", cmd.Args, err)
200                 return
201         }
202
203         stdoutChan := make(chan []byte)
204         go func() {
205                 b, _ := ioutil.ReadAll(stdoutReader)
206                 stdoutChan <- b
207                 close(stdoutChan)
208         }()
209
210         stderrChan := make(chan []byte)
211         go func() {
212                 b, _ := ioutil.ReadAll(stderrReader)
213                 stderrChan <- b
214                 close(stderrChan)
215         }()
216
217         // Send a tiny script on stdin to execute the crunch-run command
218         // slurm actually enforces that this must be a #! script
219         fmt.Fprintf(stdinWriter, "#!/bin/sh\nexec '%s' '%s'\n", crunchRunCommand, container.UUID)
220         stdinWriter.Close()
221
222         err = cmd.Wait()
223
224         stdoutMsg := <-stdoutChan
225         stderrmsg := <-stderrChan
226
227         if err != nil {
228                 submitErr = fmt.Errorf("Container submission failed %v: %v %v", cmd.Args, err, stderrmsg)
229                 return
230         }
231
232         // If everything worked out, got the jobid on stdout
233         jobid = string(stdoutMsg)
234
235         return
236 }
237
238 // finalizeRecordOnFinish uses 'strigger' command to register a script that will run on
239 // the slurm controller when the job finishes.
240 func finalizeRecordOnFinish(jobid, containerUUID, finishCommand, apiHost, apiToken, apiInsecure string) {
241         cmd := striggerCmd(jobid, containerUUID, finishCommand, apiHost, apiToken, apiInsecure)
242         cmd.Stdout = os.Stdout
243         cmd.Stderr = os.Stderr
244         err := cmd.Run()
245         if err != nil {
246                 log.Printf("While setting up strigger: %v", err)
247         }
248 }
249
250 // Run a queued container.
251 // Set container state to locked (TBD)
252 // Submit job to slurm to execute crunch-run command for the container
253 // If the container priority becomes zero while crunch job is still running, cancel the job.
254 func run(container Container, crunchRunCommand, finishCommand string, priorityPollInterval int) {
255
256         jobid, err := submit(container, crunchRunCommand)
257         if err != nil {
258                 log.Printf("Error queuing container run: %v", err)
259                 return
260         }
261
262         insecure := "0"
263         if arv.ApiInsecure {
264                 insecure = "1"
265         }
266         finalizeRecordOnFinish(jobid, container.UUID, finishCommand, arv.ApiServer, arv.ApiToken, insecure)
267
268         // Update container status to Running, this is a temporary workaround
269         // to avoid resubmitting queued containers because record locking isn't
270         // implemented yet.
271         err = arv.Update("containers", container.UUID,
272                 arvadosclient.Dict{
273                         "container": arvadosclient.Dict{"state": "Running"}},
274                 nil)
275         if err != nil {
276                 log.Printf("Error updating container state to 'Running' for %v: %q", container.UUID, err)
277         }
278
279         log.Printf("Submitted container run for %v", container.UUID)
280
281         containerUUID := container.UUID
282
283         // A goroutine to terminate the runner if container priority becomes zero
284         priorityTicker := time.NewTicker(time.Duration(priorityPollInterval) * time.Second)
285         go func() {
286                 for _ = range priorityTicker.C {
287                         var container Container
288                         err := arv.Get("containers", containerUUID, nil, &container)
289                         if err != nil {
290                                 log.Printf("Error getting container info for %v: %q", container.UUID, err)
291                         } else {
292                                 if container.Priority == 0 {
293                                         log.Printf("Canceling container %v", container.UUID)
294                                         priorityTicker.Stop()
295                                         cancelcmd := exec.Command("scancel", "--name="+container.UUID)
296                                         cancelcmd.Run()
297                                 }
298                                 if container.State == "Complete" {
299                                         priorityTicker.Stop()
300                                 }
301                         }
302                 }
303         }()
304
305 }