10684: Add Arvados-specific search path to Go SDK TLSClientConfig.
[arvados.git] / sdk / go / crunchrunner / crunchrunner.go
1 package main
2
3 import (
4         "crypto/x509"
5         "encoding/json"
6         "fmt"
7         "git.curoverse.com/arvados.git/sdk/go/arvados"
8         "git.curoverse.com/arvados.git/sdk/go/arvadosclient"
9         "git.curoverse.com/arvados.git/sdk/go/keepclient"
10         "io"
11         "io/ioutil"
12         "log"
13         "net/http"
14         "os"
15         "os/exec"
16         "os/signal"
17         "strings"
18         "syscall"
19 )
20
21 type TaskDef struct {
22         Command            []string          `json:"command"`
23         Env                map[string]string `json:"task.env"`
24         Stdin              string            `json:"task.stdin"`
25         Stdout             string            `json:"task.stdout"`
26         Stderr             string            `json:"task.stderr"`
27         Vwd                map[string]string `json:"task.vwd"`
28         SuccessCodes       []int             `json:"task.successCodes"`
29         PermanentFailCodes []int             `json:"task.permanentFailCodes"`
30         TemporaryFailCodes []int             `json:"task.temporaryFailCodes"`
31         KeepTmpOutput      bool              `json:"task.keepTmpOutput"`
32 }
33
34 type Tasks struct {
35         Tasks []TaskDef `json:"tasks"`
36 }
37
38 type Job struct {
39         Script_parameters Tasks `json:"script_parameters"`
40 }
41
42 type Task struct {
43         Job_uuid                 string  `json:"job_uuid"`
44         Created_by_job_task_uuid string  `json:"created_by_job_task_uuid"`
45         Parameters               TaskDef `json:"parameters"`
46         Sequence                 int     `json:"sequence"`
47         Output                   string  `json:"output"`
48         Success                  bool    `json:"success"`
49         Progress                 float32 `json:"sequence"`
50 }
51
52 type IArvadosClient interface {
53         Create(resourceType string, parameters arvadosclient.Dict, output interface{}) error
54         Update(resourceType string, uuid string, parameters arvadosclient.Dict, output interface{}) (err error)
55 }
56
57 func setupDirectories(crunchtmpdir, taskUuid string, keepTmp bool) (tmpdir, outdir string, err error) {
58         tmpdir = crunchtmpdir + "/tmpdir"
59         err = os.Mkdir(tmpdir, 0700)
60         if err != nil {
61                 return "", "", err
62         }
63
64         if keepTmp {
65                 outdir = os.Getenv("TASK_KEEPMOUNT_TMP")
66         } else {
67                 outdir = crunchtmpdir + "/outdir"
68                 err = os.Mkdir(outdir, 0700)
69                 if err != nil {
70                         return "", "", err
71                 }
72         }
73
74         return tmpdir, outdir, nil
75 }
76
77 func checkOutputFilename(outdir, fn string) error {
78         if strings.HasPrefix(fn, "/") || strings.HasSuffix(fn, "/") {
79                 return fmt.Errorf("Path must not start or end with '/'")
80         }
81         if strings.Index("../", fn) != -1 {
82                 return fmt.Errorf("Path must not contain '../'")
83         }
84
85         sl := strings.LastIndex(fn, "/")
86         if sl != -1 {
87                 os.MkdirAll(outdir+"/"+fn[0:sl], 0777)
88         }
89         return nil
90 }
91
92 func copyFile(dst, src string) error {
93         in, err := os.Open(src)
94         if err != nil {
95                 return err
96         }
97         defer in.Close()
98
99         out, err := os.Create(dst)
100         if err != nil {
101                 return err
102         }
103         defer out.Close()
104
105         _, err = io.Copy(out, in)
106         return err
107 }
108
109 func setupCommand(cmd *exec.Cmd, taskp TaskDef, outdir string, replacements map[string]string) (stdin, stdout, stderr string, err error) {
110         if taskp.Vwd != nil {
111                 for k, v := range taskp.Vwd {
112                         v = substitute(v, replacements)
113                         err = checkOutputFilename(outdir, k)
114                         if err != nil {
115                                 return "", "", "", err
116                         }
117                         if taskp.KeepTmpOutput {
118                                 err = copyFile(v, outdir+"/"+k)
119                         } else {
120                                 err = os.Symlink(v, outdir+"/"+k)
121                         }
122                         if err != nil {
123                                 return "", "", "", err
124                         }
125                 }
126         }
127
128         if taskp.Stdin != "" {
129                 // Set up stdin redirection
130                 stdin = substitute(taskp.Stdin, replacements)
131                 cmd.Stdin, err = os.Open(stdin)
132                 if err != nil {
133                         return "", "", "", err
134                 }
135         }
136
137         if taskp.Stdout != "" {
138                 err = checkOutputFilename(outdir, taskp.Stdout)
139                 if err != nil {
140                         return "", "", "", err
141                 }
142                 // Set up stdout redirection
143                 stdout = outdir + "/" + taskp.Stdout
144                 cmd.Stdout, err = os.Create(stdout)
145                 if err != nil {
146                         return "", "", "", err
147                 }
148         } else {
149                 cmd.Stdout = os.Stdout
150         }
151
152         if taskp.Stderr != "" {
153                 err = checkOutputFilename(outdir, taskp.Stderr)
154                 if err != nil {
155                         return "", "", "", err
156                 }
157                 // Set up stderr redirection
158                 stderr = outdir + "/" + taskp.Stderr
159                 cmd.Stderr, err = os.Create(stderr)
160                 if err != nil {
161                         return "", "", "", err
162                 }
163         } else {
164                 cmd.Stderr = os.Stderr
165         }
166
167         if taskp.Env != nil {
168                 // Set up subprocess environment
169                 cmd.Env = os.Environ()
170                 for k, v := range taskp.Env {
171                         v = substitute(v, replacements)
172                         cmd.Env = append(cmd.Env, k+"="+v)
173                 }
174         }
175         return stdin, stdout, stderr, nil
176 }
177
178 // Set up signal handlers.  Go sends signal notifications to a "signal
179 // channel".
180 func setupSignals(cmd *exec.Cmd) chan os.Signal {
181         sigChan := make(chan os.Signal, 1)
182         signal.Notify(sigChan, syscall.SIGTERM)
183         signal.Notify(sigChan, syscall.SIGINT)
184         signal.Notify(sigChan, syscall.SIGQUIT)
185         return sigChan
186 }
187
188 func inCodes(code int, codes []int) bool {
189         if codes != nil {
190                 for _, c := range codes {
191                         if code == c {
192                                 return true
193                         }
194                 }
195         }
196         return false
197 }
198
199 const TASK_TEMPFAIL = 111
200
201 type TempFail struct{ error }
202 type PermFail struct{}
203
204 func (s PermFail) Error() string {
205         return "PermFail"
206 }
207
208 func substitute(inp string, subst map[string]string) string {
209         for k, v := range subst {
210                 inp = strings.Replace(inp, k, v, -1)
211         }
212         return inp
213 }
214
215 func getKeepTmp(outdir string) (manifest string, err error) {
216         fn, err := os.Open(outdir + "/" + ".arvados#collection")
217         if err != nil {
218                 return "", err
219         }
220         defer fn.Close()
221
222         buf, err := ioutil.ReadAll(fn)
223         if err != nil {
224                 return "", err
225         }
226         collection := arvados.Collection{}
227         err = json.Unmarshal(buf, &collection)
228         return collection.ManifestText, err
229 }
230
231 func runner(api IArvadosClient,
232         kc IKeepClient,
233         jobUuid, taskUuid, crunchtmpdir, keepmount string,
234         jobStruct Job, taskStruct Task) error {
235
236         var err error
237         taskp := taskStruct.Parameters
238
239         // If this is task 0 and there are multiple tasks, dispatch subtasks
240         // and exit.
241         if taskStruct.Sequence == 0 {
242                 if len(jobStruct.Script_parameters.Tasks) == 1 {
243                         taskp = jobStruct.Script_parameters.Tasks[0]
244                 } else {
245                         for _, task := range jobStruct.Script_parameters.Tasks {
246                                 err := api.Create("job_tasks",
247                                         map[string]interface{}{
248                                                 "job_task": Task{Job_uuid: jobUuid,
249                                                         Created_by_job_task_uuid: taskUuid,
250                                                         Sequence:                 1,
251                                                         Parameters:               task}},
252                                         nil)
253                                 if err != nil {
254                                         return TempFail{err}
255                                 }
256                         }
257                         err = api.Update("job_tasks", taskUuid,
258                                 map[string]interface{}{
259                                         "job_task": Task{
260                                                 Output:   "",
261                                                 Success:  true,
262                                                 Progress: 1.0}},
263                                 nil)
264                         return nil
265                 }
266         }
267
268         var tmpdir, outdir string
269         tmpdir, outdir, err = setupDirectories(crunchtmpdir, taskUuid, taskp.KeepTmpOutput)
270         if err != nil {
271                 return TempFail{err}
272         }
273
274         replacements := map[string]string{
275                 "$(task.tmpdir)": tmpdir,
276                 "$(task.outdir)": outdir,
277                 "$(task.keep)":   keepmount}
278
279         log.Printf("crunchrunner: $(task.tmpdir)=%v", tmpdir)
280         log.Printf("crunchrunner: $(task.outdir)=%v", outdir)
281         log.Printf("crunchrunner: $(task.keep)=%v", keepmount)
282
283         // Set up subprocess
284         for k, v := range taskp.Command {
285                 taskp.Command[k] = substitute(v, replacements)
286         }
287
288         cmd := exec.Command(taskp.Command[0], taskp.Command[1:]...)
289
290         cmd.Dir = outdir
291
292         var stdin, stdout, stderr string
293         stdin, stdout, stderr, err = setupCommand(cmd, taskp, outdir, replacements)
294         if err != nil {
295                 return err
296         }
297
298         // Run subprocess and wait for it to complete
299         if stdin != "" {
300                 stdin = " < " + stdin
301         }
302         if stdout != "" {
303                 stdout = " > " + stdout
304         }
305         if stderr != "" {
306                 stderr = " 2> " + stderr
307         }
308         log.Printf("Running %v%v%v%v", cmd.Args, stdin, stdout, stderr)
309
310         var caughtSignal os.Signal
311         sigChan := setupSignals(cmd)
312
313         err = cmd.Start()
314         if err != nil {
315                 signal.Stop(sigChan)
316                 return TempFail{err}
317         }
318
319         finishedSignalNotify := make(chan struct{})
320         go func(sig <-chan os.Signal) {
321                 for sig := range sig {
322                         caughtSignal = sig
323                         cmd.Process.Signal(caughtSignal)
324                 }
325                 close(finishedSignalNotify)
326         }(sigChan)
327
328         err = cmd.Wait()
329         signal.Stop(sigChan)
330
331         close(sigChan)
332         <-finishedSignalNotify
333
334         if caughtSignal != nil {
335                 log.Printf("Caught signal %v", caughtSignal)
336                 return PermFail{}
337         }
338
339         if err != nil {
340                 // Run() returns ExitError on non-zero exit code, but we handle
341                 // that down below.  So only return if it's not ExitError.
342                 if _, ok := err.(*exec.ExitError); !ok {
343                         return TempFail{err}
344                 }
345         }
346
347         var success bool
348
349         exitCode := cmd.ProcessState.Sys().(syscall.WaitStatus).ExitStatus()
350
351         log.Printf("Completed with exit code %v", exitCode)
352
353         if inCodes(exitCode, taskp.PermanentFailCodes) {
354                 success = false
355         } else if inCodes(exitCode, taskp.TemporaryFailCodes) {
356                 return TempFail{fmt.Errorf("Process tempfail with exit code %v", exitCode)}
357         } else if inCodes(exitCode, taskp.SuccessCodes) || cmd.ProcessState.Success() {
358                 success = true
359         } else {
360                 success = false
361         }
362
363         // Upload output directory
364         var manifest string
365         if taskp.KeepTmpOutput {
366                 manifest, err = getKeepTmp(outdir)
367         } else {
368                 manifest, err = WriteTree(kc, outdir)
369         }
370         if err != nil {
371                 return TempFail{err}
372         }
373
374         // Set status
375         err = api.Update("job_tasks", taskUuid,
376                 map[string]interface{}{
377                         "job_task": Task{
378                                 Output:   manifest,
379                                 Success:  success,
380                                 Progress: 1}},
381                 nil)
382         if err != nil {
383                 return TempFail{err}
384         }
385
386         if success {
387                 return nil
388         } else {
389                 return PermFail{}
390         }
391 }
392
393 func main() {
394         api, err := arvadosclient.MakeArvadosClient()
395         if err != nil {
396                 log.Fatal(err)
397         }
398
399         jobUuid := os.Getenv("JOB_UUID")
400         taskUuid := os.Getenv("TASK_UUID")
401         tmpdir := os.Getenv("TASK_WORK")
402         keepmount := os.Getenv("TASK_KEEPMOUNT")
403
404         var jobStruct Job
405         var taskStruct Task
406
407         err = api.Get("jobs", jobUuid, nil, &jobStruct)
408         if err != nil {
409                 log.Fatal(err)
410         }
411         err = api.Get("job_tasks", taskUuid, nil, &taskStruct)
412         if err != nil {
413                 log.Fatal(err)
414         }
415
416         var kc IKeepClient
417         kc, err = keepclient.MakeKeepClient(api)
418         if err != nil {
419                 log.Fatal(err)
420         }
421
422         syscall.Umask(0022)
423         err = runner(api, kc, jobUuid, taskUuid, tmpdir, keepmount, jobStruct, taskStruct)
424
425         if err == nil {
426                 os.Exit(0)
427         } else if _, ok := err.(TempFail); ok {
428                 log.Print(err)
429                 os.Exit(TASK_TEMPFAIL)
430         } else if _, ok := err.(PermFail); ok {
431                 os.Exit(1)
432         } else {
433                 log.Fatal(err)
434         }
435 }