9684: update arvados_model -> recursive_stringify to convert ":foo" to "foo"
[arvados.git] / sdk / go / crunchrunner / crunchrunner.go
1 package main
2
3 import (
4         "crypto/x509"
5         "fmt"
6         "git.curoverse.com/arvados.git/sdk/go/arvadosclient"
7         "git.curoverse.com/arvados.git/sdk/go/keepclient"
8         "io/ioutil"
9         "log"
10         "net/http"
11         "os"
12         "os/exec"
13         "os/signal"
14         "strings"
15         "syscall"
16 )
17
18 type TaskDef struct {
19         Command            []string          `json:"command"`
20         Env                map[string]string `json:"task.env"`
21         Stdin              string            `json:"task.stdin"`
22         Stdout             string            `json:"task.stdout"`
23         Stderr             string            `json:"task.stderr"`
24         Vwd                map[string]string `json:"task.vwd"`
25         SuccessCodes       []int             `json:"task.successCodes"`
26         PermanentFailCodes []int             `json:"task.permanentFailCodes"`
27         TemporaryFailCodes []int             `json:"task.temporaryFailCodes"`
28 }
29
30 type Tasks struct {
31         Tasks []TaskDef `json:"tasks"`
32 }
33
34 type Job struct {
35         Script_parameters Tasks `json:"script_parameters"`
36 }
37
38 type Task struct {
39         Job_uuid                 string  `json:"job_uuid"`
40         Created_by_job_task_uuid string  `json:"created_by_job_task_uuid"`
41         Parameters               TaskDef `json:"parameters"`
42         Sequence                 int     `json:"sequence"`
43         Output                   string  `json:"output"`
44         Success                  bool    `json:"success"`
45         Progress                 float32 `json:"sequence"`
46 }
47
48 type IArvadosClient interface {
49         Create(resourceType string, parameters arvadosclient.Dict, output interface{}) error
50         Update(resourceType string, uuid string, parameters arvadosclient.Dict, output interface{}) (err error)
51 }
52
53 func setupDirectories(crunchtmpdir, taskUuid string) (tmpdir, outdir string, err error) {
54         tmpdir = crunchtmpdir + "/tmpdir"
55         err = os.Mkdir(tmpdir, 0700)
56         if err != nil {
57                 return "", "", err
58         }
59
60         outdir = crunchtmpdir + "/outdir"
61         err = os.Mkdir(outdir, 0700)
62         if err != nil {
63                 return "", "", err
64         }
65
66         return tmpdir, outdir, nil
67 }
68
69 func checkOutputFilename(outdir, fn string) error {
70         if strings.HasPrefix(fn, "/") || strings.HasSuffix(fn, "/") {
71                 return fmt.Errorf("Path must not start or end with '/'")
72         }
73         if strings.Index("../", fn) != -1 {
74                 return fmt.Errorf("Path must not contain '../'")
75         }
76
77         sl := strings.LastIndex(fn, "/")
78         if sl != -1 {
79                 os.MkdirAll(outdir+"/"+fn[0:sl], 0777)
80         }
81         return nil
82 }
83
84 func setupCommand(cmd *exec.Cmd, taskp TaskDef, outdir string, replacements map[string]string) (stdin, stdout, stderr string, err error) {
85         if taskp.Vwd != nil {
86                 for k, v := range taskp.Vwd {
87                         v = substitute(v, replacements)
88                         err = checkOutputFilename(outdir, k)
89                         if err != nil {
90                                 return "", "", "", err
91                         }
92                         os.Symlink(v, outdir+"/"+k)
93                 }
94         }
95
96         if taskp.Stdin != "" {
97                 // Set up stdin redirection
98                 stdin = substitute(taskp.Stdin, replacements)
99                 cmd.Stdin, err = os.Open(stdin)
100                 if err != nil {
101                         return "", "", "", err
102                 }
103         }
104
105         if taskp.Stdout != "" {
106                 err = checkOutputFilename(outdir, taskp.Stdout)
107                 if err != nil {
108                         return "", "", "", err
109                 }
110                 // Set up stdout redirection
111                 stdout = outdir + "/" + taskp.Stdout
112                 cmd.Stdout, err = os.Create(stdout)
113                 if err != nil {
114                         return "", "", "", err
115                 }
116         } else {
117                 cmd.Stdout = os.Stdout
118         }
119
120         if taskp.Stderr != "" {
121                 err = checkOutputFilename(outdir, taskp.Stderr)
122                 if err != nil {
123                         return "", "", "", err
124                 }
125                 // Set up stderr redirection
126                 stderr = outdir + "/" + taskp.Stderr
127                 cmd.Stderr, err = os.Create(stderr)
128                 if err != nil {
129                         return "", "", "", err
130                 }
131         } else {
132                 cmd.Stderr = os.Stderr
133         }
134
135         if taskp.Env != nil {
136                 // Set up subprocess environment
137                 cmd.Env = os.Environ()
138                 for k, v := range taskp.Env {
139                         v = substitute(v, replacements)
140                         cmd.Env = append(cmd.Env, k+"="+v)
141                 }
142         }
143         return stdin, stdout, stderr, nil
144 }
145
146 // Set up signal handlers.  Go sends signal notifications to a "signal
147 // channel".
148 func setupSignals(cmd *exec.Cmd) chan os.Signal {
149         sigChan := make(chan os.Signal, 1)
150         signal.Notify(sigChan, syscall.SIGTERM)
151         signal.Notify(sigChan, syscall.SIGINT)
152         signal.Notify(sigChan, syscall.SIGQUIT)
153         return sigChan
154 }
155
156 func inCodes(code int, codes []int) bool {
157         if codes != nil {
158                 for _, c := range codes {
159                         if code == c {
160                                 return true
161                         }
162                 }
163         }
164         return false
165 }
166
167 const TASK_TEMPFAIL = 111
168
169 type TempFail struct{ error }
170 type PermFail struct{}
171
172 func (s PermFail) Error() string {
173         return "PermFail"
174 }
175
176 func substitute(inp string, subst map[string]string) string {
177         for k, v := range subst {
178                 inp = strings.Replace(inp, k, v, -1)
179         }
180         return inp
181 }
182
183 func runner(api IArvadosClient,
184         kc IKeepClient,
185         jobUuid, taskUuid, crunchtmpdir, keepmount string,
186         jobStruct Job, taskStruct Task) error {
187
188         var err error
189         taskp := taskStruct.Parameters
190
191         // If this is task 0 and there are multiple tasks, dispatch subtasks
192         // and exit.
193         if taskStruct.Sequence == 0 {
194                 if len(jobStruct.Script_parameters.Tasks) == 1 {
195                         taskp = jobStruct.Script_parameters.Tasks[0]
196                 } else {
197                         for _, task := range jobStruct.Script_parameters.Tasks {
198                                 err := api.Create("job_tasks",
199                                         map[string]interface{}{
200                                                 "job_task": Task{Job_uuid: jobUuid,
201                                                         Created_by_job_task_uuid: taskUuid,
202                                                         Sequence:                 1,
203                                                         Parameters:               task}},
204                                         nil)
205                                 if err != nil {
206                                         return TempFail{err}
207                                 }
208                         }
209                         err = api.Update("job_tasks", taskUuid,
210                                 map[string]interface{}{
211                                         "job_task": Task{
212                                                 Output:   "",
213                                                 Success:  true,
214                                                 Progress: 1.0}},
215                                 nil)
216                         return nil
217                 }
218         }
219
220         var tmpdir, outdir string
221         tmpdir, outdir, err = setupDirectories(crunchtmpdir, taskUuid)
222         if err != nil {
223                 return TempFail{err}
224         }
225
226         replacements := map[string]string{
227                 "$(task.tmpdir)": tmpdir,
228                 "$(task.outdir)": outdir,
229                 "$(task.keep)":   keepmount}
230
231         log.Printf("crunchrunner: $(task.tmpdir)=%v", tmpdir)
232         log.Printf("crunchrunner: $(task.outdir)=%v", outdir)
233         log.Printf("crunchrunner: $(task.keep)=%v", keepmount)
234
235         // Set up subprocess
236         for k, v := range taskp.Command {
237                 taskp.Command[k] = substitute(v, replacements)
238         }
239
240         cmd := exec.Command(taskp.Command[0], taskp.Command[1:]...)
241
242         cmd.Dir = outdir
243
244         var stdin, stdout, stderr string
245         stdin, stdout, stderr, err = setupCommand(cmd, taskp, outdir, replacements)
246         if err != nil {
247                 return err
248         }
249
250         // Run subprocess and wait for it to complete
251         if stdin != "" {
252                 stdin = " < " + stdin
253         }
254         if stdout != "" {
255                 stdout = " > " + stdout
256         }
257         if stderr != "" {
258                 stderr = " 2> " + stderr
259         }
260         log.Printf("Running %v%v%v%v", cmd.Args, stdin, stdout, stderr)
261
262         var caughtSignal os.Signal
263         sigChan := setupSignals(cmd)
264
265         err = cmd.Start()
266         if err != nil {
267                 signal.Stop(sigChan)
268                 return TempFail{err}
269         }
270
271         finishedSignalNotify := make(chan struct{})
272         go func(sig <-chan os.Signal) {
273                 for sig := range sig {
274                         caughtSignal = sig
275                         cmd.Process.Signal(caughtSignal)
276                 }
277                 close(finishedSignalNotify)
278         }(sigChan)
279
280         err = cmd.Wait()
281         signal.Stop(sigChan)
282
283         close(sigChan)
284         <-finishedSignalNotify
285
286         if caughtSignal != nil {
287                 log.Printf("Caught signal %v", caughtSignal)
288                 return PermFail{}
289         }
290
291         if err != nil {
292                 // Run() returns ExitError on non-zero exit code, but we handle
293                 // that down below.  So only return if it's not ExitError.
294                 if _, ok := err.(*exec.ExitError); !ok {
295                         return TempFail{err}
296                 }
297         }
298
299         var success bool
300
301         exitCode := cmd.ProcessState.Sys().(syscall.WaitStatus).ExitStatus()
302
303         log.Printf("Completed with exit code %v", exitCode)
304
305         if inCodes(exitCode, taskp.PermanentFailCodes) {
306                 success = false
307         } else if inCodes(exitCode, taskp.TemporaryFailCodes) {
308                 return TempFail{fmt.Errorf("Process tempfail with exit code %v", exitCode)}
309         } else if inCodes(exitCode, taskp.SuccessCodes) || cmd.ProcessState.Success() {
310                 success = true
311         } else {
312                 success = false
313         }
314
315         // Upload output directory
316         manifest, err := WriteTree(kc, outdir)
317         if err != nil {
318                 return TempFail{err}
319         }
320
321         // Set status
322         err = api.Update("job_tasks", taskUuid,
323                 map[string]interface{}{
324                         "job_task": Task{
325                                 Output:   manifest,
326                                 Success:  success,
327                                 Progress: 1}},
328                 nil)
329         if err != nil {
330                 return TempFail{err}
331         }
332
333         if success {
334                 return nil
335         } else {
336                 return PermFail{}
337         }
338 }
339
340 func main() {
341         api, err := arvadosclient.MakeArvadosClient()
342         if err != nil {
343                 log.Fatal(err)
344         }
345
346         // Container may not have certificates installed, so need to look for
347         // /etc/arvados/ca-certificates.crt in addition to normal system certs.
348         var certFiles = []string{
349                 "/etc/ssl/certs/ca-certificates.crt", // Debian
350                 "/etc/pki/tls/certs/ca-bundle.crt",   // Red Hat
351                 "/etc/arvados/ca-certificates.crt",
352         }
353
354         certs := x509.NewCertPool()
355         for _, file := range certFiles {
356                 data, err := ioutil.ReadFile(file)
357                 if err == nil {
358                         log.Printf("Using TLS certificates at %v", file)
359                         certs.AppendCertsFromPEM(data)
360                 }
361         }
362         api.Client.Transport.(*http.Transport).TLSClientConfig.RootCAs = certs
363
364         jobUuid := os.Getenv("JOB_UUID")
365         taskUuid := os.Getenv("TASK_UUID")
366         tmpdir := os.Getenv("TASK_WORK")
367         keepmount := os.Getenv("TASK_KEEPMOUNT")
368
369         var jobStruct Job
370         var taskStruct Task
371
372         err = api.Get("jobs", jobUuid, nil, &jobStruct)
373         if err != nil {
374                 log.Fatal(err)
375         }
376         err = api.Get("job_tasks", taskUuid, nil, &taskStruct)
377         if err != nil {
378                 log.Fatal(err)
379         }
380
381         var kc IKeepClient
382         kc, err = keepclient.MakeKeepClient(&api)
383         if err != nil {
384                 log.Fatal(err)
385         }
386
387         syscall.Umask(0022)
388         err = runner(api, kc, jobUuid, taskUuid, tmpdir, keepmount, jobStruct, taskStruct)
389
390         if err == nil {
391                 os.Exit(0)
392         } else if _, ok := err.(TempFail); ok {
393                 log.Print(err)
394                 os.Exit(TASK_TEMPFAIL)
395         } else if _, ok := err.(PermFail); ok {
396                 os.Exit(1)
397         } else {
398                 log.Fatal(err)
399         }
400 }