19166: Pass GatewayAuthSecret to crunch-run through lsf/slurm.
[arvados.git] / lib / lsf / dispatch.go
index 537d52a072d6a503262b1a228c868afc8f28b151..e2348337e62992eb4463947690e809e1927bb232 100644 (file)
@@ -6,6 +6,8 @@ package lsf
 
 import (
        "context"
+       "crypto/hmac"
+       "crypto/sha256"
        "errors"
        "fmt"
        "math"
@@ -119,7 +121,7 @@ func (disp *dispatcher) init() {
        disp.lsfcli.logger = disp.logger
        disp.lsfqueue = lsfqueue{
                logger: disp.logger,
-               period: time.Duration(disp.Cluster.Containers.CloudVMs.PollInterval),
+               period: disp.Cluster.Containers.CloudVMs.PollInterval.Duration(),
                lsfcli: &disp.lsfcli,
        }
        disp.ArvClient.AuthToken = disp.AuthToken
@@ -256,7 +258,7 @@ func (disp *dispatcher) runContainer(_ *dispatch.Dispatcher, ctr arvados.Contain
 
        // Try "bkill" every few seconds until the LSF job disappears
        // from the queue.
-       ticker := time.NewTicker(5 * time.Second)
+       ticker := time.NewTicker(disp.Cluster.Containers.CloudVMs.PollInterval.Duration() / 2)
        defer ticker.Stop()
        for qent, ok := disp.lsfqueue.Lookup(ctr.UUID); ok; _, ok = disp.lsfqueue.Lookup(ctr.UUID) {
                err := disp.lsfcli.Bkill(qent.ID)
@@ -274,7 +276,12 @@ func (disp *dispatcher) submit(container arvados.Container, crunchRunCommand []s
        var crArgs []string
        crArgs = append(crArgs, crunchRunCommand...)
        crArgs = append(crArgs, container.UUID)
-       crScript := execScript(crArgs)
+
+       h := hmac.New(sha256.New, []byte(disp.Cluster.SystemRootToken))
+       fmt.Fprint(h, container.UUID)
+       authsecret := fmt.Sprintf("%x", h.Sum(nil))
+
+       crScript := execScript(crArgs, map[string]string{"GatewayAuthSecret": authsecret})
 
        bsubArgs, err := disp.bsubArgs(container)
        if err != nil {
@@ -306,11 +313,16 @@ func (disp *dispatcher) bsubArgs(container arvados.Container) ([]string, error)
                "%M": fmt.Sprintf("%d", mem),
                "%T": fmt.Sprintf("%d", tmp),
                "%U": container.UUID,
+               "%G": fmt.Sprintf("%d", container.RuntimeConstraints.CUDA.DeviceCount),
        }
 
        re := regexp.MustCompile(`%.`)
        var substitutionErrors string
-       for _, a := range disp.Cluster.Containers.LSF.BsubArgumentsList {
+       argumentTemplate := disp.Cluster.Containers.LSF.BsubArgumentsList
+       if container.RuntimeConstraints.CUDA.DeviceCount > 0 {
+               argumentTemplate = append(argumentTemplate, disp.Cluster.Containers.LSF.BsubCUDAArguments...)
+       }
+       for _, a := range argumentTemplate {
                args = append(args, re.ReplaceAllStringFunc(a, func(s string) string {
                        subst := repl[s]
                        if len(subst) == 0 {
@@ -348,8 +360,14 @@ func (disp *dispatcher) checkLsfQueueForOrphans() {
        }
 }
 
-func execScript(args []string) []byte {
-       s := "#!/bin/sh\nexec"
+func execScript(args []string, env map[string]string) []byte {
+       s := "#!/bin/sh\n"
+       for k, v := range env {
+               s += k + `='`
+               s += strings.Replace(v, `'`, `'\''`, -1)
+               s += `' `
+       }
+       s += `exec`
        for _, w := range args {
                s += ` '`
                s += strings.Replace(w, `'`, `'\''`, -1)