Skip to content

Commit cd88dca

Browse files
committed
Fix coordination
1 parent 5cf709c commit cd88dca

File tree

8 files changed

+86
-51
lines changed

8 files changed

+86
-51
lines changed

agent/conn.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,5 +181,5 @@ func (c *TailnetConn) SSHClient() (*ssh.Client, error) {
181181
func (c *TailnetConn) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) {
182182
_, rawPort, _ := net.SplitHostPort(addr)
183183
port, _ := strconv.Atoi(rawPort)
184-
return c.Server.DialContextTCP(ctx, netip.AddrPortFrom(c.Target, uint16(port)))
184+
return c.Conn.DialContextTCP(ctx, netip.AddrPortFrom(c.Target, uint16(port)))
185185
}

coderd/coderd.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@ type Options struct {
7272
TURNServer *turnconn.Server
7373
TracerProvider *sdktrace.TracerProvider
7474

75-
DERPMap *tailcfg.DERPMap
75+
ConnCoordinator *tailnet.Coordinator
76+
DERPMap *tailcfg.DERPMap
7677
}
7778

7879
// New constructs a Coder API handler.
@@ -99,6 +100,9 @@ func New(options *Options) *API {
99100
if options.PrometheusRegistry == nil {
100101
options.PrometheusRegistry = prometheus.NewRegistry()
101102
}
103+
if options.ConnCoordinator == nil {
104+
options.ConnCoordinator = tailnet.NewCoordinator()
105+
}
102106

103107
siteCacheDir := options.CacheDir
104108
if siteCacheDir != "" {
@@ -338,7 +342,7 @@ func New(options *Options) *API {
338342
r.Get("/iceservers", api.workspaceAgentICEServers)
339343

340344
// Everything below this is Tailnet.
341-
r.Get("/node", api.workspaceAgentNode)
345+
r.Get("/coordinate", api.workspaceAgentClientCoordinate)
342346
})
343347
r.Route("/{workspaceagent}", func(r chi.Router) {
344348
r.Use(
@@ -355,7 +359,7 @@ func New(options *Options) *API {
355359
r.Get("/derpmap", func(w http.ResponseWriter, r *http.Request) {
356360
httpapi.Write(w, http.StatusOK, options.DERPMap)
357361
})
358-
r.Post("/node", api.postWorkspaceAgentNode)
362+
r.Get("/coordinate", api.workspaceAgentClientCoordinate)
359363
})
360364
})
361365
r.Route("/workspaceresources/{workspaceresource}", func(r chi.Router) {

coderd/workspaceagents.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -507,10 +507,10 @@ func (api *API) dialWorkspaceAgent(r *http.Request, agentID uuid.UUID) (agent.Co
507507
}, nil
508508
}
509509

510-
// workspaceAgentNode accepts a WebSocket that reads node network updates.
510+
// workspaceAgentClientCoordinate accepts a WebSocket that reads node network updates.
511511
// After accept a PubSub starts listening for new connection node updates
512512
// which are written to the WebSocket.
513-
func (api *API) workspaceAgentNode(rw http.ResponseWriter, r *http.Request) {
513+
func (api *API) workspaceAgentClientCoordinate(rw http.ResponseWriter, r *http.Request) {
514514
api.websocketWaitMutex.Lock()
515515
api.websocketWaitGroup.Add(1)
516516
api.websocketWaitMutex.Unlock()

codersdk/workspaceagents.go

Lines changed: 54 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@ package codersdk
33
import (
44
"context"
55
"encoding/json"
6+
"errors"
67
"fmt"
78
"io"
89
"net"
910
"net/http"
1011
"net/http/cookiejar"
1112
"net/netip"
13+
"time"
1214

1315
"cloud.google.com/go/compute/metadata"
1416
"github.com/google/uuid"
@@ -29,6 +31,7 @@ import (
2931
"github.com/coder/coder/peerbroker/proto"
3032
"github.com/coder/coder/provisionersdk"
3133
"github.com/coder/coder/tailnet"
34+
"github.com/coder/retry"
3235
)
3336

3437
type GoogleInstanceIdentityToken struct {
@@ -51,6 +54,11 @@ type WorkspaceAgentAuthenticateResponse struct {
5154
SessionToken string `json:"session_token"`
5255
}
5356

57+
type WorkspaceAgentConnectionRequest struct {
58+
DERPMap tailcfg.DERPMap `json:"derp_map"`
59+
IPAddress netip.Addr `json:"ip_address"`
60+
}
61+
5462
// AuthWorkspaceGoogleInstanceIdentity uses the Google Compute Engine Metadata API to
5563
// fetch a signed JWT, and exchange it for a session token for a workspace agent.
5664
//
@@ -303,7 +311,7 @@ func (c *Client) WorkspaceAgentNodeBroker(ctx context.Context) (agent.NodeBroker
303311
return &workspaceAgentNodeBroker{conn}, nil
304312
}
305313

306-
func (c *Client) DialWorkspaceAgentTailnet(ctx context.Context, agentID uuid.UUID, logger slog.Logger) (agent.Conn, error) {
314+
func (c *Client) DialWorkspaceAgentTailnet(ctx context.Context, logger slog.Logger, agentID uuid.UUID) (agent.Conn, error) {
307315
res, err := c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/workspaceagents/%s/derpmap", agentID), nil)
308316
if err != nil {
309317
return nil, err
@@ -319,48 +327,64 @@ func (c *Client) DialWorkspaceAgentTailnet(ctx context.Context, agentID uuid.UUI
319327
}
320328

321329
ip := tailnet.IP()
322-
server, err := tailnet.NewConn(&tailnet.Options{
330+
conn, err := tailnet.NewConn(&tailnet.Options{
323331
Addresses: []netip.Prefix{netip.PrefixFrom(ip, 128)},
324332
DERPMap: &derpMap,
325333
Logger: logger,
326334
})
327335
if err != nil {
328336
return nil, xerrors.Errorf("create tailnet: %w", err)
329337
}
330-
server.SetNodeCallback(func(node *tailnet.Node) {
331-
res, err := c.Request(ctx, http.MethodPost, fmt.Sprintf("/api/v2/workspaceagents/%s/node", agentID), node)
332-
if err != nil {
333-
logger.Error(ctx, "update node", slog.Error(err), slog.F("node", node))
334-
return
335-
}
336-
defer res.Body.Close()
337-
if res.StatusCode != http.StatusOK {
338-
logger.Error(ctx, "update node", slog.F("status_code", res.StatusCode), slog.F("node", node))
339-
}
340-
})
341-
workspaceAgent, err := c.WorkspaceAgent(ctx, agentID)
338+
339+
coordinateURL, err := c.URL.Parse("/api/v2/workspaceagents/me/coordinate")
342340
if err != nil {
343-
return nil, xerrors.Errorf("get workspace agent: %w", err)
344-
}
345-
ipRanges := make([]netip.Prefix, 0, len(workspaceAgent.IPAddresses))
346-
for _, address := range workspaceAgent.IPAddresses {
347-
ipRanges = append(ipRanges, netip.PrefixFrom(address, 128))
348-
}
349-
agentNode := &tailnet.Node{
350-
Key: workspaceAgent.NodePublicKey,
351-
DiscoKey: workspaceAgent.DiscoPublicKey,
352-
PreferredDERP: workspaceAgent.PreferredDERP,
353-
Addresses: ipRanges,
354-
AllowedIPs: ipRanges,
341+
return nil, xerrors.Errorf("parse url: %w", err)
355342
}
356-
logger.Debug(ctx, "adding agent node", slog.F("node", agentNode))
357-
err = server.UpdateNodes([]*tailnet.Node{agentNode})
343+
jar, err := cookiejar.New(nil)
358344
if err != nil {
359-
return nil, xerrors.Errorf("update nodes: %w", err)
345+
return nil, xerrors.Errorf("create cookie jar: %w", err)
346+
}
347+
jar.SetCookies(coordinateURL, []*http.Cookie{{
348+
Name: SessionTokenKey,
349+
Value: c.SessionToken,
350+
}})
351+
httpClient := &http.Client{
352+
Jar: jar,
360353
}
354+
go func() {
355+
for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(ctx); {
356+
logger.Debug(ctx, "connecting")
357+
ws, res, err := websocket.Dial(ctx, coordinateURL.String(), &websocket.DialOptions{
358+
HTTPClient: httpClient,
359+
// Need to disable compression to avoid a data-race.
360+
CompressionMode: websocket.CompressionDisabled,
361+
})
362+
if errors.Is(err, context.Canceled) {
363+
return
364+
}
365+
if err != nil {
366+
logger.Debug(ctx, "failed to dial", slog.Error(err))
367+
continue
368+
}
369+
_ = res.Body.Close()
370+
sendNode, errChan := tailnet.ServeCoordinator(ctx, ws, func(node []*tailnet.Node) error {
371+
return conn.UpdateNodes(node)
372+
})
373+
conn.SetNodeCallback(sendNode)
374+
logger.Debug(ctx, "serving coordinator")
375+
err = <-errChan
376+
if errors.Is(err, context.Canceled) {
377+
return
378+
}
379+
if err != nil {
380+
logger.Debug(ctx, "error serving coordinator", slog.Error(err))
381+
continue
382+
}
383+
}
384+
}()
361385
return &agent.TailnetConn{
362386
Target: workspaceAgent.IPAddresses[0],
363-
Server: server,
387+
Conn: conn,
364388
}, nil
365389
}
366390

codersdk/workspaceresources.go

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,9 @@ import (
55
"encoding/json"
66
"fmt"
77
"net/http"
8-
"net/netip"
98
"time"
109

1110
"github.com/google/uuid"
12-
"tailscale.com/types/key"
1311
)
1412

1513
type WorkspaceAgentStatus string
@@ -54,11 +52,6 @@ type WorkspaceAgent struct {
5452
StartupScript string `json:"startup_script,omitempty"`
5553
Directory string `json:"directory,omitempty"`
5654
Apps []WorkspaceApp `json:"apps"`
57-
58-
// For internal routing only.
59-
IPAddresses []netip.Addr `json:"ip_addresses"`
60-
NodePublicKey key.NodePublic `json:"node_public_key"`
61-
DiscoPublicKey key.DiscoPublic `json:"disco_public_key"`
6255
// PreferredDERP represents the connected region.
6356
PreferredDERP int `json:"preferred_derp"`
6457
// Maps DERP region to MS latency.

site/site.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ func secureHeaders(next http.Handler) http.Handler {
346346
return secure.New(secure.Options{
347347
PermissionsPolicy: permissions,
348348

349-
// Prevent the browser from sending Referer header with requests
349+
// Prevent the browser from sending Referrer header with requests
350350
ReferrerPolicy: "no-referrer",
351351
}).Handler(cspHeaders(next))
352352
}

tailnet/coordinator.go

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ import (
1212
)
1313

1414
// ServeCoordinator matches the RW structure of a coordinator to exchange node messages.
15-
func ServeCoordinator(ctx context.Context, socket *websocket.Conn, updateNodes func(node []*Node)) (func(node *Node), <-chan error) {
15+
func ServeCoordinator(ctx context.Context, socket *websocket.Conn, updateNodes func(node []*Node) error) (func(node *Node), <-chan error) {
1616
errChan := make(chan error, 1)
1717
go func() {
1818
for {
@@ -22,12 +22,19 @@ func ServeCoordinator(ctx context.Context, socket *websocket.Conn, updateNodes f
2222
errChan <- xerrors.Errorf("read: %w", err)
2323
return
2424
}
25-
updateNodes(nodes)
25+
err = updateNodes(nodes)
26+
if err != nil {
27+
errChan <- xerrors.Errorf("update nodes: %w", err)
28+
}
2629
}
2730
}()
2831

2932
return func(node *Node) {
3033
err := wsjson.Write(ctx, socket, node)
34+
if errors.Is(err, context.Canceled) || errors.As(err, &websocket.CloseError{}) {
35+
errChan <- nil
36+
return
37+
}
3138
if err != nil {
3239
errChan <- xerrors.Errorf("write: %w", err)
3340
}

tailnet/coordinator_test.go

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ func TestCoordinator(t *testing.T) {
2323
t.Parallel()
2424
coordinator := tailnet.NewCoordinator()
2525
client, server := pipeWS(t)
26-
sendNode, errChan := tailnet.ServeCoordinator(context.Background(), client, func(node []*tailnet.Node) {})
26+
sendNode, errChan := tailnet.ServeCoordinator(context.Background(), client, func(node []*tailnet.Node) error {
27+
return nil
28+
})
2729
id := uuid.New()
2830
closeChan := make(chan struct{})
2931
go func() {
@@ -45,7 +47,9 @@ func TestCoordinator(t *testing.T) {
4547
t.Parallel()
4648
coordinator := tailnet.NewCoordinator()
4749
client, server := pipeWS(t)
48-
sendNode, errChan := tailnet.ServeCoordinator(context.Background(), client, func(node []*tailnet.Node) {})
50+
sendNode, errChan := tailnet.ServeCoordinator(context.Background(), client, func(node []*tailnet.Node) error {
51+
return nil
52+
})
4953
id := uuid.New()
5054
closeChan := make(chan struct{})
5155
go func() {
@@ -70,8 +74,9 @@ func TestCoordinator(t *testing.T) {
7074
agentWS, agentServerWS := pipeWS(t)
7175
defer agentWS.Close(websocket.StatusNormalClosure, "")
7276
agentNodeChan := make(chan []*tailnet.Node)
73-
sendAgentNode, agentErrChan := tailnet.ServeCoordinator(context.Background(), agentWS, func(nodes []*tailnet.Node) {
77+
sendAgentNode, agentErrChan := tailnet.ServeCoordinator(context.Background(), agentWS, func(nodes []*tailnet.Node) error {
7478
agentNodeChan <- nodes
79+
return nil
7580
})
7681
agentID := uuid.New()
7782
closeAgentChan := make(chan struct{})
@@ -89,8 +94,9 @@ func TestCoordinator(t *testing.T) {
8994
defer clientWS.Close(websocket.StatusNormalClosure, "")
9095
defer clientServerWS.Close(websocket.StatusNormalClosure, "")
9196
clientNodeChan := make(chan []*tailnet.Node)
92-
sendClientNode, clientErrChan := tailnet.ServeCoordinator(context.Background(), clientWS, func(nodes []*tailnet.Node) {
97+
sendClientNode, clientErrChan := tailnet.ServeCoordinator(context.Background(), clientWS, func(nodes []*tailnet.Node) error {
9398
clientNodeChan <- nodes
99+
return nil
94100
})
95101
clientID := uuid.New()
96102
closeClientChan := make(chan struct{})
@@ -120,8 +126,9 @@ func TestCoordinator(t *testing.T) {
120126
agentWS, agentServerWS = pipeWS(t)
121127
defer agentWS.Close(websocket.StatusNormalClosure, "")
122128
agentNodeChan = make(chan []*tailnet.Node)
123-
_, agentErrChan = tailnet.ServeCoordinator(context.Background(), agentWS, func(nodes []*tailnet.Node) {
129+
_, agentErrChan = tailnet.ServeCoordinator(context.Background(), agentWS, func(nodes []*tailnet.Node) error {
124130
agentNodeChan <- nodes
131+
return nil
125132
})
126133
closeAgentChan = make(chan struct{})
127134
go func() {

0 commit comments

Comments
 (0)