@@ -45,12 +45,14 @@ func NewCryptoKeyCache(ctx context.Context, log slog.Logger, client *wsproxysdk.
45
45
opt (cache )
46
46
}
47
47
48
- m , latest , err := cache .fetch (ctx )
48
+ cache .refreshCtx , cache .refreshCancel = context .WithCancel (ctx )
49
+ cache .refresher = cache .Clock .AfterFunc (time .Minute * 10 , cache .refresh )
50
+ m , latest , err := cache .fetchKeys (ctx )
49
51
if err != nil {
52
+ cache .refreshCancel ()
50
53
return nil , xerrors .Errorf ("initial fetch: %w" , err )
51
54
}
52
55
cache .keys , cache .latest = m , latest
53
- cache .refresher = cache .Clock .AfterFunc (time .Minute * 10 , cache .refresh )
54
56
55
57
return cache , nil
56
58
}
@@ -77,9 +79,12 @@ func (k *CryptoKeyCache) Signing(ctx context.Context) (codersdk.CryptoKey, error
77
79
}
78
80
79
81
k .keysMu .RLock ()
80
- if k .latest .CanSign (now ) {
81
- k .keysMu .RUnlock ()
82
- return k .latest , nil
82
+ latest = k .latest
83
+ k .keysMu .RUnlock ()
84
+
85
+ now = k .Clock .Now ()
86
+ if latest .CanSign (now ) {
87
+ return latest , nil
83
88
}
84
89
85
90
_ , latest , err := k .fetch (ctx )
@@ -91,27 +96,28 @@ func (k *CryptoKeyCache) Signing(ctx context.Context) (codersdk.CryptoKey, error
91
96
}
92
97
93
98
func (k * CryptoKeyCache ) Verifying (ctx context.Context , sequence int32 ) (codersdk.CryptoKey , error ) {
94
- now := k .Clock .Now ()
95
- k .keysMu .RLock ()
96
99
if k .isClosed () {
97
- k .keysMu .RUnlock ()
98
100
return codersdk.CryptoKey {}, cryptokeys .ErrClosed
99
101
}
100
102
103
+ now := k .Clock .Now ()
104
+ k .keysMu .RLock ()
101
105
key , ok := k .keys [sequence ]
102
106
k .keysMu .RUnlock ()
103
107
if ok {
104
108
return validKey (key , now )
105
109
}
106
110
107
- k .keysMu .Lock ()
108
- defer k .keysMu .Unlock ()
111
+ k .fetchLock .Lock ()
112
+ defer k .fetchLock .Unlock ()
109
113
110
114
if k .isClosed () {
111
115
return codersdk.CryptoKey {}, cryptokeys .ErrClosed
112
116
}
113
117
118
+ k .keysMu .RLock ()
114
119
key , ok = k .keys [sequence ]
120
+ k .keysMu .RUnlock ()
115
121
if ok {
116
122
return validKey (key , now )
117
123
}
@@ -134,14 +140,23 @@ func (k *CryptoKeyCache) refresh() {
134
140
return
135
141
}
136
142
137
- k .keysMu .RLock ()
138
- if k .Clock .Now ().Sub (k .lastFetch ) < time .Minute * 10 {
139
- k .keysMu .Unlock ()
143
+ k .fetchLock .Lock ()
144
+ defer k .fetchLock .Unlock ()
145
+
146
+ if k .isClosed () {
140
147
return
141
148
}
142
149
143
- k .fetchLock .Lock ()
144
- defer k .fetchLock .Unlock ()
150
+ k .keysMu .RLock ()
151
+ lastFetch := k .lastFetch
152
+ k .keysMu .RUnlock ()
153
+
154
+ // There's a window we must account for where the timer fires while a fetch
155
+ // is ongoing but prior to the timer getting reset. In this case we want to
156
+ // avoid double fetching.
157
+ if k .Clock .Now ().Sub (lastFetch ) < time .Minute * 10 {
158
+ return
159
+ }
145
160
146
161
_ , _ , err := k .fetch (k .refreshCtx )
147
162
if err != nil {
@@ -150,19 +165,28 @@ func (k *CryptoKeyCache) refresh() {
150
165
}
151
166
}
152
167
153
- func (k * CryptoKeyCache ) fetch (ctx context.Context ) (map [int32 ]codersdk.CryptoKey , codersdk.CryptoKey , error ) {
154
-
168
+ func (k * CryptoKeyCache ) fetchKeys (ctx context.Context ) (map [int32 ]codersdk.CryptoKey , codersdk.CryptoKey , error ) {
155
169
keys , err := k .client .CryptoKeys (ctx )
156
170
if err != nil {
157
- return nil , codersdk.CryptoKey {}, xerrors .Errorf ("get security keys: %w" , err )
171
+ return nil , codersdk.CryptoKey {}, xerrors .Errorf ("crypto keys: %w" , err )
158
172
}
173
+ cache , latest := toKeyMap (keys .CryptoKeys , k .Clock .Now ())
174
+ return cache , latest , nil
175
+ }
159
176
160
- if len (keys .CryptoKeys ) == 0 {
177
+ // fetch fetches the keys from the control plane and updates the cache. The fetchMu
178
+ // must be held when calling this function to avoid multiple concurrent fetches.
179
+ func (k * CryptoKeyCache ) fetch (ctx context.Context ) (map [int32 ]codersdk.CryptoKey , codersdk.CryptoKey , error ) {
180
+ keys , latest , err := k .fetchKeys (ctx )
181
+ if err != nil {
182
+ return nil , codersdk.CryptoKey {}, xerrors .Errorf ("fetch keys: %w" , err )
183
+ }
184
+
185
+ if len (keys ) == 0 {
161
186
return nil , codersdk.CryptoKey {}, cryptokeys .ErrKeyNotFound
162
187
}
163
188
164
189
now := k .Clock .Now ()
165
- kmap , latest := toKeyMap (keys .CryptoKeys , now )
166
190
if ! latest .CanSign (now ) {
167
191
return nil , codersdk.CryptoKey {}, cryptokeys .ErrKeyInvalid
168
192
}
@@ -172,9 +196,9 @@ func (k *CryptoKeyCache) fetch(ctx context.Context) (map[int32]codersdk.CryptoKe
172
196
173
197
k .lastFetch = k .Clock .Now ()
174
198
k .refresher .Reset (time .Minute * 10 )
175
- k .keys , k .latest = kmap , latest
199
+ k .keys , k .latest = keys , latest
176
200
177
- return kmap , latest , nil
201
+ return keys , latest , nil
178
202
}
179
203
180
204
func toKeyMap (keys []codersdk.CryptoKey , now time.Time ) (map [int32 ]codersdk.CryptoKey , codersdk.CryptoKey ) {
@@ -202,6 +226,11 @@ func (k *CryptoKeyCache) isClosed() bool {
202
226
}
203
227
204
228
func (k * CryptoKeyCache ) Close () {
229
+ // The fetch lock must always be held before holding the keys lock
230
+ // otherwise we risk a deadlock.
231
+ k .fetchLock .Lock ()
232
+ defer k .fetchLock .Unlock ()
233
+
205
234
k .keysMu .Lock ()
206
235
defer k .keysMu .Unlock ()
207
236
0 commit comments