From 8f044070ca193ef88b26b99b98eef1f43eae840d Mon Sep 17 00:00:00 2001 From: Ravi Atluri Date: Thu, 4 Dec 2025 14:26:50 +0530 Subject: [PATCH 1/5] Implement partition tracking and rebalance handling in Consumer and BatchConsumer --- xkafka/batch.go | 6 +- xkafka/batch_consumer.go | 107 +++++++- xkafka/batch_consumer_test.go | 466 +++++++++++++++++++++++++++++++--- xkafka/batch_test.go | 8 +- xkafka/confluent.go | 2 + xkafka/consumer.go | 92 ++++++- xkafka/consumer_mock_test.go | 28 ++ xkafka/consumer_test.go | 419 +++++++++++++++++++++++++----- 8 files changed, 1022 insertions(+), 106 deletions(-) diff --git a/xkafka/batch.go b/xkafka/batch.go index 3765351..c7bac94 100644 --- a/xkafka/batch.go +++ b/xkafka/batch.go @@ -85,7 +85,9 @@ func (b *Batch) GroupMaxOffset() []kafka.TopicPartition { offsets := make(map[string]map[int32]int64) for _, m := range b.Messages { if _, ok := offsets[m.Topic]; !ok { - offsets[m.Topic] = make(map[int32]int64) + offsets[m.Topic] = map[int32]int64{ + m.Partition: m.Offset, + } } if m.Offset > offsets[m.Topic][m.Partition] { @@ -102,7 +104,7 @@ func (b *Batch) GroupMaxOffset() []kafka.TopicPartition { tps = append(tps, kafka.TopicPartition{ Topic: &topic, Partition: partition, - Offset: kafka.Offset(offset + 1), + Offset: kafka.Offset(offset), }) } } diff --git a/xkafka/batch_consumer.go b/xkafka/batch_consumer.go index 31b0c7e..8e96c3d 100644 --- a/xkafka/batch_consumer.go +++ b/xkafka/batch_consumer.go @@ -4,6 +4,7 @@ import ( "context" "errors" "strings" + "sync" "sync/atomic" "time" @@ -20,6 +21,10 @@ type BatchConsumer struct { middlewares []BatchMiddlewarer config *consumerConfig stopOffset atomic.Bool + + // partition tracking + mu sync.Mutex + activePartitions map[string]map[int32]struct{} } // NewBatchConsumer creates a new BatchConsumer instance. @@ -44,10 +49,11 @@ func NewBatchConsumer(name string, handler BatchHandler, opts ...ConsumerOption) } return &BatchConsumer{ - name: name, - config: cfg, - kafka: consumer, - handler: handler, + name: name, + config: cfg, + kafka: consumer, + handler: handler, + activePartitions: make(map[string]map[int32]struct{}), }, nil } @@ -255,7 +261,24 @@ func (c *BatchConsumer) storeBatch(batch *Batch) error { return nil } - tps := batch.GroupMaxOffset() + allTps := batch.GroupMaxOffset() + + // filter to only active partitions + tps := make([]kafka.TopicPartition, 0, len(allTps)) + for _, tp := range allTps { + if tp.Topic != nil && c.isPartitionActive(*tp.Topic, tp.Partition) { + // similar to StoreMessage in confluent-kafka-go/consumer.go + // tp.Offset + 1 it ensures that the consumer starts with + // next message when it restarts + tp.Offset = kafka.Offset(tp.Offset + 1) + + tps = append(tps, tp) + } + } + + if len(tps) == 0 { + return nil + } _, err := c.kafka.StoreOffsets(tps) if err != nil { @@ -281,7 +304,79 @@ func (c *BatchConsumer) concatMiddlewares(h BatchHandler) BatchHandler { } func (c *BatchConsumer) subscribe() error { - return c.kafka.SubscribeTopics(c.config.topics, nil) + return c.kafka.SubscribeTopics(c.config.topics, c.rebalanceCallback) +} + +func (c *BatchConsumer) rebalanceCallback(_ *kafka.Consumer, event kafka.Event) error { + switch e := event.(type) { + case kafka.AssignedPartitions: + c.onPartitionsAssigned(e.Partitions) + return c.kafka.Assign(e.Partitions) + + case kafka.RevokedPartitions: + if err := c.kafka.Unassign(); err != nil { + return err + } + + c.onPartitionsRevoked(e.Partitions) + } + + return nil +} + +func (c *BatchConsumer) onPartitionsAssigned(partitions []kafka.TopicPartition) { + c.mu.Lock() + defer c.mu.Unlock() + + for _, tp := range partitions { + if tp.Topic == nil { + continue + } + + topic := *tp.Topic + if c.activePartitions[topic] == nil { + c.activePartitions[topic] = make(map[int32]struct{}) + } + + c.activePartitions[topic][tp.Partition] = struct{}{} + } +} + +func (c *BatchConsumer) onPartitionsRevoked(partitions []kafka.TopicPartition) { + c.mu.Lock() + defer c.mu.Unlock() + + for _, tp := range partitions { + if tp.Topic == nil { + continue + } + + topic := *tp.Topic + if c.activePartitions[topic] != nil { + delete(c.activePartitions[topic], tp.Partition) + + if len(c.activePartitions[topic]) == 0 { + delete(c.activePartitions, topic) + } + } + } +} + +func (c *BatchConsumer) isPartitionActive(topic string, partition int32) bool { + c.mu.Lock() + defer c.mu.Unlock() + + // if no partitions tracked yet (before first rebalance), allow all + if len(c.activePartitions) == 0 { + return true + } + + if partitions, ok := c.activePartitions[topic]; ok { + _, active := partitions[partition] + return active + } + + return false } func (c *BatchConsumer) unsubscribe() error { diff --git a/xkafka/batch_consumer_test.go b/xkafka/batch_consumer_test.go index 4b99437..e23fd92 100644 --- a/xkafka/batch_consumer_test.go +++ b/xkafka/batch_consumer_test.go @@ -483,36 +483,6 @@ func TestBatchConsumer_MiddlewareExecutionOrder(t *testing.T) { mockKafka.AssertExpectations(t) } -func TestBatchConsumer_ManualCommit(t *testing.T) { - t.Parallel() - - consumer, mockKafka := newTestBatchConsumer(t, defaultOpts...) - - km := newFakeKafkaMessage() - ctx, cancel := context.WithCancel(context.Background()) - - mockKafka.On("SubscribeTopics", []string(testTopics), mock.Anything).Return(nil) - mockKafka.On("Unsubscribe").Return(nil) - mockKafka.On("StoreOffsets", mock.Anything).Return(nil, nil) - mockKafka.On("Commit").Return(nil, nil) - mockKafka.On("ReadMessage", testTimeout).Return(km, nil) - mockKafka.On("Close").Return(nil) - - handler := BatchHandlerFunc(func(ctx context.Context, b *Batch) error { - b.AckSuccess() - - cancel() - - return nil - }) - - consumer.handler = handler - err := consumer.Run(ctx) - assert.NoError(t, err) - - mockKafka.AssertExpectations(t) -} - func TestBatchConsumer_ReadMessageTimeout(t *testing.T) { t.Parallel() @@ -714,3 +684,439 @@ func testBatchMiddleware(name string, preExec, postExec *[]string) BatchMiddlewa }) }) } + +func TestBatchConsumer_RebalanceCallback(t *testing.T) { + t.Parallel() + + t.Run("AssignedPartitions", func(t *testing.T) { + consumer, mockKafka := newTestBatchConsumer(t, defaultOpts...) + + topic1 := "topic1" + topic2 := "topic2" + partitions := []kafka.TopicPartition{ + {Topic: &topic1, Partition: 0}, + {Topic: &topic1, Partition: 1}, + {Topic: &topic2, Partition: 0}, + } + + mockKafka.On("Assign", partitions).Return(nil) + + event := kafka.AssignedPartitions{Partitions: partitions} + err := consumer.rebalanceCallback(nil, event) + + assert.NoError(t, err) + + // Verify active partitions are tracked + assert.True(t, consumer.isPartitionActive(topic1, 0)) + assert.True(t, consumer.isPartitionActive(topic1, 1)) + assert.True(t, consumer.isPartitionActive(topic2, 0)) + assert.False(t, consumer.isPartitionActive(topic2, 1)) + + mockKafka.AssertExpectations(t) + }) + + t.Run("AssignedPartitionsError", func(t *testing.T) { + consumer, mockKafka := newTestBatchConsumer(t, defaultOpts...) + + topic := "topic1" + partitions := []kafka.TopicPartition{ + {Topic: &topic, Partition: 0}, + } + + mockKafka.On("Assign", partitions).Return(assert.AnError) + + event := kafka.AssignedPartitions{Partitions: partitions} + err := consumer.rebalanceCallback(nil, event) + + assert.ErrorIs(t, err, assert.AnError) + mockKafka.AssertExpectations(t) + }) + + t.Run("RevokedPartitions", func(t *testing.T) { + consumer, mockKafka := newTestBatchConsumer(t, defaultOpts...) + + topic1 := "topic1" + topic2 := "topic2" + + // First assign partitions + assignPartitions := []kafka.TopicPartition{ + {Topic: &topic1, Partition: 0}, + {Topic: &topic1, Partition: 1}, + {Topic: &topic2, Partition: 0}, + } + mockKafka.On("Assign", assignPartitions).Return(nil) + + assignEvent := kafka.AssignedPartitions{Partitions: assignPartitions} + err := consumer.rebalanceCallback(nil, assignEvent) + assert.NoError(t, err) + + // Now revoke some partitions + revokePartitions := []kafka.TopicPartition{ + {Topic: &topic1, Partition: 1}, + {Topic: &topic2, Partition: 0}, + } + mockKafka.On("Unassign").Return(nil) + + revokeEvent := kafka.RevokedPartitions{Partitions: revokePartitions} + err = consumer.rebalanceCallback(nil, revokeEvent) + assert.NoError(t, err) + + // Verify only non-revoked partitions are active + assert.True(t, consumer.isPartitionActive(topic1, 0)) + assert.False(t, consumer.isPartitionActive(topic1, 1)) + assert.False(t, consumer.isPartitionActive(topic2, 0)) + + mockKafka.AssertExpectations(t) + }) + + t.Run("RevokedPartitionsError", func(t *testing.T) { + consumer, mockKafka := newTestBatchConsumer(t, defaultOpts...) + + topic := "topic1" + partitions := []kafka.TopicPartition{ + {Topic: &topic, Partition: 0}, + } + + mockKafka.On("Unassign").Return(assert.AnError) + + event := kafka.RevokedPartitions{Partitions: partitions} + err := consumer.rebalanceCallback(nil, event) + + assert.ErrorIs(t, err, assert.AnError) + mockKafka.AssertExpectations(t) + }) + + t.Run("RevokeAllPartitionsRemovesTopic", func(t *testing.T) { + consumer, mockKafka := newTestBatchConsumer(t, defaultOpts...) + + topic := "topic1" + + // Assign a single partition + assignPartitions := []kafka.TopicPartition{ + {Topic: &topic, Partition: 0}, + } + mockKafka.On("Assign", assignPartitions).Return(nil) + + assignEvent := kafka.AssignedPartitions{Partitions: assignPartitions} + err := consumer.rebalanceCallback(nil, assignEvent) + assert.NoError(t, err) + + // Revoke all partitions for the topic + mockKafka.On("Unassign").Return(nil) + + revokeEvent := kafka.RevokedPartitions{Partitions: assignPartitions} + err = consumer.rebalanceCallback(nil, revokeEvent) + assert.NoError(t, err) + + // Verify topic is removed from active partitions + consumer.mu.Lock() + _, exists := consumer.activePartitions[topic] + consumer.mu.Unlock() + assert.False(t, exists) + + mockKafka.AssertExpectations(t) + }) + + t.Run("NilTopicPartition", func(t *testing.T) { + consumer, mockKafka := newTestBatchConsumer(t, defaultOpts...) + + // Test with nil topic in partition + partitions := []kafka.TopicPartition{ + {Topic: nil, Partition: 0}, + } + + mockKafka.On("Assign", partitions).Return(nil) + + event := kafka.AssignedPartitions{Partitions: partitions} + err := consumer.rebalanceCallback(nil, event) + assert.NoError(t, err) + + // Verify no panic and empty active partitions + consumer.mu.Lock() + assert.Len(t, consumer.activePartitions, 0) + consumer.mu.Unlock() + + mockKafka.AssertExpectations(t) + }) + + t.Run("UnknownEvent", func(t *testing.T) { + consumer, _ := newTestBatchConsumer(t, defaultOpts...) + + // Test with an unknown event type (use kafka.Error as example) + event := kafka.NewError(kafka.ErrUnknown, "unknown", false) + err := consumer.rebalanceCallback(nil, event) + + assert.NoError(t, err) + }) +} + +func TestBatchConsumer_StoreBatch(t *testing.T) { + t.Parallel() + + t.Run("EmptyBatch", func(t *testing.T) { + consumer, _ := newTestBatchConsumer(t, defaultOpts...) + + batch := NewBatch() + err := consumer.storeBatch(batch) + + assert.NoError(t, err) + }) + + t.Run("BatchWithFailStatus", func(t *testing.T) { + consumer, mockKafka := newTestBatchConsumer(t, defaultOpts...) + + batch := NewBatch() + batch.Messages = append(batch.Messages, &Message{ + Topic: "topic1", + Partition: 0, + Offset: 100, + }) + batch.AckFail(assert.AnError) + + err := consumer.storeBatch(batch) + + assert.NoError(t, err) + mockKafka.AssertNotCalled(t, "StoreOffsets") + }) + + t.Run("BatchWithSuccessStatus", func(t *testing.T) { + consumer, mockKafka := newTestBatchConsumer(t, defaultOpts...) + + topic := "topic1" + batch := NewBatch() + batch.Messages = append(batch.Messages, &Message{ + Topic: topic, + Partition: 0, + Offset: 100, + }) + batch.AckSuccess() + + mockKafka.On("StoreOffsets", mock.MatchedBy(func(tps []kafka.TopicPartition) bool { + return len(tps) == 1 && *tps[0].Topic == topic && tps[0].Partition == 0 + })).Return(nil, nil) + + err := consumer.storeBatch(batch) + + assert.NoError(t, err) + mockKafka.AssertExpectations(t) + }) + + t.Run("BatchWithSkipStatus", func(t *testing.T) { + consumer, mockKafka := newTestBatchConsumer(t, defaultOpts...) + + topic := "topic1" + batch := NewBatch() + batch.Messages = append(batch.Messages, &Message{ + Topic: topic, + Partition: 0, + Offset: 100, + }) + batch.AckSkip() + + mockKafka.On("StoreOffsets", mock.MatchedBy(func(tps []kafka.TopicPartition) bool { + return len(tps) == 1 && *tps[0].Topic == topic && tps[0].Partition == 0 + })).Return(nil, nil) + + err := consumer.storeBatch(batch) + + assert.NoError(t, err) + mockKafka.AssertExpectations(t) + }) + + t.Run("StopOffsetPreventsStore", func(t *testing.T) { + consumer, mockKafka := newTestBatchConsumer(t, defaultOpts...) + + topic := "topic1" + batch := NewBatch() + batch.Messages = append(batch.Messages, &Message{ + Topic: topic, + Partition: 0, + Offset: 100, + }) + batch.AckSuccess() + + // Set stopOffset flag + consumer.stopOffset.Store(true) + + err := consumer.storeBatch(batch) + + assert.NoError(t, err) + mockKafka.AssertNotCalled(t, "StoreOffsets") + }) + + t.Run("FilterInactivePartitions", func(t *testing.T) { + consumer, mockKafka := newTestBatchConsumer(t, defaultOpts...) + + topic := "topic1" + + // Assign only partition 0 + partitions := []kafka.TopicPartition{ + {Topic: &topic, Partition: 0}, + } + mockKafka.On("Assign", partitions).Return(nil) + + assignEvent := kafka.AssignedPartitions{Partitions: partitions} + err := consumer.rebalanceCallback(nil, assignEvent) + require.NoError(t, err) + + // Batch has messages from partition 0 (active) and partition 1 (inactive) + batch := NewBatch() + batch.Messages = append(batch.Messages, + &Message{Topic: topic, Partition: 0, Offset: 100}, + &Message{Topic: topic, Partition: 1, Offset: 200}, + ) + batch.AckSuccess() + + // Only partition 0 should be stored + mockKafka.On("StoreOffsets", mock.MatchedBy(func(tps []kafka.TopicPartition) bool { + if len(tps) != 1 { + return false + } + return *tps[0].Topic == topic && tps[0].Partition == 0 + })).Return(nil, nil) + + err = consumer.storeBatch(batch) + + assert.NoError(t, err) + mockKafka.AssertExpectations(t) + }) + + t.Run("AllPartitionsInactive", func(t *testing.T) { + consumer, mockKafka := newTestBatchConsumer(t, defaultOpts...) + + topic := "topic1" + + // Assign only partition 0 + partitions := []kafka.TopicPartition{ + {Topic: &topic, Partition: 0}, + } + mockKafka.On("Assign", partitions).Return(nil) + + assignEvent := kafka.AssignedPartitions{Partitions: partitions} + err := consumer.rebalanceCallback(nil, assignEvent) + require.NoError(t, err) + + // Batch has only messages from inactive partitions + batch := NewBatch() + batch.Messages = append(batch.Messages, + &Message{Topic: topic, Partition: 1, Offset: 100}, + &Message{Topic: topic, Partition: 2, Offset: 200}, + ) + batch.AckSuccess() + + // StoreOffsets should not be called since all partitions are inactive + err = consumer.storeBatch(batch) + + assert.NoError(t, err) + mockKafka.AssertExpectations(t) + }) + + t.Run("StoreOffsetsError", func(t *testing.T) { + consumer, mockKafka := newTestBatchConsumer(t, defaultOpts...) + + topic := "topic1" + batch := NewBatch() + batch.Messages = append(batch.Messages, &Message{ + Topic: topic, + Partition: 0, + Offset: 100, + }) + batch.AckSuccess() + + mockKafka.On("StoreOffsets", mock.Anything).Return(nil, assert.AnError) + + err := consumer.storeBatch(batch) + + assert.ErrorIs(t, err, assert.AnError) + mockKafka.AssertExpectations(t) + }) + + t.Run("ManualCommitEnabled", func(t *testing.T) { + opts := append(defaultOpts, ManualCommit(true)) + consumer, mockKafka := newTestBatchConsumer(t, opts...) + + topic := "topic1" + batch := NewBatch() + batch.Messages = append(batch.Messages, &Message{ + Topic: topic, + Partition: 0, + Offset: 100, + }) + batch.AckSuccess() + + mockKafka.On("StoreOffsets", mock.Anything).Return(nil, nil) + mockKafka.On("Commit").Return(nil, nil) + + err := consumer.storeBatch(batch) + + assert.NoError(t, err) + mockKafka.AssertExpectations(t) + }) + + t.Run("ManualCommitError", func(t *testing.T) { + opts := append(defaultOpts, ManualCommit(true)) + consumer, mockKafka := newTestBatchConsumer(t, opts...) + + topic := "topic1" + batch := NewBatch() + batch.Messages = append(batch.Messages, &Message{ + Topic: topic, + Partition: 0, + Offset: 100, + }) + batch.AckSuccess() + + mockKafka.On("StoreOffsets", mock.Anything).Return(nil, nil) + mockKafka.On("Commit").Return(nil, assert.AnError) + + err := consumer.storeBatch(batch) + + assert.ErrorIs(t, err, assert.AnError) + mockKafka.AssertExpectations(t) + }) + + t.Run("MultipleTopicsAndPartitions", func(t *testing.T) { + consumer, mockKafka := newTestBatchConsumer(t, defaultOpts...) + + topic1 := "topic1" + topic2 := "topic2" + batch := NewBatch() + batch.Messages = append(batch.Messages, + &Message{Topic: topic1, Partition: 0, Offset: 100}, + &Message{Topic: topic1, Partition: 0, Offset: 150}, + &Message{Topic: topic1, Partition: 1, Offset: 50}, + &Message{Topic: topic2, Partition: 0, Offset: 200}, + ) + batch.AckSuccess() + + mockKafka.On("StoreOffsets", mock.MatchedBy(func(tps []kafka.TopicPartition) bool { + // Should have 3 topic partitions: topic1-0, topic1-1, topic2-0 + return len(tps) == 3 + })).Return(nil, nil) + + err := consumer.storeBatch(batch) + + assert.NoError(t, err) + mockKafka.AssertExpectations(t) + }) + + t.Run("OffsetIncrementedByOne", func(t *testing.T) { + consumer, mockKafka := newTestBatchConsumer(t, defaultOpts...) + + topic := "topic1" + batch := NewBatch() + batch.Messages = append(batch.Messages, + &Message{Topic: topic, Partition: 0, Offset: 999}, + ) + batch.AckSuccess() + + mockKafka.On("StoreOffsets", mock.MatchedBy(func(tps []kafka.TopicPartition) bool { + return len(tps) == 1 && tps[0].Offset == kafka.Offset(1000) + })).Return(nil, nil) + + err := consumer.storeBatch(batch) + + assert.NoError(t, err) + mockKafka.AssertExpectations(t) + }) +} diff --git a/xkafka/batch_test.go b/xkafka/batch_test.go index e00952b..7ac3ff9 100644 --- a/xkafka/batch_test.go +++ b/xkafka/batch_test.go @@ -65,7 +65,7 @@ func TestBatch_OffsetMethods(t *testing.T) { { Topic: strPtr("topic1"), Partition: 0, - Offset: kafka.Offset(11), + Offset: kafka.Offset(10), }, }, }, @@ -81,17 +81,17 @@ func TestBatch_OffsetMethods(t *testing.T) { { Topic: strPtr("topic1"), Partition: 0, - Offset: kafka.Offset(6), + Offset: kafka.Offset(5), }, { Topic: strPtr("topic1"), Partition: 1, - Offset: kafka.Offset(11), + Offset: kafka.Offset(10), }, { Topic: strPtr("topic2"), Partition: 0, - Offset: kafka.Offset(16), + Offset: kafka.Offset(15), }, }, }, diff --git a/xkafka/confluent.go b/xkafka/confluent.go index 98c5055..0c5c68b 100644 --- a/xkafka/confluent.go +++ b/xkafka/confluent.go @@ -11,6 +11,8 @@ type consumerClient interface { ReadMessage(timeout time.Duration) (*kafka.Message, error) SubscribeTopics(topics []string, rebalanceCb kafka.RebalanceCb) error Unsubscribe() error + Assign(partitions []kafka.TopicPartition) error + Unassign() error StoreOffsets(offsets []kafka.TopicPartition) ([]kafka.TopicPartition, error) Commit() ([]kafka.TopicPartition, error) Close() error diff --git a/xkafka/consumer.go b/xkafka/consumer.go index 9d237ba..daf98a8 100644 --- a/xkafka/consumer.go +++ b/xkafka/consumer.go @@ -4,6 +4,7 @@ import ( "context" "errors" "strings" + "sync" "sync/atomic" "time" @@ -21,6 +22,10 @@ type Consumer struct { config *consumerConfig cancelCtx atomic.Pointer[context.CancelFunc] stopOffset atomic.Bool + + // partition tracking + mu sync.Mutex + activePartitions map[string]map[int32]struct{} } // NewConsumer creates a new Consumer instance. @@ -45,10 +50,11 @@ func NewConsumer(name string, handler Handler, opts ...ConsumerOption) (*Consume } return &Consumer{ - name: name, - config: cfg, - kafka: consumer, - handler: handler, + name: name, + config: cfg, + kafka: consumer, + handler: handler, + activePartitions: make(map[string]map[int32]struct{}), }, nil } @@ -234,6 +240,11 @@ func (c *Consumer) storeMessage(msg *Message) error { return nil } + // only store offset if partition is still active + if !c.isPartitionActive(msg.Topic, msg.Partition) { + return nil + } + // similar to StoreMessage in confluent-kafka-go/consumer.go // msg.Offset + 1 it ensures that the consumer starts with // next message when it restarts @@ -267,7 +278,78 @@ func (c *Consumer) concatMiddlewares(h Handler) Handler { } func (c *Consumer) subscribe() error { - return c.kafka.SubscribeTopics(c.config.topics, nil) + return c.kafka.SubscribeTopics(c.config.topics, c.rebalanceCallback) +} + +func (c *Consumer) rebalanceCallback(_ *kafka.Consumer, event kafka.Event) error { + switch e := event.(type) { + case kafka.AssignedPartitions: + c.onPartitionsAssigned(e.Partitions) + return c.kafka.Assign(e.Partitions) + + case kafka.RevokedPartitions: + if err := c.kafka.Unassign(); err != nil { + return err + } + c.onPartitionsRevoked(e.Partitions) + } + + return nil +} + +func (c *Consumer) onPartitionsAssigned(partitions []kafka.TopicPartition) { + c.mu.Lock() + defer c.mu.Unlock() + + for _, tp := range partitions { + if tp.Topic == nil { + continue + } + + topic := *tp.Topic + if c.activePartitions[topic] == nil { + c.activePartitions[topic] = make(map[int32]struct{}) + } + + c.activePartitions[topic][tp.Partition] = struct{}{} + } +} + +func (c *Consumer) onPartitionsRevoked(partitions []kafka.TopicPartition) { + c.mu.Lock() + defer c.mu.Unlock() + + for _, tp := range partitions { + if tp.Topic == nil { + continue + } + + topic := *tp.Topic + if c.activePartitions[topic] != nil { + delete(c.activePartitions[topic], tp.Partition) + + if len(c.activePartitions[topic]) == 0 { + delete(c.activePartitions, topic) + } + } + } +} + +func (c *Consumer) isPartitionActive(topic string, partition int32) bool { + c.mu.Lock() + defer c.mu.Unlock() + + // if no partitions tracked yet (before first rebalance), allow all + if len(c.activePartitions) == 0 { + return true + } + + if partitions, ok := c.activePartitions[topic]; ok { + _, active := partitions[partition] + return active + } + + return false } func (c *Consumer) unsubscribe() error { diff --git a/xkafka/consumer_mock_test.go b/xkafka/consumer_mock_test.go index 791a408..2ed8342 100644 --- a/xkafka/consumer_mock_test.go +++ b/xkafka/consumer_mock_test.go @@ -132,6 +132,34 @@ func (_m *MockConsumerClient) StoreOffsets(offsets []kafka.TopicPartition) ([]ka return r0, r1 } +// Assign provides a mock function with given fields: partitions +func (_m *MockConsumerClient) Assign(partitions []kafka.TopicPartition) error { + ret := _m.Called(partitions) + + var r0 error + if rf, ok := ret.Get(0).(func([]kafka.TopicPartition) error); ok { + r0 = rf(partitions) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Unassign provides a mock function with given fields: +func (_m *MockConsumerClient) Unassign() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + // SubscribeTopics provides a mock function with given fields: topics, rebalanceCb func (_m *MockConsumerClient) SubscribeTopics(topics []string, rebalanceCb kafka.RebalanceCb) error { ret := _m.Called(topics, rebalanceCb) diff --git a/xkafka/consumer_test.go b/xkafka/consumer_test.go index 0967fc1..c97e30f 100644 --- a/xkafka/consumer_test.go +++ b/xkafka/consumer_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "sync" + "sync/atomic" "testing" "time" @@ -178,22 +179,6 @@ func TestConsumerGetMetadata(t *testing.T) { mockKafka.AssertExpectations(t) } -func TestConsumerSubscribeError(t *testing.T) { - t.Parallel() - - consumer, mockKafka := newTestConsumer(t, defaultOpts...) - - ctx := context.Background() - expectError := errors.New("error") - - mockKafka.On("SubscribeTopics", []string(testTopics), mock.Anything).Return(expectError) - - err := consumer.Run(ctx) - assert.EqualError(t, err, expectError.Error()) - - mockKafka.AssertExpectations(t) -} - func TestConsumerUnsubscribeError(t *testing.T) { t.Parallel() @@ -500,38 +485,6 @@ func TestConsumerMiddlewareExecutionOrder(t *testing.T) { mockKafka.AssertExpectations(t) } -func TestConsumerManualCommit(t *testing.T) { - t.Parallel() - - consumer, mockKafka := newTestConsumer(t, - append(defaultOpts, ManualCommit(true))...) - - km := newFakeKafkaMessage() - ctx, cancel := context.WithCancel(context.Background()) - - mockKafka.On("SubscribeTopics", []string(testTopics), mock.Anything).Return(nil) - mockKafka.On("Unsubscribe").Return(nil) - mockKafka.On("StoreOffsets", mock.Anything).Return(nil, nil) - mockKafka.On("Commit").Return(nil, nil) - mockKafka.On("ReadMessage", testTimeout).Return(km, nil) - mockKafka.On("Close").Return(nil) - - handler := HandlerFunc(func(ctx context.Context, msg *Message) error { - cancel() - - msg.AckSuccess() - - return nil - }) - - consumer.handler = handler - - err := consumer.Run(ctx) - assert.NoError(t, err) - - mockKafka.AssertExpectations(t) -} - func TestConsumerAsync(t *testing.T) { t.Parallel() @@ -589,20 +542,12 @@ func TestConsumerAsync_StopOffsetOnError(t *testing.T) { mockKafka.On("Commit").Return(nil, nil) mockKafka.On("Close").Return(nil) - mockKafka.On("StoreOffsets", mock.Anything). - Return(nil, nil). - Times(2) + mockKafka.On("StoreOffsets", mock.Anything).Return(nil, nil) - var recv []*Message - var mu sync.Mutex + var recv atomic.Int32 handler := HandlerFunc(func(ctx context.Context, msg *Message) error { - mu.Lock() - defer mu.Unlock() - - recv = append(recv, msg) - - if len(recv) > 2 { + if recv.Load() > 2 { err := assert.AnError msg.AckFail(err) @@ -611,6 +556,7 @@ func TestConsumerAsync_StopOffsetOnError(t *testing.T) { return err } + recv.Add(1) msg.AckSuccess() return nil @@ -741,6 +687,361 @@ func testMiddleware(name string, pre, post *[]string) MiddlewareFunc { } } +func TestConsumer_RebalanceCallback(t *testing.T) { + t.Parallel() + + t.Run("AssignedPartitions", func(t *testing.T) { + consumer, mockKafka := newTestConsumer(t, defaultOpts...) + + topic1 := "topic1" + topic2 := "topic2" + partitions := []kafka.TopicPartition{ + {Topic: &topic1, Partition: 0}, + {Topic: &topic1, Partition: 1}, + {Topic: &topic2, Partition: 0}, + } + + mockKafka.On("Assign", partitions).Return(nil) + + event := kafka.AssignedPartitions{Partitions: partitions} + err := consumer.rebalanceCallback(nil, event) + + assert.NoError(t, err) + + // Verify active partitions are tracked + assert.True(t, consumer.isPartitionActive(topic1, 0)) + assert.True(t, consumer.isPartitionActive(topic1, 1)) + assert.True(t, consumer.isPartitionActive(topic2, 0)) + assert.False(t, consumer.isPartitionActive(topic2, 1)) + + mockKafka.AssertExpectations(t) + }) + + t.Run("AssignedPartitionsError", func(t *testing.T) { + consumer, mockKafka := newTestConsumer(t, defaultOpts...) + + topic := "topic1" + partitions := []kafka.TopicPartition{ + {Topic: &topic, Partition: 0}, + } + + mockKafka.On("Assign", partitions).Return(assert.AnError) + + event := kafka.AssignedPartitions{Partitions: partitions} + err := consumer.rebalanceCallback(nil, event) + + assert.ErrorIs(t, err, assert.AnError) + mockKafka.AssertExpectations(t) + }) + + t.Run("RevokedPartitions", func(t *testing.T) { + consumer, mockKafka := newTestConsumer(t, defaultOpts...) + + topic1 := "topic1" + topic2 := "topic2" + + // First assign partitions + assignPartitions := []kafka.TopicPartition{ + {Topic: &topic1, Partition: 0}, + {Topic: &topic1, Partition: 1}, + {Topic: &topic2, Partition: 0}, + } + mockKafka.On("Assign", assignPartitions).Return(nil) + + assignEvent := kafka.AssignedPartitions{Partitions: assignPartitions} + err := consumer.rebalanceCallback(nil, assignEvent) + assert.NoError(t, err) + + // Now revoke some partitions + revokePartitions := []kafka.TopicPartition{ + {Topic: &topic1, Partition: 1}, + {Topic: &topic2, Partition: 0}, + } + mockKafka.On("Unassign").Return(nil) + + revokeEvent := kafka.RevokedPartitions{Partitions: revokePartitions} + err = consumer.rebalanceCallback(nil, revokeEvent) + assert.NoError(t, err) + + // Verify only non-revoked partitions are active + assert.True(t, consumer.isPartitionActive(topic1, 0)) + assert.False(t, consumer.isPartitionActive(topic1, 1)) + assert.False(t, consumer.isPartitionActive(topic2, 0)) + + mockKafka.AssertExpectations(t) + }) + + t.Run("RevokedPartitionsError", func(t *testing.T) { + consumer, mockKafka := newTestConsumer(t, defaultOpts...) + + topic := "topic1" + partitions := []kafka.TopicPartition{ + {Topic: &topic, Partition: 0}, + } + + mockKafka.On("Unassign").Return(assert.AnError) + + event := kafka.RevokedPartitions{Partitions: partitions} + err := consumer.rebalanceCallback(nil, event) + + assert.ErrorIs(t, err, assert.AnError) + mockKafka.AssertExpectations(t) + }) + + t.Run("RevokeAllPartitionsRemovesTopic", func(t *testing.T) { + consumer, mockKafka := newTestConsumer(t, defaultOpts...) + + topic := "topic1" + + // Assign a single partition + assignPartitions := []kafka.TopicPartition{ + {Topic: &topic, Partition: 0}, + } + mockKafka.On("Assign", assignPartitions).Return(nil) + + assignEvent := kafka.AssignedPartitions{Partitions: assignPartitions} + err := consumer.rebalanceCallback(nil, assignEvent) + assert.NoError(t, err) + + // Revoke all partitions for the topic + mockKafka.On("Unassign").Return(nil) + + revokeEvent := kafka.RevokedPartitions{Partitions: assignPartitions} + err = consumer.rebalanceCallback(nil, revokeEvent) + assert.NoError(t, err) + + // Verify topic is removed from active partitions + consumer.mu.Lock() + _, exists := consumer.activePartitions[topic] + consumer.mu.Unlock() + assert.False(t, exists) + + mockKafka.AssertExpectations(t) + }) + + t.Run("NilTopicPartition", func(t *testing.T) { + consumer, mockKafka := newTestConsumer(t, defaultOpts...) + + // Test with nil topic in partition + partitions := []kafka.TopicPartition{ + {Topic: nil, Partition: 0}, + } + + mockKafka.On("Assign", partitions).Return(nil) + + event := kafka.AssignedPartitions{Partitions: partitions} + err := consumer.rebalanceCallback(nil, event) + assert.NoError(t, err) + + // Verify no panic and empty active partitions + consumer.mu.Lock() + assert.Len(t, consumer.activePartitions, 0) + consumer.mu.Unlock() + + mockKafka.AssertExpectations(t) + }) + + t.Run("UnknownEvent", func(t *testing.T) { + consumer, _ := newTestConsumer(t, defaultOpts...) + + // Test with an unknown event type (use kafka.Error as example) + event := kafka.NewError(kafka.ErrUnknown, "unknown", false) + err := consumer.rebalanceCallback(nil, event) + + assert.NoError(t, err) + }) +} + +func TestConsumer_StoreMessage(t *testing.T) { + t.Parallel() + + t.Run("MessageWithFailStatus", func(t *testing.T) { + consumer, mockKafka := newTestConsumer(t, defaultOpts...) + + msg := &Message{ + Topic: "topic1", + Partition: 0, + Offset: 100, + Status: Fail, + } + + err := consumer.storeMessage(msg) + + assert.NoError(t, err) + mockKafka.AssertNotCalled(t, "StoreOffsets") + }) + + t.Run("MessageWithSuccessStatus", func(t *testing.T) { + consumer, mockKafka := newTestConsumer(t, defaultOpts...) + + topic := "topic1" + msg := &Message{ + Topic: topic, + Partition: 0, + Offset: 100, + Status: Success, + } + + mockKafka.On("StoreOffsets", mock.MatchedBy(func(tps []kafka.TopicPartition) bool { + return len(tps) == 1 && *tps[0].Topic == topic && tps[0].Partition == 0 && tps[0].Offset == 101 + })).Return(nil, nil) + + err := consumer.storeMessage(msg) + + assert.NoError(t, err) + mockKafka.AssertExpectations(t) + }) + + t.Run("MessageWithSkipStatus", func(t *testing.T) { + consumer, mockKafka := newTestConsumer(t, defaultOpts...) + + topic := "topic1" + msg := &Message{ + Topic: topic, + Partition: 0, + Offset: 100, + Status: Skip, + } + + mockKafka.On("StoreOffsets", mock.MatchedBy(func(tps []kafka.TopicPartition) bool { + return len(tps) == 1 && *tps[0].Topic == topic && tps[0].Partition == 0 && tps[0].Offset == 101 + })).Return(nil, nil) + + err := consumer.storeMessage(msg) + + assert.NoError(t, err) + mockKafka.AssertExpectations(t) + }) + + t.Run("StopOffsetPreventsStore", func(t *testing.T) { + consumer, mockKafka := newTestConsumer(t, defaultOpts...) + + msg := &Message{ + Topic: "topic1", + Partition: 0, + Offset: 100, + Status: Success, + } + + // Set stopOffset flag + consumer.stopOffset.Store(true) + + err := consumer.storeMessage(msg) + + assert.NoError(t, err) + mockKafka.AssertNotCalled(t, "StoreOffsets") + }) + + t.Run("FilterInactivePartition", func(t *testing.T) { + consumer, mockKafka := newTestConsumer(t, defaultOpts...) + + topic := "topic1" + + // Assign only partition 0 + partitions := []kafka.TopicPartition{ + {Topic: &topic, Partition: 0}, + } + mockKafka.On("Assign", partitions).Return(nil) + + assignEvent := kafka.AssignedPartitions{Partitions: partitions} + err := consumer.rebalanceCallback(nil, assignEvent) + require.NoError(t, err) + + // Message is from partition 1 (inactive) + msg := &Message{ + Topic: topic, + Partition: 1, + Offset: 100, + Status: Success, + } + + // StoreOffsets should not be called for inactive partition + err = consumer.storeMessage(msg) + + assert.NoError(t, err) + mockKafka.AssertExpectations(t) + }) + + t.Run("StoreOffsetsError", func(t *testing.T) { + consumer, mockKafka := newTestConsumer(t, defaultOpts...) + + msg := &Message{ + Topic: "topic1", + Partition: 0, + Offset: 100, + Status: Success, + } + + mockKafka.On("StoreOffsets", mock.Anything).Return(nil, assert.AnError) + + err := consumer.storeMessage(msg) + + assert.ErrorIs(t, err, assert.AnError) + mockKafka.AssertExpectations(t) + }) + + t.Run("ManualCommitEnabled", func(t *testing.T) { + opts := append(defaultOpts, ManualCommit(true)) + consumer, mockKafka := newTestConsumer(t, opts...) + + msg := &Message{ + Topic: "topic1", + Partition: 0, + Offset: 100, + Status: Success, + } + + mockKafka.On("StoreOffsets", mock.Anything).Return(nil, nil) + mockKafka.On("Commit").Return(nil, nil) + + err := consumer.storeMessage(msg) + + assert.NoError(t, err) + mockKafka.AssertExpectations(t) + }) + + t.Run("ManualCommitError", func(t *testing.T) { + opts := append(defaultOpts, ManualCommit(true)) + consumer, mockKafka := newTestConsumer(t, opts...) + + msg := &Message{ + Topic: "topic1", + Partition: 0, + Offset: 100, + Status: Success, + } + + mockKafka.On("StoreOffsets", mock.Anything).Return(nil, nil) + mockKafka.On("Commit").Return(nil, assert.AnError) + + err := consumer.storeMessage(msg) + + assert.ErrorIs(t, err, assert.AnError) + mockKafka.AssertExpectations(t) + }) + + t.Run("OffsetIncrementedByOne", func(t *testing.T) { + consumer, mockKafka := newTestConsumer(t, defaultOpts...) + + msg := &Message{ + Topic: "topic1", + Partition: 0, + Offset: 999, + Status: Success, + } + + // Verify offset is incremented by 1 + mockKafka.On("StoreOffsets", mock.MatchedBy(func(tps []kafka.TopicPartition) bool { + return len(tps) == 1 && tps[0].Offset == kafka.Offset(1000) + })).Return(nil, nil) + + err := consumer.storeMessage(msg) + + assert.NoError(t, err) + mockKafka.AssertExpectations(t) + }) +} + func newTestConsumer(t *testing.T, opts ...ConsumerOption) (*Consumer, *MockConsumerClient) { t.Helper() From 8569d4c905a61960615ad3525fc21edadd0e1e24 Mon Sep 17 00:00:00 2001 From: Ravi Atluri Date: Thu, 4 Dec 2025 15:37:10 +0530 Subject: [PATCH 2/5] Refactor partition handling in BatchConsumer and Consumer --- xkafka/batch_consumer.go | 5 -- xkafka/batch_consumer_test.go | 130 +++++++++++++++++++++++----------- xkafka/consumer.go | 5 -- xkafka/consumer_test.go | 57 +++++++++++++-- 4 files changed, 140 insertions(+), 57 deletions(-) diff --git a/xkafka/batch_consumer.go b/xkafka/batch_consumer.go index 8e96c3d..1a7f543 100644 --- a/xkafka/batch_consumer.go +++ b/xkafka/batch_consumer.go @@ -366,11 +366,6 @@ func (c *BatchConsumer) isPartitionActive(topic string, partition int32) bool { c.mu.Lock() defer c.mu.Unlock() - // if no partitions tracked yet (before first rebalance), allow all - if len(c.activePartitions) == 0 { - return true - } - if partitions, ok := c.activePartitions[topic]; ok { _, active := partitions[partition] return active diff --git a/xkafka/batch_consumer_test.go b/xkafka/batch_consumer_test.go index e23fd92..c489c0b 100644 --- a/xkafka/batch_consumer_test.go +++ b/xkafka/batch_consumer_test.go @@ -295,7 +295,12 @@ func TestBatchConsumer_Async(t *testing.T) { return nil }) - mockKafka.On("SubscribeTopics", []string(testTopics), mock.Anything).Return(nil) + mockKafka.On("SubscribeTopics", []string(testTopics), mock.Anything).Run(func(args mock.Arguments) { + cb := args.Get(1).(kafka.RebalanceCb) + partitions := []kafka.TopicPartition{{Topic: km.TopicPartition.Topic, Partition: km.TopicPartition.Partition}} + mockKafka.On("Assign", partitions).Return(nil).Once() + cb(nil, kafka.AssignedPartitions{Partitions: partitions}) + }).Return(nil) mockKafka.On("Unsubscribe").Return(nil) mockKafka.On("Commit").Return(nil, nil) mockKafka.On("StoreOffsets", mock.Anything).Return(nil, nil) @@ -357,15 +362,19 @@ func TestBatchConsumer_StopOffsetOnError(t *testing.T) { return nil }) - mockKafka.On("SubscribeTopics", []string(testTopics), mock.Anything).Return(nil) + mockKafka.On("SubscribeTopics", []string(testTopics), mock.Anything).Run(func(args mock.Arguments) { + cb := args.Get(1).(kafka.RebalanceCb) + partitions := []kafka.TopicPartition{{Topic: km.TopicPartition.Topic, Partition: km.TopicPartition.Partition}} + mockKafka.On("Assign", partitions).Return(nil).Once() + cb(nil, kafka.AssignedPartitions{Partitions: partitions}) + }).Return(nil) mockKafka.On("Unsubscribe").Return(nil) mockKafka.On("Commit").Return(nil, nil) mockKafka.On("ReadMessage", testTimeout).Return(km, nil) mockKafka.On("Close").Return(nil) mockKafka.On("StoreOffsets", mock.Anything). - Return(nil, nil). - Times(2) + Return(nil, nil) consumer.handler = handler err := consumer.Run(ctx) @@ -416,7 +425,12 @@ func TestBatchConsumer_BatchTimeout(t *testing.T) { return nil }) - mockKafka.On("SubscribeTopics", []string(testTopics), mock.Anything).Return(nil) + mockKafka.On("SubscribeTopics", []string(testTopics), mock.Anything).Run(func(args mock.Arguments) { + cb := args.Get(1).(kafka.RebalanceCb) + partitions := []kafka.TopicPartition{{Topic: km.TopicPartition.Topic, Partition: km.TopicPartition.Partition}} + mockKafka.On("Assign", partitions).Return(nil).Once() + cb(nil, kafka.AssignedPartitions{Partitions: partitions}) + }).Return(nil) mockKafka.On("Unsubscribe").Return(nil) mockKafka.On("Commit").Return(nil, nil) mockKafka.On("StoreOffsets", mock.Anything).Return(nil, nil) @@ -649,42 +663,6 @@ func TestBatchConsumer_CommitError(t *testing.T) { } } -func noopBatchHandler() BatchHandler { - return BatchHandlerFunc(func(ctx context.Context, b *Batch) error { - return nil - }) -} - -func newTestBatchConsumer(t *testing.T, opts ...ConsumerOption) (*BatchConsumer, *MockConsumerClient) { - t.Helper() - - mockConsumer := &MockConsumerClient{} - - opts = append(opts, mockConsumerFunc(mockConsumer)) - - consumer, err := NewBatchConsumer( - "test-batch-consumer", - noopBatchHandler(), - opts..., - ) - require.NoError(t, err) - require.NotNil(t, consumer) - - return consumer, mockConsumer -} - -func testBatchMiddleware(name string, preExec, postExec *[]string) BatchMiddlewarer { - return BatchMiddlewareFunc(func(next BatchHandler) BatchHandler { - return BatchHandlerFunc(func(ctx context.Context, b *Batch) error { - *preExec = append(*preExec, name) - err := next.HandleBatch(ctx, b) - *postExec = append(*postExec, name) - - return err - }) - }) -} - func TestBatchConsumer_RebalanceCallback(t *testing.T) { t.Parallel() @@ -883,6 +861,8 @@ func TestBatchConsumer_StoreBatch(t *testing.T) { consumer, mockKafka := newTestBatchConsumer(t, defaultOpts...) topic := "topic1" + assignBatchPartitions(t, consumer, mockKafka, topic, 0) + batch := NewBatch() batch.Messages = append(batch.Messages, &Message{ Topic: topic, @@ -905,6 +885,8 @@ func TestBatchConsumer_StoreBatch(t *testing.T) { consumer, mockKafka := newTestBatchConsumer(t, defaultOpts...) topic := "topic1" + assignBatchPartitions(t, consumer, mockKafka, topic, 0) + batch := NewBatch() batch.Messages = append(batch.Messages, &Message{ Topic: topic, @@ -1015,6 +997,8 @@ func TestBatchConsumer_StoreBatch(t *testing.T) { consumer, mockKafka := newTestBatchConsumer(t, defaultOpts...) topic := "topic1" + assignBatchPartitions(t, consumer, mockKafka, topic, 0) + batch := NewBatch() batch.Messages = append(batch.Messages, &Message{ Topic: topic, @@ -1036,6 +1020,8 @@ func TestBatchConsumer_StoreBatch(t *testing.T) { consumer, mockKafka := newTestBatchConsumer(t, opts...) topic := "topic1" + assignBatchPartitions(t, consumer, mockKafka, topic, 0) + batch := NewBatch() batch.Messages = append(batch.Messages, &Message{ Topic: topic, @@ -1058,6 +1044,8 @@ func TestBatchConsumer_StoreBatch(t *testing.T) { consumer, mockKafka := newTestBatchConsumer(t, opts...) topic := "topic1" + assignBatchPartitions(t, consumer, mockKafka, topic, 0) + batch := NewBatch() batch.Messages = append(batch.Messages, &Message{ Topic: topic, @@ -1080,6 +1068,11 @@ func TestBatchConsumer_StoreBatch(t *testing.T) { topic1 := "topic1" topic2 := "topic2" + + // Assign all partitions that will be used + assignBatchPartitions(t, consumer, mockKafka, topic1, 0, 1) + assignBatchPartitions(t, consumer, mockKafka, topic2, 0) + batch := NewBatch() batch.Messages = append(batch.Messages, &Message{Topic: topic1, Partition: 0, Offset: 100}, @@ -1104,6 +1097,8 @@ func TestBatchConsumer_StoreBatch(t *testing.T) { consumer, mockKafka := newTestBatchConsumer(t, defaultOpts...) topic := "topic1" + assignBatchPartitions(t, consumer, mockKafka, topic, 0) + batch := NewBatch() batch.Messages = append(batch.Messages, &Message{Topic: topic, Partition: 0, Offset: 999}, @@ -1120,3 +1115,56 @@ func TestBatchConsumer_StoreBatch(t *testing.T) { mockKafka.AssertExpectations(t) }) } + +func noopBatchHandler() BatchHandler { + return BatchHandlerFunc(func(ctx context.Context, b *Batch) error { + return nil + }) +} + +func newTestBatchConsumer(t *testing.T, opts ...ConsumerOption) (*BatchConsumer, *MockConsumerClient) { + t.Helper() + + mockConsumer := &MockConsumerClient{} + + opts = append(opts, mockConsumerFunc(mockConsumer)) + + consumer, err := NewBatchConsumer( + "test-batch-consumer", + noopBatchHandler(), + opts..., + ) + require.NoError(t, err) + require.NotNil(t, consumer) + + return consumer, mockConsumer +} + +func testBatchMiddleware(name string, preExec, postExec *[]string) BatchMiddlewarer { + return BatchMiddlewareFunc(func(next BatchHandler) BatchHandler { + return BatchHandlerFunc(func(ctx context.Context, b *Batch) error { + *preExec = append(*preExec, name) + err := next.HandleBatch(ctx, b) + *postExec = append(*postExec, name) + + return err + }) + }) +} + +// assignBatchPartitions is a test helper that assigns partitions to a batch consumer +// so that isPartitionActive returns true for those partitions. +func assignBatchPartitions(t *testing.T, consumer *BatchConsumer, mockKafka *MockConsumerClient, topic string, partitions ...int32) { + t.Helper() + + tps := make([]kafka.TopicPartition, len(partitions)) + for i, p := range partitions { + tps[i] = kafka.TopicPartition{Topic: &topic, Partition: p} + } + + mockKafka.On("Assign", tps).Return(nil).Once() + + event := kafka.AssignedPartitions{Partitions: tps} + err := consumer.rebalanceCallback(nil, event) + require.NoError(t, err) +} diff --git a/xkafka/consumer.go b/xkafka/consumer.go index daf98a8..61e031c 100644 --- a/xkafka/consumer.go +++ b/xkafka/consumer.go @@ -339,11 +339,6 @@ func (c *Consumer) isPartitionActive(topic string, partition int32) bool { c.mu.Lock() defer c.mu.Unlock() - // if no partitions tracked yet (before first rebalance), allow all - if len(c.activePartitions) == 0 { - return true - } - if partitions, ok := c.activePartitions[topic]; ok { _, active := partitions[partition] return active diff --git a/xkafka/consumer_test.go b/xkafka/consumer_test.go index c97e30f..b822aa9 100644 --- a/xkafka/consumer_test.go +++ b/xkafka/consumer_test.go @@ -494,7 +494,13 @@ func TestConsumerAsync(t *testing.T) { km := newFakeKafkaMessage() ctx, cancel := context.WithCancel(context.Background()) - mockKafka.On("SubscribeTopics", []string(testTopics), mock.Anything).Return(nil) + // Capture rebalance callback and trigger partition assignment + mockKafka.On("SubscribeTopics", []string(testTopics), mock.Anything).Run(func(args mock.Arguments) { + cb := args.Get(1).(kafka.RebalanceCb) + partitions := []kafka.TopicPartition{{Topic: km.TopicPartition.Topic, Partition: km.TopicPartition.Partition}} + mockKafka.On("Assign", partitions).Return(nil).Once() + cb(nil, kafka.AssignedPartitions{Partitions: partitions}) + }).Return(nil) mockKafka.On("Unsubscribe").Return(nil) mockKafka.On("StoreOffsets", mock.Anything).Return(nil, nil) mockKafka.On("ReadMessage", testTimeout).Return(km, nil) @@ -536,7 +542,13 @@ func TestConsumerAsync_StopOffsetOnError(t *testing.T) { km := newFakeKafkaMessage() ctx, cancel := context.WithCancel(context.Background()) - mockKafka.On("SubscribeTopics", []string(testTopics), mock.Anything).Return(nil) + // Capture rebalance callback and trigger partition assignment + mockKafka.On("SubscribeTopics", []string(testTopics), mock.Anything).Run(func(args mock.Arguments) { + cb := args.Get(1).(kafka.RebalanceCb) + partitions := []kafka.TopicPartition{{Topic: km.TopicPartition.Topic, Partition: km.TopicPartition.Partition}} + mockKafka.On("Assign", partitions).Return(nil).Once() + cb(nil, kafka.AssignedPartitions{Partitions: partitions}) + }).Return(nil) mockKafka.On("Unsubscribe").Return(nil) mockKafka.On("ReadMessage", testTimeout).Return(km, nil) mockKafka.On("Commit").Return(nil, nil) @@ -875,6 +887,8 @@ func TestConsumer_StoreMessage(t *testing.T) { consumer, mockKafka := newTestConsumer(t, defaultOpts...) topic := "topic1" + assignPartitions(t, consumer, mockKafka, topic, 0) + msg := &Message{ Topic: topic, Partition: 0, @@ -896,6 +910,8 @@ func TestConsumer_StoreMessage(t *testing.T) { consumer, mockKafka := newTestConsumer(t, defaultOpts...) topic := "topic1" + assignPartitions(t, consumer, mockKafka, topic, 0) + msg := &Message{ Topic: topic, Partition: 0, @@ -965,8 +981,11 @@ func TestConsumer_StoreMessage(t *testing.T) { t.Run("StoreOffsetsError", func(t *testing.T) { consumer, mockKafka := newTestConsumer(t, defaultOpts...) + topic := "topic1" + assignPartitions(t, consumer, mockKafka, topic, 0) + msg := &Message{ - Topic: "topic1", + Topic: topic, Partition: 0, Offset: 100, Status: Success, @@ -984,8 +1003,11 @@ func TestConsumer_StoreMessage(t *testing.T) { opts := append(defaultOpts, ManualCommit(true)) consumer, mockKafka := newTestConsumer(t, opts...) + topic := "topic1" + assignPartitions(t, consumer, mockKafka, topic, 0) + msg := &Message{ - Topic: "topic1", + Topic: topic, Partition: 0, Offset: 100, Status: Success, @@ -1004,8 +1026,11 @@ func TestConsumer_StoreMessage(t *testing.T) { opts := append(defaultOpts, ManualCommit(true)) consumer, mockKafka := newTestConsumer(t, opts...) + topic := "topic1" + assignPartitions(t, consumer, mockKafka, topic, 0) + msg := &Message{ - Topic: "topic1", + Topic: topic, Partition: 0, Offset: 100, Status: Success, @@ -1023,8 +1048,11 @@ func TestConsumer_StoreMessage(t *testing.T) { t.Run("OffsetIncrementedByOne", func(t *testing.T) { consumer, mockKafka := newTestConsumer(t, defaultOpts...) + topic := "topic1" + assignPartitions(t, consumer, mockKafka, topic, 0) + msg := &Message{ - Topic: "topic1", + Topic: topic, Partition: 0, Offset: 999, Status: Success, @@ -1079,3 +1107,20 @@ func noopHandler() Handler { return nil }) } + +// assignPartitions is a test helper that assigns partitions to a consumer +// so that isPartitionActive returns true for those partitions. +func assignPartitions(t *testing.T, consumer *Consumer, mockKafka *MockConsumerClient, topic string, partitions ...int32) { + t.Helper() + + tps := make([]kafka.TopicPartition, len(partitions)) + for i, p := range partitions { + tps[i] = kafka.TopicPartition{Topic: &topic, Partition: p} + } + + mockKafka.On("Assign", tps).Return(nil).Once() + + event := kafka.AssignedPartitions{Partitions: tps} + err := consumer.rebalanceCallback(nil, event) + require.NoError(t, err) +} From de05aa9433494a177eaa8baa2e5ef0b125dc87a9 Mon Sep 17 00:00:00 2001 From: Ravi Atluri Date: Tue, 27 Jan 2026 09:20:59 +0530 Subject: [PATCH 3/5] Fix missing partition assignment in Consumer and BatchConsumer error tests --- xkafka/batch_consumer_test.go | 1 + xkafka/consumer_test.go | 2 ++ 2 files changed, 3 insertions(+) diff --git a/xkafka/batch_consumer_test.go b/xkafka/batch_consumer_test.go index c489c0b..e650984 100644 --- a/xkafka/batch_consumer_test.go +++ b/xkafka/batch_consumer_test.go @@ -653,6 +653,7 @@ func TestBatchConsumer_CommitError(t *testing.T) { mockKafka.On("Commit").Return(nil, expect) consumer.handler = handler + assignBatchPartitions(t, consumer, mockKafka, testTopics[0], 1) err := consumer.Run(ctx) assert.Error(t, err) diff --git a/xkafka/consumer_test.go b/xkafka/consumer_test.go index b822aa9..480644c 100644 --- a/xkafka/consumer_test.go +++ b/xkafka/consumer_test.go @@ -624,6 +624,7 @@ func TestConsumerStoreOffsetsError(t *testing.T) { mockKafka.On("ReadMessage", testTimeout).Return(km, nil) consumer.handler = handler + assignPartitions(t, consumer, mockKafka, testTopics[0], 1) err := consumer.Run(ctx) assert.Error(t, err) @@ -677,6 +678,7 @@ func TestConsumerCommitError(t *testing.T) { mockKafka.On("ReadMessage", testTimeout).Return(km, nil) consumer.handler = handler + assignPartitions(t, consumer, mockKafka, testTopics[0], 1) err := consumer.Run(ctx) assert.Error(t, err) From f33bdf708d4115657f4640b4929fb9c52ea79cf9 Mon Sep 17 00:00:00 2001 From: Ravi Atluri Date: Tue, 27 Jan 2026 09:28:54 +0530 Subject: [PATCH 4/5] Fix flaky test by sorting tags in ErrorTags.Error() for deterministic output --- errors/errors.go | 12 +++++++++--- errors/errors_test.go | 2 +- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/errors/errors.go b/errors/errors.go index 82b382e..72d6851 100644 --- a/errors/errors.go +++ b/errors/errors.go @@ -2,6 +2,7 @@ package errors import ( "errors" + "sort" "strings" ) @@ -22,10 +23,15 @@ type ErrorTags struct { // Error returns the error message with the tags attached. func (e *ErrorTags) Error() string { - md := []string{} + keys := make([]string, 0, len(e.tags)) + for k := range e.tags { + keys = append(keys, k) + } + sort.Strings(keys) - for k, v := range e.tags { - md = append(md, k+"="+v) + md := make([]string, 0, len(e.tags)) + for _, k := range keys { + md = append(md, k+"="+e.tags[k]) } return e.err.Error() + " [" + strings.Join(md, ", ") + "]" diff --git a/errors/errors_test.go b/errors/errors_test.go index 44ae125..dab5ada 100644 --- a/errors/errors_test.go +++ b/errors/errors_test.go @@ -28,7 +28,7 @@ func TestWrap_MultipleWraps(t *testing.T) { wrapped1 := errors.Wrap(err, "foo", "bar") wrapped2 := errors.Wrap(wrapped1, "baz", "qux") - assert.Equal(t, "test error [foo=bar, baz=qux]", wrapped2.Error()) + assert.Equal(t, "test error [baz=qux, foo=bar]", wrapped2.Error()) } func TestWrap_PanicsOnOddNumberOfAttrs(t *testing.T) { From 0500965b46b9febf9285fa6d0ded6a49f8279b19 Mon Sep 17 00:00:00 2001 From: Ravi Atluri Date: Tue, 27 Jan 2026 09:32:32 +0530 Subject: [PATCH 5/5] Update GitHub Actions to use setup-go v5 with built-in caching --- .github/workflows/test.yml | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index a63c6c5..ca2f367 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -10,22 +10,18 @@ jobs: test: runs-on: ubuntu-latest steps: + - name: Checkout code + uses: actions/checkout@v4 - name: Install Go - uses: actions/setup-go@v2 + uses: actions/setup-go@v5 with: go-version: 1.25.x - - name: Checkout code - uses: actions/checkout@v2 - - name: Go mod cache - uses: actions/cache@v4 - with: - path: ~/go/pkg/mod - key: ${{ runner.os }}-go1.25.x + cache: true - name: Tools bin cache uses: actions/cache@v4 with: path: .bin - key: ${{ runner.os }}-go1.25.x-${{ hashFiles('Makefile') }} + key: ${{ runner.os }}-tools-${{ hashFiles('Makefile') }} - name: Check run: make check - name: Test