20522: Load dispatch key from file if configured as file:///...
[arvados.git] / lib / dispatchcloud / dispatcher.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package dispatchcloud
6
7 import (
8         "context"
9         "crypto/md5"
10         "encoding/json"
11         "fmt"
12         "net/http"
13         "strings"
14         "sync"
15         "time"
16
17         "git.arvados.org/arvados.git/lib/cloud"
18         "git.arvados.org/arvados.git/lib/config"
19         "git.arvados.org/arvados.git/lib/controller/dblock"
20         "git.arvados.org/arvados.git/lib/ctrlctx"
21         "git.arvados.org/arvados.git/lib/dispatchcloud/container"
22         "git.arvados.org/arvados.git/lib/dispatchcloud/scheduler"
23         "git.arvados.org/arvados.git/lib/dispatchcloud/sshexecutor"
24         "git.arvados.org/arvados.git/lib/dispatchcloud/worker"
25         "git.arvados.org/arvados.git/sdk/go/arvados"
26         "git.arvados.org/arvados.git/sdk/go/auth"
27         "git.arvados.org/arvados.git/sdk/go/ctxlog"
28         "git.arvados.org/arvados.git/sdk/go/health"
29         "git.arvados.org/arvados.git/sdk/go/httpserver"
30         "github.com/julienschmidt/httprouter"
31         "github.com/prometheus/client_golang/prometheus"
32         "github.com/prometheus/client_golang/prometheus/promhttp"
33         "github.com/sirupsen/logrus"
34         "golang.org/x/crypto/ssh"
35 )
36
37 const (
38         defaultPollInterval     = time.Second
39         defaultStaleLockTimeout = time.Minute
40 )
41
42 type pool interface {
43         scheduler.WorkerPool
44         CheckHealth() error
45         Instances() []worker.InstanceView
46         SetIdleBehavior(cloud.InstanceID, worker.IdleBehavior) error
47         KillInstance(id cloud.InstanceID, reason string) error
48         Stop()
49 }
50
51 type dispatcher struct {
52         Cluster       *arvados.Cluster
53         Context       context.Context
54         ArvClient     *arvados.Client
55         AuthToken     string
56         Registry      *prometheus.Registry
57         InstanceSetID cloud.InstanceSetID
58
59         dbConnector ctrlctx.DBConnector
60         logger      logrus.FieldLogger
61         instanceSet cloud.InstanceSet
62         pool        pool
63         queue       scheduler.ContainerQueue
64         httpHandler http.Handler
65         sshKey      ssh.Signer
66
67         setupOnce sync.Once
68         stop      chan struct{}
69         stopped   chan struct{}
70 }
71
72 // Start starts the dispatcher. Start can be called multiple times
73 // with no ill effect.
74 func (disp *dispatcher) Start() {
75         disp.setupOnce.Do(disp.setup)
76 }
77
78 // ServeHTTP implements service.Handler.
79 func (disp *dispatcher) ServeHTTP(w http.ResponseWriter, r *http.Request) {
80         disp.Start()
81         disp.httpHandler.ServeHTTP(w, r)
82 }
83
84 // CheckHealth implements service.Handler.
85 func (disp *dispatcher) CheckHealth() error {
86         disp.Start()
87         return disp.pool.CheckHealth()
88 }
89
90 // Done implements service.Handler.
91 func (disp *dispatcher) Done() <-chan struct{} {
92         return disp.stopped
93 }
94
95 // Stop dispatching containers and release resources. Typically used
96 // in tests.
97 func (disp *dispatcher) Close() {
98         disp.Start()
99         select {
100         case disp.stop <- struct{}{}:
101         default:
102         }
103         <-disp.stopped
104 }
105
106 // Make a worker.Executor for the given instance.
107 func (disp *dispatcher) newExecutor(inst cloud.Instance) worker.Executor {
108         exr := sshexecutor.New(inst)
109         exr.SetTargetPort(disp.Cluster.Containers.CloudVMs.SSHPort)
110         exr.SetSigners(disp.sshKey)
111         return exr
112 }
113
114 func (disp *dispatcher) typeChooser(ctr *arvados.Container) (arvados.InstanceType, error) {
115         return ChooseInstanceType(disp.Cluster, ctr)
116 }
117
118 func (disp *dispatcher) setup() {
119         disp.initialize()
120         go disp.run()
121 }
122
123 func (disp *dispatcher) initialize() {
124         disp.logger = ctxlog.FromContext(disp.Context)
125         disp.dbConnector = ctrlctx.DBConnector{PostgreSQL: disp.Cluster.PostgreSQL}
126
127         disp.ArvClient.AuthToken = disp.AuthToken
128
129         if disp.InstanceSetID == "" {
130                 if strings.HasPrefix(disp.AuthToken, "v2/") {
131                         disp.InstanceSetID = cloud.InstanceSetID(strings.Split(disp.AuthToken, "/")[1])
132                 } else {
133                         // Use some other string unique to this token
134                         // that doesn't reveal the token itself.
135                         disp.InstanceSetID = cloud.InstanceSetID(fmt.Sprintf("%x", md5.Sum([]byte(disp.AuthToken))))
136                 }
137         }
138         disp.stop = make(chan struct{}, 1)
139         disp.stopped = make(chan struct{})
140
141         if key, err := config.LoadSSHKey(disp.Cluster.Containers.DispatchPrivateKey); err != nil {
142                 disp.logger.Fatalf("error parsing configured Containers.DispatchPrivateKey: %s", err)
143         } else {
144                 disp.sshKey = key
145         }
146
147         instanceSet, err := newInstanceSet(disp.Cluster, disp.InstanceSetID, disp.logger, disp.Registry)
148         if err != nil {
149                 disp.logger.Fatalf("error initializing driver: %s", err)
150         }
151         dblock.Dispatch.Lock(disp.Context, disp.dbConnector.GetDB)
152         disp.instanceSet = instanceSet
153         disp.pool = worker.NewPool(disp.logger, disp.ArvClient, disp.Registry, disp.InstanceSetID, disp.instanceSet, disp.newExecutor, disp.sshKey.PublicKey(), disp.Cluster)
154         disp.queue = container.NewQueue(disp.logger, disp.Registry, disp.typeChooser, disp.ArvClient)
155
156         if disp.Cluster.ManagementToken == "" {
157                 disp.httpHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
158                         http.Error(w, "Management API authentication is not configured", http.StatusForbidden)
159                 })
160         } else {
161                 mux := httprouter.New()
162                 mux.HandlerFunc("GET", "/arvados/v1/dispatch/containers", disp.apiContainers)
163                 mux.HandlerFunc("POST", "/arvados/v1/dispatch/containers/kill", disp.apiContainerKill)
164                 mux.HandlerFunc("GET", "/arvados/v1/dispatch/instances", disp.apiInstances)
165                 mux.HandlerFunc("POST", "/arvados/v1/dispatch/instances/hold", disp.apiInstanceHold)
166                 mux.HandlerFunc("POST", "/arvados/v1/dispatch/instances/drain", disp.apiInstanceDrain)
167                 mux.HandlerFunc("POST", "/arvados/v1/dispatch/instances/run", disp.apiInstanceRun)
168                 mux.HandlerFunc("POST", "/arvados/v1/dispatch/instances/kill", disp.apiInstanceKill)
169                 metricsH := promhttp.HandlerFor(disp.Registry, promhttp.HandlerOpts{
170                         ErrorLog: disp.logger,
171                 })
172                 mux.Handler("GET", "/metrics", metricsH)
173                 mux.Handler("GET", "/metrics.json", metricsH)
174                 mux.Handler("GET", "/_health/:check", &health.Handler{
175                         Token:  disp.Cluster.ManagementToken,
176                         Prefix: "/_health/",
177                         Routes: health.Routes{"ping": disp.CheckHealth},
178                 })
179                 disp.httpHandler = auth.RequireLiteralToken(disp.Cluster.ManagementToken, mux)
180         }
181 }
182
183 func (disp *dispatcher) run() {
184         defer dblock.Dispatch.Unlock()
185         defer close(disp.stopped)
186         defer disp.instanceSet.Stop()
187         defer disp.pool.Stop()
188
189         staleLockTimeout := time.Duration(disp.Cluster.Containers.StaleLockTimeout)
190         if staleLockTimeout == 0 {
191                 staleLockTimeout = defaultStaleLockTimeout
192         }
193         pollInterval := time.Duration(disp.Cluster.Containers.CloudVMs.PollInterval)
194         if pollInterval <= 0 {
195                 pollInterval = defaultPollInterval
196         }
197         maxSupervisors := int(float64(disp.Cluster.Containers.CloudVMs.MaxInstances) * disp.Cluster.Containers.CloudVMs.SupervisorFraction)
198         if maxSupervisors == 0 && disp.Cluster.Containers.CloudVMs.SupervisorFraction > 0 {
199                 maxSupervisors = 1
200         }
201         sched := scheduler.New(disp.Context, disp.ArvClient, disp.queue, disp.pool, disp.Registry, staleLockTimeout, pollInterval, maxSupervisors)
202         sched.Start()
203         defer sched.Stop()
204
205         <-disp.stop
206 }
207
208 // Management API: all active and queued containers.
209 func (disp *dispatcher) apiContainers(w http.ResponseWriter, r *http.Request) {
210         var resp struct {
211                 Items []container.QueueEnt `json:"items"`
212         }
213         qEntries, _ := disp.queue.Entries()
214         for _, ent := range qEntries {
215                 resp.Items = append(resp.Items, ent)
216         }
217         json.NewEncoder(w).Encode(resp)
218 }
219
220 // Management API: all active instances (cloud VMs).
221 func (disp *dispatcher) apiInstances(w http.ResponseWriter, r *http.Request) {
222         var resp struct {
223                 Items []worker.InstanceView `json:"items"`
224         }
225         resp.Items = disp.pool.Instances()
226         json.NewEncoder(w).Encode(resp)
227 }
228
229 // Management API: set idle behavior to "hold" for specified instance.
230 func (disp *dispatcher) apiInstanceHold(w http.ResponseWriter, r *http.Request) {
231         disp.apiInstanceIdleBehavior(w, r, worker.IdleBehaviorHold)
232 }
233
234 // Management API: set idle behavior to "drain" for specified instance.
235 func (disp *dispatcher) apiInstanceDrain(w http.ResponseWriter, r *http.Request) {
236         disp.apiInstanceIdleBehavior(w, r, worker.IdleBehaviorDrain)
237 }
238
239 // Management API: set idle behavior to "run" for specified instance.
240 func (disp *dispatcher) apiInstanceRun(w http.ResponseWriter, r *http.Request) {
241         disp.apiInstanceIdleBehavior(w, r, worker.IdleBehaviorRun)
242 }
243
244 // Management API: shutdown/destroy specified instance now.
245 func (disp *dispatcher) apiInstanceKill(w http.ResponseWriter, r *http.Request) {
246         id := cloud.InstanceID(r.FormValue("instance_id"))
247         if id == "" {
248                 httpserver.Error(w, "instance_id parameter not provided", http.StatusBadRequest)
249                 return
250         }
251         err := disp.pool.KillInstance(id, "via management API: "+r.FormValue("reason"))
252         if err != nil {
253                 httpserver.Error(w, err.Error(), http.StatusNotFound)
254                 return
255         }
256 }
257
258 // Management API: send SIGTERM to specified container's crunch-run
259 // process now.
260 func (disp *dispatcher) apiContainerKill(w http.ResponseWriter, r *http.Request) {
261         uuid := r.FormValue("container_uuid")
262         if uuid == "" {
263                 httpserver.Error(w, "container_uuid parameter not provided", http.StatusBadRequest)
264                 return
265         }
266         if !disp.pool.KillContainer(uuid, "via management API: "+r.FormValue("reason")) {
267                 httpserver.Error(w, "container not found", http.StatusNotFound)
268                 return
269         }
270 }
271
272 func (disp *dispatcher) apiInstanceIdleBehavior(w http.ResponseWriter, r *http.Request, want worker.IdleBehavior) {
273         id := cloud.InstanceID(r.FormValue("instance_id"))
274         if id == "" {
275                 httpserver.Error(w, "instance_id parameter not provided", http.StatusBadRequest)
276                 return
277         }
278         err := disp.pool.SetIdleBehavior(id, want)
279         if err != nil {
280                 httpserver.Error(w, err.Error(), http.StatusNotFound)
281                 return
282         }
283 }