record) {
+ outputRecordsByTopic.computeIfAbsent(topic, k -> new LinkedList<>()).add(record);
+ }
+ });
+ runtime.build();
+ multiPartitionModeActive = true;
+ }
+
/**
* Get all the names of all the topics to which records have been produced during the test run.
*
@@ -848,7 +984,7 @@ TestRecord readRecord(final String topic,
}
final K key = keyDeserializer.deserialize(record.topic(), record.headers(), record.key());
final V value = valueDeserializer.deserialize(record.topic(), record.headers(), record.value());
- final int outputPartition = -1;
+ final int outputPartition = multiPartitionModeActive && record.partition() != null ? record.partition() : -1;
return new TestRecord<>(key, value, record.headers(), Instant.ofEpochMilli(record.timestamp()), outputPartition);
}
@@ -868,7 +1004,7 @@ void pipeRecord(final String topic,
throw new IllegalStateException("Provided `TestRecord` does not have a timestamp and no timestamp overwrite was provided via `time` parameter.");
}
- pipeRecord(topic, timestamp, serializedKey, serializedValue, record.headers());
+ pipeRecord(topic, timestamp, serializedKey, serializedValue, record.headers(), record.partition());
}
final long queueSize(final String topic) {
@@ -949,6 +1085,9 @@ public StateStore getStateStore(final String name) throws IllegalArgumentExcepti
private StateStore getStateStore(final String name,
final boolean throwForBuiltInStores) {
+ if (multiPartitionModeActive) {
+ return runtime.getStateStore(name, throwForBuiltInStores);
+ }
if (task != null) {
// Accessing a store must not corrupt the task's record context. Only set a dummy
// context when none exists yet (i.e. before any record has been processed) so that
@@ -979,7 +1118,91 @@ private StateStore getStateStore(final String name,
return null;
}
- private void throwIfBuiltInStore(final StateStore stateStore) {
+ /**
+ * Return the {@link StateStore} for the task owning {@code partition} of the sub-topology that
+ * registers a store named {@code name}. If the store name appears in multiple
+ * sub-topologies, throws {@link IllegalStateException}.
+ *
+ * @param name the store name
+ * @param partition the partition whose owning task should be queried
+ * @return the {@link StateStore}, or {@code null} if no sub-topology registers a store with this name
+ */
+ public StateStore getStateStore(final String name, final int partition) {
+ requireMultiPartitionMode();
+ return runtime.getStateStore(name, partition);
+ }
+
+ /**
+ * Guard for the partition-aware accessors below ({@link #getStateStore(String, int)} and
+ * friends). Unlike the original implementation, this never activates multi-partition mode as a
+ * side effect of what looks like a read-only getter: a {@code getXxx()} method that can silently
+ * rebuild the entire task graph on first call -- and behave differently the second time it's
+ * called -- is surprising and hides a non-trivial effect behind an innocuous-looking signature.
+ *
+ * Multi-partition mode must already be active by the time these accessors are called, either
+ * because {@link TopologyTestDriverBuilder#build()} activated it (at least one declared topic has
+ * more than one partition), or because the caller invoked {@link #activateMultiPartitionMode()}
+ * explicitly. If it isn't, this throws -- it does not activate it for you.
+ *
+ * @throws IllegalStateException if the driver is not operating in multi-partition mode
+ */
+ private void requireMultiPartitionMode() {
+ if (!multiPartitionModeActive) {
+ throw new IllegalStateException(
+ "This driver is not operating in multi-partition mode. Declare a topic with more than "
+ + "one partition (via TopologyTestDriverBuilder#declareTopic() or declareTopic()) "
+ + "and call activateMultiPartitionMode() -- or simply pipe a record first, which "
+ + "activates it automatically -- before calling partition-aware accessors like "
+ + "getStateStore(name, partition). Use getStateStore(name) for single-partition mode.");
+ }
+ }
+
+ /**
+ * Internal fully-qualified {@link StateStore} accessor: resolves a store to the task owning
+ * {@code (subtopologyId, partition)}. Package-private -- not part of the KIP-1238 public API;
+ * callers use {@link #getStateStore(String, int)}, which resolves the sub-topology by store name.
+ *
+ * @param name the store name
+ * @param subtopologyId the sub-topology id
+ * @param partition the partition whose owning task should be queried
+ * @return the {@link StateStore}, or {@code null} if the task does not register a store with this name
+ * @throws IllegalArgumentException if no task exists for {@code (subtopologyId, partition)}
+ * @throws IllegalStateException if the driver is not operating in multi-partition mode
+ */
+ StateStore getStateStore(final String name, final int subtopologyId, final int partition) {
+ requireMultiPartitionMode();
+ return runtime.getStateStore(name, subtopologyId, partition);
+ }
+
+ /**
+ * @return the number of partitions of the sub-topology that registers {@code storeName}, or 0
+ * if no sub-topology registers it (or 1 for a global store).
+ * @throws IllegalStateException if the driver is not operating in multi-partition mode
+ */
+ int partitionsOf(final String storeName) {
+ requireMultiPartitionMode();
+ return runtime.partitionsOf(storeName);
+ }
+
+ /**
+ * @return the number of partitions of the given sub-topology, or 0 if the id is unknown.
+ * @throws IllegalStateException if the driver is not operating in multi-partition mode
+ */
+ int partitionsOfSubtopology(final int subtopologyId) {
+ requireMultiPartitionMode();
+ return runtime.partitionsOfSubtopology(subtopologyId);
+ }
+
+ /**
+ * @return an unmodifiable list of the sub-topology ids in this driver.
+ * @throws IllegalStateException if the driver is not operating in multi-partition mode
+ */
+ List subtopologies() {
+ requireMultiPartitionMode();
+ return runtime.subtopologies();
+ }
+
+ static void throwIfBuiltInStore(final StateStore stateStore) {
if (stateStore instanceof VersionedKeyValueStore) {
throw new IllegalArgumentException("Store " + stateStore.name()
+ " is a versioned key-value store and should be accessed via `getVersionedKeyValueStore()`");
@@ -1298,6 +1521,95 @@ public SessionStoreWithHeaders getSessionStoreWithHeaders(final Str
return store instanceof SessionStoreWithHeaders ? (SessionStoreWithHeaders) store : null;
}
+ /**
+ * Partition-aware {@link KeyValueStore} accessor.
+ */
+ @SuppressWarnings("unchecked")
+ public KeyValueStore getKeyValueStore(final String name, final int partition) {
+ final StateStore store = getStateStore(name, partition);
+ if (store instanceof TimestampedKeyValueStore) {
+ log.info("Method #getTimestampedKeyValueStore() should be used to access a TimestampedKeyValueStore.");
+ return new KeyValueStoreFacade<>((TimestampedKeyValueStore) store);
+ }
+ return store instanceof KeyValueStore ? (KeyValueStore) store : null;
+ }
+
+ /**
+ * Partition-aware {@link TimestampedKeyValueStore} accessor.
+ */
+ @SuppressWarnings("unchecked")
+ public KeyValueStore> getTimestampedKeyValueStore(final String name, final int partition) {
+ final StateStore store = getStateStore(name, partition);
+ return store instanceof TimestampedKeyValueStore ? (TimestampedKeyValueStore) store : null;
+ }
+
+ /**
+ * Partition-aware {@link VersionedKeyValueStore} accessor.
+ */
+ @SuppressWarnings("unchecked")
+ public VersionedKeyValueStore getVersionedKeyValueStore(final String name, final int partition) {
+ final StateStore store = getStateStore(name, partition);
+ return store instanceof VersionedKeyValueStore ? (VersionedKeyValueStore) store : null;
+ }
+
+ /**
+ * Partition-aware {@link WindowStore} accessor.
+ */
+ @SuppressWarnings("unchecked")
+ public WindowStore getWindowStore(final String name, final int partition) {
+ final StateStore store = getStateStore(name, partition);
+ if (store instanceof TimestampedWindowStore) {
+ log.info("Method #getTimestampedWindowStore() should be used to access a TimestampedWindowStore.");
+ return new WindowStoreFacade<>((TimestampedWindowStore) store);
+ }
+ return store instanceof WindowStore ? (WindowStore) store : null;
+ }
+
+ /**
+ * Partition-aware {@link TimestampedWindowStore} accessor.
+ */
+ @SuppressWarnings("unchecked")
+ public WindowStore> getTimestampedWindowStore(final String name, final int partition) {
+ final StateStore store = getStateStore(name, partition);
+ return store instanceof TimestampedWindowStore ? (TimestampedWindowStore) store : null;
+ }
+
+ /**
+ * Partition-aware {@link SessionStore} accessor.
+ */
+ @SuppressWarnings("unchecked")
+ public SessionStore getSessionStore(final String name, final int partition) {
+ final StateStore store = getStateStore(name, partition);
+ return store instanceof SessionStore ? (SessionStore) store : null;
+ }
+
+ /**
+ * Partition-aware {@link TimestampedKeyValueStoreWithHeaders} accessor.
+ */
+ @SuppressWarnings("unchecked")
+ public KeyValueStore> getTimestampedKeyValueStoreWithHeaders(final String name, final int partition) {
+ final StateStore store = getStateStore(name, partition);
+ return store instanceof TimestampedKeyValueStoreWithHeaders ? (TimestampedKeyValueStoreWithHeaders) store : null;
+ }
+
+ /**
+ * Partition-aware {@link TimestampedWindowStoreWithHeaders} accessor.
+ */
+ @SuppressWarnings("unchecked")
+ public WindowStore> getTimestampedWindowStoreWithHeaders(final String name, final int partition) {
+ final StateStore store = getStateStore(name, partition);
+ return store instanceof TimestampedWindowStoreWithHeaders ? (TimestampedWindowStoreWithHeaders) store : null;
+ }
+
+ /**
+ * Partition-aware {@link SessionStoreWithHeaders} accessor.
+ */
+ @SuppressWarnings("unchecked")
+ public SessionStoreWithHeaders getSessionStoreWithHeaders(final String name, final int partition) {
+ final StateStore store = getStateStore(name, partition);
+ return store instanceof SessionStoreWithHeaders ? (SessionStoreWithHeaders) store : null;
+ }
+
/**
* Close the driver, its topology, and all processors.
*/
@@ -1308,6 +1620,9 @@ public void close() {
task.postCommit(true);
task.closeClean();
}
+ if (multiPartitionModeActive) {
+ runtime.closeTasks();
+ }
if (globalStateTask != null) {
try {
globalStateTask.close(false);
@@ -1315,7 +1630,11 @@ public void close() {
// ignore
}
}
- completeAllProcessableWork();
+ if (multiPartitionModeActive) {
+ runtime.completeAllProcessableWork();
+ } else {
+ completeAllProcessableWork();
+ }
if (task != null && task.hasRecordsQueued()) {
log.warn("Found some records that cannot be processed due to the" +
" {} configuration during TopologyTestDriver#close().",
diff --git a/streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriverBuilder.java b/streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriverBuilder.java
index bd7aa527bdfc0..c36cb0d6ada23 100644
--- a/streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriverBuilder.java
+++ b/streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriverBuilder.java
@@ -19,20 +19,27 @@
import java.time.Instant;
import java.util.Objects;
import java.util.Optional;
+import java.util.LinkedHashMap;
+import java.util.Map;
+import java.util.Objects;
import java.util.Properties;
/**
* Fluent builder for a {@link TopologyTestDriver}.
*
- * This is the entry point for constructing a {@link TopologyTestDriver}.
- * Configure the builder and call {@link #build()}.
- * The {@link TopologyTestDriver} constructors remain functional but are deprecated in favor of
- * this builder.
+ * This is the entry point for constructing a {@link TopologyTestDriver}, for both
+ * single and multi-partition mode. Declare the partition count of each relevant topic, then
+ * call {@link #build()}: when at least one declared topic has more than one partition the driver wires
+ * its multi-partition task graph; declaring only single-partition topics (or none) keeps the legacy
+ * single-flat-task behaviour.
+ * The {@link TopologyTestDriver} constructors remain functional but are
+ * deprecated in favour of this builder.
*
* {@code
* TopologyTestDriver driver = new TopologyTestDriverBuilder(topology)
* .withConfig(props)
* .withInitialWallClockTime(Instant.ofEpochMilli(0))
+ * .declareTopic("input", 4)
* .build();
* }
*/
@@ -41,6 +48,7 @@ public class TopologyTestDriverBuilder {
private final Topology topology;
private Properties config = new Properties();
private Optional initialWallClockTime = Optional.empty();
+ private final Map declaredTopics = new LinkedHashMap<>();
/**
* Start building a driver for the given topology.
@@ -75,14 +83,43 @@ public TopologyTestDriverBuilder withInitialWallClockTime(final Instant initialW
}
/**
- * Build the driver: construct it and apply all declared topic partition counts.
+ * Declare the number of partitions for an input, output, or internal repartition topic.
+ *
+ * @param topicName the topic to declare
+ * @param partitions the number of partitions (must be at least 1)
+ * @return this builder
+ * @throws IllegalArgumentException if {@code partitions} is less than 1, or the topic was already
+ * declared with a different count
+ */
+ public TopologyTestDriverBuilder declareTopic(final String topicName, final int partitions) {
+ Objects.requireNonNull(topicName, "topicName cannot be null");
+ if (partitions < 1) {
+ throw new IllegalArgumentException(
+ "Partition count must be at least 1 (topic='" + topicName + "', partitions=" + partitions + ").");
+ }
+ final Integer existing = declaredTopics.putIfAbsent(topicName, partitions);
+ if (existing != null && existing != partitions) {
+ throw new IllegalArgumentException(
+ "Topic '" + topicName + "' was already declared with " + existing
+ + " partitions; cannot redeclare with " + partitions + ".");
+ }
+ return this;
+ }
+
+ /**
+ * Build the driver: construct it, declare all topics, and—when at least one declared topic has more
+ * than one partition—create the multi-partition task graph.
*
* @return a ready-to-use {@link TopologyTestDriver}
*/
public TopologyTestDriver build() {
- return new TopologyTestDriver(
+ final TopologyTestDriver driver = new TopologyTestDriver(
topology.internalTopologyBuilder,
config,
initialWallClockTime.map(Instant::toEpochMilli).orElseGet(System::currentTimeMillis));
- }
+ declaredTopics.forEach(driver::declareTopic);
+ if (declaredTopics.values().stream().anyMatch(count -> count > 1)) {
+ driver.activateMultiPartitionMode();
+ }
+ return driver;}
}