import (
"bytes"
+ "fmt"
"io"
"io/ioutil"
"net"
}
func (s *executorSuite) TestInject(c *C) {
+ hostdir := c.MkDir()
+ c.Assert(os.WriteFile(hostdir+"/testfile", []byte("first tube"), 0777), IsNil)
+ mountdir := fmt.Sprintf("/injecttest-%d", os.Getpid())
s.spec.Command = []string{"sleep", "10"}
+ s.spec.BindMounts = map[string]bindmount{mountdir: {HostPath: hostdir, ReadOnly: true}}
c.Assert(s.executor.Create(s.spec), IsNil)
c.Assert(s.executor.Start(), IsNil)
starttime := time.Now()
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(2*time.Second))
defer cancel()
- injectcmd := []string{"cat", "/proc/1/cmdline"}
+ // Allow InjectCommand to fail a few times while the container
+ // is starting
+ for ctx.Err() == nil {
+ _, err := s.executor.InjectCommand(ctx, "", "root", false, []string{"true"})
+ if err == nil {
+ break
+ }
+ time.Sleep(time.Second / 10)
+ }
+
+ injectcmd := []string{"cat", mountdir + "/testfile"}
cmd, err := s.executor.InjectCommand(ctx, "", "root", false, injectcmd)
c.Assert(err, IsNil)
out, err := cmd.CombinedOutput()
c.Logf("inject %s => %q", injectcmd, out)
c.Check(err, IsNil)
- c.Check(string(out), Equals, "sleep\00010\000")
+ c.Check(string(out), Equals, "first tube")
s.executor.Stop()
code, _ := s.executor.Wait(ctx)
package crunchrun
import (
+ "bytes"
"errors"
"fmt"
"io/ioutil"
+ "net"
"os"
"os/exec"
+ "regexp"
"sort"
+ "strconv"
"syscall"
"time"
}
func (e *singularityExecutor) execCmd(path string) *exec.Cmd {
- args := []string{path, "exec", "--containall", "--cleanenv", "--pwd", e.spec.WorkingDir}
+ args := []string{path, "exec", "--containall", "--cleanenv", "--pwd", e.spec.WorkingDir, "--net"}
if !e.spec.EnableNetwork {
- args = append(args, "--net", "--network=none")
+ args = append(args, "--network=none")
}
-
if e.spec.CUDADeviceCount != 0 {
args = append(args, "--nv")
}
}
func (e *singularityExecutor) InjectCommand(ctx context.Context, detachKeys, username string, usingTTY bool, injectcmd []string) (*exec.Cmd, error) {
- return nil, errors.New("unimplemented")
+ target, err := e.containedProcess()
+ if err != nil {
+ return nil, err
+ }
+ return exec.CommandContext(ctx, "nsenter", append([]string{fmt.Sprintf("--target=%d", target), "--all"}, injectcmd...)...), nil
}
+var (
+ errContainerHasNoIPAddress = errors.New("container has no IP address distinct from host")
+)
+
func (e *singularityExecutor) IPAddress() (string, error) {
- return "", errors.New("unimplemented")
+ target, err := e.containedProcess()
+ if err != nil {
+ return "", err
+ }
+ targetIPs, err := processIPs(target)
+ if err != nil {
+ return "", err
+ }
+ selfIPs, err := processIPs(os.Getpid())
+ if err != nil {
+ return "", err
+ }
+ for ip := range targetIPs {
+ if !selfIPs[ip] {
+ return ip, nil
+ }
+ }
+ return "", errContainerHasNoIPAddress
+}
+
+func processIPs(pid int) (map[string]bool, error) {
+ fibtrie, err := os.ReadFile(fmt.Sprintf("/proc/%d/net/fib_trie", pid))
+ if err != nil {
+ return nil, err
+ }
+
+ addrs := map[string]bool{}
+ // When we see a pair of lines like this:
+ //
+ // |-- 10.1.2.3
+ // /32 host LOCAL
+ //
+ // ...we set addrs["10.1.2.3"] = true
+ lines := bytes.Split(fibtrie, []byte{'\n'})
+ for linenumber, line := range lines {
+ if !bytes.HasSuffix(line, []byte("/32 host LOCAL")) {
+ continue
+ }
+ if linenumber < 1 {
+ continue
+ }
+ i := bytes.LastIndexByte(lines[linenumber-1], ' ')
+ if i < 0 || i >= len(line)-7 {
+ continue
+ }
+ addr := string(lines[linenumber-1][i+1:])
+ if net.ParseIP(addr).To4() != nil {
+ addrs[addr] = true
+ }
+ }
+ return addrs, nil
+}
+
+var (
+ errContainerNotStarted = errors.New("container has not started yet")
+ errCannotFindChild = errors.New("failed to find any process inside the container")
+ reProcStatusPPid = regexp.MustCompile(`\nPPid:\t(\d+)\n`)
+)
+
+// Return the PID of a process that is inside the container (not
+// necessarily the topmost/pid=1 process in the container).
+func (e *singularityExecutor) containedProcess() (int, error) {
+ if e.child == nil || e.child.Process == nil {
+ return 0, errContainerNotStarted
+ }
+ lsns, err := exec.Command("lsns").CombinedOutput()
+ if err != nil {
+ return 0, fmt.Errorf("lsns: %w", err)
+ }
+ for _, line := range bytes.Split(lsns, []byte{'\n'}) {
+ fields := bytes.Fields(line)
+ if len(fields) < 4 {
+ continue
+ }
+ if !bytes.Equal(fields[1], []byte("pid")) {
+ continue
+ }
+ pid, err := strconv.ParseInt(string(fields[3]), 10, 64)
+ if err != nil {
+ return 0, fmt.Errorf("error parsing PID field in lsns output: %q", fields[3])
+ }
+ for parent := pid; ; {
+ procstatus, err := os.ReadFile(fmt.Sprintf("/proc/%d/status", parent))
+ if err != nil {
+ break
+ }
+ m := reProcStatusPPid.FindSubmatch(procstatus)
+ if m == nil {
+ break
+ }
+ parent, err = strconv.ParseInt(string(m[1]), 10, 64)
+ if err != nil {
+ break
+ }
+ if int(parent) == e.child.Process.Pid {
+ return int(pid), nil
+ }
+ }
+ }
+ return 0, errCannotFindChild
}