From adc8f271319a026ecef2a82c44f45aca1a5ec5fc Mon Sep 17 00:00:00 2001
From: Daniel Dao <dqminh89@gmail.com>
Date: Thu, 21 Sep 2017 13:34:22 +0100
Subject: [PATCH] handle runc command context manually

instead of defering to Go's stdlib to handle context, which only sends
SIGKILL when the context timed out, handle the context timeout ourselves
so we can inject a custom signal first ( default to SIGTERM, and send
SIGKILL after 10 seconds) to stop container more gracefully.

If no signal is specified, then fall back to sending SIGKILL just like
stdlib.

Fix #21
---
 command_linux.go |  5 ++--
 command_other.go |  5 ++--
 monitor.go       | 43 ++++++++++++++++++++++++++++++---
 monitor_test.go  | 40 +++++++++++++++++++++++++++++++
 runc.go          | 62 ++++++++++++++++++++++++------------------------
 5 files changed, 115 insertions(+), 40 deletions(-)
 create mode 100644 monitor_test.go

diff --git a/command_linux.go b/command_linux.go
index 8b4a502..b2d6065 100644
--- a/command_linux.go
+++ b/command_linux.go
@@ -1,17 +1,16 @@
 package runc
 
 import (
-	"context"
 	"os/exec"
 	"syscall"
 )
 
-func (r *Runc) command(context context.Context, args ...string) *exec.Cmd {
+func (r *Runc) command(args ...string) *exec.Cmd {
 	command := r.Command
 	if command == "" {
 		command = DefaultCommand
 	}
-	cmd := exec.CommandContext(context, command, append(r.args(), args...)...)
+	cmd := exec.Command(command, append(r.args(), args...)...)
 	cmd.SysProcAttr = &syscall.SysProcAttr{
 		Setpgid: r.Setpgid,
 	}
diff --git a/command_other.go b/command_other.go
index bf03a2f..d539e42 100644
--- a/command_other.go
+++ b/command_other.go
@@ -3,14 +3,13 @@
 package runc
 
 import (
-	"context"
 	"os/exec"
 )
 
-func (r *Runc) command(context context.Context, args ...string) *exec.Cmd {
+func (r *Runc) command(args ...string) *exec.Cmd {
 	command := r.Command
 	if command == "" {
 		command = DefaultCommand
 	}
-	return exec.CommandContext(context, command, append(r.args(), args...)...)
+	return exec.Command(command, append(r.args(), args...)...)
 }
diff --git a/monitor.go b/monitor.go
index 2d62c5a..b10f410 100644
--- a/monitor.go
+++ b/monitor.go
@@ -1,17 +1,22 @@
 package runc
 
 import (
+	"context"
+	"os"
 	"os/exec"
 	"syscall"
 	"time"
+
+	"golang.org/x/sys/unix"
 )
 
-var Monitor ProcessMonitor = &defaultMonitor{}
+var Monitor ProcessMonitor = DefaultMonitor(unix.SIGTERM, 10*time.Second)
 
 type Exit struct {
 	Timestamp time.Time
 	Pid       int
 	Status    int
+	Signal    os.Signal
 }
 
 // ProcessMonitor is an interface for process monitoring
@@ -22,25 +27,55 @@ type Exit struct {
 // These methods should match the methods exposed by exec.Cmd to provide
 // a consistent experience for the caller
 type ProcessMonitor interface {
-	Start(*exec.Cmd) (chan Exit, error)
+	Start(context.Context, *exec.Cmd) (chan Exit, error)
 	Wait(*exec.Cmd, chan Exit) (int, error)
 }
 
+func DefaultMonitor(defaultSignal os.Signal, killTimeout time.Duration) ProcessMonitor {
+	return &defaultMonitor{
+		defaultSignal: defaultSignal,
+		killTimeout:   killTimeout,
+	}
+}
+
 type defaultMonitor struct {
+	defaultSignal os.Signal
+	killTimeout   time.Duration
 }
 
-func (m *defaultMonitor) Start(c *exec.Cmd) (chan Exit, error) {
+func (m *defaultMonitor) Start(ctx context.Context, c *exec.Cmd) (chan Exit, error) {
 	if err := c.Start(); err != nil {
 		return nil, err
 	}
 	ec := make(chan Exit, 1)
+	waitDone := make(chan struct{}, 1)
+	go func() {
+		select {
+		case <-ctx.Done():
+			if m.defaultSignal == nil {
+				c.Process.Signal(unix.SIGKILL)
+			} else {
+				c.Process.Signal(m.defaultSignal)
+				if m.killTimeout > 0 {
+					select {
+					case <-time.After(m.killTimeout):
+						c.Process.Kill()
+					case <-waitDone:
+					}
+				}
+			}
+		case <-waitDone:
+		}
+	}()
 	go func() {
 		var status int
+		var signal os.Signal
 		if err := c.Wait(); err != nil {
 			status = 255
 			if exitErr, ok := err.(*exec.ExitError); ok {
 				if ws, ok := exitErr.Sys().(syscall.WaitStatus); ok {
 					status = ws.ExitStatus()
+					signal = ws.Signal()
 				}
 			}
 		}
@@ -48,8 +83,10 @@ func (m *defaultMonitor) Start(c *exec.Cmd) (chan Exit, error) {
 			Timestamp: time.Now(),
 			Pid:       c.Process.Pid,
 			Status:    status,
+			Signal:    signal,
 		}
 		close(ec)
+		close(waitDone)
 	}()
 	return ec, nil
 }
diff --git a/monitor_test.go b/monitor_test.go
new file mode 100644
index 0000000..39bdcff
--- /dev/null
+++ b/monitor_test.go
@@ -0,0 +1,40 @@
+package runc
+
+import (
+	"context"
+	"os/exec"
+	"testing"
+	"time"
+
+	"golang.org/x/sys/unix"
+)
+
+func TestMonitorCustomSignal(t *testing.T) {
+	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
+	defer cancel()
+	cmd := exec.Command("sleep", "10")
+	monitor := DefaultMonitor(unix.SIGTERM, time.Second)
+	ec, err := monitor.Start(ctx, cmd)
+	if err != nil {
+		t.Errorf("Failed to start command: %v", err)
+	}
+	e := <-ec
+	if e.Signal != unix.SIGTERM {
+		t.Errorf("Got signal (%v), expected (%v)", e.Signal, unix.SIGTERM)
+	}
+}
+
+func TestMonitorKill(t *testing.T) {
+	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
+	defer cancel()
+	cmd := exec.Command("sleep", "10")
+	monitor := &defaultMonitor{}
+	ec, err := monitor.Start(ctx, cmd)
+	if err != nil {
+		t.Errorf("Failed to start command: %v", err)
+	}
+	e := <-ec
+	if e.Signal != unix.SIGKILL {
+		t.Errorf("Got signal (%v), expected (%v)", e.Signal, unix.SIGTERM)
+	}
+}
diff --git a/runc.go b/runc.go
index 4cd402a..a9eceff 100644
--- a/runc.go
+++ b/runc.go
@@ -46,7 +46,7 @@ type Runc struct {
 
 // List returns all containers created inside the provided runc root directory
 func (r *Runc) List(context context.Context) ([]*Container, error) {
-	data, err := cmdOutput(r.command(context, "list", "--format=json"), false)
+	data, err := cmdOutput(context, r.command("list", "--format=json"), false)
 	if err != nil {
 		return nil, err
 	}
@@ -59,7 +59,7 @@ func (r *Runc) List(context context.Context) ([]*Container, error) {
 
 // State returns the state for the container provided by id
 func (r *Runc) State(context context.Context, id string) (*Container, error) {
-	data, err := cmdOutput(r.command(context, "state", id), true)
+	data, err := cmdOutput(context, r.command("state", id), true)
 	if err != nil {
 		return nil, fmt.Errorf("%s: %s", err, data)
 	}
@@ -121,20 +121,20 @@ func (r *Runc) Create(context context.Context, id, bundle string, opts *CreateOp
 		}
 		args = append(args, oargs...)
 	}
-	cmd := r.command(context, append(args, id)...)
+	cmd := r.command(append(args, id)...)
 	if opts != nil && opts.IO != nil {
 		opts.Set(cmd)
 	}
 	cmd.ExtraFiles = opts.ExtraFiles
 
 	if cmd.Stdout == nil && cmd.Stderr == nil {
-		data, err := cmdOutput(cmd, true)
+		data, err := cmdOutput(context, cmd, true)
 		if err != nil {
 			return fmt.Errorf("%s: %s", err, data)
 		}
 		return nil
 	}
-	ec, err := Monitor.Start(cmd)
+	ec, err := Monitor.Start(context, cmd)
 	if err != nil {
 		return err
 	}
@@ -154,7 +154,7 @@ func (r *Runc) Create(context context.Context, id, bundle string, opts *CreateOp
 
 // Start will start an already created container
 func (r *Runc) Start(context context.Context, id string) error {
-	return r.runOrError(r.command(context, "start", id))
+	return r.runOrError(context, r.command("start", id))
 }
 
 type ExecOpts struct {
@@ -202,18 +202,18 @@ func (r *Runc) Exec(context context.Context, id string, spec specs.Process, opts
 		}
 		args = append(args, oargs...)
 	}
-	cmd := r.command(context, append(args, id)...)
+	cmd := r.command(append(args, id)...)
 	if opts != nil && opts.IO != nil {
 		opts.Set(cmd)
 	}
 	if cmd.Stdout == nil && cmd.Stderr == nil {
-		data, err := cmdOutput(cmd, true)
+		data, err := cmdOutput(context, cmd, true)
 		if err != nil {
 			return fmt.Errorf("%s: %s", err, data)
 		}
 		return nil
 	}
-	ec, err := Monitor.Start(cmd)
+	ec, err := Monitor.Start(context, cmd)
 	if err != nil {
 		return err
 	}
@@ -242,11 +242,11 @@ func (r *Runc) Run(context context.Context, id, bundle string, opts *CreateOpts)
 		}
 		args = append(args, oargs...)
 	}
-	cmd := r.command(context, append(args, id)...)
+	cmd := r.command(append(args, id)...)
 	if opts != nil {
 		opts.Set(cmd)
 	}
-	ec, err := Monitor.Start(cmd)
+	ec, err := Monitor.Start(context, cmd)
 	if err != nil {
 		return -1, err
 	}
@@ -270,7 +270,7 @@ func (r *Runc) Delete(context context.Context, id string, opts *DeleteOpts) erro
 	if opts != nil {
 		args = append(args, opts.args()...)
 	}
-	return r.runOrError(r.command(context, append(args, id)...))
+	return r.runOrError(context, r.command(append(args, id)...))
 }
 
 // KillOpts specifies options for killing a container and its processes
@@ -293,17 +293,17 @@ func (r *Runc) Kill(context context.Context, id string, sig int, opts *KillOpts)
 	if opts != nil {
 		args = append(args, opts.args()...)
 	}
-	return r.runOrError(r.command(context, append(args, id, strconv.Itoa(sig))...))
+	return r.runOrError(context, r.command(append(args, id, strconv.Itoa(sig))...))
 }
 
 // Stats return the stats for a container like cpu, memory, and io
 func (r *Runc) Stats(context context.Context, id string) (*Stats, error) {
-	cmd := r.command(context, "events", "--stats", id)
+	cmd := r.command("events", "--stats", id)
 	rd, err := cmd.StdoutPipe()
 	if err != nil {
 		return nil, err
 	}
-	ec, err := Monitor.Start(cmd)
+	ec, err := Monitor.Start(context, cmd)
 	if err != nil {
 		return nil, err
 	}
@@ -320,12 +320,12 @@ func (r *Runc) Stats(context context.Context, id string) (*Stats, error) {
 
 // Events returns an event stream from runc for a container with stats and OOM notifications
 func (r *Runc) Events(context context.Context, id string, interval time.Duration) (chan *Event, error) {
-	cmd := r.command(context, "events", fmt.Sprintf("--interval=%ds", int(interval.Seconds())), id)
+	cmd := r.command("events", fmt.Sprintf("--interval=%ds", int(interval.Seconds())), id)
 	rd, err := cmd.StdoutPipe()
 	if err != nil {
 		return nil, err
 	}
-	ec, err := Monitor.Start(cmd)
+	ec, err := Monitor.Start(context, cmd)
 	if err != nil {
 		rd.Close()
 		return nil, err
@@ -359,17 +359,17 @@ func (r *Runc) Events(context context.Context, id string, interval time.Duration
 
 // Pause the container with the provided id
 func (r *Runc) Pause(context context.Context, id string) error {
-	return r.runOrError(r.command(context, "pause", id))
+	return r.runOrError(context, r.command("pause", id))
 }
 
 // Resume the container with the provided id
 func (r *Runc) Resume(context context.Context, id string) error {
-	return r.runOrError(r.command(context, "resume", id))
+	return r.runOrError(context, r.command("resume", id))
 }
 
 // Ps lists all the processes inside the container returning their pids
 func (r *Runc) Ps(context context.Context, id string) ([]int, error) {
-	data, err := cmdOutput(r.command(context, "ps", "--format", "json", id), true)
+	data, err := cmdOutput(context, r.command("ps", "--format", "json", id), true)
 	if err != nil {
 		return nil, fmt.Errorf("%s: %s", err, data)
 	}
@@ -467,7 +467,7 @@ func (r *Runc) Checkpoint(context context.Context, id string, opts *CheckpointOp
 	for _, a := range actions {
 		args = a(args)
 	}
-	return r.runOrError(r.command(context, append(args, id)...))
+	return r.runOrError(context, r.command(append(args, id)...))
 }
 
 type RestoreOpts struct {
@@ -512,11 +512,11 @@ func (r *Runc) Restore(context context.Context, id, bundle string, opts *Restore
 		args = append(args, oargs...)
 	}
 	args = append(args, "--bundle", bundle)
-	cmd := r.command(context, append(args, id)...)
+	cmd := r.command(append(args, id)...)
 	if opts != nil {
 		opts.Set(cmd)
 	}
-	ec, err := Monitor.Start(cmd)
+	ec, err := Monitor.Start(context, cmd)
 	if err != nil {
 		return -1, err
 	}
@@ -537,9 +537,9 @@ func (r *Runc) Update(context context.Context, id string, resources *specs.Linux
 		return err
 	}
 	args := []string{"update", "--resources", "-", id}
-	cmd := r.command(context, args...)
+	cmd := r.command(args...)
 	cmd.Stdin = buf
-	return r.runOrError(cmd)
+	return r.runOrError(context, cmd)
 }
 
 var ErrParseRuncVersion = errors.New("unable to parse runc version")
@@ -552,7 +552,7 @@ type Version struct {
 
 // Version returns the runc and runtime-spec versions
 func (r *Runc) Version(context context.Context) (Version, error) {
-	data, err := cmdOutput(r.command(context, "--version"), false)
+	data, err := cmdOutput(context, r.command("--version"), false)
 	if err != nil {
 		return Version{}, err
 	}
@@ -618,9 +618,9 @@ func (r *Runc) args() (out []string) {
 // encountered and neither Stdout or Stderr was set the error and the
 // stderr of the command will be returned in the format of <error>:
 // <stderr>
-func (r *Runc) runOrError(cmd *exec.Cmd) error {
+func (r *Runc) runOrError(ctx context.Context, cmd *exec.Cmd) error {
 	if cmd.Stdout != nil || cmd.Stderr != nil {
-		ec, err := Monitor.Start(cmd)
+		ec, err := Monitor.Start(ctx, cmd)
 		if err != nil {
 			return err
 		}
@@ -630,21 +630,21 @@ func (r *Runc) runOrError(cmd *exec.Cmd) error {
 		}
 		return err
 	}
-	data, err := cmdOutput(cmd, true)
+	data, err := cmdOutput(ctx, cmd, true)
 	if err != nil {
 		return fmt.Errorf("%s: %s", err, data)
 	}
 	return nil
 }
 
-func cmdOutput(cmd *exec.Cmd, combined bool) ([]byte, error) {
+func cmdOutput(ctx context.Context, cmd *exec.Cmd, combined bool) ([]byte, error) {
 	var b bytes.Buffer
 
 	cmd.Stdout = &b
 	if combined {
 		cmd.Stderr = &b
 	}
-	ec, err := Monitor.Start(cmd)
+	ec, err := Monitor.Start(ctx, cmd)
 	if err != nil {
 		return nil, err
 	}