Merge branch '20235-probe-after-upgrade'
[arvados.git] / lib / dispatchcloud / dispatcher.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package dispatchcloud
6
7 import (
8         "context"
9         "crypto/md5"
10         "encoding/json"
11         "fmt"
12         "net/http"
13         "strings"
14         "sync"
15         "time"
16
17         "git.arvados.org/arvados.git/lib/cloud"
18         "git.arvados.org/arvados.git/lib/controller/dblock"
19         "git.arvados.org/arvados.git/lib/ctrlctx"
20         "git.arvados.org/arvados.git/lib/dispatchcloud/container"
21         "git.arvados.org/arvados.git/lib/dispatchcloud/scheduler"
22         "git.arvados.org/arvados.git/lib/dispatchcloud/sshexecutor"
23         "git.arvados.org/arvados.git/lib/dispatchcloud/worker"
24         "git.arvados.org/arvados.git/sdk/go/arvados"
25         "git.arvados.org/arvados.git/sdk/go/auth"
26         "git.arvados.org/arvados.git/sdk/go/ctxlog"
27         "git.arvados.org/arvados.git/sdk/go/health"
28         "git.arvados.org/arvados.git/sdk/go/httpserver"
29         "github.com/julienschmidt/httprouter"
30         "github.com/prometheus/client_golang/prometheus"
31         "github.com/prometheus/client_golang/prometheus/promhttp"
32         "github.com/sirupsen/logrus"
33         "golang.org/x/crypto/ssh"
34 )
35
36 const (
37         defaultPollInterval     = time.Second
38         defaultStaleLockTimeout = time.Minute
39 )
40
41 type pool interface {
42         scheduler.WorkerPool
43         CheckHealth() error
44         Instances() []worker.InstanceView
45         SetIdleBehavior(cloud.InstanceID, worker.IdleBehavior) error
46         KillInstance(id cloud.InstanceID, reason string) error
47         Stop()
48 }
49
50 type dispatcher struct {
51         Cluster       *arvados.Cluster
52         Context       context.Context
53         ArvClient     *arvados.Client
54         AuthToken     string
55         Registry      *prometheus.Registry
56         InstanceSetID cloud.InstanceSetID
57
58         dbConnector ctrlctx.DBConnector
59         logger      logrus.FieldLogger
60         instanceSet cloud.InstanceSet
61         pool        pool
62         queue       scheduler.ContainerQueue
63         httpHandler http.Handler
64         sshKey      ssh.Signer
65
66         setupOnce sync.Once
67         stop      chan struct{}
68         stopped   chan struct{}
69 }
70
71 // Start starts the dispatcher. Start can be called multiple times
72 // with no ill effect.
73 func (disp *dispatcher) Start() {
74         disp.setupOnce.Do(disp.setup)
75 }
76
77 // ServeHTTP implements service.Handler.
78 func (disp *dispatcher) ServeHTTP(w http.ResponseWriter, r *http.Request) {
79         disp.Start()
80         disp.httpHandler.ServeHTTP(w, r)
81 }
82
83 // CheckHealth implements service.Handler.
84 func (disp *dispatcher) CheckHealth() error {
85         disp.Start()
86         return disp.pool.CheckHealth()
87 }
88
89 // Done implements service.Handler.
90 func (disp *dispatcher) Done() <-chan struct{} {
91         return disp.stopped
92 }
93
94 // Stop dispatching containers and release resources. Typically used
95 // in tests.
96 func (disp *dispatcher) Close() {
97         disp.Start()
98         select {
99         case disp.stop <- struct{}{}:
100         default:
101         }
102         <-disp.stopped
103 }
104
105 // Make a worker.Executor for the given instance.
106 func (disp *dispatcher) newExecutor(inst cloud.Instance) worker.Executor {
107         exr := sshexecutor.New(inst)
108         exr.SetTargetPort(disp.Cluster.Containers.CloudVMs.SSHPort)
109         exr.SetSigners(disp.sshKey)
110         return exr
111 }
112
113 func (disp *dispatcher) typeChooser(ctr *arvados.Container) (arvados.InstanceType, error) {
114         return ChooseInstanceType(disp.Cluster, ctr)
115 }
116
117 func (disp *dispatcher) setup() {
118         disp.initialize()
119         go disp.run()
120 }
121
122 func (disp *dispatcher) initialize() {
123         disp.logger = ctxlog.FromContext(disp.Context)
124         disp.dbConnector = ctrlctx.DBConnector{PostgreSQL: disp.Cluster.PostgreSQL}
125
126         disp.ArvClient.AuthToken = disp.AuthToken
127
128         if disp.InstanceSetID == "" {
129                 if strings.HasPrefix(disp.AuthToken, "v2/") {
130                         disp.InstanceSetID = cloud.InstanceSetID(strings.Split(disp.AuthToken, "/")[1])
131                 } else {
132                         // Use some other string unique to this token
133                         // that doesn't reveal the token itself.
134                         disp.InstanceSetID = cloud.InstanceSetID(fmt.Sprintf("%x", md5.Sum([]byte(disp.AuthToken))))
135                 }
136         }
137         disp.stop = make(chan struct{}, 1)
138         disp.stopped = make(chan struct{})
139
140         if key, err := ssh.ParsePrivateKey([]byte(disp.Cluster.Containers.DispatchPrivateKey)); err != nil {
141                 disp.logger.Fatalf("error parsing configured Containers.DispatchPrivateKey: %s", err)
142         } else {
143                 disp.sshKey = key
144         }
145
146         instanceSet, err := newInstanceSet(disp.Cluster, disp.InstanceSetID, disp.logger, disp.Registry)
147         if err != nil {
148                 disp.logger.Fatalf("error initializing driver: %s", err)
149         }
150         dblock.Dispatch.Lock(disp.Context, disp.dbConnector.GetDB)
151         disp.instanceSet = instanceSet
152         disp.pool = worker.NewPool(disp.logger, disp.ArvClient, disp.Registry, disp.InstanceSetID, disp.instanceSet, disp.newExecutor, disp.sshKey.PublicKey(), disp.Cluster)
153         disp.queue = container.NewQueue(disp.logger, disp.Registry, disp.typeChooser, disp.ArvClient)
154
155         if disp.Cluster.ManagementToken == "" {
156                 disp.httpHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
157                         http.Error(w, "Management API authentication is not configured", http.StatusForbidden)
158                 })
159         } else {
160                 mux := httprouter.New()
161                 mux.HandlerFunc("GET", "/arvados/v1/dispatch/containers", disp.apiContainers)
162                 mux.HandlerFunc("POST", "/arvados/v1/dispatch/containers/kill", disp.apiContainerKill)
163                 mux.HandlerFunc("GET", "/arvados/v1/dispatch/instances", disp.apiInstances)
164                 mux.HandlerFunc("POST", "/arvados/v1/dispatch/instances/hold", disp.apiInstanceHold)
165                 mux.HandlerFunc("POST", "/arvados/v1/dispatch/instances/drain", disp.apiInstanceDrain)
166                 mux.HandlerFunc("POST", "/arvados/v1/dispatch/instances/run", disp.apiInstanceRun)
167                 mux.HandlerFunc("POST", "/arvados/v1/dispatch/instances/kill", disp.apiInstanceKill)
168                 metricsH := promhttp.HandlerFor(disp.Registry, promhttp.HandlerOpts{
169                         ErrorLog: disp.logger,
170                 })
171                 mux.Handler("GET", "/metrics", metricsH)
172                 mux.Handler("GET", "/metrics.json", metricsH)
173                 mux.Handler("GET", "/_health/:check", &health.Handler{
174                         Token:  disp.Cluster.ManagementToken,
175                         Prefix: "/_health/",
176                         Routes: health.Routes{"ping": disp.CheckHealth},
177                 })
178                 disp.httpHandler = auth.RequireLiteralToken(disp.Cluster.ManagementToken, mux)
179         }
180 }
181
182 func (disp *dispatcher) run() {
183         defer dblock.Dispatch.Unlock()
184         defer close(disp.stopped)
185         defer disp.instanceSet.Stop()
186         defer disp.pool.Stop()
187
188         staleLockTimeout := time.Duration(disp.Cluster.Containers.StaleLockTimeout)
189         if staleLockTimeout == 0 {
190                 staleLockTimeout = defaultStaleLockTimeout
191         }
192         pollInterval := time.Duration(disp.Cluster.Containers.CloudVMs.PollInterval)
193         if pollInterval <= 0 {
194                 pollInterval = defaultPollInterval
195         }
196         maxSupervisors := int(float64(disp.Cluster.Containers.CloudVMs.MaxInstances) * disp.Cluster.Containers.CloudVMs.SupervisorFraction)
197         if maxSupervisors == 0 && disp.Cluster.Containers.CloudVMs.SupervisorFraction > 0 {
198                 maxSupervisors = 1
199         }
200         sched := scheduler.New(disp.Context, disp.ArvClient, disp.queue, disp.pool, disp.Registry, staleLockTimeout, pollInterval, maxSupervisors)
201         sched.Start()
202         defer sched.Stop()
203
204         <-disp.stop
205 }
206
207 // Management API: all active and queued containers.
208 func (disp *dispatcher) apiContainers(w http.ResponseWriter, r *http.Request) {
209         var resp struct {
210                 Items []container.QueueEnt `json:"items"`
211         }
212         qEntries, _ := disp.queue.Entries()
213         for _, ent := range qEntries {
214                 resp.Items = append(resp.Items, ent)
215         }
216         json.NewEncoder(w).Encode(resp)
217 }
218
219 // Management API: all active instances (cloud VMs).
220 func (disp *dispatcher) apiInstances(w http.ResponseWriter, r *http.Request) {
221         var resp struct {
222                 Items []worker.InstanceView `json:"items"`
223         }
224         resp.Items = disp.pool.Instances()
225         json.NewEncoder(w).Encode(resp)
226 }
227
228 // Management API: set idle behavior to "hold" for specified instance.
229 func (disp *dispatcher) apiInstanceHold(w http.ResponseWriter, r *http.Request) {
230         disp.apiInstanceIdleBehavior(w, r, worker.IdleBehaviorHold)
231 }
232
233 // Management API: set idle behavior to "drain" for specified instance.
234 func (disp *dispatcher) apiInstanceDrain(w http.ResponseWriter, r *http.Request) {
235         disp.apiInstanceIdleBehavior(w, r, worker.IdleBehaviorDrain)
236 }
237
238 // Management API: set idle behavior to "run" for specified instance.
239 func (disp *dispatcher) apiInstanceRun(w http.ResponseWriter, r *http.Request) {
240         disp.apiInstanceIdleBehavior(w, r, worker.IdleBehaviorRun)
241 }
242
243 // Management API: shutdown/destroy specified instance now.
244 func (disp *dispatcher) apiInstanceKill(w http.ResponseWriter, r *http.Request) {
245         id := cloud.InstanceID(r.FormValue("instance_id"))
246         if id == "" {
247                 httpserver.Error(w, "instance_id parameter not provided", http.StatusBadRequest)
248                 return
249         }
250         err := disp.pool.KillInstance(id, "via management API: "+r.FormValue("reason"))
251         if err != nil {
252                 httpserver.Error(w, err.Error(), http.StatusNotFound)
253                 return
254         }
255 }
256
257 // Management API: send SIGTERM to specified container's crunch-run
258 // process now.
259 func (disp *dispatcher) apiContainerKill(w http.ResponseWriter, r *http.Request) {
260         uuid := r.FormValue("container_uuid")
261         if uuid == "" {
262                 httpserver.Error(w, "container_uuid parameter not provided", http.StatusBadRequest)
263                 return
264         }
265         if !disp.pool.KillContainer(uuid, "via management API: "+r.FormValue("reason")) {
266                 httpserver.Error(w, "container not found", http.StatusNotFound)
267                 return
268         }
269 }
270
271 func (disp *dispatcher) apiInstanceIdleBehavior(w http.ResponseWriter, r *http.Request, want worker.IdleBehavior) {
272         id := cloud.InstanceID(r.FormValue("instance_id"))
273         if id == "" {
274                 httpserver.Error(w, "instance_id parameter not provided", http.StatusBadRequest)
275                 return
276         }
277         err := disp.pool.SetIdleBehavior(id, want)
278         if err != nil {
279                 httpserver.Error(w, err.Error(), http.StatusNotFound)
280                 return
281         }
282 }