Skip to content

fix: change tailnet AwaitReachable to wait for wireguard #8492

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from
2 changes: 1 addition & 1 deletion agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -694,7 +694,7 @@ func (a *agent) trackConnGoroutine(fn func()) error {
}

func (a *agent) createTailnet(ctx context.Context, agentID uuid.UUID, derpMap *tailcfg.DERPMap, disableDirectConnections bool) (_ *tailnet.Conn, err error) {
network, err := tailnet.NewConn(&tailnet.Options{
network, err := tailnet.NewConn(tailnet.ConnTypeAgent, &tailnet.Options{
Addresses: a.wireguardAddresses(agentID),
DERPMap: derpMap,
Logger: a.logger.Named("tailnet"),
Expand Down
7 changes: 6 additions & 1 deletion agent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1805,6 +1805,7 @@ func TestAgent_WriteVSCodeConfigs(t *testing.T) {
}

func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) (*ptytest.PTYCmd, pty.Process) {
logger := slogtest.Make(t, nil).Named("testsetup").Leveled(slog.LevelDebug)
//nolint:dogsled
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
listener, err := net.Listen("tcp", "127.0.0.1:0")
Expand All @@ -1815,16 +1816,20 @@ func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) (*pt
for {
conn, err := listener.Accept()
if err != nil {
logger.Debug(context.Background(), "error listening", slog.Error(err))
return
}
logger.Debug(context.Background(), "got local TCP connection")

ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
ssh, err := agentConn.SSH(ctx)
cancel()
if err != nil {
logger.Debug(context.Background(), "failed to connect to agent SSH")
_ = conn.Close()
return
}
logger.Debug(context.Background(), "got SSH connection to agent")
waitGroup.Add(1)
go func() {
agentssh.Bicopy(context.Background(), conn, ssh)
Expand Down Expand Up @@ -1917,7 +1922,7 @@ func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Durati
t.Cleanup(func() {
_ = closer.Close()
})
conn, err := tailnet.NewConn(&tailnet.Options{
conn, err := tailnet.NewConn(tailnet.ConnTypeClient, &tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
DERPMap: metadata.DERPMap,
Logger: logger.Named("client"),
Expand Down
81 changes: 59 additions & 22 deletions cli/portforward_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"net"
"sync"
"testing"
"time"

"github.com/google/uuid"
"github.com/pion/udp"
Expand Down Expand Up @@ -153,15 +154,22 @@ func TestPortForward(t *testing.T) {
inv.Stdin = pty.Input()
inv.Stdout = pty.Output()
inv.Stderr = pty.Output()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
// don't set a timeout on the port-forward command, because the command context has to stay active across
// the call to t.Parallel() below, during which all non-parallel parts run to completion. We don't know
// how long this will take, as it depends on the number of test cases.
ctxCmd, cancelCmd := context.WithCancel(context.Background())
defer cancelCmd()
errC := make(chan error)
go func() {
errC <- inv.WithContext(ctx).Run()
errC <- inv.WithContext(ctxCmd).Run()
}()
pty.ExpectMatchContext(ctx, "Ready!")
ctxExpect, cancelExpect := context.WithTimeout(ctxCmd, testutil.WaitLong)
defer cancelExpect()
pty.ExpectMatchContext(ctxExpect, "Ready!")

t.Parallel() // Port is reserved, enable parallel execution.
// Now that we've unpaused for parallel execution, set a new timeout context for this part of the test.
ctx, cancel := context.WithTimeout(ctxCmd, testutil.WaitLong)
defer cancel()

// Open two connections simultaneously and test them out of
// sync.
Expand All @@ -175,9 +183,13 @@ func TestPortForward(t *testing.T) {
testDial(t, c2)
testDial(t, c1)

cancel()
err = <-errC
require.ErrorIs(t, err, context.Canceled)
cancelCmd()
select {
case <-time.After(testutil.WaitLong):
t.Fatal("timeout canceling port forward")
case err = <-errC:
require.ErrorIs(t, err, context.Canceled)
}
})

t.Run(c.name+"_TwoPorts", func(t *testing.T) {
Expand All @@ -200,16 +212,25 @@ func TestPortForward(t *testing.T) {
inv.Stdin = pty.Input()
inv.Stdout = pty.Output()
inv.Stderr = pty.Output()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
// don't set a timeout on the port-forward command, because the command context has to stay active across
// the call to t.Parallel() below, during which all non-parallel parts run to completion. We don't know
// how long this will take, as it depends on the number of test cases.
ctxCmd, cancelCmd := context.WithCancel(context.Background())
defer cancelCmd()
errC := make(chan error)
go func() {
errC <- inv.WithContext(ctx).Run()
errC <- inv.WithContext(ctxCmd).Run()
}()
pty.ExpectMatchContext(ctx, "Ready!")
ctxExpect, cancelExpect := context.WithTimeout(ctxCmd, testutil.WaitLong)
defer cancelExpect()
pty.ExpectMatchContext(ctxExpect, "Ready!")

t.Parallel() // Port is reserved, enable parallel execution.

// Now that we've unpaused for parallel execution, set a new timeout context for this part of the test.
ctx, cancel := context.WithTimeout(ctxCmd, testutil.WaitLong)
defer cancel()

// Open a connection to both listener 1 and 2 simultaneously and
// then test them out of order.
d := net.Dialer{Timeout: testutil.WaitShort}
Expand All @@ -222,9 +243,13 @@ func TestPortForward(t *testing.T) {
testDial(t, c2)
testDial(t, c1)

cancel()
err = <-errC
require.ErrorIs(t, err, context.Canceled)
cancelCmd()
select {
case <-time.After(testutil.WaitLong):
t.Fatal("timeout canceling port forward")
case err = <-errC:
require.ErrorIs(t, err, context.Canceled)
}
})
}

Expand Down Expand Up @@ -253,15 +278,23 @@ func TestPortForward(t *testing.T) {
clitest.SetupConfig(t, client, root)
pty := ptytest.New(t).Attach(inv)
inv.Stderr = pty.Output()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
// don't set a timeout on the port-forward command, because the command context has to stay active across
// the call to t.Parallel() below, during which all non-parallel parts run to completion. We don't know
// how long this will take, as it depends on the number of test cases.
ctxCmd, cancelCmd := context.WithCancel(context.Background())
defer cancelCmd()
errC := make(chan error)
go func() {
errC <- inv.WithContext(ctx).Run()
errC <- inv.WithContext(ctxCmd).Run()
}()
pty.ExpectMatchContext(ctx, "Ready!")
ctxExpect, cancelExpect := context.WithTimeout(ctxCmd, testutil.WaitLong)
defer cancelExpect()
pty.ExpectMatchContext(ctxExpect, "Ready!")

t.Parallel() // Port is reserved, enable parallel execution.
// Now that we've unpaused for parallel execution, set a new timeout context for this part of the test.
ctx, cancel := context.WithTimeout(ctxCmd, testutil.WaitLong)
defer cancel()

// Open connections to all items in the "dial" array.
var (
Expand All @@ -282,9 +315,13 @@ func TestPortForward(t *testing.T) {
testDial(t, conns[i])
}

cancel()
err := <-errC
require.ErrorIs(t, err, context.Canceled)
cancelCmd()
select {
case <-time.After(testutil.WaitLong):
t.Fatal("timeout canceling port forward")
case err := <-errC:
require.ErrorIs(t, err, context.Canceled)
}
})
}

Expand Down
1 change: 0 additions & 1 deletion cli/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,6 @@ func (r *RootCmd) ssh() *clibase.Cmd {
return xerrors.Errorf("dial agent: %w", err)
}
defer conn.Close()
conn.AwaitReachable(ctx)
stopPolling := tryPollWorkspaceAutostop(ctx, client, workspace)
defer stopPolling()

Expand Down
25 changes: 24 additions & 1 deletion cli/ssh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"net"
"os"
"os/exec"
"path"
"path/filepath"
"runtime"
"strings"
Expand Down Expand Up @@ -412,6 +413,28 @@ func TestSSH(t *testing.T) {
t.Parallel()

logDir := t.TempDir()
defer func() {
// Copy any log files into test logs for debugging
ents, err := os.ReadDir(logDir)
if err != nil {
t.Log("failed to read logDir")
}
for _, ent := range ents {
fn := path.Join(logDir, ent.Name())
f, err := os.Open(fn)
if err != nil {
t.Logf("failed to open logfile %s", fn)
}
logs, err := io.ReadAll(f)
f.Close()
if err != nil {
t.Logf("failed to read logfile %s", fn)
continue
}
t.Logf("logfile %s:", fn)
t.Log(logs)
}
}()

client, workspace, agentToken := setupWorkspaceForAgent(t, nil)
inv, root := clitest.New(t, "ssh", "-l", logDir, workspace.Name)
Expand All @@ -425,7 +448,7 @@ func TestSSH(t *testing.T) {
agentClient.SetSessionToken(agentToken)
agentCloser := agent.New(agent.Options{
Client: agentClient,
Logger: slogtest.Make(t, nil).Named("agent"),
Logger: slogtest.Make(t, nil).Named("agent").Leveled(slog.LevelDebug),
})
defer func() {
_ = agentCloser.Close()
Expand Down
1 change: 0 additions & 1 deletion cli/vscodessh.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@ func (r *RootCmd) vscodeSSH() *clibase.Cmd {
}
defer agentConn.Close()

agentConn.AwaitReachable(ctx)
rawSSH, err := agentConn.SSH(ctx)
if err != nil {
return err
Expand Down
7 changes: 5 additions & 2 deletions coderd/coderd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,18 @@ func TestDERP(t *testing.T) {
},
},
}
// it's a bit arbitrary which is the client and which is the agent,
// but, we need one of each because the client initiates the wireguard
// connection.
w1IP := tailnet.IP()
w1, err := tailnet.NewConn(&tailnet.Options{
w1, err := tailnet.NewConn(tailnet.ConnTypeClient, &tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(w1IP, 128)},
Logger: logger.Named("w1"),
DERPMap: derpMap,
})
require.NoError(t, err)

w2, err := tailnet.NewConn(&tailnet.Options{
w2, err := tailnet.NewConn(tailnet.ConnTypeAgent, &tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
Logger: logger.Named("w2"),
DERPMap: derpMap,
Expand Down
2 changes: 1 addition & 1 deletion coderd/tailnet.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func NewServerTailnet(
cache *wsconncache.Cache,
) (*ServerTailnet, error) {
logger = logger.Named("servertailnet")
conn, err := tailnet.NewConn(&tailnet.Options{
conn, err := tailnet.NewConn(tailnet.ConnTypeClient, &tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
DERPMap: derpMap,
Logger: logger,
Expand Down
2 changes: 1 addition & 1 deletion coderd/tailnet_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ func setupAgent(t *testing.T, agentAddresses []netip.Prefix) (uuid.UUID, agent.A
}, testutil.WaitShort, testutil.IntervalFast)

cache := wsconncache.New(func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
conn, err := tailnet.NewConn(&tailnet.Options{
conn, err := tailnet.NewConn(tailnet.ConnTypeClient, &tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
DERPMap: manifest.DERPMap,
Logger: logger.Named("client"),
Expand Down
2 changes: 1 addition & 1 deletion coderd/workspaceagents.go
Original file line number Diff line number Diff line change
Expand Up @@ -686,7 +686,7 @@ func (api *API) workspaceAgentListeningPorts(rw http.ResponseWriter, r *http.Req
// See: https://github.com/coder/coder/issues/8218
func (api *API) _dialWorkspaceAgentTailnet(agentID uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
clientConn, serverConn := net.Pipe()
conn, err := tailnet.NewConn(&tailnet.Options{
conn, err := tailnet.NewConn(tailnet.ConnTypeClient, &tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
DERPMap: api.DERPMap,
Logger: api.Logger.Named("tailnet"),
Expand Down
2 changes: 1 addition & 1 deletion coderd/wsconncache/wsconncache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ func setupAgent(t *testing.T, manifest agentsdk.Manifest, ptyTimeout time.Durati
t.Cleanup(func() {
_ = closer.Close()
})
conn, err := tailnet.NewConn(&tailnet.Options{
conn, err := tailnet.NewConn(tailnet.ConnTypeClient, &tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
DERPMap: manifest.DERPMap,
Logger: slogtest.Make(t, nil).Named("tailnet").Leveled(slog.LevelDebug),
Expand Down
2 changes: 1 addition & 1 deletion codersdk/workspaceagents.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, opti
if ok {
header = headerTransport.Header()
}
conn, err := tailnet.NewConn(&tailnet.Options{
conn, err := tailnet.NewConn(tailnet.ConnTypeClient, &tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(ip, 128)},
DERPMap: connInfo.DERPMap,
DERPHeader: &header,
Expand Down
4 changes: 2 additions & 2 deletions scaletest/workspacetraffic/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@ import (
"io"
"sync"

"github.com/coder/coder/codersdk"

"github.com/google/uuid"
"github.com/hashicorp/go-multierror"
gossh "golang.org/x/crypto/ssh"
"golang.org/x/xerrors"

"github.com/coder/coder/codersdk"
)

func connectPTY(ctx context.Context, client *codersdk.Client, agentID, reconnect uuid.UUID) (*countReadWriteCloser, error) {
Expand Down
Loading