8017: RuntimeConstraints uses int64
[arvados.git] / services / crunch-dispatch-slurm / crunch-dispatch-slurm.go
1 package main
2
3 import (
4         "flag"
5         "fmt"
6         "git.curoverse.com/arvados.git/sdk/go/arvadosclient"
7         "io/ioutil"
8         "log"
9         "math"
10         "os"
11         "os/exec"
12         "os/signal"
13         "strconv"
14         "sync"
15         "syscall"
16         "time"
17 )
18
19 func main() {
20         err := doMain()
21         if err != nil {
22                 log.Fatalf("%q", err)
23         }
24 }
25
26 var (
27         arv              arvadosclient.ArvadosClient
28         runningCmds      map[string]*exec.Cmd
29         runningCmdsMutex sync.Mutex
30         waitGroup        sync.WaitGroup
31         doneProcessing   chan bool
32         sigChan          chan os.Signal
33 )
34
35 func doMain() error {
36         flags := flag.NewFlagSet("crunch-dispatch-slurm", flag.ExitOnError)
37
38         pollInterval := flags.Int(
39                 "poll-interval",
40                 10,
41                 "Interval in seconds to poll for queued containers")
42
43         priorityPollInterval := flags.Int(
44                 "container-priority-poll-interval",
45                 60,
46                 "Interval in seconds to check priority of a dispatched container")
47
48         crunchRunCommand := flags.String(
49                 "crunch-run-command",
50                 "/usr/bin/crunch-run",
51                 "Crunch command to run container")
52
53         finishCommand := flags.String(
54                 "finish-command",
55                 "/usr/bin/crunch-finish-slurm.sh",
56                 "Command to run from strigger when job is finished")
57
58         // Parse args; omit the first arg which is the command name
59         flags.Parse(os.Args[1:])
60
61         var err error
62         arv, err = arvadosclient.MakeArvadosClient()
63         if err != nil {
64                 return err
65         }
66
67         // Channel to terminate
68         doneProcessing = make(chan bool)
69
70         // Graceful shutdown
71         sigChan = make(chan os.Signal, 1)
72         signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT)
73         go func(sig <-chan os.Signal) {
74                 for sig := range sig {
75                         log.Printf("Caught signal: %v", sig)
76                         doneProcessing <- true
77                 }
78         }(sigChan)
79
80         // Run all queued containers
81         runQueuedContainers(*pollInterval, *priorityPollInterval, *crunchRunCommand, *finishCommand)
82
83         // Wait for all running crunch jobs to complete / terminate
84         waitGroup.Wait()
85
86         return nil
87 }
88
89 // Poll for queued containers using pollInterval.
90 // Invoke dispatchSlurm for each ticker cycle, which will run all the queued containers.
91 //
92 // Any errors encountered are logged but the program would continue to run (not exit).
93 // This is because, once one or more crunch jobs are running,
94 // we would need to wait for them complete.
95 func runQueuedContainers(pollInterval, priorityPollInterval int, crunchRunCommand, finishCommand string) {
96         ticker := time.NewTicker(time.Duration(pollInterval) * time.Second)
97
98         for {
99                 select {
100                 case <-ticker.C:
101                         dispatchSlurm(priorityPollInterval, crunchRunCommand, finishCommand)
102                 case <-doneProcessing:
103                         ticker.Stop()
104                         return
105                 }
106         }
107 }
108
109 // Container data
110 type Container struct {
111         UUID               string           `json:"uuid"`
112         State              string           `json:"state"`
113         Priority           int              `json:"priority"`
114         RuntimeConstraints map[string]int64 `json:"runtime_constraints"`
115 }
116
117 // ContainerList is a list of the containers from api
118 type ContainerList struct {
119         Items []Container `json:"items"`
120 }
121
122 // Get the list of queued containers from API server and invoke run for each container.
123 func dispatchSlurm(priorityPollInterval int, crunchRunCommand, finishCommand string) {
124         params := arvadosclient.Dict{
125                 "filters": [][]string{[]string{"state", "=", "Queued"}},
126         }
127
128         var containers ContainerList
129         err := arv.List("containers", params, &containers)
130         if err != nil {
131                 log.Printf("Error getting list of queued containers: %q", err)
132                 return
133         }
134
135         for i := 0; i < len(containers.Items); i++ {
136                 log.Printf("About to submit queued container %v", containers.Items[i].UUID)
137                 // Run the container
138                 go run(containers.Items[i], crunchRunCommand, finishCommand, priorityPollInterval)
139         }
140 }
141
142 // sbatchCmd
143 func sbatchFunc(container Container) *exec.Cmd {
144         memPerCPU := math.Ceil(float64(container.RuntimeConstraints["ram"]) / float64(container.RuntimeConstraints["vcpus"]*1048576))
145         return exec.Command("sbatch", "--share", "--parsable",
146                 "--job-name="+container.UUID,
147                 "--mem-per-cpu="+strconv.Itoa(int(memPerCPU)),
148                 "--cpus-per-task="+strconv.Itoa(int(container.RuntimeConstraints["vcpus"])))
149 }
150
151 var sbatchCmd = sbatchFunc
152
153 // striggerCmd
154 func striggerFunc(jobid, containerUUID, finishCommand, apiHost, apiToken, apiInsecure string) *exec.Cmd {
155         return exec.Command("strigger", "--set", "--jobid="+jobid, "--fini",
156                 fmt.Sprintf("--program=%s %s %s %s %s", finishCommand, apiHost, apiToken, apiInsecure, containerUUID))
157 }
158
159 var striggerCmd = striggerFunc
160
161 // Submit job to slurm using sbatch.
162 func submit(container Container, crunchRunCommand string) (jobid string, submitErr error) {
163         submitErr = nil
164
165         // Mark record as complete if anything errors out.
166         defer func() {
167                 if submitErr != nil {
168                         // This really should be an "Error" state, see #8018
169                         updateErr := arv.Update("containers", container.UUID,
170                                 arvadosclient.Dict{
171                                         "container": arvadosclient.Dict{"state": "Complete"}},
172                                 nil)
173                         if updateErr != nil {
174                                 log.Printf("Error updating container state to 'Complete' for %v: %q", container.UUID, updateErr)
175                         }
176                 }
177         }()
178
179         // Create the command and attach to stdin/stdout
180         cmd := sbatchCmd(container)
181         stdinWriter, stdinerr := cmd.StdinPipe()
182         if stdinerr != nil {
183                 submitErr = fmt.Errorf("Error creating stdin pipe %v: %q", container.UUID, stdinerr)
184                 return
185         }
186
187         stdoutReader, stdoutErr := cmd.StdoutPipe()
188         if stdoutErr != nil {
189                 submitErr = fmt.Errorf("Error creating stdout pipe %v: %q", container.UUID, stdoutErr)
190                 return
191         }
192
193         stderrReader, stderrErr := cmd.StderrPipe()
194         if stderrErr != nil {
195                 submitErr = fmt.Errorf("Error creating stderr pipe %v: %q", container.UUID, stderrErr)
196                 return
197         }
198
199         err := cmd.Start()
200         if err != nil {
201                 submitErr = fmt.Errorf("Error starting %v: %v", cmd.Args, err)
202                 return
203         }
204
205         stdoutChan := make(chan []byte)
206         go func() {
207                 b, _ := ioutil.ReadAll(stdoutReader)
208                 stdoutChan <- b
209                 close(stdoutChan)
210         }()
211
212         stderrChan := make(chan []byte)
213         go func() {
214                 b, _ := ioutil.ReadAll(stderrReader)
215                 stderrChan <- b
216                 close(stderrChan)
217         }()
218
219         // Send a tiny script on stdin to execute the crunch-run command
220         // slurm actually enforces that this must be a #! script
221         fmt.Fprintf(stdinWriter, "#!/bin/sh\nexec '%s' '%s'\n", crunchRunCommand, container.UUID)
222         stdinWriter.Close()
223
224         err = cmd.Wait()
225
226         stdoutMsg := <-stdoutChan
227         stderrmsg := <-stderrChan
228
229         if err != nil {
230                 submitErr = fmt.Errorf("Container submission failed %v: %v %v", cmd.Args, err, stderrmsg)
231                 return
232         }
233
234         // If everything worked out, got the jobid on stdout
235         jobid = string(stdoutMsg)
236
237         return
238 }
239
240 // finalizeRecordOnFinish uses 'strigger' command to register a script that will run on
241 // the slurm controller when the job finishes.
242 func finalizeRecordOnFinish(jobid, containerUUID, finishCommand, apiHost, apiToken, apiInsecure string) {
243         cmd := striggerCmd(jobid, containerUUID, finishCommand, apiHost, apiToken, apiInsecure)
244         cmd.Stdout = os.Stdout
245         cmd.Stderr = os.Stderr
246         err := cmd.Run()
247         if err != nil {
248                 log.Printf("While setting up strigger: %v", err)
249         }
250 }
251
252 // Run a queued container.
253 // Set container state to locked (TBD)
254 // Submit job to slurm to execute crunch-run command for the container
255 // If the container priority becomes zero while crunch job is still running, cancel the job.
256 func run(container Container, crunchRunCommand, finishCommand string, priorityPollInterval int) {
257
258         jobid, err := submit(container, crunchRunCommand)
259         if err != nil {
260                 log.Printf("Error queuing container run: %v", err)
261                 return
262         }
263
264         insecure := "0"
265         if arv.ApiInsecure {
266                 insecure = "1"
267         }
268         finalizeRecordOnFinish(jobid, container.UUID, finishCommand, arv.ApiServer, arv.ApiToken, insecure)
269
270         // Update container status to Running, this is a temporary workaround
271         // to avoid resubmitting queued containers because record locking isn't
272         // implemented yet.
273         err = arv.Update("containers", container.UUID,
274                 arvadosclient.Dict{
275                         "container": arvadosclient.Dict{"state": "Running"}},
276                 nil)
277         if err != nil {
278                 log.Printf("Error updating container state to 'Running' for %v: %q", container.UUID, err)
279         }
280
281         log.Printf("Submitted container run for %v", container.UUID)
282
283         containerUUID := container.UUID
284
285         // A goroutine to terminate the runner if container priority becomes zero
286         priorityTicker := time.NewTicker(time.Duration(priorityPollInterval) * time.Second)
287         go func() {
288                 for _ = range priorityTicker.C {
289                         var container Container
290                         err := arv.Get("containers", containerUUID, nil, &container)
291                         if err != nil {
292                                 log.Printf("Error getting container info for %v: %q", container.UUID, err)
293                         } else {
294                                 if container.Priority == 0 {
295                                         log.Printf("Canceling container %v", container.UUID)
296                                         priorityTicker.Stop()
297                                         cancelcmd := exec.Command("scancel", "--name="+container.UUID)
298                                         cancelcmd.Run()
299                                 }
300                                 if container.State == "Complete" {
301                                         priorityTicker.Stop()
302                                 }
303                         }
304                 }
305         }()
306
307 }