Skip to content

Commit f4e5887

Browse files
committed
It compiles!
1 parent cd88dca commit f4e5887

File tree

12 files changed

+124
-165
lines changed

12 files changed

+124
-165
lines changed

agent/agent.go

Lines changed: 28 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import (
2828
"go.uber.org/atomic"
2929
gossh "golang.org/x/crypto/ssh"
3030
"golang.org/x/xerrors"
31+
"nhooyr.io/websocket"
3132
"tailscale.com/tailcfg"
3233

3334
"cdr.dev/slog"
@@ -51,10 +52,10 @@ const (
5152
)
5253

5354
type Options struct {
54-
EnableTailnet bool
55-
NodeDialer NodeDialer
56-
WebRTCDialer WebRTCDialer
57-
FetchMetadata FetchMetadata
55+
EnableTailnet bool
56+
CoordinatorDialer CoordinatorDialer
57+
WebRTCDialer WebRTCDialer
58+
FetchMetadata FetchMetadata
5859

5960
ReconnectingPTYTimeout time.Duration
6061
EnvironmentVariables map[string]string
@@ -71,18 +72,9 @@ type Metadata struct {
7172

7273
type WebRTCDialer func(ctx context.Context, logger slog.Logger) (*peerbroker.Listener, error)
7374

74-
// NodeBroker handles the exchange of node information.
75-
type NodeBroker interface {
76-
io.Closer
77-
// Read will be a constant stream of incoming connection requests.
78-
Read(ctx context.Context) (*tailnet.Node, error)
79-
// Write should be called with the listening agent node information.
80-
Write(ctx context.Context, node *tailnet.Node) error
81-
}
82-
83-
// NodeDialer is a function that constructs a new broker.
75+
// CoordinatorDialer is a function that constructs a new broker.
8476
// A dialer must be passed in to allow for reconnects.
85-
type NodeDialer func(ctx context.Context) (NodeBroker, error)
77+
type CoordinatorDialer func(ctx context.Context) (*websocket.Conn, error)
8678

8779
// FetchMetadata is a function to obtain metadata for the agent.
8880
type FetchMetadata func(ctx context.Context) (Metadata, error)
@@ -100,7 +92,7 @@ func New(options Options) io.Closer {
10092
closed: make(chan struct{}),
10193
envVars: options.EnvironmentVariables,
10294
enableTailnet: options.EnableTailnet,
103-
nodeDialer: options.NodeDialer,
95+
coordinatorDialer: options.CoordinatorDialer,
10496
fetchMetadata: options.FetchMetadata,
10597
}
10698
server.init(ctx)
@@ -125,9 +117,9 @@ type agent struct {
125117
fetchMetadata FetchMetadata
126118
sshServer *ssh.Server
127119

128-
enableTailnet bool
129-
network *tailnet.Conn
130-
nodeDialer NodeDialer
120+
enableTailnet bool
121+
network *tailnet.Conn
122+
coordinatorDialer CoordinatorDialer
131123
}
132124

133125
func (a *agent) run(ctx context.Context) {
@@ -190,7 +182,7 @@ func (a *agent) runTailnet(ctx context.Context, addresses []netip.Addr, derpMap
190182
a.logger.Critical(ctx, "create tailnet", slog.Error(err))
191183
return
192184
}
193-
go a.runNodeBroker(ctx)
185+
go a.runCoordinator(ctx)
194186

195187
sshListener, err := a.network.Listen("tcp", ":12212")
196188
if err != nil {
@@ -208,14 +200,14 @@ func (a *agent) runTailnet(ctx context.Context, addresses []netip.Addr, derpMap
208200
}()
209201
}
210202

211-
// runNodeBroker listens for nodes and updates the self-node as it changes.
212-
func (a *agent) runNodeBroker(ctx context.Context) {
213-
var nodeBroker NodeBroker
203+
// runCoordinator listens for nodes and updates the self-node as it changes.
204+
func (a *agent) runCoordinator(ctx context.Context) {
205+
var coordinator *websocket.Conn
214206
var err error
215207
// An exponential back-off occurs when the connection is failing to dial.
216208
// This is to prevent server spam in case of a coderd outage.
217209
for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(ctx); {
218-
nodeBroker, err = a.nodeDialer(ctx)
210+
coordinator, err = a.coordinatorDialer(ctx)
219211
if err != nil {
220212
if errors.Is(err, context.Canceled) {
221213
return
@@ -226,36 +218,24 @@ func (a *agent) runNodeBroker(ctx context.Context) {
226218
a.logger.Warn(context.Background(), "failed to dial", slog.Error(err))
227219
continue
228220
}
229-
a.logger.Info(context.Background(), "connected to node broker")
221+
a.logger.Info(context.Background(), "connected to coordination server")
230222
break
231223
}
224+
sendNodes, errChan := tailnet.ServeCoordinator(ctx, coordinator, a.network.UpdateNodes)
225+
a.network.SetNodeCallback(sendNodes)
232226
select {
233227
case <-ctx.Done():
234228
return
235-
default:
236-
}
237-
238-
a.network.SetNodeCallback(func(node *tailnet.Node) {
239-
err := nodeBroker.Write(ctx, node)
240-
if err != nil {
241-
a.logger.Warn(context.Background(), "write node", slog.Error(err), slog.F("node", node))
242-
}
243-
})
244-
245-
for {
246-
node, err := nodeBroker.Read(ctx)
247-
if err != nil {
248-
if a.isClosed() {
249-
return
250-
}
251-
a.logger.Debug(ctx, "node broker accept exited; restarting connection", slog.Error(err))
252-
a.runNodeBroker(ctx)
229+
case err := <-errChan:
230+
if a.isClosed() {
253231
return
254232
}
255-
err = a.network.UpdateNodes([]*tailnet.Node{node})
256-
if err != nil {
257-
a.logger.Error(ctx, "update tailnet nodes", slog.Error(err), slog.F("node", node))
233+
if errors.Is(err, context.Canceled) {
234+
return
258235
}
236+
a.logger.Debug(ctx, "node broker accept exited; restarting connection", slog.Error(err))
237+
a.runCoordinator(ctx)
238+
return
259239
}
260240
}
261241

@@ -887,7 +867,9 @@ func (a *agent) Close() error {
887867
}
888868
close(a.closed)
889869
a.closeCancel()
870+
fmt.Printf("CLOSING NETWORK!!!!\n")
890871
if a.network != nil {
872+
fmt.Printf("ACTUALLY CLOSING NETWORK!!!!\n")
891873
_ = a.network.Close()
892874
}
893875
_ = a.sshServer.Close()

cli/agent.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,8 @@ func workspaceAgent() *cobra.Command {
181181
// shells so "gitssh" works!
182182
"CODER_AGENT_TOKEN": client.SessionToken,
183183
},
184-
EnableTailnet: wireguard,
185-
NodeDialer: client.WorkspaceAgentNodeBroker,
184+
EnableTailnet: wireguard,
185+
CoordinatorDialer: client.ListenWorkspaceAgentTailnet,
186186
})
187187
<-cmd.Context().Done()
188188
return closer.Close()

cli/cliflag/cliflag.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66
//
77
// Will produce the following usage docs:
88
//
9-
// -a, --address string The address to serve the API and dashboard (uses $CODER_ADDRESS). (default "127.0.0.1:3000")
10-
//
9+
// -a, --address string The address to serve the API and dashboard (uses $CODER_ADDRESS). (default "127.0.0.1:3000")
1110
package cliflag
1211

1312
import (

cli/cliflag/cliflag_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
)
1515

1616
// Testcliflag cannot run in parallel because it uses t.Setenv.
17+
//
1718
//nolint:paralleltest
1819
func TestCliflag(t *testing.T) {
1920
t.Run("StringDefault", func(t *testing.T) {

cli/configssh.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -558,7 +558,7 @@ func currentBinPath(w io.Writer) (string, error) {
558558

559559
// diffBytes takes two byte slices and diffs them as if they were in a
560560
// file named name.
561-
//nolint: revive // Color is an option, not a control coupling.
561+
// nolint: revive // Color is an option, not a control coupling.
562562
func diffBytes(name string, b1, b2 []byte, color bool) ([]byte, error) {
563563
var buf bytes.Buffer
564564
var opts []write.Option

coderd/authorize.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ func AuthorizeFilter[O rbac.Objecter](api *API, r *http.Request, action rbac.Act
3131
// This function will log appropriately, but the caller must return an
3232
// error to the api client.
3333
// Eg:
34+
//
3435
// if !api.Authorize(...) {
3536
// httpapi.Forbidden(rw)
3637
// return

coderd/coderd.go

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -341,8 +341,7 @@ func New(options *Options) *API {
341341
r.Get("/turn", api.workspaceAgentTurn)
342342
r.Get("/iceservers", api.workspaceAgentICEServers)
343343

344-
// Everything below this is Tailnet.
345-
r.Get("/coordinate", api.workspaceAgentClientCoordinate)
344+
r.Get("/coordinate", api.workspaceAgentCoordinate)
346345
})
347346
r.Route("/{workspaceagent}", func(r chi.Router) {
348347
r.Use(
@@ -356,9 +355,7 @@ func New(options *Options) *API {
356355
r.Get("/pty", api.workspaceAgentPTY)
357356
r.Get("/iceservers", api.workspaceAgentICEServers)
358357

359-
r.Get("/derpmap", func(w http.ResponseWriter, r *http.Request) {
360-
httpapi.Write(w, http.StatusOK, options.DERPMap)
361-
})
358+
r.Get("/connection", api.workspaceAgentConnection)
362359
r.Get("/coordinate", api.workspaceAgentClientCoordinate)
363360
})
364361
})

coderd/devtunnel/tunnel_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,8 @@ func TestTunnel(t *testing.T) {
100100

101101
// fakeTunnelServer is a fake version of the real dev tunnel server. It fakes 2 client interactions
102102
// that we want to test:
103-
// 1. Responding to a POST /tun from the client
104-
// 2. Sending an HTTP request down the wireguard connection
103+
// 1. Responding to a POST /tun from the client
104+
// 2. Sending an HTTP request down the wireguard connection
105105
//
106106
// Note that for 2, we don't implement a full proxy that accepts arbitrary requests, we just send
107107
// a test request over the Wireguard tunnel to make sure that we can listen. The proxy behavior is

coderd/workspaceagents.go

Lines changed: 52 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package coderd
22

33
import (
4-
"bytes"
54
"context"
65
"database/sql"
76
"encoding/json"
@@ -17,7 +16,6 @@ import (
1716
"github.com/hashicorp/yamux"
1817
"golang.org/x/xerrors"
1918
"nhooyr.io/websocket"
20-
"nhooyr.io/websocket/wsjson"
2119

2220
"cdr.dev/slog"
2321
"github.com/coder/coder/agent"
@@ -120,6 +118,12 @@ func (api *API) workspaceAgentDial(rw http.ResponseWriter, r *http.Request) {
120118

121119
func (api *API) workspaceAgentMetadata(rw http.ResponseWriter, r *http.Request) {
122120
workspaceAgent := httpmw.WorkspaceAgent(r)
121+
ips := make([]netip.Addr, 0)
122+
for _, ip := range workspaceAgent.IPAddresses {
123+
var ipData [16]byte
124+
copy(ipData[:], []byte(ip.IPNet.IP))
125+
ips = append(ips, netip.AddrFrom16(ipData))
126+
}
123127
apiAgent, err := convertWorkspaceAgent(workspaceAgent, nil, api.AgentInactiveDisconnectTimeout)
124128
if err != nil {
125129
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
@@ -130,7 +134,7 @@ func (api *API) workspaceAgentMetadata(rw http.ResponseWriter, r *http.Request)
130134
}
131135

132136
httpapi.Write(rw, http.StatusOK, agent.Metadata{
133-
IPAddresses: apiAgent.IPAddresses,
137+
IPAddresses: ips,
134138
DERPMap: api.DERPMap,
135139
EnvironmentVariables: apiAgent.EnvironmentVariables,
136140
StartupScript: apiAgent.StartupScript,
@@ -507,10 +511,26 @@ func (api *API) dialWorkspaceAgent(r *http.Request, agentID uuid.UUID) (agent.Co
507511
}, nil
508512
}
509513

510-
// workspaceAgentClientCoordinate accepts a WebSocket that reads node network updates.
511-
// After accept a PubSub starts listening for new connection node updates
512-
// which are written to the WebSocket.
513-
func (api *API) workspaceAgentClientCoordinate(rw http.ResponseWriter, r *http.Request) {
514+
func (api *API) workspaceAgentConnection(rw http.ResponseWriter, r *http.Request) {
515+
workspaceAgent := httpmw.WorkspaceAgentParam(r)
516+
workspace := httpmw.WorkspaceParam(r)
517+
if !api.Authorize(r, rbac.ActionRead, workspace) {
518+
httpapi.ResourceNotFound(rw)
519+
return
520+
}
521+
ips := make([]netip.Addr, 0)
522+
for _, ip := range workspaceAgent.IPAddresses {
523+
var ipData [16]byte
524+
copy(ipData[:], []byte(ip.IPNet.IP))
525+
ips = append(ips, netip.AddrFrom16(ipData))
526+
}
527+
httpapi.Write(rw, http.StatusOK, codersdk.WorkspaceAgentConnectionInfo{
528+
DERPMap: api.DERPMap,
529+
IPAddresses: ips,
530+
})
531+
}
532+
533+
func (api *API) workspaceAgentCoordinate(rw http.ResponseWriter, r *http.Request) {
514534
api.websocketWaitMutex.Lock()
515535
api.websocketWaitGroup.Add(1)
516536
api.websocketWaitMutex.Unlock()
@@ -526,54 +546,36 @@ func (api *API) workspaceAgentClientCoordinate(rw http.ResponseWriter, r *http.R
526546
return
527547
}
528548
defer conn.Close(websocket.StatusNormalClosure, "")
529-
agentIDBytes, _ := workspaceAgent.ID.MarshalText()
530-
subCancel, err := api.Pubsub.Subscribe("tailnet", func(ctx context.Context, message []byte) {
531-
// Since we subscribe to all peer broadcasts, we do a light check to
532-
// make sure we're the intended recipient without fully decoding the
533-
// message.
534-
if len(message) < len(agentIDBytes) {
535-
api.Logger.Error(ctx, "wireguard peer message too short", slog.F("got", len(message)))
536-
return
537-
}
538-
// We aren't the intended recipient.
539-
if !bytes.Equal(message[:len(agentIDBytes)], agentIDBytes) {
540-
return
541-
}
542-
_ = conn.Write(ctx, websocket.MessageText, message[len(agentIDBytes):])
543-
})
549+
err = api.ConnCoordinator.ServeAgent(r.Context(), conn, workspaceAgent.ID)
544550
if err != nil {
545-
api.Logger.Error(context.Background(), "pubsub listen", slog.Error(err))
551+
_ = conn.Close(websocket.StatusInternalError, err.Error())
546552
return
547553
}
548-
defer subCancel()
554+
}
549555

550-
for {
551-
var node tailnet.Node
552-
err = wsjson.Read(r.Context(), conn, &node)
553-
if err != nil {
554-
return
555-
}
556-
err := api.Database.UpdateWorkspaceAgentNetworkByID(r.Context(), database.UpdateWorkspaceAgentNetworkByIDParams{
557-
ID: workspaceAgent.ID,
558-
NodePublicKey: sql.NullString{
559-
String: node.Key.String(),
560-
Valid: true,
561-
},
562-
DERPLatency: node.DERPLatency,
563-
DiscoPublicKey: sql.NullString{
564-
String: node.DiscoKey.String(),
565-
Valid: true,
566-
},
567-
PreferredDERP: int32(node.PreferredDERP),
568-
UpdatedAt: database.Now(),
556+
// workspaceAgentClientCoordinate accepts a WebSocket that reads node network updates.
557+
// After accept a PubSub starts listening for new connection node updates
558+
// which are written to the WebSocket.
559+
func (api *API) workspaceAgentClientCoordinate(rw http.ResponseWriter, r *http.Request) {
560+
api.websocketWaitMutex.Lock()
561+
api.websocketWaitGroup.Add(1)
562+
api.websocketWaitMutex.Unlock()
563+
defer api.websocketWaitGroup.Done()
564+
workspaceAgent := httpmw.WorkspaceAgentParam(r)
565+
566+
conn, err := websocket.Accept(rw, r, nil)
567+
if err != nil {
568+
httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{
569+
Message: "Failed to accept websocket.",
570+
Detail: err.Error(),
569571
})
570-
if err != nil {
571-
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
572-
Message: "Internal error setting agent keys.",
573-
Detail: err.Error(),
574-
})
575-
return
576-
}
572+
return
573+
}
574+
defer conn.Close(websocket.StatusNormalClosure, "")
575+
err = api.ConnCoordinator.ServeClient(r.Context(), conn, uuid.New(), workspaceAgent.ID)
576+
if err != nil {
577+
_ = conn.Close(websocket.StatusInternalError, err.Error())
578+
return
577579
}
578580
}
579581

@@ -655,20 +657,6 @@ func convertWorkspaceAgent(dbAgent database.WorkspaceAgent, apps []codersdk.Work
655657
Apps: apps,
656658
PreferredDERP: int(dbAgent.PreferredDERP),
657659
DERPLatency: dbAgent.DERPLatency,
658-
IPAddresses: ips,
659-
}
660-
661-
if dbAgent.NodePublicKey.Valid {
662-
err := workspaceAgent.NodePublicKey.UnmarshalText([]byte(dbAgent.NodePublicKey.String))
663-
if err != nil {
664-
return codersdk.WorkspaceAgent{}, xerrors.Errorf("parse node public key: %w", err)
665-
}
666-
}
667-
if dbAgent.DiscoPublicKey.Valid {
668-
err := workspaceAgent.DiscoPublicKey.UnmarshalText([]byte(dbAgent.DiscoPublicKey.String))
669-
if err != nil {
670-
return codersdk.WorkspaceAgent{}, xerrors.Errorf("parse disco public key: %w", err)
671-
}
672660
}
673661

674662
if dbAgent.FirstConnectedAt.Valid {

0 commit comments

Comments
 (0)