Fix some tests.
[lightning.git] / arvados.go
1 // Copyright (C) The Lightning Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package lightning
6
7 import (
8         "bufio"
9         "bytes"
10         "context"
11         "encoding/json"
12         "errors"
13         "fmt"
14         "io"
15         "io/ioutil"
16         "net/http"
17         "net/url"
18         "os"
19         "regexp"
20         "strconv"
21         "strings"
22         "sync"
23         "time"
24
25         "git.arvados.org/arvados.git/lib/cmd"
26         "git.arvados.org/arvados.git/sdk/go/arvados"
27         "git.arvados.org/arvados.git/sdk/go/arvadosclient"
28         "git.arvados.org/arvados.git/sdk/go/keepclient"
29         "github.com/klauspost/pgzip"
30         log "github.com/sirupsen/logrus"
31         "golang.org/x/crypto/blake2b"
32         "golang.org/x/net/websocket"
33 )
34
35 type eventMessage struct {
36         Status     int
37         ObjectUUID string `json:"object_uuid"`
38         EventType  string `json:"event_type"`
39         Properties struct {
40                 Text string
41         }
42 }
43
44 type arvadosClient struct {
45         *arvados.Client
46         notifying map[string]map[chan<- eventMessage]int
47         wantClose chan struct{}
48         wsconn    *websocket.Conn
49         mtx       sync.Mutex
50 }
51
52 // Listen for events concerning the given uuids. When an event occurs
53 // (and after connecting/reconnecting to the event stream), send each
54 // uuid to ch. If a {ch, uuid} pair is subscribed twice, the uuid will
55 // be sent only once for each update, but two Unsubscribe calls will
56 // be needed to stop sending them.
57 func (client *arvadosClient) Subscribe(ch chan<- eventMessage, uuid string) {
58         client.mtx.Lock()
59         defer client.mtx.Unlock()
60         if client.notifying == nil {
61                 client.notifying = map[string]map[chan<- eventMessage]int{}
62                 client.wantClose = make(chan struct{})
63                 go client.runNotifier()
64         }
65         chmap := client.notifying[uuid]
66         if chmap == nil {
67                 chmap = map[chan<- eventMessage]int{}
68                 client.notifying[uuid] = chmap
69         }
70         needSub := true
71         for _, nch := range chmap {
72                 if nch > 0 {
73                         needSub = false
74                         break
75                 }
76         }
77         chmap[ch]++
78         if needSub && client.wsconn != nil {
79                 go json.NewEncoder(client.wsconn).Encode(map[string]interface{}{
80                         "method": "subscribe",
81                         "filters": [][]interface{}{
82                                 {"object_uuid", "=", uuid},
83                                 {"event_type", "in", []string{"stderr", "crunch-run", "update"}},
84                         },
85                 })
86         }
87 }
88
89 func (client *arvadosClient) Unsubscribe(ch chan<- eventMessage, uuid string) {
90         client.mtx.Lock()
91         defer client.mtx.Unlock()
92         chmap := client.notifying[uuid]
93         if n := chmap[ch] - 1; n == 0 {
94                 delete(chmap, ch)
95                 if len(chmap) == 0 {
96                         delete(client.notifying, uuid)
97                 }
98                 if client.wsconn != nil {
99                         go json.NewEncoder(client.wsconn).Encode(map[string]interface{}{
100                                 "method": "unsubscribe",
101                                 "filters": [][]interface{}{
102                                         {"object_uuid", "=", uuid},
103                                         {"event_type", "in", []string{"stderr", "crunch-run", "update"}},
104                                 },
105                         })
106                 }
107         } else if n > 0 {
108                 chmap[ch] = n
109         }
110 }
111
112 func (client *arvadosClient) Close() {
113         client.mtx.Lock()
114         defer client.mtx.Unlock()
115         if client.notifying != nil {
116                 client.notifying = nil
117                 close(client.wantClose)
118         }
119 }
120
121 func (client *arvadosClient) runNotifier() {
122 reconnect:
123         for {
124                 var cluster arvados.Cluster
125                 err := client.RequestAndDecode(&cluster, "GET", arvados.EndpointConfigGet.Path, nil, nil)
126                 if err != nil {
127                         log.Warnf("error getting cluster config: %s", err)
128                         time.Sleep(5 * time.Second)
129                         continue reconnect
130                 }
131                 wsURL := cluster.Services.Websocket.ExternalURL
132                 wsURL.Scheme = strings.Replace(wsURL.Scheme, "http", "ws", 1)
133                 wsURL.Path = "/websocket"
134                 wsURLNoToken := wsURL.String()
135                 wsURL.RawQuery = url.Values{"api_token": []string{client.AuthToken}}.Encode()
136                 conn, err := websocket.Dial(wsURL.String(), "", cluster.Services.Controller.ExternalURL.String())
137                 if err != nil {
138                         log.Warnf("websocket connection error: %s", err)
139                         time.Sleep(5 * time.Second)
140                         continue reconnect
141                 }
142                 log.Printf("connected to websocket at %s", wsURLNoToken)
143
144                 client.mtx.Lock()
145                 client.wsconn = conn
146                 resubscribe := make([]string, 0, len(client.notifying))
147                 for uuid := range client.notifying {
148                         resubscribe = append(resubscribe, uuid)
149                 }
150                 client.mtx.Unlock()
151
152                 go func() {
153                         w := json.NewEncoder(conn)
154                         for _, uuid := range resubscribe {
155                                 w.Encode(map[string]interface{}{
156                                         "method": "subscribe",
157                                         "filters": [][]interface{}{
158                                                 {"object_uuid", "=", uuid},
159                                                 {"event_type", "in", []string{"stderr", "crunch-run", "crunchstat", "update"}},
160                                         },
161                                 })
162                         }
163                 }()
164
165                 r := json.NewDecoder(conn)
166                 for {
167                         var msg eventMessage
168                         err := r.Decode(&msg)
169                         select {
170                         case <-client.wantClose:
171                                 return
172                         default:
173                                 if err != nil {
174                                         log.Printf("error decoding websocket message: %s", err)
175                                         client.mtx.Lock()
176                                         client.wsconn = nil
177                                         client.mtx.Unlock()
178                                         go conn.Close()
179                                         continue reconnect
180                                 }
181                                 client.mtx.Lock()
182                                 for ch := range client.notifying[msg.ObjectUUID] {
183                                         go func() { ch <- msg }()
184                                 }
185                                 client.mtx.Unlock()
186                         }
187                 }
188         }
189 }
190
191 var refreshTicker = time.NewTicker(5 * time.Second)
192
193 type arvadosContainerRunner struct {
194         Client      *arvados.Client
195         Name        string
196         OutputName  string
197         ProjectUUID string
198         APIAccess   bool
199         VCPUs       int
200         RAM         int64
201         Prog        string // if empty, run /proc/self/exe
202         Args        []string
203         Mounts      map[string]map[string]interface{}
204         Priority    int
205         KeepCache   int // cache buffers per VCPU (0 for default)
206         Preemptible bool
207 }
208
209 func (runner *arvadosContainerRunner) Run() (string, error) {
210         return runner.RunContext(context.Background())
211 }
212
213 func (runner *arvadosContainerRunner) RunContext(ctx context.Context) (string, error) {
214         if runner.ProjectUUID == "" {
215                 return "", errors.New("cannot run arvados container: ProjectUUID not provided")
216         }
217
218         mounts := map[string]map[string]interface{}{
219                 "/mnt/output": {
220                         "kind":     "collection",
221                         "writable": true,
222                 },
223         }
224         for path, mnt := range runner.Mounts {
225                 mounts[path] = mnt
226         }
227
228         prog := runner.Prog
229         if prog == "" {
230                 prog = "/mnt/cmd/lightning"
231                 cmdUUID, err := runner.makeCommandCollection()
232                 if err != nil {
233                         return "", err
234                 }
235                 mounts["/mnt/cmd"] = map[string]interface{}{
236                         "kind": "collection",
237                         "uuid": cmdUUID,
238                 }
239         }
240         command := append([]string{prog}, runner.Args...)
241
242         priority := runner.Priority
243         if priority < 1 {
244                 priority = 500
245         }
246         keepCache := runner.KeepCache
247         if keepCache < 1 {
248                 keepCache = 2
249         }
250         rc := arvados.RuntimeConstraints{
251                 API:          runner.APIAccess,
252                 VCPUs:        runner.VCPUs,
253                 RAM:          runner.RAM,
254                 KeepCacheRAM: (1 << 26) * int64(keepCache) * int64(runner.VCPUs),
255         }
256         outname := &runner.OutputName
257         if *outname == "" {
258                 outname = nil
259         }
260         var cr arvados.ContainerRequest
261         err := runner.Client.RequestAndDecode(&cr, "POST", "arvados/v1/container_requests", nil, map[string]interface{}{
262                 "container_request": map[string]interface{}{
263                         "owner_uuid":          runner.ProjectUUID,
264                         "name":                runner.Name,
265                         "container_image":     "lightning-runtime",
266                         "command":             command,
267                         "mounts":              mounts,
268                         "use_existing":        true,
269                         "output_path":         "/mnt/output",
270                         "output_name":         outname,
271                         "runtime_constraints": rc,
272                         "priority":            runner.Priority,
273                         "state":               arvados.ContainerRequestStateCommitted,
274                         "scheduling_parameters": arvados.SchedulingParameters{
275                                 Preemptible: runner.Preemptible,
276                                 Partitions:  []string{},
277                         },
278                         "environment": map[string]string{
279                                 "GOMAXPROCS": fmt.Sprintf("%d", rc.VCPUs),
280                         },
281                         "container_count_max": 1,
282                 },
283         })
284         if err != nil {
285                 return "", err
286         }
287         log.Printf("container request UUID: %s", cr.UUID)
288         log.Printf("container UUID: %s", cr.ContainerUUID)
289
290         logch := make(chan eventMessage)
291         client := arvadosClient{Client: runner.Client}
292         defer client.Close()
293         subscribedUUID := ""
294         defer func() {
295                 if subscribedUUID != "" {
296                         log.Printf("unsubscribe container UUID: %s", subscribedUUID)
297                         client.Unsubscribe(logch, subscribedUUID)
298                 }
299         }()
300
301         neednewline := ""
302         logTell := map[string]int64{}
303
304         lastState := cr.State
305         refreshCR := func() {
306                 ctx, cancel := context.WithDeadline(ctx, time.Now().Add(time.Minute))
307                 defer cancel()
308                 err = runner.Client.RequestAndDecodeContext(ctx, &cr, "GET", "arvados/v1/container_requests/"+cr.UUID, nil, nil)
309                 if err != nil {
310                         fmt.Fprint(os.Stderr, neednewline)
311                         neednewline = ""
312                         log.Printf("error getting container request: %s", err)
313                         return
314                 }
315                 if lastState != cr.State {
316                         fmt.Fprint(os.Stderr, neednewline)
317                         neednewline = ""
318                         log.Printf("container request state: %s", cr.State)
319                         lastState = cr.State
320                 }
321                 if subscribedUUID != cr.ContainerUUID {
322                         fmt.Fprint(os.Stderr, neednewline)
323                         neednewline = ""
324                         if subscribedUUID != "" {
325                                 log.Printf("unsubscribe container UUID: %s", subscribedUUID)
326                                 client.Unsubscribe(logch, subscribedUUID)
327                         }
328                         log.Printf("subscribe container UUID: %s", cr.ContainerUUID)
329                         client.Subscribe(logch, cr.ContainerUUID)
330                         subscribedUUID = cr.ContainerUUID
331                         logTell = map[string]int64{}
332                 }
333         }
334
335         var logWaitMax = time.Second * 10
336         var logWaitMin = time.Second
337         var logWait = logWaitMin
338         var logWaitDone = time.After(logWait)
339         var reCrunchstat = regexp.MustCompile(`mem .* (\d+) rss`)
340 waitctr:
341         for cr.State != arvados.ContainerRequestStateFinal {
342                 select {
343                 case <-ctx.Done():
344                         err := runner.Client.RequestAndDecode(&cr, "PATCH", "arvados/v1/container_requests/"+cr.UUID, nil, map[string]interface{}{
345                                 "container_request": map[string]interface{}{
346                                         "priority": 0,
347                                 },
348                         })
349                         if err != nil {
350                                 log.Errorf("error while trying to cancel container request %s: %s", cr.UUID, err)
351                         }
352                         break waitctr
353                 case <-refreshTicker.C:
354                         refreshCR()
355                 case msg := <-logch:
356                         if msg.EventType == "update" {
357                                 refreshCR()
358                         }
359                 case <-logWaitDone:
360                         any := false
361                         for _, fnm := range []string{"stderr.txt", "crunchstat.txt"} {
362                                 req, err := http.NewRequest("GET", "https://"+runner.Client.APIHost+"/arvados/v1/container_requests/"+cr.UUID+"/log/"+cr.ContainerUUID+"/"+fnm, nil)
363                                 if err != nil {
364                                         log.Errorf("error preparing log request: %s", err)
365                                         continue
366                                 }
367                                 req.Header.Set("Range", fmt.Sprintf("bytes=%d-", logTell[fnm]))
368                                 resp, err := runner.Client.Do(req)
369                                 if err != nil {
370                                         log.Errorf("error getting log data: %s", err)
371                                         continue
372                                 } else if (resp.StatusCode == http.StatusNotFound && logTell[fnm] == 0) ||
373                                         (resp.StatusCode == http.StatusRequestedRangeNotSatisfiable && logTell[fnm] > 0) {
374                                         continue
375                                 } else if resp.StatusCode >= 300 {
376                                         log.Errorf("error getting log data: %s", resp.Status)
377                                         continue
378                                 }
379                                 logdata, err := io.ReadAll(resp.Body)
380                                 if err != nil {
381                                         log.Errorf("error reading log data: %s", err)
382                                         continue
383                                 }
384                                 if len(logdata) == 0 {
385                                         continue
386                                 }
387                                 for {
388                                         eol := bytes.Index(logdata, []byte{'\n'})
389                                         if eol < 0 {
390                                                 break
391                                         }
392                                         line := string(logdata[:eol])
393                                         logdata = logdata[eol+1:]
394                                         logTell[fnm] += int64(eol + 1)
395                                         if len(line) == 0 {
396                                                 continue
397                                         }
398                                         any = true
399                                         if fnm == "stderr.txt" {
400                                                 fmt.Fprint(os.Stderr, neednewline)
401                                                 neednewline = ""
402                                                 log.Print(line)
403                                         } else if fnm == "crunchstat.txt" {
404                                                 m := reCrunchstat.FindStringSubmatch(line)
405                                                 if m != nil {
406                                                         rss, _ := strconv.ParseInt(m[1], 10, 64)
407                                                         fmt.Fprintf(os.Stderr, "%s rss %.3f GB           \r", cr.UUID, float64(rss)/1e9)
408                                                         neednewline = "\n"
409                                                 }
410                                         }
411                                 }
412                         }
413                         if any {
414                                 logWait = logWaitMin
415                         } else {
416                                 logWait = logWait * 2
417                                 if logWait > logWaitMax {
418                                         logWait = logWaitMax
419                                 }
420                         }
421                         logWaitDone = time.After(logWait)
422                 }
423         }
424         fmt.Fprint(os.Stderr, neednewline)
425
426         if err := ctx.Err(); err != nil {
427                 return "", err
428         }
429
430         var c arvados.Container
431         err = runner.Client.RequestAndDecode(&c, "GET", "arvados/v1/containers/"+cr.ContainerUUID, nil, nil)
432         if err != nil {
433                 return "", err
434         } else if c.State != arvados.ContainerStateComplete {
435                 return "", fmt.Errorf("container did not complete: %s", c.State)
436         } else if c.ExitCode != 0 {
437                 return "", fmt.Errorf("container exited %d", c.ExitCode)
438         }
439         return cr.OutputUUID, err
440 }
441
442 var collectionInPathRe = regexp.MustCompile(`^(.*/)?([0-9a-f]{32}\+[0-9]+|[0-9a-z]{5}-[0-9a-z]{5}-[0-9a-z]{15})(/.*)?$`)
443
444 func (runner *arvadosContainerRunner) TranslatePaths(paths ...*string) error {
445         if runner.Mounts == nil {
446                 runner.Mounts = make(map[string]map[string]interface{})
447         }
448         for _, path := range paths {
449                 if *path == "" || *path == "-" {
450                         continue
451                 }
452                 m := collectionInPathRe.FindStringSubmatch(*path)
453                 if m == nil {
454                         return fmt.Errorf("cannot find uuid in path: %q", *path)
455                 }
456                 collID := m[2]
457                 mnt, ok := runner.Mounts["/mnt/"+collID]
458                 if !ok {
459                         mnt = map[string]interface{}{
460                                 "kind": "collection",
461                         }
462                         if len(collID) == 27 {
463                                 mnt["uuid"] = collID
464                         } else {
465                                 mnt["portable_data_hash"] = collID
466                         }
467                         runner.Mounts["/mnt/"+collID] = mnt
468                 }
469                 *path = "/mnt/" + collID + m[3]
470         }
471         return nil
472 }
473
474 var mtxMakeCommandCollection sync.Mutex
475
476 func (runner *arvadosContainerRunner) makeCommandCollection() (string, error) {
477         mtxMakeCommandCollection.Lock()
478         defer mtxMakeCommandCollection.Unlock()
479         exe, err := ioutil.ReadFile("/proc/self/exe")
480         if err != nil {
481                 return "", err
482         }
483         b2 := blake2b.Sum256(exe)
484         cname := "lightning " + cmd.Version.String() // must build with "make", not just "go install"
485         var existing arvados.CollectionList
486         err = runner.Client.RequestAndDecode(&existing, "GET", "arvados/v1/collections", nil, arvados.ListOptions{
487                 Limit: 1,
488                 Count: "none",
489                 Filters: []arvados.Filter{
490                         {Attr: "name", Operator: "=", Operand: cname},
491                         {Attr: "owner_uuid", Operator: "=", Operand: runner.ProjectUUID},
492                         {Attr: "properties.blake2b", Operator: "=", Operand: fmt.Sprintf("%x", b2)},
493                 },
494         })
495         if err != nil {
496                 return "", err
497         }
498         if len(existing.Items) > 0 {
499                 coll := existing.Items[0]
500                 log.Printf("using lightning binary in existing collection %s (name is %q, hash is %q; did not verify whether content matches)", coll.UUID, cname, coll.Properties["blake2b"])
501                 return coll.UUID, nil
502         }
503         log.Printf("writing lightning binary to new collection %q", cname)
504         ac, err := arvadosclient.New(runner.Client)
505         if err != nil {
506                 return "", err
507         }
508         kc := keepclient.New(ac)
509         var coll arvados.Collection
510         fs, err := coll.FileSystem(runner.Client, kc)
511         if err != nil {
512                 return "", err
513         }
514         f, err := fs.OpenFile("lightning", os.O_CREATE|os.O_WRONLY, 0777)
515         if err != nil {
516                 return "", err
517         }
518         _, err = f.Write(exe)
519         if err != nil {
520                 return "", err
521         }
522         err = f.Close()
523         if err != nil {
524                 return "", err
525         }
526         mtxt, err := fs.MarshalManifest(".")
527         if err != nil {
528                 return "", err
529         }
530         err = runner.Client.RequestAndDecode(&coll, "POST", "arvados/v1/collections", nil, map[string]interface{}{
531                 "collection": map[string]interface{}{
532                         "owner_uuid":    runner.ProjectUUID,
533                         "manifest_text": mtxt,
534                         "name":          cname,
535                         "properties": map[string]interface{}{
536                                 "blake2b": fmt.Sprintf("%x", b2),
537                         },
538                 },
539         })
540         if err != nil {
541                 return "", err
542         }
543         log.Printf("stored lightning binary in new collection %s", coll.UUID)
544         return coll.UUID, nil
545 }
546
547 // zopen returns a reader for the given file, using the arvados API
548 // instead of arv-mount/fuse where applicable, and transparently
549 // decompressing the input if fnm ends with ".gz".
550 func zopen(fnm string) (io.ReadCloser, error) {
551         f, err := open(fnm)
552         if err != nil || !strings.HasSuffix(fnm, ".gz") {
553                 return f, err
554         }
555         rdr, err := pgzip.NewReader(bufio.NewReaderSize(f, 4*1024*1024))
556         if err != nil {
557                 f.Close()
558                 return nil, err
559         }
560         return gzipr{rdr, f}, nil
561 }
562
563 // gzipr wraps a ReadCloser and a Closer, presenting a single Close()
564 // method that closes both wrapped objects.
565 type gzipr struct {
566         io.ReadCloser
567         io.Closer
568 }
569
570 func (gr gzipr) Close() error {
571         e1 := gr.ReadCloser.Close()
572         e2 := gr.Closer.Close()
573         if e1 != nil {
574                 return e1
575         }
576         return e2
577 }
578
579 var (
580         arvadosClientFromEnv = arvados.NewClientFromEnv()
581         keepClient           *keepclient.KeepClient
582         siteFS               arvados.CustomFileSystem
583         siteFSMtx            sync.Mutex
584 )
585
586 type file interface {
587         io.ReadCloser
588         io.Seeker
589         Readdir(n int) ([]os.FileInfo, error)
590 }
591
592 func open(fnm string) (file, error) {
593         if os.Getenv("ARVADOS_API_HOST") == "" {
594                 return os.Open(fnm)
595         }
596         m := collectionInPathRe.FindStringSubmatch(fnm)
597         if m == nil {
598                 return os.Open(fnm)
599         }
600         collectionUUID := m[2]
601         collectionPath := m[3]
602
603         siteFSMtx.Lock()
604         defer siteFSMtx.Unlock()
605         if siteFS == nil {
606                 log.Info("setting up Arvados client")
607                 ac, err := arvadosclient.New(arvadosClientFromEnv)
608                 if err != nil {
609                         return nil, err
610                 }
611                 ac.Client = arvados.DefaultSecureClient
612                 keepClient = keepclient.New(ac)
613                 // Don't use keepclient's default short timeouts.
614                 keepClient.HTTPClient = arvados.DefaultSecureClient
615                 keepClient.BlockCache = &keepclient.BlockCache{MaxBlocks: 4}
616                 siteFS = arvadosClientFromEnv.SiteFileSystem(keepClient)
617         } else {
618                 keepClient.BlockCache.MaxBlocks += 2
619         }
620
621         log.Infof("reading %q from %s using Arvados client", collectionPath, collectionUUID)
622         f, err := siteFS.Open("by_id/" + collectionUUID + collectionPath)
623         if err != nil {
624                 return nil, err
625         }
626         return &reduceCacheOnClose{file: f}, nil
627 }
628
629 type reduceCacheOnClose struct {
630         file
631         once sync.Once
632 }
633
634 func (rc *reduceCacheOnClose) Close() error {
635         rc.once.Do(func() { keepClient.BlockCache.MaxBlocks -= 2 })
636         return rc.file.Close()
637 }