Skip to content

Commit 975ef8b

Browse files
chore: fix disconnect bug and add agentcontainers test
1 parent 5aef560 commit 975ef8b

File tree

3 files changed

+82
-3
lines changed

3 files changed

+82
-3
lines changed

agent/agentcontainers/api.go

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -560,14 +560,20 @@ func (api *API) watchContainers(rw http.ResponseWriter, r *http.Request) {
560560
return
561561
}
562562

563+
ctx = api.ctx
564+
563565
go httpapi.Heartbeat(ctx, conn)
564566
defer conn.Close(websocket.StatusNormalClosure, "connection closed")
565567

566568
encoder := wsjson.NewEncoder[codersdk.WorkspaceAgentListContainersResponse](conn, websocket.MessageText)
567569
defer encoder.Close(websocket.StatusNormalClosure)
568570

569-
updateCh := make(chan struct{})
570-
defer close(updateCh)
571+
updateCh := make(chan struct{}, 1)
572+
defer func() {
573+
api.mu.Lock()
574+
close(updateCh)
575+
api.mu.Unlock()
576+
}()
571577

572578
api.mu.Lock()
573579
api.updateChans = append(api.updateChans, updateCh)
@@ -644,7 +650,10 @@ func (api *API) updateContainers(ctx context.Context) error {
644650

645651
// Broadcast our updates
646652
for _, ch := range api.updateChans {
647-
ch <- struct{}{}
653+
select {
654+
case ch <- struct{}{}:
655+
default:
656+
}
648657
}
649658

650659
api.logger.Debug(ctx, "containers updated successfully", slog.F("container_count", len(api.containers.Containers)), slog.F("warning_count", len(api.containers.Warnings)), slog.F("devcontainer_count", len(api.knownDevcontainers)))

agent/agentcontainers/api_test.go

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ import (
3636
"github.com/coder/coder/v2/pty"
3737
"github.com/coder/coder/v2/testutil"
3838
"github.com/coder/quartz"
39+
"github.com/coder/websocket"
3940
)
4041

4142
// fakeContainerCLI implements the agentcontainers.ContainerCLI interface for
@@ -441,6 +442,73 @@ func TestAPI(t *testing.T) {
441442
logbuf.Reset()
442443
})
443444

445+
t.Run("Watch", func(t *testing.T) {
446+
t.Parallel()
447+
448+
fakeContainer1 := fakeContainer(t)
449+
fakeContainer2 := fakeContainer(t)
450+
fakeContainer3 := fakeContainer(t)
451+
452+
makeResponse := func(cts ...codersdk.WorkspaceAgentContainer) codersdk.WorkspaceAgentListContainersResponse {
453+
return codersdk.WorkspaceAgentListContainersResponse{Containers: cts}
454+
}
455+
456+
var (
457+
ctx = testutil.Context(t, testutil.WaitShort)
458+
mClock = quartz.NewMock(t)
459+
updaterTickerTrap = mClock.Trap().TickerFunc("updaterLoop")
460+
mCtrl = gomock.NewController(t)
461+
mLister = acmock.NewMockContainerCLI(mCtrl)
462+
logger = slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
463+
)
464+
465+
mLister.EXPECT().List(gomock.Any()).Return(makeResponse(), nil)
466+
467+
api := agentcontainers.NewAPI(logger,
468+
agentcontainers.WithClock(mClock),
469+
agentcontainers.WithContainerCLI(mLister),
470+
agentcontainers.WithContainerLabelIncludeFilter("this.label.does.not.exist.ignore.devcontainers", "true"),
471+
)
472+
api.Start()
473+
defer api.Close()
474+
475+
srv := httptest.NewServer(api.Routes())
476+
defer srv.Close()
477+
478+
updaterTickerTrap.MustWait(ctx).MustRelease(ctx)
479+
defer updaterTickerTrap.Close()
480+
481+
client, _, err := websocket.Dial(ctx, srv.URL+"/watch", nil)
482+
require.NoError(t, err)
483+
484+
for _, mockResponse := range []codersdk.WorkspaceAgentListContainersResponse{
485+
makeResponse(),
486+
makeResponse(fakeContainer1),
487+
makeResponse(fakeContainer1, fakeContainer2),
488+
makeResponse(fakeContainer1, fakeContainer2, fakeContainer3),
489+
makeResponse(fakeContainer1, fakeContainer2),
490+
makeResponse(fakeContainer1),
491+
makeResponse(),
492+
} {
493+
mLister.EXPECT().List(gomock.Any()).Return(mockResponse, nil)
494+
495+
// Given: We allow the update loop to progress
496+
_, aw := mClock.AdvanceNext()
497+
aw.MustWait(ctx)
498+
499+
// When: We attempt to read a message from the socket.
500+
mt, msg, err := client.Read(ctx)
501+
require.NoError(t, err)
502+
require.Equal(t, websocket.MessageText, mt)
503+
504+
// Then: We expect the receieved message matches the mocked response.
505+
var got codersdk.WorkspaceAgentListContainersResponse
506+
err = json.Unmarshal(msg, &got)
507+
require.NoError(t, err)
508+
require.Equal(t, mockResponse, got)
509+
}
510+
})
511+
444512
// List tests the API.getContainers method using a mock
445513
// implementation. It specifically tests caching behavior.
446514
t.Run("List", func(t *testing.T) {

coderd/workspaceagents.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -865,6 +865,8 @@ func (api *API) watchWorkspaceAgentContainers(rw http.ResponseWriter, r *http.Re
865865
return
866866
}
867867

868+
ctx = api.ctx
869+
868870
go httpapi.Heartbeat(ctx, conn)
869871
defer conn.Close(websocket.StatusNormalClosure, "connection closed")
870872

0 commit comments

Comments
 (0)