21934: Use assertGreater
[arvados.git] / lib / boot / cert.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package boot
6
7 import (
8         "context"
9         "crypto/rsa"
10         "crypto/tls"
11         "crypto/x509"
12         "encoding/pem"
13         "errors"
14         "fmt"
15         "io/ioutil"
16         "net"
17         "net/http"
18         "net/url"
19         "os"
20         "path/filepath"
21         "strings"
22         "time"
23
24         "golang.org/x/crypto/acme"
25         "golang.org/x/crypto/acme/autocert"
26 )
27
28 const stagingDirectoryURL = "https://acme-staging-v02.api.letsencrypt.org/directory"
29
30 var errInvalidHost = errors.New("unrecognized target host in incoming TLS request")
31
32 type createCertificates struct{}
33
34 func (createCertificates) String() string {
35         return "certificates"
36 }
37
38 func (createCertificates) Run(ctx context.Context, fail func(error), super *Supervisor) error {
39         if super.cluster.TLS.ACME.Server != "" {
40                 return bootAutoCert(ctx, fail, super)
41         } else if super.cluster.TLS.Key == "" && super.cluster.TLS.Certificate == "" {
42                 return createSelfSignedCert(ctx, fail, super)
43         } else {
44                 return nil
45         }
46 }
47
48 // bootAutoCert uses Let's Encrypt to get certificates for all the
49 // domains appearing in ExternalURLs, writes them to files where Nginx
50 // can load them, and updates super.cluster.TLS fields (Key and
51 // Certificiate) to point to those files.
52 //
53 // It also runs a background task to keep the files up to date.
54 //
55 // After bootAutoCert returns, other service components will get the
56 // certificates they need by reading these files or by using a
57 // read-only autocert cache.
58 //
59 // Currently this only works when port 80 of every ExternalURL domain
60 // is routed to this host, i.e., on a single-node cluster. Wildcard
61 // domains [for WebDAV] are not supported.
62 func bootAutoCert(ctx context.Context, fail func(error), super *Supervisor) error {
63         hosts := map[string]bool{}
64         for _, svc := range super.cluster.Services.Map() {
65                 u := url.URL(svc.ExternalURL)
66                 if u.Scheme == "https" || u.Scheme == "wss" {
67                         hosts[strings.ToLower(u.Hostname())] = true
68                 }
69         }
70         mgr := &autocert.Manager{
71                 Cache:  autocert.DirCache(super.tempdir + "/autocert"),
72                 Prompt: autocert.AcceptTOS,
73                 HostPolicy: func(ctx context.Context, host string) error {
74                         if hosts[strings.ToLower(host)] {
75                                 return nil
76                         } else {
77                                 return errInvalidHost
78                         }
79                 },
80         }
81         if srv := super.cluster.TLS.ACME.Server; srv == "LE" {
82                 // Leaving mgr.Client == nil means use Let's Encrypt
83                 // production environment
84         } else if srv == "LE-staging" {
85                 mgr.Client = &acme.Client{DirectoryURL: stagingDirectoryURL}
86         } else if strings.HasPrefix(srv, "https://") {
87                 mgr.Client = &acme.Client{DirectoryURL: srv}
88         } else {
89                 return fmt.Errorf("autocert setup: invalid directory URL in TLS.ACME.Server: %q", srv)
90         }
91         go func() {
92                 err := http.ListenAndServe(":80", mgr.HTTPHandler(nil))
93                 fail(fmt.Errorf("autocert http-01 challenge handler stopped: %w", err))
94         }()
95         u := url.URL(super.cluster.Services.Controller.ExternalURL)
96         extHost := u.Hostname()
97         update := func() error {
98                 for h := range hosts {
99                         cert, err := mgr.GetCertificate(&tls.ClientHelloInfo{ServerName: h})
100                         if err != nil {
101                                 return err
102                         }
103                         if h == extHost {
104                                 err = writeCert(super.tempdir, "server.key", "server.crt", cert)
105                                 if err != nil {
106                                         return err
107                                 }
108                         }
109                 }
110                 return nil
111         }
112         err := update()
113         if err != nil {
114                 return err
115         }
116         go func() {
117                 for range time.NewTicker(time.Hour).C {
118                         err := update()
119                         if err != nil {
120                                 super.logger.WithError(err).Error("error getting certificate from autocert")
121                         }
122                 }
123         }()
124         super.cluster.TLS.Key = "file://" + super.tempdir + "/server.key"
125         super.cluster.TLS.Certificate = "file://" + super.tempdir + "/server.crt"
126         return nil
127 }
128
129 // Save cert chain and key in a format Nginx can read.
130 func writeCert(outdir, keyfile, certfile string, cert *tls.Certificate) error {
131         keytmp, err := os.CreateTemp(outdir, keyfile+".tmp.*")
132         if err != nil {
133                 return err
134         }
135         defer keytmp.Close()
136         defer os.Remove(keytmp.Name())
137
138         certtmp, err := os.CreateTemp(outdir, certfile+".tmp.*")
139         if err != nil {
140                 return err
141         }
142         defer certtmp.Close()
143         defer os.Remove(certtmp.Name())
144
145         switch privkey := cert.PrivateKey.(type) {
146         case *rsa.PrivateKey:
147                 err = pem.Encode(keytmp, &pem.Block{
148                         Type:  "RSA PRIVATE KEY",
149                         Bytes: x509.MarshalPKCS1PrivateKey(privkey),
150                 })
151                 if err != nil {
152                         return err
153                 }
154         default:
155                 buf, err := x509.MarshalPKCS8PrivateKey(privkey)
156                 if err != nil {
157                         return err
158                 }
159                 err = pem.Encode(keytmp, &pem.Block{
160                         Type:  "PRIVATE KEY",
161                         Bytes: buf,
162                 })
163                 if err != nil {
164                         return err
165                 }
166         }
167         err = keytmp.Close()
168         if err != nil {
169                 return err
170         }
171
172         for _, cert := range cert.Certificate {
173                 err = pem.Encode(certtmp, &pem.Block{
174                         Type:  "CERTIFICATE",
175                         Bytes: cert,
176                 })
177                 if err != nil {
178                         return err
179                 }
180         }
181         err = certtmp.Close()
182         if err != nil {
183                 return err
184         }
185
186         err = os.Rename(keytmp.Name(), filepath.Join(outdir, keyfile))
187         if err != nil {
188                 return err
189         }
190         err = os.Rename(certtmp.Name(), filepath.Join(outdir, certfile))
191         if err != nil {
192                 return err
193         }
194         return nil
195 }
196
197 // Create a root CA key and use it to make a new server
198 // certificate+key pair.
199 //
200 // In future we'll make one root CA key per host instead of one per
201 // cluster, so it only needs to be imported to a browser once for
202 // ongoing dev/test usage.
203 func createSelfSignedCert(ctx context.Context, fail func(error), super *Supervisor) error {
204         san := "DNS:localhost,DNS:localhost.localdomain"
205         if net.ParseIP(super.ListenHost) != nil {
206                 san += fmt.Sprintf(",IP:%s", super.ListenHost)
207         } else {
208                 san += fmt.Sprintf(",DNS:%s", super.ListenHost)
209         }
210         hostname, err := os.Hostname()
211         if err != nil {
212                 return fmt.Errorf("hostname: %w", err)
213         }
214         if hostname != super.ListenHost {
215                 san += ",DNS:" + hostname
216         }
217
218         // Generate root key
219         err = super.RunProgram(ctx, super.tempdir, runOptions{}, "openssl", "genrsa", "-out", "rootCA.key", "4096")
220         if err != nil {
221                 return err
222         }
223         // Generate a self-signed root certificate
224         err = super.RunProgram(ctx, super.tempdir, runOptions{}, "openssl", "req", "-x509", "-new", "-nodes", "-key", "rootCA.key", "-sha256", "-days", "3650", "-out", "rootCA.crt", "-subj", "/C=US/ST=MA/O=Example Org/CN=localhost")
225         if err != nil {
226                 return err
227         }
228         // Generate server key
229         err = super.RunProgram(ctx, super.tempdir, runOptions{}, "openssl", "genrsa", "-out", "server.key", "2048")
230         if err != nil {
231                 return err
232         }
233         // Build config file for signing request
234         defaultconf, err := ioutil.ReadFile("/etc/ssl/openssl.cnf")
235         if err != nil {
236                 return err
237         }
238         conf := append(defaultconf, []byte(fmt.Sprintf("\n[SAN]\nsubjectAltName=%s\n", san))...)
239         err = ioutil.WriteFile(filepath.Join(super.tempdir, "server.cfg"), conf, 0644)
240         if err != nil {
241                 return err
242         }
243         // Generate signing request
244         err = super.RunProgram(ctx, super.tempdir, runOptions{}, "openssl", "req", "-new", "-sha256", "-key", "server.key", "-subj", "/C=US/ST=MA/O=Example Org/CN=localhost", "-reqexts", "SAN", "-config", "server.cfg", "-out", "server.csr")
245         if err != nil {
246                 return err
247         }
248         // Sign certificate
249         err = super.RunProgram(ctx, super.tempdir, runOptions{}, "openssl", "x509", "-req", "-in", "server.csr", "-CA", "rootCA.crt", "-CAkey", "rootCA.key", "-CAcreateserial", "-out", "server.crt", "-extfile", "server.cfg", "-extensions", "SAN", "-days", "3650", "-sha256")
250         if err != nil {
251                 return err
252         }
253         super.cluster.TLS.Key = "file://" + super.tempdir + "/server.key"
254         super.cluster.TLS.Certificate = "file://" + super.tempdir + "/server.crt"
255         return nil
256 }