Skip to content

Commit e73c9aa

Browse files
authored
Merge branch 'main' into mafredri/feat-shutdown-script
2 parents caa49d7 + 4432cd0 commit e73c9aa

37 files changed

+1328
-356
lines changed

agent/agent.go

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ type Options struct {
7272
type Client interface {
7373
Metadata(ctx context.Context) (agentsdk.Metadata, error)
7474
Listen(ctx context.Context) (net.Conn, error)
75-
ReportStats(ctx context.Context, log slog.Logger, stats func() *agentsdk.Stats) (io.Closer, error)
75+
ReportStats(ctx context.Context, log slog.Logger, statsChan <-chan *agentsdk.Stats, setInterval func(time.Duration)) (io.Closer, error)
7676
PostLifecycle(ctx context.Context, state agentsdk.PostLifecycleRequest) error
7777
PostAppHealth(ctx context.Context, req agentsdk.PostAppHealthsRequest) error
7878
PostStartup(ctx context.Context, req agentsdk.PostStartupRequest) error
@@ -113,6 +113,7 @@ func New(options Options) io.Closer {
113113
tempDir: options.TempDir,
114114
lifecycleUpdate: make(chan struct{}, 1),
115115
lifecycleReported: make(chan codersdk.WorkspaceAgentLifecycle, 1),
116+
connStatsChan: make(chan *agentsdk.Stats, 1),
116117
}
117118
a.init(ctx)
118119
return a
@@ -145,7 +146,8 @@ type agent struct {
145146
lifecycleMu sync.RWMutex // Protects following.
146147
lifecycleState codersdk.WorkspaceAgentLifecycle
147148

148-
network *tailnet.Conn
149+
network *tailnet.Conn
150+
connStatsChan chan *agentsdk.Stats
149151
}
150152

151153
// runLoop attempts to start the agent in a retry loop.
@@ -365,11 +367,20 @@ func (a *agent) run(ctx context.Context) error {
365367
return xerrors.New("agent is closed")
366368
}
367369

370+
setStatInterval := func(d time.Duration) {
371+
network.SetConnStatsCallback(d, 2048,
372+
func(_, _ time.Time, virtual, _ map[netlogtype.Connection]netlogtype.Counts) {
373+
select {
374+
case a.connStatsChan <- convertAgentStats(virtual):
375+
default:
376+
a.logger.Warn(ctx, "network stat dropped")
377+
}
378+
},
379+
)
380+
}
381+
368382
// Report statistics from the created network.
369-
cl, err := a.client.ReportStats(ctx, a.logger, func() *agentsdk.Stats {
370-
stats := network.ExtractTrafficStats()
371-
return convertAgentStats(stats)
372-
})
383+
cl, err := a.client.ReportStats(ctx, a.logger, a.connStatsChan, setStatInterval)
373384
if err != nil {
374385
a.logger.Error(ctx, "report stats", slog.Error(err))
375386
} else {
@@ -413,10 +424,9 @@ func (a *agent) trackConnGoroutine(fn func()) error {
413424

414425
func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (_ *tailnet.Conn, err error) {
415426
network, err := tailnet.NewConn(&tailnet.Options{
416-
Addresses: []netip.Prefix{netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128)},
417-
DERPMap: derpMap,
418-
Logger: a.logger.Named("tailnet"),
419-
EnableTrafficStats: true,
427+
Addresses: []netip.Prefix{netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128)},
428+
DERPMap: derpMap,
429+
Logger: a.logger.Named("tailnet"),
420430
})
421431
if err != nil {
422432
return nil, xerrors.Errorf("create tailnet: %w", err)

agent/agent_test.go

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,6 @@ import (
2222
"testing"
2323
"time"
2424

25-
"golang.org/x/xerrors"
26-
"tailscale.com/net/speedtest"
27-
"tailscale.com/tailcfg"
28-
2925
scp "github.com/bramvdbogaerde/go-scp"
3026
"github.com/google/uuid"
3127
"github.com/pion/udp"
@@ -37,6 +33,9 @@ import (
3733
"golang.org/x/crypto/ssh"
3834
"golang.org/x/text/encoding/unicode"
3935
"golang.org/x/text/transform"
36+
"golang.org/x/xerrors"
37+
"tailscale.com/net/speedtest"
38+
"tailscale.com/tailcfg"
4039

4140
"cdr.dev/slog"
4241
"cdr.dev/slog/sloggers/slogtest"
@@ -53,6 +52,8 @@ func TestMain(m *testing.M) {
5352
goleak.VerifyTestMain(m)
5453
}
5554

55+
// NOTE: These tests only work when your default shell is bash for some reason.
56+
5657
func TestAgent_Stats_SSH(t *testing.T) {
5758
t.Parallel()
5859
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
@@ -1341,17 +1342,16 @@ func setupAgent(t *testing.T, metadata agentsdk.Metadata, ptyTimeout time.Durati
13411342
closer := agent.New(agent.Options{
13421343
Client: c,
13431344
Filesystem: fs,
1344-
Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug),
1345+
Logger: slogtest.Make(t, nil).Named("agent").Leveled(slog.LevelDebug),
13451346
ReconnectingPTYTimeout: ptyTimeout,
13461347
})
13471348
t.Cleanup(func() {
13481349
_ = closer.Close()
13491350
})
13501351
conn, err := tailnet.NewConn(&tailnet.Options{
1351-
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
1352-
DERPMap: metadata.DERPMap,
1353-
Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug),
1354-
EnableTrafficStats: true,
1352+
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
1353+
DERPMap: metadata.DERPMap,
1354+
Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug),
13551355
})
13561356
require.NoError(t, err)
13571357
clientConn, serverConn := net.Pipe()
@@ -1439,28 +1439,27 @@ func (c *client) Listen(_ context.Context) (net.Conn, error) {
14391439
return clientConn, nil
14401440
}
14411441

1442-
func (c *client) ReportStats(ctx context.Context, _ slog.Logger, stats func() *agentsdk.Stats) (io.Closer, error) {
1442+
func (c *client) ReportStats(ctx context.Context, _ slog.Logger, statsChan <-chan *agentsdk.Stats, setInterval func(time.Duration)) (io.Closer, error) {
14431443
doneCh := make(chan struct{})
14441444
ctx, cancel := context.WithCancel(ctx)
14451445

14461446
go func() {
14471447
defer close(doneCh)
14481448

1449-
t := time.NewTicker(500 * time.Millisecond)
1450-
defer t.Stop()
1449+
setInterval(500 * time.Millisecond)
14511450
for {
14521451
select {
14531452
case <-ctx.Done():
14541453
return
1455-
case <-t.C:
1456-
}
1457-
select {
1458-
case c.statsChan <- stats():
1459-
case <-ctx.Done():
1460-
return
1461-
default:
1462-
// We don't want to send old stats.
1463-
continue
1454+
case stat := <-statsChan:
1455+
select {
1456+
case c.statsChan <- stat:
1457+
case <-ctx.Done():
1458+
return
1459+
default:
1460+
// We don't want to send old stats.
1461+
continue
1462+
}
14641463
}
14651464
}
14661465
}()

agent/reaper/reaper.go

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
package reaper
22

3-
import "github.com/hashicorp/go-reap"
3+
import (
4+
"os"
5+
6+
"github.com/hashicorp/go-reap"
7+
)
48

59
type Option func(o *options)
610

@@ -22,7 +26,16 @@ func WithPIDCallback(ch reap.PidCh) Option {
2226
}
2327
}
2428

29+
// WithCatchSignals sets the signals that are caught and forwarded to the
30+
// child process. By default no signals are forwarded.
31+
func WithCatchSignals(sigs ...os.Signal) Option {
32+
return func(o *options) {
33+
o.CatchSignals = sigs
34+
}
35+
}
36+
2537
type options struct {
26-
ExecArgs []string
27-
PIDs reap.PidCh
38+
ExecArgs []string
39+
PIDs reap.PidCh
40+
CatchSignals []os.Signal
2841
}

agent/reaper/reaper_test.go

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,11 @@
33
package reaper_test
44

55
import (
6+
"fmt"
67
"os"
78
"os/exec"
9+
"os/signal"
10+
"syscall"
811
"testing"
912
"time"
1013

@@ -15,9 +18,8 @@ import (
1518
"github.com/coder/coder/testutil"
1619
)
1720

21+
//nolint:paralleltest // Non-parallel subtest.
1822
func TestReap(t *testing.T) {
19-
t.Parallel()
20-
2123
// Don't run the reaper test in CI. It does weird
2224
// things like forkexecing which may have unintended
2325
// consequences in CI.
@@ -28,8 +30,9 @@ func TestReap(t *testing.T) {
2830
// OK checks that's the reaper is successfully reaping
2931
// exited processes and passing the PIDs through the shared
3032
// channel.
33+
34+
//nolint:paralleltest // Signal handling.
3135
t.Run("OK", func(t *testing.T) {
32-
t.Parallel()
3336
pids := make(reap.PidCh, 1)
3437
err := reaper.ForkReap(
3538
reaper.WithPIDCallback(pids),
@@ -64,3 +67,39 @@ func TestReap(t *testing.T) {
6467
}
6568
})
6669
}
70+
71+
//nolint:paralleltest // Signal handling.
72+
func TestReapInterrupt(t *testing.T) {
73+
// Don't run the reaper test in CI. It does weird
74+
// things like forkexecing which may have unintended
75+
// consequences in CI.
76+
if _, ok := os.LookupEnv("CI"); ok {
77+
t.Skip("Detected CI, skipping reaper tests")
78+
}
79+
80+
errC := make(chan error, 1)
81+
pids := make(reap.PidCh, 1)
82+
83+
// Use signals to notify when the child process is ready for the
84+
// next step of our test.
85+
usrSig := make(chan os.Signal, 1)
86+
signal.Notify(usrSig, syscall.SIGUSR1, syscall.SIGUSR2)
87+
defer signal.Stop(usrSig)
88+
89+
go func() {
90+
errC <- reaper.ForkReap(
91+
reaper.WithPIDCallback(pids),
92+
reaper.WithCatchSignals(os.Interrupt),
93+
// Signal propagation does not extend to children of children, so
94+
// we create a little bash script to ensure sleep is interrupted.
95+
reaper.WithExecArgs("/bin/sh", "-c", fmt.Sprintf("pid=0; trap 'kill -USR2 %d; kill -TERM $pid' INT; sleep 10 &\npid=$!; kill -USR1 %d; wait", os.Getpid(), os.Getpid())),
96+
)
97+
}()
98+
99+
require.Equal(t, <-usrSig, syscall.SIGUSR1)
100+
err := syscall.Kill(os.Getpid(), syscall.SIGINT)
101+
require.NoError(t, err)
102+
require.Equal(t, <-usrSig, syscall.SIGUSR2)
103+
104+
require.NoError(t, <-errC)
105+
}

agent/reaper/reaper_unix.go

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ package reaper
44

55
import (
66
"os"
7+
"os/signal"
78
"syscall"
89

910
"github.com/hashicorp/go-reap"
@@ -15,6 +16,24 @@ func IsInitProcess() bool {
1516
return os.Getpid() == 1
1617
}
1718

19+
func catchSignals(pid int, sigs []os.Signal) {
20+
if len(sigs) == 0 {
21+
return
22+
}
23+
24+
sc := make(chan os.Signal, 1)
25+
signal.Notify(sc, sigs...)
26+
defer signal.Stop(sc)
27+
28+
for {
29+
s := <-sc
30+
sig, ok := s.(syscall.Signal)
31+
if ok {
32+
_ = syscall.Kill(pid, sig)
33+
}
34+
}
35+
}
36+
1837
// ForkReap spawns a goroutine that reaps children. In order to avoid
1938
// complications with spawning `exec.Commands` in the same process that
2039
// is reaping, we forkexec a child process. This prevents a race between
@@ -51,13 +70,17 @@ func ForkReap(opt ...Option) error {
5170
}
5271

5372
//#nosec G204
54-
pid, _ := syscall.ForkExec(opts.ExecArgs[0], opts.ExecArgs, pattrs)
73+
pid, err := syscall.ForkExec(opts.ExecArgs[0], opts.ExecArgs, pattrs)
74+
if err != nil {
75+
return xerrors.Errorf("fork exec: %w", err)
76+
}
77+
78+
go catchSignals(pid, opts.CatchSignals)
5579

5680
var wstatus syscall.WaitStatus
5781
_, err = syscall.Wait4(pid, &wstatus, 0, nil)
5882
for xerrors.Is(err, syscall.EINTR) {
5983
_, err = syscall.Wait4(pid, &wstatus, 0, nil)
6084
}
61-
62-
return nil
85+
return err
6386
}

cli/agent.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,10 @@ func workspaceAgent() *cobra.Command {
6868
// Do not start a reaper on the child process. It's important
6969
// to do this else we fork bomb ourselves.
7070
args := append(os.Args, "--no-reap")
71-
err := reaper.ForkReap(reaper.WithExecArgs(args...))
71+
err := reaper.ForkReap(
72+
reaper.WithExecArgs(args...),
73+
reaper.WithCatchSignals(InterruptSignals...),
74+
)
7275
if err != nil {
7376
logger.Error(ctx, "failed to reap", slog.Error(err))
7477
return xerrors.Errorf("fork reap: %w", err)

0 commit comments

Comments
 (0)