d362f66d14b3ee12b9a4fb6b197b9a34747d944c
[arvados.git] / lib / lsf / dispatch.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package lsf
6
7 import (
8         "context"
9         "crypto/hmac"
10         "crypto/sha256"
11         "errors"
12         "fmt"
13         "math"
14         "net/http"
15         "regexp"
16         "strings"
17         "sync"
18         "time"
19
20         "git.arvados.org/arvados.git/lib/cmd"
21         "git.arvados.org/arvados.git/lib/dispatchcloud"
22         "git.arvados.org/arvados.git/lib/service"
23         "git.arvados.org/arvados.git/sdk/go/arvados"
24         "git.arvados.org/arvados.git/sdk/go/arvadosclient"
25         "git.arvados.org/arvados.git/sdk/go/auth"
26         "git.arvados.org/arvados.git/sdk/go/ctxlog"
27         "git.arvados.org/arvados.git/sdk/go/dispatch"
28         "git.arvados.org/arvados.git/sdk/go/health"
29         "github.com/julienschmidt/httprouter"
30         "github.com/prometheus/client_golang/prometheus"
31         "github.com/prometheus/client_golang/prometheus/promhttp"
32         "github.com/sirupsen/logrus"
33 )
34
35 var DispatchCommand cmd.Handler = service.Command(arvados.ServiceNameDispatchLSF, newHandler)
36
37 func newHandler(ctx context.Context, cluster *arvados.Cluster, token string, reg *prometheus.Registry) service.Handler {
38         ac, err := arvados.NewClientFromConfig(cluster)
39         if err != nil {
40                 return service.ErrorHandler(ctx, cluster, fmt.Errorf("error initializing client from cluster config: %s", err))
41         }
42         d := &dispatcher{
43                 Cluster:   cluster,
44                 Context:   ctx,
45                 ArvClient: ac,
46                 AuthToken: token,
47                 Registry:  reg,
48         }
49         go d.Start()
50         return d
51 }
52
53 type dispatcher struct {
54         Cluster   *arvados.Cluster
55         Context   context.Context
56         ArvClient *arvados.Client
57         AuthToken string
58         Registry  *prometheus.Registry
59
60         logger        logrus.FieldLogger
61         lsfcli        lsfcli
62         lsfqueue      lsfqueue
63         arvDispatcher *dispatch.Dispatcher
64         httpHandler   http.Handler
65
66         initOnce sync.Once
67         stop     chan struct{}
68         stopped  chan struct{}
69 }
70
71 // Start starts the dispatcher. Start can be called multiple times
72 // with no ill effect.
73 func (disp *dispatcher) Start() {
74         disp.initOnce.Do(func() {
75                 disp.init()
76                 go func() {
77                         disp.checkLsfQueueForOrphans()
78                         err := disp.arvDispatcher.Run(disp.Context)
79                         if err != nil {
80                                 disp.logger.Error(err)
81                                 disp.Close()
82                         }
83                 }()
84         })
85 }
86
87 // ServeHTTP implements service.Handler.
88 func (disp *dispatcher) ServeHTTP(w http.ResponseWriter, r *http.Request) {
89         disp.Start()
90         disp.httpHandler.ServeHTTP(w, r)
91 }
92
93 // CheckHealth implements service.Handler.
94 func (disp *dispatcher) CheckHealth() error {
95         disp.Start()
96         select {
97         case <-disp.stopped:
98                 return errors.New("stopped")
99         default:
100                 return nil
101         }
102 }
103
104 // Done implements service.Handler.
105 func (disp *dispatcher) Done() <-chan struct{} {
106         return disp.stopped
107 }
108
109 // Stop dispatching containers and release resources. Used by tests.
110 func (disp *dispatcher) Close() {
111         disp.Start()
112         select {
113         case disp.stop <- struct{}{}:
114         default:
115         }
116         <-disp.stopped
117 }
118
119 func (disp *dispatcher) init() {
120         disp.logger = ctxlog.FromContext(disp.Context)
121         disp.lsfcli.logger = disp.logger
122         disp.lsfqueue = lsfqueue{
123                 logger: disp.logger,
124                 period: disp.Cluster.Containers.CloudVMs.PollInterval.Duration(),
125                 lsfcli: &disp.lsfcli,
126         }
127         disp.ArvClient.AuthToken = disp.AuthToken
128         disp.stop = make(chan struct{}, 1)
129         disp.stopped = make(chan struct{})
130
131         arv, err := arvadosclient.New(disp.ArvClient)
132         if err != nil {
133                 disp.logger.Fatalf("Error making Arvados client: %v", err)
134         }
135         arv.Retries = 25
136         arv.ApiToken = disp.AuthToken
137         disp.arvDispatcher = &dispatch.Dispatcher{
138                 Arv:            arv,
139                 Logger:         disp.logger,
140                 BatchSize:      disp.Cluster.API.MaxItemsPerResponse,
141                 RunContainer:   disp.runContainer,
142                 PollPeriod:     time.Duration(disp.Cluster.Containers.CloudVMs.PollInterval),
143                 MinRetryPeriod: time.Duration(disp.Cluster.Containers.MinRetryPeriod),
144         }
145
146         if disp.Cluster.ManagementToken == "" {
147                 disp.httpHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
148                         http.Error(w, "Management API authentication is not configured", http.StatusForbidden)
149                 })
150         } else {
151                 mux := httprouter.New()
152                 metricsH := promhttp.HandlerFor(disp.Registry, promhttp.HandlerOpts{
153                         ErrorLog: disp.logger,
154                 })
155                 mux.Handler("GET", "/metrics", metricsH)
156                 mux.Handler("GET", "/metrics.json", metricsH)
157                 mux.Handler("GET", "/_health/:check", &health.Handler{
158                         Token:  disp.Cluster.ManagementToken,
159                         Prefix: "/_health/",
160                         Routes: health.Routes{"ping": disp.CheckHealth},
161                 })
162                 disp.httpHandler = auth.RequireLiteralToken(disp.Cluster.ManagementToken, mux)
163         }
164 }
165
166 func (disp *dispatcher) runContainer(_ *dispatch.Dispatcher, ctr arvados.Container, status <-chan arvados.Container) error {
167         ctx, cancel := context.WithCancel(disp.Context)
168         defer cancel()
169
170         if ctr.State != dispatch.Locked {
171                 // already started by prior invocation
172         } else if _, ok := disp.lsfqueue.Lookup(ctr.UUID); !ok {
173                 if _, err := dispatchcloud.ChooseInstanceType(disp.Cluster, &ctr); errors.As(err, &dispatchcloud.ConstraintsNotSatisfiableError{}) {
174                         err := disp.arvDispatcher.Arv.Update("containers", ctr.UUID, arvadosclient.Dict{
175                                 "container": map[string]interface{}{
176                                         "runtime_status": map[string]string{
177                                                 "error": err.Error(),
178                                         },
179                                 },
180                         }, nil)
181                         if err != nil {
182                                 return fmt.Errorf("error setting runtime_status on %s: %s", ctr.UUID, err)
183                         }
184                         return disp.arvDispatcher.UpdateState(ctr.UUID, dispatch.Cancelled)
185                 }
186                 disp.logger.Printf("Submitting container %s to LSF", ctr.UUID)
187                 cmd := []string{disp.Cluster.Containers.CrunchRunCommand}
188                 cmd = append(cmd, "--runtime-engine="+disp.Cluster.Containers.RuntimeEngine)
189                 cmd = append(cmd, disp.Cluster.Containers.CrunchRunArgumentsList...)
190                 err := disp.submit(ctr, cmd)
191                 if err != nil {
192                         return err
193                 }
194         }
195
196         disp.logger.Printf("Start monitoring container %v in state %q", ctr.UUID, ctr.State)
197         defer disp.logger.Printf("Done monitoring container %s", ctr.UUID)
198
199         go func(uuid string) {
200                 for ctx.Err() == nil {
201                         _, ok := disp.lsfqueue.Lookup(uuid)
202                         if !ok {
203                                 // If the container disappears from
204                                 // the lsf queue, there is no point in
205                                 // waiting for further dispatch
206                                 // updates: just clean up and return.
207                                 disp.logger.Printf("container %s job disappeared from LSF queue", uuid)
208                                 cancel()
209                                 return
210                         }
211                 }
212         }(ctr.UUID)
213
214         for done := false; !done; {
215                 select {
216                 case <-ctx.Done():
217                         // Disappeared from lsf queue
218                         if err := disp.arvDispatcher.Arv.Get("containers", ctr.UUID, nil, &ctr); err != nil {
219                                 disp.logger.Printf("error getting final container state for %s: %s", ctr.UUID, err)
220                         }
221                         switch ctr.State {
222                         case dispatch.Running:
223                                 disp.arvDispatcher.UpdateState(ctr.UUID, dispatch.Cancelled)
224                         case dispatch.Locked:
225                                 disp.arvDispatcher.Unlock(ctr.UUID)
226                         }
227                         return nil
228                 case updated, ok := <-status:
229                         if !ok {
230                                 // status channel is closed, which is
231                                 // how arvDispatcher tells us to stop
232                                 // touching the container record, kill
233                                 // off any remaining LSF processes,
234                                 // etc.
235                                 done = true
236                                 break
237                         }
238                         if updated.State != ctr.State {
239                                 disp.logger.Infof("container %s changed state from %s to %s", ctr.UUID, ctr.State, updated.State)
240                         }
241                         ctr = updated
242                         if ctr.Priority < 1 {
243                                 disp.logger.Printf("container %s has state %s, priority %d: cancel lsf job", ctr.UUID, ctr.State, ctr.Priority)
244                                 disp.bkill(ctr)
245                         } else {
246                                 disp.lsfqueue.SetPriority(ctr.UUID, int64(ctr.Priority))
247                         }
248                 }
249         }
250         disp.logger.Printf("container %s is done", ctr.UUID)
251
252         // Try "bkill" every few seconds until the LSF job disappears
253         // from the queue.
254         ticker := time.NewTicker(disp.Cluster.Containers.CloudVMs.PollInterval.Duration() / 2)
255         defer ticker.Stop()
256         for qent, ok := disp.lsfqueue.Lookup(ctr.UUID); ok; _, ok = disp.lsfqueue.Lookup(ctr.UUID) {
257                 err := disp.lsfcli.Bkill(qent.ID)
258                 if err != nil {
259                         disp.logger.Warnf("%s: bkill(%s): %s", ctr.UUID, qent.ID, err)
260                 }
261                 <-ticker.C
262         }
263         return nil
264 }
265
266 func (disp *dispatcher) submit(container arvados.Container, crunchRunCommand []string) error {
267         // Start with an empty slice here to ensure append() doesn't
268         // modify crunchRunCommand's underlying array
269         var crArgs []string
270         crArgs = append(crArgs, crunchRunCommand...)
271         crArgs = append(crArgs, container.UUID)
272
273         h := hmac.New(sha256.New, []byte(disp.Cluster.SystemRootToken))
274         fmt.Fprint(h, container.UUID)
275         authsecret := fmt.Sprintf("%x", h.Sum(nil))
276
277         crScript := execScript(crArgs, map[string]string{"GatewayAuthSecret": authsecret})
278
279         bsubArgs, err := disp.bsubArgs(container)
280         if err != nil {
281                 return err
282         }
283         return disp.lsfcli.Bsub(crScript, bsubArgs, disp.ArvClient)
284 }
285
286 func (disp *dispatcher) bkill(ctr arvados.Container) {
287         if qent, ok := disp.lsfqueue.Lookup(ctr.UUID); !ok {
288                 disp.logger.Debugf("bkill(%s): redundant, job not in queue", ctr.UUID)
289         } else if err := disp.lsfcli.Bkill(qent.ID); err != nil {
290                 disp.logger.Warnf("%s: bkill(%s): %s", ctr.UUID, qent.ID, err)
291         }
292 }
293
294 func (disp *dispatcher) bsubArgs(container arvados.Container) ([]string, error) {
295         args := []string{"bsub"}
296
297         tmp := int64(math.Ceil(float64(dispatchcloud.EstimateScratchSpace(&container)) / 1048576))
298         vcpus := container.RuntimeConstraints.VCPUs
299         mem := int64(math.Ceil(float64(container.RuntimeConstraints.RAM+
300                 container.RuntimeConstraints.KeepCacheRAM+
301                 int64(disp.Cluster.Containers.ReserveExtraRAM)) / 1048576))
302
303         repl := map[string]string{
304                 "%%": "%",
305                 "%C": fmt.Sprintf("%d", vcpus),
306                 "%M": fmt.Sprintf("%d", mem),
307                 "%T": fmt.Sprintf("%d", tmp),
308                 "%U": container.UUID,
309                 "%G": fmt.Sprintf("%d", container.RuntimeConstraints.CUDA.DeviceCount),
310         }
311
312         re := regexp.MustCompile(`%.`)
313         var substitutionErrors string
314         argumentTemplate := disp.Cluster.Containers.LSF.BsubArgumentsList
315         if container.RuntimeConstraints.CUDA.DeviceCount > 0 {
316                 argumentTemplate = append(argumentTemplate, disp.Cluster.Containers.LSF.BsubCUDAArguments...)
317         }
318         for _, a := range argumentTemplate {
319                 args = append(args, re.ReplaceAllStringFunc(a, func(s string) string {
320                         subst := repl[s]
321                         if len(subst) == 0 {
322                                 substitutionErrors += fmt.Sprintf("Unknown substitution parameter %s in BsubArgumentsList, ", s)
323                         }
324                         return subst
325                 }))
326         }
327         if len(substitutionErrors) != 0 {
328                 return nil, fmt.Errorf("%s", substitutionErrors[:len(substitutionErrors)-2])
329         }
330
331         if u := disp.Cluster.Containers.LSF.BsubSudoUser; u != "" {
332                 args = append([]string{"sudo", "-E", "-u", u}, args...)
333         }
334         return args, nil
335 }
336
337 // Check the next bjobs report, and invoke TrackContainer for all the
338 // containers in the report. This gives us a chance to cancel existing
339 // Arvados LSF jobs (started by a previous dispatch process) that
340 // never released their LSF job allocations even though their
341 // container states are Cancelled or Complete. See
342 // https://dev.arvados.org/issues/10979
343 func (disp *dispatcher) checkLsfQueueForOrphans() {
344         containerUuidPattern := regexp.MustCompile(`^[a-z0-9]{5}-dz642-[a-z0-9]{15}$`)
345         for _, uuid := range disp.lsfqueue.All() {
346                 if !containerUuidPattern.MatchString(uuid) || !strings.HasPrefix(uuid, disp.Cluster.ClusterID) {
347                         continue
348                 }
349                 err := disp.arvDispatcher.TrackContainer(uuid)
350                 if err != nil {
351                         disp.logger.Warnf("checkLsfQueueForOrphans: TrackContainer(%s): %s", uuid, err)
352                 }
353         }
354 }
355
356 func execScript(args []string, env map[string]string) []byte {
357         s := "#!/bin/sh\n"
358         for k, v := range env {
359                 s += k + `='`
360                 s += strings.Replace(v, `'`, `'\''`, -1)
361                 s += `' `
362         }
363         s += `exec`
364         for _, w := range args {
365                 s += ` '`
366                 s += strings.Replace(w, `'`, `'\''`, -1)
367                 s += `'`
368         }
369         return []byte(s + "\n")
370 }