Skip to content

Commit 087bcd5

Browse files
committed
fix: Fix goroutine leak by propagating websocket closure
Fixes #1508
1 parent 51c420c commit 087bcd5

File tree

1 file changed

+52
-8
lines changed

1 file changed

+52
-8
lines changed

coderd/workspaceagents.go

Lines changed: 52 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
package coderd
22

33
import (
4+
"context"
45
"database/sql"
56
"encoding/json"
7+
"errors"
68
"fmt"
79
"io"
810
"net"
@@ -16,6 +18,7 @@ import (
1618
"nhooyr.io/websocket"
1719

1820
"cdr.dev/slog"
21+
1922
"github.com/coder/coder/agent"
2023
"github.com/coder/coder/coderd/database"
2124
"github.com/coder/coder/coderd/httpapi"
@@ -310,16 +313,16 @@ func (api *API) workspaceAgentTurn(rw http.ResponseWriter, r *http.Request) {
310313
})
311314
return
312315
}
313-
defer func() {
314-
_ = wsConn.Close(websocket.StatusNormalClosure, "")
315-
}()
316-
netConn := websocket.NetConn(r.Context(), wsConn, websocket.MessageBinary)
317-
api.Logger.Debug(r.Context(), "accepting turn connection", slog.F("remote-address", r.RemoteAddr), slog.F("local-address", localAddress))
316+
317+
ctx, wsNetConn := websocketNetConn(r.Context(), wsConn, websocket.MessageBinary)
318+
defer wsNetConn.Close() // Also closes conn.
319+
320+
api.Logger.Debug(ctx, "accepting turn connection", slog.F("remote-address", r.RemoteAddr), slog.F("local-address", localAddress))
318321
select {
319-
case <-api.TURNServer.Accept(netConn, remoteAddress, localAddress).Closed():
320-
case <-r.Context().Done():
322+
case <-api.TURNServer.Accept(wsNetConn, remoteAddress, localAddress).Closed():
323+
case <-ctx.Done():
321324
}
322-
api.Logger.Debug(r.Context(), "completed turn connection", slog.F("remote-address", r.RemoteAddr), slog.F("local-address", localAddress))
325+
api.Logger.Debug(ctx, "completed turn connection", slog.F("remote-address", r.RemoteAddr), slog.F("local-address", localAddress))
323326
}
324327

325328
// workspaceAgentPTY spawns a PTY and pipes it over a WebSocket.
@@ -501,3 +504,44 @@ func convertWorkspaceAgent(dbAgent database.WorkspaceAgent, agentUpdateFrequency
501504

502505
return workspaceAgent, nil
503506
}
507+
508+
// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func
509+
// is called if io.EOF is encountered.
510+
type wsNetConn struct {
511+
cancel context.CancelFunc
512+
net.Conn
513+
}
514+
515+
func (c *wsNetConn) Read(b []byte) (n int, err error) {
516+
n, err = c.Conn.Read(b)
517+
if errors.Is(err, io.EOF) {
518+
c.cancel()
519+
}
520+
return n, err
521+
}
522+
523+
func (c *wsNetConn) Write(b []byte) (n int, err error) {
524+
n, err = c.Conn.Write(b)
525+
if errors.Is(err, io.EOF) {
526+
c.cancel()
527+
}
528+
return n, err
529+
}
530+
531+
func (c *wsNetConn) Close() error {
532+
defer c.cancel()
533+
return c.Conn.Close()
534+
}
535+
536+
// websocketNetConn wraps websocket.NetConn and returns a context that
537+
// is tied to the parent context and the lifetime of the conn. A io.EOF
538+
// error during read or write will cancel the context, but not close the
539+
// conn. Close should be called to release context resources.
540+
func websocketNetConn(ctx context.Context, conn *websocket.Conn, msgType websocket.MessageType) (context.Context, net.Conn) {
541+
ctx, cancel := context.WithCancel(ctx)
542+
nc := websocket.NetConn(ctx, conn, msgType)
543+
return ctx, &wsNetConn{
544+
cancel: cancel,
545+
Conn: nc,
546+
}
547+
}

0 commit comments

Comments
 (0)