8087: makes changes suggested by radhika
[arvados.git] / sdk / go / crunchrunner / crunchrunner.go
1 package main
2
3 import (
4         "fmt"
5         "git.curoverse.com/arvados.git/sdk/go/arvadosclient"
6         "git.curoverse.com/arvados.git/sdk/go/keepclient"
7         "log"
8         "os"
9         "os/exec"
10         "os/signal"
11         "strings"
12         "syscall"
13 )
14
15 type TaskDef struct {
16         Command            []string          `json:"command"`
17         Env                map[string]string `json:"task.env"`
18         Stdin              string            `json:"task.stdin"`
19         Stdout             string            `json:"task.stdout"`
20         Vwd                map[string]string `json:"task.vwd"`
21         SuccessCodes       []int             `json:"task.successCodes"`
22         PermanentFailCodes []int             `json:"task.permanentFailCodes"`
23         TemporaryFailCodes []int             `json:"task.temporaryFailCodes"`
24 }
25
26 type Tasks struct {
27         Tasks []TaskDef `json:"tasks"`
28 }
29
30 type Job struct {
31         Script_parameters Tasks `json:"script_parameters"`
32 }
33
34 type Task struct {
35         Job_uuid                 string  `json:"job_uuid"`
36         Created_by_job_task_uuid string  `json:"created_by_job_task_uuid"`
37         Parameters               TaskDef `json:"parameters"`
38         Sequence                 int     `json:"sequence"`
39         Output                   string  `json:"output"`
40         Success                  bool    `json:"success"`
41         Progress                 float32 `json:"sequence"`
42 }
43
44 type IArvadosClient interface {
45         Create(resourceType string, parameters arvadosclient.Dict, output interface{}) error
46         Update(resourceType string, uuid string, parameters arvadosclient.Dict, output interface{}) (err error)
47 }
48
49 func setupDirectories(crunchtmpdir, taskUuid string) (tmpdir, outdir string, err error) {
50         tmpdir = crunchtmpdir + "/tmpdir"
51         err = os.Mkdir(tmpdir, 0700)
52         if err != nil {
53                 return "", "", err
54         }
55
56         outdir = crunchtmpdir + "/outdir"
57         err = os.Mkdir(outdir, 0700)
58         if err != nil {
59                 return "", "", err
60         }
61
62         return tmpdir, outdir, nil
63 }
64
65 func checkOutputFilename(outdir, fn string) error {
66         if strings.HasPrefix(fn, "/") || strings.HasSuffix(fn, "/") {
67                 return fmt.Errorf("Path must not start or end with '/'")
68         }
69         if strings.Index("../", fn) != -1 {
70                 return fmt.Errorf("Path must not contain '../'")
71         }
72
73         sl := strings.LastIndex(fn, "/")
74         if sl != -1 {
75                 os.MkdirAll(outdir+"/"+fn[0:sl], 0777)
76         }
77         return nil
78 }
79
80 func setupCommand(cmd *exec.Cmd, taskp TaskDef, outdir string, replacements map[string]string) (stdin, stdout string, err error) {
81         if taskp.Vwd != nil {
82                 for k, v := range taskp.Vwd {
83                         v = substitute(v, replacements)
84                         err = checkOutputFilename(outdir, k)
85                         if err != nil {
86                                 return "", "", err
87                         }
88                         os.Symlink(v, outdir+"/"+k)
89                 }
90         }
91
92         if taskp.Stdin != "" {
93                 // Set up stdin redirection
94                 stdin = substitute(taskp.Stdin, replacements)
95                 cmd.Stdin, err = os.Open(stdin)
96                 if err != nil {
97                         return "", "", err
98                 }
99         }
100
101         if taskp.Stdout != "" {
102                 err = checkOutputFilename(outdir, taskp.Stdout)
103                 if err != nil {
104                         return "", "", err
105                 }
106                 // Set up stdout redirection
107                 stdout = outdir + "/" + taskp.Stdout
108                 cmd.Stdout, err = os.Create(stdout)
109                 if err != nil {
110                         return "", "", err
111                 }
112         } else {
113                 cmd.Stdout = os.Stdout
114         }
115
116         if taskp.Env != nil {
117                 // Set up subprocess environment
118                 cmd.Env = os.Environ()
119                 for k, v := range taskp.Env {
120                         v = substitute(v, replacements)
121                         cmd.Env = append(cmd.Env, k+"="+v)
122                 }
123         }
124         return stdin, stdout, nil
125 }
126
127 // Set up signal handlers.  Go sends signal notifications to a "signal
128 // channel".
129 func setupSignals(cmd *exec.Cmd) chan os.Signal {
130         sigChan := make(chan os.Signal, 1)
131         signal.Notify(sigChan, syscall.SIGTERM)
132         signal.Notify(sigChan, syscall.SIGINT)
133         signal.Notify(sigChan, syscall.SIGQUIT)
134         return sigChan
135 }
136
137 func inCodes(code int, codes []int) bool {
138         if codes != nil {
139                 for _, c := range codes {
140                         if code == c {
141                                 return true
142                         }
143                 }
144         }
145         return false
146 }
147
148 const TASK_TEMPFAIL = 111
149
150 type TempFail struct{ error }
151 type PermFail struct{}
152
153 func (s PermFail) Error() string {
154         return "PermFail"
155 }
156
157 func substitute(inp string, subst map[string]string) string {
158         for k, v := range subst {
159                 inp = strings.Replace(inp, k, v, -1)
160         }
161         return inp
162 }
163
164 func runner(api IArvadosClient,
165         kc IKeepClient,
166         jobUuid, taskUuid, crunchtmpdir, keepmount string,
167         jobStruct Job, taskStruct Task) error {
168
169         var err error
170         taskp := taskStruct.Parameters
171
172         // If this is task 0 and there are multiple tasks, dispatch subtasks
173         // and exit.
174         if taskStruct.Sequence == 0 {
175                 if len(jobStruct.Script_parameters.Tasks) == 1 {
176                         taskp = jobStruct.Script_parameters.Tasks[0]
177                 } else {
178                         for _, task := range jobStruct.Script_parameters.Tasks {
179                                 err := api.Create("job_tasks",
180                                         map[string]interface{}{
181                                                 "job_task": Task{Job_uuid: jobUuid,
182                                                         Created_by_job_task_uuid: taskUuid,
183                                                         Sequence:                 1,
184                                                         Parameters:               task}},
185                                         nil)
186                                 if err != nil {
187                                         return TempFail{err}
188                                 }
189                         }
190                         err = api.Update("job_tasks", taskUuid,
191                                 map[string]interface{}{
192                                         "job_task": Task{
193                                                 Output:   "",
194                                                 Success:  true,
195                                                 Progress: 1.0}},
196                                 nil)
197                         return nil
198                 }
199         }
200
201         var tmpdir, outdir string
202         tmpdir, outdir, err = setupDirectories(crunchtmpdir, taskUuid)
203         if err != nil {
204                 return TempFail{err}
205         }
206
207         replacements := map[string]string{
208                 "$(task.tmpdir)": tmpdir,
209                 "$(task.outdir)": outdir,
210                 "$(task.keep)":   keepmount}
211
212         // Set up subprocess
213         for k, v := range taskp.Command {
214                 taskp.Command[k] = substitute(v, replacements)
215         }
216
217         cmd := exec.Command(taskp.Command[0], taskp.Command[1:]...)
218
219         cmd.Dir = outdir
220
221         var stdin, stdout string
222         stdin, stdout, err = setupCommand(cmd, taskp, outdir, replacements)
223         if err != nil {
224                 return err
225         }
226
227         // Run subprocess and wait for it to complete
228         if stdin != "" {
229                 stdin = " < " + stdin
230         }
231         if stdout != "" {
232                 stdout = " > " + stdout
233         }
234         log.Printf("Running %v%v%v", cmd.Args, stdin, stdout)
235
236         var caughtSignal os.Signal
237         sigChan := setupSignals(cmd)
238
239         err = cmd.Start()
240         if err != nil {
241                 signal.Stop(sigChan)
242                 return TempFail{err}
243         }
244
245         finishedSignalNotify := make(chan struct{})
246         go func(sig <-chan os.Signal) {
247                 for sig := range sig {
248                         caughtSignal = sig
249                         cmd.Process.Signal(caughtSignal)
250                 }
251                 close(finishedSignalNotify)
252         }(sigChan)
253
254         err = cmd.Wait()
255         signal.Stop(sigChan)
256
257         close(sigChan)
258         <-finishedSignalNotify
259
260         if caughtSignal != nil {
261                 log.Printf("Caught signal %v", caughtSignal)
262                 return PermFail{}
263         }
264
265         if err != nil {
266                 // Run() returns ExitError on non-zero exit code, but we handle
267                 // that down below.  So only return if it's not ExitError.
268                 if _, ok := err.(*exec.ExitError); !ok {
269                         return TempFail{err}
270                 }
271         }
272
273         var success bool
274
275         exitCode := cmd.ProcessState.Sys().(syscall.WaitStatus).ExitStatus()
276
277         log.Printf("Completed with exit code %v", exitCode)
278
279         if inCodes(exitCode, taskp.PermanentFailCodes) {
280                 success = false
281         } else if inCodes(exitCode, taskp.TemporaryFailCodes) {
282                 return TempFail{fmt.Errorf("Process tempfail with exit code %v", exitCode)}
283         } else if inCodes(exitCode, taskp.SuccessCodes) || cmd.ProcessState.Success() {
284                 success = true
285         } else {
286                 success = false
287         }
288
289         // Upload output directory
290         manifest, err := WriteTree(kc, outdir)
291         if err != nil {
292                 return TempFail{err}
293         }
294
295         // Set status
296         err = api.Update("job_tasks", taskUuid,
297                 map[string]interface{}{
298                         "job_task": Task{
299                                 Output:   manifest,
300                                 Success:  success,
301                                 Progress: 1}},
302                 nil)
303         if err != nil {
304                 return TempFail{err}
305         }
306
307         if success {
308                 return nil
309         } else {
310                 return PermFail{}
311         }
312 }
313
314 func main() {
315         api, err := arvadosclient.MakeArvadosClient()
316         if err != nil {
317                 log.Fatal(err)
318         }
319
320         jobUuid := os.Getenv("JOB_UUID")
321         taskUuid := os.Getenv("TASK_UUID")
322         tmpdir := os.Getenv("TASK_WORK")
323         keepmount := os.Getenv("TASK_KEEPMOUNT")
324
325         var jobStruct Job
326         var taskStruct Task
327
328         err = api.Get("jobs", jobUuid, nil, &jobStruct)
329         if err != nil {
330                 log.Fatal(err)
331         }
332         err = api.Get("job_tasks", taskUuid, nil, &taskStruct)
333         if err != nil {
334                 log.Fatal(err)
335         }
336
337         var kc IKeepClient
338         kc, err = keepclient.MakeKeepClient(&api)
339         if err != nil {
340                 log.Fatal(err)
341         }
342
343         syscall.Umask(0022)
344         err = runner(api, kc, jobUuid, taskUuid, tmpdir, keepmount, jobStruct, taskStruct)
345
346         if err == nil {
347                 os.Exit(0)
348         } else if _, ok := err.(TempFail); ok {
349                 log.Print(err)
350                 os.Exit(TASK_TEMPFAIL)
351         } else if _, ok := err.(PermFail); ok {
352                 os.Exit(1)
353         } else {
354                 log.Fatal(err)
355         }
356 }