16053: Run postgresql as "postgres" user if supervisor is root.
[arvados.git] / lib / boot / postgresql.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package boot
6
7 import (
8         "bytes"
9         "context"
10         "database/sql"
11         "fmt"
12         "os"
13         "os/exec"
14         "os/user"
15         "path/filepath"
16         "strconv"
17         "strings"
18         "time"
19
20         "git.arvados.org/arvados.git/sdk/go/arvados"
21         "github.com/lib/pq"
22 )
23
24 // Run a postgresql server in a private data directory. Set up a db
25 // user, database, and TCP listener that match the supervisor's
26 // configured database connection info.
27 type runPostgreSQL struct{}
28
29 func (runPostgreSQL) String() string {
30         return "postgresql"
31 }
32
33 func (runPostgreSQL) Run(ctx context.Context, fail func(error), super *Supervisor) error {
34         err := super.wait(ctx, createCertificates{})
35         if err != nil {
36                 return err
37         }
38
39         iamroot := false
40         if u, err := user.Current(); err != nil {
41                 return fmt.Errorf("user.Current(): %s", err)
42         } else if u.Uid == "0" {
43                 iamroot = true
44         }
45
46         buf := bytes.NewBuffer(nil)
47         err = super.RunProgram(ctx, super.tempdir, buf, nil, "pg_config", "--bindir")
48         if err != nil {
49                 return err
50         }
51         bindir := strings.TrimSpace(buf.String())
52
53         datadir := filepath.Join(super.tempdir, "pgdata")
54         err = os.Mkdir(datadir, 0700)
55         if err != nil {
56                 return err
57         }
58         prog, args := filepath.Join(bindir, "initdb"), []string{"-D", datadir, "-E", "utf8"}
59         if iamroot {
60                 postgresUser, err := user.Lookup("postgres")
61                 if err != nil {
62                         return fmt.Errorf("user.Lookup(\"postgres\"): %s", err)
63                 }
64                 postgresUid, err := strconv.Atoi(postgresUser.Uid)
65                 if err != nil {
66                         return fmt.Errorf("user.Lookup(\"postgres\"): non-numeric uid?: %q", postgresUser.Uid)
67                 }
68                 postgresGid, err := strconv.Atoi(postgresUser.Gid)
69                 if err != nil {
70                         return fmt.Errorf("user.Lookup(\"postgres\"): non-numeric gid?: %q", postgresUser.Gid)
71                 }
72                 err = os.Chown(super.tempdir, 0, postgresGid)
73                 if err != nil {
74                         return err
75                 }
76                 err = os.Chmod(super.tempdir, 0710)
77                 if err != nil {
78                         return err
79                 }
80                 err = os.Chown(datadir, postgresUid, 0)
81                 if err != nil {
82                         return err
83                 }
84                 args = append([]string{"-u", "postgres", prog}, args...)
85                 prog = "sudo"
86         }
87         err = super.RunProgram(ctx, super.tempdir, nil, nil, prog, args...)
88         if err != nil {
89                 return err
90         }
91
92         err = super.RunProgram(ctx, super.tempdir, nil, nil, "cp", "server.crt", "server.key", datadir)
93         if err != nil {
94                 return err
95         }
96         if iamroot {
97                 err = super.RunProgram(ctx, super.tempdir, nil, nil, "chown", "postgres", datadir+"/server.crt", datadir+"/server.key")
98                 if err != nil {
99                         return err
100                 }
101         }
102
103         port := super.cluster.PostgreSQL.Connection["port"]
104
105         super.waitShutdown.Add(1)
106         go func() {
107                 defer super.waitShutdown.Done()
108                 prog, args := filepath.Join(bindir, "postgres"), []string{
109                         "-l",          // enable ssl
110                         "-D", datadir, // data dir
111                         "-k", datadir, // socket dir
112                         "-p", super.cluster.PostgreSQL.Connection["port"],
113                 }
114                 if iamroot {
115                         args = append([]string{"-u", "postgres", prog}, args...)
116                         prog = "sudo"
117                 }
118                 fail(super.RunProgram(ctx, super.tempdir, nil, nil, prog, args...))
119         }()
120
121         for {
122                 if ctx.Err() != nil {
123                         return ctx.Err()
124                 }
125                 if exec.CommandContext(ctx, "pg_isready", "--timeout=10", "--host="+super.cluster.PostgreSQL.Connection["host"], "--port="+port).Run() == nil {
126                         break
127                 }
128                 time.Sleep(time.Second / 2)
129         }
130         pgconn := arvados.PostgreSQLConnection{
131                 "host":   datadir,
132                 "port":   port,
133                 "dbname": "postgres",
134         }
135         if iamroot {
136                 pgconn["user"] = "postgres"
137         }
138         db, err := sql.Open("postgres", pgconn.String())
139         if err != nil {
140                 return fmt.Errorf("db open failed: %s", err)
141         }
142         defer db.Close()
143         conn, err := db.Conn(ctx)
144         if err != nil {
145                 return fmt.Errorf("db conn failed: %s", err)
146         }
147         defer conn.Close()
148         _, err = conn.ExecContext(ctx, `CREATE USER `+pq.QuoteIdentifier(super.cluster.PostgreSQL.Connection["user"])+` WITH SUPERUSER ENCRYPTED PASSWORD `+pq.QuoteLiteral(super.cluster.PostgreSQL.Connection["password"]))
149         if err != nil {
150                 return fmt.Errorf("createuser failed: %s", err)
151         }
152         _, err = conn.ExecContext(ctx, `CREATE DATABASE `+pq.QuoteIdentifier(super.cluster.PostgreSQL.Connection["dbname"]))
153         if err != nil {
154                 return fmt.Errorf("createdb failed: %s", err)
155         }
156         return nil
157 }