7
7
"errors"
8
8
"fmt"
9
9
"io"
10
+ "math"
10
11
"net"
11
12
"os"
12
13
"path/filepath"
@@ -22,102 +23,133 @@ import (
22
23
"cdr.dev/slog"
23
24
)
24
25
25
- // x11Callback is called when the client requests X11 forwarding.
26
- // It adds an Xauthority entry to the Xauthority file.
27
- func (s * Server ) x11Callback (ctx ssh.Context , x11 ssh.X11 ) bool {
28
- hostname , err := os .Hostname ()
29
- if err != nil {
30
- s .logger .Warn (ctx , "failed to get hostname" , slog .Error (err ))
31
- s .metrics .x11HandlerErrors .WithLabelValues ("hostname" ).Add (1 )
32
- return false
33
- }
34
-
35
- err = s .fs .MkdirAll (s .config .X11SocketDir , 0o700 )
36
- if err != nil {
37
- s .logger .Warn (ctx , "failed to make the x11 socket dir" , slog .F ("dir" , s .config .X11SocketDir ), slog .Error (err ))
38
- s .metrics .x11HandlerErrors .WithLabelValues ("socker_dir" ).Add (1 )
39
- return false
40
- }
26
+ const (
27
+ // X11StartPort is the starting port for X11 forwarding, this is the
28
+ // port used for "DISPLAY=localhost:0".
29
+ X11StartPort = 6000
30
+ // X11DefaultDisplayOffset is the default offset for X11 forwarding.
31
+ X11DefaultDisplayOffset = 10
32
+ )
41
33
42
- err = addXauthEntry (ctx , s .fs , hostname , strconv .Itoa (int (x11 .ScreenNumber )), x11 .AuthProtocol , x11 .AuthCookie )
43
- if err != nil {
44
- s .logger .Warn (ctx , "failed to add Xauthority entry" , slog .Error (err ))
45
- s .metrics .x11HandlerErrors .WithLabelValues ("xauthority" ).Add (1 )
46
- return false
47
- }
34
+ // x11Callback is called when the client requests X11 forwarding.
35
+ func (* Server ) x11Callback (_ ssh.Context , _ ssh.X11 ) bool {
36
+ // Always allow.
48
37
return true
49
38
}
50
39
51
40
// x11Handler is called when a session has requested X11 forwarding.
52
41
// It listens for X11 connections and forwards them to the client.
53
- func (s * Server ) x11Handler (ctx ssh.Context , x11 ssh.X11 ) bool {
42
+ func (s * Server ) x11Handler (ctx ssh.Context , x11 ssh.X11 ) ( display int , handled bool ) {
54
43
serverConn , valid := ctx .Value (ssh .ContextKeyConn ).(* gossh.ServerConn )
55
44
if ! valid {
56
45
s .logger .Warn (ctx , "failed to get server connection" )
57
- return false
46
+ return - 1 , false
58
47
}
59
- // We want to overwrite the socket so that subsequent connections will succeed.
60
- socketPath := filepath .Join (s .config .X11SocketDir , fmt .Sprintf ("X%d" , x11 .ScreenNumber ))
61
- err := os .Remove (socketPath )
62
- if err != nil && ! errors .Is (err , os .ErrNotExist ) {
63
- s .logger .Warn (ctx , "failed to remove existing X11 socket" , slog .Error (err ))
64
- return false
65
- }
66
- listener , err := net .Listen ("unix" , socketPath )
48
+
49
+ hostname , err := os .Hostname ()
67
50
if err != nil {
51
+ s .logger .Warn (ctx , "failed to get hostname" , slog .Error (err ))
52
+ s .metrics .x11HandlerErrors .WithLabelValues ("hostname" ).Add (1 )
53
+ return - 1 , false
54
+ }
55
+
56
+ var (
57
+ lc net.ListenConfig
58
+ ln net.Listener
59
+ port = X11StartPort + * s .config .X11DisplayOffset
60
+ )
61
+ // Look for an open port to listen on..
62
+ for ; port >= X11StartPort && port < math .MaxUint16 ; port ++ {
63
+ ln , err = lc .Listen (ctx , "tcp" , fmt .Sprintf ("localhost:%d" , port ))
64
+ if err == nil {
65
+ display = port - X11StartPort
66
+ break
67
+ }
68
+ }
69
+ if ln == nil {
68
70
s .logger .Warn (ctx , "failed to listen for X11" , slog .Error (err ))
69
- return false
71
+ s .metrics .x11HandlerErrors .WithLabelValues ("listen" ).Add (1 )
72
+ return - 1 , false
73
+ }
74
+ s .trackListener (ln , true )
75
+ defer func () {
76
+ if ! handled {
77
+ s .trackListener (ln , false )
78
+ _ = ln .Close ()
79
+ }
80
+ }()
81
+
82
+ err = addXauthEntry (ctx , s .fs , hostname , strconv .Itoa (display ), x11 .AuthProtocol , x11 .AuthCookie )
83
+ if err != nil {
84
+ s .logger .Warn (ctx , "failed to add Xauthority entry" , slog .Error (err ))
85
+ s .metrics .x11HandlerErrors .WithLabelValues ("xauthority" ).Add (1 )
86
+ return - 1 , false
70
87
}
71
- s .trackListener (listener , true )
72
88
73
89
go func () {
74
- defer listener .Close ()
75
- defer s .trackListener (listener , false )
76
- handledFirstConnection := false
90
+ // Don't leave the listener open after the session is gone.
91
+ <- ctx .Done ()
92
+ _ = ln .Close ()
93
+ }()
94
+
95
+ go func () {
96
+ defer ln .Close ()
97
+ defer s .trackListener (ln , false )
77
98
78
99
for {
79
- conn , err := listener .Accept ()
100
+ conn , err := ln .Accept ()
80
101
if err != nil {
81
102
if errors .Is (err , net .ErrClosed ) {
82
103
return
83
104
}
84
105
s .logger .Warn (ctx , "failed to accept X11 connection" , slog .Error (err ))
85
106
return
86
107
}
87
- if x11 .SingleConnection && handledFirstConnection {
88
- s .logger .Warn (ctx , "X11 connection rejected because single connection is enabled" )
89
- _ = conn .Close ()
90
- continue
108
+ if x11 .SingleConnection {
109
+ s .logger .Debug (ctx , "single connection requested, closing X11 listener" )
110
+ _ = ln .Close ()
91
111
}
92
- handledFirstConnection = true
93
112
94
- unixConn , ok := conn .(* net.UnixConn )
113
+ tcpConn , ok := conn .(* net.TCPConn )
95
114
if ! ok {
96
- s .logger .Warn (ctx , fmt .Sprintf ("failed to cast connection to UnixConn. got: %T" , conn ))
97
- return
115
+ s .logger .Warn (ctx , fmt .Sprintf ("failed to cast connection to TCPConn. got: %T" , conn ))
116
+ _ = conn .Close ()
117
+ continue
98
118
}
99
- unixAddr , ok := unixConn .LocalAddr ().(* net.UnixAddr )
119
+ tcpAddr , ok := tcpConn .LocalAddr ().(* net.TCPAddr )
100
120
if ! ok {
101
- s .logger .Warn (ctx , fmt .Sprintf ("failed to cast local address to UnixAddr. got: %T" , unixConn .LocalAddr ()))
102
- return
121
+ s .logger .Warn (ctx , fmt .Sprintf ("failed to cast local address to TCPAddr. got: %T" , tcpConn .LocalAddr ()))
122
+ _ = conn .Close ()
123
+ continue
103
124
}
104
125
105
126
channel , reqs , err := serverConn .OpenChannel ("x11" , gossh .Marshal (struct {
106
127
OriginatorAddress string
107
128
OriginatorPort uint32
108
129
}{
109
- OriginatorAddress : unixAddr . Name ,
110
- OriginatorPort : 0 ,
130
+ OriginatorAddress : tcpAddr . IP . String () ,
131
+ OriginatorPort : uint32 ( tcpAddr . Port ) ,
111
132
}))
112
133
if err != nil {
113
134
s .logger .Warn (ctx , "failed to open X11 channel" , slog .Error (err ))
114
- return
135
+ _ = conn .Close ()
136
+ continue
115
137
}
116
138
go gossh .DiscardRequests (reqs )
117
- go Bicopy (ctx , conn , channel )
139
+
140
+ if ! s .trackConn (ln , conn , true ) {
141
+ s .logger .Warn (ctx , "failed to track X11 connection" )
142
+ _ = conn .Close ()
143
+ continue
144
+ }
145
+ go func () {
146
+ defer s .trackConn (ln , conn , false )
147
+ Bicopy (ctx , conn , channel )
148
+ }()
118
149
}
119
150
}()
120
- return true
151
+
152
+ return display , true
121
153
}
122
154
123
155
// addXauthEntry adds an Xauthority entry to the Xauthority file.
0 commit comments