Skip to content

Commit 0d08065

Browse files
authored
fix: use a waitgroup to ensure all connections are cleaned up in agent (#5910)
* fix: use a waitgroup to ensure all connections are cleaned up in agent There was a race where connections would be created at the same time as close. The `net.Conn` produced by Tailscale doesn't close then the listener does. * Remove accidental test
1 parent ce36a84 commit 0d08065

File tree

2 files changed

+63
-39
lines changed

2 files changed

+63
-39
lines changed

agent/agent.go

Lines changed: 63 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -398,24 +398,28 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (_
398398
}
399399
}()
400400
if err = a.trackConnGoroutine(func() {
401+
var wg sync.WaitGroup
401402
for {
402403
conn, err := sshListener.Accept()
403404
if err != nil {
404-
return
405+
break
405406
}
407+
wg.Add(1)
406408
closed := make(chan struct{})
407-
_ = a.trackConnGoroutine(func() {
409+
go func() {
408410
select {
409-
case <-network.Closed():
410411
case <-closed:
412+
case <-a.closed:
413+
_ = conn.Close()
411414
}
412-
_ = conn.Close()
413-
})
414-
_ = a.trackConnGoroutine(func() {
415+
wg.Done()
416+
}()
417+
go func() {
415418
defer close(closed)
416419
a.sshServer.HandleConn(conn)
417-
})
420+
}()
418421
}
422+
wg.Wait()
419423
}); err != nil {
420424
return nil, err
421425
}
@@ -431,35 +435,47 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (_
431435
}()
432436
if err = a.trackConnGoroutine(func() {
433437
logger := a.logger.Named("reconnecting-pty")
434-
438+
var wg sync.WaitGroup
435439
for {
436440
conn, err := reconnectingPTYListener.Accept()
437441
if err != nil {
438442
logger.Debug(ctx, "accept pty failed", slog.Error(err))
439-
return
440-
}
441-
// This cannot use a JSON decoder, since that can
442-
// buffer additional data that is required for the PTY.
443-
rawLen := make([]byte, 2)
444-
_, err = conn.Read(rawLen)
445-
if err != nil {
446-
continue
447-
}
448-
length := binary.LittleEndian.Uint16(rawLen)
449-
data := make([]byte, length)
450-
_, err = conn.Read(data)
451-
if err != nil {
452-
continue
453-
}
454-
var msg codersdk.WorkspaceAgentReconnectingPTYInit
455-
err = json.Unmarshal(data, &msg)
456-
if err != nil {
457-
continue
443+
break
458444
}
445+
wg.Add(1)
446+
closed := make(chan struct{})
447+
go func() {
448+
select {
449+
case <-closed:
450+
case <-a.closed:
451+
_ = conn.Close()
452+
}
453+
wg.Done()
454+
}()
459455
go func() {
456+
defer close(closed)
457+
// This cannot use a JSON decoder, since that can
458+
// buffer additional data that is required for the PTY.
459+
rawLen := make([]byte, 2)
460+
_, err = conn.Read(rawLen)
461+
if err != nil {
462+
return
463+
}
464+
length := binary.LittleEndian.Uint16(rawLen)
465+
data := make([]byte, length)
466+
_, err = conn.Read(data)
467+
if err != nil {
468+
return
469+
}
470+
var msg codersdk.WorkspaceAgentReconnectingPTYInit
471+
err = json.Unmarshal(data, &msg)
472+
if err != nil {
473+
return
474+
}
460475
_ = a.handleReconnectingPTY(ctx, logger, msg, conn)
461476
}()
462477
}
478+
wg.Wait()
463479
}); err != nil {
464480
return nil, err
465481
}
@@ -474,20 +490,29 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (_
474490
}
475491
}()
476492
if err = a.trackConnGoroutine(func() {
493+
var wg sync.WaitGroup
477494
for {
478495
conn, err := speedtestListener.Accept()
479496
if err != nil {
480497
a.logger.Debug(ctx, "speedtest listener failed", slog.Error(err))
481-
return
498+
break
482499
}
483-
if err = a.trackConnGoroutine(func() {
500+
wg.Add(1)
501+
closed := make(chan struct{})
502+
go func() {
503+
select {
504+
case <-closed:
505+
case <-a.closed:
506+
_ = conn.Close()
507+
}
508+
wg.Done()
509+
}()
510+
go func() {
511+
defer close(closed)
484512
_ = speedtest.ServeConn(conn)
485-
}); err != nil {
486-
a.logger.Debug(ctx, "speedtest listener failed", slog.Error(err))
487-
_ = conn.Close()
488-
return
489-
}
513+
}()
490514
}
515+
wg.Wait()
491516
}); err != nil {
492517
return nil, err
493518
}
@@ -511,7 +536,10 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (_
511536
ErrorLog: slog.Stdlib(ctx, a.logger.Named("statistics_http_server"), slog.LevelInfo),
512537
}
513538
go func() {
514-
<-ctx.Done()
539+
select {
540+
case <-ctx.Done():
541+
case <-a.closed:
542+
}
515543
_ = server.Close()
516544
}()
517545

scaletest/reconnectingpty/run_test.go

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,6 @@ import (
2323

2424
func Test_Runner(t *testing.T) {
2525
t.Parallel()
26-
// There's a race condition in agent/agent.go where connections
27-
// aren't closed when the Tailnet connection is. This causes the
28-
// goroutines to hang around and cause the test to fail.
29-
t.Skip("TODO: fix this test")
3026

3127
t.Run("OK", func(t *testing.T) {
3228
t.Parallel()

0 commit comments

Comments
 (0)