Skip to content

Commit 7d4f170

Browse files
committed
Add agent shutdown lifecycle states
1 parent e5b404f commit 7d4f170

17 files changed

+499
-169
lines changed

agent/agent.go

Lines changed: 81 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ func New(options Options) io.Closer {
112112
logDir: options.LogDir,
113113
tempDir: options.TempDir,
114114
lifecycleUpdate: make(chan struct{}, 1),
115+
lifecycleReported: make(chan codersdk.WorkspaceAgentLifecycle, 1),
115116
}
116117
a.init(ctx)
117118
return a
@@ -139,9 +140,10 @@ type agent struct {
139140
sessionToken atomic.Pointer[string]
140141
sshServer *ssh.Server
141142

142-
lifecycleUpdate chan struct{}
143-
lifecycleMu sync.Mutex // Protects following.
144-
lifecycleState codersdk.WorkspaceAgentLifecycle
143+
lifecycleUpdate chan struct{}
144+
lifecycleReported chan codersdk.WorkspaceAgentLifecycle
145+
lifecycleMu sync.RWMutex // Protects following.
146+
lifecycleState codersdk.WorkspaceAgentLifecycle
145147

146148
network *tailnet.Conn
147149
}
@@ -187,9 +189,9 @@ func (a *agent) reportLifecycleLoop(ctx context.Context) {
187189
}
188190

189191
for r := retry.New(time.Second, 15*time.Second); r.Wait(ctx); {
190-
a.lifecycleMu.Lock()
192+
a.lifecycleMu.RLock()
191193
state := a.lifecycleState
192-
a.lifecycleMu.Unlock()
194+
a.lifecycleMu.RUnlock()
193195

194196
if state == lastReported {
195197
break
@@ -202,6 +204,11 @@ func (a *agent) reportLifecycleLoop(ctx context.Context) {
202204
})
203205
if err == nil {
204206
lastReported = state
207+
select {
208+
case a.lifecycleReported <- state:
209+
case <-a.lifecycleReported:
210+
a.lifecycleReported <- state
211+
}
205212
break
206213
}
207214
if xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) {
@@ -213,13 +220,20 @@ func (a *agent) reportLifecycleLoop(ctx context.Context) {
213220
}
214221
}
215222

223+
// setLifecycle sets the lifecycle state and notifies the lifecycle loop.
224+
// The state is only updated if it's a valid state transition.
216225
func (a *agent) setLifecycle(ctx context.Context, state codersdk.WorkspaceAgentLifecycle) {
217226
a.lifecycleMu.Lock()
218-
defer a.lifecycleMu.Unlock()
219-
220-
a.logger.Debug(ctx, "set lifecycle state", slog.F("state", state), slog.F("previous", a.lifecycleState))
221-
227+
lastState := a.lifecycleState
228+
if slices.Index(codersdk.WorkspaceAgentLifecycleOrder, lastState) > slices.Index(codersdk.WorkspaceAgentLifecycleOrder, state) {
229+
a.logger.Warn(ctx, "attempted to set lifecycle state to a previous state", slog.F("last", lastState), slog.F("state", state))
230+
a.lifecycleMu.Unlock()
231+
return
232+
}
222233
a.lifecycleState = state
234+
a.logger.Debug(ctx, "set lifecycle state", slog.F("state", state), slog.F("last", lastState))
235+
a.lifecycleMu.Unlock()
236+
223237
select {
224238
case a.lifecycleUpdate <- struct{}{}:
225239
default:
@@ -310,15 +324,15 @@ func (a *agent) run(ctx context.Context) error {
310324
return
311325
}
312326
execTime := time.Since(scriptStart)
313-
lifecycleStatus := codersdk.WorkspaceAgentLifecycleReady
327+
lifecycleState := codersdk.WorkspaceAgentLifecycleReady
314328
if err != nil {
315329
a.logger.Warn(ctx, "startup script failed", slog.F("execution_time", execTime), slog.Error(err))
316-
lifecycleStatus = codersdk.WorkspaceAgentLifecycleStartError
330+
lifecycleState = codersdk.WorkspaceAgentLifecycleStartError
317331
} else {
318332
a.logger.Info(ctx, "startup script completed", slog.F("execution_time", execTime))
319333
}
320334

321-
a.setLifecycle(ctx, lifecycleStatus)
335+
a.setLifecycle(ctx, lifecycleState)
322336
}()
323337
}
324338

@@ -1193,25 +1207,72 @@ func (a *agent) Close() error {
11931207
if a.isClosed() {
11941208
return nil
11951209
}
1196-
close(a.closed)
1197-
a.closeCancel()
11981210

1211+
ctx := context.Background()
1212+
a.setLifecycle(ctx, codersdk.WorkspaceAgentLifecycleShuttingDown)
1213+
1214+
// Close services before running shutdown script.
1215+
// TODO(mafredri): Gracefully shutdown:
1216+
// - Close active SSH server connections
1217+
// - Close processes (send HUP, wait, etc.)
1218+
1219+
lifecycleState := codersdk.WorkspaceAgentLifecycleOff
11991220
if metadata, ok := a.metadata.Load().(agentsdk.Metadata); ok {
1200-
ctx := context.Background()
1201-
err := a.runShutdownScript(ctx, metadata.ShutdownScript)
1221+
scriptDone := make(chan error, 1)
1222+
scriptStart := time.Now()
1223+
go func() {
1224+
defer close(scriptDone)
1225+
scriptDone <- a.runShutdownScript(ctx, metadata.ShutdownScript)
1226+
}()
1227+
1228+
var timeout <-chan time.Time
1229+
// If timeout is zero, an older version of the coder
1230+
// provider was used. Otherwise a timeout is always > 0.
1231+
if metadata.ShutdownScriptTimeout > 0 {
1232+
t := time.NewTimer(metadata.ShutdownScriptTimeout)
1233+
defer t.Stop()
1234+
timeout = t.C
1235+
}
1236+
1237+
var err error
1238+
select {
1239+
case err = <-scriptDone:
1240+
case <-timeout:
1241+
a.logger.Warn(ctx, "shutdown script timed out")
1242+
a.setLifecycle(ctx, codersdk.WorkspaceAgentLifecycleShutdownTimeout)
1243+
err = <-scriptDone // The script can still complete after a timeout.
1244+
}
1245+
execTime := time.Since(scriptStart)
12021246
if err != nil {
1203-
a.logger.Error(ctx, "shutdown script failed", slog.Error(err))
1247+
a.logger.Warn(ctx, "shutdown script failed", slog.F("execution_time", execTime), slog.Error(err))
1248+
lifecycleState = codersdk.WorkspaceAgentLifecycleShutdownError
1249+
} else {
1250+
a.logger.Info(ctx, "shutdown script completed", slog.F("execution_time", execTime))
12041251
}
1205-
} else {
1206-
// No metadata.. halt?
1252+
}
1253+
1254+
// Set final state and wait for it to be reported because context
1255+
// cancellation will stop the report loop.
1256+
a.setLifecycle(ctx, lifecycleState)
1257+
for s := range a.lifecycleReported {
1258+
if s == lifecycleState {
1259+
break
1260+
}
1261+
}
1262+
1263+
if lifecycleState != codersdk.WorkspaceAgentLifecycleOff {
1264+
// TODO(mafredri): Delay shutdown, ensure debugging is possible.
12071265
_ = false
12081266
}
12091267

1268+
close(a.closed)
1269+
a.closeCancel()
1270+
_ = a.sshServer.Close()
12101271
if a.network != nil {
12111272
_ = a.network.Close()
12121273
}
1213-
_ = a.sshServer.Close()
12141274
a.connCloseWait.Wait()
1275+
12151276
return nil
12161277
}
12171278

0 commit comments

Comments
 (0)