Skip to content

Commit a801d0c

Browse files
committed
craft test for race condition
1 parent 0fef6b0 commit a801d0c

File tree

2 files changed

+81
-10
lines changed

2 files changed

+81
-10
lines changed

enterprise/wsproxy/keycache.go

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package wsproxy
22

33
import (
44
"context"
5+
"maps"
56
"sync"
67
"sync/atomic"
78
"time"
@@ -47,7 +48,7 @@ func NewCryptoKeyCache(ctx context.Context, log slog.Logger, client *wsproxysdk.
4748

4849
cache.refreshCtx, cache.refreshCancel = context.WithCancel(ctx)
4950
cache.refresher = cache.Clock.AfterFunc(time.Minute*10, cache.refresh)
50-
m, latest, err := cache.fetchKeys(ctx)
51+
m, latest, err := cache.cryptoKeys(ctx)
5152
if err != nil {
5253
cache.refreshCancel()
5354
return nil, xerrors.Errorf("initial fetch: %w", err)
@@ -100,10 +101,11 @@ func (k *CryptoKeyCache) Verifying(ctx context.Context, sequence int32) (codersd
100101
return codersdk.CryptoKey{}, cryptokeys.ErrClosed
101102
}
102103

103-
now := k.Clock.Now()
104104
k.keysMu.RLock()
105105
key, ok := k.keys[sequence]
106106
k.keysMu.RUnlock()
107+
108+
now := k.Clock.Now()
107109
if ok {
108110
return validKey(key, now)
109111
}
@@ -135,11 +137,13 @@ func (k *CryptoKeyCache) Verifying(ctx context.Context, sequence int32) (codersd
135137
return validKey(key, now)
136138
}
137139

140+
// refresh fetches the keys from the control plane and updates the cache.
138141
func (k *CryptoKeyCache) refresh() {
139142
if k.isClosed() {
140143
return
141144
}
142145

146+
now := k.Clock.Now("CryptoKeyCache", "refresh")
143147
k.fetchLock.Lock()
144148
defer k.fetchLock.Unlock()
145149

@@ -154,7 +158,7 @@ func (k *CryptoKeyCache) refresh() {
154158
// There's a window we must account for where the timer fires while a fetch
155159
// is ongoing but prior to the timer getting reset. In this case we want to
156160
// avoid double fetching.
157-
if k.Clock.Now().Sub(lastFetch) < time.Minute*10 {
161+
if now.Sub(lastFetch) < time.Minute*10 {
158162
return
159163
}
160164

@@ -165,7 +169,9 @@ func (k *CryptoKeyCache) refresh() {
165169
}
166170
}
167171

168-
func (k *CryptoKeyCache) fetchKeys(ctx context.Context) (map[int32]codersdk.CryptoKey, codersdk.CryptoKey, error) {
172+
// cryptoKeys queries the control plane for the crypto keys.
173+
// Outside of initialization, this should only be called by fetch.
174+
func (k *CryptoKeyCache) cryptoKeys(ctx context.Context) (map[int32]codersdk.CryptoKey, codersdk.CryptoKey, error) {
169175
keys, err := k.client.CryptoKeys(ctx)
170176
if err != nil {
171177
return nil, codersdk.CryptoKey{}, xerrors.Errorf("crypto keys: %w", err)
@@ -176,8 +182,9 @@ func (k *CryptoKeyCache) fetchKeys(ctx context.Context) (map[int32]codersdk.Cryp
176182

177183
// fetch fetches the keys from the control plane and updates the cache. The fetchMu
178184
// must be held when calling this function to avoid multiple concurrent fetches.
185+
// The returned keys are safe to use without additional locking.
179186
func (k *CryptoKeyCache) fetch(ctx context.Context) (map[int32]codersdk.CryptoKey, codersdk.CryptoKey, error) {
180-
keys, latest, err := k.fetchKeys(ctx)
187+
keys, latest, err := k.cryptoKeys(ctx)
181188
if err != nil {
182189
return nil, codersdk.CryptoKey{}, xerrors.Errorf("fetch keys: %w", err)
183190
}
@@ -196,7 +203,7 @@ func (k *CryptoKeyCache) fetch(ctx context.Context) (map[int32]codersdk.CryptoKe
196203

197204
k.lastFetch = k.Clock.Now()
198205
k.refresher.Reset(time.Minute * 10)
199-
k.keys, k.latest = keys, latest
206+
k.keys, k.latest = maps.Clone(keys), latest
200207

201208
return keys, latest, nil
202209
}
@@ -226,8 +233,8 @@ func (k *CryptoKeyCache) isClosed() bool {
226233
}
227234

228235
func (k *CryptoKeyCache) Close() {
229-
// The fetch lock must always be held before holding the keys lock
230-
// otherwise we risk a deadlock.
236+
// It's important to hold the locks here so that we don't unintentionally
237+
// reset the timer via an in flight request when Close is called.
231238
k.fetchLock.Lock()
232239
defer k.fetchLock.Unlock()
233240

@@ -239,5 +246,6 @@ func (k *CryptoKeyCache) Close() {
239246
}
240247

241248
k.refreshCancel()
249+
k.refresher.Stop()
242250
k.closed.Store(true)
243251
}

enterprise/wsproxy/keycache_test.go

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"time"
1010

1111
"github.com/stretchr/testify/require"
12+
"go.uber.org/goleak"
1213

1314
"cdr.dev/slog/sloggers/slogtest"
1415

@@ -20,6 +21,10 @@ import (
2021
"github.com/coder/quartz"
2122
)
2223

24+
func TestMain(m *testing.M) {
25+
goleak.VerifyTestMain(m)
26+
}
27+
2328
func TestCryptoKeyCache(t *testing.T) {
2429
t.Parallel()
2530

@@ -346,20 +351,78 @@ func TestCryptoKeyCache(t *testing.T) {
346351
fc.keys = []codersdk.CryptoKey{newKey}
347352

348353
// The ticker should fire and cause a request to coderd.
349-
_, advance := clock.AdvanceNext()
354+
dur, advance := clock.AdvanceNext()
350355
advance.MustWait(ctx)
351356
require.Equal(t, 2, fc.called)
357+
require.Equal(t, time.Minute*10, dur)
352358

353359
// Assert hits cache.
354360
got, err = cache.Signing(ctx)
355361
require.NoError(t, err)
356362
require.Equal(t, newKey, got)
357363
require.Equal(t, 2, fc.called)
358364

359-
// The ticker should fire and cause a request to coderd.
365+
// We check again to ensure the timer has been reset.
360366
_, advance = clock.AdvanceNext()
361367
advance.MustWait(ctx)
362368
require.Equal(t, 3, fc.called)
369+
require.Equal(t, time.Minute*10, dur)
370+
})
371+
372+
// This test ensures that if the refresh timer races with an inflight request
373+
// and loses that it doesn't cause a redundant fetch.
374+
375+
t.Run("RefreshNoDoubleFetch", func(t *testing.T) {
376+
t.Parallel()
377+
378+
var (
379+
ctx = testutil.Context(t, testutil.WaitShort)
380+
logger = slogtest.Make(t, nil)
381+
clock = quartz.NewMock(t)
382+
)
383+
384+
now := clock.Now().UTC()
385+
expected := codersdk.CryptoKey{
386+
Feature: codersdk.CryptoKeyFeatureWorkspaceApp,
387+
Secret: "key1",
388+
Sequence: 12,
389+
StartsAt: now,
390+
DeletesAt: now.Add(time.Minute * 10),
391+
}
392+
fc := newFakeCoderd(t, []codersdk.CryptoKey{
393+
expected,
394+
})
395+
396+
// Create a trap that blocks when the refresh timer fires.
397+
trap := clock.Trap().Now("refresh")
398+
cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, wsproxysdk.New(fc.url), withClock(clock))
399+
require.NoError(t, err)
400+
401+
_, wait := clock.AdvanceNext()
402+
trapped := trap.MustWait(ctx)
403+
404+
newKey := codersdk.CryptoKey{
405+
Feature: codersdk.CryptoKeyFeatureWorkspaceApp,
406+
Secret: "key2",
407+
Sequence: 13,
408+
StartsAt: now,
409+
}
410+
fc.keys = []codersdk.CryptoKey{newKey}
411+
412+
_, err = cache.Verifying(ctx, newKey.Sequence)
413+
require.NoError(t, err)
414+
require.Equal(t, 2, fc.called)
415+
416+
trapped.Release()
417+
wait.MustWait(ctx)
418+
require.Equal(t, 2, fc.called)
419+
trap.Close()
420+
421+
// The next timer should fire in 10 minutes.
422+
dur, wait := clock.AdvanceNext()
423+
wait.MustWait(ctx)
424+
require.Equal(t, time.Minute*10, dur)
425+
require.Equal(t, 3, fc.called)
363426
})
364427

365428
t.Run("Closed", func(t *testing.T) {

0 commit comments

Comments
 (0)