Merge branch 'main' into 18842-arv-mount-disk-config
[arvados.git] / lib / dispatchcloud / dispatcher.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package dispatchcloud
6
7 import (
8         "context"
9         "crypto/md5"
10         "encoding/json"
11         "fmt"
12         "net/http"
13         "strings"
14         "sync"
15         "time"
16
17         "git.arvados.org/arvados.git/lib/cloud"
18         "git.arvados.org/arvados.git/lib/controller/dblock"
19         "git.arvados.org/arvados.git/lib/ctrlctx"
20         "git.arvados.org/arvados.git/lib/dispatchcloud/container"
21         "git.arvados.org/arvados.git/lib/dispatchcloud/scheduler"
22         "git.arvados.org/arvados.git/lib/dispatchcloud/sshexecutor"
23         "git.arvados.org/arvados.git/lib/dispatchcloud/worker"
24         "git.arvados.org/arvados.git/sdk/go/arvados"
25         "git.arvados.org/arvados.git/sdk/go/auth"
26         "git.arvados.org/arvados.git/sdk/go/ctxlog"
27         "git.arvados.org/arvados.git/sdk/go/health"
28         "git.arvados.org/arvados.git/sdk/go/httpserver"
29         "github.com/julienschmidt/httprouter"
30         "github.com/prometheus/client_golang/prometheus"
31         "github.com/prometheus/client_golang/prometheus/promhttp"
32         "github.com/sirupsen/logrus"
33         "golang.org/x/crypto/ssh"
34 )
35
36 const (
37         defaultPollInterval     = time.Second
38         defaultStaleLockTimeout = time.Minute
39 )
40
41 type pool interface {
42         scheduler.WorkerPool
43         CheckHealth() error
44         Instances() []worker.InstanceView
45         SetIdleBehavior(cloud.InstanceID, worker.IdleBehavior) error
46         KillInstance(id cloud.InstanceID, reason string) error
47         Stop()
48 }
49
50 type dispatcher struct {
51         Cluster       *arvados.Cluster
52         Context       context.Context
53         ArvClient     *arvados.Client
54         AuthToken     string
55         Registry      *prometheus.Registry
56         InstanceSetID cloud.InstanceSetID
57
58         dbConnector ctrlctx.DBConnector
59         logger      logrus.FieldLogger
60         instanceSet cloud.InstanceSet
61         pool        pool
62         queue       scheduler.ContainerQueue
63         httpHandler http.Handler
64         sshKey      ssh.Signer
65
66         setupOnce sync.Once
67         stop      chan struct{}
68         stopped   chan struct{}
69 }
70
71 // Start starts the dispatcher. Start can be called multiple times
72 // with no ill effect.
73 func (disp *dispatcher) Start() {
74         disp.setupOnce.Do(disp.setup)
75 }
76
77 // ServeHTTP implements service.Handler.
78 func (disp *dispatcher) ServeHTTP(w http.ResponseWriter, r *http.Request) {
79         disp.Start()
80         disp.httpHandler.ServeHTTP(w, r)
81 }
82
83 // CheckHealth implements service.Handler.
84 func (disp *dispatcher) CheckHealth() error {
85         disp.Start()
86         return disp.pool.CheckHealth()
87 }
88
89 // Done implements service.Handler.
90 func (disp *dispatcher) Done() <-chan struct{} {
91         return disp.stopped
92 }
93
94 // Stop dispatching containers and release resources. Typically used
95 // in tests.
96 func (disp *dispatcher) Close() {
97         disp.Start()
98         select {
99         case disp.stop <- struct{}{}:
100         default:
101         }
102         <-disp.stopped
103 }
104
105 // Make a worker.Executor for the given instance.
106 func (disp *dispatcher) newExecutor(inst cloud.Instance) worker.Executor {
107         exr := sshexecutor.New(inst)
108         exr.SetTargetPort(disp.Cluster.Containers.CloudVMs.SSHPort)
109         exr.SetSigners(disp.sshKey)
110         return exr
111 }
112
113 func (disp *dispatcher) typeChooser(ctr *arvados.Container) (arvados.InstanceType, error) {
114         return ChooseInstanceType(disp.Cluster, ctr)
115 }
116
117 func (disp *dispatcher) setup() {
118         disp.initialize()
119         go disp.run()
120 }
121
122 func (disp *dispatcher) initialize() {
123         disp.logger = ctxlog.FromContext(disp.Context)
124         disp.dbConnector = ctrlctx.DBConnector{PostgreSQL: disp.Cluster.PostgreSQL}
125
126         disp.ArvClient.AuthToken = disp.AuthToken
127
128         if disp.InstanceSetID == "" {
129                 if strings.HasPrefix(disp.AuthToken, "v2/") {
130                         disp.InstanceSetID = cloud.InstanceSetID(strings.Split(disp.AuthToken, "/")[1])
131                 } else {
132                         // Use some other string unique to this token
133                         // that doesn't reveal the token itself.
134                         disp.InstanceSetID = cloud.InstanceSetID(fmt.Sprintf("%x", md5.Sum([]byte(disp.AuthToken))))
135                 }
136         }
137         disp.stop = make(chan struct{}, 1)
138         disp.stopped = make(chan struct{})
139
140         if key, err := ssh.ParsePrivateKey([]byte(disp.Cluster.Containers.DispatchPrivateKey)); err != nil {
141                 disp.logger.Fatalf("error parsing configured Containers.DispatchPrivateKey: %s", err)
142         } else {
143                 disp.sshKey = key
144         }
145
146         instanceSet, err := newInstanceSet(disp.Cluster, disp.InstanceSetID, disp.logger, disp.Registry)
147         if err != nil {
148                 disp.logger.Fatalf("error initializing driver: %s", err)
149         }
150         dblock.Dispatch.Lock(disp.Context, disp.dbConnector.GetDB)
151         disp.instanceSet = instanceSet
152         disp.pool = worker.NewPool(disp.logger, disp.ArvClient, disp.Registry, disp.InstanceSetID, disp.instanceSet, disp.newExecutor, disp.sshKey.PublicKey(), disp.Cluster)
153         disp.queue = container.NewQueue(disp.logger, disp.Registry, disp.typeChooser, disp.ArvClient)
154
155         if disp.Cluster.ManagementToken == "" {
156                 disp.httpHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
157                         http.Error(w, "Management API authentication is not configured", http.StatusForbidden)
158                 })
159         } else {
160                 mux := httprouter.New()
161                 mux.HandlerFunc("GET", "/arvados/v1/dispatch/containers", disp.apiContainers)
162                 mux.HandlerFunc("POST", "/arvados/v1/dispatch/containers/kill", disp.apiContainerKill)
163                 mux.HandlerFunc("GET", "/arvados/v1/dispatch/instances", disp.apiInstances)
164                 mux.HandlerFunc("POST", "/arvados/v1/dispatch/instances/hold", disp.apiInstanceHold)
165                 mux.HandlerFunc("POST", "/arvados/v1/dispatch/instances/drain", disp.apiInstanceDrain)
166                 mux.HandlerFunc("POST", "/arvados/v1/dispatch/instances/run", disp.apiInstanceRun)
167                 mux.HandlerFunc("POST", "/arvados/v1/dispatch/instances/kill", disp.apiInstanceKill)
168                 metricsH := promhttp.HandlerFor(disp.Registry, promhttp.HandlerOpts{
169                         ErrorLog: disp.logger,
170                 })
171                 mux.Handler("GET", "/metrics", metricsH)
172                 mux.Handler("GET", "/metrics.json", metricsH)
173                 mux.Handler("GET", "/_health/:check", &health.Handler{
174                         Token:  disp.Cluster.ManagementToken,
175                         Prefix: "/_health/",
176                         Routes: health.Routes{"ping": disp.CheckHealth},
177                 })
178                 disp.httpHandler = auth.RequireLiteralToken(disp.Cluster.ManagementToken, mux)
179         }
180 }
181
182 func (disp *dispatcher) run() {
183         defer dblock.Dispatch.Unlock()
184         defer close(disp.stopped)
185         defer disp.instanceSet.Stop()
186         defer disp.pool.Stop()
187
188         staleLockTimeout := time.Duration(disp.Cluster.Containers.StaleLockTimeout)
189         if staleLockTimeout == 0 {
190                 staleLockTimeout = defaultStaleLockTimeout
191         }
192         pollInterval := time.Duration(disp.Cluster.Containers.CloudVMs.PollInterval)
193         if pollInterval <= 0 {
194                 pollInterval = defaultPollInterval
195         }
196         sched := scheduler.New(disp.Context, disp.queue, disp.pool, disp.Registry, staleLockTimeout, pollInterval)
197         sched.Start()
198         defer sched.Stop()
199
200         <-disp.stop
201 }
202
203 // Management API: all active and queued containers.
204 func (disp *dispatcher) apiContainers(w http.ResponseWriter, r *http.Request) {
205         var resp struct {
206                 Items []container.QueueEnt `json:"items"`
207         }
208         qEntries, _ := disp.queue.Entries()
209         for _, ent := range qEntries {
210                 resp.Items = append(resp.Items, ent)
211         }
212         json.NewEncoder(w).Encode(resp)
213 }
214
215 // Management API: all active instances (cloud VMs).
216 func (disp *dispatcher) apiInstances(w http.ResponseWriter, r *http.Request) {
217         var resp struct {
218                 Items []worker.InstanceView `json:"items"`
219         }
220         resp.Items = disp.pool.Instances()
221         json.NewEncoder(w).Encode(resp)
222 }
223
224 // Management API: set idle behavior to "hold" for specified instance.
225 func (disp *dispatcher) apiInstanceHold(w http.ResponseWriter, r *http.Request) {
226         disp.apiInstanceIdleBehavior(w, r, worker.IdleBehaviorHold)
227 }
228
229 // Management API: set idle behavior to "drain" for specified instance.
230 func (disp *dispatcher) apiInstanceDrain(w http.ResponseWriter, r *http.Request) {
231         disp.apiInstanceIdleBehavior(w, r, worker.IdleBehaviorDrain)
232 }
233
234 // Management API: set idle behavior to "run" for specified instance.
235 func (disp *dispatcher) apiInstanceRun(w http.ResponseWriter, r *http.Request) {
236         disp.apiInstanceIdleBehavior(w, r, worker.IdleBehaviorRun)
237 }
238
239 // Management API: shutdown/destroy specified instance now.
240 func (disp *dispatcher) apiInstanceKill(w http.ResponseWriter, r *http.Request) {
241         id := cloud.InstanceID(r.FormValue("instance_id"))
242         if id == "" {
243                 httpserver.Error(w, "instance_id parameter not provided", http.StatusBadRequest)
244                 return
245         }
246         err := disp.pool.KillInstance(id, "via management API: "+r.FormValue("reason"))
247         if err != nil {
248                 httpserver.Error(w, err.Error(), http.StatusNotFound)
249                 return
250         }
251 }
252
253 // Management API: send SIGTERM to specified container's crunch-run
254 // process now.
255 func (disp *dispatcher) apiContainerKill(w http.ResponseWriter, r *http.Request) {
256         uuid := r.FormValue("container_uuid")
257         if uuid == "" {
258                 httpserver.Error(w, "container_uuid parameter not provided", http.StatusBadRequest)
259                 return
260         }
261         if !disp.pool.KillContainer(uuid, "via management API: "+r.FormValue("reason")) {
262                 httpserver.Error(w, "container not found", http.StatusNotFound)
263                 return
264         }
265 }
266
267 func (disp *dispatcher) apiInstanceIdleBehavior(w http.ResponseWriter, r *http.Request, want worker.IdleBehavior) {
268         id := cloud.InstanceID(r.FormValue("instance_id"))
269         if id == "" {
270                 httpserver.Error(w, "instance_id parameter not provided", http.StatusBadRequest)
271                 return
272         }
273         err := disp.pool.SetIdleBehavior(id, want)
274         if err != nil {
275                 httpserver.Error(w, err.Error(), http.StatusNotFound)
276                 return
277         }
278 }