Skip to content

Commit 0b0aac0

Browse files
committed
fix: Use of rw and r.Context() in workspaceAgentListen
1 parent f1357a4 commit 0b0aac0

File tree

1 file changed

+14
-12
lines changed

1 file changed

+14
-12
lines changed

coderd/workspaceagents.go

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -147,29 +147,31 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
147147
defer api.websocketWaitGroup.Done()
148148

149149
workspaceAgent := httpmw.WorkspaceAgent(r)
150-
conn, err := websocket.Accept(rw, r, &websocket.AcceptOptions{
151-
CompressionMode: websocket.CompressionDisabled,
152-
})
150+
151+
resource, err := api.Database.GetWorkspaceResourceByID(r.Context(), workspaceAgent.ResourceID)
153152
if err != nil {
154153
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
155154
Message: fmt.Sprintf("accept websocket: %s", err),
156155
})
157156
return
158157
}
159-
resource, err := api.Database.GetWorkspaceResourceByID(r.Context(), workspaceAgent.ResourceID)
158+
159+
conn, err := websocket.Accept(rw, r, &websocket.AcceptOptions{
160+
CompressionMode: websocket.CompressionDisabled,
161+
})
160162
if err != nil {
161163
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
162164
Message: fmt.Sprintf("accept websocket: %s", err),
163165
})
164166
return
165167
}
166168

167-
defer func() {
168-
_ = conn.Close(websocket.StatusNormalClosure, "")
169-
}()
169+
ctx, wsNetConn := websocketNetConn(r.Context(), conn, websocket.MessageBinary)
170+
defer wsNetConn.Close() // Also closes conn.
171+
170172
config := yamux.DefaultConfig()
171173
config.LogOutput = io.Discard
172-
session, err := yamux.Server(websocket.NetConn(r.Context(), conn, websocket.MessageBinary), config)
174+
session, err := yamux.Server(wsNetConn, config)
173175
if err != nil {
174176
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
175177
return
@@ -197,7 +199,7 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
197199
}
198200
disconnectedAt := workspaceAgent.DisconnectedAt
199201
updateConnectionTimes := func() error {
200-
err = api.Database.UpdateWorkspaceAgentConnectionByID(r.Context(), database.UpdateWorkspaceAgentConnectionByIDParams{
202+
err = api.Database.UpdateWorkspaceAgentConnectionByID(ctx, database.UpdateWorkspaceAgentConnectionByIDParams{
201203
ID: workspaceAgent.ID,
202204
FirstConnectedAt: firstConnectedAt,
203205
LastConnectedAt: lastConnectedAt,
@@ -208,15 +210,15 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
208210
}
209211
return nil
210212
}
211-
build, err := api.Database.GetWorkspaceBuildByJobID(r.Context(), resource.JobID)
213+
build, err := api.Database.GetWorkspaceBuildByJobID(ctx, resource.JobID)
212214
if err != nil {
213215
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
214216
return
215217
}
216218
// Ensure the resource is still valid!
217219
// We only accept agents for resources on the latest build.
218220
ensureLatestBuild := func() error {
219-
latestBuild, err := api.Database.GetLatestWorkspaceBuildByWorkspaceID(r.Context(), build.WorkspaceID)
221+
latestBuild, err := api.Database.GetLatestWorkspaceBuildByWorkspaceID(ctx, build.WorkspaceID)
220222
if err != nil {
221223
return err
222224
}
@@ -245,7 +247,7 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
245247
return
246248
}
247249

248-
api.Logger.Info(r.Context(), "accepting agent", slog.F("resource", resource), slog.F("agent", workspaceAgent))
250+
api.Logger.Info(ctx, "accepting agent", slog.F("resource", resource), slog.F("agent", workspaceAgent))
249251

250252
ticker := time.NewTicker(api.AgentConnectionUpdateFrequency)
251253
defer ticker.Stop()

0 commit comments

Comments
 (0)