Skip to content

Commit

Permalink
[FIXED] Race in MessageBatch (#1743)
Browse files Browse the repository at this point in the history
Signed-off-by: Piotr Piotrowski <[email protected]>
  • Loading branch information
piotrpio authored Dec 17, 2024
1 parent d05f24a commit 6bc4159
Show file tree
Hide file tree
Showing 6 changed files with 312 additions and 127 deletions.
13 changes: 8 additions & 5 deletions jetstream/ordered.go
Original file line number Diff line number Diff line change
Expand Up @@ -393,26 +393,26 @@ func (s *orderedSubscription) Closed() <-chan struct{} {
// reset the consumer for each subsequent Fetch call.
// Consider using [Consumer.Consume] or [Consumer.Messages] instead.
func (c *orderedConsumer) Fetch(batch int, opts ...FetchOpt) (MessageBatch, error) {
c.Lock()
if c.consumerType == consumerTypeConsume {
c.Unlock()
return nil, ErrOrderConsumerUsedAsConsume
}
c.currentConsumer.Lock()
if c.runningFetch != nil {
if !c.runningFetch.done {
c.currentConsumer.Unlock()
if !c.runningFetch.closed() {
return nil, ErrOrderedConsumerConcurrentRequests
}
if c.runningFetch.sseq != 0 {
c.cursor.streamSeq = c.runningFetch.sseq
}
}
c.currentConsumer.Unlock()
c.consumerType = consumerTypeFetch
sub := orderedSubscription{
consumer: c,
done: make(chan struct{}),
}
c.subscription = &sub
c.Unlock()
err := c.reset()
if err != nil {
return nil, err
Expand All @@ -433,11 +433,13 @@ func (c *orderedConsumer) Fetch(batch int, opts ...FetchOpt) (MessageBatch, erro
// reset the consumer for each subsequent Fetch call.
// Consider using [Consumer.Consume] or [Consumer.Messages] instead.
func (c *orderedConsumer) FetchBytes(maxBytes int, opts ...FetchOpt) (MessageBatch, error) {
c.Lock()
if c.consumerType == consumerTypeConsume {
c.Unlock()
return nil, ErrOrderConsumerUsedAsConsume
}
if c.runningFetch != nil {
if !c.runningFetch.done {
if !c.runningFetch.closed() {
return nil, ErrOrderedConsumerConcurrentRequests
}
if c.runningFetch.sseq != 0 {
Expand All @@ -450,6 +452,7 @@ func (c *orderedConsumer) FetchBytes(maxBytes int, opts ...FetchOpt) (MessageBat
done: make(chan struct{}),
}
c.subscription = &sub
c.Unlock()
err := c.reset()
if err != nil {
return nil, err
Expand Down
25 changes: 20 additions & 5 deletions jetstream/pull.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ type (
}

fetchResult struct {
sync.Mutex
msgs chan Msg
err error
done bool
Expand Down Expand Up @@ -780,7 +781,7 @@ func (p *pullConsumer) fetch(req *pullRequest) (MessageBatch, error) {
for {
select {
case msg := <-msgs:
p.Lock()
res.Lock()
if hbTimer != nil {
hbTimer.Reset(2 * req.Heartbeat)
}
Expand All @@ -791,11 +792,11 @@ func (p *pullConsumer) fetch(req *pullRequest) (MessageBatch, error) {
res.err = err
}
res.done = true
p.Unlock()
res.Unlock()
return
}
if !userMsg {
p.Unlock()
res.Unlock()
continue
}
res.msgs <- p.jetStream.toJSMsg(msg)
Expand All @@ -810,16 +811,20 @@ func (p *pullConsumer) fetch(req *pullRequest) (MessageBatch, error) {
}
if receivedMsgs == req.Batch || (req.MaxBytes != 0 && receivedBytes >= req.MaxBytes) {
res.done = true
p.Unlock()
res.Unlock()
return
}
p.Unlock()
res.Unlock()
case err := <-sub.errs:
res.Lock()
res.err = err
res.done = true
res.Unlock()
return
case <-time.After(req.Expires + 1*time.Second):
res.Lock()
res.done = true
res.Unlock()
return
}
}
Expand All @@ -828,13 +833,23 @@ func (p *pullConsumer) fetch(req *pullRequest) (MessageBatch, error) {
}

func (fr *fetchResult) Messages() <-chan Msg {
fr.Lock()
defer fr.Unlock()
return fr.msgs
}

func (fr *fetchResult) Error() error {
fr.Lock()
defer fr.Unlock()
return fr.err
}

func (fr *fetchResult) closed() bool {
fr.Lock()
defer fr.Unlock()
return fr.done
}

// Next is used to retrieve the next message from the stream. This
// method will block until the message is retrieved or timeout is
// reached.
Expand Down
216 changes: 104 additions & 112 deletions jetstream/test/ordered_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -580,131 +580,123 @@ func TestOrderedConsumerConsume(t *testing.T) {
})

t.Run("wait for closed after drain", func(t *testing.T) {
for i := 0; i < 10; i++ {
t.Run(fmt.Sprintf("run %d", i), func(t *testing.T) {
srv := RunBasicJetStreamServer()
defer shutdownJSServerAndRemoveStorage(t, srv)
nc, err := nats.Connect(srv.ClientURL())
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
srv := RunBasicJetStreamServer()
defer shutdownJSServerAndRemoveStorage(t, srv)
nc, err := nats.Connect(srv.ClientURL())
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

js, err := jetstream.New(nc)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
defer nc.Close()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
c, err := s.OrderedConsumer(ctx, jetstream.OrderedConsumerConfig{})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
msgs := make([]jetstream.Msg, 0)
lock := sync.Mutex{}
publishTestMsgs(t, js)
cc, err := c.Consume(func(msg jetstream.Msg) {
time.Sleep(50 * time.Millisecond)
msg.Ack()
lock.Lock()
msgs = append(msgs, msg)
lock.Unlock()
})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
closed := cc.Closed()
time.Sleep(100 * time.Millisecond)
if err := s.DeleteConsumer(context.Background(), c.CachedInfo().Name); err != nil {
t.Fatalf("Unexpected error: %v", err)
}
publishTestMsgs(t, js)

// wait for the consumer to be recreated before calling drain
for i := 0; i < 5; i++ {
_, err = c.Info(ctx)
if err != nil {
if errors.Is(err, jetstream.ErrConsumerNotFound) {
time.Sleep(100 * time.Millisecond)
continue
}
t.Fatalf("Unexpected error: %v", err)
}
break
js, err := jetstream.New(nc)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
defer nc.Close()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
c, err := s.OrderedConsumer(ctx, jetstream.OrderedConsumerConfig{})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
msgs := make([]jetstream.Msg, 0)
lock := sync.Mutex{}
publishTestMsgs(t, js)
cc, err := c.Consume(func(msg jetstream.Msg) {
time.Sleep(50 * time.Millisecond)
msg.Ack()
lock.Lock()
msgs = append(msgs, msg)
lock.Unlock()
})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
closed := cc.Closed()
time.Sleep(100 * time.Millisecond)
if err := s.DeleteConsumer(context.Background(), c.CachedInfo().Name); err != nil {
t.Fatalf("Unexpected error: %v", err)
}
publishTestMsgs(t, js)

// wait for the consumer to be recreated before calling drain
for i := 0; i < 5; i++ {
_, err = c.Info(ctx)
if err != nil {
if errors.Is(err, jetstream.ErrConsumerNotFound) {
time.Sleep(100 * time.Millisecond)
continue
}
t.Fatalf("Unexpected error: %v", err)
}
break
}

cc.Drain()
cc.Drain()

select {
case <-closed:
case <-time.After(5 * time.Second):
t.Fatalf("Timeout waiting for consume to be closed")
}
select {
case <-closed:
case <-time.After(5 * time.Second):
t.Fatalf("Timeout waiting for consume to be closed")
}

if len(msgs) != 2*len(testMsgs) {
t.Fatalf("Unexpected received message count after consume closed; want %d; got %d", 2*len(testMsgs), len(msgs))
}
})
if len(msgs) != 2*len(testMsgs) {
t.Fatalf("Unexpected received message count after consume closed; want %d; got %d", 2*len(testMsgs), len(msgs))
}
})

t.Run("wait for closed on already closed consume", func(t *testing.T) {
for i := 0; i < 10; i++ {
t.Run(fmt.Sprintf("run %d", i), func(t *testing.T) {
srv := RunBasicJetStreamServer()
defer shutdownJSServerAndRemoveStorage(t, srv)
nc, err := nats.Connect(srv.ClientURL())
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
srv := RunBasicJetStreamServer()
defer shutdownJSServerAndRemoveStorage(t, srv)
nc, err := nats.Connect(srv.ClientURL())
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

js, err := jetstream.New(nc)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
defer nc.Close()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
c, err := s.OrderedConsumer(ctx, jetstream.OrderedConsumerConfig{})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
msgs := make([]jetstream.Msg, 0)
lock := sync.Mutex{}
publishTestMsgs(t, js)
cc, err := c.Consume(func(msg jetstream.Msg) {
time.Sleep(50 * time.Millisecond)
msg.Ack()
lock.Lock()
msgs = append(msgs, msg)
lock.Unlock()
})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
time.Sleep(100 * time.Millisecond)
if err := s.DeleteConsumer(context.Background(), c.CachedInfo().Name); err != nil {
t.Fatalf("Unexpected error: %v", err)
}
js, err := jetstream.New(nc)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
defer nc.Close()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
c, err := s.OrderedConsumer(ctx, jetstream.OrderedConsumerConfig{})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
msgs := make([]jetstream.Msg, 0)
lock := sync.Mutex{}
publishTestMsgs(t, js)
cc, err := c.Consume(func(msg jetstream.Msg) {
time.Sleep(50 * time.Millisecond)
msg.Ack()
lock.Lock()
msgs = append(msgs, msg)
lock.Unlock()
})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
time.Sleep(100 * time.Millisecond)
if err := s.DeleteConsumer(context.Background(), c.CachedInfo().Name); err != nil {
t.Fatalf("Unexpected error: %v", err)
}

cc.Stop()
cc.Stop()

time.Sleep(100 * time.Millisecond)
time.Sleep(100 * time.Millisecond)

select {
case <-cc.Closed():
case <-time.After(5 * time.Second):
t.Fatalf("Timeout waiting for consume to be closed")
}
})
select {
case <-cc.Closed():
case <-time.After(5 * time.Second):
t.Fatalf("Timeout waiting for consume to be closed")
}
})
}
Expand Down
Loading

0 comments on commit 6bc4159

Please sign in to comment.