Skip to content

Commit 651ac20

Browse files
committed
wip
1 parent e03f240 commit 651ac20

File tree

7 files changed

+471
-165
lines changed

7 files changed

+471
-165
lines changed

tailnet/controllers.go

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,6 @@ type WorkspaceUpdatesClient interface {
104104

105105
type WorkspaceUpdatesController interface {
106106
New(WorkspaceUpdatesClient) CloserWaiter
107-
CurrentState() *proto.WorkspaceUpdate
108107
}
109108

110109
// DNSHostsSetter is something that you can set a mapping of DNS names to IPs on. It's the subset
@@ -113,6 +112,11 @@ type DNSHostsSetter interface {
113112
SetDNSHosts(hosts map[dnsname.FQDN][]netip.Addr) error
114113
}
115114

115+
// UpdatesHandler is anything that expects a stream of workspace update diffs.
116+
type UpdatesHandler interface {
117+
Update(*proto.WorkspaceUpdate) error
118+
}
119+
116120
// ControlProtocolClients represents an abstract interface to the tailnet control plane via a set
117121
// of protocol clients. The Closer should close all the clients (e.g. by closing the underlying
118122
// connection).
@@ -856,12 +860,12 @@ func (r *basicResumeTokenRefresher) refresh() {
856860
r.timer.Reset(dur, "basicResumeTokenRefresher", "refresh")
857861
}
858862

859-
type tunnelAllWorkspaceUpdatesController struct {
860-
coordCtrl *TunnelSrcCoordController
861-
dnsHostSetter DNSHostsSetter
862-
updateCallback func(*proto.WorkspaceUpdate)
863-
ownerUsername string
864-
logger slog.Logger
863+
type TunnelAllWorkspaceUpdatesController struct {
864+
coordCtrl *TunnelSrcCoordController
865+
dnsHostSetter DNSHostsSetter
866+
updateHandler UpdatesHandler
867+
ownerUsername string
868+
logger slog.Logger
865869

866870
sync.Mutex
867871
updater *tunnelUpdater
@@ -906,7 +910,7 @@ type agent struct {
906910
name string
907911
}
908912

909-
func (t *tunnelAllWorkspaceUpdatesController) New(client WorkspaceUpdatesClient) CloserWaiter {
913+
func (t *TunnelAllWorkspaceUpdatesController) New(client WorkspaceUpdatesClient) CloserWaiter {
910914
t.Lock()
911915
defer t.Unlock()
912916
updater := &tunnelUpdater{
@@ -915,7 +919,7 @@ func (t *tunnelAllWorkspaceUpdatesController) New(client WorkspaceUpdatesClient)
915919
logger: t.logger,
916920
coordCtrl: t.coordCtrl,
917921
dnsHostsSetter: t.dnsHostSetter,
918-
updateCallback: t.updateCallback,
922+
updateHandler: t.updateHandler,
919923
ownerUsername: t.ownerUsername,
920924
recvLoopDone: make(chan struct{}),
921925
workspaces: make(map[uuid.UUID]*workspace),
@@ -925,7 +929,7 @@ func (t *tunnelAllWorkspaceUpdatesController) New(client WorkspaceUpdatesClient)
925929
return t.updater
926930
}
927931

928-
func (t *tunnelAllWorkspaceUpdatesController) CurrentState() *proto.WorkspaceUpdate {
932+
func (t *TunnelAllWorkspaceUpdatesController) CurrentState() *proto.WorkspaceUpdate {
929933
t.Lock()
930934
defer t.Unlock()
931935
if t.updater == nil {
@@ -960,7 +964,7 @@ type tunnelUpdater struct {
960964
client WorkspaceUpdatesClient
961965
coordCtrl *TunnelSrcCoordController
962966
dnsHostsSetter DNSHostsSetter
963-
updateCallback func(*proto.WorkspaceUpdate)
967+
updateHandler UpdatesHandler
964968
ownerUsername string
965969
recvLoopDone chan struct{}
966970

@@ -1095,9 +1099,12 @@ func (t *tunnelUpdater) handleUpdate(update *proto.WorkspaceUpdate) error {
10951099
} else {
10961100
t.logger.Debug(context.Background(), "skipping setting DNS names because we have no setter")
10971101
}
1098-
if t.updateCallback != nil {
1099-
t.logger.Debug(context.Background(), "calling update callback")
1100-
t.updateCallback(update)
1102+
if t.updateHandler != nil {
1103+
t.logger.Debug(context.Background(), "calling update handler")
1104+
err := t.updateHandler.Update(update)
1105+
if err != nil {
1106+
t.logger.Error(context.Background(), "failed to call update handler", slog.Error(err))
1107+
}
11011108
}
11021109
return nil
11031110
}
@@ -1160,20 +1167,20 @@ func (t *tunnelUpdater) allDNSNames() map[dnsname.FQDN][]netip.Addr {
11601167
return names
11611168
}
11621169

1163-
type TunnelAllOption func(t *tunnelAllWorkspaceUpdatesController)
1170+
type TunnelAllOption func(t *TunnelAllWorkspaceUpdatesController)
11641171

11651172
// WithDNS configures the tunnelAllWorkspaceUpdatesController to set DNS names for all workspaces
11661173
// and agents it learns about.
11671174
func WithDNS(d DNSHostsSetter, ownerUsername string) TunnelAllOption {
1168-
return func(t *tunnelAllWorkspaceUpdatesController) {
1175+
return func(t *TunnelAllWorkspaceUpdatesController) {
11691176
t.dnsHostSetter = d
11701177
t.ownerUsername = ownerUsername
11711178
}
11721179
}
11731180

1174-
func WithCallback(cb func(*proto.WorkspaceUpdate)) TunnelAllOption {
1175-
return func(t *tunnelAllWorkspaceUpdatesController) {
1176-
t.updateCallback = cb
1181+
func WithHandler(h UpdatesHandler) TunnelAllOption {
1182+
return func(t *TunnelAllWorkspaceUpdatesController) {
1183+
t.updateHandler = h
11771184
}
11781185
}
11791186

@@ -1182,8 +1189,8 @@ func WithCallback(cb func(*proto.WorkspaceUpdate)) TunnelAllOption {
11821189
// DNSHostSetter is provided, it also programs DNS hosts based on the agent and workspace names.
11831190
func NewTunnelAllWorkspaceUpdatesController(
11841191
logger slog.Logger, c *TunnelSrcCoordController, opts ...TunnelAllOption,
1185-
) WorkspaceUpdatesController {
1186-
t := &tunnelAllWorkspaceUpdatesController{logger: logger, coordCtrl: c}
1192+
) *TunnelAllWorkspaceUpdatesController {
1193+
t := &TunnelAllWorkspaceUpdatesController{logger: logger, coordCtrl: c}
11871194
for _, opt := range opts {
11881195
opt(t)
11891196
}

tailnet/controllers_test.go

Lines changed: 89 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"net"
88
"net/netip"
99
"slices"
10+
"strings"
1011
"sync"
1112
"sync/atomic"
1213
"testing"
@@ -1451,10 +1452,35 @@ func (f *fakeDNSSetter) SetDNSHosts(hosts map[dnsname.FQDN][]netip.Addr) error {
14511452
}
14521453
}
14531454

1455+
func newFakeUpdateHandler(ctx context.Context, t testing.TB) *fakeUpdateHandler {
1456+
return &fakeUpdateHandler{
1457+
ctx: ctx,
1458+
t: t,
1459+
ch: make(chan *proto.WorkspaceUpdate),
1460+
}
1461+
}
1462+
1463+
type fakeUpdateHandler struct {
1464+
ctx context.Context
1465+
t testing.TB
1466+
ch chan *proto.WorkspaceUpdate
1467+
}
1468+
1469+
func (f *fakeUpdateHandler) Update(wu *proto.WorkspaceUpdate) error {
1470+
f.t.Helper()
1471+
select {
1472+
case <-f.ctx.Done():
1473+
return timeoutOnFakeErr
1474+
case f.ch <- wu:
1475+
// OK
1476+
}
1477+
return nil
1478+
}
1479+
14541480
func setupConnectedAllWorkspaceUpdatesController(
14551481
ctx context.Context, t testing.TB, logger slog.Logger, opts ...tailnet.TunnelAllOption,
14561482
) (
1457-
*fakeCoordinatorClient, *fakeWorkspaceUpdateClient,
1483+
*fakeCoordinatorClient, *fakeWorkspaceUpdateClient, *tailnet.TunnelAllWorkspaceUpdatesController,
14581484
) {
14591485
fConn := &fakeCoordinatee{}
14601486
tsc := tailnet.NewTunnelSrcCoordController(logger, fConn)
@@ -1484,17 +1510,20 @@ func setupConnectedAllWorkspaceUpdatesController(
14841510
err := testutil.RequireRecvCtx(ctx, t, updateCW.Wait())
14851511
require.ErrorIs(t, err, io.EOF)
14861512
})
1487-
return coordC, updateC
1513+
return coordC, updateC, uut
14881514
}
14891515

14901516
func TestTunnelAllWorkspaceUpdatesController_Initial(t *testing.T) {
14911517
t.Parallel()
14921518
ctx := testutil.Context(t, testutil.WaitShort)
14931519
logger := testutil.Logger(t)
14941520

1521+
fUH := newFakeUpdateHandler(ctx, t)
14951522
fDNS := newFakeDNSSetter(ctx, t)
1496-
coordC, updateC := setupConnectedAllWorkspaceUpdatesController(ctx, t, logger,
1497-
tailnet.WithDNS(fDNS, "testy"))
1523+
coordC, updateC, updateCtrl := setupConnectedAllWorkspaceUpdatesController(ctx, t, logger,
1524+
tailnet.WithDNS(fDNS, "testy"),
1525+
tailnet.WithHandler(fUH),
1526+
)
14981527

14991528
// Initial update contains 2 workspaces with 1 & 2 agents, respectively
15001529
w1ID := testUUID(1)
@@ -1541,16 +1570,43 @@ func TestTunnelAllWorkspaceUpdatesController_Initial(t *testing.T) {
15411570
dnsCall := testutil.RequireRecvCtx(ctx, t, fDNS.calls)
15421571
require.Equal(t, expectedDNS, dnsCall.hosts)
15431572
testutil.RequireSendCtx(ctx, t, dnsCall.err, nil)
1573+
1574+
// And the callback
1575+
cbUpdate := testutil.RequireRecvCtx(ctx, t, fUH.ch)
1576+
require.Equal(t, initUp, cbUpdate)
1577+
1578+
// Current state should match
1579+
state := updateCtrl.CurrentState()
1580+
slices.SortFunc(state.UpsertedWorkspaces, func(a, b *proto.Workspace) int {
1581+
return strings.Compare(a.Name, b.Name)
1582+
})
1583+
slices.SortFunc(state.UpsertedAgents, func(a, b *proto.Agent) int {
1584+
return strings.Compare(a.Name, b.Name)
1585+
})
1586+
require.Equal(t, &proto.WorkspaceUpdate{
1587+
UpsertedWorkspaces: []*proto.Workspace{
1588+
{Id: w1ID[:], Name: "w1"},
1589+
{Id: w2ID[:], Name: "w2"},
1590+
},
1591+
UpsertedAgents: []*proto.Agent{
1592+
{Id: w1a1ID[:], Name: "w1a1", WorkspaceId: w1ID[:]},
1593+
{Id: w2a1ID[:], Name: "w2a1", WorkspaceId: w2ID[:]},
1594+
{Id: w2a2ID[:], Name: "w2a2", WorkspaceId: w2ID[:]},
1595+
},
1596+
}, state)
15441597
}
15451598

15461599
func TestTunnelAllWorkspaceUpdatesController_DeleteAgent(t *testing.T) {
15471600
t.Parallel()
15481601
ctx := testutil.Context(t, testutil.WaitShort)
15491602
logger := testutil.Logger(t)
15501603

1604+
fUH := newFakeUpdateHandler(ctx, t)
15511605
fDNS := newFakeDNSSetter(ctx, t)
1552-
coordC, updateC := setupConnectedAllWorkspaceUpdatesController(ctx, t, logger,
1553-
tailnet.WithDNS(fDNS, "testy"))
1606+
coordC, updateC, updateCtrl := setupConnectedAllWorkspaceUpdatesController(ctx, t, logger,
1607+
tailnet.WithDNS(fDNS, "testy"),
1608+
tailnet.WithHandler(fUH),
1609+
)
15541610

15551611
w1ID := testUUID(1)
15561612
w1a1ID := testUUID(1, 1)
@@ -1582,6 +1638,20 @@ func TestTunnelAllWorkspaceUpdatesController_DeleteAgent(t *testing.T) {
15821638
require.Equal(t, expectedDNS, dnsCall.hosts)
15831639
testutil.RequireSendCtx(ctx, t, dnsCall.err, nil)
15841640

1641+
cbUpdate := testutil.RequireRecvCtx(ctx, t, fUH.ch)
1642+
require.Equal(t, initUp, cbUpdate)
1643+
1644+
// Current state should match initial
1645+
state := updateCtrl.CurrentState()
1646+
require.Equal(t, &proto.WorkspaceUpdate{
1647+
UpsertedWorkspaces: []*proto.Workspace{
1648+
{Id: w1ID[:], Name: "w1"},
1649+
},
1650+
UpsertedAgents: []*proto.Agent{
1651+
{Id: w1a1ID[:], Name: "w1a1", WorkspaceId: w1ID[:]},
1652+
},
1653+
}, state)
1654+
15851655
// Send update that removes w1a1 and adds w1a2
15861656
agentUpdate := &proto.WorkspaceUpdate{
15871657
UpsertedAgents: []*proto.Agent{
@@ -1613,6 +1683,19 @@ func TestTunnelAllWorkspaceUpdatesController_DeleteAgent(t *testing.T) {
16131683
dnsCall = testutil.RequireRecvCtx(ctx, t, fDNS.calls)
16141684
require.Equal(t, expectedDNS, dnsCall.hosts)
16151685
testutil.RequireSendCtx(ctx, t, dnsCall.err, nil)
1686+
1687+
cbUpdate = testutil.RequireRecvCtx(ctx, t, fUH.ch)
1688+
require.Equal(t, agentUpdate, cbUpdate)
1689+
1690+
state = updateCtrl.CurrentState()
1691+
require.Equal(t, &proto.WorkspaceUpdate{
1692+
UpsertedWorkspaces: []*proto.Workspace{
1693+
{Id: w1ID[:], Name: "w1"},
1694+
},
1695+
UpsertedAgents: []*proto.Agent{
1696+
{Id: w1a2ID[:], Name: "w1a2", WorkspaceId: w1ID[:]},
1697+
},
1698+
}, state)
16161699
}
16171700

16181701
func TestTunnelAllWorkspaceUpdatesController_DNSError(t *testing.T) {

0 commit comments

Comments
 (0)