}()
flags := flag.NewFlagSet(prog, flag.ContinueOnError)
+ flags.SetOutput(stderr)
+
format := flags.String("format", "json", "output format (json, yaml, or uuid)")
flags.StringVar(format, "f", "json", "output format (json, yaml, or uuid)")
short := flags.Bool("short", false, "equivalent to --format=uuid")
"flag"
"fmt"
"io"
+ "io/ioutil"
+ "sort"
+ "strings"
)
// A RunFunc runs a command with the given args, and returns an exit
// code.
type RunFunc func(prog string, args []string, stdin io.Reader, stdout, stderr io.Writer) int
-// Multi returns a command that looks up its first argument in m, and
-// runs the resulting RunFunc with the remaining args.
+// Multi returns a RunFunc that looks up its first argument in m, and
+// invokes the resulting RunFunc with the remaining args.
//
// Example:
//
func Multi(m map[string]RunFunc) RunFunc {
return func(prog string, args []string, stdin io.Reader, stdout, stderr io.Writer) int {
if len(args) < 1 {
- fmt.Fprintf(stderr, "usage: %s command [args]", prog)
+ fmt.Fprintf(stderr, "usage: %s command [args]\n", prog)
+ multiUsage(stderr, m)
return 2
}
if cmd, ok := m[args[0]]; !ok {
- fmt.Fprintf(stderr, "unrecognized command %q", args[0])
+ fmt.Fprintf(stderr, "unrecognized command %q\n", args[0])
+ multiUsage(stderr, m)
return 2
} else {
return cmd(prog+" "+args[0], args[1:], stdin, stdout, stderr)
}
}
+func multiUsage(stderr io.Writer, m map[string]RunFunc) {
+ var subcommands []string
+ for sc := range m {
+ if strings.HasPrefix(sc, "-") {
+ // Some subcommands have alternate versions
+ // like "--version" for compatibility. Don't
+ // clutter the subcommand summary with those.
+ continue
+ }
+ subcommands = append(subcommands, sc)
+ }
+ sort.Strings(subcommands)
+ fmt.Fprintf(stderr, "\nAvailable commands:\n")
+ for _, sc := range subcommands {
+ fmt.Fprintf(stderr, " %s\n", sc)
+ }
+}
+
// WithLateSubcommand wraps a RunFunc by skipping over some known
// flags to find a subcommand, and moving that subcommand to the front
// of the args before calling the wrapped RunFunc. For example:
}
// Ignore errors. We can't report a useful error
// message anyway.
+ flags.SetOutput(ioutil.Discard)
+ flags.Usage = func() {}
flags.Parse(args)
if flags.NArg() > 0 {
// Move the first arg after the recognized
"strings"
"testing"
+ "git.curoverse.com/arvados.git/lib/cmdtest"
check "gopkg.in/check.v1"
)
})
func (s *CmdSuite) TestHello(c *check.C) {
+ defer cmdtest.LeakCheck(c)()
stdout := bytes.NewBuffer(nil)
stderr := bytes.NewBuffer(nil)
exited := testCmd("prog", []string{"echo", "hello", "world"}, bytes.NewReader(nil), stdout, stderr)
c.Check(stderr.String(), check.Equals, "")
}
+func (s *CmdSuite) TestUsage(c *check.C) {
+ defer cmdtest.LeakCheck(c)()
+ stdout := bytes.NewBuffer(nil)
+ stderr := bytes.NewBuffer(nil)
+ exited := testCmd("prog", []string{"nosuchcommand", "hi"}, bytes.NewReader(nil), stdout, stderr)
+ c.Check(exited, check.Equals, 2)
+ c.Check(stdout.String(), check.Equals, "")
+ c.Check(stderr.String(), check.Matches, `(?ms)^unrecognized command "nosuchcommand"\n.*echo.*\n`)
+}
+
func (s *CmdSuite) TestWithLateSubcommand(c *check.C) {
+ defer cmdtest.LeakCheck(c)()
stdout := bytes.NewBuffer(nil)
stderr := bytes.NewBuffer(nil)
run := WithLateSubcommand(testCmd, []string{"format", "f"}, []string{"n"})
--- /dev/null
+// Copyright (C) The Arvados Authors. All rights reserved.
+//
+// SPDX-License-Identifier: Apache-2.0
+
+// Package cmdtest provides tools for testing command line tools.
+package cmdtest
+
+import (
+ "io"
+ "io/ioutil"
+ "os"
+
+ check "gopkg.in/check.v1"
+)
+
+// LeakCheck tests for output being leaked to os.Stdout and os.Stderr
+// that should be sent elsewhere (e.g., the stdout and stderr streams
+// passed to a cmd.RunFunc).
+//
+// It redirects os.Stdout and os.Stderr to a tempfile, and returns a
+// func, which the caller is expected to defer, that restores os.* and
+// checks that the tempfile is empty.
+//
+// Example:
+//
+// func (s *Suite) TestSomething(c *check.C) {
+// defer cmdtest.LeakCheck(c)()
+// // ... do things that shouldn't print to os.Stderr or os.Stdout
+// }
+func LeakCheck(c *check.C) func() {
+ tmpfiles := map[string]*os.File{"stdout": nil, "stderr": nil}
+ for i := range tmpfiles {
+ var err error
+ tmpfiles[i], err = ioutil.TempFile("", "")
+ c.Assert(err, check.IsNil)
+ err = os.Remove(tmpfiles[i].Name())
+ c.Assert(err, check.IsNil)
+ }
+
+ stdout, stderr := os.Stdout, os.Stderr
+ os.Stdout, os.Stderr = tmpfiles["stdout"], tmpfiles["stderr"]
+ return func() {
+ os.Stdout, os.Stderr = stdout, stderr
+
+ for i, tmpfile := range tmpfiles {
+ c.Log("checking %s", i)
+ _, err := tmpfile.Seek(0, io.SeekStart)
+ c.Assert(err, check.IsNil)
+ leaked, err := ioutil.ReadAll(tmpfile)
+ c.Assert(err, check.IsNil)
+ c.Check(string(leaked), check.Equals, "")
+ }
+ }
+}