8815: Now expect /usr/local/bin/crunchrunner. Bind mount host certificates to
[arvados.git] / sdk / go / crunchrunner / crunchrunner.go
1 package main
2
3 import (
4         "crypto/x509"
5         "fmt"
6         "git.curoverse.com/arvados.git/sdk/go/arvadosclient"
7         "git.curoverse.com/arvados.git/sdk/go/keepclient"
8         "io/ioutil"
9         "log"
10         "net/http"
11         "os"
12         "os/exec"
13         "os/signal"
14         "path"
15         "strings"
16         "syscall"
17 )
18
19 type TaskDef struct {
20         Command            []string          `json:"command"`
21         Env                map[string]string `json:"task.env"`
22         Stdin              string            `json:"task.stdin"`
23         Stdout             string            `json:"task.stdout"`
24         Vwd                map[string]string `json:"task.vwd"`
25         SuccessCodes       []int             `json:"task.successCodes"`
26         PermanentFailCodes []int             `json:"task.permanentFailCodes"`
27         TemporaryFailCodes []int             `json:"task.temporaryFailCodes"`
28 }
29
30 type Tasks struct {
31         Tasks []TaskDef `json:"tasks"`
32 }
33
34 type Job struct {
35         Script_parameters Tasks `json:"script_parameters"`
36 }
37
38 type Task struct {
39         Job_uuid                 string  `json:"job_uuid"`
40         Created_by_job_task_uuid string  `json:"created_by_job_task_uuid"`
41         Parameters               TaskDef `json:"parameters"`
42         Sequence                 int     `json:"sequence"`
43         Output                   string  `json:"output"`
44         Success                  bool    `json:"success"`
45         Progress                 float32 `json:"sequence"`
46 }
47
48 type IArvadosClient interface {
49         Create(resourceType string, parameters arvadosclient.Dict, output interface{}) error
50         Update(resourceType string, uuid string, parameters arvadosclient.Dict, output interface{}) (err error)
51 }
52
53 func setupDirectories(crunchtmpdir, taskUuid string) (tmpdir, outdir string, err error) {
54         tmpdir = crunchtmpdir + "/tmpdir"
55         err = os.Mkdir(tmpdir, 0700)
56         if err != nil {
57                 return "", "", err
58         }
59
60         outdir = crunchtmpdir + "/outdir"
61         err = os.Mkdir(outdir, 0700)
62         if err != nil {
63                 return "", "", err
64         }
65
66         return tmpdir, outdir, nil
67 }
68
69 func checkOutputFilename(outdir, fn string) error {
70         if strings.HasPrefix(fn, "/") || strings.HasSuffix(fn, "/") {
71                 return fmt.Errorf("Path must not start or end with '/'")
72         }
73         if strings.Index("../", fn) != -1 {
74                 return fmt.Errorf("Path must not contain '../'")
75         }
76
77         sl := strings.LastIndex(fn, "/")
78         if sl != -1 {
79                 os.MkdirAll(outdir+"/"+fn[0:sl], 0777)
80         }
81         return nil
82 }
83
84 func setupCommand(cmd *exec.Cmd, taskp TaskDef, outdir string, replacements map[string]string) (stdin, stdout string, err error) {
85         if taskp.Vwd != nil {
86                 for k, v := range taskp.Vwd {
87                         v = substitute(v, replacements)
88                         err = checkOutputFilename(outdir, k)
89                         if err != nil {
90                                 return "", "", err
91                         }
92                         os.Symlink(v, outdir+"/"+k)
93                 }
94         }
95
96         if taskp.Stdin != "" {
97                 // Set up stdin redirection
98                 stdin = substitute(taskp.Stdin, replacements)
99                 cmd.Stdin, err = os.Open(stdin)
100                 if err != nil {
101                         return "", "", err
102                 }
103         }
104
105         if taskp.Stdout != "" {
106                 err = checkOutputFilename(outdir, taskp.Stdout)
107                 if err != nil {
108                         return "", "", err
109                 }
110                 // Set up stdout redirection
111                 stdout = outdir + "/" + taskp.Stdout
112                 cmd.Stdout, err = os.Create(stdout)
113                 if err != nil {
114                         return "", "", err
115                 }
116         } else {
117                 cmd.Stdout = os.Stdout
118         }
119
120         cmd.Stderr = os.Stderr
121
122         if taskp.Env != nil {
123                 // Set up subprocess environment
124                 cmd.Env = os.Environ()
125                 for k, v := range taskp.Env {
126                         v = substitute(v, replacements)
127                         cmd.Env = append(cmd.Env, k+"="+v)
128                 }
129         }
130         return stdin, stdout, nil
131 }
132
133 // Set up signal handlers.  Go sends signal notifications to a "signal
134 // channel".
135 func setupSignals(cmd *exec.Cmd) chan os.Signal {
136         sigChan := make(chan os.Signal, 1)
137         signal.Notify(sigChan, syscall.SIGTERM)
138         signal.Notify(sigChan, syscall.SIGINT)
139         signal.Notify(sigChan, syscall.SIGQUIT)
140         return sigChan
141 }
142
143 func inCodes(code int, codes []int) bool {
144         if codes != nil {
145                 for _, c := range codes {
146                         if code == c {
147                                 return true
148                         }
149                 }
150         }
151         return false
152 }
153
154 const TASK_TEMPFAIL = 111
155
156 type TempFail struct{ error }
157 type PermFail struct{}
158
159 func (s PermFail) Error() string {
160         return "PermFail"
161 }
162
163 func substitute(inp string, subst map[string]string) string {
164         for k, v := range subst {
165                 inp = strings.Replace(inp, k, v, -1)
166         }
167         return inp
168 }
169
170 func runner(api IArvadosClient,
171         kc IKeepClient,
172         jobUuid, taskUuid, crunchtmpdir, keepmount string,
173         jobStruct Job, taskStruct Task) error {
174
175         var err error
176         taskp := taskStruct.Parameters
177
178         // If this is task 0 and there are multiple tasks, dispatch subtasks
179         // and exit.
180         if taskStruct.Sequence == 0 {
181                 if len(jobStruct.Script_parameters.Tasks) == 1 {
182                         taskp = jobStruct.Script_parameters.Tasks[0]
183                 } else {
184                         for _, task := range jobStruct.Script_parameters.Tasks {
185                                 err := api.Create("job_tasks",
186                                         map[string]interface{}{
187                                                 "job_task": Task{Job_uuid: jobUuid,
188                                                         Created_by_job_task_uuid: taskUuid,
189                                                         Sequence:                 1,
190                                                         Parameters:               task}},
191                                         nil)
192                                 if err != nil {
193                                         return TempFail{err}
194                                 }
195                         }
196                         err = api.Update("job_tasks", taskUuid,
197                                 map[string]interface{}{
198                                         "job_task": Task{
199                                                 Output:   "",
200                                                 Success:  true,
201                                                 Progress: 1.0}},
202                                 nil)
203                         return nil
204                 }
205         }
206
207         var tmpdir, outdir string
208         tmpdir, outdir, err = setupDirectories(crunchtmpdir, taskUuid)
209         if err != nil {
210                 return TempFail{err}
211         }
212
213         replacements := map[string]string{
214                 "$(task.tmpdir)": tmpdir,
215                 "$(task.outdir)": outdir,
216                 "$(task.keep)":   keepmount}
217
218         log.Printf("crunchrunner: $(task.tmpdir)=%v", tmpdir)
219         log.Printf("crunchrunner: $(task.outdir)=%v", outdir)
220         log.Printf("crunchrunner: $(task.keep)=%v", keepmount)
221
222         // Set up subprocess
223         for k, v := range taskp.Command {
224                 taskp.Command[k] = substitute(v, replacements)
225         }
226
227         cmd := exec.Command(taskp.Command[0], taskp.Command[1:]...)
228
229         cmd.Dir = outdir
230
231         var stdin, stdout string
232         stdin, stdout, err = setupCommand(cmd, taskp, outdir, replacements)
233         if err != nil {
234                 return err
235         }
236
237         // Run subprocess and wait for it to complete
238         if stdin != "" {
239                 stdin = " < " + stdin
240         }
241         if stdout != "" {
242                 stdout = " > " + stdout
243         }
244         log.Printf("Running %v%v%v", cmd.Args, stdin, stdout)
245
246         var caughtSignal os.Signal
247         sigChan := setupSignals(cmd)
248
249         err = cmd.Start()
250         if err != nil {
251                 signal.Stop(sigChan)
252                 return TempFail{err}
253         }
254
255         finishedSignalNotify := make(chan struct{})
256         go func(sig <-chan os.Signal) {
257                 for sig := range sig {
258                         caughtSignal = sig
259                         cmd.Process.Signal(caughtSignal)
260                 }
261                 close(finishedSignalNotify)
262         }(sigChan)
263
264         err = cmd.Wait()
265         signal.Stop(sigChan)
266
267         close(sigChan)
268         <-finishedSignalNotify
269
270         if caughtSignal != nil {
271                 log.Printf("Caught signal %v", caughtSignal)
272                 return PermFail{}
273         }
274
275         if err != nil {
276                 // Run() returns ExitError on non-zero exit code, but we handle
277                 // that down below.  So only return if it's not ExitError.
278                 if _, ok := err.(*exec.ExitError); !ok {
279                         return TempFail{err}
280                 }
281         }
282
283         var success bool
284
285         exitCode := cmd.ProcessState.Sys().(syscall.WaitStatus).ExitStatus()
286
287         log.Printf("Completed with exit code %v", exitCode)
288
289         if inCodes(exitCode, taskp.PermanentFailCodes) {
290                 success = false
291         } else if inCodes(exitCode, taskp.TemporaryFailCodes) {
292                 return TempFail{fmt.Errorf("Process tempfail with exit code %v", exitCode)}
293         } else if inCodes(exitCode, taskp.SuccessCodes) || cmd.ProcessState.Success() {
294                 success = true
295         } else {
296                 success = false
297         }
298
299         // Upload output directory
300         manifest, err := WriteTree(kc, outdir)
301         if err != nil {
302                 return TempFail{err}
303         }
304
305         // Set status
306         err = api.Update("job_tasks", taskUuid,
307                 map[string]interface{}{
308                         "job_task": Task{
309                                 Output:   manifest,
310                                 Success:  success,
311                                 Progress: 1}},
312                 nil)
313         if err != nil {
314                 return TempFail{err}
315         }
316
317         if success {
318                 return nil
319         } else {
320                 return PermFail{}
321         }
322 }
323
324 func main() {
325         api, err := arvadosclient.MakeArvadosClient()
326         if err != nil {
327                 log.Fatal(err)
328         }
329
330         // Container may not have certificates installed, so need to look for
331         // /etc/arvados/ca-certificates.crt in addition to normal system certs.
332         var certFiles = []string{
333                 "/etc/ssl/certs/ca-certificates.crt", // Debian
334                 "/etc/pki/tls/certs/ca-bundle.crt",   // Red Hat
335                 "/etc/arvados/ca-certificates.crt",
336         }
337
338         certs := x509.NewCertPool()
339         for _, file := range certFiles {
340                 data, err := ioutil.ReadFile(file)
341                 if err == nil {
342                         log.Printf("Using TLS certificates at %v", file)
343                         certs.AppendCertsFromPEM(data)
344                 }
345         }
346         api.Client.Transport.(*http.Transport).TLSClientConfig.RootCAs = certs
347
348         jobUuid := os.Getenv("JOB_UUID")
349         taskUuid := os.Getenv("TASK_UUID")
350         tmpdir := os.Getenv("TASK_WORK")
351         keepmount := os.Getenv("TASK_KEEPMOUNT")
352
353         var jobStruct Job
354         var taskStruct Task
355
356         err = api.Get("jobs", jobUuid, nil, &jobStruct)
357         if err != nil {
358                 log.Fatal(err)
359         }
360         err = api.Get("job_tasks", taskUuid, nil, &taskStruct)
361         if err != nil {
362                 log.Fatal(err)
363         }
364
365         var kc IKeepClient
366         kc, err = keepclient.MakeKeepClient(&api)
367         if err != nil {
368                 log.Fatal(err)
369         }
370
371         syscall.Umask(0022)
372         err = runner(api, kc, jobUuid, taskUuid, tmpdir, keepmount, jobStruct, taskStruct)
373
374         if err == nil {
375                 os.Exit(0)
376         } else if _, ok := err.(TempFail); ok {
377                 log.Print(err)
378                 os.Exit(TASK_TEMPFAIL)
379         } else if _, ok := err.(PermFail); ok {
380                 os.Exit(1)
381         } else {
382                 log.Fatal(err)
383         }
384 }