14324: Use logrus in Azure driver. Fix Sirupsen->sirupsen in imports
[arvados.git] / lib / dispatchcloud / dispatcher.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package dispatchcloud
6
7 import (
8         "crypto/md5"
9         "encoding/json"
10         "fmt"
11         "net/http"
12         "strings"
13         "sync"
14         "time"
15
16         "git.curoverse.com/arvados.git/lib/cloud"
17         "git.curoverse.com/arvados.git/lib/dispatchcloud/container"
18         "git.curoverse.com/arvados.git/lib/dispatchcloud/scheduler"
19         "git.curoverse.com/arvados.git/lib/dispatchcloud/ssh_executor"
20         "git.curoverse.com/arvados.git/lib/dispatchcloud/worker"
21         "git.curoverse.com/arvados.git/sdk/go/arvados"
22         "git.curoverse.com/arvados.git/sdk/go/auth"
23         "git.curoverse.com/arvados.git/sdk/go/httpserver"
24         "github.com/prometheus/client_golang/prometheus"
25         "github.com/prometheus/client_golang/prometheus/promhttp"
26         "github.com/sirupsen/logrus"
27         "golang.org/x/crypto/ssh"
28 )
29
30 const (
31         defaultPollInterval     = time.Second
32         defaultStaleLockTimeout = time.Minute
33 )
34
35 type pool interface {
36         scheduler.WorkerPool
37         Instances() []worker.InstanceView
38         Stop()
39 }
40
41 type dispatcher struct {
42         Cluster       *arvados.Cluster
43         InstanceSetID cloud.InstanceSetID
44
45         logger      logrus.FieldLogger
46         reg         *prometheus.Registry
47         instanceSet cloud.InstanceSet
48         pool        pool
49         queue       scheduler.ContainerQueue
50         httpHandler http.Handler
51         sshKey      ssh.Signer
52
53         setupOnce sync.Once
54         stop      chan struct{}
55         stopped   chan struct{}
56 }
57
58 // Start starts the dispatcher. Start can be called multiple times
59 // with no ill effect.
60 func (disp *dispatcher) Start() {
61         disp.setupOnce.Do(disp.setup)
62 }
63
64 // ServeHTTP implements service.Handler.
65 func (disp *dispatcher) ServeHTTP(w http.ResponseWriter, r *http.Request) {
66         disp.Start()
67         disp.httpHandler.ServeHTTP(w, r)
68 }
69
70 // CheckHealth implements service.Handler.
71 func (disp *dispatcher) CheckHealth() error {
72         disp.Start()
73         return nil
74 }
75
76 // Stop dispatching containers and release resources. Typically used
77 // in tests.
78 func (disp *dispatcher) Close() {
79         disp.Start()
80         select {
81         case disp.stop <- struct{}{}:
82         default:
83         }
84         <-disp.stopped
85 }
86
87 // Make a worker.Executor for the given instance.
88 func (disp *dispatcher) newExecutor(inst cloud.Instance) worker.Executor {
89         exr := ssh_executor.New(inst)
90         exr.SetSigners(disp.sshKey)
91         return exr
92 }
93
94 func (disp *dispatcher) typeChooser(ctr *arvados.Container) (arvados.InstanceType, error) {
95         return ChooseInstanceType(disp.Cluster, ctr)
96 }
97
98 func (disp *dispatcher) setup() {
99         disp.initialize()
100         go disp.run()
101 }
102
103 func (disp *dispatcher) initialize() {
104         arvClient := arvados.NewClientFromEnv()
105         if disp.InstanceSetID == "" {
106                 if strings.HasPrefix(arvClient.AuthToken, "v2/") {
107                         disp.InstanceSetID = cloud.InstanceSetID(strings.Split(arvClient.AuthToken, "/")[1])
108                 } else {
109                         // Use some other string unique to this token
110                         // that doesn't reveal the token itself.
111                         disp.InstanceSetID = cloud.InstanceSetID(fmt.Sprintf("%x", md5.Sum([]byte(arvClient.AuthToken))))
112                 }
113         }
114         disp.stop = make(chan struct{}, 1)
115         disp.stopped = make(chan struct{})
116         disp.logger = logrus.StandardLogger()
117
118         if key, err := ssh.ParsePrivateKey(disp.Cluster.Dispatch.PrivateKey); err != nil {
119                 disp.logger.Fatalf("error parsing configured Dispatch.PrivateKey: %s", err)
120         } else {
121                 disp.sshKey = key
122         }
123
124         instanceSet, err := newInstanceSet(disp.Cluster, disp.InstanceSetID, disp.logger)
125         if err != nil {
126                 disp.logger.Fatalf("error initializing driver: %s", err)
127         }
128         disp.instanceSet = &instanceSetProxy{instanceSet}
129         disp.reg = prometheus.NewRegistry()
130         disp.pool = worker.NewPool(disp.logger, disp.reg, disp.instanceSet, disp.newExecutor, disp.Cluster)
131         disp.queue = container.NewQueue(disp.logger, disp.reg, disp.typeChooser, arvClient)
132
133         if disp.Cluster.ManagementToken == "" {
134                 disp.httpHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
135                         http.Error(w, "Management API authentication is not configured", http.StatusForbidden)
136                 })
137         } else {
138                 mux := http.NewServeMux()
139                 mux.HandleFunc("/arvados/v1/dispatch/containers", disp.apiContainers)
140                 mux.HandleFunc("/arvados/v1/dispatch/instances", disp.apiInstances)
141                 metricsH := promhttp.HandlerFor(disp.reg, promhttp.HandlerOpts{
142                         ErrorLog: disp.logger,
143                 })
144                 mux.Handle("/metrics", metricsH)
145                 mux.Handle("/metrics.json", metricsH)
146                 disp.httpHandler = auth.RequireLiteralToken(disp.Cluster.ManagementToken, mux)
147         }
148 }
149
150 func (disp *dispatcher) run() {
151         defer close(disp.stopped)
152         defer disp.instanceSet.Stop()
153         defer disp.pool.Stop()
154
155         staleLockTimeout := time.Duration(disp.Cluster.Dispatch.StaleLockTimeout)
156         if staleLockTimeout == 0 {
157                 staleLockTimeout = defaultStaleLockTimeout
158         }
159         pollInterval := time.Duration(disp.Cluster.Dispatch.PollInterval)
160         if pollInterval <= 0 {
161                 pollInterval = defaultPollInterval
162         }
163         sched := scheduler.New(disp.logger, disp.queue, disp.pool, staleLockTimeout, pollInterval)
164         sched.Start()
165         defer sched.Stop()
166
167         <-disp.stop
168 }
169
170 // Management API: all active and queued containers.
171 func (disp *dispatcher) apiContainers(w http.ResponseWriter, r *http.Request) {
172         if r.Method != "GET" {
173                 httpserver.Error(w, "method not allowed", http.StatusMethodNotAllowed)
174                 return
175         }
176         var resp struct {
177                 Items []container.QueueEnt
178         }
179         qEntries, _ := disp.queue.Entries()
180         for _, ent := range qEntries {
181                 resp.Items = append(resp.Items, ent)
182         }
183         json.NewEncoder(w).Encode(resp)
184 }
185
186 // Management API: all active instances (cloud VMs).
187 func (disp *dispatcher) apiInstances(w http.ResponseWriter, r *http.Request) {
188         if r.Method != "GET" {
189                 httpserver.Error(w, "method not allowed", http.StatusMethodNotAllowed)
190                 return
191         }
192         var resp struct {
193                 Items []worker.InstanceView
194         }
195         resp.Items = disp.pool.Instances()
196         json.NewEncoder(w).Encode(resp)
197 }