@@ -147,29 +147,31 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
147
147
defer api .websocketWaitGroup .Done ()
148
148
149
149
workspaceAgent := httpmw .WorkspaceAgent (r )
150
- conn , err := websocket .Accept (rw , r , & websocket.AcceptOptions {
151
- CompressionMode : websocket .CompressionDisabled ,
152
- })
150
+
151
+ resource , err := api .Database .GetWorkspaceResourceByID (r .Context (), workspaceAgent .ResourceID )
153
152
if err != nil {
154
153
httpapi .Write (rw , http .StatusBadRequest , httpapi.Response {
155
154
Message : fmt .Sprintf ("accept websocket: %s" , err ),
156
155
})
157
156
return
158
157
}
159
- resource , err := api .Database .GetWorkspaceResourceByID (r .Context (), workspaceAgent .ResourceID )
158
+
159
+ conn , err := websocket .Accept (rw , r , & websocket.AcceptOptions {
160
+ CompressionMode : websocket .CompressionDisabled ,
161
+ })
160
162
if err != nil {
161
163
httpapi .Write (rw , http .StatusBadRequest , httpapi.Response {
162
164
Message : fmt .Sprintf ("accept websocket: %s" , err ),
163
165
})
164
166
return
165
167
}
166
168
167
- defer func () {
168
- _ = conn .Close (websocket . StatusNormalClosure , "" )
169
- }()
169
+ ctx , wsNetConn := websocketNetConn ( r . Context (), conn , websocket . MessageBinary )
170
+ defer wsNetConn .Close () // Also closes conn.
171
+
170
172
config := yamux .DefaultConfig ()
171
173
config .LogOutput = io .Discard
172
- session , err := yamux .Server (websocket . NetConn ( r . Context (), conn , websocket . MessageBinary ) , config )
174
+ session , err := yamux .Server (wsNetConn , config )
173
175
if err != nil {
174
176
_ = conn .Close (websocket .StatusAbnormalClosure , err .Error ())
175
177
return
@@ -197,7 +199,7 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
197
199
}
198
200
disconnectedAt := workspaceAgent .DisconnectedAt
199
201
updateConnectionTimes := func () error {
200
- err = api .Database .UpdateWorkspaceAgentConnectionByID (r . Context () , database.UpdateWorkspaceAgentConnectionByIDParams {
202
+ err = api .Database .UpdateWorkspaceAgentConnectionByID (ctx , database.UpdateWorkspaceAgentConnectionByIDParams {
201
203
ID : workspaceAgent .ID ,
202
204
FirstConnectedAt : firstConnectedAt ,
203
205
LastConnectedAt : lastConnectedAt ,
@@ -208,15 +210,15 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
208
210
}
209
211
return nil
210
212
}
211
- build , err := api .Database .GetWorkspaceBuildByJobID (r . Context () , resource .JobID )
213
+ build , err := api .Database .GetWorkspaceBuildByJobID (ctx , resource .JobID )
212
214
if err != nil {
213
215
_ = conn .Close (websocket .StatusAbnormalClosure , err .Error ())
214
216
return
215
217
}
216
218
// Ensure the resource is still valid!
217
219
// We only accept agents for resources on the latest build.
218
220
ensureLatestBuild := func () error {
219
- latestBuild , err := api .Database .GetLatestWorkspaceBuildByWorkspaceID (r . Context () , build .WorkspaceID )
221
+ latestBuild , err := api .Database .GetLatestWorkspaceBuildByWorkspaceID (ctx , build .WorkspaceID )
220
222
if err != nil {
221
223
return err
222
224
}
@@ -245,7 +247,7 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
245
247
return
246
248
}
247
249
248
- api .Logger .Info (r . Context () , "accepting agent" , slog .F ("resource" , resource ), slog .F ("agent" , workspaceAgent ))
250
+ api .Logger .Info (ctx , "accepting agent" , slog .F ("resource" , resource ), slog .F ("agent" , workspaceAgent ))
249
251
250
252
ticker := time .NewTicker (api .AgentConnectionUpdateFrequency )
251
253
defer ticker .Stop ()
0 commit comments