Skip to content

Commit 6e7856a

Browse files
committed
add heartbeat
1 parent f7810e5 commit 6e7856a

File tree

4 files changed

+117
-14
lines changed

4 files changed

+117
-14
lines changed

coderd/coderd.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1190,9 +1190,9 @@ func (api *API) CreateInMemoryProvisionerDaemon(ctx context.Context, name string
11901190
api.Logger.Info(ctx, "starting in-memory provisioner daemon", slog.F("name", name))
11911191
logger := api.Logger.Named(fmt.Sprintf("inmem-provisionerd-%s", name))
11921192
srv, err := provisionerdserver.NewServer(
1193-
api.ctx,
1193+
api.ctx, // use the same ctx as the API
11941194
api.AccessURL,
1195-
uuid.New(),
1195+
daemon.ID,
11961196
logger,
11971197
daemon.Provisioners,
11981198
provisionerdserver.Tags(daemon.Tags),

coderd/provisionerdserver/provisionerdserver.go

Lines changed: 71 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,15 @@ import (
4444
sdkproto "github.com/coder/coder/v2/provisionersdk/proto"
4545
)
4646

47-
// DefaultAcquireJobLongPollDur is the time the (deprecated) AcquireJob rpc waits to try to obtain a job before
48-
// canceling and returning an empty job.
49-
const DefaultAcquireJobLongPollDur = time.Second * 5
47+
const (
48+
// DefaultAcquireJobLongPollDur is the time the (deprecated) AcquireJob rpc waits to try to obtain a job before
49+
// canceling and returning an empty job.
50+
DefaultAcquireJobLongPollDur = time.Second * 5
51+
52+
// DefaultHeartbeatInterval is the interval at which the provisioner daemon
53+
// will update its last seen at timestamp in the database.
54+
DefaultHeartbeatInterval = time.Minute
55+
)
5056

5157
type Options struct {
5258
OIDCConfig httpmw.OAuth2Config
@@ -56,6 +62,15 @@ type Options struct {
5662

5763
// AcquireJobLongPollDur is used in tests
5864
AcquireJobLongPollDur time.Duration
65+
66+
// HeartbeatInterval is the interval at which the provisioner daemon
67+
// will update its last seen at timestamp in the database.
68+
HeartbeatInterval time.Duration
69+
70+
// HeartbeatFn is the function that will be called at the interval
71+
// specified by HeartbeatInterval.
72+
// This is only used in tests.
73+
HeartbeatFn func(context.Context) error
5974
}
6075

6176
type server struct {
@@ -85,6 +100,9 @@ type server struct {
85100
TimeNowFn func() time.Time
86101

87102
acquireJobLongPollDur time.Duration
103+
104+
HeartbeatInterval time.Duration
105+
HeartbeatFn func(ctx context.Context) error
88106
}
89107

90108
// We use the null byte (0x00) in generating a canonical map key for tags, so
@@ -161,7 +179,11 @@ func NewServer(
161179
if options.AcquireJobLongPollDur == 0 {
162180
options.AcquireJobLongPollDur = DefaultAcquireJobLongPollDur
163181
}
164-
return &server{
182+
if options.HeartbeatInterval == 0 {
183+
options.HeartbeatInterval = DefaultHeartbeatInterval
184+
}
185+
186+
s := &server{
165187
lifecycleCtx: lifecycleCtx,
166188
AccessURL: accessURL,
167189
ID: id,
@@ -182,7 +204,13 @@ func NewServer(
182204
OIDCConfig: options.OIDCConfig,
183205
TimeNowFn: options.TimeNowFn,
184206
acquireJobLongPollDur: options.AcquireJobLongPollDur,
185-
}, nil
207+
HeartbeatInterval: options.HeartbeatInterval,
208+
HeartbeatFn: options.HeartbeatFn,
209+
}
210+
211+
go s.heartbeat()
212+
213+
return s, nil
186214
}
187215

188216
// timeNow should be used when trying to get the current time for math
@@ -194,6 +222,44 @@ func (s *server) timeNow() time.Time {
194222
return dbtime.Now()
195223
}
196224

225+
// heartbeat runs heartbeatOnce at the interval specified by HeartbeatInterval
226+
// until the lifecycle context is canceled.
227+
func (s *server) heartbeat() {
228+
tick := time.NewTicker(time.Nanosecond)
229+
defer tick.Stop()
230+
for {
231+
select {
232+
case <-s.lifecycleCtx.Done():
233+
return
234+
case <-tick.C:
235+
hbCtx, hbCancel := context.WithTimeout(s.lifecycleCtx, s.HeartbeatInterval)
236+
if err := s.heartbeatOnce(hbCtx); err != nil {
237+
s.Logger.Error(hbCtx, "heartbeat failed", slog.Error(err))
238+
}
239+
hbCancel()
240+
tick.Reset(s.HeartbeatInterval)
241+
}
242+
}
243+
}
244+
245+
// heartbeatOnce updates the last seen at timestamp in the database.
246+
// If HeartbeatFn is set, it will be called instead.
247+
func (s *server) heartbeatOnce(ctx context.Context) error {
248+
if s.HeartbeatFn != nil {
249+
return s.HeartbeatFn(ctx)
250+
}
251+
252+
if s.lifecycleCtx.Err() != nil {
253+
return nil
254+
}
255+
256+
//nolint:gocritic // Provisionerd has specific authz rules.
257+
return s.Database.UpdateProvisionerDaemonLastSeenAt(dbauthz.AsProvisionerd(ctx), database.UpdateProvisionerDaemonLastSeenAtParams{
258+
ID: s.ID,
259+
LastSeenAt: sql.NullTime{Time: s.timeNow(), Valid: true},
260+
})
261+
}
262+
197263
// AcquireJob queries the database to lock a job.
198264
//
199265
// Deprecated: This method is only available for back-level provisioner daemons.

coderd/provisionerdserver/provisionerdserver_test.go

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,29 @@ func TestAcquireJobWithCancel_Cancel(t *testing.T) {
9595
require.Equal(t, "", job.JobId)
9696
}
9797

98+
func TestHeartbeat(t *testing.T) {
99+
t.Parallel()
100+
101+
ctx, cancel := context.WithCancel(context.Background())
102+
t.Cleanup(cancel)
103+
heartbeatChan := make(chan struct{})
104+
heartbeatFn := func(context.Context) error {
105+
heartbeatChan <- struct{}{}
106+
return nil
107+
}
108+
//nolint:dogsled // this is a test
109+
_, _, _ = setup(t, false, &overrides{
110+
ctx: ctx,
111+
heartbeatFn: heartbeatFn,
112+
heartbeatInterval: testutil.IntervalFast,
113+
})
114+
115+
<-heartbeatChan
116+
cancel()
117+
close(heartbeatChan)
118+
<-time.After(testutil.IntervalFast)
119+
}
120+
98121
func TestAcquireJob(t *testing.T) {
99122
t.Parallel()
100123

@@ -1686,19 +1709,20 @@ func TestInsertWorkspaceResource(t *testing.T) {
16861709
}
16871710

16881711
type overrides struct {
1712+
ctx context.Context
16891713
deploymentValues *codersdk.DeploymentValues
16901714
externalAuthConfigs []*externalauth.Config
16911715
id *uuid.UUID
16921716
templateScheduleStore *atomic.Pointer[schedule.TemplateScheduleStore]
16931717
userQuietHoursScheduleStore *atomic.Pointer[schedule.UserQuietHoursScheduleStore]
16941718
timeNowFn func() time.Time
16951719
acquireJobLongPollDuration time.Duration
1720+
heartbeatFn func(ctx context.Context) error
1721+
heartbeatInterval time.Duration
16961722
}
16971723

16981724
func setup(t *testing.T, ignoreLogErrors bool, ov *overrides) (proto.DRPCProvisionerDaemonServer, database.Store, pubsub.Pubsub) {
16991725
t.Helper()
1700-
ctx, cancel := context.WithCancel(context.Background())
1701-
t.Cleanup(cancel)
17021726
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
17031727
db := dbmem.New()
17041728
ps := pubsub.NewInMemory()
@@ -1710,6 +1734,14 @@ func setup(t *testing.T, ignoreLogErrors bool, ov *overrides) (proto.DRPCProvisi
17101734
var timeNowFn func() time.Time
17111735
pollDur := time.Duration(0)
17121736
if ov != nil {
1737+
if ov.ctx == nil {
1738+
ctx, cancel := context.WithCancel(context.Background())
1739+
t.Cleanup(cancel)
1740+
ov.ctx = ctx
1741+
}
1742+
if ov.heartbeatInterval == 0 {
1743+
ov.heartbeatInterval = testutil.IntervalMedium
1744+
}
17131745
if ov.deploymentValues != nil {
17141746
deploymentValues = ov.deploymentValues
17151747
}
@@ -1744,15 +1776,15 @@ func setup(t *testing.T, ignoreLogErrors bool, ov *overrides) (proto.DRPCProvisi
17441776
}
17451777

17461778
srv, err := provisionerdserver.NewServer(
1747-
ctx,
1779+
ov.ctx,
17481780
&url.URL{},
17491781
srvID,
17501782
slogtest.Make(t, &slogtest.Options{IgnoreErrors: ignoreLogErrors}),
17511783
[]database.ProvisionerType{database.ProvisionerTypeEcho},
17521784
provisionerdserver.Tags{},
17531785
db,
17541786
ps,
1755-
provisionerdserver.NewAcquirer(ctx, logger.Named("acquirer"), db, ps),
1787+
provisionerdserver.NewAcquirer(ov.ctx, logger.Named("acquirer"), db, ps),
17561788
telemetry.NewNoop(),
17571789
trace.NewNoopTracerProvider().Tracer("noop"),
17581790
&atomic.Pointer[proto.QuotaCommitter]{},
@@ -1765,6 +1797,8 @@ func setup(t *testing.T, ignoreLogErrors bool, ov *overrides) (proto.DRPCProvisi
17651797
TimeNowFn: timeNowFn,
17661798
OIDCConfig: &oauth2.Config{},
17671799
AcquireJobLongPollDur: pollDur,
1800+
HeartbeatInterval: ov.heartbeatInterval,
1801+
HeartbeatFn: ov.heartbeatFn,
17681802
},
17691803
)
17701804
require.NoError(t, err)

enterprise/coderd/provisionerdaemons.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
234234
}
235235

236236
// Create the daemon in the database.
237-
_, err := api.Database.UpsertProvisionerDaemon(authCtx, database.UpsertProvisionerDaemonParams{
237+
daemon, err := api.Database.UpsertProvisionerDaemon(authCtx, database.UpsertProvisionerDaemonParams{
238238
Name: name,
239239
Provisioners: provisioners,
240240
Tags: tags,
@@ -295,11 +295,13 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
295295
}
296296
mux := drpcmux.New()
297297
logger := api.Logger.Named(fmt.Sprintf("ext-provisionerd-%s", name))
298+
srvCtx, srvCancel := context.WithCancel(ctx)
299+
defer srvCancel()
298300
logger.Info(ctx, "starting external provisioner daemon")
299301
srv, err := provisionerdserver.NewServer(
300-
api.ctx,
302+
srvCtx,
301303
api.AccessURL,
302-
id,
304+
daemon.ID,
303305
logger,
304306
provisioners,
305307
tags,
@@ -339,6 +341,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
339341
},
340342
})
341343
err = server.Serve(ctx, session)
344+
srvCancel()
342345
logger.Info(ctx, "provisioner daemon disconnected", slog.Error(err))
343346
if err != nil && !xerrors.Is(err, io.EOF) {
344347
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("serve: %s", err))

0 commit comments

Comments
 (0)