From 3fcaf4aaf0c12aa4c3a6aacc367dccbaf12c7a38 Mon Sep 17 00:00:00 2001 From: Atharva262005 Date: Sun, 25 Jan 2026 01:52:40 +0530 Subject: [PATCH 1/2] feat(sparkreceiver): Implement parallel reading with configurable numReaders Adds withNumReaders() to SparkReceiverIO and implements parallel execution using Create.of(shards) + Reshuffle. This addresses scalability limitations by allowing work distribution across multiple workers. --- CHANGES.md | 1 + .../ReadFromSparkReceiverWithOffsetDoFn.java | 18 ++++--- .../sdk/io/sparkreceiver/SparkReceiverIO.java | 49 +++++++++++++++++-- ...adFromSparkReceiverWithOffsetDoFnTest.java | 2 +- .../io/sparkreceiver/SparkReceiverIOTest.java | 27 ++++++++++ 5 files changed, 83 insertions(+), 14 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index ff931802addf..53934ea2b707 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -65,6 +65,7 @@ ## I/Os * Add support for Datadog IO (Java) ([#37318](https://github.com/apache/beam/issues/37318)). +* Support for parallel reading in SparkReceiverIO (Java) ([#37410](https://github.com/apache/beam/issues/37410)). ## New Features / Improvements diff --git a/sdks/java/io/sparkreceiver/3/src/main/java/org/apache/beam/sdk/io/sparkreceiver/ReadFromSparkReceiverWithOffsetDoFn.java b/sdks/java/io/sparkreceiver/3/src/main/java/org/apache/beam/sdk/io/sparkreceiver/ReadFromSparkReceiverWithOffsetDoFn.java index e641e01cb3a4..1cd3e9a94039 100644 --- a/sdks/java/io/sparkreceiver/3/src/main/java/org/apache/beam/sdk/io/sparkreceiver/ReadFromSparkReceiverWithOffsetDoFn.java +++ b/sdks/java/io/sparkreceiver/3/src/main/java/org/apache/beam/sdk/io/sparkreceiver/ReadFromSparkReceiverWithOffsetDoFn.java @@ -57,7 +57,7 @@ * ReadFromSparkReceiverWithOffsetDoFn} will move to process the next element. */ @UnboundedPerElement -class ReadFromSparkReceiverWithOffsetDoFn extends DoFn { +class ReadFromSparkReceiverWithOffsetDoFn extends DoFn { private static final Logger LOG = LoggerFactory.getLogger(ReadFromSparkReceiverWithOffsetDoFn.class); @@ -118,7 +118,7 @@ class ReadFromSparkReceiverWithOffsetDoFn extends DoFn { } @GetInitialRestriction - public OffsetRange initialRestriction(@Element byte[] element) { + public OffsetRange initialRestriction(@Element Integer element) { return new OffsetRange(startOffset, Long.MAX_VALUE); } @@ -134,13 +134,13 @@ public WatermarkEstimator newWatermarkEstimator( } @GetSize - public double getSize(@Element byte[] element, @Restriction OffsetRange offsetRange) { + public double getSize(@Element Integer element, @Restriction OffsetRange offsetRange) { return restrictionTracker(element, offsetRange).getProgress().getWorkRemaining(); } @NewTracker public OffsetRangeTracker restrictionTracker( - @Element byte[] element, @Restriction OffsetRange restriction) { + @Element Integer element, @Restriction OffsetRange restriction) { return new OffsetRangeTracker(restriction) { private final AtomicBoolean isCheckDoneCalled = new AtomicBoolean(false); @@ -178,7 +178,8 @@ public Coder restrictionCoder() { } // Need to do an unchecked cast from Object - // because org.apache.spark.streaming.receiver.ReceiverSupervisor accepts Object in push methods + // because org.apache.spark.streaming.receiver.ReceiverSupervisor accepts Object + // in push methods @SuppressWarnings("unchecked") private static class SparkConsumerWithOffset implements SparkConsumer { private final Queue recordsQueue; @@ -210,8 +211,9 @@ public void start(Receiver sparkReceiver) { return null; } /* - Use only [0] element - data. - The other elements are not needed because they are related to Spark environment options. + * Use only [0] element - data. + * The other elements are not needed because they are related to Spark + * environment options. */ Object data = input[0]; @@ -265,7 +267,7 @@ public void stop() { @ProcessElement public ProcessContinuation processElement( - @Element byte[] element, + @Element Integer element, RestrictionTracker tracker, WatermarkEstimator watermarkEstimator, OutputReceiver receiver) { diff --git a/sdks/java/io/sparkreceiver/3/src/main/java/org/apache/beam/sdk/io/sparkreceiver/SparkReceiverIO.java b/sdks/java/io/sparkreceiver/3/src/main/java/org/apache/beam/sdk/io/sparkreceiver/SparkReceiverIO.java index d51c26154328..14158cb5a253 100644 --- a/sdks/java/io/sparkreceiver/3/src/main/java/org/apache/beam/sdk/io/sparkreceiver/SparkReceiverIO.java +++ b/sdks/java/io/sparkreceiver/3/src/main/java/org/apache/beam/sdk/io/sparkreceiver/SparkReceiverIO.java @@ -21,12 +21,19 @@ import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; import com.google.auto.value.AutoValue; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.Impulse; +import org.apache.beam.sdk.transforms.MapElements; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.Reshuffle; import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.TypeDescriptors; import org.apache.spark.streaming.receiver.Receiver; import org.checkerframework.checker.nullness.qual.Nullable; import org.joda.time.Instant; @@ -99,6 +106,8 @@ public abstract static class Read extends PTransform> abstract @Nullable Long getStartOffset(); + abstract @Nullable Integer getNumReaders(); + abstract Builder toBuilder(); @AutoValue.Builder @@ -117,6 +126,8 @@ abstract Builder setSparkReceiverBuilder( abstract Builder setStartOffset(Long startOffset); + abstract Builder setNumReaders(Integer numReaders); + abstract Read build(); } @@ -151,12 +162,21 @@ public Read withStartPollTimeoutSec(Long startPollTimeoutSec) { return toBuilder().setStartPollTimeoutSec(startPollTimeoutSec).build(); } - /** Inclusive start offset from which the reading should be started. */ public Read withStartOffset(Long startOffset) { checkArgument(startOffset != null, "Start offset can not be null"); return toBuilder().setStartOffset(startOffset).build(); } + /** + * A number of workers to read from Spark {@link Receiver}. + * + *

If this value is not set, or set to 1, the reading will be performed on a single worker. + */ + public Read withNumReaders(int numReaders) { + checkArgument(numReaders > 0, "Number of readers should be greater than 0"); + return toBuilder().setNumReaders(numReaders).build(); + } + @Override public PCollection expand(PBegin input) { validateTransform(); @@ -191,10 +211,29 @@ public PCollection expand(PBegin input) { sparkReceiverBuilder.getSparkReceiverClass().getName())); } else { LOG.info("{} started reading", ReadFromSparkReceiverWithOffsetDoFn.class.getSimpleName()); - return input - .apply(Impulse.create()) - .apply(ParDo.of(new ReadFromSparkReceiverWithOffsetDoFn<>(sparkReceiverRead))); - // TODO: Split data from SparkReceiver into multiple workers + Integer numReadersObj = sparkReceiverRead.getNumReaders(); + if (numReadersObj == null || numReadersObj == 1) { + return input + .apply(Impulse.create()) + .apply( + MapElements.into(TypeDescriptors.integers()) + .via( + new SerializableFunction() { + @Override + public Integer apply(byte[] input) { + return 0; + } + })) + .apply(ParDo.of(new ReadFromSparkReceiverWithOffsetDoFn<>(sparkReceiverRead))); + } else { + int numReaders = numReadersObj; + List shards = + IntStream.range(0, numReaders).boxed().collect(Collectors.toList()); + return input + .apply(Create.of(shards)) + .apply(Reshuffle.viaRandomKey()) + .apply(ParDo.of(new ReadFromSparkReceiverWithOffsetDoFn<>(sparkReceiverRead))); + } } } } diff --git a/sdks/java/io/sparkreceiver/3/src/test/java/org/apache/beam/sdk/io/sparkreceiver/ReadFromSparkReceiverWithOffsetDoFnTest.java b/sdks/java/io/sparkreceiver/3/src/test/java/org/apache/beam/sdk/io/sparkreceiver/ReadFromSparkReceiverWithOffsetDoFnTest.java index 6ab5d8393def..98b7ddccfa12 100644 --- a/sdks/java/io/sparkreceiver/3/src/test/java/org/apache/beam/sdk/io/sparkreceiver/ReadFromSparkReceiverWithOffsetDoFnTest.java +++ b/sdks/java/io/sparkreceiver/3/src/test/java/org/apache/beam/sdk/io/sparkreceiver/ReadFromSparkReceiverWithOffsetDoFnTest.java @@ -35,7 +35,7 @@ /** Test class for {@link ReadFromSparkReceiverWithOffsetDoFn}. */ public class ReadFromSparkReceiverWithOffsetDoFnTest { - private static final byte[] TEST_ELEMENT = new byte[] {}; + private static final Integer TEST_ELEMENT = 0; private final ReadFromSparkReceiverWithOffsetDoFn dofnInstance = new ReadFromSparkReceiverWithOffsetDoFn<>(makeReadTransform()); diff --git a/sdks/java/io/sparkreceiver/3/src/test/java/org/apache/beam/sdk/io/sparkreceiver/SparkReceiverIOTest.java b/sdks/java/io/sparkreceiver/3/src/test/java/org/apache/beam/sdk/io/sparkreceiver/SparkReceiverIOTest.java index bb482e798387..45aa347f18cc 100644 --- a/sdks/java/io/sparkreceiver/3/src/test/java/org/apache/beam/sdk/io/sparkreceiver/SparkReceiverIOTest.java +++ b/sdks/java/io/sparkreceiver/3/src/test/java/org/apache/beam/sdk/io/sparkreceiver/SparkReceiverIOTest.java @@ -241,4 +241,31 @@ public void testReadFromReceiverIteratorData() { PAssert.that(actual).containsInAnyOrder(expected); pipeline.run().waitUntilFinish(Duration.standardSeconds(15)); } + + @Test + public void testReadFromCustomReceiverWithParallelism() { + CustomReceiverWithOffset.shouldFailInTheMiddle = false; + ReceiverBuilder receiverBuilder = + new ReceiverBuilder<>(CustomReceiverWithOffset.class).withConstructorArgs(); + SparkReceiverIO.Read reader = + SparkReceiverIO.read() + .withGetOffsetFn(Long::valueOf) + .withTimestampFn(Instant::parse) + .withPullFrequencySec(PULL_FREQUENCY_SEC) + .withStartPollTimeoutSec(START_POLL_TIMEOUT_SEC) + .withStartOffset(START_OFFSET) + .withSparkReceiverBuilder(receiverBuilder) + .withNumReaders(3); + + List expected = new ArrayList<>(); + for (int j = 0; j < 3; j++) { + for (int i = 0; i < CustomReceiverWithOffset.RECORDS_COUNT; i++) { + expected.add(String.valueOf(i)); + } + } + PCollection actual = pipeline.apply(reader).setCoder(StringUtf8Coder.of()); + + PAssert.that(actual).containsInAnyOrder(expected); + pipeline.run().waitUntilFinish(Duration.standardSeconds(15)); + } } From 1481d9c1232b77586f4cf66bc9a18460ac8516af Mon Sep 17 00:00:00 2001 From: Atharva262005 Date: Sun, 25 Jan 2026 02:26:02 +0530 Subject: [PATCH 2/2] fix(sparkreceiver): address PR review comments 1. Implemented setShard() in HasOffset to allow receivers to handle partitioning (prevents data duplication). 2. Updated CustomReceiverWithOffset to filter records based on shardId. 3. Updated DoFn to pass shardId/numShards to the receiver. 4. Restored Javadoc for withStartOffset. 5. Simplified backward compatibility logic in expand() using Create.of(). 6. Updated tests to verify parallel reading produces correct, unique record set. --- .../beam/sdk/io/sparkreceiver/HasOffset.java | 10 +++++++++- .../ReadFromSparkReceiverWithOffsetDoFn.java | 7 +++++++ .../sdk/io/sparkreceiver/SparkReceiverIO.java | 15 ++------------- .../sparkreceiver/CustomReceiverWithOffset.java | 16 +++++++++++++--- .../io/sparkreceiver/SparkReceiverIOTest.java | 9 +++++---- 5 files changed, 36 insertions(+), 21 deletions(-) diff --git a/sdks/java/io/sparkreceiver/3/src/main/java/org/apache/beam/sdk/io/sparkreceiver/HasOffset.java b/sdks/java/io/sparkreceiver/3/src/main/java/org/apache/beam/sdk/io/sparkreceiver/HasOffset.java index 92bf082112e6..e3fb2d9f4f1b 100644 --- a/sdks/java/io/sparkreceiver/3/src/main/java/org/apache/beam/sdk/io/sparkreceiver/HasOffset.java +++ b/sdks/java/io/sparkreceiver/3/src/main/java/org/apache/beam/sdk/io/sparkreceiver/HasOffset.java @@ -33,5 +33,13 @@ public interface HasOffset { * Some {@link org.apache.spark.streaming.receiver.Receiver} support mechanism of checkpoint (e.g. * ack). This method should be called before stopping the receiver. */ - default void setCheckpoint(Long recordsProcessed) {}; + default void setCheckpoint(Long recordsProcessed) {} + + /** + * Set the shard identifier and the total number of shards for parallel reading. + * + * @param shardId The unique identifier for this shard (reader). + * @param numShards The total number of shards (readers). + */ + default void setShard(int shardId, int numShards) {} } diff --git a/sdks/java/io/sparkreceiver/3/src/main/java/org/apache/beam/sdk/io/sparkreceiver/ReadFromSparkReceiverWithOffsetDoFn.java b/sdks/java/io/sparkreceiver/3/src/main/java/org/apache/beam/sdk/io/sparkreceiver/ReadFromSparkReceiverWithOffsetDoFn.java index 1cd3e9a94039..2752c30f6dfe 100644 --- a/sdks/java/io/sparkreceiver/3/src/main/java/org/apache/beam/sdk/io/sparkreceiver/ReadFromSparkReceiverWithOffsetDoFn.java +++ b/sdks/java/io/sparkreceiver/3/src/main/java/org/apache/beam/sdk/io/sparkreceiver/ReadFromSparkReceiverWithOffsetDoFn.java @@ -79,6 +79,7 @@ class ReadFromSparkReceiverWithOffsetDoFn extends DoFn { private final Long pullFrequencySec; private final Long startPollTimeoutSec; private final Long startOffset; + private final int numReaders; ReadFromSparkReceiverWithOffsetDoFn(SparkReceiverIO.Read transform) { createWatermarkEstimatorFn = WatermarkEstimators.Manual::new; @@ -115,6 +116,9 @@ class ReadFromSparkReceiverWithOffsetDoFn extends DoFn { startOffset = DEFAULT_START_OFFSET; } this.startOffset = startOffset; + + Integer numReadersObj = transform.getNumReaders(); + this.numReaders = (numReadersObj != null) ? numReadersObj : 1; } @GetInitialRestriction @@ -286,6 +290,9 @@ public ProcessContinuation processElement( } LOG.debug("Restriction {}", tracker.currentRestriction().toString()); sparkConsumer = new SparkConsumerWithOffset<>(tracker.currentRestriction().getFrom()); + if (sparkReceiver instanceof HasOffset) { + ((HasOffset) sparkReceiver).setShard(element, numReaders); + } sparkConsumer.start(sparkReceiver); Long recordsProcessed = 0L; diff --git a/sdks/java/io/sparkreceiver/3/src/main/java/org/apache/beam/sdk/io/sparkreceiver/SparkReceiverIO.java b/sdks/java/io/sparkreceiver/3/src/main/java/org/apache/beam/sdk/io/sparkreceiver/SparkReceiverIO.java index 14158cb5a253..91c8d40e50cb 100644 --- a/sdks/java/io/sparkreceiver/3/src/main/java/org/apache/beam/sdk/io/sparkreceiver/SparkReceiverIO.java +++ b/sdks/java/io/sparkreceiver/3/src/main/java/org/apache/beam/sdk/io/sparkreceiver/SparkReceiverIO.java @@ -25,15 +25,12 @@ import java.util.stream.Collectors; import java.util.stream.IntStream; import org.apache.beam.sdk.transforms.Create; -import org.apache.beam.sdk.transforms.Impulse; -import org.apache.beam.sdk.transforms.MapElements; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.Reshuffle; import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.sdk.values.TypeDescriptors; import org.apache.spark.streaming.receiver.Receiver; import org.checkerframework.checker.nullness.qual.Nullable; import org.joda.time.Instant; @@ -162,6 +159,7 @@ public Read withStartPollTimeoutSec(Long startPollTimeoutSec) { return toBuilder().setStartPollTimeoutSec(startPollTimeoutSec).build(); } + /** Inclusive start offset from which the reading should be started. */ public Read withStartOffset(Long startOffset) { checkArgument(startOffset != null, "Start offset can not be null"); return toBuilder().setStartOffset(startOffset).build(); @@ -214,16 +212,7 @@ public PCollection expand(PBegin input) { Integer numReadersObj = sparkReceiverRead.getNumReaders(); if (numReadersObj == null || numReadersObj == 1) { return input - .apply(Impulse.create()) - .apply( - MapElements.into(TypeDescriptors.integers()) - .via( - new SerializableFunction() { - @Override - public Integer apply(byte[] input) { - return 0; - } - })) + .apply(Create.of(0)) .apply(ParDo.of(new ReadFromSparkReceiverWithOffsetDoFn<>(sparkReceiverRead))); } else { int numReaders = numReadersObj; diff --git a/sdks/java/io/sparkreceiver/3/src/test/java/org/apache/beam/sdk/io/sparkreceiver/CustomReceiverWithOffset.java b/sdks/java/io/sparkreceiver/3/src/test/java/org/apache/beam/sdk/io/sparkreceiver/CustomReceiverWithOffset.java index 084208556fee..8937c88c3e1d 100644 --- a/sdks/java/io/sparkreceiver/3/src/test/java/org/apache/beam/sdk/io/sparkreceiver/CustomReceiverWithOffset.java +++ b/sdks/java/io/sparkreceiver/3/src/test/java/org/apache/beam/sdk/io/sparkreceiver/CustomReceiverWithOffset.java @@ -36,11 +36,13 @@ public class CustomReceiverWithOffset extends Receiver implements HasOff public static final int RECORDS_COUNT = 20; /* - Used in test for imitation of reading with exception - */ + * Used in test for imitation of reading with exception + */ public static boolean shouldFailInTheMiddle = false; private Long startOffset; + private int shardId = 0; + private int numShards = 1; CustomReceiverWithOffset() { super(StorageLevel.MEMORY_AND_DISK_2()); @@ -53,6 +55,12 @@ public void setStartOffset(Long startOffset) { } } + @Override + public void setShard(int shardId, int numShards) { + this.shardId = shardId; + this.numShards = numShards; + } + @Override @SuppressWarnings("FutureReturnValueIgnored") public void onStart() { @@ -76,7 +84,9 @@ private void receive() { LOG.debug("Expected fail in the middle of reading"); throw new IllegalStateException("Expected exception"); } - store(String.valueOf(currentOffset)); + if (currentOffset % numShards == shardId) { + store(String.valueOf(currentOffset)); + } currentOffset++; } else { break; diff --git a/sdks/java/io/sparkreceiver/3/src/test/java/org/apache/beam/sdk/io/sparkreceiver/SparkReceiverIOTest.java b/sdks/java/io/sparkreceiver/3/src/test/java/org/apache/beam/sdk/io/sparkreceiver/SparkReceiverIOTest.java index 45aa347f18cc..c8fa9f0526b3 100644 --- a/sdks/java/io/sparkreceiver/3/src/test/java/org/apache/beam/sdk/io/sparkreceiver/SparkReceiverIOTest.java +++ b/sdks/java/io/sparkreceiver/3/src/test/java/org/apache/beam/sdk/io/sparkreceiver/SparkReceiverIOTest.java @@ -258,10 +258,11 @@ public void testReadFromCustomReceiverWithParallelism() { .withNumReaders(3); List expected = new ArrayList<>(); - for (int j = 0; j < 3; j++) { - for (int i = 0; i < CustomReceiverWithOffset.RECORDS_COUNT; i++) { - expected.add(String.valueOf(i)); - } + // With sharding enabled in CustomReceiverWithOffset, the total records read + // across all workers + // should be exactly the set of 0..RECORDS_COUNT-1, each read exactly once. + for (int i = 0; i < CustomReceiverWithOffset.RECORDS_COUNT; i++) { + expected.add(String.valueOf(i)); } PCollection actual = pipeline.apply(reader).setCoder(StringUtf8Coder.of());