diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index a63c6c5..ca2f367 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -10,22 +10,18 @@ jobs: test: runs-on: ubuntu-latest steps: + - name: Checkout code + uses: actions/checkout@v4 - name: Install Go - uses: actions/setup-go@v2 + uses: actions/setup-go@v5 with: go-version: 1.25.x - - name: Checkout code - uses: actions/checkout@v2 - - name: Go mod cache - uses: actions/cache@v4 - with: - path: ~/go/pkg/mod - key: ${{ runner.os }}-go1.25.x + cache: true - name: Tools bin cache uses: actions/cache@v4 with: path: .bin - key: ${{ runner.os }}-go1.25.x-${{ hashFiles('Makefile') }} + key: ${{ runner.os }}-tools-${{ hashFiles('Makefile') }} - name: Check run: make check - name: Test diff --git a/errors/errors.go b/errors/errors.go index 82b382e..72d6851 100644 --- a/errors/errors.go +++ b/errors/errors.go @@ -2,6 +2,7 @@ package errors import ( "errors" + "sort" "strings" ) @@ -22,10 +23,15 @@ type ErrorTags struct { // Error returns the error message with the tags attached. func (e *ErrorTags) Error() string { - md := []string{} + keys := make([]string, 0, len(e.tags)) + for k := range e.tags { + keys = append(keys, k) + } + sort.Strings(keys) - for k, v := range e.tags { - md = append(md, k+"="+v) + md := make([]string, 0, len(e.tags)) + for _, k := range keys { + md = append(md, k+"="+e.tags[k]) } return e.err.Error() + " [" + strings.Join(md, ", ") + "]" diff --git a/errors/errors_test.go b/errors/errors_test.go index 44ae125..dab5ada 100644 --- a/errors/errors_test.go +++ b/errors/errors_test.go @@ -28,7 +28,7 @@ func TestWrap_MultipleWraps(t *testing.T) { wrapped1 := errors.Wrap(err, "foo", "bar") wrapped2 := errors.Wrap(wrapped1, "baz", "qux") - assert.Equal(t, "test error [foo=bar, baz=qux]", wrapped2.Error()) + assert.Equal(t, "test error [baz=qux, foo=bar]", wrapped2.Error()) } func TestWrap_PanicsOnOddNumberOfAttrs(t *testing.T) { diff --git a/xkafka/batch.go b/xkafka/batch.go index 3765351..c7bac94 100644 --- a/xkafka/batch.go +++ b/xkafka/batch.go @@ -85,7 +85,9 @@ func (b *Batch) GroupMaxOffset() []kafka.TopicPartition { offsets := make(map[string]map[int32]int64) for _, m := range b.Messages { if _, ok := offsets[m.Topic]; !ok { - offsets[m.Topic] = make(map[int32]int64) + offsets[m.Topic] = map[int32]int64{ + m.Partition: m.Offset, + } } if m.Offset > offsets[m.Topic][m.Partition] { @@ -102,7 +104,7 @@ func (b *Batch) GroupMaxOffset() []kafka.TopicPartition { tps = append(tps, kafka.TopicPartition{ Topic: &topic, Partition: partition, - Offset: kafka.Offset(offset + 1), + Offset: kafka.Offset(offset), }) } } diff --git a/xkafka/batch_consumer.go b/xkafka/batch_consumer.go index 31b0c7e..1a7f543 100644 --- a/xkafka/batch_consumer.go +++ b/xkafka/batch_consumer.go @@ -4,6 +4,7 @@ import ( "context" "errors" "strings" + "sync" "sync/atomic" "time" @@ -20,6 +21,10 @@ type BatchConsumer struct { middlewares []BatchMiddlewarer config *consumerConfig stopOffset atomic.Bool + + // partition tracking + mu sync.Mutex + activePartitions map[string]map[int32]struct{} } // NewBatchConsumer creates a new BatchConsumer instance. @@ -44,10 +49,11 @@ func NewBatchConsumer(name string, handler BatchHandler, opts ...ConsumerOption) } return &BatchConsumer{ - name: name, - config: cfg, - kafka: consumer, - handler: handler, + name: name, + config: cfg, + kafka: consumer, + handler: handler, + activePartitions: make(map[string]map[int32]struct{}), }, nil } @@ -255,7 +261,24 @@ func (c *BatchConsumer) storeBatch(batch *Batch) error { return nil } - tps := batch.GroupMaxOffset() + allTps := batch.GroupMaxOffset() + + // filter to only active partitions + tps := make([]kafka.TopicPartition, 0, len(allTps)) + for _, tp := range allTps { + if tp.Topic != nil && c.isPartitionActive(*tp.Topic, tp.Partition) { + // similar to StoreMessage in confluent-kafka-go/consumer.go + // tp.Offset + 1 it ensures that the consumer starts with + // next message when it restarts + tp.Offset = kafka.Offset(tp.Offset + 1) + + tps = append(tps, tp) + } + } + + if len(tps) == 0 { + return nil + } _, err := c.kafka.StoreOffsets(tps) if err != nil { @@ -281,7 +304,74 @@ func (c *BatchConsumer) concatMiddlewares(h BatchHandler) BatchHandler { } func (c *BatchConsumer) subscribe() error { - return c.kafka.SubscribeTopics(c.config.topics, nil) + return c.kafka.SubscribeTopics(c.config.topics, c.rebalanceCallback) +} + +func (c *BatchConsumer) rebalanceCallback(_ *kafka.Consumer, event kafka.Event) error { + switch e := event.(type) { + case kafka.AssignedPartitions: + c.onPartitionsAssigned(e.Partitions) + return c.kafka.Assign(e.Partitions) + + case kafka.RevokedPartitions: + if err := c.kafka.Unassign(); err != nil { + return err + } + + c.onPartitionsRevoked(e.Partitions) + } + + return nil +} + +func (c *BatchConsumer) onPartitionsAssigned(partitions []kafka.TopicPartition) { + c.mu.Lock() + defer c.mu.Unlock() + + for _, tp := range partitions { + if tp.Topic == nil { + continue + } + + topic := *tp.Topic + if c.activePartitions[topic] == nil { + c.activePartitions[topic] = make(map[int32]struct{}) + } + + c.activePartitions[topic][tp.Partition] = struct{}{} + } +} + +func (c *BatchConsumer) onPartitionsRevoked(partitions []kafka.TopicPartition) { + c.mu.Lock() + defer c.mu.Unlock() + + for _, tp := range partitions { + if tp.Topic == nil { + continue + } + + topic := *tp.Topic + if c.activePartitions[topic] != nil { + delete(c.activePartitions[topic], tp.Partition) + + if len(c.activePartitions[topic]) == 0 { + delete(c.activePartitions, topic) + } + } + } +} + +func (c *BatchConsumer) isPartitionActive(topic string, partition int32) bool { + c.mu.Lock() + defer c.mu.Unlock() + + if partitions, ok := c.activePartitions[topic]; ok { + _, active := partitions[partition] + return active + } + + return false } func (c *BatchConsumer) unsubscribe() error { diff --git a/xkafka/batch_consumer_test.go b/xkafka/batch_consumer_test.go index 4b99437..e650984 100644 --- a/xkafka/batch_consumer_test.go +++ b/xkafka/batch_consumer_test.go @@ -295,7 +295,12 @@ func TestBatchConsumer_Async(t *testing.T) { return nil }) - mockKafka.On("SubscribeTopics", []string(testTopics), mock.Anything).Return(nil) + mockKafka.On("SubscribeTopics", []string(testTopics), mock.Anything).Run(func(args mock.Arguments) { + cb := args.Get(1).(kafka.RebalanceCb) + partitions := []kafka.TopicPartition{{Topic: km.TopicPartition.Topic, Partition: km.TopicPartition.Partition}} + mockKafka.On("Assign", partitions).Return(nil).Once() + cb(nil, kafka.AssignedPartitions{Partitions: partitions}) + }).Return(nil) mockKafka.On("Unsubscribe").Return(nil) mockKafka.On("Commit").Return(nil, nil) mockKafka.On("StoreOffsets", mock.Anything).Return(nil, nil) @@ -357,15 +362,19 @@ func TestBatchConsumer_StopOffsetOnError(t *testing.T) { return nil }) - mockKafka.On("SubscribeTopics", []string(testTopics), mock.Anything).Return(nil) + mockKafka.On("SubscribeTopics", []string(testTopics), mock.Anything).Run(func(args mock.Arguments) { + cb := args.Get(1).(kafka.RebalanceCb) + partitions := []kafka.TopicPartition{{Topic: km.TopicPartition.Topic, Partition: km.TopicPartition.Partition}} + mockKafka.On("Assign", partitions).Return(nil).Once() + cb(nil, kafka.AssignedPartitions{Partitions: partitions}) + }).Return(nil) mockKafka.On("Unsubscribe").Return(nil) mockKafka.On("Commit").Return(nil, nil) mockKafka.On("ReadMessage", testTimeout).Return(km, nil) mockKafka.On("Close").Return(nil) mockKafka.On("StoreOffsets", mock.Anything). - Return(nil, nil). - Times(2) + Return(nil, nil) consumer.handler = handler err := consumer.Run(ctx) @@ -416,7 +425,12 @@ func TestBatchConsumer_BatchTimeout(t *testing.T) { return nil }) - mockKafka.On("SubscribeTopics", []string(testTopics), mock.Anything).Return(nil) + mockKafka.On("SubscribeTopics", []string(testTopics), mock.Anything).Run(func(args mock.Arguments) { + cb := args.Get(1).(kafka.RebalanceCb) + partitions := []kafka.TopicPartition{{Topic: km.TopicPartition.Topic, Partition: km.TopicPartition.Partition}} + mockKafka.On("Assign", partitions).Return(nil).Once() + cb(nil, kafka.AssignedPartitions{Partitions: partitions}) + }).Return(nil) mockKafka.On("Unsubscribe").Return(nil) mockKafka.On("Commit").Return(nil, nil) mockKafka.On("StoreOffsets", mock.Anything).Return(nil, nil) @@ -483,36 +497,6 @@ func TestBatchConsumer_MiddlewareExecutionOrder(t *testing.T) { mockKafka.AssertExpectations(t) } -func TestBatchConsumer_ManualCommit(t *testing.T) { - t.Parallel() - - consumer, mockKafka := newTestBatchConsumer(t, defaultOpts...) - - km := newFakeKafkaMessage() - ctx, cancel := context.WithCancel(context.Background()) - - mockKafka.On("SubscribeTopics", []string(testTopics), mock.Anything).Return(nil) - mockKafka.On("Unsubscribe").Return(nil) - mockKafka.On("StoreOffsets", mock.Anything).Return(nil, nil) - mockKafka.On("Commit").Return(nil, nil) - mockKafka.On("ReadMessage", testTimeout).Return(km, nil) - mockKafka.On("Close").Return(nil) - - handler := BatchHandlerFunc(func(ctx context.Context, b *Batch) error { - b.AckSuccess() - - cancel() - - return nil - }) - - consumer.handler = handler - err := consumer.Run(ctx) - assert.NoError(t, err) - - mockKafka.AssertExpectations(t) -} - func TestBatchConsumer_ReadMessageTimeout(t *testing.T) { t.Parallel() @@ -669,6 +653,7 @@ func TestBatchConsumer_CommitError(t *testing.T) { mockKafka.On("Commit").Return(nil, expect) consumer.handler = handler + assignBatchPartitions(t, consumer, mockKafka, testTopics[0], 1) err := consumer.Run(ctx) assert.Error(t, err) @@ -679,6 +664,459 @@ func TestBatchConsumer_CommitError(t *testing.T) { } } +func TestBatchConsumer_RebalanceCallback(t *testing.T) { + t.Parallel() + + t.Run("AssignedPartitions", func(t *testing.T) { + consumer, mockKafka := newTestBatchConsumer(t, defaultOpts...) + + topic1 := "topic1" + topic2 := "topic2" + partitions := []kafka.TopicPartition{ + {Topic: &topic1, Partition: 0}, + {Topic: &topic1, Partition: 1}, + {Topic: &topic2, Partition: 0}, + } + + mockKafka.On("Assign", partitions).Return(nil) + + event := kafka.AssignedPartitions{Partitions: partitions} + err := consumer.rebalanceCallback(nil, event) + + assert.NoError(t, err) + + // Verify active partitions are tracked + assert.True(t, consumer.isPartitionActive(topic1, 0)) + assert.True(t, consumer.isPartitionActive(topic1, 1)) + assert.True(t, consumer.isPartitionActive(topic2, 0)) + assert.False(t, consumer.isPartitionActive(topic2, 1)) + + mockKafka.AssertExpectations(t) + }) + + t.Run("AssignedPartitionsError", func(t *testing.T) { + consumer, mockKafka := newTestBatchConsumer(t, defaultOpts...) + + topic := "topic1" + partitions := []kafka.TopicPartition{ + {Topic: &topic, Partition: 0}, + } + + mockKafka.On("Assign", partitions).Return(assert.AnError) + + event := kafka.AssignedPartitions{Partitions: partitions} + err := consumer.rebalanceCallback(nil, event) + + assert.ErrorIs(t, err, assert.AnError) + mockKafka.AssertExpectations(t) + }) + + t.Run("RevokedPartitions", func(t *testing.T) { + consumer, mockKafka := newTestBatchConsumer(t, defaultOpts...) + + topic1 := "topic1" + topic2 := "topic2" + + // First assign partitions + assignPartitions := []kafka.TopicPartition{ + {Topic: &topic1, Partition: 0}, + {Topic: &topic1, Partition: 1}, + {Topic: &topic2, Partition: 0}, + } + mockKafka.On("Assign", assignPartitions).Return(nil) + + assignEvent := kafka.AssignedPartitions{Partitions: assignPartitions} + err := consumer.rebalanceCallback(nil, assignEvent) + assert.NoError(t, err) + + // Now revoke some partitions + revokePartitions := []kafka.TopicPartition{ + {Topic: &topic1, Partition: 1}, + {Topic: &topic2, Partition: 0}, + } + mockKafka.On("Unassign").Return(nil) + + revokeEvent := kafka.RevokedPartitions{Partitions: revokePartitions} + err = consumer.rebalanceCallback(nil, revokeEvent) + assert.NoError(t, err) + + // Verify only non-revoked partitions are active + assert.True(t, consumer.isPartitionActive(topic1, 0)) + assert.False(t, consumer.isPartitionActive(topic1, 1)) + assert.False(t, consumer.isPartitionActive(topic2, 0)) + + mockKafka.AssertExpectations(t) + }) + + t.Run("RevokedPartitionsError", func(t *testing.T) { + consumer, mockKafka := newTestBatchConsumer(t, defaultOpts...) + + topic := "topic1" + partitions := []kafka.TopicPartition{ + {Topic: &topic, Partition: 0}, + } + + mockKafka.On("Unassign").Return(assert.AnError) + + event := kafka.RevokedPartitions{Partitions: partitions} + err := consumer.rebalanceCallback(nil, event) + + assert.ErrorIs(t, err, assert.AnError) + mockKafka.AssertExpectations(t) + }) + + t.Run("RevokeAllPartitionsRemovesTopic", func(t *testing.T) { + consumer, mockKafka := newTestBatchConsumer(t, defaultOpts...) + + topic := "topic1" + + // Assign a single partition + assignPartitions := []kafka.TopicPartition{ + {Topic: &topic, Partition: 0}, + } + mockKafka.On("Assign", assignPartitions).Return(nil) + + assignEvent := kafka.AssignedPartitions{Partitions: assignPartitions} + err := consumer.rebalanceCallback(nil, assignEvent) + assert.NoError(t, err) + + // Revoke all partitions for the topic + mockKafka.On("Unassign").Return(nil) + + revokeEvent := kafka.RevokedPartitions{Partitions: assignPartitions} + err = consumer.rebalanceCallback(nil, revokeEvent) + assert.NoError(t, err) + + // Verify topic is removed from active partitions + consumer.mu.Lock() + _, exists := consumer.activePartitions[topic] + consumer.mu.Unlock() + assert.False(t, exists) + + mockKafka.AssertExpectations(t) + }) + + t.Run("NilTopicPartition", func(t *testing.T) { + consumer, mockKafka := newTestBatchConsumer(t, defaultOpts...) + + // Test with nil topic in partition + partitions := []kafka.TopicPartition{ + {Topic: nil, Partition: 0}, + } + + mockKafka.On("Assign", partitions).Return(nil) + + event := kafka.AssignedPartitions{Partitions: partitions} + err := consumer.rebalanceCallback(nil, event) + assert.NoError(t, err) + + // Verify no panic and empty active partitions + consumer.mu.Lock() + assert.Len(t, consumer.activePartitions, 0) + consumer.mu.Unlock() + + mockKafka.AssertExpectations(t) + }) + + t.Run("UnknownEvent", func(t *testing.T) { + consumer, _ := newTestBatchConsumer(t, defaultOpts...) + + // Test with an unknown event type (use kafka.Error as example) + event := kafka.NewError(kafka.ErrUnknown, "unknown", false) + err := consumer.rebalanceCallback(nil, event) + + assert.NoError(t, err) + }) +} + +func TestBatchConsumer_StoreBatch(t *testing.T) { + t.Parallel() + + t.Run("EmptyBatch", func(t *testing.T) { + consumer, _ := newTestBatchConsumer(t, defaultOpts...) + + batch := NewBatch() + err := consumer.storeBatch(batch) + + assert.NoError(t, err) + }) + + t.Run("BatchWithFailStatus", func(t *testing.T) { + consumer, mockKafka := newTestBatchConsumer(t, defaultOpts...) + + batch := NewBatch() + batch.Messages = append(batch.Messages, &Message{ + Topic: "topic1", + Partition: 0, + Offset: 100, + }) + batch.AckFail(assert.AnError) + + err := consumer.storeBatch(batch) + + assert.NoError(t, err) + mockKafka.AssertNotCalled(t, "StoreOffsets") + }) + + t.Run("BatchWithSuccessStatus", func(t *testing.T) { + consumer, mockKafka := newTestBatchConsumer(t, defaultOpts...) + + topic := "topic1" + assignBatchPartitions(t, consumer, mockKafka, topic, 0) + + batch := NewBatch() + batch.Messages = append(batch.Messages, &Message{ + Topic: topic, + Partition: 0, + Offset: 100, + }) + batch.AckSuccess() + + mockKafka.On("StoreOffsets", mock.MatchedBy(func(tps []kafka.TopicPartition) bool { + return len(tps) == 1 && *tps[0].Topic == topic && tps[0].Partition == 0 + })).Return(nil, nil) + + err := consumer.storeBatch(batch) + + assert.NoError(t, err) + mockKafka.AssertExpectations(t) + }) + + t.Run("BatchWithSkipStatus", func(t *testing.T) { + consumer, mockKafka := newTestBatchConsumer(t, defaultOpts...) + + topic := "topic1" + assignBatchPartitions(t, consumer, mockKafka, topic, 0) + + batch := NewBatch() + batch.Messages = append(batch.Messages, &Message{ + Topic: topic, + Partition: 0, + Offset: 100, + }) + batch.AckSkip() + + mockKafka.On("StoreOffsets", mock.MatchedBy(func(tps []kafka.TopicPartition) bool { + return len(tps) == 1 && *tps[0].Topic == topic && tps[0].Partition == 0 + })).Return(nil, nil) + + err := consumer.storeBatch(batch) + + assert.NoError(t, err) + mockKafka.AssertExpectations(t) + }) + + t.Run("StopOffsetPreventsStore", func(t *testing.T) { + consumer, mockKafka := newTestBatchConsumer(t, defaultOpts...) + + topic := "topic1" + batch := NewBatch() + batch.Messages = append(batch.Messages, &Message{ + Topic: topic, + Partition: 0, + Offset: 100, + }) + batch.AckSuccess() + + // Set stopOffset flag + consumer.stopOffset.Store(true) + + err := consumer.storeBatch(batch) + + assert.NoError(t, err) + mockKafka.AssertNotCalled(t, "StoreOffsets") + }) + + t.Run("FilterInactivePartitions", func(t *testing.T) { + consumer, mockKafka := newTestBatchConsumer(t, defaultOpts...) + + topic := "topic1" + + // Assign only partition 0 + partitions := []kafka.TopicPartition{ + {Topic: &topic, Partition: 0}, + } + mockKafka.On("Assign", partitions).Return(nil) + + assignEvent := kafka.AssignedPartitions{Partitions: partitions} + err := consumer.rebalanceCallback(nil, assignEvent) + require.NoError(t, err) + + // Batch has messages from partition 0 (active) and partition 1 (inactive) + batch := NewBatch() + batch.Messages = append(batch.Messages, + &Message{Topic: topic, Partition: 0, Offset: 100}, + &Message{Topic: topic, Partition: 1, Offset: 200}, + ) + batch.AckSuccess() + + // Only partition 0 should be stored + mockKafka.On("StoreOffsets", mock.MatchedBy(func(tps []kafka.TopicPartition) bool { + if len(tps) != 1 { + return false + } + return *tps[0].Topic == topic && tps[0].Partition == 0 + })).Return(nil, nil) + + err = consumer.storeBatch(batch) + + assert.NoError(t, err) + mockKafka.AssertExpectations(t) + }) + + t.Run("AllPartitionsInactive", func(t *testing.T) { + consumer, mockKafka := newTestBatchConsumer(t, defaultOpts...) + + topic := "topic1" + + // Assign only partition 0 + partitions := []kafka.TopicPartition{ + {Topic: &topic, Partition: 0}, + } + mockKafka.On("Assign", partitions).Return(nil) + + assignEvent := kafka.AssignedPartitions{Partitions: partitions} + err := consumer.rebalanceCallback(nil, assignEvent) + require.NoError(t, err) + + // Batch has only messages from inactive partitions + batch := NewBatch() + batch.Messages = append(batch.Messages, + &Message{Topic: topic, Partition: 1, Offset: 100}, + &Message{Topic: topic, Partition: 2, Offset: 200}, + ) + batch.AckSuccess() + + // StoreOffsets should not be called since all partitions are inactive + err = consumer.storeBatch(batch) + + assert.NoError(t, err) + mockKafka.AssertExpectations(t) + }) + + t.Run("StoreOffsetsError", func(t *testing.T) { + consumer, mockKafka := newTestBatchConsumer(t, defaultOpts...) + + topic := "topic1" + assignBatchPartitions(t, consumer, mockKafka, topic, 0) + + batch := NewBatch() + batch.Messages = append(batch.Messages, &Message{ + Topic: topic, + Partition: 0, + Offset: 100, + }) + batch.AckSuccess() + + mockKafka.On("StoreOffsets", mock.Anything).Return(nil, assert.AnError) + + err := consumer.storeBatch(batch) + + assert.ErrorIs(t, err, assert.AnError) + mockKafka.AssertExpectations(t) + }) + + t.Run("ManualCommitEnabled", func(t *testing.T) { + opts := append(defaultOpts, ManualCommit(true)) + consumer, mockKafka := newTestBatchConsumer(t, opts...) + + topic := "topic1" + assignBatchPartitions(t, consumer, mockKafka, topic, 0) + + batch := NewBatch() + batch.Messages = append(batch.Messages, &Message{ + Topic: topic, + Partition: 0, + Offset: 100, + }) + batch.AckSuccess() + + mockKafka.On("StoreOffsets", mock.Anything).Return(nil, nil) + mockKafka.On("Commit").Return(nil, nil) + + err := consumer.storeBatch(batch) + + assert.NoError(t, err) + mockKafka.AssertExpectations(t) + }) + + t.Run("ManualCommitError", func(t *testing.T) { + opts := append(defaultOpts, ManualCommit(true)) + consumer, mockKafka := newTestBatchConsumer(t, opts...) + + topic := "topic1" + assignBatchPartitions(t, consumer, mockKafka, topic, 0) + + batch := NewBatch() + batch.Messages = append(batch.Messages, &Message{ + Topic: topic, + Partition: 0, + Offset: 100, + }) + batch.AckSuccess() + + mockKafka.On("StoreOffsets", mock.Anything).Return(nil, nil) + mockKafka.On("Commit").Return(nil, assert.AnError) + + err := consumer.storeBatch(batch) + + assert.ErrorIs(t, err, assert.AnError) + mockKafka.AssertExpectations(t) + }) + + t.Run("MultipleTopicsAndPartitions", func(t *testing.T) { + consumer, mockKafka := newTestBatchConsumer(t, defaultOpts...) + + topic1 := "topic1" + topic2 := "topic2" + + // Assign all partitions that will be used + assignBatchPartitions(t, consumer, mockKafka, topic1, 0, 1) + assignBatchPartitions(t, consumer, mockKafka, topic2, 0) + + batch := NewBatch() + batch.Messages = append(batch.Messages, + &Message{Topic: topic1, Partition: 0, Offset: 100}, + &Message{Topic: topic1, Partition: 0, Offset: 150}, + &Message{Topic: topic1, Partition: 1, Offset: 50}, + &Message{Topic: topic2, Partition: 0, Offset: 200}, + ) + batch.AckSuccess() + + mockKafka.On("StoreOffsets", mock.MatchedBy(func(tps []kafka.TopicPartition) bool { + // Should have 3 topic partitions: topic1-0, topic1-1, topic2-0 + return len(tps) == 3 + })).Return(nil, nil) + + err := consumer.storeBatch(batch) + + assert.NoError(t, err) + mockKafka.AssertExpectations(t) + }) + + t.Run("OffsetIncrementedByOne", func(t *testing.T) { + consumer, mockKafka := newTestBatchConsumer(t, defaultOpts...) + + topic := "topic1" + assignBatchPartitions(t, consumer, mockKafka, topic, 0) + + batch := NewBatch() + batch.Messages = append(batch.Messages, + &Message{Topic: topic, Partition: 0, Offset: 999}, + ) + batch.AckSuccess() + + mockKafka.On("StoreOffsets", mock.MatchedBy(func(tps []kafka.TopicPartition) bool { + return len(tps) == 1 && tps[0].Offset == kafka.Offset(1000) + })).Return(nil, nil) + + err := consumer.storeBatch(batch) + + assert.NoError(t, err) + mockKafka.AssertExpectations(t) + }) +} + func noopBatchHandler() BatchHandler { return BatchHandlerFunc(func(ctx context.Context, b *Batch) error { return nil @@ -714,3 +1152,20 @@ func testBatchMiddleware(name string, preExec, postExec *[]string) BatchMiddlewa }) }) } + +// assignBatchPartitions is a test helper that assigns partitions to a batch consumer +// so that isPartitionActive returns true for those partitions. +func assignBatchPartitions(t *testing.T, consumer *BatchConsumer, mockKafka *MockConsumerClient, topic string, partitions ...int32) { + t.Helper() + + tps := make([]kafka.TopicPartition, len(partitions)) + for i, p := range partitions { + tps[i] = kafka.TopicPartition{Topic: &topic, Partition: p} + } + + mockKafka.On("Assign", tps).Return(nil).Once() + + event := kafka.AssignedPartitions{Partitions: tps} + err := consumer.rebalanceCallback(nil, event) + require.NoError(t, err) +} diff --git a/xkafka/batch_test.go b/xkafka/batch_test.go index e00952b..7ac3ff9 100644 --- a/xkafka/batch_test.go +++ b/xkafka/batch_test.go @@ -65,7 +65,7 @@ func TestBatch_OffsetMethods(t *testing.T) { { Topic: strPtr("topic1"), Partition: 0, - Offset: kafka.Offset(11), + Offset: kafka.Offset(10), }, }, }, @@ -81,17 +81,17 @@ func TestBatch_OffsetMethods(t *testing.T) { { Topic: strPtr("topic1"), Partition: 0, - Offset: kafka.Offset(6), + Offset: kafka.Offset(5), }, { Topic: strPtr("topic1"), Partition: 1, - Offset: kafka.Offset(11), + Offset: kafka.Offset(10), }, { Topic: strPtr("topic2"), Partition: 0, - Offset: kafka.Offset(16), + Offset: kafka.Offset(15), }, }, }, diff --git a/xkafka/confluent.go b/xkafka/confluent.go index 98c5055..0c5c68b 100644 --- a/xkafka/confluent.go +++ b/xkafka/confluent.go @@ -11,6 +11,8 @@ type consumerClient interface { ReadMessage(timeout time.Duration) (*kafka.Message, error) SubscribeTopics(topics []string, rebalanceCb kafka.RebalanceCb) error Unsubscribe() error + Assign(partitions []kafka.TopicPartition) error + Unassign() error StoreOffsets(offsets []kafka.TopicPartition) ([]kafka.TopicPartition, error) Commit() ([]kafka.TopicPartition, error) Close() error diff --git a/xkafka/consumer.go b/xkafka/consumer.go index 9d237ba..61e031c 100644 --- a/xkafka/consumer.go +++ b/xkafka/consumer.go @@ -4,6 +4,7 @@ import ( "context" "errors" "strings" + "sync" "sync/atomic" "time" @@ -21,6 +22,10 @@ type Consumer struct { config *consumerConfig cancelCtx atomic.Pointer[context.CancelFunc] stopOffset atomic.Bool + + // partition tracking + mu sync.Mutex + activePartitions map[string]map[int32]struct{} } // NewConsumer creates a new Consumer instance. @@ -45,10 +50,11 @@ func NewConsumer(name string, handler Handler, opts ...ConsumerOption) (*Consume } return &Consumer{ - name: name, - config: cfg, - kafka: consumer, - handler: handler, + name: name, + config: cfg, + kafka: consumer, + handler: handler, + activePartitions: make(map[string]map[int32]struct{}), }, nil } @@ -234,6 +240,11 @@ func (c *Consumer) storeMessage(msg *Message) error { return nil } + // only store offset if partition is still active + if !c.isPartitionActive(msg.Topic, msg.Partition) { + return nil + } + // similar to StoreMessage in confluent-kafka-go/consumer.go // msg.Offset + 1 it ensures that the consumer starts with // next message when it restarts @@ -267,7 +278,73 @@ func (c *Consumer) concatMiddlewares(h Handler) Handler { } func (c *Consumer) subscribe() error { - return c.kafka.SubscribeTopics(c.config.topics, nil) + return c.kafka.SubscribeTopics(c.config.topics, c.rebalanceCallback) +} + +func (c *Consumer) rebalanceCallback(_ *kafka.Consumer, event kafka.Event) error { + switch e := event.(type) { + case kafka.AssignedPartitions: + c.onPartitionsAssigned(e.Partitions) + return c.kafka.Assign(e.Partitions) + + case kafka.RevokedPartitions: + if err := c.kafka.Unassign(); err != nil { + return err + } + c.onPartitionsRevoked(e.Partitions) + } + + return nil +} + +func (c *Consumer) onPartitionsAssigned(partitions []kafka.TopicPartition) { + c.mu.Lock() + defer c.mu.Unlock() + + for _, tp := range partitions { + if tp.Topic == nil { + continue + } + + topic := *tp.Topic + if c.activePartitions[topic] == nil { + c.activePartitions[topic] = make(map[int32]struct{}) + } + + c.activePartitions[topic][tp.Partition] = struct{}{} + } +} + +func (c *Consumer) onPartitionsRevoked(partitions []kafka.TopicPartition) { + c.mu.Lock() + defer c.mu.Unlock() + + for _, tp := range partitions { + if tp.Topic == nil { + continue + } + + topic := *tp.Topic + if c.activePartitions[topic] != nil { + delete(c.activePartitions[topic], tp.Partition) + + if len(c.activePartitions[topic]) == 0 { + delete(c.activePartitions, topic) + } + } + } +} + +func (c *Consumer) isPartitionActive(topic string, partition int32) bool { + c.mu.Lock() + defer c.mu.Unlock() + + if partitions, ok := c.activePartitions[topic]; ok { + _, active := partitions[partition] + return active + } + + return false } func (c *Consumer) unsubscribe() error { diff --git a/xkafka/consumer_mock_test.go b/xkafka/consumer_mock_test.go index 791a408..2ed8342 100644 --- a/xkafka/consumer_mock_test.go +++ b/xkafka/consumer_mock_test.go @@ -132,6 +132,34 @@ func (_m *MockConsumerClient) StoreOffsets(offsets []kafka.TopicPartition) ([]ka return r0, r1 } +// Assign provides a mock function with given fields: partitions +func (_m *MockConsumerClient) Assign(partitions []kafka.TopicPartition) error { + ret := _m.Called(partitions) + + var r0 error + if rf, ok := ret.Get(0).(func([]kafka.TopicPartition) error); ok { + r0 = rf(partitions) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Unassign provides a mock function with given fields: +func (_m *MockConsumerClient) Unassign() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + // SubscribeTopics provides a mock function with given fields: topics, rebalanceCb func (_m *MockConsumerClient) SubscribeTopics(topics []string, rebalanceCb kafka.RebalanceCb) error { ret := _m.Called(topics, rebalanceCb) diff --git a/xkafka/consumer_test.go b/xkafka/consumer_test.go index 0967fc1..480644c 100644 --- a/xkafka/consumer_test.go +++ b/xkafka/consumer_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "sync" + "sync/atomic" "testing" "time" @@ -178,22 +179,6 @@ func TestConsumerGetMetadata(t *testing.T) { mockKafka.AssertExpectations(t) } -func TestConsumerSubscribeError(t *testing.T) { - t.Parallel() - - consumer, mockKafka := newTestConsumer(t, defaultOpts...) - - ctx := context.Background() - expectError := errors.New("error") - - mockKafka.On("SubscribeTopics", []string(testTopics), mock.Anything).Return(expectError) - - err := consumer.Run(ctx) - assert.EqualError(t, err, expectError.Error()) - - mockKafka.AssertExpectations(t) -} - func TestConsumerUnsubscribeError(t *testing.T) { t.Parallel() @@ -500,38 +485,6 @@ func TestConsumerMiddlewareExecutionOrder(t *testing.T) { mockKafka.AssertExpectations(t) } -func TestConsumerManualCommit(t *testing.T) { - t.Parallel() - - consumer, mockKafka := newTestConsumer(t, - append(defaultOpts, ManualCommit(true))...) - - km := newFakeKafkaMessage() - ctx, cancel := context.WithCancel(context.Background()) - - mockKafka.On("SubscribeTopics", []string(testTopics), mock.Anything).Return(nil) - mockKafka.On("Unsubscribe").Return(nil) - mockKafka.On("StoreOffsets", mock.Anything).Return(nil, nil) - mockKafka.On("Commit").Return(nil, nil) - mockKafka.On("ReadMessage", testTimeout).Return(km, nil) - mockKafka.On("Close").Return(nil) - - handler := HandlerFunc(func(ctx context.Context, msg *Message) error { - cancel() - - msg.AckSuccess() - - return nil - }) - - consumer.handler = handler - - err := consumer.Run(ctx) - assert.NoError(t, err) - - mockKafka.AssertExpectations(t) -} - func TestConsumerAsync(t *testing.T) { t.Parallel() @@ -541,7 +494,13 @@ func TestConsumerAsync(t *testing.T) { km := newFakeKafkaMessage() ctx, cancel := context.WithCancel(context.Background()) - mockKafka.On("SubscribeTopics", []string(testTopics), mock.Anything).Return(nil) + // Capture rebalance callback and trigger partition assignment + mockKafka.On("SubscribeTopics", []string(testTopics), mock.Anything).Run(func(args mock.Arguments) { + cb := args.Get(1).(kafka.RebalanceCb) + partitions := []kafka.TopicPartition{{Topic: km.TopicPartition.Topic, Partition: km.TopicPartition.Partition}} + mockKafka.On("Assign", partitions).Return(nil).Once() + cb(nil, kafka.AssignedPartitions{Partitions: partitions}) + }).Return(nil) mockKafka.On("Unsubscribe").Return(nil) mockKafka.On("StoreOffsets", mock.Anything).Return(nil, nil) mockKafka.On("ReadMessage", testTimeout).Return(km, nil) @@ -583,26 +542,24 @@ func TestConsumerAsync_StopOffsetOnError(t *testing.T) { km := newFakeKafkaMessage() ctx, cancel := context.WithCancel(context.Background()) - mockKafka.On("SubscribeTopics", []string(testTopics), mock.Anything).Return(nil) + // Capture rebalance callback and trigger partition assignment + mockKafka.On("SubscribeTopics", []string(testTopics), mock.Anything).Run(func(args mock.Arguments) { + cb := args.Get(1).(kafka.RebalanceCb) + partitions := []kafka.TopicPartition{{Topic: km.TopicPartition.Topic, Partition: km.TopicPartition.Partition}} + mockKafka.On("Assign", partitions).Return(nil).Once() + cb(nil, kafka.AssignedPartitions{Partitions: partitions}) + }).Return(nil) mockKafka.On("Unsubscribe").Return(nil) mockKafka.On("ReadMessage", testTimeout).Return(km, nil) mockKafka.On("Commit").Return(nil, nil) mockKafka.On("Close").Return(nil) - mockKafka.On("StoreOffsets", mock.Anything). - Return(nil, nil). - Times(2) + mockKafka.On("StoreOffsets", mock.Anything).Return(nil, nil) - var recv []*Message - var mu sync.Mutex + var recv atomic.Int32 handler := HandlerFunc(func(ctx context.Context, msg *Message) error { - mu.Lock() - defer mu.Unlock() - - recv = append(recv, msg) - - if len(recv) > 2 { + if recv.Load() > 2 { err := assert.AnError msg.AckFail(err) @@ -611,6 +568,7 @@ func TestConsumerAsync_StopOffsetOnError(t *testing.T) { return err } + recv.Add(1) msg.AckSuccess() return nil @@ -666,6 +624,7 @@ func TestConsumerStoreOffsetsError(t *testing.T) { mockKafka.On("ReadMessage", testTimeout).Return(km, nil) consumer.handler = handler + assignPartitions(t, consumer, mockKafka, testTopics[0], 1) err := consumer.Run(ctx) assert.Error(t, err) @@ -719,6 +678,7 @@ func TestConsumerCommitError(t *testing.T) { mockKafka.On("ReadMessage", testTimeout).Return(km, nil) consumer.handler = handler + assignPartitions(t, consumer, mockKafka, testTopics[0], 1) err := consumer.Run(ctx) assert.Error(t, err) @@ -741,6 +701,377 @@ func testMiddleware(name string, pre, post *[]string) MiddlewareFunc { } } +func TestConsumer_RebalanceCallback(t *testing.T) { + t.Parallel() + + t.Run("AssignedPartitions", func(t *testing.T) { + consumer, mockKafka := newTestConsumer(t, defaultOpts...) + + topic1 := "topic1" + topic2 := "topic2" + partitions := []kafka.TopicPartition{ + {Topic: &topic1, Partition: 0}, + {Topic: &topic1, Partition: 1}, + {Topic: &topic2, Partition: 0}, + } + + mockKafka.On("Assign", partitions).Return(nil) + + event := kafka.AssignedPartitions{Partitions: partitions} + err := consumer.rebalanceCallback(nil, event) + + assert.NoError(t, err) + + // Verify active partitions are tracked + assert.True(t, consumer.isPartitionActive(topic1, 0)) + assert.True(t, consumer.isPartitionActive(topic1, 1)) + assert.True(t, consumer.isPartitionActive(topic2, 0)) + assert.False(t, consumer.isPartitionActive(topic2, 1)) + + mockKafka.AssertExpectations(t) + }) + + t.Run("AssignedPartitionsError", func(t *testing.T) { + consumer, mockKafka := newTestConsumer(t, defaultOpts...) + + topic := "topic1" + partitions := []kafka.TopicPartition{ + {Topic: &topic, Partition: 0}, + } + + mockKafka.On("Assign", partitions).Return(assert.AnError) + + event := kafka.AssignedPartitions{Partitions: partitions} + err := consumer.rebalanceCallback(nil, event) + + assert.ErrorIs(t, err, assert.AnError) + mockKafka.AssertExpectations(t) + }) + + t.Run("RevokedPartitions", func(t *testing.T) { + consumer, mockKafka := newTestConsumer(t, defaultOpts...) + + topic1 := "topic1" + topic2 := "topic2" + + // First assign partitions + assignPartitions := []kafka.TopicPartition{ + {Topic: &topic1, Partition: 0}, + {Topic: &topic1, Partition: 1}, + {Topic: &topic2, Partition: 0}, + } + mockKafka.On("Assign", assignPartitions).Return(nil) + + assignEvent := kafka.AssignedPartitions{Partitions: assignPartitions} + err := consumer.rebalanceCallback(nil, assignEvent) + assert.NoError(t, err) + + // Now revoke some partitions + revokePartitions := []kafka.TopicPartition{ + {Topic: &topic1, Partition: 1}, + {Topic: &topic2, Partition: 0}, + } + mockKafka.On("Unassign").Return(nil) + + revokeEvent := kafka.RevokedPartitions{Partitions: revokePartitions} + err = consumer.rebalanceCallback(nil, revokeEvent) + assert.NoError(t, err) + + // Verify only non-revoked partitions are active + assert.True(t, consumer.isPartitionActive(topic1, 0)) + assert.False(t, consumer.isPartitionActive(topic1, 1)) + assert.False(t, consumer.isPartitionActive(topic2, 0)) + + mockKafka.AssertExpectations(t) + }) + + t.Run("RevokedPartitionsError", func(t *testing.T) { + consumer, mockKafka := newTestConsumer(t, defaultOpts...) + + topic := "topic1" + partitions := []kafka.TopicPartition{ + {Topic: &topic, Partition: 0}, + } + + mockKafka.On("Unassign").Return(assert.AnError) + + event := kafka.RevokedPartitions{Partitions: partitions} + err := consumer.rebalanceCallback(nil, event) + + assert.ErrorIs(t, err, assert.AnError) + mockKafka.AssertExpectations(t) + }) + + t.Run("RevokeAllPartitionsRemovesTopic", func(t *testing.T) { + consumer, mockKafka := newTestConsumer(t, defaultOpts...) + + topic := "topic1" + + // Assign a single partition + assignPartitions := []kafka.TopicPartition{ + {Topic: &topic, Partition: 0}, + } + mockKafka.On("Assign", assignPartitions).Return(nil) + + assignEvent := kafka.AssignedPartitions{Partitions: assignPartitions} + err := consumer.rebalanceCallback(nil, assignEvent) + assert.NoError(t, err) + + // Revoke all partitions for the topic + mockKafka.On("Unassign").Return(nil) + + revokeEvent := kafka.RevokedPartitions{Partitions: assignPartitions} + err = consumer.rebalanceCallback(nil, revokeEvent) + assert.NoError(t, err) + + // Verify topic is removed from active partitions + consumer.mu.Lock() + _, exists := consumer.activePartitions[topic] + consumer.mu.Unlock() + assert.False(t, exists) + + mockKafka.AssertExpectations(t) + }) + + t.Run("NilTopicPartition", func(t *testing.T) { + consumer, mockKafka := newTestConsumer(t, defaultOpts...) + + // Test with nil topic in partition + partitions := []kafka.TopicPartition{ + {Topic: nil, Partition: 0}, + } + + mockKafka.On("Assign", partitions).Return(nil) + + event := kafka.AssignedPartitions{Partitions: partitions} + err := consumer.rebalanceCallback(nil, event) + assert.NoError(t, err) + + // Verify no panic and empty active partitions + consumer.mu.Lock() + assert.Len(t, consumer.activePartitions, 0) + consumer.mu.Unlock() + + mockKafka.AssertExpectations(t) + }) + + t.Run("UnknownEvent", func(t *testing.T) { + consumer, _ := newTestConsumer(t, defaultOpts...) + + // Test with an unknown event type (use kafka.Error as example) + event := kafka.NewError(kafka.ErrUnknown, "unknown", false) + err := consumer.rebalanceCallback(nil, event) + + assert.NoError(t, err) + }) +} + +func TestConsumer_StoreMessage(t *testing.T) { + t.Parallel() + + t.Run("MessageWithFailStatus", func(t *testing.T) { + consumer, mockKafka := newTestConsumer(t, defaultOpts...) + + msg := &Message{ + Topic: "topic1", + Partition: 0, + Offset: 100, + Status: Fail, + } + + err := consumer.storeMessage(msg) + + assert.NoError(t, err) + mockKafka.AssertNotCalled(t, "StoreOffsets") + }) + + t.Run("MessageWithSuccessStatus", func(t *testing.T) { + consumer, mockKafka := newTestConsumer(t, defaultOpts...) + + topic := "topic1" + assignPartitions(t, consumer, mockKafka, topic, 0) + + msg := &Message{ + Topic: topic, + Partition: 0, + Offset: 100, + Status: Success, + } + + mockKafka.On("StoreOffsets", mock.MatchedBy(func(tps []kafka.TopicPartition) bool { + return len(tps) == 1 && *tps[0].Topic == topic && tps[0].Partition == 0 && tps[0].Offset == 101 + })).Return(nil, nil) + + err := consumer.storeMessage(msg) + + assert.NoError(t, err) + mockKafka.AssertExpectations(t) + }) + + t.Run("MessageWithSkipStatus", func(t *testing.T) { + consumer, mockKafka := newTestConsumer(t, defaultOpts...) + + topic := "topic1" + assignPartitions(t, consumer, mockKafka, topic, 0) + + msg := &Message{ + Topic: topic, + Partition: 0, + Offset: 100, + Status: Skip, + } + + mockKafka.On("StoreOffsets", mock.MatchedBy(func(tps []kafka.TopicPartition) bool { + return len(tps) == 1 && *tps[0].Topic == topic && tps[0].Partition == 0 && tps[0].Offset == 101 + })).Return(nil, nil) + + err := consumer.storeMessage(msg) + + assert.NoError(t, err) + mockKafka.AssertExpectations(t) + }) + + t.Run("StopOffsetPreventsStore", func(t *testing.T) { + consumer, mockKafka := newTestConsumer(t, defaultOpts...) + + msg := &Message{ + Topic: "topic1", + Partition: 0, + Offset: 100, + Status: Success, + } + + // Set stopOffset flag + consumer.stopOffset.Store(true) + + err := consumer.storeMessage(msg) + + assert.NoError(t, err) + mockKafka.AssertNotCalled(t, "StoreOffsets") + }) + + t.Run("FilterInactivePartition", func(t *testing.T) { + consumer, mockKafka := newTestConsumer(t, defaultOpts...) + + topic := "topic1" + + // Assign only partition 0 + partitions := []kafka.TopicPartition{ + {Topic: &topic, Partition: 0}, + } + mockKafka.On("Assign", partitions).Return(nil) + + assignEvent := kafka.AssignedPartitions{Partitions: partitions} + err := consumer.rebalanceCallback(nil, assignEvent) + require.NoError(t, err) + + // Message is from partition 1 (inactive) + msg := &Message{ + Topic: topic, + Partition: 1, + Offset: 100, + Status: Success, + } + + // StoreOffsets should not be called for inactive partition + err = consumer.storeMessage(msg) + + assert.NoError(t, err) + mockKafka.AssertExpectations(t) + }) + + t.Run("StoreOffsetsError", func(t *testing.T) { + consumer, mockKafka := newTestConsumer(t, defaultOpts...) + + topic := "topic1" + assignPartitions(t, consumer, mockKafka, topic, 0) + + msg := &Message{ + Topic: topic, + Partition: 0, + Offset: 100, + Status: Success, + } + + mockKafka.On("StoreOffsets", mock.Anything).Return(nil, assert.AnError) + + err := consumer.storeMessage(msg) + + assert.ErrorIs(t, err, assert.AnError) + mockKafka.AssertExpectations(t) + }) + + t.Run("ManualCommitEnabled", func(t *testing.T) { + opts := append(defaultOpts, ManualCommit(true)) + consumer, mockKafka := newTestConsumer(t, opts...) + + topic := "topic1" + assignPartitions(t, consumer, mockKafka, topic, 0) + + msg := &Message{ + Topic: topic, + Partition: 0, + Offset: 100, + Status: Success, + } + + mockKafka.On("StoreOffsets", mock.Anything).Return(nil, nil) + mockKafka.On("Commit").Return(nil, nil) + + err := consumer.storeMessage(msg) + + assert.NoError(t, err) + mockKafka.AssertExpectations(t) + }) + + t.Run("ManualCommitError", func(t *testing.T) { + opts := append(defaultOpts, ManualCommit(true)) + consumer, mockKafka := newTestConsumer(t, opts...) + + topic := "topic1" + assignPartitions(t, consumer, mockKafka, topic, 0) + + msg := &Message{ + Topic: topic, + Partition: 0, + Offset: 100, + Status: Success, + } + + mockKafka.On("StoreOffsets", mock.Anything).Return(nil, nil) + mockKafka.On("Commit").Return(nil, assert.AnError) + + err := consumer.storeMessage(msg) + + assert.ErrorIs(t, err, assert.AnError) + mockKafka.AssertExpectations(t) + }) + + t.Run("OffsetIncrementedByOne", func(t *testing.T) { + consumer, mockKafka := newTestConsumer(t, defaultOpts...) + + topic := "topic1" + assignPartitions(t, consumer, mockKafka, topic, 0) + + msg := &Message{ + Topic: topic, + Partition: 0, + Offset: 999, + Status: Success, + } + + // Verify offset is incremented by 1 + mockKafka.On("StoreOffsets", mock.MatchedBy(func(tps []kafka.TopicPartition) bool { + return len(tps) == 1 && tps[0].Offset == kafka.Offset(1000) + })).Return(nil, nil) + + err := consumer.storeMessage(msg) + + assert.NoError(t, err) + mockKafka.AssertExpectations(t) + }) +} + func newTestConsumer(t *testing.T, opts ...ConsumerOption) (*Consumer, *MockConsumerClient) { t.Helper() @@ -778,3 +1109,20 @@ func noopHandler() Handler { return nil }) } + +// assignPartitions is a test helper that assigns partitions to a consumer +// so that isPartitionActive returns true for those partitions. +func assignPartitions(t *testing.T, consumer *Consumer, mockKafka *MockConsumerClient, topic string, partitions ...int32) { + t.Helper() + + tps := make([]kafka.TopicPartition, len(partitions)) + for i, p := range partitions { + tps[i] = kafka.TopicPartition{Topic: &topic, Partition: p} + } + + mockKafka.On("Assign", tps).Return(nil).Once() + + event := kafka.AssignedPartitions{Partitions: tps} + err := consumer.rebalanceCallback(nil, event) + require.NoError(t, err) +}