Skip to content

feat(agent/agentssh): use tcp for X11 forwarding #14560

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions agent/agentssh/agentssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,9 @@ type Config struct {
// where users will land when they connect via SSH. Default is the home
// directory of the user.
WorkingDirectory func() string
// X11SocketDir is the directory where X11 sockets are created. Default is
// /tmp/.X11-unix.
X11SocketDir string
// X11DisplayOffset is the offset to add to the X11 display number.
// Default is 10.
X11DisplayOffset *int
// BlockFileTransfer restricts use of file transfer applications.
BlockFileTransfer bool
}
Expand Down Expand Up @@ -124,8 +124,9 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
if config == nil {
config = &Config{}
}
if config.X11SocketDir == "" {
config.X11SocketDir = filepath.Join(os.TempDir(), ".X11-unix")
if config.X11DisplayOffset == nil {
offset := X11DefaultDisplayOffset
config.X11DisplayOffset = &offset
}
if config.UpdateEnv == nil {
config.UpdateEnv = func(current []string) ([]string, error) { return current, nil }
Expand Down Expand Up @@ -273,13 +274,13 @@ func (s *Server) sessionHandler(session ssh.Session) {
extraEnv := make([]string, 0)
x11, hasX11 := session.X11()
if hasX11 {
handled := s.x11Handler(session.Context(), x11)
display, handled := s.x11Handler(session.Context(), x11)
if !handled {
_ = session.Exit(1)
logger.Error(ctx, "x11 handler failed")
return
}
extraEnv = append(extraEnv, fmt.Sprintf("DISPLAY=:%d.0", x11.ScreenNumber))
extraEnv = append(extraEnv, fmt.Sprintf("DISPLAY=localhost:%d.%d", display, x11.ScreenNumber))
}

if s.fileTransferBlocked(session) {
Expand Down
141 changes: 88 additions & 53 deletions agent/agentssh/x11.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"errors"
"fmt"
"io"
"math"
"net"
"os"
"path/filepath"
Expand All @@ -22,102 +23,136 @@ import (
"cdr.dev/slog"
)

const (
// X11StartPort is the starting port for X11 forwarding, this is the
// port used for "DISPLAY=localhost:0".
X11StartPort = 6000
// X11DefaultDisplayOffset is the default offset for X11 forwarding.
X11DefaultDisplayOffset = 10
)

// x11Callback is called when the client requests X11 forwarding.
// It adds an Xauthority entry to the Xauthority file.
func (s *Server) x11Callback(ctx ssh.Context, x11 ssh.X11) bool {
func (*Server) x11Callback(_ ssh.Context, _ ssh.X11) bool {
// Always allow.
return true
}

// x11Handler is called when a session has requested X11 forwarding.
// It listens for X11 connections and forwards them to the client.
func (s *Server) x11Handler(ctx ssh.Context, x11 ssh.X11) (displayNumber int, handled bool) {
serverConn, valid := ctx.Value(ssh.ContextKeyConn).(*gossh.ServerConn)
if !valid {
s.logger.Warn(ctx, "failed to get server connection")
return -1, false
}

hostname, err := os.Hostname()
if err != nil {
s.logger.Warn(ctx, "failed to get hostname", slog.Error(err))
s.metrics.x11HandlerErrors.WithLabelValues("hostname").Add(1)
return false
return -1, false
}

err = s.fs.MkdirAll(s.config.X11SocketDir, 0o700)
ln, display, err := createX11Listener(ctx, *s.config.X11DisplayOffset)
if err != nil {
s.logger.Warn(ctx, "failed to make the x11 socket dir", slog.F("dir", s.config.X11SocketDir), slog.Error(err))
s.metrics.x11HandlerErrors.WithLabelValues("socker_dir").Add(1)
return false
}
s.logger.Warn(ctx, "failed to create X11 listener", slog.Error(err))
s.metrics.x11HandlerErrors.WithLabelValues("listen").Add(1)
return -1, false
}
s.trackListener(ln, true)
defer func() {
if !handled {
s.trackListener(ln, false)
_ = ln.Close()
}
}()

err = addXauthEntry(ctx, s.fs, hostname, strconv.Itoa(int(x11.ScreenNumber)), x11.AuthProtocol, x11.AuthCookie)
err = addXauthEntry(ctx, s.fs, hostname, strconv.Itoa(display), x11.AuthProtocol, x11.AuthCookie)
if err != nil {
s.logger.Warn(ctx, "failed to add Xauthority entry", slog.Error(err))
s.metrics.x11HandlerErrors.WithLabelValues("xauthority").Add(1)
return false
return -1, false
}
return true
}

// x11Handler is called when a session has requested X11 forwarding.
// It listens for X11 connections and forwards them to the client.
func (s *Server) x11Handler(ctx ssh.Context, x11 ssh.X11) bool {
serverConn, valid := ctx.Value(ssh.ContextKeyConn).(*gossh.ServerConn)
if !valid {
s.logger.Warn(ctx, "failed to get server connection")
return false
}
// We want to overwrite the socket so that subsequent connections will succeed.
socketPath := filepath.Join(s.config.X11SocketDir, fmt.Sprintf("X%d", x11.ScreenNumber))
err := os.Remove(socketPath)
if err != nil && !errors.Is(err, os.ErrNotExist) {
s.logger.Warn(ctx, "failed to remove existing X11 socket", slog.Error(err))
return false
}
listener, err := net.Listen("unix", socketPath)
if err != nil {
s.logger.Warn(ctx, "failed to listen for X11", slog.Error(err))
return false
}
s.trackListener(listener, true)
go func() {
// Don't leave the listener open after the session is gone.
<-ctx.Done()
_ = ln.Close()
}()

go func() {
defer listener.Close()
defer s.trackListener(listener, false)
handledFirstConnection := false
defer ln.Close()
defer s.trackListener(ln, false)

for {
conn, err := listener.Accept()
conn, err := ln.Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
return
}
s.logger.Warn(ctx, "failed to accept X11 connection", slog.Error(err))
return
}
if x11.SingleConnection && handledFirstConnection {
s.logger.Warn(ctx, "X11 connection rejected because single connection is enabled")
_ = conn.Close()
continue
if x11.SingleConnection {
s.logger.Debug(ctx, "single connection requested, closing X11 listener")
_ = ln.Close()
}
handledFirstConnection = true

unixConn, ok := conn.(*net.UnixConn)
tcpConn, ok := conn.(*net.TCPConn)
if !ok {
s.logger.Warn(ctx, fmt.Sprintf("failed to cast connection to UnixConn. got: %T", conn))
return
s.logger.Warn(ctx, fmt.Sprintf("failed to cast connection to TCPConn. got: %T", conn))
_ = conn.Close()
continue
}
unixAddr, ok := unixConn.LocalAddr().(*net.UnixAddr)
tcpAddr, ok := tcpConn.LocalAddr().(*net.TCPAddr)
if !ok {
s.logger.Warn(ctx, fmt.Sprintf("failed to cast local address to UnixAddr. got: %T", unixConn.LocalAddr()))
return
s.logger.Warn(ctx, fmt.Sprintf("failed to cast local address to TCPAddr. got: %T", tcpConn.LocalAddr()))
_ = conn.Close()
continue
}

channel, reqs, err := serverConn.OpenChannel("x11", gossh.Marshal(struct {
OriginatorAddress string
OriginatorPort uint32
}{
OriginatorAddress: unixAddr.Name,
OriginatorPort: 0,
OriginatorAddress: tcpAddr.IP.String(),
OriginatorPort: uint32(tcpAddr.Port),
}))
if err != nil {
s.logger.Warn(ctx, "failed to open X11 channel", slog.Error(err))
return
_ = conn.Close()
continue
}
go gossh.DiscardRequests(reqs)
go Bicopy(ctx, conn, channel)

if !s.trackConn(ln, conn, true) {
s.logger.Warn(ctx, "failed to track X11 connection")
_ = conn.Close()
continue
}
go func() {
defer s.trackConn(ln, conn, false)
Bicopy(ctx, conn, channel)
}()
}
}()
return true

return display, true
}

// createX11Listener creates a listener for X11 forwarding, it will use
// the next available port starting from X11StartPort and displayOffset.
func createX11Listener(ctx context.Context, displayOffset int) (ln net.Listener, display int, err error) {
var lc net.ListenConfig
// Look for an open port to listen on.
for port := X11StartPort + displayOffset; port < math.MaxUint16; port++ {
ln, err = lc.Listen(ctx, "tcp", fmt.Sprintf("localhost:%d", port))
if err == nil {
display = port - X11StartPort
return ln, display, nil
}
}
return nil, -1, xerrors.Errorf("failed to find open port for X11 listener: %w", err)
}

// addXauthEntry adds an Xauthority entry to the Xauthority file.
Expand Down
40 changes: 33 additions & 7 deletions agent/agentssh/x11_test.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
package agentssh_test

import (
"bufio"
"bytes"
"context"
"encoding/hex"
"fmt"
"net"
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"testing"

"github.com/gliderlabs/ssh"
Expand All @@ -31,10 +36,7 @@ func TestServer_X11(t *testing.T) {
ctx := context.Background()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
fs := afero.NewOsFs()
dir := t.TempDir()
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), fs, &agentssh.Config{
X11SocketDir: dir,
})
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), fs, &agentssh.Config{})
require.NoError(t, err)
defer s.Close()

Expand All @@ -53,21 +55,45 @@ func TestServer_X11(t *testing.T) {
sess, err := c.NewSession()
require.NoError(t, err)

wantScreenNumber := 1
reply, err := sess.SendRequest("x11-req", true, gossh.Marshal(ssh.X11{
AuthProtocol: "MIT-MAGIC-COOKIE-1",
AuthCookie: hex.EncodeToString([]byte("cookie")),
ScreenNumber: 0,
ScreenNumber: uint32(wantScreenNumber),
}))
require.NoError(t, err)
assert.True(t, reply)

err = sess.Shell()
// Want: ~DISPLAY=localhost:10.1
out, err := sess.Output("echo DISPLAY=$DISPLAY")
require.NoError(t, err)

sc := bufio.NewScanner(bytes.NewReader(out))
displayNumber := -1
for sc.Scan() {
line := strings.TrimSpace(sc.Text())
t.Log(line)
if strings.HasPrefix(line, "DISPLAY=") {
parts := strings.SplitN(line, "=", 2)
display := parts[1]
parts = strings.SplitN(display, ":", 2)
parts = strings.SplitN(parts[1], ".", 2)
displayNumber, err = strconv.Atoi(parts[0])
require.NoError(t, err)
assert.GreaterOrEqual(t, displayNumber, 10, "display number should be >= 10")
gotScreenNumber, err := strconv.Atoi(parts[1])
require.NoError(t, err)
assert.Equal(t, wantScreenNumber, gotScreenNumber, "screen number should match")
break
}
}
require.NoError(t, sc.Err())
require.NotEqual(t, -1, displayNumber)

x11Chans := c.HandleChannelOpen("x11")
payload := "hello world"
require.Eventually(t, func() bool {
conn, err := net.Dial("unix", filepath.Join(dir, "X0"))
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", agentssh.X11StartPort+displayNumber))
if err == nil {
_, err = conn.Write([]byte(payload))
assert.NoError(t, err)
Expand Down
Loading