Merge branch '8815-crunchrunner-everywhere' closes #8815
[arvados.git] / sdk / go / crunchrunner / crunchrunner.go
1 package main
2
3 import (
4         "crypto/x509"
5         "fmt"
6         "git.curoverse.com/arvados.git/sdk/go/arvadosclient"
7         "git.curoverse.com/arvados.git/sdk/go/keepclient"
8         "io/ioutil"
9         "log"
10         "net/http"
11         "os"
12         "os/exec"
13         "os/signal"
14         "strings"
15         "syscall"
16 )
17
18 type TaskDef struct {
19         Command            []string          `json:"command"`
20         Env                map[string]string `json:"task.env"`
21         Stdin              string            `json:"task.stdin"`
22         Stdout             string            `json:"task.stdout"`
23         Vwd                map[string]string `json:"task.vwd"`
24         SuccessCodes       []int             `json:"task.successCodes"`
25         PermanentFailCodes []int             `json:"task.permanentFailCodes"`
26         TemporaryFailCodes []int             `json:"task.temporaryFailCodes"`
27 }
28
29 type Tasks struct {
30         Tasks []TaskDef `json:"tasks"`
31 }
32
33 type Job struct {
34         Script_parameters Tasks `json:"script_parameters"`
35 }
36
37 type Task struct {
38         Job_uuid                 string  `json:"job_uuid"`
39         Created_by_job_task_uuid string  `json:"created_by_job_task_uuid"`
40         Parameters               TaskDef `json:"parameters"`
41         Sequence                 int     `json:"sequence"`
42         Output                   string  `json:"output"`
43         Success                  bool    `json:"success"`
44         Progress                 float32 `json:"sequence"`
45 }
46
47 type IArvadosClient interface {
48         Create(resourceType string, parameters arvadosclient.Dict, output interface{}) error
49         Update(resourceType string, uuid string, parameters arvadosclient.Dict, output interface{}) (err error)
50 }
51
52 func setupDirectories(crunchtmpdir, taskUuid string) (tmpdir, outdir string, err error) {
53         tmpdir = crunchtmpdir + "/tmpdir"
54         err = os.Mkdir(tmpdir, 0700)
55         if err != nil {
56                 return "", "", err
57         }
58
59         outdir = crunchtmpdir + "/outdir"
60         err = os.Mkdir(outdir, 0700)
61         if err != nil {
62                 return "", "", err
63         }
64
65         return tmpdir, outdir, nil
66 }
67
68 func checkOutputFilename(outdir, fn string) error {
69         if strings.HasPrefix(fn, "/") || strings.HasSuffix(fn, "/") {
70                 return fmt.Errorf("Path must not start or end with '/'")
71         }
72         if strings.Index("../", fn) != -1 {
73                 return fmt.Errorf("Path must not contain '../'")
74         }
75
76         sl := strings.LastIndex(fn, "/")
77         if sl != -1 {
78                 os.MkdirAll(outdir+"/"+fn[0:sl], 0777)
79         }
80         return nil
81 }
82
83 func setupCommand(cmd *exec.Cmd, taskp TaskDef, outdir string, replacements map[string]string) (stdin, stdout string, err error) {
84         if taskp.Vwd != nil {
85                 for k, v := range taskp.Vwd {
86                         v = substitute(v, replacements)
87                         err = checkOutputFilename(outdir, k)
88                         if err != nil {
89                                 return "", "", err
90                         }
91                         os.Symlink(v, outdir+"/"+k)
92                 }
93         }
94
95         if taskp.Stdin != "" {
96                 // Set up stdin redirection
97                 stdin = substitute(taskp.Stdin, replacements)
98                 cmd.Stdin, err = os.Open(stdin)
99                 if err != nil {
100                         return "", "", err
101                 }
102         }
103
104         if taskp.Stdout != "" {
105                 err = checkOutputFilename(outdir, taskp.Stdout)
106                 if err != nil {
107                         return "", "", err
108                 }
109                 // Set up stdout redirection
110                 stdout = outdir + "/" + taskp.Stdout
111                 cmd.Stdout, err = os.Create(stdout)
112                 if err != nil {
113                         return "", "", err
114                 }
115         } else {
116                 cmd.Stdout = os.Stdout
117         }
118
119         cmd.Stderr = os.Stderr
120
121         if taskp.Env != nil {
122                 // Set up subprocess environment
123                 cmd.Env = os.Environ()
124                 for k, v := range taskp.Env {
125                         v = substitute(v, replacements)
126                         cmd.Env = append(cmd.Env, k+"="+v)
127                 }
128         }
129         return stdin, stdout, nil
130 }
131
132 // Set up signal handlers.  Go sends signal notifications to a "signal
133 // channel".
134 func setupSignals(cmd *exec.Cmd) chan os.Signal {
135         sigChan := make(chan os.Signal, 1)
136         signal.Notify(sigChan, syscall.SIGTERM)
137         signal.Notify(sigChan, syscall.SIGINT)
138         signal.Notify(sigChan, syscall.SIGQUIT)
139         return sigChan
140 }
141
142 func inCodes(code int, codes []int) bool {
143         if codes != nil {
144                 for _, c := range codes {
145                         if code == c {
146                                 return true
147                         }
148                 }
149         }
150         return false
151 }
152
153 const TASK_TEMPFAIL = 111
154
155 type TempFail struct{ error }
156 type PermFail struct{}
157
158 func (s PermFail) Error() string {
159         return "PermFail"
160 }
161
162 func substitute(inp string, subst map[string]string) string {
163         for k, v := range subst {
164                 inp = strings.Replace(inp, k, v, -1)
165         }
166         return inp
167 }
168
169 func runner(api IArvadosClient,
170         kc IKeepClient,
171         jobUuid, taskUuid, crunchtmpdir, keepmount string,
172         jobStruct Job, taskStruct Task) error {
173
174         var err error
175         taskp := taskStruct.Parameters
176
177         // If this is task 0 and there are multiple tasks, dispatch subtasks
178         // and exit.
179         if taskStruct.Sequence == 0 {
180                 if len(jobStruct.Script_parameters.Tasks) == 1 {
181                         taskp = jobStruct.Script_parameters.Tasks[0]
182                 } else {
183                         for _, task := range jobStruct.Script_parameters.Tasks {
184                                 err := api.Create("job_tasks",
185                                         map[string]interface{}{
186                                                 "job_task": Task{Job_uuid: jobUuid,
187                                                         Created_by_job_task_uuid: taskUuid,
188                                                         Sequence:                 1,
189                                                         Parameters:               task}},
190                                         nil)
191                                 if err != nil {
192                                         return TempFail{err}
193                                 }
194                         }
195                         err = api.Update("job_tasks", taskUuid,
196                                 map[string]interface{}{
197                                         "job_task": Task{
198                                                 Output:   "",
199                                                 Success:  true,
200                                                 Progress: 1.0}},
201                                 nil)
202                         return nil
203                 }
204         }
205
206         var tmpdir, outdir string
207         tmpdir, outdir, err = setupDirectories(crunchtmpdir, taskUuid)
208         if err != nil {
209                 return TempFail{err}
210         }
211
212         replacements := map[string]string{
213                 "$(task.tmpdir)": tmpdir,
214                 "$(task.outdir)": outdir,
215                 "$(task.keep)":   keepmount}
216
217         log.Printf("crunchrunner: $(task.tmpdir)=%v", tmpdir)
218         log.Printf("crunchrunner: $(task.outdir)=%v", outdir)
219         log.Printf("crunchrunner: $(task.keep)=%v", keepmount)
220
221         // Set up subprocess
222         for k, v := range taskp.Command {
223                 taskp.Command[k] = substitute(v, replacements)
224         }
225
226         cmd := exec.Command(taskp.Command[0], taskp.Command[1:]...)
227
228         cmd.Dir = outdir
229
230         var stdin, stdout string
231         stdin, stdout, err = setupCommand(cmd, taskp, outdir, replacements)
232         if err != nil {
233                 return err
234         }
235
236         // Run subprocess and wait for it to complete
237         if stdin != "" {
238                 stdin = " < " + stdin
239         }
240         if stdout != "" {
241                 stdout = " > " + stdout
242         }
243         log.Printf("Running %v%v%v", cmd.Args, stdin, stdout)
244
245         var caughtSignal os.Signal
246         sigChan := setupSignals(cmd)
247
248         err = cmd.Start()
249         if err != nil {
250                 signal.Stop(sigChan)
251                 return TempFail{err}
252         }
253
254         finishedSignalNotify := make(chan struct{})
255         go func(sig <-chan os.Signal) {
256                 for sig := range sig {
257                         caughtSignal = sig
258                         cmd.Process.Signal(caughtSignal)
259                 }
260                 close(finishedSignalNotify)
261         }(sigChan)
262
263         err = cmd.Wait()
264         signal.Stop(sigChan)
265
266         close(sigChan)
267         <-finishedSignalNotify
268
269         if caughtSignal != nil {
270                 log.Printf("Caught signal %v", caughtSignal)
271                 return PermFail{}
272         }
273
274         if err != nil {
275                 // Run() returns ExitError on non-zero exit code, but we handle
276                 // that down below.  So only return if it's not ExitError.
277                 if _, ok := err.(*exec.ExitError); !ok {
278                         return TempFail{err}
279                 }
280         }
281
282         var success bool
283
284         exitCode := cmd.ProcessState.Sys().(syscall.WaitStatus).ExitStatus()
285
286         log.Printf("Completed with exit code %v", exitCode)
287
288         if inCodes(exitCode, taskp.PermanentFailCodes) {
289                 success = false
290         } else if inCodes(exitCode, taskp.TemporaryFailCodes) {
291                 return TempFail{fmt.Errorf("Process tempfail with exit code %v", exitCode)}
292         } else if inCodes(exitCode, taskp.SuccessCodes) || cmd.ProcessState.Success() {
293                 success = true
294         } else {
295                 success = false
296         }
297
298         // Upload output directory
299         manifest, err := WriteTree(kc, outdir)
300         if err != nil {
301                 return TempFail{err}
302         }
303
304         // Set status
305         err = api.Update("job_tasks", taskUuid,
306                 map[string]interface{}{
307                         "job_task": Task{
308                                 Output:   manifest,
309                                 Success:  success,
310                                 Progress: 1}},
311                 nil)
312         if err != nil {
313                 return TempFail{err}
314         }
315
316         if success {
317                 return nil
318         } else {
319                 return PermFail{}
320         }
321 }
322
323 func main() {
324         api, err := arvadosclient.MakeArvadosClient()
325         if err != nil {
326                 log.Fatal(err)
327         }
328
329         // Container may not have certificates installed, so need to look for
330         // /etc/arvados/ca-certificates.crt in addition to normal system certs.
331         var certFiles = []string{
332                 "/etc/ssl/certs/ca-certificates.crt", // Debian
333                 "/etc/pki/tls/certs/ca-bundle.crt",   // Red Hat
334                 "/etc/arvados/ca-certificates.crt",
335         }
336
337         certs := x509.NewCertPool()
338         for _, file := range certFiles {
339                 data, err := ioutil.ReadFile(file)
340                 if err == nil {
341                         log.Printf("Using TLS certificates at %v", file)
342                         certs.AppendCertsFromPEM(data)
343                 }
344         }
345         api.Client.Transport.(*http.Transport).TLSClientConfig.RootCAs = certs
346
347         jobUuid := os.Getenv("JOB_UUID")
348         taskUuid := os.Getenv("TASK_UUID")
349         tmpdir := os.Getenv("TASK_WORK")
350         keepmount := os.Getenv("TASK_KEEPMOUNT")
351
352         var jobStruct Job
353         var taskStruct Task
354
355         err = api.Get("jobs", jobUuid, nil, &jobStruct)
356         if err != nil {
357                 log.Fatal(err)
358         }
359         err = api.Get("job_tasks", taskUuid, nil, &taskStruct)
360         if err != nil {
361                 log.Fatal(err)
362         }
363
364         var kc IKeepClient
365         kc, err = keepclient.MakeKeepClient(&api)
366         if err != nil {
367                 log.Fatal(err)
368         }
369
370         syscall.Umask(0022)
371         err = runner(api, kc, jobUuid, taskUuid, tmpdir, keepmount, jobStruct, taskStruct)
372
373         if err == nil {
374                 os.Exit(0)
375         } else if _, ok := err.(TempFail); ok {
376                 log.Print(err)
377                 os.Exit(TASK_TEMPFAIL)
378         } else if _, ok := err.(PermFail); ok {
379                 os.Exit(1)
380         } else {
381                 log.Fatal(err)
382         }
383 }