Skip to content

fix: use a waitgroup to ensure all connections are cleaned up in agent #5910

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 63 additions & 35 deletions agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -398,24 +398,28 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (_
}
}()
if err = a.trackConnGoroutine(func() {
var wg sync.WaitGroup
for {
conn, err := sshListener.Accept()
if err != nil {
return
break
}
wg.Add(1)
closed := make(chan struct{})
_ = a.trackConnGoroutine(func() {
go func() {
select {
case <-network.Closed():
case <-closed:
case <-a.closed:
_ = conn.Close()
}
_ = conn.Close()
})
_ = a.trackConnGoroutine(func() {
wg.Done()
}()
go func() {
defer close(closed)
a.sshServer.HandleConn(conn)
})
}()
}
wg.Wait()
}); err != nil {
return nil, err
}
Expand All @@ -431,35 +435,47 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (_
}()
if err = a.trackConnGoroutine(func() {
logger := a.logger.Named("reconnecting-pty")

var wg sync.WaitGroup
for {
conn, err := reconnectingPTYListener.Accept()
if err != nil {
logger.Debug(ctx, "accept pty failed", slog.Error(err))
return
}
// This cannot use a JSON decoder, since that can
// buffer additional data that is required for the PTY.
rawLen := make([]byte, 2)
_, err = conn.Read(rawLen)
if err != nil {
continue
}
length := binary.LittleEndian.Uint16(rawLen)
data := make([]byte, length)
_, err = conn.Read(data)
if err != nil {
continue
}
var msg codersdk.WorkspaceAgentReconnectingPTYInit
err = json.Unmarshal(data, &msg)
if err != nil {
continue
break
}
wg.Add(1)
closed := make(chan struct{})
go func() {
select {
case <-closed:
case <-a.closed:
_ = conn.Close()
}
wg.Done()
}()
go func() {
defer close(closed)
// This cannot use a JSON decoder, since that can
// buffer additional data that is required for the PTY.
rawLen := make([]byte, 2)
_, err = conn.Read(rawLen)
if err != nil {
return
}
length := binary.LittleEndian.Uint16(rawLen)
data := make([]byte, length)
_, err = conn.Read(data)
if err != nil {
return
}
var msg codersdk.WorkspaceAgentReconnectingPTYInit
err = json.Unmarshal(data, &msg)
if err != nil {
return
}
_ = a.handleReconnectingPTY(ctx, logger, msg, conn)
}()
}
wg.Wait()
}); err != nil {
return nil, err
}
Expand All @@ -474,20 +490,29 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (_
}
}()
if err = a.trackConnGoroutine(func() {
var wg sync.WaitGroup
for {
conn, err := speedtestListener.Accept()
if err != nil {
a.logger.Debug(ctx, "speedtest listener failed", slog.Error(err))
return
break
}
if err = a.trackConnGoroutine(func() {
wg.Add(1)
closed := make(chan struct{})
go func() {
select {
case <-closed:
case <-a.closed:
_ = conn.Close()
}
wg.Done()
}()
go func() {
defer close(closed)
_ = speedtest.ServeConn(conn)
}); err != nil {
a.logger.Debug(ctx, "speedtest listener failed", slog.Error(err))
_ = conn.Close()
return
}
}()
}
wg.Wait()
}); err != nil {
return nil, err
}
Expand All @@ -511,7 +536,10 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (_
ErrorLog: slog.Stdlib(ctx, a.logger.Named("statistics_http_server"), slog.LevelInfo),
}
go func() {
<-ctx.Done()
select {
case <-ctx.Done():
case <-a.closed:
}
_ = server.Close()
}()

Expand Down
4 changes: 0 additions & 4 deletions scaletest/reconnectingpty/run_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,6 @@ import (

func Test_Runner(t *testing.T) {
t.Parallel()
// There's a race condition in agent/agent.go where connections
// aren't closed when the Tailnet connection is. This causes the
// goroutines to hang around and cause the test to fail.
t.Skip("TODO: fix this test")

t.Run("OK", func(t *testing.T) {
t.Parallel()
Expand Down