Fix panic in test case (overwriting a locked sync.Mutex with an unlocked one).
[arvados.git] / sdk / go / dispatch / dispatch.go
1 // Package dispatch is a helper library for building Arvados container
2 // dispatchers.
3 package dispatch
4
5 import (
6         "context"
7         "fmt"
8         "log"
9         "sync"
10         "time"
11
12         "git.curoverse.com/arvados.git/sdk/go/arvados"
13         "git.curoverse.com/arvados.git/sdk/go/arvadosclient"
14 )
15
16 const (
17         Queued    = arvados.ContainerStateQueued
18         Locked    = arvados.ContainerStateLocked
19         Running   = arvados.ContainerStateRunning
20         Complete  = arvados.ContainerStateComplete
21         Cancelled = arvados.ContainerStateCancelled
22 )
23
24 // Dispatcher struct
25 type Dispatcher struct {
26         Arv *arvadosclient.ArvadosClient
27
28         // Queue polling frequency
29         PollPeriod time.Duration
30
31         // Time to wait between successive attempts to run the same container
32         MinRetryPeriod time.Duration
33
34         // Func that implements the container lifecycle. Must be set
35         // to a non-nil DispatchFunc before calling Run().
36         RunContainer DispatchFunc
37
38         auth     arvados.APIClientAuthorization
39         mtx      sync.Mutex
40         trackers map[string]*runTracker
41         throttle throttle
42 }
43
44 // A DispatchFunc executes a container (if the container record is
45 // Locked) or resume monitoring an already-running container, and wait
46 // until that container exits.
47 //
48 // While the container runs, the DispatchFunc should listen for
49 // updated container records on the provided channel. When the channel
50 // closes, the DispatchFunc should stop the container if it's still
51 // running, and return.
52 //
53 // The DispatchFunc should not return until the container is finished.
54 type DispatchFunc func(*Dispatcher, arvados.Container, <-chan arvados.Container)
55
56 // Run watches the API server's queue for containers that are either
57 // ready to run and available to lock, or are already locked by this
58 // dispatcher's token. When a new one appears, Run calls RunContainer
59 // in a new goroutine.
60 func (d *Dispatcher) Run(ctx context.Context) error {
61         err := d.Arv.Call("GET", "api_client_authorizations", "", "current", nil, &d.auth)
62         if err != nil {
63                 return fmt.Errorf("error getting my token UUID: %v", err)
64         }
65
66         d.throttle.hold = d.MinRetryPeriod
67
68         poll := time.NewTicker(d.PollPeriod)
69         defer poll.Stop()
70
71         for {
72                 tracked := d.trackedUUIDs()
73                 d.checkForUpdates([][]interface{}{
74                         {"uuid", "in", tracked}})
75                 d.checkForUpdates([][]interface{}{
76                         {"locked_by_uuid", "=", d.auth.UUID},
77                         {"uuid", "not in", tracked}})
78                 d.checkForUpdates([][]interface{}{
79                         {"state", "=", Queued},
80                         {"priority", ">", "0"},
81                         {"uuid", "not in", tracked}})
82                 select {
83                 case <-poll.C:
84                         continue
85                 case <-ctx.Done():
86                         return ctx.Err()
87                 }
88         }
89 }
90
91 func (d *Dispatcher) trackedUUIDs() []string {
92         d.mtx.Lock()
93         defer d.mtx.Unlock()
94         if len(d.trackers) == 0 {
95                 // API bug: ["uuid", "not in", []] does not work as
96                 // expected, but this does:
97                 return []string{"this-uuid-does-not-exist"}
98         }
99         uuids := make([]string, 0, len(d.trackers))
100         for x := range d.trackers {
101                 uuids = append(uuids, x)
102         }
103         return uuids
104 }
105
106 // Start a runner in a new goroutine, and send the initial container
107 // record to its updates channel.
108 func (d *Dispatcher) start(c arvados.Container) *runTracker {
109         tracker := &runTracker{updates: make(chan arvados.Container, 1)}
110         tracker.updates <- c
111         go func() {
112                 d.RunContainer(d, c, tracker.updates)
113
114                 d.mtx.Lock()
115                 delete(d.trackers, c.UUID)
116                 d.mtx.Unlock()
117         }()
118         return tracker
119 }
120
121 func (d *Dispatcher) checkForUpdates(filters [][]interface{}) {
122         params := arvadosclient.Dict{
123                 "filters": filters,
124                 "order":   []string{"priority desc"}}
125
126         var list arvados.ContainerList
127         for offset, more := 0, true; more; offset += len(list.Items) {
128                 params["offset"] = offset
129                 err := d.Arv.List("containers", params, &list)
130                 if err != nil {
131                         log.Printf("Error getting list of containers: %q", err)
132                         return
133                 }
134                 more = len(list.Items) > 0 && list.ItemsAvailable > len(list.Items)+offset
135                 d.checkListForUpdates(list.Items)
136         }
137 }
138
139 func (d *Dispatcher) checkListForUpdates(containers []arvados.Container) {
140         d.mtx.Lock()
141         defer d.mtx.Unlock()
142         if d.trackers == nil {
143                 d.trackers = make(map[string]*runTracker)
144         }
145
146         for _, c := range containers {
147                 tracker, alreadyTracking := d.trackers[c.UUID]
148                 if c.LockedByUUID != "" && c.LockedByUUID != d.auth.UUID {
149                         log.Printf("debug: ignoring %s locked by %s", c.UUID, c.LockedByUUID)
150                 } else if alreadyTracking {
151                         switch c.State {
152                         case Queued:
153                                 tracker.close()
154                         case Locked, Running:
155                                 tracker.update(c)
156                         case Cancelled, Complete:
157                                 tracker.close()
158                         }
159                 } else {
160                         switch c.State {
161                         case Queued:
162                                 if !d.throttle.Check(c.UUID) {
163                                         break
164                                 }
165                                 err := d.lock(c.UUID)
166                                 if err != nil {
167                                         log.Printf("debug: error locking container %s: %s", c.UUID, err)
168                                         break
169                                 }
170                                 c.State = Locked
171                                 d.trackers[c.UUID] = d.start(c)
172                         case Locked, Running:
173                                 if !d.throttle.Check(c.UUID) {
174                                         break
175                                 }
176                                 d.trackers[c.UUID] = d.start(c)
177                         case Cancelled, Complete:
178                                 // no-op (we already stopped monitoring)
179                         }
180                 }
181         }
182 }
183
184 // UpdateState makes an API call to change the state of a container.
185 func (d *Dispatcher) UpdateState(uuid string, state arvados.ContainerState) error {
186         err := d.Arv.Update("containers", uuid,
187                 arvadosclient.Dict{
188                         "container": arvadosclient.Dict{"state": state},
189                 }, nil)
190         if err != nil {
191                 log.Printf("Error updating container %s to state %q: %s", uuid, state, err)
192         }
193         return err
194 }
195
196 // Lock makes the lock API call which updates the state of a container to Locked.
197 func (d *Dispatcher) lock(uuid string) error {
198         return d.Arv.Call("POST", "containers", uuid, "lock", nil, nil)
199 }
200
201 // Unlock makes the unlock API call which updates the state of a container to Queued.
202 func (d *Dispatcher) Unlock(uuid string) error {
203         return d.Arv.Call("POST", "containers", uuid, "unlock", nil, nil)
204 }
205
206 // TrackContainer ensures a tracker is running for the given UUID,
207 // regardless of the current state of the container (except: if the
208 // container is locked by a different dispatcher, a tracker will not
209 // be started). If the container is not in Locked or Running state,
210 // the new tracker will close down immediately.
211 //
212 // This allows the dispatcher to put its own RunContainer func into a
213 // cleanup phase (for example, to kill local processes created by a
214 // prevous dispatch process that are still running even though the
215 // container state is final) without the risk of having multiple
216 // goroutines monitoring the same UUID.
217 func (d *Dispatcher) TrackContainer(uuid string) error {
218         var cntr arvados.Container
219         err := d.Arv.Call("GET", "containers", uuid, "", nil, &cntr)
220         if err != nil {
221                 return err
222         }
223         if cntr.LockedByUUID != "" && cntr.LockedByUUID != d.auth.UUID {
224                 return nil
225         }
226
227         d.mtx.Lock()
228         defer d.mtx.Unlock()
229         if _, alreadyTracking := d.trackers[uuid]; alreadyTracking {
230                 return nil
231         }
232         if d.trackers == nil {
233                 d.trackers = make(map[string]*runTracker)
234         }
235         d.trackers[uuid] = d.start(cntr)
236         switch cntr.State {
237         case Queued, Cancelled, Complete:
238                 d.trackers[uuid].close()
239         }
240         return nil
241 }
242
243 type runTracker struct {
244         closing bool
245         updates chan arvados.Container
246 }
247
248 func (tracker *runTracker) close() {
249         if !tracker.closing {
250                 close(tracker.updates)
251         }
252         tracker.closing = true
253 }
254
255 func (tracker *runTracker) update(c arvados.Container) {
256         if tracker.closing {
257                 return
258         }
259         select {
260         case <-tracker.updates:
261                 log.Printf("debug: runner is handling updates slowly, discarded previous update for %s", c.UUID)
262         default:
263         }
264         tracker.updates <- c
265 }