@@ -50,6 +50,7 @@ type Server struct {
50
50
mu sync.RWMutex // Protects following.
51
51
listeners map [net.Listener ]struct {}
52
52
conns map [net.Conn ]struct {}
53
+ sessions map [ssh.Session ]struct {}
53
54
closing chan struct {}
54
55
// Wait for goroutines to exit, waited without
55
56
// a lock on mu but protected by closing.
@@ -86,6 +87,7 @@ func NewServer(ctx context.Context, logger slog.Logger, maxTimeout time.Duration
86
87
s := & Server {
87
88
listeners : make (map [net.Listener ]struct {}),
88
89
conns : make (map [net.Conn ]struct {}),
90
+ sessions : make (map [ssh.Session ]struct {}),
89
91
logger : logger ,
90
92
}
91
93
@@ -129,7 +131,7 @@ func NewServer(ctx context.Context, logger slog.Logger, maxTimeout time.Duration
129
131
}
130
132
},
131
133
SubsystemHandlers : map [string ]ssh.SubsystemHandler {
132
- "sftp" : s .sftpHandler ,
134
+ "sftp" : s .sessionHandler ,
133
135
},
134
136
MaxTimeout : maxTimeout ,
135
137
}
@@ -152,7 +154,25 @@ func (s *Server) ConnStats() ConnStats {
152
154
}
153
155
154
156
func (s * Server ) sessionHandler (session ssh.Session ) {
157
+ if ! s .trackSession (session , true ) {
158
+ session .Exit (MagicSessionErrorCode )
159
+ return
160
+ }
161
+ defer s .trackSession (session , false )
162
+
155
163
ctx := session .Context ()
164
+
165
+ switch ss := session .Subsystem (); ss {
166
+ case "" :
167
+ case "sftp" :
168
+ s .sftpHandler (session )
169
+ return
170
+ default :
171
+ s .logger .Debug (ctx , "unsupported subsystem" , slog .F ("subsystem" , ss ))
172
+ _ = session .Exit (1 )
173
+ return
174
+ }
175
+
156
176
err := s .sessionStart (session )
157
177
var exitError * exec.ExitError
158
178
if xerrors .As (err , & exitError ) {
@@ -560,6 +580,25 @@ func (s *Server) trackConn(l net.Listener, c net.Conn, add bool) (ok bool) {
560
580
return true
561
581
}
562
582
583
+ // trackSession registers the session with the server. If the server is
584
+ // closing, the session is not registered and should be closed.
585
+ //
586
+ //nolint:revive
587
+ func (s * Server ) trackSession (ss ssh.Session , add bool ) (ok bool ) {
588
+ s .mu .Lock ()
589
+ defer s .mu .Unlock ()
590
+ if add {
591
+ if s .closing != nil {
592
+ // Server closed.
593
+ return false
594
+ }
595
+ s .sessions [ss ] = struct {}{}
596
+ return true
597
+ }
598
+ delete (s .sessions , ss )
599
+ return true
600
+ }
601
+
563
602
// Close the server and all active connections. Server can be re-used
564
603
// after Close is done.
565
604
func (s * Server ) Close () error {
@@ -573,6 +612,12 @@ func (s *Server) Close() error {
573
612
}
574
613
s .closing = make (chan struct {})
575
614
615
+ // Close all active sessions to gracefully
616
+ // terminate client connections.
617
+ for ss := range s .sessions {
618
+ _ = ss .Close ()
619
+ }
620
+
576
621
// Close all active listeners and connections.
577
622
for l := range s .listeners {
578
623
_ = l .Close ()
0 commit comments