14360: Merge branch 'master' into 14360-dispatch-cloud
[arvados.git] / lib / dispatchcloud / dispatcher.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package dispatchcloud
6
7 import (
8         "crypto/md5"
9         "encoding/json"
10         "fmt"
11         "net/http"
12         "strings"
13         "sync"
14         "time"
15
16         "git.curoverse.com/arvados.git/lib/cloud"
17         "git.curoverse.com/arvados.git/lib/dispatchcloud/container"
18         "git.curoverse.com/arvados.git/lib/dispatchcloud/scheduler"
19         "git.curoverse.com/arvados.git/lib/dispatchcloud/ssh_executor"
20         "git.curoverse.com/arvados.git/lib/dispatchcloud/worker"
21         "git.curoverse.com/arvados.git/sdk/go/arvados"
22         "git.curoverse.com/arvados.git/sdk/go/auth"
23         "git.curoverse.com/arvados.git/sdk/go/httpserver"
24         "github.com/Sirupsen/logrus"
25         "github.com/prometheus/client_golang/prometheus"
26         "github.com/prometheus/client_golang/prometheus/promhttp"
27         "golang.org/x/crypto/ssh"
28 )
29
30 const (
31         defaultPollInterval     = time.Second
32         defaultStaleLockTimeout = time.Minute
33 )
34
35 type pool interface {
36         scheduler.WorkerPool
37         Instances() []worker.InstanceView
38 }
39
40 type dispatcher struct {
41         Cluster       *arvados.Cluster
42         InstanceSetID cloud.InstanceSetID
43
44         logger      logrus.FieldLogger
45         reg         *prometheus.Registry
46         instanceSet cloud.InstanceSet
47         pool        pool
48         queue       scheduler.ContainerQueue
49         httpHandler http.Handler
50         sshKey      ssh.Signer
51
52         setupOnce sync.Once
53         stop      chan struct{}
54 }
55
56 // Start starts the dispatcher. Start can be called multiple times
57 // with no ill effect.
58 func (disp *dispatcher) Start() {
59         disp.setupOnce.Do(disp.setup)
60 }
61
62 // ServeHTTP implements service.Handler.
63 func (disp *dispatcher) ServeHTTP(w http.ResponseWriter, r *http.Request) {
64         disp.Start()
65         disp.httpHandler.ServeHTTP(w, r)
66 }
67
68 // CheckHealth implements service.Handler.
69 func (disp *dispatcher) CheckHealth() error {
70         disp.Start()
71         return nil
72 }
73
74 // Stop dispatching containers and release resources. Typically used
75 // in tests.
76 func (disp *dispatcher) Close() {
77         disp.Start()
78         select {
79         case disp.stop <- struct{}{}:
80         default:
81         }
82 }
83
84 // Make a worker.Executor for the given instance.
85 func (disp *dispatcher) newExecutor(inst cloud.Instance) worker.Executor {
86         exr := ssh_executor.New(inst)
87         exr.SetSigners(disp.sshKey)
88         return exr
89 }
90
91 func (disp *dispatcher) typeChooser(ctr *arvados.Container) (arvados.InstanceType, error) {
92         return ChooseInstanceType(disp.Cluster, ctr)
93 }
94
95 func (disp *dispatcher) setup() {
96         disp.initialize()
97         go disp.run()
98 }
99
100 func (disp *dispatcher) initialize() {
101         arvClient := arvados.NewClientFromEnv()
102         if disp.InstanceSetID == "" {
103                 if strings.HasPrefix(arvClient.AuthToken, "v2/") {
104                         disp.InstanceSetID = cloud.InstanceSetID(strings.Split(arvClient.AuthToken, "/")[1])
105                 } else {
106                         // Use some other string unique to this token
107                         // that doesn't reveal the token itself.
108                         disp.InstanceSetID = cloud.InstanceSetID(fmt.Sprintf("%x", md5.Sum([]byte(arvClient.AuthToken))))
109                 }
110         }
111         disp.stop = make(chan struct{}, 1)
112         disp.logger = logrus.StandardLogger()
113
114         if key, err := ssh.ParsePrivateKey(disp.Cluster.Dispatch.PrivateKey); err != nil {
115                 disp.logger.Fatalf("error parsing configured Dispatch.PrivateKey: %s", err)
116         } else {
117                 disp.sshKey = key
118         }
119
120         instanceSet, err := newInstanceSet(disp.Cluster, disp.InstanceSetID)
121         if err != nil {
122                 disp.logger.Fatalf("error initializing driver: %s", err)
123         }
124         disp.instanceSet = &instanceSetProxy{instanceSet}
125         disp.reg = prometheus.NewRegistry()
126         disp.pool = worker.NewPool(disp.logger, disp.reg, disp.instanceSet, disp.newExecutor, disp.Cluster)
127         disp.queue = container.NewQueue(disp.logger, disp.reg, disp.typeChooser, arvClient)
128
129         if disp.Cluster.ManagementToken == "" {
130                 disp.httpHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
131                         http.Error(w, "Management API authentication is not configured", http.StatusForbidden)
132                 })
133         } else {
134                 mux := http.NewServeMux()
135                 mux.HandleFunc("/arvados/v1/dispatch/containers", disp.apiContainers)
136                 mux.HandleFunc("/arvados/v1/dispatch/instances", disp.apiInstances)
137                 metricsH := promhttp.HandlerFor(disp.reg, promhttp.HandlerOpts{
138                         ErrorLog: disp.logger,
139                 })
140                 mux.Handle("/metrics", metricsH)
141                 mux.Handle("/metrics.json", metricsH)
142                 disp.httpHandler = auth.RequireLiteralToken(disp.Cluster.ManagementToken, mux)
143         }
144 }
145
146 func (disp *dispatcher) run() {
147         defer disp.instanceSet.Stop()
148
149         staleLockTimeout := time.Duration(disp.Cluster.Dispatch.StaleLockTimeout)
150         if staleLockTimeout == 0 {
151                 staleLockTimeout = defaultStaleLockTimeout
152         }
153         pollInterval := time.Duration(disp.Cluster.Dispatch.PollInterval)
154         if pollInterval <= 0 {
155                 pollInterval = defaultPollInterval
156         }
157         sched := scheduler.New(disp.logger, disp.queue, disp.pool, staleLockTimeout, pollInterval)
158         sched.Start()
159         defer sched.Stop()
160
161         <-disp.stop
162 }
163
164 // Management API: all active and queued containers.
165 func (disp *dispatcher) apiContainers(w http.ResponseWriter, r *http.Request) {
166         if r.Method != "GET" {
167                 httpserver.Error(w, "method not allowed", http.StatusMethodNotAllowed)
168                 return
169         }
170         var resp struct {
171                 Items []container.QueueEnt
172         }
173         qEntries, _ := disp.queue.Entries()
174         for _, ent := range qEntries {
175                 resp.Items = append(resp.Items, ent)
176         }
177         json.NewEncoder(w).Encode(resp)
178 }
179
180 // Management API: all active instances (cloud VMs).
181 func (disp *dispatcher) apiInstances(w http.ResponseWriter, r *http.Request) {
182         if r.Method != "GET" {
183                 httpserver.Error(w, "method not allowed", http.StatusMethodNotAllowed)
184                 return
185         }
186         var resp struct {
187                 Items []worker.InstanceView
188         }
189         resp.Items = disp.pool.Instances()
190         json.NewEncoder(w).Encode(resp)
191 }