Skip to content

Commit 0fef6b0

Browse files
committed
fix tests
1 parent 9061a1c commit 0fef6b0

File tree

2 files changed

+55
-31
lines changed

2 files changed

+55
-31
lines changed

enterprise/wsproxy/keycache.go

Lines changed: 51 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,14 @@ func NewCryptoKeyCache(ctx context.Context, log slog.Logger, client *wsproxysdk.
4545
opt(cache)
4646
}
4747

48-
m, latest, err := cache.fetch(ctx)
48+
cache.refreshCtx, cache.refreshCancel = context.WithCancel(ctx)
49+
cache.refresher = cache.Clock.AfterFunc(time.Minute*10, cache.refresh)
50+
m, latest, err := cache.fetchKeys(ctx)
4951
if err != nil {
52+
cache.refreshCancel()
5053
return nil, xerrors.Errorf("initial fetch: %w", err)
5154
}
5255
cache.keys, cache.latest = m, latest
53-
cache.refresher = cache.Clock.AfterFunc(time.Minute*10, cache.refresh)
5456

5557
return cache, nil
5658
}
@@ -77,9 +79,12 @@ func (k *CryptoKeyCache) Signing(ctx context.Context) (codersdk.CryptoKey, error
7779
}
7880

7981
k.keysMu.RLock()
80-
if k.latest.CanSign(now) {
81-
k.keysMu.RUnlock()
82-
return k.latest, nil
82+
latest = k.latest
83+
k.keysMu.RUnlock()
84+
85+
now = k.Clock.Now()
86+
if latest.CanSign(now) {
87+
return latest, nil
8388
}
8489

8590
_, latest, err := k.fetch(ctx)
@@ -91,27 +96,28 @@ func (k *CryptoKeyCache) Signing(ctx context.Context) (codersdk.CryptoKey, error
9196
}
9297

9398
func (k *CryptoKeyCache) Verifying(ctx context.Context, sequence int32) (codersdk.CryptoKey, error) {
94-
now := k.Clock.Now()
95-
k.keysMu.RLock()
9699
if k.isClosed() {
97-
k.keysMu.RUnlock()
98100
return codersdk.CryptoKey{}, cryptokeys.ErrClosed
99101
}
100102

103+
now := k.Clock.Now()
104+
k.keysMu.RLock()
101105
key, ok := k.keys[sequence]
102106
k.keysMu.RUnlock()
103107
if ok {
104108
return validKey(key, now)
105109
}
106110

107-
k.keysMu.Lock()
108-
defer k.keysMu.Unlock()
111+
k.fetchLock.Lock()
112+
defer k.fetchLock.Unlock()
109113

110114
if k.isClosed() {
111115
return codersdk.CryptoKey{}, cryptokeys.ErrClosed
112116
}
113117

118+
k.keysMu.RLock()
114119
key, ok = k.keys[sequence]
120+
k.keysMu.RUnlock()
115121
if ok {
116122
return validKey(key, now)
117123
}
@@ -134,14 +140,23 @@ func (k *CryptoKeyCache) refresh() {
134140
return
135141
}
136142

137-
k.keysMu.RLock()
138-
if k.Clock.Now().Sub(k.lastFetch) < time.Minute*10 {
139-
k.keysMu.Unlock()
143+
k.fetchLock.Lock()
144+
defer k.fetchLock.Unlock()
145+
146+
if k.isClosed() {
140147
return
141148
}
142149

143-
k.fetchLock.Lock()
144-
defer k.fetchLock.Unlock()
150+
k.keysMu.RLock()
151+
lastFetch := k.lastFetch
152+
k.keysMu.RUnlock()
153+
154+
// There's a window we must account for where the timer fires while a fetch
155+
// is ongoing but prior to the timer getting reset. In this case we want to
156+
// avoid double fetching.
157+
if k.Clock.Now().Sub(lastFetch) < time.Minute*10 {
158+
return
159+
}
145160

146161
_, _, err := k.fetch(k.refreshCtx)
147162
if err != nil {
@@ -150,19 +165,28 @@ func (k *CryptoKeyCache) refresh() {
150165
}
151166
}
152167

153-
func (k *CryptoKeyCache) fetch(ctx context.Context) (map[int32]codersdk.CryptoKey, codersdk.CryptoKey, error) {
154-
168+
func (k *CryptoKeyCache) fetchKeys(ctx context.Context) (map[int32]codersdk.CryptoKey, codersdk.CryptoKey, error) {
155169
keys, err := k.client.CryptoKeys(ctx)
156170
if err != nil {
157-
return nil, codersdk.CryptoKey{}, xerrors.Errorf("get security keys: %w", err)
171+
return nil, codersdk.CryptoKey{}, xerrors.Errorf("crypto keys: %w", err)
158172
}
173+
cache, latest := toKeyMap(keys.CryptoKeys, k.Clock.Now())
174+
return cache, latest, nil
175+
}
159176

160-
if len(keys.CryptoKeys) == 0 {
177+
// fetch fetches the keys from the control plane and updates the cache. The fetchMu
178+
// must be held when calling this function to avoid multiple concurrent fetches.
179+
func (k *CryptoKeyCache) fetch(ctx context.Context) (map[int32]codersdk.CryptoKey, codersdk.CryptoKey, error) {
180+
keys, latest, err := k.fetchKeys(ctx)
181+
if err != nil {
182+
return nil, codersdk.CryptoKey{}, xerrors.Errorf("fetch keys: %w", err)
183+
}
184+
185+
if len(keys) == 0 {
161186
return nil, codersdk.CryptoKey{}, cryptokeys.ErrKeyNotFound
162187
}
163188

164189
now := k.Clock.Now()
165-
kmap, latest := toKeyMap(keys.CryptoKeys, now)
166190
if !latest.CanSign(now) {
167191
return nil, codersdk.CryptoKey{}, cryptokeys.ErrKeyInvalid
168192
}
@@ -172,9 +196,9 @@ func (k *CryptoKeyCache) fetch(ctx context.Context) (map[int32]codersdk.CryptoKe
172196

173197
k.lastFetch = k.Clock.Now()
174198
k.refresher.Reset(time.Minute * 10)
175-
k.keys, k.latest = kmap, latest
199+
k.keys, k.latest = keys, latest
176200

177-
return kmap, latest, nil
201+
return keys, latest, nil
178202
}
179203

180204
func toKeyMap(keys []codersdk.CryptoKey, now time.Time) (map[int32]codersdk.CryptoKey, codersdk.CryptoKey) {
@@ -202,6 +226,11 @@ func (k *CryptoKeyCache) isClosed() bool {
202226
}
203227

204228
func (k *CryptoKeyCache) Close() {
229+
// The fetch lock must always be held before holding the keys lock
230+
// otherwise we risk a deadlock.
231+
k.fetchLock.Lock()
232+
defer k.fetchLock.Unlock()
233+
205234
k.keysMu.Lock()
206235
defer k.keysMu.Unlock()
207236

enterprise/wsproxy/keycache_test.go

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -317,8 +317,6 @@ func TestCryptoKeyCache(t *testing.T) {
317317
clock = quartz.NewMock(t)
318318
)
319319

320-
trap := clock.Trap().TickerFunc()
321-
322320
now := clock.Now().UTC()
323321
expected := codersdk.CryptoKey{
324322
Feature: codersdk.CryptoKeyFeatureWorkspaceApp,
@@ -339,8 +337,6 @@ func TestCryptoKeyCache(t *testing.T) {
339337
require.Equal(t, expected, got)
340338
require.Equal(t, 1, fc.called)
341339

342-
wait := trap.MustWait(ctx)
343-
344340
newKey := codersdk.CryptoKey{
345341
Feature: codersdk.CryptoKeyFeatureWorkspaceApp,
346342
Secret: "key2",
@@ -349,8 +345,6 @@ func TestCryptoKeyCache(t *testing.T) {
349345
}
350346
fc.keys = []codersdk.CryptoKey{newKey}
351347

352-
wait.Release()
353-
354348
// The ticker should fire and cause a request to coderd.
355349
_, advance := clock.AdvanceNext()
356350
advance.MustWait(ctx)
@@ -362,9 +356,10 @@ func TestCryptoKeyCache(t *testing.T) {
362356
require.Equal(t, newKey, got)
363357
require.Equal(t, 2, fc.called)
364358

365-
// Assert we do not have the old key.
366-
_, err = cache.Verifying(ctx, expected.Sequence)
367-
require.Error(t, err)
359+
// The ticker should fire and cause a request to coderd.
360+
_, advance = clock.AdvanceNext()
361+
advance.MustWait(ctx)
362+
require.Equal(t, 3, fc.called)
368363
})
369364

370365
t.Run("Closed", func(t *testing.T) {

0 commit comments

Comments
 (0)