18874: Make `run-tests.sh --only services/workbench2` work.
[arvados.git] / lib / dispatchcloud / dispatcher.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package dispatchcloud
6
7 import (
8         "context"
9         "crypto/md5"
10         "encoding/json"
11         "fmt"
12         "net/http"
13         "strings"
14         "sync"
15         "time"
16
17         "git.arvados.org/arvados.git/lib/cloud"
18         "git.arvados.org/arvados.git/lib/config"
19         "git.arvados.org/arvados.git/lib/controller/dblock"
20         "git.arvados.org/arvados.git/lib/ctrlctx"
21         "git.arvados.org/arvados.git/lib/dispatchcloud/container"
22         "git.arvados.org/arvados.git/lib/dispatchcloud/scheduler"
23         "git.arvados.org/arvados.git/lib/dispatchcloud/sshexecutor"
24         "git.arvados.org/arvados.git/lib/dispatchcloud/worker"
25         "git.arvados.org/arvados.git/sdk/go/arvados"
26         "git.arvados.org/arvados.git/sdk/go/auth"
27         "git.arvados.org/arvados.git/sdk/go/ctxlog"
28         "git.arvados.org/arvados.git/sdk/go/health"
29         "git.arvados.org/arvados.git/sdk/go/httpserver"
30         "github.com/julienschmidt/httprouter"
31         "github.com/prometheus/client_golang/prometheus"
32         "github.com/prometheus/client_golang/prometheus/promhttp"
33         "github.com/sirupsen/logrus"
34         "golang.org/x/crypto/ssh"
35 )
36
37 const (
38         defaultPollInterval     = time.Second
39         defaultStaleLockTimeout = time.Minute
40 )
41
42 type pool interface {
43         scheduler.WorkerPool
44         CheckHealth() error
45         Instances() []worker.InstanceView
46         SetIdleBehavior(cloud.InstanceID, worker.IdleBehavior) error
47         KillInstance(id cloud.InstanceID, reason string) error
48         Stop()
49 }
50
51 type dispatcher struct {
52         Cluster       *arvados.Cluster
53         Context       context.Context
54         ArvClient     *arvados.Client
55         AuthToken     string
56         Registry      *prometheus.Registry
57         InstanceSetID cloud.InstanceSetID
58
59         dbConnector ctrlctx.DBConnector
60         logger      logrus.FieldLogger
61         instanceSet cloud.InstanceSet
62         pool        pool
63         queue       scheduler.ContainerQueue
64         httpHandler http.Handler
65         sshKey      ssh.Signer
66
67         setupOnce sync.Once
68         stop      chan struct{}
69         stopped   chan struct{}
70 }
71
72 // Start starts the dispatcher. Start can be called multiple times
73 // with no ill effect.
74 func (disp *dispatcher) Start() {
75         disp.setupOnce.Do(disp.setup)
76 }
77
78 // ServeHTTP implements service.Handler.
79 func (disp *dispatcher) ServeHTTP(w http.ResponseWriter, r *http.Request) {
80         disp.Start()
81         disp.httpHandler.ServeHTTP(w, r)
82 }
83
84 // CheckHealth implements service.Handler.
85 func (disp *dispatcher) CheckHealth() error {
86         disp.Start()
87         return disp.pool.CheckHealth()
88 }
89
90 // Done implements service.Handler.
91 func (disp *dispatcher) Done() <-chan struct{} {
92         return disp.stopped
93 }
94
95 // Stop dispatching containers and release resources. Typically used
96 // in tests.
97 func (disp *dispatcher) Close() {
98         disp.Start()
99         select {
100         case disp.stop <- struct{}{}:
101         default:
102         }
103         <-disp.stopped
104 }
105
106 // Make a worker.Executor for the given instance.
107 func (disp *dispatcher) newExecutor(inst cloud.Instance) worker.Executor {
108         exr := sshexecutor.New(inst)
109         exr.SetTargetPort(disp.Cluster.Containers.CloudVMs.SSHPort)
110         exr.SetSigners(disp.sshKey)
111         return exr
112 }
113
114 func (disp *dispatcher) typeChooser(ctr *arvados.Container) (arvados.InstanceType, error) {
115         return ChooseInstanceType(disp.Cluster, ctr)
116 }
117
118 func (disp *dispatcher) setup() {
119         disp.initialize()
120         go disp.run()
121 }
122
123 func (disp *dispatcher) initialize() {
124         disp.logger = ctxlog.FromContext(disp.Context)
125         disp.dbConnector = ctrlctx.DBConnector{PostgreSQL: disp.Cluster.PostgreSQL}
126
127         disp.ArvClient.AuthToken = disp.AuthToken
128
129         if disp.InstanceSetID == "" {
130                 if strings.HasPrefix(disp.AuthToken, "v2/") {
131                         disp.InstanceSetID = cloud.InstanceSetID(strings.Split(disp.AuthToken, "/")[1])
132                 } else {
133                         // Use some other string unique to this token
134                         // that doesn't reveal the token itself.
135                         disp.InstanceSetID = cloud.InstanceSetID(fmt.Sprintf("%x", md5.Sum([]byte(disp.AuthToken))))
136                 }
137         }
138         disp.stop = make(chan struct{}, 1)
139         disp.stopped = make(chan struct{})
140
141         if key, err := config.LoadSSHKey(disp.Cluster.Containers.DispatchPrivateKey); err != nil {
142                 disp.logger.Fatalf("error parsing configured Containers.DispatchPrivateKey: %s", err)
143         } else {
144                 disp.sshKey = key
145         }
146         installPublicKey := disp.sshKey.PublicKey()
147         if !disp.Cluster.Containers.CloudVMs.DeployPublicKey {
148                 installPublicKey = nil
149         }
150
151         instanceSet, err := newInstanceSet(disp.Cluster, disp.InstanceSetID, disp.logger, disp.Registry)
152         if err != nil {
153                 disp.logger.Fatalf("error initializing driver: %s", err)
154         }
155         dblock.Dispatch.Lock(disp.Context, disp.dbConnector.GetDB)
156         disp.instanceSet = instanceSet
157         disp.pool = worker.NewPool(disp.logger, disp.ArvClient, disp.Registry, disp.InstanceSetID, disp.instanceSet, disp.newExecutor, installPublicKey, disp.Cluster)
158         disp.queue = container.NewQueue(disp.logger, disp.Registry, disp.typeChooser, disp.ArvClient)
159
160         if disp.Cluster.ManagementToken == "" {
161                 disp.httpHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
162                         http.Error(w, "Management API authentication is not configured", http.StatusForbidden)
163                 })
164         } else {
165                 mux := httprouter.New()
166                 mux.HandlerFunc("GET", "/arvados/v1/dispatch/containers", disp.apiContainers)
167                 mux.HandlerFunc("POST", "/arvados/v1/dispatch/containers/kill", disp.apiContainerKill)
168                 mux.HandlerFunc("GET", "/arvados/v1/dispatch/instances", disp.apiInstances)
169                 mux.HandlerFunc("POST", "/arvados/v1/dispatch/instances/hold", disp.apiInstanceHold)
170                 mux.HandlerFunc("POST", "/arvados/v1/dispatch/instances/drain", disp.apiInstanceDrain)
171                 mux.HandlerFunc("POST", "/arvados/v1/dispatch/instances/run", disp.apiInstanceRun)
172                 mux.HandlerFunc("POST", "/arvados/v1/dispatch/instances/kill", disp.apiInstanceKill)
173                 metricsH := promhttp.HandlerFor(disp.Registry, promhttp.HandlerOpts{
174                         ErrorLog: disp.logger,
175                 })
176                 mux.Handler("GET", "/metrics", metricsH)
177                 mux.Handler("GET", "/metrics.json", metricsH)
178                 mux.Handler("GET", "/_health/:check", &health.Handler{
179                         Token:  disp.Cluster.ManagementToken,
180                         Prefix: "/_health/",
181                         Routes: health.Routes{"ping": disp.CheckHealth},
182                 })
183                 disp.httpHandler = auth.RequireLiteralToken(disp.Cluster.ManagementToken, mux)
184         }
185 }
186
187 func (disp *dispatcher) run() {
188         defer dblock.Dispatch.Unlock()
189         defer close(disp.stopped)
190         defer disp.instanceSet.Stop()
191         defer disp.pool.Stop()
192
193         staleLockTimeout := time.Duration(disp.Cluster.Containers.StaleLockTimeout)
194         if staleLockTimeout == 0 {
195                 staleLockTimeout = defaultStaleLockTimeout
196         }
197         pollInterval := time.Duration(disp.Cluster.Containers.CloudVMs.PollInterval)
198         if pollInterval <= 0 {
199                 pollInterval = defaultPollInterval
200         }
201         sched := scheduler.New(disp.Context, disp.ArvClient, disp.queue, disp.pool, disp.Registry, staleLockTimeout, pollInterval,
202                 disp.Cluster.Containers.CloudVMs.InitialQuotaEstimate,
203                 disp.Cluster.Containers.CloudVMs.MaxInstances,
204                 disp.Cluster.Containers.CloudVMs.SupervisorFraction)
205         sched.Start()
206         defer sched.Stop()
207
208         <-disp.stop
209 }
210
211 // Management API: all active and queued containers.
212 func (disp *dispatcher) apiContainers(w http.ResponseWriter, r *http.Request) {
213         var resp struct {
214                 Items []container.QueueEnt `json:"items"`
215         }
216         qEntries, _ := disp.queue.Entries()
217         for _, ent := range qEntries {
218                 resp.Items = append(resp.Items, ent)
219         }
220         json.NewEncoder(w).Encode(resp)
221 }
222
223 // Management API: all active instances (cloud VMs).
224 func (disp *dispatcher) apiInstances(w http.ResponseWriter, r *http.Request) {
225         var resp struct {
226                 Items []worker.InstanceView `json:"items"`
227         }
228         resp.Items = disp.pool.Instances()
229         json.NewEncoder(w).Encode(resp)
230 }
231
232 // Management API: set idle behavior to "hold" for specified instance.
233 func (disp *dispatcher) apiInstanceHold(w http.ResponseWriter, r *http.Request) {
234         disp.apiInstanceIdleBehavior(w, r, worker.IdleBehaviorHold)
235 }
236
237 // Management API: set idle behavior to "drain" for specified instance.
238 func (disp *dispatcher) apiInstanceDrain(w http.ResponseWriter, r *http.Request) {
239         disp.apiInstanceIdleBehavior(w, r, worker.IdleBehaviorDrain)
240 }
241
242 // Management API: set idle behavior to "run" for specified instance.
243 func (disp *dispatcher) apiInstanceRun(w http.ResponseWriter, r *http.Request) {
244         disp.apiInstanceIdleBehavior(w, r, worker.IdleBehaviorRun)
245 }
246
247 // Management API: shutdown/destroy specified instance now.
248 func (disp *dispatcher) apiInstanceKill(w http.ResponseWriter, r *http.Request) {
249         id := cloud.InstanceID(r.FormValue("instance_id"))
250         if id == "" {
251                 httpserver.Error(w, "instance_id parameter not provided", http.StatusBadRequest)
252                 return
253         }
254         err := disp.pool.KillInstance(id, "via management API: "+r.FormValue("reason"))
255         if err != nil {
256                 httpserver.Error(w, err.Error(), http.StatusNotFound)
257                 return
258         }
259 }
260
261 // Management API: send SIGTERM to specified container's crunch-run
262 // process now.
263 func (disp *dispatcher) apiContainerKill(w http.ResponseWriter, r *http.Request) {
264         uuid := r.FormValue("container_uuid")
265         if uuid == "" {
266                 httpserver.Error(w, "container_uuid parameter not provided", http.StatusBadRequest)
267                 return
268         }
269         if !disp.pool.KillContainer(uuid, "via management API: "+r.FormValue("reason")) {
270                 httpserver.Error(w, "container not found", http.StatusNotFound)
271                 return
272         }
273 }
274
275 func (disp *dispatcher) apiInstanceIdleBehavior(w http.ResponseWriter, r *http.Request, want worker.IdleBehavior) {
276         id := cloud.InstanceID(r.FormValue("instance_id"))
277         if id == "" {
278                 httpserver.Error(w, "instance_id parameter not provided", http.StatusBadRequest)
279                 return
280         }
281         err := disp.pool.SetIdleBehavior(id, want)
282         if err != nil {
283                 httpserver.Error(w, err.Error(), http.StatusNotFound)
284                 return
285         }
286 }