Merge branch '20647-cr-logs-preflight'
[arvados.git] / lib / dispatchcloud / dispatcher.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package dispatchcloud
6
7 import (
8         "context"
9         "crypto/md5"
10         "encoding/json"
11         "fmt"
12         "net/http"
13         "strings"
14         "sync"
15         "time"
16
17         "git.arvados.org/arvados.git/lib/cloud"
18         "git.arvados.org/arvados.git/lib/config"
19         "git.arvados.org/arvados.git/lib/controller/dblock"
20         "git.arvados.org/arvados.git/lib/ctrlctx"
21         "git.arvados.org/arvados.git/lib/dispatchcloud/container"
22         "git.arvados.org/arvados.git/lib/dispatchcloud/scheduler"
23         "git.arvados.org/arvados.git/lib/dispatchcloud/sshexecutor"
24         "git.arvados.org/arvados.git/lib/dispatchcloud/worker"
25         "git.arvados.org/arvados.git/sdk/go/arvados"
26         "git.arvados.org/arvados.git/sdk/go/auth"
27         "git.arvados.org/arvados.git/sdk/go/ctxlog"
28         "git.arvados.org/arvados.git/sdk/go/health"
29         "git.arvados.org/arvados.git/sdk/go/httpserver"
30         "github.com/julienschmidt/httprouter"
31         "github.com/prometheus/client_golang/prometheus"
32         "github.com/prometheus/client_golang/prometheus/promhttp"
33         "github.com/sirupsen/logrus"
34         "golang.org/x/crypto/ssh"
35 )
36
37 const (
38         defaultPollInterval     = time.Second
39         defaultStaleLockTimeout = time.Minute
40 )
41
42 type pool interface {
43         scheduler.WorkerPool
44         CheckHealth() error
45         Instances() []worker.InstanceView
46         SetIdleBehavior(cloud.InstanceID, worker.IdleBehavior) error
47         KillInstance(id cloud.InstanceID, reason string) error
48         Stop()
49 }
50
51 type dispatcher struct {
52         Cluster       *arvados.Cluster
53         Context       context.Context
54         ArvClient     *arvados.Client
55         AuthToken     string
56         Registry      *prometheus.Registry
57         InstanceSetID cloud.InstanceSetID
58
59         dbConnector ctrlctx.DBConnector
60         logger      logrus.FieldLogger
61         instanceSet cloud.InstanceSet
62         pool        pool
63         queue       scheduler.ContainerQueue
64         httpHandler http.Handler
65         sshKey      ssh.Signer
66
67         setupOnce sync.Once
68         stop      chan struct{}
69         stopped   chan struct{}
70 }
71
72 // Start starts the dispatcher. Start can be called multiple times
73 // with no ill effect.
74 func (disp *dispatcher) Start() {
75         disp.setupOnce.Do(disp.setup)
76 }
77
78 // ServeHTTP implements service.Handler.
79 func (disp *dispatcher) ServeHTTP(w http.ResponseWriter, r *http.Request) {
80         disp.Start()
81         disp.httpHandler.ServeHTTP(w, r)
82 }
83
84 // CheckHealth implements service.Handler.
85 func (disp *dispatcher) CheckHealth() error {
86         disp.Start()
87         return disp.pool.CheckHealth()
88 }
89
90 // Done implements service.Handler.
91 func (disp *dispatcher) Done() <-chan struct{} {
92         return disp.stopped
93 }
94
95 // Stop dispatching containers and release resources. Typically used
96 // in tests.
97 func (disp *dispatcher) Close() {
98         disp.Start()
99         select {
100         case disp.stop <- struct{}{}:
101         default:
102         }
103         <-disp.stopped
104 }
105
106 // Make a worker.Executor for the given instance.
107 func (disp *dispatcher) newExecutor(inst cloud.Instance) worker.Executor {
108         exr := sshexecutor.New(inst)
109         exr.SetTargetPort(disp.Cluster.Containers.CloudVMs.SSHPort)
110         exr.SetSigners(disp.sshKey)
111         return exr
112 }
113
114 func (disp *dispatcher) typeChooser(ctr *arvados.Container) (arvados.InstanceType, error) {
115         return ChooseInstanceType(disp.Cluster, ctr)
116 }
117
118 func (disp *dispatcher) setup() {
119         disp.initialize()
120         go disp.run()
121 }
122
123 func (disp *dispatcher) initialize() {
124         disp.logger = ctxlog.FromContext(disp.Context)
125         disp.dbConnector = ctrlctx.DBConnector{PostgreSQL: disp.Cluster.PostgreSQL}
126
127         disp.ArvClient.AuthToken = disp.AuthToken
128
129         if disp.InstanceSetID == "" {
130                 if strings.HasPrefix(disp.AuthToken, "v2/") {
131                         disp.InstanceSetID = cloud.InstanceSetID(strings.Split(disp.AuthToken, "/")[1])
132                 } else {
133                         // Use some other string unique to this token
134                         // that doesn't reveal the token itself.
135                         disp.InstanceSetID = cloud.InstanceSetID(fmt.Sprintf("%x", md5.Sum([]byte(disp.AuthToken))))
136                 }
137         }
138         disp.stop = make(chan struct{}, 1)
139         disp.stopped = make(chan struct{})
140
141         if key, err := config.LoadSSHKey(disp.Cluster.Containers.DispatchPrivateKey); err != nil {
142                 disp.logger.Fatalf("error parsing configured Containers.DispatchPrivateKey: %s", err)
143         } else {
144                 disp.sshKey = key
145         }
146         installPublicKey := disp.sshKey.PublicKey()
147         if !disp.Cluster.Containers.CloudVMs.DeployPublicKey {
148                 installPublicKey = nil
149         }
150
151         instanceSet, err := newInstanceSet(disp.Cluster, disp.InstanceSetID, disp.logger, disp.Registry)
152         if err != nil {
153                 disp.logger.Fatalf("error initializing driver: %s", err)
154         }
155         dblock.Dispatch.Lock(disp.Context, disp.dbConnector.GetDB)
156         disp.instanceSet = instanceSet
157         disp.pool = worker.NewPool(disp.logger, disp.ArvClient, disp.Registry, disp.InstanceSetID, disp.instanceSet, disp.newExecutor, installPublicKey, disp.Cluster)
158         disp.queue = container.NewQueue(disp.logger, disp.Registry, disp.typeChooser, disp.ArvClient)
159
160         if disp.Cluster.ManagementToken == "" {
161                 disp.httpHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
162                         http.Error(w, "Management API authentication is not configured", http.StatusForbidden)
163                 })
164         } else {
165                 mux := httprouter.New()
166                 mux.HandlerFunc("GET", "/arvados/v1/dispatch/containers", disp.apiContainers)
167                 mux.HandlerFunc("POST", "/arvados/v1/dispatch/containers/kill", disp.apiContainerKill)
168                 mux.HandlerFunc("GET", "/arvados/v1/dispatch/instances", disp.apiInstances)
169                 mux.HandlerFunc("POST", "/arvados/v1/dispatch/instances/hold", disp.apiInstanceHold)
170                 mux.HandlerFunc("POST", "/arvados/v1/dispatch/instances/drain", disp.apiInstanceDrain)
171                 mux.HandlerFunc("POST", "/arvados/v1/dispatch/instances/run", disp.apiInstanceRun)
172                 mux.HandlerFunc("POST", "/arvados/v1/dispatch/instances/kill", disp.apiInstanceKill)
173                 metricsH := promhttp.HandlerFor(disp.Registry, promhttp.HandlerOpts{
174                         ErrorLog: disp.logger,
175                 })
176                 mux.Handler("GET", "/metrics", metricsH)
177                 mux.Handler("GET", "/metrics.json", metricsH)
178                 mux.Handler("GET", "/_health/:check", &health.Handler{
179                         Token:  disp.Cluster.ManagementToken,
180                         Prefix: "/_health/",
181                         Routes: health.Routes{"ping": disp.CheckHealth},
182                 })
183                 disp.httpHandler = auth.RequireLiteralToken(disp.Cluster.ManagementToken, mux)
184         }
185 }
186
187 func (disp *dispatcher) run() {
188         defer dblock.Dispatch.Unlock()
189         defer close(disp.stopped)
190         defer disp.instanceSet.Stop()
191         defer disp.pool.Stop()
192
193         staleLockTimeout := time.Duration(disp.Cluster.Containers.StaleLockTimeout)
194         if staleLockTimeout == 0 {
195                 staleLockTimeout = defaultStaleLockTimeout
196         }
197         pollInterval := time.Duration(disp.Cluster.Containers.CloudVMs.PollInterval)
198         if pollInterval <= 0 {
199                 pollInterval = defaultPollInterval
200         }
201         sched := scheduler.New(disp.Context, disp.ArvClient, disp.queue, disp.pool, disp.Registry, staleLockTimeout, pollInterval, disp.Cluster.Containers.CloudVMs.MaxInstances, disp.Cluster.Containers.CloudVMs.SupervisorFraction)
202         sched.Start()
203         defer sched.Stop()
204
205         <-disp.stop
206 }
207
208 // Management API: all active and queued containers.
209 func (disp *dispatcher) apiContainers(w http.ResponseWriter, r *http.Request) {
210         var resp struct {
211                 Items []container.QueueEnt `json:"items"`
212         }
213         qEntries, _ := disp.queue.Entries()
214         for _, ent := range qEntries {
215                 resp.Items = append(resp.Items, ent)
216         }
217         json.NewEncoder(w).Encode(resp)
218 }
219
220 // Management API: all active instances (cloud VMs).
221 func (disp *dispatcher) apiInstances(w http.ResponseWriter, r *http.Request) {
222         var resp struct {
223                 Items []worker.InstanceView `json:"items"`
224         }
225         resp.Items = disp.pool.Instances()
226         json.NewEncoder(w).Encode(resp)
227 }
228
229 // Management API: set idle behavior to "hold" for specified instance.
230 func (disp *dispatcher) apiInstanceHold(w http.ResponseWriter, r *http.Request) {
231         disp.apiInstanceIdleBehavior(w, r, worker.IdleBehaviorHold)
232 }
233
234 // Management API: set idle behavior to "drain" for specified instance.
235 func (disp *dispatcher) apiInstanceDrain(w http.ResponseWriter, r *http.Request) {
236         disp.apiInstanceIdleBehavior(w, r, worker.IdleBehaviorDrain)
237 }
238
239 // Management API: set idle behavior to "run" for specified instance.
240 func (disp *dispatcher) apiInstanceRun(w http.ResponseWriter, r *http.Request) {
241         disp.apiInstanceIdleBehavior(w, r, worker.IdleBehaviorRun)
242 }
243
244 // Management API: shutdown/destroy specified instance now.
245 func (disp *dispatcher) apiInstanceKill(w http.ResponseWriter, r *http.Request) {
246         id := cloud.InstanceID(r.FormValue("instance_id"))
247         if id == "" {
248                 httpserver.Error(w, "instance_id parameter not provided", http.StatusBadRequest)
249                 return
250         }
251         err := disp.pool.KillInstance(id, "via management API: "+r.FormValue("reason"))
252         if err != nil {
253                 httpserver.Error(w, err.Error(), http.StatusNotFound)
254                 return
255         }
256 }
257
258 // Management API: send SIGTERM to specified container's crunch-run
259 // process now.
260 func (disp *dispatcher) apiContainerKill(w http.ResponseWriter, r *http.Request) {
261         uuid := r.FormValue("container_uuid")
262         if uuid == "" {
263                 httpserver.Error(w, "container_uuid parameter not provided", http.StatusBadRequest)
264                 return
265         }
266         if !disp.pool.KillContainer(uuid, "via management API: "+r.FormValue("reason")) {
267                 httpserver.Error(w, "container not found", http.StatusNotFound)
268                 return
269         }
270 }
271
272 func (disp *dispatcher) apiInstanceIdleBehavior(w http.ResponseWriter, r *http.Request, want worker.IdleBehavior) {
273         id := cloud.InstanceID(r.FormValue("instance_id"))
274         if id == "" {
275                 httpserver.Error(w, "instance_id parameter not provided", http.StatusBadRequest)
276                 return
277         }
278         err := disp.pool.SetIdleBehavior(id, want)
279         if err != nil {
280                 httpserver.Error(w, err.Error(), http.StatusNotFound)
281                 return
282         }
283 }