Skip to content

Commit 1fd4e9a

Browse files
committed
Add state wait func where caller holds the lock
1 parent 88a6b96 commit 1fd4e9a

File tree

2 files changed

+49
-44
lines changed

2 files changed

+49
-44
lines changed

agent/reconnectingpty/buffered.go

Lines changed: 40 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -178,39 +178,9 @@ func (rpty *bufferedReconnectingPTY) Attach(ctx context.Context, connID string,
178178
ctx, cancel := context.WithCancel(ctx)
179179
defer cancel()
180180

181-
// Once we are ready, attach the active connection while we hold the mutex.
182-
state, err := rpty.state.waitForStateOrContext(ctx, StateReady, func(state State) error {
183-
// Write any previously stored data for the TTY. Since the command might be
184-
// short-lived and have already exited, make sure we always at least output
185-
// the buffer before returning.
186-
prevBuf := slices.Clone(rpty.circularBuffer.Bytes())
187-
_, err := conn.Write(prevBuf)
188-
if err != nil {
189-
rpty.metrics.WithLabelValues("write").Add(1)
190-
return xerrors.Errorf("write buffer to conn: %w", err)
191-
}
192-
193-
if state != StateReady {
194-
return nil
195-
}
196-
197-
go heartbeat(ctx, rpty.timer, rpty.timeout)
198-
199-
// Resize the PTY to initial height + width.
200-
err = rpty.ptty.Resize(height, width)
201-
if err != nil {
202-
// We can continue after this, it's not fatal!
203-
logger.Warn(ctx, "reconnecting PTY initial resize failed, but will continue", slog.Error(err))
204-
rpty.metrics.WithLabelValues("resize").Add(1)
205-
}
206-
207-
// Store the connection for future writes.
208-
rpty.activeConns[connID] = conn
209-
210-
return nil
211-
})
212-
if state != StateReady || err != nil {
213-
return xerrors.Errorf("reconnecting pty ready wait: %w", err)
181+
err := rpty.doAttach(ctx, connID, conn, height, width, logger)
182+
if err != nil {
183+
return err
214184
}
215185

216186
defer func() {
@@ -224,6 +194,43 @@ func (rpty *bufferedReconnectingPTY) Attach(ctx context.Context, connID string,
224194
return nil
225195
}
226196

197+
// doAttach adds the connection to the map, replays the buffer, and starts the
198+
// heartbeat. It exists separately only so we can defer the mutex unlock which
199+
// is not possible in Attach since it blocks.
200+
func (rpty *bufferedReconnectingPTY) doAttach(ctx context.Context, connID string, conn net.Conn, height, width uint16, logger slog.Logger) error {
201+
rpty.state.cond.L.Lock()
202+
defer rpty.state.cond.L.Unlock()
203+
204+
// Write any previously stored data for the TTY. Since the command might be
205+
// short-lived and have already exited, make sure we always at least output
206+
// the buffer before returning, mostly just so tests pass.
207+
prevBuf := slices.Clone(rpty.circularBuffer.Bytes())
208+
_, err := conn.Write(prevBuf)
209+
if err != nil {
210+
rpty.metrics.WithLabelValues("write").Add(1)
211+
return xerrors.Errorf("write buffer to conn: %w", err)
212+
}
213+
214+
state, err := rpty.state.waitForStateOrContextLocked(ctx, StateReady)
215+
if state != StateReady {
216+
return xerrors.Errorf("reconnecting pty ready wait: %w", err)
217+
}
218+
219+
go heartbeat(ctx, rpty.timer, rpty.timeout)
220+
221+
// Resize the PTY to initial height + width.
222+
err = rpty.ptty.Resize(height, width)
223+
if err != nil {
224+
// We can continue after this, it's not fatal!
225+
logger.Warn(ctx, "reconnecting PTY initial resize failed, but will continue", slog.Error(err))
226+
rpty.metrics.WithLabelValues("resize").Add(1)
227+
}
228+
229+
rpty.activeConns[connID] = conn
230+
231+
return nil
232+
}
233+
227234
func (rpty *bufferedReconnectingPTY) Wait() {
228235
_, _ = rpty.state.waitForState(StateClosing)
229236
}

agent/reconnectingpty/reconnectingpty.go

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -167,10 +167,16 @@ func (s *ptyState) waitForState(state State) (State, error) {
167167
}
168168

169169
// waitForStateOrContext blocks until the state or a greater one is reached or
170-
// the provided context ends. If fn is non-nil it will be ran while the lock is
171-
// held unless the context ends. If fn returns an error then fn's error will
172-
// replace waitForStateOrContext's error.
170+
// the provided context ends.
173171
func (s *ptyState) waitForStateOrContext(ctx context.Context, state State, fn func(state State) error) (State, error) {
172+
s.cond.L.Lock()
173+
defer s.cond.L.Unlock()
174+
return s.waitForStateOrContextLocked(ctx, state)
175+
}
176+
177+
// waitForStateOrContextLocked is the same as waitForStateOrContext except it
178+
// assumes the caller has already locked cond.
179+
func (s *ptyState) waitForStateOrContextLocked(ctx context.Context, state State) (State, error) {
174180
nevermind := make(chan struct{})
175181
defer close(nevermind)
176182
go func() {
@@ -182,20 +188,12 @@ func (s *ptyState) waitForStateOrContext(ctx context.Context, state State, fn fu
182188
}
183189
}()
184190

185-
s.cond.L.Lock()
186-
defer s.cond.L.Unlock()
187191
for ctx.Err() == nil && state > s.state {
188192
s.cond.Wait()
189193
}
190194
if ctx.Err() != nil {
191195
return s.state, ctx.Err()
192196
}
193-
if fn != nil {
194-
err := fn(s.state)
195-
if err != nil {
196-
return s.state, err
197-
}
198-
}
199197
return s.state, s.error
200198
}
201199

0 commit comments

Comments
 (0)