Merge branch '15928-fs-deadlock'
[arvados.git] / sdk / go / health / aggregator.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: Apache-2.0
4
5 package health
6
7 import (
8         "context"
9         "encoding/json"
10         "errors"
11         "fmt"
12         "net/http"
13         "net/url"
14         "sync"
15         "time"
16
17         "git.arvados.org/arvados.git/sdk/go/arvados"
18         "git.arvados.org/arvados.git/sdk/go/auth"
19 )
20
21 const defaultTimeout = arvados.Duration(2 * time.Second)
22
23 // Aggregator implements http.Handler. It handles "GET /_health/all"
24 // by checking the health of all configured services on the cluster
25 // and responding 200 if everything is healthy.
26 type Aggregator struct {
27         setupOnce  sync.Once
28         httpClient *http.Client
29         timeout    arvados.Duration
30
31         Cluster *arvados.Cluster
32
33         // If non-nil, Log is called after handling each request.
34         Log func(*http.Request, error)
35 }
36
37 func (agg *Aggregator) setup() {
38         agg.httpClient = http.DefaultClient
39         if agg.timeout == 0 {
40                 // this is always the case, except in the test suite
41                 agg.timeout = defaultTimeout
42         }
43 }
44
45 func (agg *Aggregator) CheckHealth() error {
46         return nil
47 }
48
49 func (agg *Aggregator) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
50         agg.setupOnce.Do(agg.setup)
51         sendErr := func(statusCode int, err error) {
52                 resp.WriteHeader(statusCode)
53                 json.NewEncoder(resp).Encode(map[string]string{"error": err.Error()})
54                 if agg.Log != nil {
55                         agg.Log(req, err)
56                 }
57         }
58
59         resp.Header().Set("Content-Type", "application/json")
60
61         if !agg.checkAuth(req) {
62                 sendErr(http.StatusUnauthorized, errUnauthorized)
63                 return
64         }
65         if req.URL.Path != "/_health/all" {
66                 sendErr(http.StatusNotFound, errNotFound)
67                 return
68         }
69         json.NewEncoder(resp).Encode(agg.ClusterHealth())
70         if agg.Log != nil {
71                 agg.Log(req, nil)
72         }
73 }
74
75 type ClusterHealthResponse struct {
76         // "OK" if all needed services are OK, otherwise "ERROR".
77         Health string `json:"health"`
78
79         // An entry for each known health check of each known instance
80         // of each needed component: "instance of service S on node N
81         // reports health-check C is OK."
82         Checks map[string]CheckResult `json:"checks"`
83
84         // An entry for each service type: "service S is OK." This
85         // exposes problems that can't be expressed in Checks, like
86         // "service S is needed, but isn't configured to run
87         // anywhere."
88         Services map[arvados.ServiceName]ServiceHealth `json:"services"`
89 }
90
91 type CheckResult struct {
92         Health         string                 `json:"health"`
93         Error          string                 `json:"error,omitempty"`
94         HTTPStatusCode int                    `json:",omitempty"`
95         HTTPStatusText string                 `json:",omitempty"`
96         Response       map[string]interface{} `json:"response"`
97         ResponseTime   json.Number            `json:"responseTime"`
98 }
99
100 type ServiceHealth struct {
101         Health string `json:"health"`
102         N      int    `json:"n"`
103 }
104
105 func (agg *Aggregator) ClusterHealth() ClusterHealthResponse {
106         resp := ClusterHealthResponse{
107                 Health:   "OK",
108                 Checks:   make(map[string]CheckResult),
109                 Services: make(map[arvados.ServiceName]ServiceHealth),
110         }
111
112         mtx := sync.Mutex{}
113         wg := sync.WaitGroup{}
114         for svcName, svc := range agg.Cluster.Services.Map() {
115                 // Ensure svc is listed in resp.Services.
116                 mtx.Lock()
117                 if _, ok := resp.Services[svcName]; !ok {
118                         resp.Services[svcName] = ServiceHealth{Health: "ERROR"}
119                 }
120                 mtx.Unlock()
121
122                 for addr := range svc.InternalURLs {
123                         wg.Add(1)
124                         go func(svcName arvados.ServiceName, addr arvados.URL) {
125                                 defer wg.Done()
126                                 var result CheckResult
127                                 pingURL, err := agg.pingURL(addr)
128                                 if err != nil {
129                                         result = CheckResult{
130                                                 Health: "ERROR",
131                                                 Error:  err.Error(),
132                                         }
133                                 } else {
134                                         result = agg.ping(pingURL)
135                                 }
136
137                                 mtx.Lock()
138                                 defer mtx.Unlock()
139                                 resp.Checks[fmt.Sprintf("%s+%s", svcName, pingURL)] = result
140                                 if result.Health == "OK" {
141                                         h := resp.Services[svcName]
142                                         h.N++
143                                         h.Health = "OK"
144                                         resp.Services[svcName] = h
145                                 } else {
146                                         resp.Health = "ERROR"
147                                 }
148                         }(svcName, addr)
149                 }
150         }
151         wg.Wait()
152
153         // Report ERROR if a needed service didn't fail any checks
154         // merely because it isn't configured to run anywhere.
155         for _, sh := range resp.Services {
156                 if sh.Health != "OK" {
157                         resp.Health = "ERROR"
158                         break
159                 }
160         }
161         return resp
162 }
163
164 func (agg *Aggregator) pingURL(svcURL arvados.URL) (*url.URL, error) {
165         base := url.URL(svcURL)
166         return base.Parse("/_health/ping")
167 }
168
169 func (agg *Aggregator) ping(target *url.URL) (result CheckResult) {
170         t0 := time.Now()
171
172         var err error
173         defer func() {
174                 result.ResponseTime = json.Number(fmt.Sprintf("%.6f", time.Since(t0).Seconds()))
175                 if err != nil {
176                         result.Health, result.Error = "ERROR", err.Error()
177                 } else {
178                         result.Health = "OK"
179                 }
180         }()
181
182         req, err := http.NewRequest("GET", target.String(), nil)
183         if err != nil {
184                 return
185         }
186         req.Header.Set("Authorization", "Bearer "+agg.Cluster.ManagementToken)
187
188         ctx, cancel := context.WithTimeout(req.Context(), time.Duration(agg.timeout))
189         defer cancel()
190         req = req.WithContext(ctx)
191         resp, err := agg.httpClient.Do(req)
192         if err != nil {
193                 return
194         }
195         result.HTTPStatusCode = resp.StatusCode
196         result.HTTPStatusText = resp.Status
197         err = json.NewDecoder(resp.Body).Decode(&result.Response)
198         if err != nil {
199                 err = fmt.Errorf("cannot decode response: %s", err)
200         } else if resp.StatusCode != http.StatusOK {
201                 err = fmt.Errorf("HTTP %d %s", resp.StatusCode, resp.Status)
202         } else if h, _ := result.Response["health"].(string); h != "OK" {
203                 if e, ok := result.Response["error"].(string); ok && e != "" {
204                         err = errors.New(e)
205                 } else {
206                         err = fmt.Errorf("health=%q in ping response", h)
207                 }
208         }
209         return
210 }
211
212 func (agg *Aggregator) checkAuth(req *http.Request) bool {
213         creds := auth.CredentialsFromRequest(req)
214         for _, token := range creds.Tokens {
215                 if token != "" && token == agg.Cluster.ManagementToken {
216                         return true
217                 }
218         }
219         return false
220 }