Merge branch '14018-acr-set-container-properties' into main
[arvados.git] / lib / dispatchcloud / dispatcher.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package dispatchcloud
6
7 import (
8         "context"
9         "crypto/md5"
10         "encoding/json"
11         "fmt"
12         "net/http"
13         "strings"
14         "sync"
15         "time"
16
17         "git.arvados.org/arvados.git/lib/cloud"
18         "git.arvados.org/arvados.git/lib/dispatchcloud/container"
19         "git.arvados.org/arvados.git/lib/dispatchcloud/scheduler"
20         "git.arvados.org/arvados.git/lib/dispatchcloud/sshexecutor"
21         "git.arvados.org/arvados.git/lib/dispatchcloud/worker"
22         "git.arvados.org/arvados.git/sdk/go/arvados"
23         "git.arvados.org/arvados.git/sdk/go/auth"
24         "git.arvados.org/arvados.git/sdk/go/ctxlog"
25         "git.arvados.org/arvados.git/sdk/go/health"
26         "git.arvados.org/arvados.git/sdk/go/httpserver"
27         "github.com/julienschmidt/httprouter"
28         "github.com/prometheus/client_golang/prometheus"
29         "github.com/prometheus/client_golang/prometheus/promhttp"
30         "github.com/sirupsen/logrus"
31         "golang.org/x/crypto/ssh"
32 )
33
34 const (
35         defaultPollInterval     = time.Second
36         defaultStaleLockTimeout = time.Minute
37 )
38
39 type pool interface {
40         scheduler.WorkerPool
41         CheckHealth() error
42         Instances() []worker.InstanceView
43         SetIdleBehavior(cloud.InstanceID, worker.IdleBehavior) error
44         KillInstance(id cloud.InstanceID, reason string) error
45         Stop()
46 }
47
48 type dispatcher struct {
49         Cluster       *arvados.Cluster
50         Context       context.Context
51         ArvClient     *arvados.Client
52         AuthToken     string
53         Registry      *prometheus.Registry
54         InstanceSetID cloud.InstanceSetID
55
56         logger      logrus.FieldLogger
57         instanceSet cloud.InstanceSet
58         pool        pool
59         queue       scheduler.ContainerQueue
60         httpHandler http.Handler
61         sshKey      ssh.Signer
62
63         setupOnce sync.Once
64         stop      chan struct{}
65         stopped   chan struct{}
66 }
67
68 // Start starts the dispatcher. Start can be called multiple times
69 // with no ill effect.
70 func (disp *dispatcher) Start() {
71         disp.setupOnce.Do(disp.setup)
72 }
73
74 // ServeHTTP implements service.Handler.
75 func (disp *dispatcher) ServeHTTP(w http.ResponseWriter, r *http.Request) {
76         disp.Start()
77         disp.httpHandler.ServeHTTP(w, r)
78 }
79
80 // CheckHealth implements service.Handler.
81 func (disp *dispatcher) CheckHealth() error {
82         disp.Start()
83         return disp.pool.CheckHealth()
84 }
85
86 // Done implements service.Handler.
87 func (disp *dispatcher) Done() <-chan struct{} {
88         return disp.stopped
89 }
90
91 // Stop dispatching containers and release resources. Typically used
92 // in tests.
93 func (disp *dispatcher) Close() {
94         disp.Start()
95         select {
96         case disp.stop <- struct{}{}:
97         default:
98         }
99         <-disp.stopped
100 }
101
102 // Make a worker.Executor for the given instance.
103 func (disp *dispatcher) newExecutor(inst cloud.Instance) worker.Executor {
104         exr := sshexecutor.New(inst)
105         exr.SetTargetPort(disp.Cluster.Containers.CloudVMs.SSHPort)
106         exr.SetSigners(disp.sshKey)
107         return exr
108 }
109
110 func (disp *dispatcher) typeChooser(ctr *arvados.Container) (arvados.InstanceType, error) {
111         return ChooseInstanceType(disp.Cluster, ctr)
112 }
113
114 func (disp *dispatcher) setup() {
115         disp.initialize()
116         go disp.run()
117 }
118
119 func (disp *dispatcher) initialize() {
120         disp.logger = ctxlog.FromContext(disp.Context)
121
122         disp.ArvClient.AuthToken = disp.AuthToken
123
124         if disp.InstanceSetID == "" {
125                 if strings.HasPrefix(disp.AuthToken, "v2/") {
126                         disp.InstanceSetID = cloud.InstanceSetID(strings.Split(disp.AuthToken, "/")[1])
127                 } else {
128                         // Use some other string unique to this token
129                         // that doesn't reveal the token itself.
130                         disp.InstanceSetID = cloud.InstanceSetID(fmt.Sprintf("%x", md5.Sum([]byte(disp.AuthToken))))
131                 }
132         }
133         disp.stop = make(chan struct{}, 1)
134         disp.stopped = make(chan struct{})
135
136         if key, err := ssh.ParsePrivateKey([]byte(disp.Cluster.Containers.DispatchPrivateKey)); err != nil {
137                 disp.logger.Fatalf("error parsing configured Containers.DispatchPrivateKey: %s", err)
138         } else {
139                 disp.sshKey = key
140         }
141
142         instanceSet, err := newInstanceSet(disp.Cluster, disp.InstanceSetID, disp.logger, disp.Registry)
143         if err != nil {
144                 disp.logger.Fatalf("error initializing driver: %s", err)
145         }
146         disp.instanceSet = instanceSet
147         disp.pool = worker.NewPool(disp.logger, disp.ArvClient, disp.Registry, disp.InstanceSetID, disp.instanceSet, disp.newExecutor, disp.sshKey.PublicKey(), disp.Cluster)
148         disp.queue = container.NewQueue(disp.logger, disp.Registry, disp.typeChooser, disp.ArvClient)
149
150         if disp.Cluster.ManagementToken == "" {
151                 disp.httpHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
152                         http.Error(w, "Management API authentication is not configured", http.StatusForbidden)
153                 })
154         } else {
155                 mux := httprouter.New()
156                 mux.HandlerFunc("GET", "/arvados/v1/dispatch/containers", disp.apiContainers)
157                 mux.HandlerFunc("POST", "/arvados/v1/dispatch/containers/kill", disp.apiContainerKill)
158                 mux.HandlerFunc("GET", "/arvados/v1/dispatch/instances", disp.apiInstances)
159                 mux.HandlerFunc("POST", "/arvados/v1/dispatch/instances/hold", disp.apiInstanceHold)
160                 mux.HandlerFunc("POST", "/arvados/v1/dispatch/instances/drain", disp.apiInstanceDrain)
161                 mux.HandlerFunc("POST", "/arvados/v1/dispatch/instances/run", disp.apiInstanceRun)
162                 mux.HandlerFunc("POST", "/arvados/v1/dispatch/instances/kill", disp.apiInstanceKill)
163                 metricsH := promhttp.HandlerFor(disp.Registry, promhttp.HandlerOpts{
164                         ErrorLog: disp.logger,
165                 })
166                 mux.Handler("GET", "/metrics", metricsH)
167                 mux.Handler("GET", "/metrics.json", metricsH)
168                 mux.Handler("GET", "/_health/:check", &health.Handler{
169                         Token:  disp.Cluster.ManagementToken,
170                         Prefix: "/_health/",
171                         Routes: health.Routes{"ping": disp.CheckHealth},
172                 })
173                 disp.httpHandler = auth.RequireLiteralToken(disp.Cluster.ManagementToken, mux)
174         }
175 }
176
177 func (disp *dispatcher) run() {
178         defer close(disp.stopped)
179         defer disp.instanceSet.Stop()
180         defer disp.pool.Stop()
181
182         staleLockTimeout := time.Duration(disp.Cluster.Containers.StaleLockTimeout)
183         if staleLockTimeout == 0 {
184                 staleLockTimeout = defaultStaleLockTimeout
185         }
186         pollInterval := time.Duration(disp.Cluster.Containers.CloudVMs.PollInterval)
187         if pollInterval <= 0 {
188                 pollInterval = defaultPollInterval
189         }
190         sched := scheduler.New(disp.Context, disp.queue, disp.pool, disp.Registry, staleLockTimeout, pollInterval)
191         sched.Start()
192         defer sched.Stop()
193
194         <-disp.stop
195 }
196
197 // Management API: all active and queued containers.
198 func (disp *dispatcher) apiContainers(w http.ResponseWriter, r *http.Request) {
199         var resp struct {
200                 Items []container.QueueEnt `json:"items"`
201         }
202         qEntries, _ := disp.queue.Entries()
203         for _, ent := range qEntries {
204                 resp.Items = append(resp.Items, ent)
205         }
206         json.NewEncoder(w).Encode(resp)
207 }
208
209 // Management API: all active instances (cloud VMs).
210 func (disp *dispatcher) apiInstances(w http.ResponseWriter, r *http.Request) {
211         var resp struct {
212                 Items []worker.InstanceView `json:"items"`
213         }
214         resp.Items = disp.pool.Instances()
215         json.NewEncoder(w).Encode(resp)
216 }
217
218 // Management API: set idle behavior to "hold" for specified instance.
219 func (disp *dispatcher) apiInstanceHold(w http.ResponseWriter, r *http.Request) {
220         disp.apiInstanceIdleBehavior(w, r, worker.IdleBehaviorHold)
221 }
222
223 // Management API: set idle behavior to "drain" for specified instance.
224 func (disp *dispatcher) apiInstanceDrain(w http.ResponseWriter, r *http.Request) {
225         disp.apiInstanceIdleBehavior(w, r, worker.IdleBehaviorDrain)
226 }
227
228 // Management API: set idle behavior to "run" for specified instance.
229 func (disp *dispatcher) apiInstanceRun(w http.ResponseWriter, r *http.Request) {
230         disp.apiInstanceIdleBehavior(w, r, worker.IdleBehaviorRun)
231 }
232
233 // Management API: shutdown/destroy specified instance now.
234 func (disp *dispatcher) apiInstanceKill(w http.ResponseWriter, r *http.Request) {
235         id := cloud.InstanceID(r.FormValue("instance_id"))
236         if id == "" {
237                 httpserver.Error(w, "instance_id parameter not provided", http.StatusBadRequest)
238                 return
239         }
240         err := disp.pool.KillInstance(id, "via management API: "+r.FormValue("reason"))
241         if err != nil {
242                 httpserver.Error(w, err.Error(), http.StatusNotFound)
243                 return
244         }
245 }
246
247 // Management API: send SIGTERM to specified container's crunch-run
248 // process now.
249 func (disp *dispatcher) apiContainerKill(w http.ResponseWriter, r *http.Request) {
250         uuid := r.FormValue("container_uuid")
251         if uuid == "" {
252                 httpserver.Error(w, "container_uuid parameter not provided", http.StatusBadRequest)
253                 return
254         }
255         if !disp.pool.KillContainer(uuid, "via management API: "+r.FormValue("reason")) {
256                 httpserver.Error(w, "container not found", http.StatusNotFound)
257                 return
258         }
259 }
260
261 func (disp *dispatcher) apiInstanceIdleBehavior(w http.ResponseWriter, r *http.Request, want worker.IdleBehavior) {
262         id := cloud.InstanceID(r.FormValue("instance_id"))
263         if id == "" {
264                 httpserver.Error(w, "instance_id parameter not provided", http.StatusBadRequest)
265                 return
266         }
267         err := disp.pool.SetIdleBehavior(id, want)
268         if err != nil {
269                 httpserver.Error(w, err.Error(), http.StatusNotFound)
270                 return
271         }
272 }