3ff3912e8c892ab1e33f936170157eaa2886df46
[arvados.git] / services / crunchstat / crunchstat.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package main
6
7 import (
8         "bufio"
9         "flag"
10         "fmt"
11         "io"
12         "log"
13         "os"
14         "os/exec"
15         "os/signal"
16         "syscall"
17         "time"
18
19         "git.arvados.org/arvados.git/lib/crunchstat"
20 )
21
22 const MaxLogLine = 1 << 14 // Child stderr lines >16KiB will be split
23
24 var (
25         signalOnDeadPPID  int = 15
26         ppidCheckInterval     = time.Second
27         version               = "dev"
28 )
29
30 func main() {
31         reporter := crunchstat.Reporter{
32                 Logger: log.New(os.Stderr, "crunchstat: ", 0),
33         }
34
35         flags := flag.NewFlagSet(os.Args[0], flag.ExitOnError)
36         flags.StringVar(&reporter.CgroupRoot, "cgroup-root", "", "Root of cgroup tree")
37         flags.StringVar(&reporter.CgroupParent, "cgroup-parent", "", "Name of container parent under cgroup")
38         flags.StringVar(&reporter.CIDFile, "cgroup-cid", "", "Path to container id file")
39         flags.IntVar(&signalOnDeadPPID, "signal-on-dead-ppid", signalOnDeadPPID, "Signal to send child if crunchstat's parent process disappears (0 to disable)")
40         flags.DurationVar(&ppidCheckInterval, "ppid-check-interval", ppidCheckInterval, "Time between checks for parent process disappearance")
41         pollMsec := flags.Int64("poll", 1000, "Reporting interval, in milliseconds")
42         getVersion := flags.Bool("version", false, "Print version information and exit.")
43
44         err := flags.Parse(os.Args[1:])
45         if err == flag.ErrHelp {
46                 return
47         } else if err != nil {
48                 reporter.Logger.Print(err)
49                 os.Exit(2)
50         }
51
52         // Print version information if requested
53         if *getVersion {
54                 fmt.Printf("crunchstat %s\n", version)
55                 return
56         }
57
58         if flags.NArg() == 0 {
59                 fmt.Fprintf(flags.Output(), "Usage: %s [options] program [args...]\n\nOptions:\n", os.Args[0])
60                 flags.PrintDefaults()
61                 os.Exit(2)
62         }
63
64         reporter.Logger.Printf("crunchstat %s started", version)
65
66         if reporter.CgroupRoot == "" {
67                 reporter.Logger.Fatal("error: must provide -cgroup-root")
68         } else if signalOnDeadPPID < 0 {
69                 reporter.Logger.Fatalf("-signal-on-dead-ppid=%d is invalid (use a positive signal number, or 0 to disable)", signalOnDeadPPID)
70         }
71         reporter.PollPeriod = time.Duration(*pollMsec) * time.Millisecond
72
73         reporter.Start()
74         err = runCommand(flags.Args(), reporter.Logger)
75         reporter.Stop()
76
77         if err, ok := err.(*exec.ExitError); ok {
78                 // The program has exited with an exit code != 0
79
80                 // This works on both Unix and Windows. Although
81                 // package syscall is generally platform dependent,
82                 // WaitStatus is defined for both Unix and Windows and
83                 // in both cases has an ExitStatus() method with the
84                 // same signature.
85                 if status, ok := err.Sys().(syscall.WaitStatus); ok {
86                         os.Exit(status.ExitStatus())
87                 } else {
88                         reporter.Logger.Fatalln("ExitError without WaitStatus:", err)
89                 }
90         } else if err != nil {
91                 reporter.Logger.Fatalln("error in cmd.Wait:", err)
92         }
93 }
94
95 func runCommand(argv []string, logger *log.Logger) error {
96         cmd := exec.Command(argv[0], argv[1:]...)
97
98         logger.Println("Running", argv)
99
100         // Child process will use our stdin and stdout pipes
101         // (we close our copies below)
102         cmd.Stdin = os.Stdin
103         cmd.Stdout = os.Stdout
104
105         // Forward SIGINT and SIGTERM to child process
106         sigChan := make(chan os.Signal, 1)
107         go func(sig <-chan os.Signal) {
108                 catch := <-sig
109                 if cmd.Process != nil {
110                         cmd.Process.Signal(catch)
111                 }
112                 logger.Println("notice: caught signal:", catch)
113         }(sigChan)
114         signal.Notify(sigChan, syscall.SIGTERM)
115         signal.Notify(sigChan, syscall.SIGINT)
116
117         // Kill our child proc if our parent process disappears
118         if signalOnDeadPPID != 0 {
119                 go sendSignalOnDeadPPID(ppidCheckInterval, signalOnDeadPPID, os.Getppid(), cmd, logger)
120         }
121
122         // Funnel stderr through our channel
123         stderrPipe, err := cmd.StderrPipe()
124         if err != nil {
125                 logger.Fatalln("error in StderrPipe:", err)
126         }
127
128         // Run subprocess
129         if err := cmd.Start(); err != nil {
130                 logger.Fatalln("error in cmd.Start:", err)
131         }
132
133         // Close stdin/stdout in this (parent) process
134         os.Stdin.Close()
135         os.Stdout.Close()
136
137         copyPipeToChildLog(stderrPipe, log.New(os.Stderr, "", 0))
138
139         return cmd.Wait()
140 }
141
142 func sendSignalOnDeadPPID(intvl time.Duration, signum, ppidOrig int, cmd *exec.Cmd, logger *log.Logger) {
143         ticker := time.NewTicker(intvl)
144         for range ticker.C {
145                 ppid := os.Getppid()
146                 if ppid == ppidOrig {
147                         continue
148                 }
149                 if cmd.Process == nil {
150                         // Child process isn't running yet
151                         continue
152                 }
153                 logger.Printf("notice: crunchstat ppid changed from %d to %d -- killing child pid %d with signal %d", ppidOrig, ppid, cmd.Process.Pid, signum)
154                 err := cmd.Process.Signal(syscall.Signal(signum))
155                 if err != nil {
156                         logger.Printf("error: sending signal: %s", err)
157                         continue
158                 }
159                 ticker.Stop()
160                 break
161         }
162 }
163
164 func copyPipeToChildLog(in io.ReadCloser, logger *log.Logger) {
165         reader := bufio.NewReaderSize(in, MaxLogLine)
166         var prefix string
167         for {
168                 line, isPrefix, err := reader.ReadLine()
169                 if err == io.EOF {
170                         break
171                 } else if err != nil {
172                         logger.Fatal("error reading child stderr:", err)
173                 }
174                 var suffix string
175                 if isPrefix {
176                         suffix = "[...]"
177                 }
178                 logger.Print(prefix, string(line), suffix)
179                 // Set up prefix for following line
180                 if isPrefix {
181                         prefix = "[...]"
182                 } else {
183                         prefix = ""
184                 }
185         }
186         in.Close()
187 }