closes #8508
[arvados.git] / sdk / go / crunchrunner / crunchrunner.go
1 package main
2
3 import (
4         "crypto/x509"
5         "fmt"
6         "git.curoverse.com/arvados.git/sdk/go/arvadosclient"
7         "git.curoverse.com/arvados.git/sdk/go/keepclient"
8         "io/ioutil"
9         "log"
10         "net/http"
11         "os"
12         "os/exec"
13         "os/signal"
14         "path"
15         "strings"
16         "syscall"
17 )
18
19 type TaskDef struct {
20         Command            []string          `json:"command"`
21         Env                map[string]string `json:"task.env"`
22         Stdin              string            `json:"task.stdin"`
23         Stdout             string            `json:"task.stdout"`
24         Vwd                map[string]string `json:"task.vwd"`
25         SuccessCodes       []int             `json:"task.successCodes"`
26         PermanentFailCodes []int             `json:"task.permanentFailCodes"`
27         TemporaryFailCodes []int             `json:"task.temporaryFailCodes"`
28 }
29
30 type Tasks struct {
31         Tasks []TaskDef `json:"tasks"`
32 }
33
34 type Job struct {
35         Script_parameters Tasks `json:"script_parameters"`
36 }
37
38 type Task struct {
39         Job_uuid                 string  `json:"job_uuid"`
40         Created_by_job_task_uuid string  `json:"created_by_job_task_uuid"`
41         Parameters               TaskDef `json:"parameters"`
42         Sequence                 int     `json:"sequence"`
43         Output                   string  `json:"output"`
44         Success                  bool    `json:"success"`
45         Progress                 float32 `json:"sequence"`
46 }
47
48 type IArvadosClient interface {
49         Create(resourceType string, parameters arvadosclient.Dict, output interface{}) error
50         Update(resourceType string, uuid string, parameters arvadosclient.Dict, output interface{}) (err error)
51 }
52
53 func setupDirectories(crunchtmpdir, taskUuid string) (tmpdir, outdir string, err error) {
54         tmpdir = crunchtmpdir + "/tmpdir"
55         err = os.Mkdir(tmpdir, 0700)
56         if err != nil {
57                 return "", "", err
58         }
59
60         outdir = crunchtmpdir + "/outdir"
61         err = os.Mkdir(outdir, 0700)
62         if err != nil {
63                 return "", "", err
64         }
65
66         return tmpdir, outdir, nil
67 }
68
69 func checkOutputFilename(outdir, fn string) error {
70         if strings.HasPrefix(fn, "/") || strings.HasSuffix(fn, "/") {
71                 return fmt.Errorf("Path must not start or end with '/'")
72         }
73         if strings.Index("../", fn) != -1 {
74                 return fmt.Errorf("Path must not contain '../'")
75         }
76
77         sl := strings.LastIndex(fn, "/")
78         if sl != -1 {
79                 os.MkdirAll(outdir+"/"+fn[0:sl], 0777)
80         }
81         return nil
82 }
83
84 func setupCommand(cmd *exec.Cmd, taskp TaskDef, outdir string, replacements map[string]string) (stdin, stdout string, err error) {
85         if taskp.Vwd != nil {
86                 for k, v := range taskp.Vwd {
87                         v = substitute(v, replacements)
88                         err = checkOutputFilename(outdir, k)
89                         if err != nil {
90                                 return "", "", err
91                         }
92                         os.Symlink(v, outdir+"/"+k)
93                 }
94         }
95
96         if taskp.Stdin != "" {
97                 // Set up stdin redirection
98                 stdin = substitute(taskp.Stdin, replacements)
99                 cmd.Stdin, err = os.Open(stdin)
100                 if err != nil {
101                         return "", "", err
102                 }
103         }
104
105         if taskp.Stdout != "" {
106                 err = checkOutputFilename(outdir, taskp.Stdout)
107                 if err != nil {
108                         return "", "", err
109                 }
110                 // Set up stdout redirection
111                 stdout = outdir + "/" + taskp.Stdout
112                 cmd.Stdout, err = os.Create(stdout)
113                 if err != nil {
114                         return "", "", err
115                 }
116         } else {
117                 cmd.Stdout = os.Stdout
118         }
119
120         if taskp.Env != nil {
121                 // Set up subprocess environment
122                 cmd.Env = os.Environ()
123                 for k, v := range taskp.Env {
124                         v = substitute(v, replacements)
125                         cmd.Env = append(cmd.Env, k+"="+v)
126                 }
127         }
128         return stdin, stdout, nil
129 }
130
131 // Set up signal handlers.  Go sends signal notifications to a "signal
132 // channel".
133 func setupSignals(cmd *exec.Cmd) chan os.Signal {
134         sigChan := make(chan os.Signal, 1)
135         signal.Notify(sigChan, syscall.SIGTERM)
136         signal.Notify(sigChan, syscall.SIGINT)
137         signal.Notify(sigChan, syscall.SIGQUIT)
138         return sigChan
139 }
140
141 func inCodes(code int, codes []int) bool {
142         if codes != nil {
143                 for _, c := range codes {
144                         if code == c {
145                                 return true
146                         }
147                 }
148         }
149         return false
150 }
151
152 const TASK_TEMPFAIL = 111
153
154 type TempFail struct{ error }
155 type PermFail struct{}
156
157 func (s PermFail) Error() string {
158         return "PermFail"
159 }
160
161 func substitute(inp string, subst map[string]string) string {
162         for k, v := range subst {
163                 inp = strings.Replace(inp, k, v, -1)
164         }
165         return inp
166 }
167
168 func runner(api IArvadosClient,
169         kc IKeepClient,
170         jobUuid, taskUuid, crunchtmpdir, keepmount string,
171         jobStruct Job, taskStruct Task) error {
172
173         var err error
174         taskp := taskStruct.Parameters
175
176         // If this is task 0 and there are multiple tasks, dispatch subtasks
177         // and exit.
178         if taskStruct.Sequence == 0 {
179                 if len(jobStruct.Script_parameters.Tasks) == 1 {
180                         taskp = jobStruct.Script_parameters.Tasks[0]
181                 } else {
182                         for _, task := range jobStruct.Script_parameters.Tasks {
183                                 err := api.Create("job_tasks",
184                                         map[string]interface{}{
185                                                 "job_task": Task{Job_uuid: jobUuid,
186                                                         Created_by_job_task_uuid: taskUuid,
187                                                         Sequence:                 1,
188                                                         Parameters:               task}},
189                                         nil)
190                                 if err != nil {
191                                         return TempFail{err}
192                                 }
193                         }
194                         err = api.Update("job_tasks", taskUuid,
195                                 map[string]interface{}{
196                                         "job_task": Task{
197                                                 Output:   "",
198                                                 Success:  true,
199                                                 Progress: 1.0}},
200                                 nil)
201                         return nil
202                 }
203         }
204
205         var tmpdir, outdir string
206         tmpdir, outdir, err = setupDirectories(crunchtmpdir, taskUuid)
207         if err != nil {
208                 return TempFail{err}
209         }
210
211         replacements := map[string]string{
212                 "$(task.tmpdir)": tmpdir,
213                 "$(task.outdir)": outdir,
214                 "$(task.keep)":   keepmount}
215
216         log.Printf("crunchrunner: $(task.tmpdir)=%v", tmpdir)
217         log.Printf("crunchrunner: $(task.outdir)=%v", outdir)
218         log.Printf("crunchrunner: $(task.keep)=%v", keepmount)
219
220         // Set up subprocess
221         for k, v := range taskp.Command {
222                 taskp.Command[k] = substitute(v, replacements)
223         }
224
225         cmd := exec.Command(taskp.Command[0], taskp.Command[1:]...)
226
227         cmd.Dir = outdir
228
229         var stdin, stdout string
230         stdin, stdout, err = setupCommand(cmd, taskp, outdir, replacements)
231         if err != nil {
232                 return err
233         }
234
235         // Run subprocess and wait for it to complete
236         if stdin != "" {
237                 stdin = " < " + stdin
238         }
239         if stdout != "" {
240                 stdout = " > " + stdout
241         }
242         log.Printf("Running %v%v%v", cmd.Args, stdin, stdout)
243
244         var caughtSignal os.Signal
245         sigChan := setupSignals(cmd)
246
247         err = cmd.Start()
248         if err != nil {
249                 signal.Stop(sigChan)
250                 return TempFail{err}
251         }
252
253         finishedSignalNotify := make(chan struct{})
254         go func(sig <-chan os.Signal) {
255                 for sig := range sig {
256                         caughtSignal = sig
257                         cmd.Process.Signal(caughtSignal)
258                 }
259                 close(finishedSignalNotify)
260         }(sigChan)
261
262         err = cmd.Wait()
263         signal.Stop(sigChan)
264
265         close(sigChan)
266         <-finishedSignalNotify
267
268         if caughtSignal != nil {
269                 log.Printf("Caught signal %v", caughtSignal)
270                 return PermFail{}
271         }
272
273         if err != nil {
274                 // Run() returns ExitError on non-zero exit code, but we handle
275                 // that down below.  So only return if it's not ExitError.
276                 if _, ok := err.(*exec.ExitError); !ok {
277                         return TempFail{err}
278                 }
279         }
280
281         var success bool
282
283         exitCode := cmd.ProcessState.Sys().(syscall.WaitStatus).ExitStatus()
284
285         log.Printf("Completed with exit code %v", exitCode)
286
287         if inCodes(exitCode, taskp.PermanentFailCodes) {
288                 success = false
289         } else if inCodes(exitCode, taskp.TemporaryFailCodes) {
290                 return TempFail{fmt.Errorf("Process tempfail with exit code %v", exitCode)}
291         } else if inCodes(exitCode, taskp.SuccessCodes) || cmd.ProcessState.Success() {
292                 success = true
293         } else {
294                 success = false
295         }
296
297         // Upload output directory
298         manifest, err := WriteTree(kc, outdir)
299         if err != nil {
300                 return TempFail{err}
301         }
302
303         // Set status
304         err = api.Update("job_tasks", taskUuid,
305                 map[string]interface{}{
306                         "job_task": Task{
307                                 Output:   manifest,
308                                 Success:  success,
309                                 Progress: 1}},
310                 nil)
311         if err != nil {
312                 return TempFail{err}
313         }
314
315         if success {
316                 return nil
317         } else {
318                 return PermFail{}
319         }
320 }
321
322 func main() {
323         api, err := arvadosclient.MakeArvadosClient()
324         if err != nil {
325                 log.Fatal(err)
326         }
327
328         certpath := path.Join(path.Dir(os.Args[0]), "ca-certificates.crt")
329         certdata, err := ioutil.ReadFile(certpath)
330         if err == nil {
331                 log.Printf("Using TLS certificates at %v", certpath)
332                 certs := x509.NewCertPool()
333                 certs.AppendCertsFromPEM(certdata)
334                 api.Client.Transport.(*http.Transport).TLSClientConfig.RootCAs = certs
335         }
336
337         jobUuid := os.Getenv("JOB_UUID")
338         taskUuid := os.Getenv("TASK_UUID")
339         tmpdir := os.Getenv("TASK_WORK")
340         keepmount := os.Getenv("TASK_KEEPMOUNT")
341
342         var jobStruct Job
343         var taskStruct Task
344
345         err = api.Get("jobs", jobUuid, nil, &jobStruct)
346         if err != nil {
347                 log.Fatal(err)
348         }
349         err = api.Get("job_tasks", taskUuid, nil, &taskStruct)
350         if err != nil {
351                 log.Fatal(err)
352         }
353
354         var kc IKeepClient
355         kc, err = keepclient.MakeKeepClient(&api)
356         if err != nil {
357                 log.Fatal(err)
358         }
359
360         syscall.Umask(0022)
361         err = runner(api, kc, jobUuid, taskUuid, tmpdir, keepmount, jobStruct, taskStruct)
362
363         if err == nil {
364                 os.Exit(0)
365         } else if _, ok := err.(TempFail); ok {
366                 log.Print(err)
367                 os.Exit(TASK_TEMPFAIL)
368         } else if _, ok := err.(PermFail); ok {
369                 os.Exit(1)
370         } else {
371                 log.Fatal(err)
372         }
373 }