@@ -129,6 +129,7 @@ type Server struct {
129
129
listeners map [net.Listener ]struct {}
130
130
conns map [net.Conn ]struct {}
131
131
sessions map [ssh.Session ]struct {}
132
+ processes map [* os.Process ]struct {}
132
133
closing chan struct {}
133
134
// Wait for goroutines to exit, waited without
134
135
// a lock on mu but protected by closing.
@@ -188,6 +189,7 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
188
189
fs : fs ,
189
190
conns : make (map [net.Conn ]struct {}),
190
191
sessions : make (map [ssh.Session ]struct {}),
192
+ processes : make (map [* os.Process ]struct {}),
191
193
logger : logger ,
192
194
193
195
config : config ,
@@ -606,7 +608,10 @@ func (s *Server) startNonPTYSession(logger slog.Logger, session ssh.Session, mag
606
608
// otherwise context cancellation will not propagate properly
607
609
// and SSH server close may be delayed.
608
610
cmd .SysProcAttr = cmdSysProcAttr ()
609
- cmd .Cancel = cmdCancel (session .Context (), logger , cmd )
611
+
612
+ // to match OpenSSH, we don't actually tear a non-TTY command down, even if the session ends.
613
+ // c.f. https://github.com/coder/coder/issues/18519#issuecomment-3019118271
614
+ cmd .Cancel = nil
610
615
611
616
cmd .Stdout = session
612
617
cmd .Stderr = session .Stderr ()
@@ -629,6 +634,16 @@ func (s *Server) startNonPTYSession(logger slog.Logger, session ssh.Session, mag
629
634
s .metrics .sessionErrors .WithLabelValues (magicTypeLabel , "no" , "start_command" ).Add (1 )
630
635
return xerrors .Errorf ("start: %w" , err )
631
636
}
637
+
638
+ // Since we don't cancel the process when the session stops, we still need to tear it down if we are closing. So
639
+ // track it here.
640
+ if ! s .trackProcess (cmd .Process , true ) {
641
+ // must be closing
642
+ err = cmdCancel (logger , cmd .Process )
643
+ return xerrors .Errorf ("failed to track process: %w" , err )
644
+ }
645
+ defer s .trackProcess (cmd .Process , false )
646
+
632
647
sigs := make (chan ssh.Signal , 1 )
633
648
session .Signals (sigs )
634
649
defer func () {
@@ -1089,6 +1104,27 @@ func (s *Server) trackSession(ss ssh.Session, add bool) (ok bool) {
1089
1104
return true
1090
1105
}
1091
1106
1107
+ // trackCommand registers the process with the server. If the server is
1108
+ // closing, the process is not registered and should be closed.
1109
+ //
1110
+ //nolint:revive
1111
+ func (s * Server ) trackProcess (p * os.Process , add bool ) (ok bool ) {
1112
+ s .mu .Lock ()
1113
+ defer s .mu .Unlock ()
1114
+ if add {
1115
+ if s .closing != nil {
1116
+ // Server closed.
1117
+ return false
1118
+ }
1119
+ s .wg .Add (1 )
1120
+ s .processes [p ] = struct {}{}
1121
+ return true
1122
+ }
1123
+ s .wg .Done ()
1124
+ delete (s .processes , p )
1125
+ return true
1126
+ }
1127
+
1092
1128
// Close the server and all active connections. Server can be re-used
1093
1129
// after Close is done.
1094
1130
func (s * Server ) Close () error {
@@ -1128,6 +1164,10 @@ func (s *Server) Close() error {
1128
1164
_ = c .Close ()
1129
1165
}
1130
1166
1167
+ for p := range s .processes {
1168
+ _ = cmdCancel (s .logger , p )
1169
+ }
1170
+
1131
1171
s .logger .Debug (ctx , "closing SSH server" )
1132
1172
err := s .srv .Close ()
1133
1173
0 commit comments