Skip to content

feat(agentssh): Gracefully close SSH sessions on Close #7027

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 6, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 50 additions & 1 deletion agent/agentssh/agentssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ type Server struct {
mu sync.RWMutex // Protects following.
listeners map[net.Listener]struct{}
conns map[net.Conn]struct{}
sessions map[ssh.Session]struct{}
closing chan struct{}
// Wait for goroutines to exit, waited without
// a lock on mu but protected by closing.
Expand Down Expand Up @@ -86,6 +87,7 @@ func NewServer(ctx context.Context, logger slog.Logger, maxTimeout time.Duration
s := &Server{
listeners: make(map[net.Listener]struct{}),
conns: make(map[net.Conn]struct{}),
sessions: make(map[ssh.Session]struct{}),
logger: logger,
}

Expand Down Expand Up @@ -129,7 +131,7 @@ func NewServer(ctx context.Context, logger slog.Logger, maxTimeout time.Duration
}
},
SubsystemHandlers: map[string]ssh.SubsystemHandler{
"sftp": s.sftpHandler,
"sftp": s.sessionHandler,
},
MaxTimeout: maxTimeout,
}
Expand All @@ -152,7 +154,26 @@ func (s *Server) ConnStats() ConnStats {
}

func (s *Server) sessionHandler(session ssh.Session) {
if !s.trackSession(session, true) {
// See (*Server).Close() for why we call Close instead of Exit.
_ = session.Close()
return
}
defer s.trackSession(session, false)

ctx := session.Context()

switch ss := session.Subsystem(); ss {
case "":
case "sftp":
s.sftpHandler(session)
return
default:
s.logger.Debug(ctx, "unsupported subsystem", slog.F("subsystem", ss))
_ = session.Exit(1)
return
}

err := s.sessionStart(session)
var exitError *exec.ExitError
if xerrors.As(err, &exitError) {
Expand Down Expand Up @@ -560,6 +581,25 @@ func (s *Server) trackConn(l net.Listener, c net.Conn, add bool) (ok bool) {
return true
}

// trackSession registers the session with the server. If the server is
// closing, the session is not registered and should be closed.
//
//nolint:revive
func (s *Server) trackSession(ss ssh.Session, add bool) (ok bool) {
s.mu.Lock()
defer s.mu.Unlock()
if add {
if s.closing != nil {
// Server closed.
return false
}
s.sessions[ss] = struct{}{}
return true
}
delete(s.sessions, ss)
return true
}

// Close the server and all active connections. Server can be re-used
// after Close is done.
func (s *Server) Close() error {
Expand All @@ -573,6 +613,15 @@ func (s *Server) Close() error {
}
s.closing = make(chan struct{})

// Close all active sessions to gracefully
// terminate client connections.
for ss := range s.sessions {
// We call Close on the underlying channel here because we don't
// want to send an exit status to the client (via Exit()).
// Typically OpenSSH clients will return 255 as the exit status.
_ = ss.Close()
}

// Close all active listeners and connections.
for l := range s.listeners {
_ = l.Close()
Expand Down