8028: when a signal is received, terminate all running commands and wait in a WaitGroup.
[arvados.git] / services / crunch-dispatch-local / crunch-dispatch-local.go
1 package main
2
3 import (
4         "flag"
5         "git.curoverse.com/arvados.git/sdk/go/arvadosclient"
6         "log"
7         "os"
8         "os/exec"
9         "os/signal"
10         "sync"
11         "syscall"
12         "time"
13 )
14
15 func main() {
16         err := doMain()
17         if err != nil {
18                 log.Fatalf("%q", err)
19         }
20 }
21
22 var arv arvadosclient.ArvadosClient
23 var runningCmds map[string]*exec.Cmd
24
25 func doMain() error {
26         flags := flag.NewFlagSet("crunch-dispatch-local", flag.ExitOnError)
27
28         pollInterval := flags.Int(
29                 "poll-interval",
30                 10,
31                 "Interval in seconds to poll for queued containers")
32
33         priorityPollInterval := flags.Int(
34                 "container-priority-poll-interval",
35                 60,
36                 "Interval in seconds to check priority of a dispatched container")
37
38         crunchRunCommand := flags.String(
39                 "crunch-run-command",
40                 "/usr/bin/crunch-run",
41                 "Crunch command to run container")
42
43         // Parse args; omit the first arg which is the command name
44         flags.Parse(os.Args[1:])
45
46         var err error
47         arv, err = arvadosclient.MakeArvadosClient()
48         if err != nil {
49                 return err
50         }
51
52         runningCmds = make(map[string]*exec.Cmd)
53         sigChan = make(chan os.Signal, 1)
54         signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT)
55         go func(sig <-chan os.Signal) {
56                 var wg sync.WaitGroup
57                 for sig := range sig {
58                         doneProcessing <- true
59                         caught := sig
60                         for uuid, cmd := range runningCmds {
61                                 go func(uuid string) {
62                                         wg.Add(1)
63                                         defer wg.Done()
64                                         cmd.Process.Signal(caught)
65                                         if _, err := cmd.Process.Wait(); err != nil {
66                                                 log.Printf("Error while waiting for process to finish for %v: %q", uuid, err)
67                                         }
68                                 }(uuid)
69                         }
70                 }
71                 wg.Wait()
72         }(sigChan)
73
74         // channel to terminate
75         doneProcessing = make(chan bool)
76
77         // run all queued containers
78         runQueuedContainers(*pollInterval, *priorityPollInterval, *crunchRunCommand)
79         return nil
80 }
81
82 var doneProcessing chan bool
83 var sigChan chan os.Signal
84
85 // Poll for queued containers using pollInterval.
86 // Invoke dispatchLocal for each ticker cycle, which will run all the queued containers.
87 //
88 // Any errors encountered are logged but the program would continue to run (not exit).
89 // This is because, once one or more child processes are running,
90 // we would need to wait for them complete.
91 func runQueuedContainers(pollInterval, priorityPollInterval int, crunchRunCommand string) {
92         ticker := time.NewTicker(time.Duration(pollInterval) * time.Second)
93
94         for {
95                 select {
96                 case <-ticker.C:
97                         dispatchLocal(priorityPollInterval, crunchRunCommand)
98                 case <-doneProcessing:
99                         ticker.Stop()
100                         return
101                 }
102         }
103 }
104
105 // Container data
106 type Container struct {
107         UUID     string `json:"uuid"`
108         State    string `json:"state"`
109         Priority int    `json:"priority"`
110 }
111
112 // ContainerList is a list of the containers from api
113 type ContainerList struct {
114         Items []Container `json:"items"`
115 }
116
117 // Get the list of queued containers from API server and invoke run for each container.
118 func dispatchLocal(priorityPollInterval int, crunchRunCommand string) {
119         params := arvadosclient.Dict{
120                 "filters": [][]string{[]string{"state", "=", "Queued"}},
121         }
122
123         var containers ContainerList
124         err := arv.List("containers", params, &containers)
125         if err != nil {
126                 log.Printf("Error getting list of queued containers: %q", err)
127                 return
128         }
129
130         for i := 0; i < len(containers.Items); i++ {
131                 log.Printf("About to run queued container %v", containers.Items[i].UUID)
132                 go run(containers.Items[i].UUID, crunchRunCommand, priorityPollInterval)
133         }
134 }
135
136 // Run queued container:
137 // Set container state to locked (TBD)
138 // Run container using the given crunch-run command
139 // Set the container state to Running
140 // If the container priority becomes zero while crunch job is still running, terminate it.
141 func run(uuid string, crunchRunCommand string, priorityPollInterval int) {
142         cmd := exec.Command(crunchRunCommand, uuid)
143
144         cmd.Stdin = nil
145         cmd.Stderr = os.Stderr
146         cmd.Stdout = os.Stderr
147         if err := cmd.Start(); err != nil {
148                 log.Printf("Error running container for %v: %q", uuid, err)
149                 return
150         }
151
152         runningCmds[uuid] = cmd
153
154         log.Printf("Started container run for %v", uuid)
155
156         err := arv.Update("containers", uuid,
157                 arvadosclient.Dict{
158                         "container": arvadosclient.Dict{"state": "Running"}},
159                 nil)
160         if err != nil {
161                 log.Printf("Error updating container state to 'Running' for %v: %q", uuid, err)
162         }
163
164         // Terminate the runner if container priority becomes zero
165         priorityTicker := time.NewTicker(time.Duration(priorityPollInterval) * time.Second)
166         go func() {
167                 for {
168                         select {
169                         case <-priorityTicker.C:
170                                 var container Container
171                                 err := arv.Get("containers", uuid, nil, &container)
172                                 if err != nil {
173                                         log.Printf("Error getting container info for %v: %q", uuid, err)
174                                 } else {
175                                         if container.Priority == 0 {
176                                                 priorityTicker.Stop()
177                                                 cmd.Process.Signal(os.Interrupt)
178                                                 delete(runningCmds, uuid)
179                                                 return
180                                         }
181                                 }
182                         }
183                 }
184         }()
185
186         // Wait for the process to exit
187         if _, err := cmd.Process.Wait(); err != nil {
188                 log.Printf("Error while waiting for process to finish for %v: %q", uuid, err)
189         }
190         delete(runningCmds, uuid)
191
192         priorityTicker.Stop()
193
194         var container Container
195         err = arv.Get("containers", uuid, nil, &container)
196         if container.State == "Running" {
197                 log.Printf("After crunch-run process termination, the state is still 'Running' for %v. Updating it to 'Complete'", uuid)
198                 err = arv.Update("containers", uuid,
199                         arvadosclient.Dict{
200                                 "container": arvadosclient.Dict{"state": "Complete"}},
201                         nil)
202                 if err != nil {
203                         log.Printf("Error updating container state to Complete for %v: %q", uuid, err)
204                 }
205         }
206 }