Skip to content

Commit 94e90a3

Browse files
committed
Add coordinator tests
1 parent 01c52c7 commit 94e90a3

File tree

3 files changed

+222
-82
lines changed

3 files changed

+222
-82
lines changed

.vscode/settings.json

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
"cSpell.words": [
33
"apps",
44
"awsidentity",
5+
"bodyclose",
56
"buildinfo",
67
"buildname",
78
"circbuf",
@@ -92,6 +93,7 @@
9293
"templateversions",
9394
"testdata",
9495
"testid",
96+
"testutil",
9597
"tfexec",
9698
"tfjson",
9799
"tfplan",

tailnet/coordinator.go

Lines changed: 101 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ import (
1111
"nhooyr.io/websocket/wsjson"
1212
)
1313

14-
// Coordinate matches the RW structure of a coordinator to exchange node messages.
15-
func Coordinate(ctx context.Context, socket *websocket.Conn, updateNodes func(node []*Node)) (func(node *Node), <-chan error) {
14+
// ServeCoordinator matches the RW structure of a coordinator to exchange node messages.
15+
func ServeCoordinator(ctx context.Context, socket *websocket.Conn, updateNodes func(node []*Node)) (func(node *Node), <-chan error) {
1616
errChan := make(chan error, 1)
1717
go func() {
1818
for {
@@ -37,139 +37,178 @@ func Coordinate(ctx context.Context, socket *websocket.Conn, updateNodes func(no
3737
// NewCoordinator constructs a new in-memory connection coordinator.
3838
func NewCoordinator() *Coordinator {
3939
return &Coordinator{
40-
agentNodes: map[uuid.UUID]*Node{},
41-
agentClientNodes: map[uuid.UUID]map[uuid.UUID]*Node{},
42-
agentSockets: map[uuid.UUID]*websocket.Conn{},
43-
agentClientSockets: map[uuid.UUID]map[uuid.UUID]*websocket.Conn{},
40+
nodes: map[uuid.UUID]*Node{},
41+
agentSockets: map[uuid.UUID]*websocket.Conn{},
42+
agentToConnectionSockets: map[uuid.UUID]map[uuid.UUID]*websocket.Conn{},
4443
}
4544
}
4645

47-
// Coordinator brokers connections over WebSockets.
46+
// Coordinator exchanges nodes with agents to establish connections.
47+
// ┌──────────────────┐ ┌────────────────────┐ ┌───────────────────┐ ┌──────────────────┐
48+
// │tailnet.Coordinate├──►│tailnet.AcceptClient│◄─►│tailnet.AcceptAgent│◄──┤tailnet.Coordinate│
49+
// └──────────────────┘ └────────────────────┘ └───────────────────┘ └──────────────────┘
50+
// This coordinator is incompatible with multiple Coder
51+
// replicas as all node data is in-memory.
4852
type Coordinator struct {
4953
mutex sync.Mutex
50-
// Stores the most recent node an agent sent.
51-
agentNodes map[uuid.UUID]*Node
52-
// Stores the most recent node reported by each client.
53-
agentClientNodes map[uuid.UUID]map[uuid.UUID]*Node
54-
// Stores the active connection from an agent.
54+
55+
// Maps agent and connection IDs to a node.
56+
nodes map[uuid.UUID]*Node
57+
// Maps agent ID to an open socket.
5558
agentSockets map[uuid.UUID]*websocket.Conn
56-
// Stores the active connection from a client to an agent.
57-
agentClientSockets map[uuid.UUID]map[uuid.UUID]*websocket.Conn
59+
// Maps agent ID to connection ID for sending
60+
// new node data as it comes in!
61+
agentToConnectionSockets map[uuid.UUID]map[uuid.UUID]*websocket.Conn
5862
}
5963

60-
// Client represents a tailnet looking to peer with an agent.
61-
func (c *Coordinator) Client(ctx context.Context, agentID uuid.UUID, socket *websocket.Conn) error {
62-
id := uuid.New()
64+
// Node returns an in-memory node by ID.
65+
func (c *Coordinator) Node(id uuid.UUID) *Node {
6366
c.mutex.Lock()
64-
clients, ok := c.agentClientSockets[agentID]
65-
if !ok {
66-
clients = map[uuid.UUID]*websocket.Conn{}
67-
c.agentClientSockets[agentID] = clients
68-
}
69-
clients[id] = socket
70-
agentNode, ok := c.agentNodes[agentID]
67+
defer c.mutex.Unlock()
68+
node := c.nodes[id]
69+
return node
70+
}
71+
72+
// ServeClient accepts a WebSocket connection that wants to
73+
// connect to an agent with the specified ID.
74+
func (c *Coordinator) ServeClient(ctx context.Context, socket *websocket.Conn, id uuid.UUID, agent uuid.UUID) error {
75+
c.mutex.Lock()
76+
// When a new connection is requested, we update it with the latest
77+
// node of the agent. This allows the connection to establish.
78+
node, ok := c.nodes[agent]
7179
if ok {
72-
err := wsjson.Write(ctx, socket, []*Node{agentNode})
80+
err := wsjson.Write(ctx, socket, []*Node{node})
7381
if err != nil {
7482
c.mutex.Unlock()
75-
return xerrors.Errorf("write agent node: %w", err)
83+
return xerrors.Errorf("write nodes: %w", err)
7684
}
7785
}
78-
86+
connectionSockets, ok := c.agentToConnectionSockets[agent]
87+
if !ok {
88+
connectionSockets = map[uuid.UUID]*websocket.Conn{}
89+
c.agentToConnectionSockets[agent] = connectionSockets
90+
}
91+
// Insert this connection into a map so the agent
92+
// can publish node updates.
93+
connectionSockets[id] = socket
7994
c.mutex.Unlock()
8095
defer func() {
8196
c.mutex.Lock()
8297
defer c.mutex.Unlock()
83-
clients, ok := c.agentClientSockets[agentID]
98+
// Clean all traces of this connection from the map.
99+
delete(c.nodes, id)
100+
connectionSockets, ok := c.agentToConnectionSockets[agent]
84101
if !ok {
85102
return
86103
}
87-
delete(clients, id)
88-
nodes, ok := c.agentClientNodes[agentID]
89-
if !ok {
104+
delete(connectionSockets, id)
105+
if len(connectionSockets) != 0 {
90106
return
91107
}
92-
delete(nodes, id)
108+
delete(c.agentToConnectionSockets, agent)
93109
}()
94110

95111
for {
96112
var node Node
97113
err := wsjson.Read(ctx, socket, &node)
98-
if errors.Is(err, context.Canceled) {
114+
if errors.Is(err, context.Canceled) || errors.As(err, &websocket.CloseError{}) {
99115
return nil
100116
}
101117
if err != nil {
102118
return xerrors.Errorf("read json: %w", err)
103119
}
104120
c.mutex.Lock()
105-
nodes, ok := c.agentClientNodes[agentID]
121+
// Update the node of this client in our in-memory map.
122+
// If an agent entirely shuts down and reconnects, it
123+
// needs to be aware of all clients attempting to
124+
// establish connections.
125+
c.nodes[id] = &node
126+
agentSocket, ok := c.agentSockets[agent]
106127
if !ok {
107-
nodes = map[uuid.UUID]*Node{}
108-
c.agentClientNodes[agentID] = nodes
109-
}
110-
nodes[id] = &node
111-
112-
agentSocket, ok := c.agentSockets[agentID]
113-
if !ok {
114-
// If the agent isn't connected yet, that's fine. It'll reconcile later.
115128
c.mutex.Unlock()
116129
continue
117130
}
131+
// Write the new node from this client to the actively
132+
// connected agent.
118133
err = wsjson.Write(ctx, agentSocket, []*Node{&node})
134+
if errors.Is(err, context.Canceled) {
135+
c.mutex.Unlock()
136+
return nil
137+
}
119138
if err != nil {
120139
c.mutex.Unlock()
121-
return xerrors.Errorf("write node to agent: %w", err)
140+
return xerrors.Errorf("write json: %w", err)
122141
}
123142
c.mutex.Unlock()
124143
}
125144
}
126145

127-
func (c *Coordinator) Agent(ctx context.Context, agentID uuid.UUID, socket *websocket.Conn) error {
146+
// ServeAgent accepts a WebSocket connection to an agent that
147+
// listens to incoming connections and publishes node updates.
148+
func (c *Coordinator) ServeAgent(ctx context.Context, socket *websocket.Conn, id uuid.UUID) error {
128149
c.mutex.Lock()
129-
agentSocket, ok := c.agentSockets[agentID]
130-
if ok {
131-
agentSocket.Close(websocket.StatusGoingAway, "another agent started with the same id")
132-
}
133-
c.agentSockets[agentID] = socket
134-
nodes, ok := c.agentClientNodes[agentID]
150+
sockets, ok := c.agentToConnectionSockets[id]
135151
if ok {
152+
// Publish all nodes that want to connect to the
153+
// desired agent ID.
154+
nodes := make([]*Node, 0, len(sockets))
155+
for targetID := range sockets {
156+
node, ok := c.nodes[targetID]
157+
if !ok {
158+
continue
159+
}
160+
nodes = append(nodes, node)
161+
}
136162
err := wsjson.Write(ctx, socket, nodes)
137163
if err != nil {
138164
c.mutex.Unlock()
139165
return xerrors.Errorf("write nodes: %w", err)
140166
}
141167
}
168+
169+
// If an old agent socket is connected, we close it
170+
// to avoid any leaks. This shouldn't ever occur because
171+
// we expect one agent to be running.
172+
oldAgentSocket, ok := c.agentSockets[id]
173+
if ok {
174+
_ = oldAgentSocket.Close(websocket.StatusNormalClosure, "another agent connected with the same id")
175+
}
176+
c.agentSockets[id] = socket
142177
c.mutex.Unlock()
143178
defer func() {
144179
c.mutex.Lock()
145180
defer c.mutex.Unlock()
146-
delete(c.agentSockets, agentID)
181+
delete(c.agentSockets, id)
182+
delete(c.nodes, id)
147183
}()
148184

149185
for {
150186
var node Node
151187
err := wsjson.Read(ctx, socket, &node)
152-
if errors.Is(err, context.Canceled) {
188+
if errors.Is(err, context.Canceled) || errors.As(err, &websocket.CloseError{}) {
153189
return nil
154190
}
155191
if err != nil {
156-
return xerrors.Errorf("read node: %w", err)
192+
return xerrors.Errorf("read json: %w", err)
157193
}
158194
c.mutex.Lock()
159-
c.agentNodes[agentID] = &node
160-
161-
clients, ok := c.agentClientSockets[agentID]
195+
c.nodes[id] = &node
196+
connectionSockets, ok := c.agentToConnectionSockets[id]
162197
if !ok {
163198
c.mutex.Unlock()
164199
continue
165200
}
166-
for _, client := range clients {
167-
err = wsjson.Write(ctx, client, []*Node{&node})
168-
if err != nil {
169-
c.mutex.Unlock()
170-
return xerrors.Errorf("write to client: %w", err)
171-
}
201+
// Publish the new node to every listening socket.
202+
var wg sync.WaitGroup
203+
wg.Add(len(connectionSockets))
204+
for _, connectionSocket := range connectionSockets {
205+
connectionSocket := connectionSocket
206+
go func() {
207+
_ = wsjson.Write(ctx, connectionSocket, []*Node{&node})
208+
wg.Done()
209+
}()
172210
}
211+
wg.Wait()
173212
c.mutex.Unlock()
174213
}
175214
}

0 commit comments

Comments
 (0)