Skip to content

Commit 65f57b1

Browse files
committed
Protect against NOTIFY races a slow receiver
Signed-off-by: Danny Kopping <danny@coder.com>
1 parent 28a96de commit 65f57b1

File tree

4 files changed

+131
-32
lines changed

4 files changed

+131
-32
lines changed

coderd/database/pubsub/latency.go

Lines changed: 40 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,65 @@
11
package pubsub
22

33
import (
4+
"bytes"
45
"context"
56
"fmt"
67
"time"
78

89
"github.com/google/uuid"
910
"golang.org/x/xerrors"
11+
12+
"cdr.dev/slog"
1013
)
1114

1215
// LatencyMeasurer is used to measure the send & receive latencies of the underlying Pubsub implementation. We use these
1316
// measurements to export metrics which can indicate when a Pubsub implementation's queue is overloaded and/or full.
1417
type LatencyMeasurer struct {
15-
// Create unique pubsub channel names so that multiple replicas do not clash when performing latency measurements,
16-
// and only create one UUID per Pubsub impl (and not request) to limit the number of notification channels that need
17-
// to be maintained by the Pubsub impl.
18-
channelIDs map[Pubsub]uuid.UUID
18+
// Create unique pubsub channel names so that multiple coderd replicas do not clash when performing latency measurements.
19+
channel uuid.UUID
20+
logger slog.Logger
1921
}
2022

21-
func NewLatencyMeasurer() *LatencyMeasurer {
23+
// LatencyMessageLength is the length of a UUIDv4 encoded to hex.
24+
const LatencyMessageLength = 36
25+
26+
func NewLatencyMeasurer(logger slog.Logger) *LatencyMeasurer {
2227
return &LatencyMeasurer{
23-
channelIDs: make(map[Pubsub]uuid.UUID),
28+
channel: uuid.New(),
29+
logger: logger,
2430
}
2531
}
2632

2733
// Measure takes a given Pubsub implementation, publishes a message & immediately receives it, and returns the observed latency.
2834
func (lm *LatencyMeasurer) Measure(ctx context.Context, p Pubsub) (send float64, recv float64, err error) {
2935
var (
30-
start time.Time
31-
res = make(chan float64, 1)
36+
start time.Time
37+
res = make(chan float64, 1)
38+
subscribeErr = make(chan error, 1)
3239
)
3340

34-
cancel, err := p.Subscribe(lm.latencyChannelName(p), func(ctx context.Context, _ []byte) {
35-
res <- time.Since(start).Seconds()
36-
})
37-
if err != nil {
38-
return -1, -1, xerrors.Errorf("failed to subscribe: %w", err)
39-
}
40-
defer cancel()
41+
msg := []byte(uuid.New().String())
42+
log := lm.logger.With(slog.F("msg", msg))
43+
44+
go func() {
45+
_, err = p.Subscribe(lm.latencyChannelName(), func(ctx context.Context, in []byte) {
46+
p := p
47+
_ = p
48+
49+
if !bytes.Equal(in, msg) {
50+
log.Warn(ctx, "received unexpected message!", slog.F("in", in))
51+
return
52+
}
53+
54+
res <- time.Since(start).Seconds()
55+
})
56+
if err != nil {
57+
subscribeErr <- xerrors.Errorf("failed to subscribe: %w", err)
58+
}
59+
}()
4160

4261
start = time.Now()
43-
err = p.Publish(lm.latencyChannelName(p), []byte{})
62+
err = p.Publish(lm.latencyChannelName(), msg)
4463
if err != nil {
4564
return -1, -1, xerrors.Errorf("failed to publish: %w", err)
4665
}
@@ -49,18 +68,15 @@ func (lm *LatencyMeasurer) Measure(ctx context.Context, p Pubsub) (send float64,
4968

5069
select {
5170
case <-ctx.Done():
71+
log.Error(ctx, "context canceled before message could be received", slog.Error(ctx.Err()))
5272
return send, -1, ctx.Err()
5373
case val := <-res:
5474
return send, val, nil
75+
case err = <-subscribeErr:
76+
return send, -1, err
5577
}
5678
}
5779

58-
func (lm *LatencyMeasurer) latencyChannelName(p Pubsub) string {
59-
cid, found := lm.channelIDs[p]
60-
if !found {
61-
cid = uuid.New()
62-
lm.channelIDs[p] = cid
63-
}
64-
65-
return fmt.Sprintf("latency-measure:%s", cid.String())
80+
func (lm *LatencyMeasurer) latencyChannelName() string {
81+
return fmt.Sprintf("latency-measure:%s", lm.channel)
6682
}

coderd/database/pubsub/pubsub.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -589,7 +589,7 @@ func newWithoutListener(logger slog.Logger, database *sql.DB) *PGPubsub {
589589
listenDone: make(chan struct{}),
590590
db: database,
591591
queues: make(map[string]map[uuid.UUID]*msgQueue),
592-
latencyMeasurer: NewLatencyMeasurer(),
592+
latencyMeasurer: NewLatencyMeasurer(logger.Named("latency-measurer")),
593593

594594
publishesTotal: prometheus.NewCounterVec(prometheus.CounterOpts{
595595
Namespace: "coder",

coderd/database/pubsub/pubsub_linux_test.go

Lines changed: 84 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,17 @@
33
package pubsub_test
44

55
import (
6+
"bytes"
67
"context"
78
"database/sql"
89
"fmt"
910
"math/rand"
1011
"strconv"
12+
"sync"
1113
"testing"
1214
"time"
1315

16+
"cdr.dev/slog/sloggers/sloghuman"
1417
"github.com/stretchr/testify/assert"
1518
"github.com/stretchr/testify/require"
1619
"golang.org/x/xerrors"
@@ -319,13 +322,14 @@ func TestMeasureLatency(t *testing.T) {
319322
t.Run("MeasureLatency", func(t *testing.T) {
320323
t.Parallel()
321324

325+
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
322326
ps, done := newPubsub()
323327
defer done()
324328

325-
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
329+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
326330
defer cancel()
327331

328-
send, recv, err := pubsub.NewLatencyMeasurer().Measure(ctx, ps)
332+
send, recv, err := pubsub.NewLatencyMeasurer(logger).Measure(ctx, ps)
329333
require.NoError(t, err)
330334
require.Greater(t, send, 0.0)
331335
require.Greater(t, recv, 0.0)
@@ -334,16 +338,93 @@ func TestMeasureLatency(t *testing.T) {
334338
t.Run("MeasureLatencyRecvTimeout", func(t *testing.T) {
335339
t.Parallel()
336340

341+
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
337342
ps, done := newPubsub()
338343
defer done()
339344

340345
// nolint:gocritic // need a very short timeout here to trigger error
341346
ctx, cancel := context.WithTimeout(context.Background(), time.Nanosecond)
342347
defer cancel()
343348

344-
send, recv, err := pubsub.NewLatencyMeasurer().Measure(ctx, ps)
349+
send, recv, err := pubsub.NewLatencyMeasurer(logger).Measure(ctx, ps)
345350
require.ErrorContains(t, err, context.DeadlineExceeded.Error())
346351
require.Greater(t, send, 0.0)
347352
require.EqualValues(t, recv, -1)
348353
})
354+
355+
t.Run("MeasureLatencyNotifyRace", func(t *testing.T) {
356+
t.Parallel()
357+
358+
var buf bytes.Buffer
359+
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
360+
logger = logger.AppendSinks(sloghuman.Sink(&buf))
361+
362+
lm := pubsub.NewLatencyMeasurer(logger)
363+
ps, done := newPubsub()
364+
defer done()
365+
366+
slow := newDelayedListener(ps, time.Second)
367+
fast := newDelayedListener(ps, time.Nanosecond)
368+
369+
var wg sync.WaitGroup
370+
wg.Add(2)
371+
372+
// Publish message concurrently to a slow receiver.
373+
go func() {
374+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
375+
defer cancel()
376+
defer wg.Done()
377+
378+
// Slow receiver will not receive its latency message because the fast one receives it first.
379+
_, _, err := lm.Measure(ctx, slow)
380+
require.ErrorContains(t, err, context.DeadlineExceeded.Error())
381+
}()
382+
383+
// Publish message concurrently to a fast receiver who will receive both its own and the slow receiver's messages.
384+
// It should ignore the unexpected message and consume its own, leaving the slow receiver to timeout since it
385+
// will never receive their own message.
386+
go func() {
387+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
388+
defer cancel()
389+
defer wg.Done()
390+
391+
send, recv, err := lm.Measure(ctx, fast)
392+
require.NoError(t, err)
393+
require.Greater(t, send, 0.0)
394+
require.Greater(t, recv, 0.0)
395+
}()
396+
397+
wg.Wait()
398+
399+
// Flush the contents of the logger to its buffer.
400+
logger.Sync()
401+
require.Contains(t, buf.String(), "received unexpected message!")
402+
})
403+
}
404+
405+
type delayedListener struct {
406+
pubsub.Pubsub
407+
delay time.Duration
408+
}
409+
410+
func newDelayedListener(ps pubsub.Pubsub, delay time.Duration) *delayedListener {
411+
return &delayedListener{Pubsub: ps, delay: delay}
412+
}
413+
414+
func (s *delayedListener) Subscribe(event string, listener pubsub.Listener) (cancel func(), err error) {
415+
time.Sleep(s.delay)
416+
return s.Pubsub.Subscribe(event, listener)
417+
}
418+
419+
func (s *delayedListener) SubscribeWithErr(event string, listener pubsub.ListenerWithErr) (cancel func(), err error) {
420+
time.Sleep(s.delay)
421+
return s.Pubsub.SubscribeWithErr(event, listener)
422+
}
423+
424+
func (s *delayedListener) Publish(event string, message []byte) error {
425+
return s.Pubsub.Publish(event, message)
426+
}
427+
428+
func (s *delayedListener) Close() error {
429+
return s.Pubsub.Close()
349430
}

coderd/database/pubsub/pubsub_test.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ func TestPGPubsub_Metrics(t *testing.T) {
6363
_ = testutil.RequireRecvCtx(ctx, t, messageChannel)
6464

6565
require.Eventually(t, func() bool {
66+
latencyBytes := gatherCount * pubsub.LatencyMessageLength
6667
metrics, err = registry.Gather()
6768
gatherCount++
6869
assert.NoError(t, err)
@@ -72,8 +73,8 @@ func TestPGPubsub_Metrics(t *testing.T) {
7273
testutil.PromCounterHasValue(t, metrics, gatherCount, "coder_pubsub_publishes_total", "true") &&
7374
testutil.PromCounterHasValue(t, metrics, gatherCount, "coder_pubsub_subscribes_total", "true") &&
7475
testutil.PromCounterHasValue(t, metrics, gatherCount, "coder_pubsub_messages_total", "normal") &&
75-
testutil.PromCounterHasValue(t, metrics, float64(len(data)), "coder_pubsub_received_bytes_total") &&
76-
testutil.PromCounterHasValue(t, metrics, float64(len(data)), "coder_pubsub_published_bytes_total") &&
76+
testutil.PromCounterHasValue(t, metrics, float64(len(data))+latencyBytes, "coder_pubsub_received_bytes_total") &&
77+
testutil.PromCounterHasValue(t, metrics, float64(len(data))+latencyBytes, "coder_pubsub_published_bytes_total") &&
7778
testutil.PromGaugeAssertion(t, metrics, func(in float64) bool { return in > 0 }, "coder_pubsub_send_latency_seconds") &&
7879
testutil.PromGaugeAssertion(t, metrics, func(in float64) bool { return in > 0 }, "coder_pubsub_receive_latency_seconds") &&
7980
!testutil.PromCounterGathered(t, metrics, "coder_pubsub_latency_measure_errs_total")
@@ -98,6 +99,7 @@ func TestPGPubsub_Metrics(t *testing.T) {
9899
_ = testutil.RequireRecvCtx(ctx, t, messageChannel)
99100

100101
require.Eventually(t, func() bool {
102+
latencyBytes := gatherCount * pubsub.LatencyMessageLength
101103
metrics, err = registry.Gather()
102104
gatherCount++
103105
assert.NoError(t, err)
@@ -108,8 +110,8 @@ func TestPGPubsub_Metrics(t *testing.T) {
108110
testutil.PromCounterHasValue(t, metrics, 1+gatherCount, "coder_pubsub_subscribes_total", "true") &&
109111
testutil.PromCounterHasValue(t, metrics, gatherCount, "coder_pubsub_messages_total", "normal") &&
110112
testutil.PromCounterHasValue(t, metrics, 1, "coder_pubsub_messages_total", "colossal") &&
111-
testutil.PromCounterHasValue(t, metrics, float64(colossalSize+len(data)), "coder_pubsub_received_bytes_total") &&
112-
testutil.PromCounterHasValue(t, metrics, float64(colossalSize+len(data)), "coder_pubsub_published_bytes_total") &&
113+
testutil.PromCounterHasValue(t, metrics, float64(colossalSize+len(data))+latencyBytes, "coder_pubsub_received_bytes_total") &&
114+
testutil.PromCounterHasValue(t, metrics, float64(colossalSize+len(data))+latencyBytes, "coder_pubsub_published_bytes_total") &&
113115
testutil.PromGaugeAssertion(t, metrics, func(in float64) bool { return in > 0 }, "coder_pubsub_send_latency_seconds") &&
114116
testutil.PromGaugeAssertion(t, metrics, func(in float64) bool { return in > 0 }, "coder_pubsub_receive_latency_seconds") &&
115117
!testutil.PromCounterGathered(t, metrics, "coder_pubsub_latency_measure_errs_total")

0 commit comments

Comments
 (0)