Skip to content

Commit 14a9916

Browse files
committed
feat(agentssh): Gracefully close SSH sessions on Close
By tracking and closing sessions manually before closing the underlying connections, we ensure that the termination is propagated to SSH/SFTP clients and they're not left waiting for a connection timeout. Refs: #6177
1 parent ed63a2b commit 14a9916

File tree

1 file changed

+46
-1
lines changed

1 file changed

+46
-1
lines changed

agent/agentssh/agentssh.go

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ type Server struct {
5050
mu sync.RWMutex // Protects following.
5151
listeners map[net.Listener]struct{}
5252
conns map[net.Conn]struct{}
53+
sessions map[ssh.Session]struct{}
5354
closing chan struct{}
5455
// Wait for goroutines to exit, waited without
5556
// a lock on mu but protected by closing.
@@ -86,6 +87,7 @@ func NewServer(ctx context.Context, logger slog.Logger, maxTimeout time.Duration
8687
s := &Server{
8788
listeners: make(map[net.Listener]struct{}),
8889
conns: make(map[net.Conn]struct{}),
90+
sessions: make(map[ssh.Session]struct{}),
8991
logger: logger,
9092
}
9193

@@ -129,7 +131,7 @@ func NewServer(ctx context.Context, logger slog.Logger, maxTimeout time.Duration
129131
}
130132
},
131133
SubsystemHandlers: map[string]ssh.SubsystemHandler{
132-
"sftp": s.sftpHandler,
134+
"sftp": s.sessionHandler,
133135
},
134136
MaxTimeout: maxTimeout,
135137
}
@@ -152,7 +154,25 @@ func (s *Server) ConnStats() ConnStats {
152154
}
153155

154156
func (s *Server) sessionHandler(session ssh.Session) {
157+
if !s.trackSession(session, true) {
158+
session.Exit(MagicSessionErrorCode)
159+
return
160+
}
161+
defer s.trackSession(session, false)
162+
155163
ctx := session.Context()
164+
165+
switch ss := session.Subsystem(); ss {
166+
case "":
167+
case "sftp":
168+
s.sftpHandler(session)
169+
return
170+
default:
171+
s.logger.Debug(ctx, "unsupported subsystem", slog.F("subsystem", ss))
172+
_ = session.Exit(1)
173+
return
174+
}
175+
156176
err := s.sessionStart(session)
157177
var exitError *exec.ExitError
158178
if xerrors.As(err, &exitError) {
@@ -560,6 +580,25 @@ func (s *Server) trackConn(l net.Listener, c net.Conn, add bool) (ok bool) {
560580
return true
561581
}
562582

583+
// trackSession registers the session with the server. If the server is
584+
// closing, the session is not registered and should be closed.
585+
//
586+
//nolint:revive
587+
func (s *Server) trackSession(ss ssh.Session, add bool) (ok bool) {
588+
s.mu.Lock()
589+
defer s.mu.Unlock()
590+
if add {
591+
if s.closing != nil {
592+
// Server closed.
593+
return false
594+
}
595+
s.sessions[ss] = struct{}{}
596+
return true
597+
}
598+
delete(s.sessions, ss)
599+
return true
600+
}
601+
563602
// Close the server and all active connections. Server can be re-used
564603
// after Close is done.
565604
func (s *Server) Close() error {
@@ -573,6 +612,12 @@ func (s *Server) Close() error {
573612
}
574613
s.closing = make(chan struct{})
575614

615+
// Close all active sessions to gracefully
616+
// terminate client connections.
617+
for ss := range s.sessions {
618+
_ = ss.Close()
619+
}
620+
576621
// Close all active listeners and connections.
577622
for l := range s.listeners {
578623
_ = l.Close()

0 commit comments

Comments
 (0)