Merge branch '16265-security-updates' into dependabot/bundler/apps/workbench/loofah...
[arvados.git] / lib / dispatchcloud / dispatcher.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package dispatchcloud
6
7 import (
8         "context"
9         "crypto/md5"
10         "encoding/json"
11         "fmt"
12         "net/http"
13         "strings"
14         "sync"
15         "time"
16
17         "git.arvados.org/arvados.git/lib/cloud"
18         "git.arvados.org/arvados.git/lib/dispatchcloud/container"
19         "git.arvados.org/arvados.git/lib/dispatchcloud/scheduler"
20         "git.arvados.org/arvados.git/lib/dispatchcloud/ssh_executor"
21         "git.arvados.org/arvados.git/lib/dispatchcloud/worker"
22         "git.arvados.org/arvados.git/sdk/go/arvados"
23         "git.arvados.org/arvados.git/sdk/go/auth"
24         "git.arvados.org/arvados.git/sdk/go/ctxlog"
25         "git.arvados.org/arvados.git/sdk/go/httpserver"
26         "github.com/julienschmidt/httprouter"
27         "github.com/prometheus/client_golang/prometheus"
28         "github.com/prometheus/client_golang/prometheus/promhttp"
29         "github.com/sirupsen/logrus"
30         "golang.org/x/crypto/ssh"
31 )
32
33 const (
34         defaultPollInterval     = time.Second
35         defaultStaleLockTimeout = time.Minute
36 )
37
38 type pool interface {
39         scheduler.WorkerPool
40         CheckHealth() error
41         Instances() []worker.InstanceView
42         SetIdleBehavior(cloud.InstanceID, worker.IdleBehavior) error
43         KillInstance(id cloud.InstanceID, reason string) error
44         Stop()
45 }
46
47 type dispatcher struct {
48         Cluster       *arvados.Cluster
49         Context       context.Context
50         ArvClient     *arvados.Client
51         AuthToken     string
52         Registry      *prometheus.Registry
53         InstanceSetID cloud.InstanceSetID
54
55         logger      logrus.FieldLogger
56         instanceSet cloud.InstanceSet
57         pool        pool
58         queue       scheduler.ContainerQueue
59         httpHandler http.Handler
60         sshKey      ssh.Signer
61
62         setupOnce sync.Once
63         stop      chan struct{}
64         stopped   chan struct{}
65 }
66
67 // Start starts the dispatcher. Start can be called multiple times
68 // with no ill effect.
69 func (disp *dispatcher) Start() {
70         disp.setupOnce.Do(disp.setup)
71 }
72
73 // ServeHTTP implements service.Handler.
74 func (disp *dispatcher) ServeHTTP(w http.ResponseWriter, r *http.Request) {
75         disp.Start()
76         disp.httpHandler.ServeHTTP(w, r)
77 }
78
79 // CheckHealth implements service.Handler.
80 func (disp *dispatcher) CheckHealth() error {
81         disp.Start()
82         return disp.pool.CheckHealth()
83 }
84
85 // Stop dispatching containers and release resources. Typically used
86 // in tests.
87 func (disp *dispatcher) Close() {
88         disp.Start()
89         select {
90         case disp.stop <- struct{}{}:
91         default:
92         }
93         <-disp.stopped
94 }
95
96 // Make a worker.Executor for the given instance.
97 func (disp *dispatcher) newExecutor(inst cloud.Instance) worker.Executor {
98         exr := ssh_executor.New(inst)
99         exr.SetTargetPort(disp.Cluster.Containers.CloudVMs.SSHPort)
100         exr.SetSigners(disp.sshKey)
101         return exr
102 }
103
104 func (disp *dispatcher) typeChooser(ctr *arvados.Container) (arvados.InstanceType, error) {
105         return ChooseInstanceType(disp.Cluster, ctr)
106 }
107
108 func (disp *dispatcher) setup() {
109         disp.initialize()
110         go disp.run()
111 }
112
113 func (disp *dispatcher) initialize() {
114         disp.logger = ctxlog.FromContext(disp.Context)
115
116         disp.ArvClient.AuthToken = disp.AuthToken
117
118         if disp.InstanceSetID == "" {
119                 if strings.HasPrefix(disp.AuthToken, "v2/") {
120                         disp.InstanceSetID = cloud.InstanceSetID(strings.Split(disp.AuthToken, "/")[1])
121                 } else {
122                         // Use some other string unique to this token
123                         // that doesn't reveal the token itself.
124                         disp.InstanceSetID = cloud.InstanceSetID(fmt.Sprintf("%x", md5.Sum([]byte(disp.AuthToken))))
125                 }
126         }
127         disp.stop = make(chan struct{}, 1)
128         disp.stopped = make(chan struct{})
129
130         if key, err := ssh.ParsePrivateKey([]byte(disp.Cluster.Containers.DispatchPrivateKey)); err != nil {
131                 disp.logger.Fatalf("error parsing configured Containers.DispatchPrivateKey: %s", err)
132         } else {
133                 disp.sshKey = key
134         }
135
136         instanceSet, err := newInstanceSet(disp.Cluster, disp.InstanceSetID, disp.logger, disp.Registry)
137         if err != nil {
138                 disp.logger.Fatalf("error initializing driver: %s", err)
139         }
140         disp.instanceSet = instanceSet
141         disp.pool = worker.NewPool(disp.logger, disp.ArvClient, disp.Registry, disp.InstanceSetID, disp.instanceSet, disp.newExecutor, disp.sshKey.PublicKey(), disp.Cluster)
142         disp.queue = container.NewQueue(disp.logger, disp.Registry, disp.typeChooser, disp.ArvClient)
143
144         if disp.Cluster.ManagementToken == "" {
145                 disp.httpHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
146                         http.Error(w, "Management API authentication is not configured", http.StatusForbidden)
147                 })
148         } else {
149                 mux := httprouter.New()
150                 mux.HandlerFunc("GET", "/arvados/v1/dispatch/containers", disp.apiContainers)
151                 mux.HandlerFunc("POST", "/arvados/v1/dispatch/containers/kill", disp.apiContainerKill)
152                 mux.HandlerFunc("GET", "/arvados/v1/dispatch/instances", disp.apiInstances)
153                 mux.HandlerFunc("POST", "/arvados/v1/dispatch/instances/hold", disp.apiInstanceHold)
154                 mux.HandlerFunc("POST", "/arvados/v1/dispatch/instances/drain", disp.apiInstanceDrain)
155                 mux.HandlerFunc("POST", "/arvados/v1/dispatch/instances/run", disp.apiInstanceRun)
156                 mux.HandlerFunc("POST", "/arvados/v1/dispatch/instances/kill", disp.apiInstanceKill)
157                 metricsH := promhttp.HandlerFor(disp.Registry, promhttp.HandlerOpts{
158                         ErrorLog: disp.logger,
159                 })
160                 mux.Handler("GET", "/metrics", metricsH)
161                 mux.Handler("GET", "/metrics.json", metricsH)
162                 disp.httpHandler = auth.RequireLiteralToken(disp.Cluster.ManagementToken, mux)
163         }
164 }
165
166 func (disp *dispatcher) run() {
167         defer close(disp.stopped)
168         defer disp.instanceSet.Stop()
169         defer disp.pool.Stop()
170
171         staleLockTimeout := time.Duration(disp.Cluster.Containers.StaleLockTimeout)
172         if staleLockTimeout == 0 {
173                 staleLockTimeout = defaultStaleLockTimeout
174         }
175         pollInterval := time.Duration(disp.Cluster.Containers.CloudVMs.PollInterval)
176         if pollInterval <= 0 {
177                 pollInterval = defaultPollInterval
178         }
179         sched := scheduler.New(disp.Context, disp.queue, disp.pool, staleLockTimeout, pollInterval)
180         sched.Start()
181         defer sched.Stop()
182
183         <-disp.stop
184 }
185
186 // Management API: all active and queued containers.
187 func (disp *dispatcher) apiContainers(w http.ResponseWriter, r *http.Request) {
188         var resp struct {
189                 Items []container.QueueEnt `json:"items"`
190         }
191         qEntries, _ := disp.queue.Entries()
192         for _, ent := range qEntries {
193                 resp.Items = append(resp.Items, ent)
194         }
195         json.NewEncoder(w).Encode(resp)
196 }
197
198 // Management API: all active instances (cloud VMs).
199 func (disp *dispatcher) apiInstances(w http.ResponseWriter, r *http.Request) {
200         var resp struct {
201                 Items []worker.InstanceView `json:"items"`
202         }
203         resp.Items = disp.pool.Instances()
204         json.NewEncoder(w).Encode(resp)
205 }
206
207 // Management API: set idle behavior to "hold" for specified instance.
208 func (disp *dispatcher) apiInstanceHold(w http.ResponseWriter, r *http.Request) {
209         disp.apiInstanceIdleBehavior(w, r, worker.IdleBehaviorHold)
210 }
211
212 // Management API: set idle behavior to "drain" for specified instance.
213 func (disp *dispatcher) apiInstanceDrain(w http.ResponseWriter, r *http.Request) {
214         disp.apiInstanceIdleBehavior(w, r, worker.IdleBehaviorDrain)
215 }
216
217 // Management API: set idle behavior to "run" for specified instance.
218 func (disp *dispatcher) apiInstanceRun(w http.ResponseWriter, r *http.Request) {
219         disp.apiInstanceIdleBehavior(w, r, worker.IdleBehaviorRun)
220 }
221
222 // Management API: shutdown/destroy specified instance now.
223 func (disp *dispatcher) apiInstanceKill(w http.ResponseWriter, r *http.Request) {
224         id := cloud.InstanceID(r.FormValue("instance_id"))
225         if id == "" {
226                 httpserver.Error(w, "instance_id parameter not provided", http.StatusBadRequest)
227                 return
228         }
229         err := disp.pool.KillInstance(id, "via management API: "+r.FormValue("reason"))
230         if err != nil {
231                 httpserver.Error(w, err.Error(), http.StatusNotFound)
232                 return
233         }
234 }
235
236 // Management API: send SIGTERM to specified container's crunch-run
237 // process now.
238 func (disp *dispatcher) apiContainerKill(w http.ResponseWriter, r *http.Request) {
239         uuid := r.FormValue("container_uuid")
240         if uuid == "" {
241                 httpserver.Error(w, "container_uuid parameter not provided", http.StatusBadRequest)
242                 return
243         }
244         if !disp.pool.KillContainer(uuid, "via management API: "+r.FormValue("reason")) {
245                 httpserver.Error(w, "container not found", http.StatusNotFound)
246                 return
247         }
248 }
249
250 func (disp *dispatcher) apiInstanceIdleBehavior(w http.ResponseWriter, r *http.Request, want worker.IdleBehavior) {
251         id := cloud.InstanceID(r.FormValue("instance_id"))
252         if id == "" {
253                 httpserver.Error(w, "instance_id parameter not provided", http.StatusBadRequest)
254                 return
255         }
256         err := disp.pool.SetIdleBehavior(id, want)
257         if err != nil {
258                 httpserver.Error(w, err.Error(), http.StatusNotFound)
259                 return
260         }
261 }