6518: Add check for sbatch/strigger command line in test.
[arvados.git] / services / crunch-dispatch-slurm / crunch-dispatch-slurm.go
1 package main
2
3 import (
4         "flag"
5         "fmt"
6         "git.curoverse.com/arvados.git/sdk/go/arvadosclient"
7         "io/ioutil"
8         "log"
9         "os"
10         "os/exec"
11         "os/signal"
12         "sync"
13         "syscall"
14         "time"
15 )
16
17 func main() {
18         err := doMain()
19         if err != nil {
20                 log.Fatalf("%q", err)
21         }
22 }
23
24 var (
25         arv              arvadosclient.ArvadosClient
26         runningCmds      map[string]*exec.Cmd
27         runningCmdsMutex sync.Mutex
28         waitGroup        sync.WaitGroup
29         doneProcessing   chan bool
30         sigChan          chan os.Signal
31 )
32
33 func doMain() error {
34         flags := flag.NewFlagSet("crunch-dispatch-slurm", flag.ExitOnError)
35
36         pollInterval := flags.Int(
37                 "poll-interval",
38                 10,
39                 "Interval in seconds to poll for queued containers")
40
41         priorityPollInterval := flags.Int(
42                 "container-priority-poll-interval",
43                 60,
44                 "Interval in seconds to check priority of a dispatched container")
45
46         crunchRunCommand := flags.String(
47                 "crunch-run-command",
48                 "/usr/bin/crunch-run",
49                 "Crunch command to run container")
50
51         finishCommand := flags.String(
52                 "finish-command",
53                 "/usr/bin/crunch-finish-slurm.sh",
54                 "Command to run from strigger when job is finished")
55
56         // Parse args; omit the first arg which is the command name
57         flags.Parse(os.Args[1:])
58
59         var err error
60         arv, err = arvadosclient.MakeArvadosClient()
61         if err != nil {
62                 return err
63         }
64
65         // Channel to terminate
66         doneProcessing = make(chan bool)
67
68         // Graceful shutdown
69         sigChan = make(chan os.Signal, 1)
70         signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT)
71         go func(sig <-chan os.Signal) {
72                 for sig := range sig {
73                         log.Printf("Caught signal: %v", sig)
74                         doneProcessing <- true
75                 }
76         }(sigChan)
77
78         // Run all queued containers
79         runQueuedContainers(*pollInterval, *priorityPollInterval, *crunchRunCommand, *finishCommand)
80
81         // Wait for all running crunch jobs to complete / terminate
82         waitGroup.Wait()
83
84         return nil
85 }
86
87 // Poll for queued containers using pollInterval.
88 // Invoke dispatchSlurm for each ticker cycle, which will run all the queued containers.
89 //
90 // Any errors encountered are logged but the program would continue to run (not exit).
91 // This is because, once one or more crunch jobs are running,
92 // we would need to wait for them complete.
93 func runQueuedContainers(pollInterval, priorityPollInterval int, crunchRunCommand, finishCommand string) {
94         ticker := time.NewTicker(time.Duration(pollInterval) * time.Second)
95
96         for {
97                 select {
98                 case <-ticker.C:
99                         dispatchSlurm(priorityPollInterval, crunchRunCommand, finishCommand)
100                 case <-doneProcessing:
101                         ticker.Stop()
102                         return
103                 }
104         }
105 }
106
107 // Container data
108 type Container struct {
109         UUID     string `json:"uuid"`
110         State    string `json:"state"`
111         Priority int    `json:"priority"`
112 }
113
114 // ContainerList is a list of the containers from api
115 type ContainerList struct {
116         Items []Container `json:"items"`
117 }
118
119 // Get the list of queued containers from API server and invoke run for each container.
120 func dispatchSlurm(priorityPollInterval int, crunchRunCommand, finishCommand string) {
121         params := arvadosclient.Dict{
122                 "filters": [][]string{[]string{"state", "=", "Queued"}},
123         }
124
125         var containers ContainerList
126         err := arv.List("containers", params, &containers)
127         if err != nil {
128                 log.Printf("Error getting list of queued containers: %q", err)
129                 return
130         }
131
132         for i := 0; i < len(containers.Items); i++ {
133                 log.Printf("About to submit queued container %v", containers.Items[i].UUID)
134                 // Run the container
135                 go run(containers.Items[i], crunchRunCommand, finishCommand, priorityPollInterval)
136         }
137 }
138
139 // sbatchCmd
140 func sbatchFunc(uuid string) *exec.Cmd {
141         return exec.Command("sbatch", "--job-name="+uuid, "--share", "--parsable")
142 }
143
144 var sbatchCmd = sbatchFunc
145
146 // striggerCmd
147 func striggerFunc(jobid, containerUUID, finishCommand, apiHost, apiToken, apiInsecure string) *exec.Cmd {
148         return exec.Command("strigger", "--set", "--jobid="+jobid, "--fini",
149                 fmt.Sprintf("--program=%s %s %s %s %s", finishCommand, apiHost, apiToken, apiInsecure, containerUUID))
150 }
151
152 var striggerCmd = striggerFunc
153
154 // Submit job to slurm using sbatch.
155 func submit(container Container, crunchRunCommand string) (jobid string, submitErr error) {
156         submitErr = nil
157
158         // Mark record as complete if anything errors out.
159         defer func() {
160                 if submitErr != nil {
161                         // This really should be an "Error" state, see #8018
162                         updateErr := arv.Update("containers", container.UUID,
163                                 arvadosclient.Dict{
164                                         "container": arvadosclient.Dict{"state": "Complete"}},
165                                 nil)
166                         if updateErr != nil {
167                                 log.Printf("Error updating container state to 'Complete' for %v: %q", container.UUID, updateErr)
168                         }
169                 }
170         }()
171
172         // Create the command and attach to stdin/stdout
173         cmd := sbatchCmd(container.UUID)
174         stdinWriter, stdinerr := cmd.StdinPipe()
175         if stdinerr != nil {
176                 submitErr = fmt.Errorf("Error creating stdin pipe %v: %q", container.UUID, stdinerr)
177                 return
178         }
179
180         stdoutReader, stdoutErr := cmd.StdoutPipe()
181         if stdoutErr != nil {
182                 submitErr = fmt.Errorf("Error creating stdout pipe %v: %q", container.UUID, stdoutErr)
183                 return
184         }
185
186         stderrReader, stderrErr := cmd.StderrPipe()
187         if stderrErr != nil {
188                 submitErr = fmt.Errorf("Error creating stderr pipe %v: %q", container.UUID, stderrErr)
189                 return
190         }
191
192         err := cmd.Start()
193         if err != nil {
194                 submitErr = fmt.Errorf("Error starting %v: %v", cmd.Args, err)
195                 return
196         }
197
198         stdoutChan := make(chan []byte)
199         go func() {
200                 b, _ := ioutil.ReadAll(stdoutReader)
201                 stdoutChan <- b
202                 close(stdoutChan)
203         }()
204
205         stderrChan := make(chan []byte)
206         go func() {
207                 b, _ := ioutil.ReadAll(stderrReader)
208                 stderrChan <- b
209                 close(stderrChan)
210         }()
211
212         // Send a tiny script on stdin to execute the crunch-run command
213         // slurm actually enforces that this must be a #! script
214         fmt.Fprintf(stdinWriter, "#!/bin/sh\nexec '%s' '%s'\n", crunchRunCommand, container.UUID)
215         stdinWriter.Close()
216
217         err = cmd.Wait()
218
219         stdoutMsg := <-stdoutChan
220         stderrmsg := <-stderrChan
221
222         if err != nil {
223                 submitErr = fmt.Errorf("Container submission failed %v: %v %v", cmd.Args, err, stderrmsg)
224                 return
225         }
226
227         // If everything worked out, got the jobid on stdout
228         jobid = string(stdoutMsg)
229
230         return
231 }
232
233 // finalizeRecordOnFinish uses 'strigger' command to register a script that will run on
234 // the slurm controller when the job finishes.
235 func finalizeRecordOnFinish(jobid, containerUUID, finishCommand, apiHost, apiToken, apiInsecure string) {
236         cmd := striggerCmd(jobid, containerUUID, finishCommand, apiHost, apiToken, apiInsecure)
237         cmd.Stdout = os.Stdout
238         cmd.Stderr = os.Stderr
239         err := cmd.Run()
240         if err != nil {
241                 log.Printf("While setting up strigger: %v", err)
242         }
243 }
244
245 // Run a queued container.
246 // Set container state to locked (TBD)
247 // Submit job to slurm to execute crunch-run command for the container
248 // If the container priority becomes zero while crunch job is still running, cancel the job.
249 func run(container Container, crunchRunCommand, finishCommand string, priorityPollInterval int) {
250
251         jobid, err := submit(container, crunchRunCommand)
252         if err != nil {
253                 log.Printf("Error queuing container run: %v", err)
254                 return
255         }
256
257         insecure := "0"
258         if arv.ApiInsecure {
259                 insecure = "1"
260         }
261         finalizeRecordOnFinish(jobid, container.UUID, finishCommand, arv.ApiServer, arv.ApiToken, insecure)
262
263         // Update container status to Running, this is a temporary workaround
264         // to avoid resubmitting queued containers because record locking isn't
265         // implemented yet.
266         err = arv.Update("containers", container.UUID,
267                 arvadosclient.Dict{
268                         "container": arvadosclient.Dict{"state": "Running"}},
269                 nil)
270         if err != nil {
271                 log.Printf("Error updating container state to 'Running' for %v: %q", container.UUID, err)
272         }
273
274         log.Printf("Submitted container run for %v", container.UUID)
275
276         containerUUID := container.UUID
277
278         // A goroutine to terminate the runner if container priority becomes zero
279         priorityTicker := time.NewTicker(time.Duration(priorityPollInterval) * time.Second)
280         go func() {
281                 for _ = range priorityTicker.C {
282                         var container Container
283                         err := arv.Get("containers", containerUUID, nil, &container)
284                         if err != nil {
285                                 log.Printf("Error getting container info for %v: %q", container.UUID, err)
286                         } else {
287                                 if container.Priority == 0 {
288                                         log.Printf("Canceling container %v", container.UUID)
289                                         priorityTicker.Stop()
290                                         cancelcmd := exec.Command("scancel", "--name="+container.UUID)
291                                         cancelcmd.Run()
292                                 }
293                                 if container.State == "Complete" {
294                                         priorityTicker.Stop()
295                                 }
296                         }
297                 }
298         }()
299
300 }