diff --git a/checkstyle/suppressions.xml b/checkstyle/suppressions.xml index 4f252572fb22d..05d2af82bced2 100644 --- a/checkstyle/suppressions.xml +++ b/checkstyle/suppressions.xml @@ -225,7 +225,8 @@ - + diff --git a/streams/test-utils/src/main/java/org/apache/kafka/streams/MultiPartitionRuntime.java b/streams/test-utils/src/main/java/org/apache/kafka/streams/MultiPartitionRuntime.java new file mode 100644 index 0000000000000..8d62fc8e679b3 --- /dev/null +++ b/streams/test-utils/src/main/java/org/apache/kafka/streams/MultiPartitionRuntime.java @@ -0,0 +1,579 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.MockConsumer; +import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.clients.producer.MockProducer; +import org.apache.kafka.clients.producer.ProducerRecord; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.header.Headers; +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.common.record.TimestampType; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.common.utils.internals.LogContext; +import org.apache.kafka.streams.TopologyConfig.TaskConfig; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.GlobalStateManager; +import org.apache.kafka.streams.processor.internals.InternalProcessorContext; +import org.apache.kafka.streams.processor.internals.InternalTopologyBuilder; +import org.apache.kafka.streams.processor.internals.ProcessorContextImpl; +import org.apache.kafka.streams.processor.internals.ProcessorRecordContext; +import org.apache.kafka.streams.processor.internals.ProcessorStateManager; +import org.apache.kafka.streams.processor.internals.ProcessorTopology; +import org.apache.kafka.streams.processor.internals.RecordCollector; +import org.apache.kafka.streams.processor.internals.RecordCollectorImpl; +import org.apache.kafka.streams.processor.internals.StateDirectory; +import org.apache.kafka.streams.processor.internals.StreamTask; +import org.apache.kafka.streams.processor.internals.StreamsProducer; +import org.apache.kafka.streams.processor.internals.Task; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.streams.processor.internals.metrics.TaskMetrics; +import org.apache.kafka.streams.state.internals.ThreadCache; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Queue; +import java.util.Set; +import java.util.TreeMap; +import java.util.concurrent.atomic.AtomicLong; + +/** + * Owns the multi-partition task graph and record routing for {@link TopologyTestDriver}. It builds + * one {@link StreamTask} per {@code (subtopologyId, partition)} pair from a + * {@link MultiPartitionTopologyPlan} and drives processing, punctuation, state-store access and + * shutdown of those tasks. It collaborates with the driver through the {@link Host} callbacks for + * the work that remains the driver's responsibility (transactional commit, global-state updates, + * global-partition lookups and recording output records). + */ +final class MultiPartitionRuntime { + + private static final Logger log = LoggerFactory.getLogger(MultiPartitionRuntime.class); + + private final MultiPartitionTopologyPlan plan; + private final InternalTopologyBuilder internalTopologyBuilder; + private final MockConsumer consumer; + private final MockProducer producer; + private final StreamsProducer testDriverProducer; + private final GlobalStateManager globalStateManager; + private final StreamsConfig streamsConfig; + private final TaskConfig taskConfig; + private final StreamsMetricsImpl streamsMetrics; + private final ThreadCache cache; + private final StateDirectory stateDirectory; + private final LogContext logContext; + private final Time wallClockTime; + private final Host host; + + private final TreeMap tasks = new TreeMap<>(); + private final Map taskByTopicPartition = new HashMap<>(); + private final Map>>> outputByTopicPartition = new HashMap<>(); + private final Map nullKeyRoundRobinByTopic = new HashMap<>(); + private final Map offsets = new HashMap<>(); + + interface Host { + void commit(Map offsets); + void processGlobalRecord(TopicPartition partition, long timestamp, byte[] key, byte[] value, Headers headers); + TopicPartition globalPartitionOrNull(String topic); + void recordOutput(String topic, ProducerRecord record); + } + + MultiPartitionRuntime(final MultiPartitionTopologyPlan plan, + final InternalTopologyBuilder internalTopologyBuilder, + final MockConsumer consumer, + final MockProducer producer, + final StreamsProducer testDriverProducer, + final GlobalStateManager globalStateManager, + final StreamsConfig streamsConfig, + final TaskConfig taskConfig, + final StreamsMetricsImpl streamsMetrics, + final ThreadCache cache, + final StateDirectory stateDirectory, + final LogContext logContext, + final Time wallClockTime, + final Host host) { + this.plan = plan; + this.internalTopologyBuilder = internalTopologyBuilder; + this.consumer = consumer; + this.producer = producer; + this.testDriverProducer = testDriverProducer; + this.globalStateManager = globalStateManager; + this.streamsConfig = streamsConfig; + this.taskConfig = taskConfig; + this.streamsMetrics = streamsMetrics; + this.cache = cache; + this.stateDirectory = stateDirectory; + this.logContext = logContext; + this.wallClockTime = wallClockTime; + this.host = host; + } + + /** + * Build one {@link StreamTask} per {@code (subtopologyId, partition)} pair using the structures + * computed by the plan. All tasks share the driver's single {@link #consumer} and the + * {@link #testDriverProducer} as their record collector's producer. + */ + void build() { + final List allSourcePartitions = new ArrayList<>(); + final String threadId = Thread.currentThread().getName(); + + for (final int sid : plan.subtopologyIds()) { + final ProcessorTopology pt = plan.subtopology(sid); + if (pt.sourceTopics().isEmpty()) { + continue; + } + final int numPartitions = plan.partitionsOfSubtopology(sid); + + // Register an offset counter for every (source-topic, partition) the sub-topology consumes. + for (final String src : pt.sourceTopics()) { + final int n = plan.partitionsOfTopic(src); + for (int p = 0; p < n; p++) { + final TopicPartition tp = new TopicPartition(src, p); + offsets.putIfAbsent(tp, new AtomicLong()); + allSourcePartitions.add(tp); + } + } + + for (int p = 0; p < numPartitions; p++) { + // Build a fresh ProcessorTopology per task: ProcessorNode state (sources, processors, + // store handles) is single-init and would otherwise throw "The processor is not closed" + // when the second task tries to initialize the same instance. + final ProcessorTopology freshPt = internalTopologyBuilder.buildSubtopology(sid); + buildOneTask(sid, p, freshPt, threadId); + } + } + + if (!allSourcePartitions.isEmpty()) { + consumer.assign(allSourcePartitions); + final Map startOffsets = new HashMap<>(); + for (final TopicPartition tp : allSourcePartitions) { + startOffsets.put(tp, 0L); + } + consumer.updateBeginningOffsets(startOffsets); + consumer.updateEndOffsets(startOffsets); + } + } + + private void buildOneTask(final int sid, + final int partition, + final ProcessorTopology pt, + final String threadId) { + final TaskId taskId = new TaskId(sid, partition); + TaskMetrics.droppedRecordsSensor(threadId, taskId.toString(), streamsMetrics); + + // This task owns partition {@code p} of each source topic that has at least p+1 partitions. + final Set inputPartitions = new HashSet<>(); + for (final String src : pt.sourceTopics()) { + final int n = plan.partitionsOfTopic(src); + if (partition < n) { + final TopicPartition tp = new TopicPartition(src, partition); + inputPartitions.add(tp); + taskByTopicPartition.put(tp, taskId); + } + } + if (inputPartitions.isEmpty()) { + return; + } + + final ProcessorStateManager stateManager = new ProcessorStateManager( + taskId, + Task.TaskType.ACTIVE, + StreamsConfig.EXACTLY_ONCE_V2.equals(streamsConfig.getString(StreamsConfig.PROCESSING_GUARANTEE_CONFIG)), + streamsConfig.getBoolean(StreamsConfig.TRANSACTIONAL_STATE_STORES_CONFIG), + logContext, + stateDirectory, + pt.storeToChangelogTopic(), + new HashSet<>(inputPartitions)); + final RecordCollector recordCollector = new RecordCollectorImpl( + logContext, + taskId, + testDriverProducer, + streamsConfig.productionExceptionHandler(), + streamsMetrics, + pt + ); + final InternalProcessorContext context = new ProcessorContextImpl( + taskId, + streamsConfig, + stateManager, + streamsMetrics, + cache + ); + final StreamTask task = new StreamTask( + taskId, + new HashSet<>(inputPartitions), + pt, + consumer, + taskConfig, + streamsMetrics, + stateDirectory, + cache, + wallClockTime, + stateManager, + recordCollector, + context, + logContext, + false + ); + task.initializeIfNeeded(); + task.completeRestoration(noOpResetter -> { }); + task.processorContext().setRecordContext(null); + tasks.put(taskId, task); + } + + /** + * Resolve the partition a record routes to. + * Explicit partition wins; otherwise {@code Utils.toPositive(Utils.murmur2(keyBytes)) % n} matches + * {@code BuiltInPartitioner.partitionForKey}; null key or n == 1 routes to partition 0. + */ + private int resolvePartition(final String topic, final byte[] keyBytes, final int explicit) { + final int n = Math.max(1, plan.partitionsOfTopic(topic)); + // A negative explicit partition is the "unset" sentinel (TestRecord default): route by key instead. + if (explicit >= 0) { + if (explicit >= n) { + throw new IllegalArgumentException( + "Partition " + explicit + " is out of range for topic '" + topic + + "' (has " + n + " partitions). Declare a higher count via declareTopic() if needed."); + } + return explicit; + } + if (n == 1) { + return 0; + } + if (keyBytes == null) { + // Distribute null-key records round-robin across the topic's partitions. + final int count = nullKeyRoundRobinByTopic.merge(topic, 1, Integer::sum); + return (count - 1) % n; + } + return Utils.toPositive(Utils.murmur2(keyBytes)) % n; + } + + /** + * Multi-sub-topology pipe path. Routes the record to the task owning the resolved + * (topic, partition) and drains every task to quiescence before returning. + */ + void pipeRecord(final String topicName, + final long timestamp, + final byte[] key, + final byte[] value, + final Headers headers, + final int explicitPartition) { + final boolean isTaskInput = (plan.subtopologyForInputTopic(topicName) != null); + final TopicPartition globalPartition = host.globalPartitionOrNull(topicName); + final boolean isGlobal = globalPartition != null; + if (!isTaskInput && !isGlobal) { + throw new IllegalArgumentException("Unknown topic: " + topicName); + } + if (isTaskInput) { + final int partition = resolvePartition(topicName, key, explicitPartition); + enqueueTaskRecord(topicName, new TopicPartition(topicName, partition), + timestamp, key, value, headers); + completeAllProcessableWork(); + } + if (isGlobal) { + host.processGlobalRecord(globalPartition, timestamp, key, value, headers); + } + } + + private void enqueueTaskRecord(final String topic, + final TopicPartition tp, + final long timestamp, + final byte[] key, + final byte[] value, + final Headers headers) { + final TaskId taskId = taskByTopicPartition.get(tp); + if (taskId == null) { + throw new IllegalStateException( + "No task owns " + tp + ". This typically means init() was not called or the topic " + + "was not declared with enough partitions."); + } + final StreamTask owner = tasks.get(taskId); + if (owner == null) { + throw new IllegalStateException("Task " + taskId + " is registered but no StreamTask exists for it."); + } + final long offset = offsets + .computeIfAbsent(tp, k -> new AtomicLong()) + .getAndIncrement(); + owner.addRecords(tp, Collections.singleton(new ConsumerRecord<>( + topic, tp.partition(), offset, timestamp, TimestampType.CREATE_TIME, + key == null ? ConsumerRecord.NULL_SIZE : key.length, + value == null ? ConsumerRecord.NULL_SIZE : value.length, + key, value, headers, Optional.empty()))); + } + + /** + * Drain every multi-sub-topology task to quiescence, picking the processable task with the lowest + * current stream time on each iteration to mirror {@code PartitionGroup} ordering across tasks. + */ + void completeAllProcessableWork() { + captureOutputs(); + if (tasks.isEmpty()) { + return; + } + StreamTask next; + while ((next = pickNextProcessableTask()) != null) { + next.resumePollingForPartitionsWithAvailableSpace(); + next.updateLags(); + next.process(wallClockTime.milliseconds()); + next.maybePunctuateStreamTime(); + host.commit(next.prepareCommit(true)); + next.postCommit(true); + captureOutputs(); + } + for (final StreamTask t : tasks.values()) { + if (t.hasRecordsQueued()) { + log.info("Multi-sub task {} has records that cannot be processed right now; advance " + + "wall-clock time or pipe records on co-partitioned topics (see {}).", + t.id(), StreamsConfig.MAX_TASK_IDLE_MS_CONFIG); + } + } + } + + private StreamTask pickNextProcessableTask() { + StreamTask best = null; + long bestTime = Long.MAX_VALUE; + final long now = wallClockTime.milliseconds(); + for (final StreamTask t : tasks.values()) { + if (!t.hasRecordsQueued() || !t.isProcessable(now)) { + continue; + } + final long streamTime = ((ProcessorContextImpl) t.processorContext()).currentStreamTimeMs(); + if (streamTime < bestTime) { + bestTime = streamTime; + best = t; + } + } + return best; + } + + /** + * Capture all records emitted by the shared producer this round, partition them in + * {@link #outputByTopicPartition} (and hand them to the host for back-compat with the existing + * read accessors), and loop back into any sub-topology that consumes the topic. + * Honours an explicit producer partition when set (custom {@link org.apache.kafka.streams.processor.StreamPartitioner} + * on a sink); otherwise resolves by key. + */ + private void captureOutputs() { + final List> output = producer.history(); + producer.clear(); + for (final ProducerRecord record : output) { + final String topic = record.topic(); + final Integer producedPartition = record.partition(); + // MockProducer leaves partition() null when the upstream code did not pin one. Resolve it + // ourselves so the output record reflects the partition the test driver actually routes to. + final int capturedPartition = producedPartition != null + ? producedPartition + : resolvePartition(topic, record.key(), -1); + final ProducerRecord stamped = producedPartition != null + ? record + : new ProducerRecord<>(topic, capturedPartition, record.timestamp(), + record.key(), record.value(), record.headers()); + + host.recordOutput(topic, stamped); + outputByTopicPartition + .computeIfAbsent(topic, k -> new HashMap<>()) + .computeIfAbsent(capturedPartition, k -> new LinkedList<>()) + .add(stamped); + + if (plan.subtopologyForInputTopic(topic) != null) { + enqueueTaskRecord(topic, new TopicPartition(topic, capturedPartition), + record.timestamp(), record.key(), record.value(), record.headers()); + } + final TopicPartition globalPartition = host.globalPartitionOrNull(topic); + if (globalPartition != null) { + host.processGlobalRecord(globalPartition, + record.timestamp(), record.key(), record.value(), record.headers()); + } + } + } + + /** + * Multi-sub-topology lookup. A global store match wins. Otherwise, the no-argument + * accessors are only valid for a store registered in a sub-topology that has exactly one + * partition in the declared plan -- not merely "happens to have one task built right + * now". A store registered in a sub-topology whose declared partition count is > 1 always + * throws {@link IllegalStateException}, even if only one of its tasks has been instantiated + * (e.g. because one of its source topics has fewer partitions than the sub-topology's max): no + * single partition can be inferred, and silently resolving to whichever task happens to exist + * would be surprising. + */ + StateStore getStateStore(final String name, final boolean throwForBuiltInStores) { + if (globalStateManager != null) { + final StateStore gs = globalStateManager.store(name); + if (gs != null) { + if (throwForBuiltInStores) { + TopologyTestDriver.throwIfBuiltInStore(gs); + } + return gs; + } + } + final Integer sid = subtopologyOwningStore(name); + if (sid == null) { + return null; + } + final int declaredPartitions = plan.partitionsOfSubtopology(sid); + if (declaredPartitions > 1) { + throw new IllegalStateException( + "Store '" + name + "' is registered in sub-topology " + sid + ", which is declared " + + "with " + declaredPartitions + " partitions; no single partition can be inferred. " + + "Use getStateStore(name, partition) to access a specific partition."); + } + // declaredPartitions == 1: exactly one task exists for this sub-topology (partition 0). + final StreamTask only = tasks.get(new TaskId(sid, 0)); + if (only == null) { + return null; + } + only.processorContext().setRecordContext( + new ProcessorRecordContext(0L, -1L, -1, null, new RecordHeaders())); + final StateStore stateStore = ((ProcessorContextImpl) only.processorContext()).stateManager().store(name); + if (throwForBuiltInStores && stateStore != null) { + TopologyTestDriver.throwIfBuiltInStore(stateStore); + } + return stateStore; + } + + private Integer subtopologyOwningStore(final String name) { + Integer found = null; + for (final StreamTask t : tasks.values()) { + final StateStore s = ((ProcessorContextImpl) t.processorContext()).stateManager().store(name); + if (s == null) { + continue; + } + final int sid = t.id().subtopology(); + if (found != null && found != sid) { + throw new IllegalStateException( + "Store '" + name + "' is registered in more than one sub-topology (" + + found + " and " + sid + ")."); + } + found = sid; + } + return found; + } + + /** + * Return the {@link StateStore} for the task owning {@code partition} of the sub-topology that + * registers a store named {@code name}. If the store name appears in multiple + * sub-topologies, throws {@link IllegalStateException}. + * + * @param name the store name + * @param partition the partition whose owning task should be queried + * @return the {@link StateStore}, or {@code null} if no sub-topology registers a store with this name + */ + StateStore getStateStore(final String name, final int partition) { + if (globalStateManager != null) { + final StateStore gs = globalStateManager.store(name); + if (gs != null) { + return gs; + } + } + final Integer sid = subtopologyOwningStore(name); + if (sid == null) { + return null; + } + return getStateStore(name, sid, partition); + } + + /** + * Internal fully-qualified {@link StateStore} accessor: resolves a store to the task owning + * {@code (subtopologyId, partition)}. + * + * @param name the store name + * @param subtopologyId the sub-topology id + * @param partition the partition whose owning task should be queried + * @return the {@link StateStore}, or {@code null} if the task does not register a store with this name + * @throws IllegalArgumentException if no task exists for {@code (subtopologyId, partition)} + */ + StateStore getStateStore(final String name, final int subtopologyId, final int partition) { + final TaskId taskId = new TaskId(subtopologyId, partition); + final StreamTask owner = tasks.get(taskId); + if (owner == null) { + throw new IllegalArgumentException( + "No task exists for " + taskId + " (sub-topology " + subtopologyId + " has " + + plan.partitionsOfSubtopology(subtopologyId) + " partition(s))."); + } + owner.processorContext().setRecordContext( + new ProcessorRecordContext(0L, -1L, -1, null, new RecordHeaders())); + return ((ProcessorContextImpl) owner.processorContext()).stateManager().store(name); + } + + /** + * @return the number of partitions of the sub-topology that registers {@code storeName}, or 0 + * if no sub-topology registers it (or 1 for a global store). + */ + int partitionsOf(final String storeName) { + if (globalStateManager != null && globalStateManager.store(storeName) != null) { + return 1; + } + final Integer sid = subtopologyOwningStore(storeName); + return sid == null ? 0 : plan.partitionsOfSubtopology(sid); + } + + /** + * @return the number of partitions of the given sub-topology, or 0 if the id is unknown. + */ + int partitionsOfSubtopology(final int subtopologyId) { + return plan.partitionsOfSubtopology(subtopologyId); + } + + /** + * @return the list of the sub-topology ids in this runtime. + */ + List subtopologies() { + return plan.subtopologyIds(); + } + + /** + * Advance wall-clock time across every multi-sub-topology task, firing system-time punctuators, + * committing and then draining processable work. + */ + void advanceWallClockTime() { + for (final StreamTask t : tasks.values()) { + t.maybePunctuateSystemTime(); + host.commit(t.prepareCommit(true)); + t.postCommit(true); + } + completeAllProcessableWork(); + } + + /** + * Suspend, commit and close every multi-sub-topology task, swallowing per-task close failures. + */ + void closeTasks() { + for (final StreamTask t : tasks.values()) { + try { + t.suspend(); + t.prepareCommit(true); + t.postCommit(true); + t.closeClean(); + } catch (final RuntimeException e) { + log.warn("Error closing multi-sub task {}: {}", t.id(), e.toString()); + } + } + } +} diff --git a/streams/test-utils/src/main/java/org/apache/kafka/streams/MultiPartitionTopologyPlan.java b/streams/test-utils/src/main/java/org/apache/kafka/streams/MultiPartitionTopologyPlan.java new file mode 100644 index 0000000000000..ceab09041934e --- /dev/null +++ b/streams/test-utils/src/main/java/org/apache/kafka/streams/MultiPartitionTopologyPlan.java @@ -0,0 +1,330 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams; + +import org.apache.kafka.streams.errors.TopologyException; +import org.apache.kafka.streams.processor.internals.InternalTopologyBuilder; +import org.apache.kafka.streams.processor.internals.ProcessorTopology; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * Computes the multi-partition layout for a {@link TopologyTestDriver}: which node groups are task + * sub-topologies, how many partitions each (input, output, and internal repartition) topic has, and + * the resulting partition count of each sub-topology. This is pure planning logic — it reads the + * topology and the user-declared partition counts and produces data structures; it builds no tasks + * and touches no producer or consumer. {@link TopologyTestDriver#init()} runs {@link #compute()} + * once and hands the results to the runtime that builds the task graph. + * + *

The partition count of an internal repartition topic is resolved by a layered rule, highest + * precedence first: an explicit {@code Repartitioned.withNumberOfPartitions(N)} (or a count the user + * declared); inheritance from a co-partition-group peer; the max partition count across the producing + * sub-topology's source topics; otherwise a fallback to 1 with a warning. The resolution iterates to + * a fixed point because chains of internal topics can depend on each other.

+ */ +final class MultiPartitionTopologyPlan { + + private static final Logger log = LoggerFactory.getLogger(MultiPartitionTopologyPlan.class); + + private final InternalTopologyBuilder internalTopologyBuilder; + private final Set globalSourceTopics; + + // Resolved partition count per topic. Seeded with the user's declarations and completed by compute(). + private final Map partitionsByTopic; + + private final List subtopologyIds = new ArrayList<>(); + private final Map subtopologyTopologies = new HashMap<>(); + private final Map partitionsBySubtopology = new HashMap<>(); + private final Map subtopologyByInputTopic = new HashMap<>(); + private final Map sinkTopicToSubtopology = new HashMap<>(); + + /** + * @param internalTopologyBuilder the rewritten topology under test + * @param globalTopology the global {@link ProcessorTopology}, or {@code null} if the + * topology has no global stores; its source topics are excluded + * from task sub-topologies + * @param declaredPartitions the user-declared partition counts (copied defensively) + */ + MultiPartitionTopologyPlan(final InternalTopologyBuilder internalTopologyBuilder, + final ProcessorTopology globalTopology, + final Map declaredPartitions) { + this.internalTopologyBuilder = internalTopologyBuilder; + this.globalSourceTopics = globalTopology == null + ? Collections.emptySet() + : new HashSet<>(globalTopology.sourceTopics()); + this.partitionsByTopic = new HashMap<>(declaredPartitions); + } + + /** + * Run the full planning sequence: enumerate task sub-topologies, resolve every topic's partition + * count, validate co-partitioning, and compute each sub-topology's partition count. + * + * @throws TopologyException if a co-partition group has mismatching declared partition counts + */ + void compute() { + enumerateTaskSubtopologies(); + final Set internalRepartitionTopics = collectInternalRepartitionTopics(); + seedPartitionCounts(internalRepartitionTopics); + resolveInternalRepartitionTopicPartitions(internalRepartitionTopics); + validateCopartitioning(); + computePartitionsBySubtopology(); + } + + // --- results --- + + List subtopologyIds() { + return Collections.unmodifiableList(subtopologyIds); + } + + ProcessorTopology subtopology(final int subtopologyId) { + return subtopologyTopologies.get(subtopologyId); + } + + int partitionsOfSubtopology(final int subtopologyId) { + return partitionsBySubtopology.getOrDefault(subtopologyId, 0); + } + + int partitionsOfTopic(final String topic) { + return partitionsByTopic.getOrDefault(topic, 1); + } + + Integer subtopologyForInputTopic(final String topic) { + return subtopologyByInputTopic.get(topic); + } + + Map resolvedTopicPartitions() { + return Collections.unmodifiableMap(partitionsByTopic); + } + + // --- planning steps --- + + /** + * Build each task sub-topology's {@link ProcessorTopology} and record the {@code (sink topic -> + * sub-topology)} mapping. A node group whose only source topics are global is not a task + * sub-topology (global state is fed separately), so it is skipped -- this mirrors + * {@code InternalTopologyBuilder#subtopologyToTopicsInfo()}, which drops any group left with no + * non-global source topic. + */ + private void enumerateTaskSubtopologies() { + for (final int id : internalTopologyBuilder.nodeGroups().keySet()) { + final ProcessorTopology pt = internalTopologyBuilder.buildSubtopology(id); + if (!hasNonGlobalSourceTopic(pt, globalSourceTopics)) { + continue; + } + subtopologyIds.add(id); + subtopologyTopologies.put(id, pt); + // A repartition topic is produced by the sub-topology whose sink writes it; remember that + // mapping so the upstream-max resolution can find the producer. + for (final String sink : pt.sinkTopics()) { + sinkTopicToSubtopology.putIfAbsent(sink, id); + } + } + Collections.sort(subtopologyIds); + } + + private static boolean hasNonGlobalSourceTopic(final ProcessorTopology pt, + final Set globalSourceTopics) { + for (final String src : pt.sourceTopics()) { + if (!globalSourceTopics.contains(src)) { + return true; + } + } + return false; + } + + /** + * Seed the partition-count map before the 3-layer resolution runs: pin any repartition topic the + * builder already fixed (e.g. via {@code Repartitioned.withNumberOfPartitions}) so the upstream-max + * layer cannot overwrite it, and default user-declared source topics to 1. Internal repartition + * topics are deliberately left unset here so {@link #resolveInternalRepartitionTopicPartitions(Set)} + * resolves them. + */ + private void seedPartitionCounts(final Set allInternalRepartitionTopics) { + for (final Map.Entry entry : explicitRepartitionTopicPartitionCounts().entrySet()) { + partitionsByTopic.putIfAbsent(entry.getKey(), entry.getValue()); + } + for (final int sid : subtopologyIds) { + final ProcessorTopology pt = subtopologyTopologies.get(sid); + for (final String src : pt.sourceTopics()) { + subtopologyByInputTopic.put(src, sid); + if (!allInternalRepartitionTopics.contains(src)) { + partitionsByTopic.putIfAbsent(src, 1); + } + } + } + } + + /** Per-sub-topology partition count = max across its source topics. */ + private void computePartitionsBySubtopology() { + for (final int sid : subtopologyIds) { + final ProcessorTopology pt = subtopologyTopologies.get(sid); + int max = 1; + for (final String src : pt.sourceTopics()) { + max = Math.max(max, partitionsByTopic.getOrDefault(src, 1)); + } + partitionsBySubtopology.put(sid, max); + } + } + + private Set collectInternalRepartitionTopics() { + final Set internalTopics = new HashSet<>(); + for (final InternalTopologyBuilder.TopicsInfo info : internalTopologyBuilder.subtopologyToTopicsInfo().values()) { + internalTopics.addAll(info.repartitionSourceTopics.keySet()); + } + return internalTopics; + } + + /** + * Internal repartition topics whose partition count was pinned explicitly (e.g. via + * {@code Repartitioned.withNumberOfPartitions}); topics left to upstream inheritance are absent. + */ + private Map explicitRepartitionTopicPartitionCounts() { + final Map result = new HashMap<>(); + for (final InternalTopologyBuilder.TopicsInfo info : internalTopologyBuilder.subtopologyToTopicsInfo().values()) { + info.repartitionSourceTopics.forEach((name, config) -> + config.numberOfPartitions().ifPresent(n -> result.put(name, n))); + } + return result; + } + + /** + * Resolve partition counts for internal repartition topics using the 3-layer rule: + * (1) explicit declaration wins; (2) co-partition group inheritance from a declared peer; (3) max + * partition count across the producing sub-topology's source topics; iterate to a fixed point + * because chains of internal topics can depend on each other. Topics still unresolved fall back to + * 1 partition with a warning. + */ + private void resolveInternalRepartitionTopicPartitions(final Set internalTopics) { + boolean changed = true; + while (changed) { + changed = false; + for (final String topic : internalTopics) { + if (tryResolveOneInternalTopic(topic)) { + changed = true; + } + } + } + + for (final String topic : internalTopics) { + if (!partitionsByTopic.containsKey(topic)) { + log.warn("Could not resolve partition count for internal repartition topic '{}'; defaulting to 1. " + + "Declare it explicitly via declareTopic() if a different count is needed.", topic); + partitionsByTopic.put(topic, 1); + } + } + } + + /** + * Try to resolve a single internal repartition topic's partition count this round. Returns + * {@code true} if progress was made (the topic now has a count); {@code false} if it cannot be + * resolved yet (caller will iterate to a fixed point) or is already resolved. + */ + private boolean tryResolveOneInternalTopic(final String topic) { + if (partitionsByTopic.containsKey(topic)) { + return false; + } + final Integer fromCopartition = resolveFromCopartitionGroup(topic); + if (fromCopartition != null) { + partitionsByTopic.put(topic, fromCopartition); + return true; + } + return tryResolveFromUpstreamSubtopology(topic); + } + + private boolean tryResolveFromUpstreamSubtopology(final String topic) { + final Integer producerSid = sinkTopicToSubtopology.get(topic); + if (producerSid == null) { + return false; + } + final ProcessorTopology pt = subtopologyTopologies.get(producerSid); + if (pt == null) { + return false; + } + Integer max = null; + for (final String src : pt.sourceTopics()) { + final Integer n = partitionsByTopic.get(src); + if (n == null) { + return false; + } + if (max == null || n > max) { + max = n; + } + } + if (max == null) { + return false; + } + partitionsByTopic.put(topic, max); + return true; + } + + /** + * If {@code topic} participates in a co-partition group with any topic that already has a declared + * count, return that count. Returns {@code null} if the topic is unconstrained by co-partitioning. + */ + private Integer resolveFromCopartitionGroup(final String topic) { + for (final Set group : internalTopologyBuilder.copartitionGroups()) { + if (!group.contains(topic)) { + continue; + } + for (final String peer : group) { + if (peer.equals(topic)) { + continue; + } + final Integer n = partitionsByTopic.get(peer); + if (n != null) { + return n; + } + } + } + return null; + } + + /** + * Validate that all topics in each co-partition group share the same declared partition count. + * Throws {@link TopologyException} naming the two witnessing topics on conflict. + */ + private void validateCopartitioning() { + for (final Set group : internalTopologyBuilder.copartitionGroups()) { + Integer expected = null; + String witness = null; + for (final String topic : group) { + final Integer n = partitionsByTopic.get(topic); + if (n == null) { + continue; + } + if (expected == null) { + expected = n; + witness = topic; + } else if (!expected.equals(n)) { + throw new TopologyException( + "Co-partitioned topics have mismatching partition counts: '" + witness + "' has " + + expected + " but '" + topic + "' has " + n + + ". Declare matching counts via declareTopic() before piping records."); + } + } + } + } +} diff --git a/streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriver.java b/streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriver.java index 7573f8a0f1e19..e12d93ca35124 100644 --- a/streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriver.java +++ b/streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriver.java @@ -109,6 +109,7 @@ import java.io.IOException; import java.time.Duration; import java.time.Instant; +import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; @@ -255,6 +256,16 @@ public class TopologyTestDriver implements Closeable { private final Map>> outputRecordsByTopic = new HashMap<>(); private final StreamsConfigUtils.ProcessingMode processingMode; + // Multi-partition lifecycle (declareTopic/init). The fields below back the new API only; + // the legacy single-partition execution path does not consult them and continues to work unchanged. + private final Map declaredPartitionsByTopic = new HashMap<>(); + private boolean multiPartitionModeActive = false; + private MultiPartitionRuntime runtime; + private final StreamsConfig multiSubStreamsConfig; + private final TaskConfig multiSubTaskConfig; + private final StreamsMetricsImpl multiSubStreamsMetrics; + private final ThreadCache multiSubCache; + private final StateRestoreListener stateRestoreListener = new StateRestoreListener() { @Override public void onRestoreStart(final TopicPartition topicPartition, final String storeName, final long startingOffset, final long endingOffset) {} @@ -364,7 +375,17 @@ public TopologyTestDriver(final Topology topology, producer = new MockProducer<>(Cluster.empty(), true, null, bytesSerializer, bytesSerializer) { @Override public List partitionsFor(final String topic) { - return Collections.singletonList(new PartitionInfo(topic, PARTITION_ID, null, null, null)); + // When topics are declared with > 1 partition, the sink-side partitioner + // (DefaultStreamPartitioner) must see them all to compute the right output partition. + final int n = Math.max(1, declaredPartitionsByTopic.getOrDefault(topic, 1)); + if (n == 1) { + return Collections.singletonList(new PartitionInfo(topic, PARTITION_ID, null, null, null)); + } + final List partitionInfos = new ArrayList<>(n); + for (int p = 0; p < n; p++) { + partitionInfos.add(new PartitionInfo(topic, p, null, null, null)); + } + return partitionInfos; } }; @@ -377,6 +398,12 @@ public List partitionsFor(final String topic) { setupGlobalTask(mockWallClockTime, streamsConfig, streamsMetrics, cache); setupTask(streamsConfig, streamsMetrics, cache, internalTopologyBuilder.topologyConfigs().getTaskConfig()); + + // Capture references the multi-sub-topology runtime path needs at init() time. + this.multiSubStreamsConfig = streamsConfig; + this.multiSubTaskConfig = internalTopologyBuilder.topologyConfigs().getTaskConfig(); + this.multiSubStreamsMetrics = streamsMetrics; + this.multiSubCache = cache; } private static void logIfTaskIdleEnabled(final StreamsConfig streamsConfig) { @@ -558,7 +585,13 @@ private void pipeRecord(final String topicName, final long timestamp, final byte[] key, final byte[] value, - final Headers headers) { + final Headers headers, + final int explicitPartition) { + if (multiPartitionModeActive) { + runtime.pipeRecord(topicName, timestamp, key, value, headers, explicitPartition); + return; + } + final TopicPartition inputTopicOrPatternPartition = getInputTopicOrPatternPartition(topicName); final TopicPartition globalInputTopicPartition = globalPartitionsByInputTopic.get(topicName); @@ -732,6 +765,10 @@ private void captureOutputsAndReEnqueueInternalResults() { public void advanceWallClockTime(final Duration advance) { Objects.requireNonNull(advance, "advance cannot be null"); mockWallClockTime.sleep(advance.toMillis()); + if (multiPartitionModeActive) { + runtime.advanceWallClockTime(); + return; + } if (task != null) { task.maybePunctuateSystemTime(); commit(task.prepareCommit(true)); @@ -744,8 +781,8 @@ private Queue> getRecordsQueue(final String topic final Queue> outputRecords = outputRecordsByTopic.get(topicName); if (outputRecords == null && !processorTopology.sinkTopics().contains(topicName)) { log.warn("Unrecognized topic: {}, this can occur if dynamic routing is used and no output has been " - + "sent to this topic yet. If not using a TopicNameExtractor, check that the output topic " - + "is correct.", topicName); + + "sent to this topic yet. If not using a TopicNameExtractor, check that the output topic " + + "is correct.", topicName); } return outputRecords; } @@ -812,6 +849,105 @@ public final TestOutputTopic createOutputTopic(final String topicNa return new TestOutputTopic<>(this, topicName, keyDeserializer, valueDeserializer); } + /** + * Declare the number of partitions for an input, output, or generated repartition topic. + * Must be called before any record is piped. Subsequent calls with the same count are no-ops; calls + * with a different count throw {@link IllegalArgumentException}. Calls after the driver has been + * initialised throw {@link IllegalStateException}. + * + * @param topicName the topic to declare + * @param partitions the number of partitions (must be at least 1) + * @throws IllegalStateException if the driver has already been initialised + * @throws IllegalArgumentException if {@code partitions} is less than 1, or the topic was already declared with a different count + */ + void declareTopic(final String topicName, final int partitions) { + Objects.requireNonNull(topicName, "topicName cannot be null"); + if (multiPartitionModeActive) { + throw new IllegalStateException( + "Cannot declare topic '" + topicName + "' after multi-partition mode has been activated; " + + "declare all multi-partition topics before piping records."); + } + if (partitions < 1) { + throw new IllegalArgumentException( + "Partition count must be at least 1 (topic='" + topicName + "', partitions=" + partitions + ")."); + } + final Integer existing = declaredPartitionsByTopic.get(topicName); + if (existing != null && existing != partitions) { + throw new IllegalArgumentException( + "Topic '" + topicName + "' was already declared with " + existing + + " partitions; cannot redeclare with " + partitions + "."); + } + declaredPartitionsByTopic.put(topicName, partitions); + } + + /** + * Activate multi-partition mode. Idempotent. Call this after declaring all multi-partition + * topics and before piping records. The single-partition back-compat path auto-activates on first use, + * so existing tests do not need to call this method. + * + *

This builds the sub-topology task graph: for each sub-topology, it constructs its + * {@link ProcessorTopology}, resolves the partition count of any internal repartition topic + * (declared explicit count > co-partition group inheritance > max upstream sources > + * fallback to 1), validates co-partitioning, and computes the per-sub-topology partition count + * as the max across its source topics.

+ */ + void activateMultiPartitionMode() { + if (multiPartitionModeActive) { + return; + } + + // Plan the multi-partition layout (task sub-topologies, per-topic and per-sub-topology + // partition counts, co-partition validation). + final MultiPartitionTopologyPlan plan = + new MultiPartitionTopologyPlan(internalTopologyBuilder, globalTopology, declaredPartitionsByTopic); + plan.compute(); + + // Mirror the resolved partition counts so the shared MockProducer sees the right topic layout. + declaredPartitionsByTopic.putAll(plan.resolvedTopicPartitions()); + + runtime = new MultiPartitionRuntime( + plan, + internalTopologyBuilder, + consumer, + producer, + testDriverProducer, + globalStateManager, + multiSubStreamsConfig, + multiSubTaskConfig, + multiSubStreamsMetrics, + multiSubCache, + stateDirectory, + logContext, + mockWallClockTime, + new MultiPartitionRuntime.Host() { + @Override + public void commit(final Map offsets) { + TopologyTestDriver.this.commit(offsets); + } + + @Override + public void processGlobalRecord(final TopicPartition partition, + final long timestamp, + final byte[] key, + final byte[] value, + final Headers headers) { + TopologyTestDriver.this.processGlobalRecord(partition, timestamp, key, value, headers); + } + + @Override + public TopicPartition globalPartitionOrNull(final String topic) { + return globalPartitionsByInputTopic.get(topic); + } + + @Override + public void recordOutput(final String topic, final ProducerRecord record) { + outputRecordsByTopic.computeIfAbsent(topic, k -> new LinkedList<>()).add(record); + } + }); + runtime.build(); + multiPartitionModeActive = true; + } + /** * Get all the names of all the topics to which records have been produced during the test run. *

@@ -848,7 +984,7 @@ TestRecord readRecord(final String topic, } final K key = keyDeserializer.deserialize(record.topic(), record.headers(), record.key()); final V value = valueDeserializer.deserialize(record.topic(), record.headers(), record.value()); - final int outputPartition = -1; + final int outputPartition = multiPartitionModeActive && record.partition() != null ? record.partition() : -1; return new TestRecord<>(key, value, record.headers(), Instant.ofEpochMilli(record.timestamp()), outputPartition); } @@ -868,7 +1004,7 @@ void pipeRecord(final String topic, throw new IllegalStateException("Provided `TestRecord` does not have a timestamp and no timestamp overwrite was provided via `time` parameter."); } - pipeRecord(topic, timestamp, serializedKey, serializedValue, record.headers()); + pipeRecord(topic, timestamp, serializedKey, serializedValue, record.headers(), record.partition()); } final long queueSize(final String topic) { @@ -949,6 +1085,9 @@ public StateStore getStateStore(final String name) throws IllegalArgumentExcepti private StateStore getStateStore(final String name, final boolean throwForBuiltInStores) { + if (multiPartitionModeActive) { + return runtime.getStateStore(name, throwForBuiltInStores); + } if (task != null) { // Accessing a store must not corrupt the task's record context. Only set a dummy // context when none exists yet (i.e. before any record has been processed) so that @@ -979,7 +1118,91 @@ private StateStore getStateStore(final String name, return null; } - private void throwIfBuiltInStore(final StateStore stateStore) { + /** + * Return the {@link StateStore} for the task owning {@code partition} of the sub-topology that + * registers a store named {@code name}. If the store name appears in multiple + * sub-topologies, throws {@link IllegalStateException}. + * + * @param name the store name + * @param partition the partition whose owning task should be queried + * @return the {@link StateStore}, or {@code null} if no sub-topology registers a store with this name + */ + public StateStore getStateStore(final String name, final int partition) { + requireMultiPartitionMode(); + return runtime.getStateStore(name, partition); + } + + /** + * Guard for the partition-aware accessors below ({@link #getStateStore(String, int)} and + * friends). Unlike the original implementation, this never activates multi-partition mode as a + * side effect of what looks like a read-only getter: a {@code getXxx()} method that can silently + * rebuild the entire task graph on first call -- and behave differently the second time it's + * called -- is surprising and hides a non-trivial effect behind an innocuous-looking signature. + * + *

Multi-partition mode must already be active by the time these accessors are called, either + * because {@link TopologyTestDriverBuilder#build()} activated it (at least one declared topic has + * more than one partition), or because the caller invoked {@link #activateMultiPartitionMode()} + * explicitly. If it isn't, this throws -- it does not activate it for you. + * + * @throws IllegalStateException if the driver is not operating in multi-partition mode + */ + private void requireMultiPartitionMode() { + if (!multiPartitionModeActive) { + throw new IllegalStateException( + "This driver is not operating in multi-partition mode. Declare a topic with more than " + + "one partition (via TopologyTestDriverBuilder#declareTopic() or declareTopic()) " + + "and call activateMultiPartitionMode() -- or simply pipe a record first, which " + + "activates it automatically -- before calling partition-aware accessors like " + + "getStateStore(name, partition). Use getStateStore(name) for single-partition mode."); + } + } + + /** + * Internal fully-qualified {@link StateStore} accessor: resolves a store to the task owning + * {@code (subtopologyId, partition)}. Package-private -- not part of the KIP-1238 public API; + * callers use {@link #getStateStore(String, int)}, which resolves the sub-topology by store name. + * + * @param name the store name + * @param subtopologyId the sub-topology id + * @param partition the partition whose owning task should be queried + * @return the {@link StateStore}, or {@code null} if the task does not register a store with this name + * @throws IllegalArgumentException if no task exists for {@code (subtopologyId, partition)} + * @throws IllegalStateException if the driver is not operating in multi-partition mode + */ + StateStore getStateStore(final String name, final int subtopologyId, final int partition) { + requireMultiPartitionMode(); + return runtime.getStateStore(name, subtopologyId, partition); + } + + /** + * @return the number of partitions of the sub-topology that registers {@code storeName}, or 0 + * if no sub-topology registers it (or 1 for a global store). + * @throws IllegalStateException if the driver is not operating in multi-partition mode + */ + int partitionsOf(final String storeName) { + requireMultiPartitionMode(); + return runtime.partitionsOf(storeName); + } + + /** + * @return the number of partitions of the given sub-topology, or 0 if the id is unknown. + * @throws IllegalStateException if the driver is not operating in multi-partition mode + */ + int partitionsOfSubtopology(final int subtopologyId) { + requireMultiPartitionMode(); + return runtime.partitionsOfSubtopology(subtopologyId); + } + + /** + * @return an unmodifiable list of the sub-topology ids in this driver. + * @throws IllegalStateException if the driver is not operating in multi-partition mode + */ + List subtopologies() { + requireMultiPartitionMode(); + return runtime.subtopologies(); + } + + static void throwIfBuiltInStore(final StateStore stateStore) { if (stateStore instanceof VersionedKeyValueStore) { throw new IllegalArgumentException("Store " + stateStore.name() + " is a versioned key-value store and should be accessed via `getVersionedKeyValueStore()`"); @@ -1298,6 +1521,95 @@ public SessionStoreWithHeaders getSessionStoreWithHeaders(final Str return store instanceof SessionStoreWithHeaders ? (SessionStoreWithHeaders) store : null; } + /** + * Partition-aware {@link KeyValueStore} accessor. + */ + @SuppressWarnings("unchecked") + public KeyValueStore getKeyValueStore(final String name, final int partition) { + final StateStore store = getStateStore(name, partition); + if (store instanceof TimestampedKeyValueStore) { + log.info("Method #getTimestampedKeyValueStore() should be used to access a TimestampedKeyValueStore."); + return new KeyValueStoreFacade<>((TimestampedKeyValueStore) store); + } + return store instanceof KeyValueStore ? (KeyValueStore) store : null; + } + + /** + * Partition-aware {@link TimestampedKeyValueStore} accessor. + */ + @SuppressWarnings("unchecked") + public KeyValueStore> getTimestampedKeyValueStore(final String name, final int partition) { + final StateStore store = getStateStore(name, partition); + return store instanceof TimestampedKeyValueStore ? (TimestampedKeyValueStore) store : null; + } + + /** + * Partition-aware {@link VersionedKeyValueStore} accessor. + */ + @SuppressWarnings("unchecked") + public VersionedKeyValueStore getVersionedKeyValueStore(final String name, final int partition) { + final StateStore store = getStateStore(name, partition); + return store instanceof VersionedKeyValueStore ? (VersionedKeyValueStore) store : null; + } + + /** + * Partition-aware {@link WindowStore} accessor. + */ + @SuppressWarnings("unchecked") + public WindowStore getWindowStore(final String name, final int partition) { + final StateStore store = getStateStore(name, partition); + if (store instanceof TimestampedWindowStore) { + log.info("Method #getTimestampedWindowStore() should be used to access a TimestampedWindowStore."); + return new WindowStoreFacade<>((TimestampedWindowStore) store); + } + return store instanceof WindowStore ? (WindowStore) store : null; + } + + /** + * Partition-aware {@link TimestampedWindowStore} accessor. + */ + @SuppressWarnings("unchecked") + public WindowStore> getTimestampedWindowStore(final String name, final int partition) { + final StateStore store = getStateStore(name, partition); + return store instanceof TimestampedWindowStore ? (TimestampedWindowStore) store : null; + } + + /** + * Partition-aware {@link SessionStore} accessor. + */ + @SuppressWarnings("unchecked") + public SessionStore getSessionStore(final String name, final int partition) { + final StateStore store = getStateStore(name, partition); + return store instanceof SessionStore ? (SessionStore) store : null; + } + + /** + * Partition-aware {@link TimestampedKeyValueStoreWithHeaders} accessor. + */ + @SuppressWarnings("unchecked") + public KeyValueStore> getTimestampedKeyValueStoreWithHeaders(final String name, final int partition) { + final StateStore store = getStateStore(name, partition); + return store instanceof TimestampedKeyValueStoreWithHeaders ? (TimestampedKeyValueStoreWithHeaders) store : null; + } + + /** + * Partition-aware {@link TimestampedWindowStoreWithHeaders} accessor. + */ + @SuppressWarnings("unchecked") + public WindowStore> getTimestampedWindowStoreWithHeaders(final String name, final int partition) { + final StateStore store = getStateStore(name, partition); + return store instanceof TimestampedWindowStoreWithHeaders ? (TimestampedWindowStoreWithHeaders) store : null; + } + + /** + * Partition-aware {@link SessionStoreWithHeaders} accessor. + */ + @SuppressWarnings("unchecked") + public SessionStoreWithHeaders getSessionStoreWithHeaders(final String name, final int partition) { + final StateStore store = getStateStore(name, partition); + return store instanceof SessionStoreWithHeaders ? (SessionStoreWithHeaders) store : null; + } + /** * Close the driver, its topology, and all processors. */ @@ -1308,6 +1620,9 @@ public void close() { task.postCommit(true); task.closeClean(); } + if (multiPartitionModeActive) { + runtime.closeTasks(); + } if (globalStateTask != null) { try { globalStateTask.close(false); @@ -1315,7 +1630,11 @@ public void close() { // ignore } } - completeAllProcessableWork(); + if (multiPartitionModeActive) { + runtime.completeAllProcessableWork(); + } else { + completeAllProcessableWork(); + } if (task != null && task.hasRecordsQueued()) { log.warn("Found some records that cannot be processed due to the" + " {} configuration during TopologyTestDriver#close().", diff --git a/streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriverBuilder.java b/streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriverBuilder.java index bd7aa527bdfc0..c36cb0d6ada23 100644 --- a/streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriverBuilder.java +++ b/streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriverBuilder.java @@ -19,20 +19,27 @@ import java.time.Instant; import java.util.Objects; import java.util.Optional; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.Objects; import java.util.Properties; /** * Fluent builder for a {@link TopologyTestDriver}. * - *

This is the entry point for constructing a {@link TopologyTestDriver}. - * Configure the builder and call {@link #build()}. - * The {@link TopologyTestDriver} constructors remain functional but are deprecated in favor of - * this builder.

+ *

This is the entry point for constructing a {@link TopologyTestDriver}, for both + * single and multi-partition mode. Declare the partition count of each relevant topic, then + * call {@link #build()}: when at least one declared topic has more than one partition the driver wires + * its multi-partition task graph; declaring only single-partition topics (or none) keeps the legacy + * single-flat-task behaviour. + * The {@link TopologyTestDriver} constructors remain functional but are + * deprecated in favour of this builder.

* *
{@code
  * TopologyTestDriver driver = new TopologyTestDriverBuilder(topology)
  *     .withConfig(props)
  *     .withInitialWallClockTime(Instant.ofEpochMilli(0))
+ *     .declareTopic("input", 4)
  *     .build();
  * }
*/ @@ -41,6 +48,7 @@ public class TopologyTestDriverBuilder { private final Topology topology; private Properties config = new Properties(); private Optional initialWallClockTime = Optional.empty(); + private final Map declaredTopics = new LinkedHashMap<>(); /** * Start building a driver for the given topology. @@ -75,14 +83,43 @@ public TopologyTestDriverBuilder withInitialWallClockTime(final Instant initialW } /** - * Build the driver: construct it and apply all declared topic partition counts. + * Declare the number of partitions for an input, output, or internal repartition topic. + * + * @param topicName the topic to declare + * @param partitions the number of partitions (must be at least 1) + * @return this builder + * @throws IllegalArgumentException if {@code partitions} is less than 1, or the topic was already + * declared with a different count + */ + public TopologyTestDriverBuilder declareTopic(final String topicName, final int partitions) { + Objects.requireNonNull(topicName, "topicName cannot be null"); + if (partitions < 1) { + throw new IllegalArgumentException( + "Partition count must be at least 1 (topic='" + topicName + "', partitions=" + partitions + ")."); + } + final Integer existing = declaredTopics.putIfAbsent(topicName, partitions); + if (existing != null && existing != partitions) { + throw new IllegalArgumentException( + "Topic '" + topicName + "' was already declared with " + existing + + " partitions; cannot redeclare with " + partitions + "."); + } + return this; + } + + /** + * Build the driver: construct it, declare all topics, and—when at least one declared topic has more + * than one partition—create the multi-partition task graph. * * @return a ready-to-use {@link TopologyTestDriver} */ public TopologyTestDriver build() { - return new TopologyTestDriver( + final TopologyTestDriver driver = new TopologyTestDriver( topology.internalTopologyBuilder, config, initialWallClockTime.map(Instant::toEpochMilli).orElseGet(System::currentTimeMillis)); - } + declaredTopics.forEach(driver::declareTopic); + if (declaredTopics.values().stream().anyMatch(count -> count > 1)) { + driver.activateMultiPartitionMode(); + } + return driver;} }