Skip to content

chore: refactor DERP setting loop #15344

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 14 additions & 20 deletions codersdk/workspacesdk/connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import (
"nhooyr.io/websocket"
"storj.io/drpc"
"storj.io/drpc/drpcerr"
"tailscale.com/tailcfg"

"cdr.dev/slog"
"github.com/coder/coder/v2/buildinfo"
Expand All @@ -37,7 +36,7 @@ var tailnetConnectorGracefulTimeout = time.Second
// @typescript-ignore tailnetConn
type tailnetConn interface {
tailnet.Coordinatee
SetDERPMap(derpMap *tailcfg.DERPMap)
tailnet.DERPMapSetter
}

// tailnetAPIConnector dials the tailnet API (v2+) and then uses the API with a tailnet.Conn to
Expand Down Expand Up @@ -65,7 +64,7 @@ type tailnetAPIConnector struct {
coordinateURL string
clock quartz.Clock
dialOptions *websocket.DialOptions
conn tailnetConn
derpCtrl tailnet.DERPController
coordCtrl tailnet.CoordinationController
customDialFn func() (proto.DRPCTailnetClient, error)

Expand All @@ -91,7 +90,6 @@ func newTailnetAPIConnector(ctx context.Context, logger slog.Logger, agentID uui
coordinateURL: coordinateURL,
clock: clock,
dialOptions: dialOptions,
conn: nil,
connected: make(chan error, 1),
closed: make(chan struct{}),
}
Expand All @@ -112,7 +110,7 @@ func (tac *tailnetAPIConnector) manageGracefulTimeout() {

// Runs a tailnetAPIConnector using the provided connection
func (tac *tailnetAPIConnector) runConnector(conn tailnetConn) {
tac.conn = conn
tac.derpCtrl = tailnet.NewBasicDERPController(tac.logger, conn)
tac.coordCtrl = tailnet.NewSingleDestController(tac.logger, conn, tac.agentID)
tac.gracefulCtx, tac.cancelGracefulCtx = context.WithCancel(context.Background())
go tac.manageGracefulTimeout()
Expand Down Expand Up @@ -294,7 +292,9 @@ func (tac *tailnetAPIConnector) coordinate(client proto.DRPCTailnetClient) {
}

func (tac *tailnetAPIConnector) derpMap(client proto.DRPCTailnetClient) error {
s, err := client.StreamDERPMaps(tac.ctx, &proto.StreamDERPMapsRequest{})
s := &tailnet.DERPFromDRPCWrapper{}
var err error
s.Client, err = client.StreamDERPMaps(tac.ctx, &proto.StreamDERPMapsRequest{})
if err != nil {
return xerrors.Errorf("failed to connect to StreamDERPMaps RPC: %w", err)
}
Expand All @@ -304,21 +304,15 @@ func (tac *tailnetAPIConnector) derpMap(client proto.DRPCTailnetClient) error {
tac.logger.Debug(tac.ctx, "error closing StreamDERPMaps RPC", slog.Error(cErr))
}
}()
for {
dmp, err := s.Recv()
if err != nil {
if xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) {
return nil
}
if !xerrors.Is(err, io.EOF) {
tac.logger.Error(tac.ctx, "error receiving DERP Map", slog.Error(err))
}
return err
}
tac.logger.Debug(tac.ctx, "got new DERP Map", slog.F("derp_map", dmp))
dm := tailnet.DERPMapFromProto(dmp)
tac.conn.SetDERPMap(dm)
cw := tac.derpCtrl.New(s)
err = <-cw.Wait()
if xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) {
return nil
}
if err != nil && !xerrors.Is(err, io.EOF) {
tac.logger.Error(tac.ctx, "error receiving DERP Map", slog.Error(err))
}
return err
}

func (tac *tailnetAPIConnector) refreshToken(ctx context.Context, client proto.DRPCTailnetClient) {
Expand Down
2 changes: 0 additions & 2 deletions codersdk/workspacesdk/connector_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,6 @@ func TestTailnetAPIConnector_TelemetryUnimplemented(t *testing.T) {
coordinateURL: "",
clock: quartz.NewReal(),
dialOptions: &websocket.DialOptions{},
conn: nil,
connected: make(chan error, 1),
closed: make(chan struct{}),
customDialFn: func() (proto.DRPCTailnetClient, error) {
Expand Down Expand Up @@ -481,7 +480,6 @@ func TestTailnetAPIConnector_TelemetryNotRecognised(t *testing.T) {
coordinateURL: "",
clock: quartz.NewReal(),
dialOptions: &websocket.DialOptions{},
conn: nil,
connected: make(chan error, 1),
closed: make(chan struct{}),
customDialFn: func() (proto.DRPCTailnetClient, error) {
Expand Down
81 changes: 81 additions & 0 deletions tailnet/controllers.go
Original file line number Diff line number Diff line change
Expand Up @@ -359,3 +359,84 @@ func NewInMemoryCoordinatorClient(
)
return c
}

type DERPMapSetter interface {
SetDERPMap(derpMap *tailcfg.DERPMap)
}

type basicDERPController struct {
logger slog.Logger
setter DERPMapSetter
}

func (b *basicDERPController) New(client DERPClient) CloserWaiter {
l := &derpSetLoop{
logger: b.logger,
setter: b.setter,
client: client,
errChan: make(chan error, 1),
recvLoopDone: make(chan struct{}),
}
go l.recvLoop()
return l
}

func NewBasicDERPController(logger slog.Logger, setter DERPMapSetter) DERPController {
return &basicDERPController{
logger: logger,
setter: setter,
}
}

type derpSetLoop struct {
logger slog.Logger
setter DERPMapSetter
client DERPClient

sync.Mutex
closed bool
errChan chan error
recvLoopDone chan struct{}
}

func (l *derpSetLoop) Close(ctx context.Context) error {
l.Lock()
defer l.Unlock()
if l.closed {
select {
case <-ctx.Done():
return ctx.Err()
case <-l.recvLoopDone:
return nil
}
}
l.closed = true
cErr := l.client.Close()
select {
case <-ctx.Done():
return ctx.Err()
case <-l.recvLoopDone:
return cErr
}
}

func (l *derpSetLoop) Wait() <-chan error {
return l.errChan
}

func (l *derpSetLoop) recvLoop() {
defer close(l.recvLoopDone)
for {
dm, err := l.client.Recv()
if err != nil {
l.logger.Debug(context.Background(), "failed to receive DERP message", slog.Error(err))
select {
case l.errChan <- err:
default:
}
return
}
l.logger.Debug(context.Background(), "got new DERP Map", slog.F("derp_map", dm))
l.setter.SetDERPMap(dm)
}
}
70 changes: 70 additions & 0 deletions tailnet/controllers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"golang.org/x/xerrors"
"tailscale.com/tailcfg"
"tailscale.com/types/key"

Expand Down Expand Up @@ -281,3 +282,72 @@ func (f *fakeCoordinatee) SetNodeCallback(callback func(*tailnet.Node)) {
defer f.Unlock()
f.callback = callback
}

func TestNewBasicDERPController_Mainline(t *testing.T) {
t.Parallel()
fs := make(chan *tailcfg.DERPMap)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
uut := tailnet.NewBasicDERPController(logger, fakeSetter(fs))
fc := fakeDERPClient{
ch: make(chan *tailcfg.DERPMap),
}
c := uut.New(fc)
ctx := testutil.Context(t, testutil.WaitShort)
expectDM := &tailcfg.DERPMap{}
testutil.RequireSendCtx(ctx, t, fc.ch, expectDM)
gotDM := testutil.RequireRecvCtx(ctx, t, fs)
require.Equal(t, expectDM, gotDM)
err := c.Close(ctx)
require.NoError(t, err)
err = testutil.RequireRecvCtx(ctx, t, c.Wait())
require.ErrorIs(t, err, io.EOF)
// ensure Close is idempotent
err = c.Close(ctx)
require.NoError(t, err)
}

func TestNewBasicDERPController_RecvErr(t *testing.T) {
t.Parallel()
fs := make(chan *tailcfg.DERPMap)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
uut := tailnet.NewBasicDERPController(logger, fakeSetter(fs))
expectedErr := xerrors.New("a bad thing happened")
fc := fakeDERPClient{
ch: make(chan *tailcfg.DERPMap),
err: expectedErr,
}
c := uut.New(fc)
ctx := testutil.Context(t, testutil.WaitShort)
err := testutil.RequireRecvCtx(ctx, t, c.Wait())
require.ErrorIs(t, err, expectedErr)
// ensure Close is idempotent
err = c.Close(ctx)
require.NoError(t, err)
}

type fakeSetter chan *tailcfg.DERPMap

func (s fakeSetter) SetDERPMap(derpMap *tailcfg.DERPMap) {
s <- derpMap
}

type fakeDERPClient struct {
ch chan *tailcfg.DERPMap
err error
}

func (f fakeDERPClient) Close() error {
close(f.ch)
return nil
}

func (f fakeDERPClient) Recv() (*tailcfg.DERPMap, error) {
if f.err != nil {
return nil, f.err
}
dm, ok := <-f.ch
if ok {
return dm, nil
}
return nil, io.EOF
}
18 changes: 18 additions & 0 deletions tailnet/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -298,3 +298,21 @@ func WorkspaceStatusToProto(status codersdk.WorkspaceStatus) proto.Workspace_Sta
return proto.Workspace_UNKNOWN
}
}

type DERPFromDRPCWrapper struct {
Client proto.DRPCTailnet_StreamDERPMapsClient
}

func (w *DERPFromDRPCWrapper) Close() error {
return w.Client.Close()
}

func (w *DERPFromDRPCWrapper) Recv() (*tailcfg.DERPMap, error) {
p, err := w.Client.Recv()
if err != nil {
return nil, err
}
return DERPMapFromProto(p), nil
}

var _ DERPClient = &DERPFromDRPCWrapper{}
Loading