Merge branch 'master' into 6518-crunch2-dispatch-slurm
[arvados.git] / services / crunch-dispatch-slurm / crunch-dispatch-slurm.go
1 package main
2
3 import (
4         "flag"
5         "fmt"
6         "git.curoverse.com/arvados.git/sdk/go/arvadosclient"
7         "io/ioutil"
8         "log"
9         "os"
10         "os/exec"
11         "os/signal"
12         "sync"
13         "syscall"
14         "time"
15 )
16
17 func main() {
18         err := doMain()
19         if err != nil {
20                 log.Fatalf("%q", err)
21         }
22 }
23
24 var (
25         arv              arvadosclient.ArvadosClient
26         runningCmds      map[string]*exec.Cmd
27         runningCmdsMutex sync.Mutex
28         waitGroup        sync.WaitGroup
29         doneProcessing   chan bool
30         sigChan          chan os.Signal
31 )
32
33 func doMain() error {
34         flags := flag.NewFlagSet("crunch-dispatch-slurm", flag.ExitOnError)
35
36         pollInterval := flags.Int(
37                 "poll-interval",
38                 10,
39                 "Interval in seconds to poll for queued containers")
40
41         priorityPollInterval := flags.Int(
42                 "container-priority-poll-interval",
43                 60,
44                 "Interval in seconds to check priority of a dispatched container")
45
46         crunchRunCommand := flags.String(
47                 "crunch-run-command",
48                 "/usr/bin/crunch-run",
49                 "Crunch command to run container")
50
51         finishCommand := flags.String(
52                 "finish-command",
53                 "/usr/bin/crunch-finish-slurm.sh",
54                 "Command to run from strigger when job is finished")
55
56         // Parse args; omit the first arg which is the command name
57         flags.Parse(os.Args[1:])
58
59         var err error
60         arv, err = arvadosclient.MakeArvadosClient()
61         if err != nil {
62                 return err
63         }
64
65         // Channel to terminate
66         doneProcessing = make(chan bool)
67
68         // Graceful shutdown
69         sigChan = make(chan os.Signal, 1)
70         signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT)
71         go func(sig <-chan os.Signal) {
72                 for sig := range sig {
73                         log.Printf("Caught signal: %v", sig)
74                         doneProcessing <- true
75                 }
76         }(sigChan)
77
78         // Run all queued containers
79         runQueuedContainers(*pollInterval, *priorityPollInterval, *crunchRunCommand, *finishCommand)
80
81         // Wait for all running crunch jobs to complete / terminate
82         waitGroup.Wait()
83
84         return nil
85 }
86
87 // Poll for queued containers using pollInterval.
88 // Invoke dispatchSlurm for each ticker cycle, which will run all the queued containers.
89 //
90 // Any errors encountered are logged but the program would continue to run (not exit).
91 // This is because, once one or more crunch jobs are running,
92 // we would need to wait for them complete.
93 func runQueuedContainers(pollInterval, priorityPollInterval int, crunchRunCommand, finishCommand string) {
94         ticker := time.NewTicker(time.Duration(pollInterval) * time.Second)
95
96         for {
97                 select {
98                 case <-ticker.C:
99                         dispatchSlurm(priorityPollInterval, crunchRunCommand, finishCommand)
100                 case <-doneProcessing:
101                         ticker.Stop()
102                         return
103                 }
104         }
105 }
106
107 // Container data
108 type Container struct {
109         UUID     string `json:"uuid"`
110         State    string `json:"state"`
111         Priority int    `json:"priority"`
112 }
113
114 // ContainerList is a list of the containers from api
115 type ContainerList struct {
116         Items []Container `json:"items"`
117 }
118
119 // Get the list of queued containers from API server and invoke run for each container.
120 func dispatchSlurm(priorityPollInterval int, crunchRunCommand, finishCommand string) {
121         params := arvadosclient.Dict{
122                 "filters": [][]string{[]string{"state", "=", "Queued"}},
123         }
124
125         var containers ContainerList
126         err := arv.List("containers", params, &containers)
127         if err != nil {
128                 log.Printf("Error getting list of queued containers: %q", err)
129                 return
130         }
131
132         for i := 0; i < len(containers.Items); i++ {
133                 log.Printf("About to submit queued container %v", containers.Items[i].UUID)
134                 // Run the container
135                 go run(containers.Items[i], crunchRunCommand, finishCommand, priorityPollInterval)
136         }
137 }
138
139 // sbatchCmd
140 var sbatchCmd = func(uuid string) *exec.Cmd {
141         return exec.Command("sbatch", "--job-name="+uuid, "--share", "--parsable")
142 }
143
144 // striggerCmd
145 var striggerCmd = func(jobid, containerUUID, finishCommand, apiHost, apiToken, apiInsecure string) *exec.Cmd {
146         return exec.Command("strigger", "--set", "--jobid="+jobid, "--fini",
147                 fmt.Sprintf("--program=%s %s %s %s %s", finishCommand, apiHost, apiToken, apiInsecure, containerUUID))
148 }
149
150 // Submit job to slurm using sbatch.
151 func submit(container Container, crunchRunCommand string) (jobid string, submitErr error) {
152         submitErr = nil
153
154         // Mark record as complete if anything errors out.
155         defer func() {
156                 if submitErr != nil {
157                         // This really should be an "Error" state, see #8018
158                         updateErr := arv.Update("containers", container.UUID,
159                                 arvadosclient.Dict{
160                                         "container": arvadosclient.Dict{"state": "Complete"}},
161                                 nil)
162                         if updateErr != nil {
163                                 log.Printf("Error updating container state to 'Complete' for %v: %q", container.UUID, updateErr)
164                         }
165                 }
166         }()
167
168         // Create the command and attach to stdin/stdout
169         cmd := sbatchCmd(container.UUID)
170         stdinWriter, stdinerr := cmd.StdinPipe()
171         if stdinerr != nil {
172                 submitErr = fmt.Errorf("Error creating stdin pipe %v: %q", container.UUID, stdinerr)
173                 return
174         }
175
176         stdoutReader, stdoutErr := cmd.StdoutPipe()
177         if stdoutErr != nil {
178                 submitErr = fmt.Errorf("Error creating stdout pipe %v: %q", container.UUID, stdoutErr)
179                 return
180         }
181
182         stderrReader, stderrErr := cmd.StderrPipe()
183         if stderrErr != nil {
184                 submitErr = fmt.Errorf("Error creating stderr pipe %v: %q", container.UUID, stderrErr)
185                 return
186         }
187
188         err := cmd.Start()
189         if err != nil {
190                 submitErr = fmt.Errorf("Error starting %v: %v", cmd.Args, err)
191                 return
192         }
193
194         stdoutChan := make(chan []byte)
195         go func() {
196                 b, _ := ioutil.ReadAll(stdoutReader)
197                 stdoutChan <- b
198                 close(stdoutChan)
199         }()
200
201         stderrChan := make(chan []byte)
202         go func() {
203                 b, _ := ioutil.ReadAll(stderrReader)
204                 stderrChan <- b
205                 close(stderrChan)
206         }()
207
208         // Send a tiny script on stdin to execute the crunch-run command
209         // slurm actually enforces that this must be a #! script
210         fmt.Fprintf(stdinWriter, "#!/bin/sh\nexec '%s' '%s'\n", crunchRunCommand, container.UUID)
211         stdinWriter.Close()
212
213         err = cmd.Wait()
214
215         stdoutMsg := <-stdoutChan
216         stderrmsg := <-stderrChan
217
218         if err != nil {
219                 submitErr = fmt.Errorf("Container submission failed %v: %v %v", cmd.Args, err, stderrmsg)
220                 return
221         }
222
223         // If everything worked out, got the jobid on stdout
224         jobid = string(stdoutMsg)
225
226         return
227 }
228
229 // finalizeRecordOnFinish uses 'strigger' command to register a script that will run on
230 // the slurm controller when the job finishes.
231 func finalizeRecordOnFinish(jobid, containerUUID, finishCommand, apiHost, apiToken, apiInsecure string) {
232         cmd := striggerCmd(jobid, containerUUID, finishCommand, apiHost, apiToken, apiInsecure)
233         cmd.Stdout = os.Stdout
234         cmd.Stderr = os.Stderr
235         err := cmd.Run()
236         if err != nil {
237                 log.Printf("While setting up strigger: %v", err)
238         }
239 }
240
241 // Run a queued container.
242 // Set container state to locked (TBD)
243 // Submit job to slurm to execute crunch-run command for the container
244 // If the container priority becomes zero while crunch job is still running, cancel the job.
245 func run(container Container, crunchRunCommand, finishCommand string, priorityPollInterval int) {
246
247         jobid, err := submit(container, crunchRunCommand)
248         if err != nil {
249                 log.Printf("Error queuing container run: %v", err)
250                 return
251         }
252
253         insecure := "0"
254         if arv.ApiInsecure {
255                 insecure = "1"
256         }
257         finalizeRecordOnFinish(jobid, container.UUID, finishCommand, arv.ApiServer, arv.ApiToken, insecure)
258
259         // Update container status to Running, this is a temporary workaround
260         // to avoid resubmitting queued containers because record locking isn't
261         // implemented yet.
262         err = arv.Update("containers", container.UUID,
263                 arvadosclient.Dict{
264                         "container": arvadosclient.Dict{"state": "Running"}},
265                 nil)
266         if err != nil {
267                 log.Printf("Error updating container state to 'Running' for %v: %q", container.UUID, err)
268         }
269
270         log.Printf("Submitted container run for %v", container.UUID)
271
272         containerUUID := container.UUID
273
274         // A goroutine to terminate the runner if container priority becomes zero
275         priorityTicker := time.NewTicker(time.Duration(priorityPollInterval) * time.Second)
276         go func() {
277                 for _ = range priorityTicker.C {
278                         var container Container
279                         err := arv.Get("containers", containerUUID, nil, &container)
280                         if err != nil {
281                                 log.Printf("Error getting container info for %v: %q", container.UUID, err)
282                         } else {
283                                 if container.Priority == 0 {
284                                         log.Printf("Canceling container %v", container.UUID)
285                                         priorityTicker.Stop()
286                                         cancelcmd := exec.Command("scancel", "--name="+container.UUID)
287                                         cancelcmd.Run()
288                                 }
289                                 if container.State == "Complete" {
290                                         priorityTicker.Stop()
291                                 }
292                         }
293                 }
294         }()
295
296 }