Refactor websocket listener.
authorTom Clegg <tom@tomclegg.ca>
Wed, 11 Mar 2020 20:43:37 +0000 (16:43 -0400)
committerTom Clegg <tom@tomclegg.ca>
Wed, 11 Mar 2020 20:43:37 +0000 (16:43 -0400)
Arvados-DCO-1.1-Signed-off-by: Tom Clegg <tom@tomclegg.ca>

arvados.go

index 5c4787dedbc1b0d065f21f84bc26b41228320315..bfeedc2679174e525e8c366460f50036d7dd92a7 100644 (file)
@@ -9,6 +9,7 @@ import (
        "os"
        "regexp"
        "strings"
+       "sync"
        "time"
 
        "git.arvados.org/arvados.git/sdk/go/arvados"
@@ -19,6 +20,151 @@ import (
        "golang.org/x/net/websocket"
 )
 
+type eventMessage struct {
+       Status     int
+       ObjectUUID string `json:"object_uuid"`
+       EventType  string `json:"event_type"`
+       Properties struct {
+               Text string
+       }
+}
+
+type arvadosClient struct {
+       *arvados.Client
+       notifying map[string]map[chan<- eventMessage]int
+       wantClose chan struct{}
+       wsconn    *websocket.Conn
+       mtx       sync.Mutex
+}
+
+// Listen for events concerning the given uuids. When an event occurs
+// (and after connecting/reconnecting to the event stream), send each
+// uuid to ch. If a {ch, uuid} pair is subscribed twice, the uuid will
+// be sent only once for each update, but two Unsubscribe calls will
+// be needed to stop sending them.
+func (client *arvadosClient) Subscribe(ch chan<- eventMessage, uuid string) {
+       client.mtx.Lock()
+       defer client.mtx.Unlock()
+       if client.notifying == nil {
+               client.notifying = map[string]map[chan<- eventMessage]int{}
+               client.wantClose = make(chan struct{})
+               go client.runNotifier()
+       }
+       chmap := client.notifying[uuid]
+       if chmap == nil {
+               chmap = map[chan<- eventMessage]int{}
+               client.notifying[uuid] = chmap
+       }
+       needSub := true
+       for _, nch := range chmap {
+               if nch > 0 {
+                       needSub = false
+                       break
+               }
+       }
+       chmap[ch]++
+       if needSub && client.wsconn != nil {
+               go json.NewEncoder(client.wsconn).Encode(map[string]interface{}{
+                       "method": "subscribe",
+                       "filters": [][]interface{}{
+                               {"object_uuid", "=", uuid},
+                               {"event_type", "in", []string{"stderr", "crunch-run", "update"}},
+                       },
+               })
+       }
+}
+
+func (client *arvadosClient) Unsubscribe(ch chan<- eventMessage, uuid string) {
+       client.mtx.Lock()
+       defer client.mtx.Unlock()
+       chmap := client.notifying[uuid]
+       if n := chmap[ch] - 1; n == 0 {
+               delete(chmap, ch)
+               if len(chmap) == 0 {
+                       delete(client.notifying, uuid)
+               }
+               if client.wsconn != nil {
+                       go json.NewEncoder(client.wsconn).Encode(map[string]interface{}{
+                               "method": "unsubscribe",
+                               "filters": [][]interface{}{
+                                       {"object_uuid", "=", uuid},
+                                       {"event_type", "in", []string{"stderr", "crunch-run", "update"}},
+                               },
+                       })
+               }
+       } else if n > 0 {
+               chmap[ch] = n
+       }
+}
+
+func (client *arvadosClient) Close() {
+       client.mtx.Lock()
+       defer client.mtx.Unlock()
+       if client.notifying != nil {
+               client.notifying = nil
+               close(client.wantClose)
+       }
+}
+
+func (client *arvadosClient) runNotifier() {
+reconnect:
+       for {
+               var cluster arvados.Cluster
+               err := client.RequestAndDecode(&cluster, "GET", arvados.EndpointConfigGet.Path, nil, nil)
+               if err != nil {
+                       log.Warnf("error getting cluster config: %s", err)
+                       time.Sleep(5 * time.Second)
+                       continue reconnect
+               }
+               wsURL := cluster.Services.Websocket.ExternalURL
+               wsURL.Scheme = strings.Replace(wsURL.Scheme, "http", "ws", 1)
+               wsURL.Path = "/websocket"
+               wsURL.RawQuery = url.Values{"api_token": []string{client.AuthToken}}.Encode()
+               conn, err := websocket.Dial(wsURL.String(), "", cluster.Services.Controller.ExternalURL.String())
+               if err != nil {
+                       log.Warnf("websocket connection error: %s", err)
+                       time.Sleep(5 * time.Second)
+                       continue reconnect
+               }
+               client.mtx.Lock()
+               client.wsconn = conn
+               client.mtx.Unlock()
+
+               w := json.NewEncoder(conn)
+               for uuid := range client.notifying {
+                       w.Encode(map[string]interface{}{
+                               "method": "subscribe",
+                               "filters": [][]interface{}{
+                                       {"object_uuid", "=", uuid},
+                                       {"event_type", "in", []string{"stderr", "crunch-run", "update"}},
+                               },
+                       })
+               }
+
+               r := json.NewDecoder(conn)
+               for {
+                       var msg eventMessage
+                       err := r.Decode(&msg)
+                       select {
+                       case <-client.wantClose:
+                               return
+                       default:
+                               if err != nil {
+                                       log.Printf("error decoding websocket message: %s", err)
+                                       client.mtx.Lock()
+                                       client.wsconn = nil
+                                       client.mtx.Unlock()
+                                       go conn.Close()
+                                       continue reconnect
+                               }
+                               for ch := range client.notifying[msg.ObjectUUID] {
+                                       ch <- msg
+                               }
+                       }
+               }
+       }
+}
+
 type arvadosContainerRunner struct {
        Client      *arvados.Client
        Name        string
@@ -91,11 +237,13 @@ func (runner *arvadosContainerRunner) Run() (string, error) {
        log.Printf("container request UUID: %s", cr.UUID)
        log.Printf("container UUID: %s", cr.ContainerUUID)
 
-       var logch <-chan eventMessage
-       var logstream *logStream
+       logch := make(chan eventMessage)
+       client := arvadosClient{Client: runner.Client}
+       defer client.Close()
+       subscribedUUID := ""
        defer func() {
-               if logstream != nil {
-                       logstream.Close()
+               if subscribedUUID != "" {
+                       client.Unsubscribe(logch, subscribedUUID)
                }
        }()
 
@@ -113,25 +261,20 @@ func (runner *arvadosContainerRunner) Run() (string, error) {
                        log.Printf("container state: %s", cr.State)
                        lastState = cr.State
                }
+               if subscribedUUID != cr.ContainerUUID {
+                       if subscribedUUID != "" {
+                               client.Unsubscribe(logch, subscribedUUID)
+                       }
+                       client.Subscribe(logch, cr.ContainerUUID)
+                       subscribedUUID = cr.ContainerUUID
+               }
        }
 
-       subscribedUUID := ""
        for cr.State != arvados.ContainerRequestStateFinal {
-               if logch == nil && cr.ContainerUUID != subscribedUUID {
-                       if logstream != nil {
-                               logstream.Close()
-                       }
-                       logstream = runner.logStream(cr.ContainerUUID)
-                       logch = logstream.C
-               }
                select {
-               case msg, ok := <-logch:
-                       if !ok {
-                               logstream.Close()
-                               logstream = nil
-                               logch = nil
-                               break
-                       }
+               case <-ticker.C:
+                       refreshCR()
+               case msg := <-logch:
                        switch msg.EventType {
                        case "update":
                                refreshCR()
@@ -142,8 +285,6 @@ func (runner *arvadosContainerRunner) Run() (string, error) {
                                        }
                                }
                        }
-               case <-ticker.C:
-                       refreshCR()
                }
        }
 
@@ -250,68 +391,3 @@ func (runner *arvadosContainerRunner) makeCommandCollection() (string, error) {
        log.Printf("stored lightning binary in new collection %s", coll.UUID)
        return coll.UUID, nil
 }
-
-type eventMessage struct {
-       Status     int
-       ObjectUUID string `json:"object_uuid"`
-       EventType  string `json:"event_type"`
-       Properties struct {
-               Text string
-       }
-}
-
-type logStream struct {
-       C     <-chan eventMessage
-       Close func() error
-}
-
-func (runner *arvadosContainerRunner) logStream(uuid string) *logStream {
-       ch := make(chan eventMessage)
-       done := make(chan struct{})
-       go func() {
-               defer close(ch)
-               var cluster arvados.Cluster
-               runner.Client.RequestAndDecode(&cluster, "GET", arvados.EndpointConfigGet.Path, nil, nil)
-               wsURL := cluster.Services.Websocket.ExternalURL
-               wsURL.Scheme = strings.Replace(wsURL.Scheme, "http", "ws", 1)
-               wsURL.Path = "/websocket"
-               wsURL.RawQuery = url.Values{"api_token": []string{runner.Client.AuthToken}}.Encode()
-               conn, err := websocket.Dial(wsURL.String(), "", cluster.Services.Controller.ExternalURL.String())
-               if err != nil {
-                       log.Printf("websocket error: %s", err)
-                       return
-               }
-               w := json.NewEncoder(conn)
-               go w.Encode(map[string]interface{}{
-                       "method": "subscribe",
-                       "filters": [][]interface{}{
-                               {"object_uuid", "=", uuid},
-                               {"event_type", "in", []string{"stderr", "crunch-run", "update"}},
-                       },
-               })
-               r := json.NewDecoder(conn)
-               for {
-                       var msg eventMessage
-                       err := r.Decode(&msg)
-                       if err != nil {
-                               log.Printf("error decoding websocket message: %s", err)
-                               return
-                       }
-                       if msg.ObjectUUID == uuid {
-                               ch <- msg
-                       }
-                       select {
-                       case <-done:
-                               return
-                       default:
-                       }
-               }
-       }()
-       return &logStream{
-               C: ch,
-               Close: func() error {
-                       close(done)
-                       return nil
-               },
-       }
-}