Merge branch '20594-scaling-nginx-settings'. Closes #20594
[arvados.git] / lib / dispatchcloud / dispatcher.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package dispatchcloud
6
7 import (
8         "context"
9         "crypto/md5"
10         "encoding/json"
11         "fmt"
12         "net/http"
13         "strings"
14         "sync"
15         "time"
16
17         "git.arvados.org/arvados.git/lib/cloud"
18         "git.arvados.org/arvados.git/lib/controller/dblock"
19         "git.arvados.org/arvados.git/lib/ctrlctx"
20         "git.arvados.org/arvados.git/lib/dispatchcloud/container"
21         "git.arvados.org/arvados.git/lib/dispatchcloud/scheduler"
22         "git.arvados.org/arvados.git/lib/dispatchcloud/sshexecutor"
23         "git.arvados.org/arvados.git/lib/dispatchcloud/worker"
24         "git.arvados.org/arvados.git/sdk/go/arvados"
25         "git.arvados.org/arvados.git/sdk/go/auth"
26         "git.arvados.org/arvados.git/sdk/go/ctxlog"
27         "git.arvados.org/arvados.git/sdk/go/health"
28         "git.arvados.org/arvados.git/sdk/go/httpserver"
29         "github.com/julienschmidt/httprouter"
30         "github.com/prometheus/client_golang/prometheus"
31         "github.com/prometheus/client_golang/prometheus/promhttp"
32         "github.com/sirupsen/logrus"
33         "golang.org/x/crypto/ssh"
34 )
35
36 const (
37         defaultPollInterval     = time.Second
38         defaultStaleLockTimeout = time.Minute
39 )
40
41 type pool interface {
42         scheduler.WorkerPool
43         CheckHealth() error
44         Instances() []worker.InstanceView
45         SetIdleBehavior(cloud.InstanceID, worker.IdleBehavior) error
46         KillInstance(id cloud.InstanceID, reason string) error
47         Stop()
48 }
49
50 type dispatcher struct {
51         Cluster       *arvados.Cluster
52         Context       context.Context
53         ArvClient     *arvados.Client
54         AuthToken     string
55         Registry      *prometheus.Registry
56         InstanceSetID cloud.InstanceSetID
57
58         dbConnector ctrlctx.DBConnector
59         logger      logrus.FieldLogger
60         instanceSet cloud.InstanceSet
61         pool        pool
62         queue       scheduler.ContainerQueue
63         httpHandler http.Handler
64         sshKey      ssh.Signer
65
66         setupOnce sync.Once
67         stop      chan struct{}
68         stopped   chan struct{}
69 }
70
71 // Start starts the dispatcher. Start can be called multiple times
72 // with no ill effect.
73 func (disp *dispatcher) Start() {
74         disp.setupOnce.Do(disp.setup)
75 }
76
77 // ServeHTTP implements service.Handler.
78 func (disp *dispatcher) ServeHTTP(w http.ResponseWriter, r *http.Request) {
79         disp.Start()
80         disp.httpHandler.ServeHTTP(w, r)
81 }
82
83 // CheckHealth implements service.Handler.
84 func (disp *dispatcher) CheckHealth() error {
85         disp.Start()
86         return disp.pool.CheckHealth()
87 }
88
89 // Done implements service.Handler.
90 func (disp *dispatcher) Done() <-chan struct{} {
91         return disp.stopped
92 }
93
94 // Stop dispatching containers and release resources. Typically used
95 // in tests.
96 func (disp *dispatcher) Close() {
97         disp.Start()
98         select {
99         case disp.stop <- struct{}{}:
100         default:
101         }
102         <-disp.stopped
103 }
104
105 // Make a worker.Executor for the given instance.
106 func (disp *dispatcher) newExecutor(inst cloud.Instance) worker.Executor {
107         exr := sshexecutor.New(inst)
108         exr.SetTargetPort(disp.Cluster.Containers.CloudVMs.SSHPort)
109         exr.SetSigners(disp.sshKey)
110         return exr
111 }
112
113 func (disp *dispatcher) typeChooser(ctr *arvados.Container) (arvados.InstanceType, error) {
114         return ChooseInstanceType(disp.Cluster, ctr)
115 }
116
117 func (disp *dispatcher) setup() {
118         disp.initialize()
119         go disp.run()
120 }
121
122 func (disp *dispatcher) initialize() {
123         disp.logger = ctxlog.FromContext(disp.Context)
124         disp.dbConnector = ctrlctx.DBConnector{PostgreSQL: disp.Cluster.PostgreSQL}
125
126         disp.ArvClient.AuthToken = disp.AuthToken
127
128         if disp.InstanceSetID == "" {
129                 if strings.HasPrefix(disp.AuthToken, "v2/") {
130                         disp.InstanceSetID = cloud.InstanceSetID(strings.Split(disp.AuthToken, "/")[1])
131                 } else {
132                         // Use some other string unique to this token
133                         // that doesn't reveal the token itself.
134                         disp.InstanceSetID = cloud.InstanceSetID(fmt.Sprintf("%x", md5.Sum([]byte(disp.AuthToken))))
135                 }
136         }
137         disp.stop = make(chan struct{}, 1)
138         disp.stopped = make(chan struct{})
139
140         if key, err := ssh.ParsePrivateKey([]byte(disp.Cluster.Containers.DispatchPrivateKey)); err != nil {
141                 disp.logger.Fatalf("error parsing configured Containers.DispatchPrivateKey: %s", err)
142         } else {
143                 disp.sshKey = key
144         }
145         installPublicKey := disp.sshKey.PublicKey()
146         if !disp.Cluster.Containers.CloudVMs.DeployPublicKey {
147                 installPublicKey = nil
148         }
149
150         instanceSet, err := newInstanceSet(disp.Cluster, disp.InstanceSetID, disp.logger, disp.Registry)
151         if err != nil {
152                 disp.logger.Fatalf("error initializing driver: %s", err)
153         }
154         dblock.Dispatch.Lock(disp.Context, disp.dbConnector.GetDB)
155         disp.instanceSet = instanceSet
156         disp.pool = worker.NewPool(disp.logger, disp.ArvClient, disp.Registry, disp.InstanceSetID, disp.instanceSet, disp.newExecutor, installPublicKey, disp.Cluster)
157         disp.queue = container.NewQueue(disp.logger, disp.Registry, disp.typeChooser, disp.ArvClient)
158
159         if disp.Cluster.ManagementToken == "" {
160                 disp.httpHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
161                         http.Error(w, "Management API authentication is not configured", http.StatusForbidden)
162                 })
163         } else {
164                 mux := httprouter.New()
165                 mux.HandlerFunc("GET", "/arvados/v1/dispatch/containers", disp.apiContainers)
166                 mux.HandlerFunc("POST", "/arvados/v1/dispatch/containers/kill", disp.apiContainerKill)
167                 mux.HandlerFunc("GET", "/arvados/v1/dispatch/instances", disp.apiInstances)
168                 mux.HandlerFunc("POST", "/arvados/v1/dispatch/instances/hold", disp.apiInstanceHold)
169                 mux.HandlerFunc("POST", "/arvados/v1/dispatch/instances/drain", disp.apiInstanceDrain)
170                 mux.HandlerFunc("POST", "/arvados/v1/dispatch/instances/run", disp.apiInstanceRun)
171                 mux.HandlerFunc("POST", "/arvados/v1/dispatch/instances/kill", disp.apiInstanceKill)
172                 metricsH := promhttp.HandlerFor(disp.Registry, promhttp.HandlerOpts{
173                         ErrorLog: disp.logger,
174                 })
175                 mux.Handler("GET", "/metrics", metricsH)
176                 mux.Handler("GET", "/metrics.json", metricsH)
177                 mux.Handler("GET", "/_health/:check", &health.Handler{
178                         Token:  disp.Cluster.ManagementToken,
179                         Prefix: "/_health/",
180                         Routes: health.Routes{"ping": disp.CheckHealth},
181                 })
182                 disp.httpHandler = auth.RequireLiteralToken(disp.Cluster.ManagementToken, mux)
183         }
184 }
185
186 func (disp *dispatcher) run() {
187         defer dblock.Dispatch.Unlock()
188         defer close(disp.stopped)
189         defer disp.instanceSet.Stop()
190         defer disp.pool.Stop()
191
192         staleLockTimeout := time.Duration(disp.Cluster.Containers.StaleLockTimeout)
193         if staleLockTimeout == 0 {
194                 staleLockTimeout = defaultStaleLockTimeout
195         }
196         pollInterval := time.Duration(disp.Cluster.Containers.CloudVMs.PollInterval)
197         if pollInterval <= 0 {
198                 pollInterval = defaultPollInterval
199         }
200         maxSupervisors := int(float64(disp.Cluster.Containers.CloudVMs.MaxInstances) * disp.Cluster.Containers.CloudVMs.SupervisorFraction)
201         if maxSupervisors == 0 && disp.Cluster.Containers.CloudVMs.SupervisorFraction > 0 {
202                 maxSupervisors = 1
203         }
204         sched := scheduler.New(disp.Context, disp.ArvClient, disp.queue, disp.pool, disp.Registry, staleLockTimeout, pollInterval, maxSupervisors)
205         sched.Start()
206         defer sched.Stop()
207
208         <-disp.stop
209 }
210
211 // Management API: all active and queued containers.
212 func (disp *dispatcher) apiContainers(w http.ResponseWriter, r *http.Request) {
213         var resp struct {
214                 Items []container.QueueEnt `json:"items"`
215         }
216         qEntries, _ := disp.queue.Entries()
217         for _, ent := range qEntries {
218                 resp.Items = append(resp.Items, ent)
219         }
220         json.NewEncoder(w).Encode(resp)
221 }
222
223 // Management API: all active instances (cloud VMs).
224 func (disp *dispatcher) apiInstances(w http.ResponseWriter, r *http.Request) {
225         var resp struct {
226                 Items []worker.InstanceView `json:"items"`
227         }
228         resp.Items = disp.pool.Instances()
229         json.NewEncoder(w).Encode(resp)
230 }
231
232 // Management API: set idle behavior to "hold" for specified instance.
233 func (disp *dispatcher) apiInstanceHold(w http.ResponseWriter, r *http.Request) {
234         disp.apiInstanceIdleBehavior(w, r, worker.IdleBehaviorHold)
235 }
236
237 // Management API: set idle behavior to "drain" for specified instance.
238 func (disp *dispatcher) apiInstanceDrain(w http.ResponseWriter, r *http.Request) {
239         disp.apiInstanceIdleBehavior(w, r, worker.IdleBehaviorDrain)
240 }
241
242 // Management API: set idle behavior to "run" for specified instance.
243 func (disp *dispatcher) apiInstanceRun(w http.ResponseWriter, r *http.Request) {
244         disp.apiInstanceIdleBehavior(w, r, worker.IdleBehaviorRun)
245 }
246
247 // Management API: shutdown/destroy specified instance now.
248 func (disp *dispatcher) apiInstanceKill(w http.ResponseWriter, r *http.Request) {
249         id := cloud.InstanceID(r.FormValue("instance_id"))
250         if id == "" {
251                 httpserver.Error(w, "instance_id parameter not provided", http.StatusBadRequest)
252                 return
253         }
254         err := disp.pool.KillInstance(id, "via management API: "+r.FormValue("reason"))
255         if err != nil {
256                 httpserver.Error(w, err.Error(), http.StatusNotFound)
257                 return
258         }
259 }
260
261 // Management API: send SIGTERM to specified container's crunch-run
262 // process now.
263 func (disp *dispatcher) apiContainerKill(w http.ResponseWriter, r *http.Request) {
264         uuid := r.FormValue("container_uuid")
265         if uuid == "" {
266                 httpserver.Error(w, "container_uuid parameter not provided", http.StatusBadRequest)
267                 return
268         }
269         if !disp.pool.KillContainer(uuid, "via management API: "+r.FormValue("reason")) {
270                 httpserver.Error(w, "container not found", http.StatusNotFound)
271                 return
272         }
273 }
274
275 func (disp *dispatcher) apiInstanceIdleBehavior(w http.ResponseWriter, r *http.Request, want worker.IdleBehavior) {
276         id := cloud.InstanceID(r.FormValue("instance_id"))
277         if id == "" {
278                 httpserver.Error(w, "instance_id parameter not provided", http.StatusBadRequest)
279                 return
280         }
281         err := disp.pool.SetIdleBehavior(id, want)
282         if err != nil {
283                 httpserver.Error(w, err.Error(), http.StatusNotFound)
284                 return
285         }
286 }