Skip to content

Commit 20fb40e

Browse files
committed
chore: implement vpn client & dylib tunnel
1 parent 32fc844 commit 20fb40e

File tree

7 files changed

+520
-71
lines changed

7 files changed

+520
-71
lines changed

tailnet/conn.go

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414

1515
"github.com/cenkalti/backoff/v4"
1616
"github.com/google/uuid"
17+
"github.com/tailscale/wireguard-go/tun"
1718
"golang.org/x/xerrors"
1819
"google.golang.org/protobuf/types/known/durationpb"
1920
"google.golang.org/protobuf/types/known/wrapperspb"
@@ -113,6 +114,8 @@ type Options struct {
113114
DNSConfigurator dns.OSConfigurator
114115
// Router is optional, and is passed to the underlying wireguard engine.
115116
Router router.Router
117+
// TUNDev is optional, and is passed to the underlying wireguard engine.
118+
TUNDev tun.Device
116119
}
117120

118121
// TelemetrySink allows tailnet.Conn to send network telemetry to the Coder
@@ -143,6 +146,8 @@ func NewConn(options *Options) (conn *Conn, err error) {
143146
return nil, xerrors.New("At least one IP range must be provided")
144147
}
145148

149+
netns.SetEnabled(options.TUNDev != nil)
150+
146151
var telemetryStore *TelemetryStore
147152
if options.TelemetrySink != nil {
148153
var err error
@@ -187,6 +192,7 @@ func NewConn(options *Options) (conn *Conn, err error) {
187192
SetSubsystem: sys.Set,
188193
DNS: options.DNSConfigurator,
189194
Router: options.Router,
195+
Tun: options.TUNDev,
190196
})
191197
if err != nil {
192198
return nil, xerrors.Errorf("create wgengine: %w", err)
@@ -197,11 +203,14 @@ func NewConn(options *Options) (conn *Conn, err error) {
197203
}
198204
}()
199205
wireguardEngine.InstallCaptureHook(options.CaptureHook)
200-
dialer.UseNetstackForIP = func(ip netip.Addr) bool {
201-
_, ok := wireguardEngine.PeerForIP(ip)
202-
return ok
206+
if options.TUNDev == nil {
207+
dialer.UseNetstackForIP = func(ip netip.Addr) bool {
208+
_, ok := wireguardEngine.PeerForIP(ip)
209+
return ok
210+
}
203211
}
204212

213+
wireguardEngine = wgengine.NewWatchdog(wireguardEngine)
205214
sys.Set(wireguardEngine)
206215

207216
magicConn := sys.MagicSock.Get()
@@ -244,11 +253,12 @@ func NewConn(options *Options) (conn *Conn, err error) {
244253
return nil, xerrors.Errorf("create netstack: %w", err)
245254
}
246255

247-
dialer.NetstackDialTCP = func(ctx context.Context, dst netip.AddrPort) (net.Conn, error) {
248-
return netStack.DialContextTCP(ctx, dst)
256+
if options.TUNDev == nil {
257+
dialer.NetstackDialTCP = func(ctx context.Context, dst netip.AddrPort) (net.Conn, error) {
258+
return netStack.DialContextTCP(ctx, dst)
259+
}
260+
netStack.ProcessLocalIPs = true
249261
}
250-
netStack.ProcessLocalIPs = true
251-
wireguardEngine = wgengine.NewWatchdog(wireguardEngine)
252262

253263
cfgMaps := newConfigMaps(
254264
options.Logger,

tailnet/controllers.go

Lines changed: 59 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ type WorkspaceUpdatesClient interface {
104104

105105
type WorkspaceUpdatesController interface {
106106
New(WorkspaceUpdatesClient) CloserWaiter
107+
CurrentState() *proto.WorkspaceUpdate
107108
}
108109

109110
// DNSHostsSetter is something that you can set a mapping of DNS names to IPs on. It's the subset
@@ -856,10 +857,14 @@ func (r *basicResumeTokenRefresher) refresh() {
856857
}
857858

858859
type tunnelAllWorkspaceUpdatesController struct {
859-
coordCtrl *TunnelSrcCoordController
860-
dnsHostSetter DNSHostsSetter
861-
ownerUsername string
862-
logger slog.Logger
860+
coordCtrl *TunnelSrcCoordController
861+
dnsHostSetter DNSHostsSetter
862+
updateCallback func(*proto.WorkspaceUpdate)
863+
ownerUsername string
864+
logger slog.Logger
865+
866+
sync.Mutex
867+
updater *tunnelUpdater
863868
}
864869

865870
type workspace struct {
@@ -902,18 +907,51 @@ type agent struct {
902907
}
903908

904909
func (t *tunnelAllWorkspaceUpdatesController) New(client WorkspaceUpdatesClient) CloserWaiter {
910+
t.Lock()
911+
defer t.Unlock()
905912
updater := &tunnelUpdater{
906913
client: client,
907914
errChan: make(chan error, 1),
908915
logger: t.logger,
909916
coordCtrl: t.coordCtrl,
910917
dnsHostsSetter: t.dnsHostSetter,
918+
updateCallback: t.updateCallback,
911919
ownerUsername: t.ownerUsername,
912920
recvLoopDone: make(chan struct{}),
913921
workspaces: make(map[uuid.UUID]*workspace),
914922
}
915-
go updater.recvLoop()
916-
return updater
923+
t.updater = updater
924+
go t.updater.recvLoop()
925+
return t.updater
926+
}
927+
928+
func (t *tunnelAllWorkspaceUpdatesController) CurrentState() *proto.WorkspaceUpdate {
929+
t.Lock()
930+
defer t.Unlock()
931+
if t.updater == nil {
932+
return nil
933+
}
934+
t.updater.Lock()
935+
defer t.updater.Unlock()
936+
out := &proto.WorkspaceUpdate{
937+
UpsertedWorkspaces: make([]*proto.Workspace, 0, len(t.updater.workspaces)),
938+
UpsertedAgents: make([]*proto.Agent, 0, len(t.updater.workspaces)),
939+
}
940+
for _, w := range t.updater.workspaces {
941+
upw := &proto.Workspace{
942+
Id: UUIDToByteSlice(w.id),
943+
Name: w.name,
944+
}
945+
out.UpsertedWorkspaces = append(out.UpsertedWorkspaces, upw)
946+
for _, a := range w.agents {
947+
out.UpsertedAgents = append(out.UpsertedAgents, &proto.Agent{
948+
Id: UUIDToByteSlice(a.id),
949+
Name: a.name,
950+
WorkspaceId: UUIDToByteSlice(w.id),
951+
})
952+
}
953+
}
954+
return out
917955
}
918956

919957
type tunnelUpdater struct {
@@ -922,14 +960,13 @@ type tunnelUpdater struct {
922960
client WorkspaceUpdatesClient
923961
coordCtrl *TunnelSrcCoordController
924962
dnsHostsSetter DNSHostsSetter
963+
updateCallback func(*proto.WorkspaceUpdate)
925964
ownerUsername string
926965
recvLoopDone chan struct{}
927966

928-
// don't need the mutex since only manipulated by the recvLoop
929-
workspaces map[uuid.UUID]*workspace
930-
931967
sync.Mutex
932-
closed bool
968+
workspaces map[uuid.UUID]*workspace
969+
closed bool
933970
}
934971

935972
func (t *tunnelUpdater) Close(ctx context.Context) error {
@@ -991,6 +1028,8 @@ func (t *tunnelUpdater) recvLoop() {
9911028
}
9921029

9931030
func (t *tunnelUpdater) handleUpdate(update *proto.WorkspaceUpdate) error {
1031+
t.Lock()
1032+
defer t.Unlock()
9941033
for _, uw := range update.UpsertedWorkspaces {
9951034
workspaceID, err := uuid.FromBytes(uw.Id)
9961035
if err != nil {
@@ -1056,6 +1095,10 @@ func (t *tunnelUpdater) handleUpdate(update *proto.WorkspaceUpdate) error {
10561095
} else {
10571096
t.logger.Debug(context.Background(), "skipping setting DNS names because we have no setter")
10581097
}
1098+
if t.updateCallback != nil {
1099+
t.logger.Debug(context.Background(), "calling update callback")
1100+
t.updateCallback(update)
1101+
}
10591102
return nil
10601103
}
10611104

@@ -1128,6 +1171,12 @@ func WithDNS(d DNSHostsSetter, ownerUsername string) TunnelAllOption {
11281171
}
11291172
}
11301173

1174+
func WithCallback(cb func(*proto.WorkspaceUpdate)) TunnelAllOption {
1175+
return func(t *tunnelAllWorkspaceUpdatesController) {
1176+
t.updateCallback = cb
1177+
}
1178+
}
1179+
11311180
// NewTunnelAllWorkspaceUpdatesController creates a WorkspaceUpdatesController that creates tunnels
11321181
// (via the TunnelSrcCoordController) to all agents received over the WorkspaceUpdates RPC. If a
11331182
// DNSHostSetter is provided, it also programs DNS hosts based on the agent and workspace names.

tailnet/controllers_test.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1778,6 +1778,10 @@ type fakeWorkspaceUpdatesController struct {
17781778
calls chan *newWorkspaceUpdatesCall
17791779
}
17801780

1781+
func (*fakeWorkspaceUpdatesController) CurrentState() *proto.WorkspaceUpdate {
1782+
panic("unimplemented")
1783+
}
1784+
17811785
type newWorkspaceUpdatesCall struct {
17821786
client tailnet.WorkspaceUpdatesClient
17831787
resp chan<- tailnet.CloserWaiter

vpn/client.go

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
package vpn
2+
3+
import (
4+
"context"
5+
"net/http"
6+
"net/netip"
7+
8+
"github.com/tailscale/wireguard-go/tun"
9+
"golang.org/x/xerrors"
10+
"nhooyr.io/websocket"
11+
"tailscale.com/net/dns"
12+
"tailscale.com/wgengine/router"
13+
14+
"cdr.dev/slog"
15+
"github.com/coder/coder/v2/codersdk"
16+
"github.com/coder/coder/v2/codersdk/workspacesdk"
17+
"github.com/coder/coder/v2/tailnet"
18+
"github.com/coder/coder/v2/tailnet/proto"
19+
"github.com/coder/quartz"
20+
)
21+
22+
type Client struct {
23+
sdk *codersdk.Client
24+
}
25+
26+
// NewClient creates a new VPN client.
27+
func NewClient(c *codersdk.Client) *Client {
28+
return &Client{
29+
sdk: c,
30+
}
31+
}
32+
33+
type DialOptions struct {
34+
Logger slog.Logger
35+
DNSConfigurator dns.OSConfigurator
36+
Router router.Router
37+
TUNDev tun.Device
38+
UpdatesCallback func(*proto.WorkspaceUpdate)
39+
}
40+
41+
func (c *Client) Dial(dialCtx context.Context, options *DialOptions) (vpnConn Conn, err error) {
42+
if options == nil {
43+
options = &DialOptions{}
44+
}
45+
46+
var headers http.Header
47+
if headerTransport, ok := c.sdk.HTTPClient.Transport.(*codersdk.HeaderTransport); ok {
48+
headers = headerTransport.Header
49+
}
50+
headers.Set(codersdk.SessionTokenHeader, c.sdk.SessionToken())
51+
52+
// New context, separate from dialCtx. We don't want to cancel the
53+
// connection if dialCtx is canceled.
54+
ctx, cancel := context.WithCancel(context.Background())
55+
defer func() {
56+
if err != nil {
57+
cancel()
58+
}
59+
}()
60+
61+
rpcURL, err := c.sdk.URL.Parse("/api/v2/tailnet")
62+
if err != nil {
63+
return Conn{}, xerrors.Errorf("parse rpc url: %w", err)
64+
}
65+
66+
me, err := c.sdk.User(dialCtx, codersdk.Me)
67+
if err != nil {
68+
return Conn{}, xerrors.Errorf("get user: %w", err)
69+
}
70+
71+
connInfo, err := workspacesdk.New(c.sdk).AgentConnectionInfoGeneric(dialCtx)
72+
if err != nil {
73+
return Conn{}, xerrors.Errorf("get connection info: %w", err)
74+
}
75+
76+
dialer := workspacesdk.NewWebsocketDialer(options.Logger, rpcURL, &websocket.DialOptions{
77+
HTTPClient: c.sdk.HTTPClient,
78+
HTTPHeader: headers,
79+
CompressionMode: websocket.CompressionDisabled,
80+
}, workspacesdk.WithWorkspaceUpdates(&proto.WorkspaceUpdatesRequest{
81+
WorkspaceOwnerId: tailnet.UUIDToByteSlice(me.ID),
82+
}))
83+
84+
ip := tailnet.CoderServicePrefix.RandomAddr()
85+
conn, err := tailnet.NewConn(&tailnet.Options{
86+
Addresses: []netip.Prefix{netip.PrefixFrom(ip, 128)},
87+
DERPMap: connInfo.DERPMap,
88+
DERPHeader: &headers,
89+
DERPForceWebSockets: connInfo.DERPForceWebSockets,
90+
Logger: options.Logger,
91+
BlockEndpoints: connInfo.DisableDirectConnections,
92+
DNSConfigurator: options.DNSConfigurator,
93+
Router: options.Router,
94+
TUNDev: options.TUNDev,
95+
})
96+
if err != nil {
97+
return Conn{}, xerrors.Errorf("create tailnet: %w", err)
98+
}
99+
defer func() {
100+
if err != nil {
101+
_ = conn.Close()
102+
}
103+
}()
104+
105+
clk := quartz.NewReal()
106+
controller := tailnet.NewController(options.Logger, dialer)
107+
coordCtrl := tailnet.NewTunnelSrcCoordController(options.Logger, conn)
108+
controller.ResumeTokenCtrl = tailnet.NewBasicResumeTokenController(options.Logger, clk)
109+
controller.CoordCtrl = coordCtrl
110+
controller.DERPCtrl = tailnet.NewBasicDERPController(options.Logger, conn)
111+
controller.WorkspaceUpdatesCtrl = tailnet.NewTunnelAllWorkspaceUpdatesController(
112+
options.Logger,
113+
coordCtrl,
114+
tailnet.WithDNS(conn, me.Name),
115+
tailnet.WithCallback(options.UpdatesCallback),
116+
)
117+
controller.Run(ctx)
118+
119+
options.Logger.Debug(ctx, "running tailnet API v2+ connector")
120+
121+
select {
122+
case <-dialCtx.Done():
123+
return Conn{}, xerrors.Errorf("timed out waiting for coordinator and derp map: %w", dialCtx.Err())
124+
case err = <-dialer.Connected():
125+
if err != nil {
126+
options.Logger.Error(ctx, "failed to connect to tailnet v2+ API", slog.Error(err))
127+
return Conn{}, xerrors.Errorf("start connector: %w", err)
128+
}
129+
options.Logger.Debug(ctx, "connected to tailnet v2+ API")
130+
}
131+
132+
return Conn{
133+
Conn: conn,
134+
cancelFn: cancel,
135+
controller: controller,
136+
}, nil
137+
}
138+
139+
type Conn struct {
140+
*tailnet.Conn
141+
142+
cancelFn func()
143+
controller *tailnet.Controller
144+
}
145+
146+
func (c Conn) CurrentWorkspaceState() *proto.WorkspaceUpdate {
147+
return c.controller.WorkspaceUpdatesCtrl.CurrentState()
148+
}
149+
150+
func (c Conn) Close() error {
151+
c.cancelFn()
152+
<-c.controller.Closed()
153+
return c.Conn.Close()
154+
}

0 commit comments

Comments
 (0)