4ed7603d2a55689a298041286dddca5f09643b97
[arvados.git] / lib / boot / postgresql.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package boot
6
7 import (
8         "bytes"
9         "context"
10         "database/sql"
11         "fmt"
12         "os"
13         "os/exec"
14         "os/user"
15         "path/filepath"
16         "strconv"
17         "strings"
18         "time"
19
20         "git.arvados.org/arvados.git/sdk/go/arvados"
21         "github.com/lib/pq"
22 )
23
24 // Run a postgresql server in a private data directory. Set up a db
25 // user, database, and TCP listener that match the supervisor's
26 // configured database connection info.
27 type runPostgreSQL struct{}
28
29 func (runPostgreSQL) String() string {
30         return "postgresql"
31 }
32
33 func (runPostgreSQL) Run(ctx context.Context, fail func(error), super *Supervisor) error {
34         err := super.wait(ctx, createCertificates{})
35         if err != nil {
36                 return err
37         }
38
39         if super.ClusterType == "production" {
40                 return nil
41         }
42
43         iamroot := false
44         if u, err := user.Current(); err != nil {
45                 return fmt.Errorf("user.Current(): %w", err)
46         } else if u.Uid == "0" {
47                 iamroot = true
48         }
49
50         buf := bytes.NewBuffer(nil)
51         err = super.RunProgram(ctx, super.tempdir, runOptions{output: buf}, "pg_config", "--bindir")
52         if err != nil {
53                 return err
54         }
55         bindir := strings.TrimSpace(buf.String())
56
57         datadir := filepath.Join(super.tempdir, "pgdata")
58         err = os.Mkdir(datadir, 0700)
59         if err != nil {
60                 return err
61         }
62         prog, args := filepath.Join(bindir, "initdb"), []string{"-D", datadir, "-E", "utf8"}
63         opts := runOptions{}
64         if iamroot {
65                 postgresUser, err := user.Lookup("postgres")
66                 if err != nil {
67                         return fmt.Errorf("user.Lookup(\"postgres\"): %s", err)
68                 }
69                 postgresUID, err := strconv.Atoi(postgresUser.Uid)
70                 if err != nil {
71                         return fmt.Errorf("user.Lookup(\"postgres\"): non-numeric uid?: %q", postgresUser.Uid)
72                 }
73                 postgresGid, err := strconv.Atoi(postgresUser.Gid)
74                 if err != nil {
75                         return fmt.Errorf("user.Lookup(\"postgres\"): non-numeric gid?: %q", postgresUser.Gid)
76                 }
77                 err = os.Chown(super.tempdir, 0, postgresGid)
78                 if err != nil {
79                         return err
80                 }
81                 err = os.Chmod(super.tempdir, 0710)
82                 if err != nil {
83                         return err
84                 }
85                 err = os.Chown(datadir, postgresUID, 0)
86                 if err != nil {
87                         return err
88                 }
89                 opts.user = "postgres"
90         }
91         err = super.RunProgram(ctx, super.tempdir, opts, prog, args...)
92         if err != nil {
93                 return err
94         }
95
96         err = super.RunProgram(ctx, super.tempdir, runOptions{}, "cp", "server.crt", "server.key", datadir)
97         if err != nil {
98                 return err
99         }
100         if iamroot {
101                 err = super.RunProgram(ctx, super.tempdir, runOptions{}, "chown", "postgres", datadir+"/server.crt", datadir+"/server.key")
102                 if err != nil {
103                         return err
104                 }
105         }
106
107         port := super.cluster.PostgreSQL.Connection["port"]
108
109         super.waitShutdown.Add(1)
110         go func() {
111                 defer super.waitShutdown.Done()
112                 prog, args := filepath.Join(bindir, "postgres"), []string{
113                         "-l",          // enable ssl
114                         "-D", datadir, // data dir
115                         "-k", datadir, // socket dir
116                         "-p", super.cluster.PostgreSQL.Connection["port"],
117                 }
118                 opts := runOptions{}
119                 if iamroot {
120                         opts.user = "postgres"
121                 }
122                 fail(super.RunProgram(ctx, super.tempdir, opts, prog, args...))
123         }()
124
125         for {
126                 if ctx.Err() != nil {
127                         return ctx.Err()
128                 }
129                 if exec.CommandContext(ctx, "pg_isready", "--timeout=10", "--host="+super.cluster.PostgreSQL.Connection["host"], "--port="+port).Run() == nil {
130                         break
131                 }
132                 time.Sleep(time.Second / 2)
133         }
134         pgconn := arvados.PostgreSQLConnection{
135                 "host":   datadir,
136                 "port":   port,
137                 "dbname": "postgres",
138         }
139         if iamroot {
140                 pgconn["user"] = "postgres"
141         }
142         db, err := sql.Open("postgres", pgconn.String())
143         if err != nil {
144                 return fmt.Errorf("db open failed: %s", err)
145         }
146         defer db.Close()
147         conn, err := db.Conn(ctx)
148         if err != nil {
149                 return fmt.Errorf("db conn failed: %s", err)
150         }
151         defer conn.Close()
152         _, err = conn.ExecContext(ctx, `CREATE USER `+pq.QuoteIdentifier(super.cluster.PostgreSQL.Connection["user"])+` WITH SUPERUSER ENCRYPTED PASSWORD `+pq.QuoteLiteral(super.cluster.PostgreSQL.Connection["password"]))
153         if err != nil {
154                 return fmt.Errorf("createuser failed: %s", err)
155         }
156         _, err = conn.ExecContext(ctx, `CREATE DATABASE `+pq.QuoteIdentifier(super.cluster.PostgreSQL.Connection["dbname"])+` WITH TEMPLATE template0 ENCODING 'utf8'`)
157         if err != nil {
158                 return fmt.Errorf("createdb failed: %s", err)
159         }
160         return nil
161 }