14360: Shutdown between tests to eliminate leaking logs.
[arvados.git] / lib / dispatchcloud / dispatcher.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package dispatchcloud
6
7 import (
8         "crypto/md5"
9         "encoding/json"
10         "fmt"
11         "net/http"
12         "strings"
13         "sync"
14         "time"
15
16         "git.curoverse.com/arvados.git/lib/cloud"
17         "git.curoverse.com/arvados.git/lib/dispatchcloud/container"
18         "git.curoverse.com/arvados.git/lib/dispatchcloud/scheduler"
19         "git.curoverse.com/arvados.git/lib/dispatchcloud/ssh_executor"
20         "git.curoverse.com/arvados.git/lib/dispatchcloud/worker"
21         "git.curoverse.com/arvados.git/sdk/go/arvados"
22         "git.curoverse.com/arvados.git/sdk/go/auth"
23         "git.curoverse.com/arvados.git/sdk/go/httpserver"
24         "github.com/Sirupsen/logrus"
25         "github.com/prometheus/client_golang/prometheus"
26         "github.com/prometheus/client_golang/prometheus/promhttp"
27         "golang.org/x/crypto/ssh"
28 )
29
30 const (
31         defaultPollInterval     = time.Second
32         defaultStaleLockTimeout = time.Minute
33 )
34
35 type pool interface {
36         scheduler.WorkerPool
37         Instances() []worker.InstanceView
38 }
39
40 type dispatcher struct {
41         Cluster       *arvados.Cluster
42         InstanceSetID cloud.InstanceSetID
43
44         logger      logrus.FieldLogger
45         reg         *prometheus.Registry
46         instanceSet cloud.InstanceSet
47         pool        pool
48         queue       scheduler.ContainerQueue
49         httpHandler http.Handler
50         sshKey      ssh.Signer
51
52         setupOnce sync.Once
53         stop      chan struct{}
54         stopped   chan struct{}
55 }
56
57 // Start starts the dispatcher. Start can be called multiple times
58 // with no ill effect.
59 func (disp *dispatcher) Start() {
60         disp.setupOnce.Do(disp.setup)
61 }
62
63 // ServeHTTP implements service.Handler.
64 func (disp *dispatcher) ServeHTTP(w http.ResponseWriter, r *http.Request) {
65         disp.Start()
66         disp.httpHandler.ServeHTTP(w, r)
67 }
68
69 // CheckHealth implements service.Handler.
70 func (disp *dispatcher) CheckHealth() error {
71         disp.Start()
72         return nil
73 }
74
75 // Stop dispatching containers and release resources. Typically used
76 // in tests.
77 func (disp *dispatcher) Close() {
78         disp.Start()
79         select {
80         case disp.stop <- struct{}{}:
81         default:
82         }
83         <-disp.stopped
84 }
85
86 // Make a worker.Executor for the given instance.
87 func (disp *dispatcher) newExecutor(inst cloud.Instance) worker.Executor {
88         exr := ssh_executor.New(inst)
89         exr.SetSigners(disp.sshKey)
90         return exr
91 }
92
93 func (disp *dispatcher) typeChooser(ctr *arvados.Container) (arvados.InstanceType, error) {
94         return ChooseInstanceType(disp.Cluster, ctr)
95 }
96
97 func (disp *dispatcher) setup() {
98         disp.initialize()
99         go disp.run()
100 }
101
102 func (disp *dispatcher) initialize() {
103         arvClient := arvados.NewClientFromEnv()
104         if disp.InstanceSetID == "" {
105                 if strings.HasPrefix(arvClient.AuthToken, "v2/") {
106                         disp.InstanceSetID = cloud.InstanceSetID(strings.Split(arvClient.AuthToken, "/")[1])
107                 } else {
108                         // Use some other string unique to this token
109                         // that doesn't reveal the token itself.
110                         disp.InstanceSetID = cloud.InstanceSetID(fmt.Sprintf("%x", md5.Sum([]byte(arvClient.AuthToken))))
111                 }
112         }
113         disp.stop = make(chan struct{}, 1)
114         disp.stopped = make(chan struct{})
115         disp.logger = logrus.StandardLogger()
116
117         if key, err := ssh.ParsePrivateKey(disp.Cluster.Dispatch.PrivateKey); err != nil {
118                 disp.logger.Fatalf("error parsing configured Dispatch.PrivateKey: %s", err)
119         } else {
120                 disp.sshKey = key
121         }
122
123         instanceSet, err := newInstanceSet(disp.Cluster, disp.InstanceSetID)
124         if err != nil {
125                 disp.logger.Fatalf("error initializing driver: %s", err)
126         }
127         disp.instanceSet = &instanceSetProxy{instanceSet}
128         disp.reg = prometheus.NewRegistry()
129         disp.pool = worker.NewPool(disp.logger, disp.reg, disp.instanceSet, disp.newExecutor, disp.Cluster)
130         disp.queue = container.NewQueue(disp.logger, disp.reg, disp.typeChooser, arvClient)
131
132         if disp.Cluster.ManagementToken == "" {
133                 disp.httpHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
134                         http.Error(w, "Management API authentication is not configured", http.StatusForbidden)
135                 })
136         } else {
137                 mux := http.NewServeMux()
138                 mux.HandleFunc("/arvados/v1/dispatch/containers", disp.apiContainers)
139                 mux.HandleFunc("/arvados/v1/dispatch/instances", disp.apiInstances)
140                 metricsH := promhttp.HandlerFor(disp.reg, promhttp.HandlerOpts{
141                         ErrorLog: disp.logger,
142                 })
143                 mux.Handle("/metrics", metricsH)
144                 mux.Handle("/metrics.json", metricsH)
145                 disp.httpHandler = auth.RequireLiteralToken(disp.Cluster.ManagementToken, mux)
146         }
147 }
148
149 func (disp *dispatcher) run() {
150         defer close(disp.stopped)
151         defer disp.instanceSet.Stop()
152
153         staleLockTimeout := time.Duration(disp.Cluster.Dispatch.StaleLockTimeout)
154         if staleLockTimeout == 0 {
155                 staleLockTimeout = defaultStaleLockTimeout
156         }
157         pollInterval := time.Duration(disp.Cluster.Dispatch.PollInterval)
158         if pollInterval <= 0 {
159                 pollInterval = defaultPollInterval
160         }
161         sched := scheduler.New(disp.logger, disp.queue, disp.pool, staleLockTimeout, pollInterval)
162         sched.Start()
163         defer sched.Stop()
164
165         <-disp.stop
166 }
167
168 // Management API: all active and queued containers.
169 func (disp *dispatcher) apiContainers(w http.ResponseWriter, r *http.Request) {
170         if r.Method != "GET" {
171                 httpserver.Error(w, "method not allowed", http.StatusMethodNotAllowed)
172                 return
173         }
174         var resp struct {
175                 Items []container.QueueEnt
176         }
177         qEntries, _ := disp.queue.Entries()
178         for _, ent := range qEntries {
179                 resp.Items = append(resp.Items, ent)
180         }
181         json.NewEncoder(w).Encode(resp)
182 }
183
184 // Management API: all active instances (cloud VMs).
185 func (disp *dispatcher) apiInstances(w http.ResponseWriter, r *http.Request) {
186         if r.Method != "GET" {
187                 httpserver.Error(w, "method not allowed", http.StatusMethodNotAllowed)
188                 return
189         }
190         var resp struct {
191                 Items []worker.InstanceView
192         }
193         resp.Items = disp.pool.Instances()
194         json.NewEncoder(w).Encode(resp)
195 }