diff --git a/CHANGES.md b/CHANGES.md index ff931802addf..53934ea2b707 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -65,6 +65,7 @@ ## I/Os * Add support for Datadog IO (Java) ([#37318](https://github.com/apache/beam/issues/37318)). +* Support for parallel reading in SparkReceiverIO (Java) ([#37410](https://github.com/apache/beam/issues/37410)). ## New Features / Improvements diff --git a/sdks/java/io/sparkreceiver/3/src/main/java/org/apache/beam/sdk/io/sparkreceiver/HasOffset.java b/sdks/java/io/sparkreceiver/3/src/main/java/org/apache/beam/sdk/io/sparkreceiver/HasOffset.java index 92bf082112e6..e3fb2d9f4f1b 100644 --- a/sdks/java/io/sparkreceiver/3/src/main/java/org/apache/beam/sdk/io/sparkreceiver/HasOffset.java +++ b/sdks/java/io/sparkreceiver/3/src/main/java/org/apache/beam/sdk/io/sparkreceiver/HasOffset.java @@ -33,5 +33,13 @@ public interface HasOffset { * Some {@link org.apache.spark.streaming.receiver.Receiver} support mechanism of checkpoint (e.g. * ack). This method should be called before stopping the receiver. */ - default void setCheckpoint(Long recordsProcessed) {}; + default void setCheckpoint(Long recordsProcessed) {} + + /** + * Set the shard identifier and the total number of shards for parallel reading. + * + * @param shardId The unique identifier for this shard (reader). + * @param numShards The total number of shards (readers). + */ + default void setShard(int shardId, int numShards) {} } diff --git a/sdks/java/io/sparkreceiver/3/src/main/java/org/apache/beam/sdk/io/sparkreceiver/ReadFromSparkReceiverWithOffsetDoFn.java b/sdks/java/io/sparkreceiver/3/src/main/java/org/apache/beam/sdk/io/sparkreceiver/ReadFromSparkReceiverWithOffsetDoFn.java index e641e01cb3a4..2752c30f6dfe 100644 --- a/sdks/java/io/sparkreceiver/3/src/main/java/org/apache/beam/sdk/io/sparkreceiver/ReadFromSparkReceiverWithOffsetDoFn.java +++ b/sdks/java/io/sparkreceiver/3/src/main/java/org/apache/beam/sdk/io/sparkreceiver/ReadFromSparkReceiverWithOffsetDoFn.java @@ -57,7 +57,7 @@ * ReadFromSparkReceiverWithOffsetDoFn} will move to process the next element. */ @UnboundedPerElement -class ReadFromSparkReceiverWithOffsetDoFn extends DoFn { +class ReadFromSparkReceiverWithOffsetDoFn extends DoFn { private static final Logger LOG = LoggerFactory.getLogger(ReadFromSparkReceiverWithOffsetDoFn.class); @@ -79,6 +79,7 @@ class ReadFromSparkReceiverWithOffsetDoFn extends DoFn { private final Long pullFrequencySec; private final Long startPollTimeoutSec; private final Long startOffset; + private final int numReaders; ReadFromSparkReceiverWithOffsetDoFn(SparkReceiverIO.Read transform) { createWatermarkEstimatorFn = WatermarkEstimators.Manual::new; @@ -115,10 +116,13 @@ class ReadFromSparkReceiverWithOffsetDoFn extends DoFn { startOffset = DEFAULT_START_OFFSET; } this.startOffset = startOffset; + + Integer numReadersObj = transform.getNumReaders(); + this.numReaders = (numReadersObj != null) ? numReadersObj : 1; } @GetInitialRestriction - public OffsetRange initialRestriction(@Element byte[] element) { + public OffsetRange initialRestriction(@Element Integer element) { return new OffsetRange(startOffset, Long.MAX_VALUE); } @@ -134,13 +138,13 @@ public WatermarkEstimator newWatermarkEstimator( } @GetSize - public double getSize(@Element byte[] element, @Restriction OffsetRange offsetRange) { + public double getSize(@Element Integer element, @Restriction OffsetRange offsetRange) { return restrictionTracker(element, offsetRange).getProgress().getWorkRemaining(); } @NewTracker public OffsetRangeTracker restrictionTracker( - @Element byte[] element, @Restriction OffsetRange restriction) { + @Element Integer element, @Restriction OffsetRange restriction) { return new OffsetRangeTracker(restriction) { private final AtomicBoolean isCheckDoneCalled = new AtomicBoolean(false); @@ -178,7 +182,8 @@ public Coder restrictionCoder() { } // Need to do an unchecked cast from Object - // because org.apache.spark.streaming.receiver.ReceiverSupervisor accepts Object in push methods + // because org.apache.spark.streaming.receiver.ReceiverSupervisor accepts Object + // in push methods @SuppressWarnings("unchecked") private static class SparkConsumerWithOffset implements SparkConsumer { private final Queue recordsQueue; @@ -210,8 +215,9 @@ public void start(Receiver sparkReceiver) { return null; } /* - Use only [0] element - data. - The other elements are not needed because they are related to Spark environment options. + * Use only [0] element - data. + * The other elements are not needed because they are related to Spark + * environment options. */ Object data = input[0]; @@ -265,7 +271,7 @@ public void stop() { @ProcessElement public ProcessContinuation processElement( - @Element byte[] element, + @Element Integer element, RestrictionTracker tracker, WatermarkEstimator watermarkEstimator, OutputReceiver receiver) { @@ -284,6 +290,9 @@ public ProcessContinuation processElement( } LOG.debug("Restriction {}", tracker.currentRestriction().toString()); sparkConsumer = new SparkConsumerWithOffset<>(tracker.currentRestriction().getFrom()); + if (sparkReceiver instanceof HasOffset) { + ((HasOffset) sparkReceiver).setShard(element, numReaders); + } sparkConsumer.start(sparkReceiver); Long recordsProcessed = 0L; diff --git a/sdks/java/io/sparkreceiver/3/src/main/java/org/apache/beam/sdk/io/sparkreceiver/SparkReceiverIO.java b/sdks/java/io/sparkreceiver/3/src/main/java/org/apache/beam/sdk/io/sparkreceiver/SparkReceiverIO.java index d51c26154328..91c8d40e50cb 100644 --- a/sdks/java/io/sparkreceiver/3/src/main/java/org/apache/beam/sdk/io/sparkreceiver/SparkReceiverIO.java +++ b/sdks/java/io/sparkreceiver/3/src/main/java/org/apache/beam/sdk/io/sparkreceiver/SparkReceiverIO.java @@ -21,9 +21,13 @@ import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; import com.google.auto.value.AutoValue; -import org.apache.beam.sdk.transforms.Impulse; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.Reshuffle; import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; @@ -99,6 +103,8 @@ public abstract static class Read extends PTransform> abstract @Nullable Long getStartOffset(); + abstract @Nullable Integer getNumReaders(); + abstract Builder toBuilder(); @AutoValue.Builder @@ -117,6 +123,8 @@ abstract Builder setSparkReceiverBuilder( abstract Builder setStartOffset(Long startOffset); + abstract Builder setNumReaders(Integer numReaders); + abstract Read build(); } @@ -157,6 +165,16 @@ public Read withStartOffset(Long startOffset) { return toBuilder().setStartOffset(startOffset).build(); } + /** + * A number of workers to read from Spark {@link Receiver}. + * + *

If this value is not set, or set to 1, the reading will be performed on a single worker. + */ + public Read withNumReaders(int numReaders) { + checkArgument(numReaders > 0, "Number of readers should be greater than 0"); + return toBuilder().setNumReaders(numReaders).build(); + } + @Override public PCollection expand(PBegin input) { validateTransform(); @@ -191,10 +209,20 @@ public PCollection expand(PBegin input) { sparkReceiverBuilder.getSparkReceiverClass().getName())); } else { LOG.info("{} started reading", ReadFromSparkReceiverWithOffsetDoFn.class.getSimpleName()); - return input - .apply(Impulse.create()) - .apply(ParDo.of(new ReadFromSparkReceiverWithOffsetDoFn<>(sparkReceiverRead))); - // TODO: Split data from SparkReceiver into multiple workers + Integer numReadersObj = sparkReceiverRead.getNumReaders(); + if (numReadersObj == null || numReadersObj == 1) { + return input + .apply(Create.of(0)) + .apply(ParDo.of(new ReadFromSparkReceiverWithOffsetDoFn<>(sparkReceiverRead))); + } else { + int numReaders = numReadersObj; + List shards = + IntStream.range(0, numReaders).boxed().collect(Collectors.toList()); + return input + .apply(Create.of(shards)) + .apply(Reshuffle.viaRandomKey()) + .apply(ParDo.of(new ReadFromSparkReceiverWithOffsetDoFn<>(sparkReceiverRead))); + } } } } diff --git a/sdks/java/io/sparkreceiver/3/src/test/java/org/apache/beam/sdk/io/sparkreceiver/CustomReceiverWithOffset.java b/sdks/java/io/sparkreceiver/3/src/test/java/org/apache/beam/sdk/io/sparkreceiver/CustomReceiverWithOffset.java index 084208556fee..8937c88c3e1d 100644 --- a/sdks/java/io/sparkreceiver/3/src/test/java/org/apache/beam/sdk/io/sparkreceiver/CustomReceiverWithOffset.java +++ b/sdks/java/io/sparkreceiver/3/src/test/java/org/apache/beam/sdk/io/sparkreceiver/CustomReceiverWithOffset.java @@ -36,11 +36,13 @@ public class CustomReceiverWithOffset extends Receiver implements HasOff public static final int RECORDS_COUNT = 20; /* - Used in test for imitation of reading with exception - */ + * Used in test for imitation of reading with exception + */ public static boolean shouldFailInTheMiddle = false; private Long startOffset; + private int shardId = 0; + private int numShards = 1; CustomReceiverWithOffset() { super(StorageLevel.MEMORY_AND_DISK_2()); @@ -53,6 +55,12 @@ public void setStartOffset(Long startOffset) { } } + @Override + public void setShard(int shardId, int numShards) { + this.shardId = shardId; + this.numShards = numShards; + } + @Override @SuppressWarnings("FutureReturnValueIgnored") public void onStart() { @@ -76,7 +84,9 @@ private void receive() { LOG.debug("Expected fail in the middle of reading"); throw new IllegalStateException("Expected exception"); } - store(String.valueOf(currentOffset)); + if (currentOffset % numShards == shardId) { + store(String.valueOf(currentOffset)); + } currentOffset++; } else { break; diff --git a/sdks/java/io/sparkreceiver/3/src/test/java/org/apache/beam/sdk/io/sparkreceiver/ReadFromSparkReceiverWithOffsetDoFnTest.java b/sdks/java/io/sparkreceiver/3/src/test/java/org/apache/beam/sdk/io/sparkreceiver/ReadFromSparkReceiverWithOffsetDoFnTest.java index 6ab5d8393def..98b7ddccfa12 100644 --- a/sdks/java/io/sparkreceiver/3/src/test/java/org/apache/beam/sdk/io/sparkreceiver/ReadFromSparkReceiverWithOffsetDoFnTest.java +++ b/sdks/java/io/sparkreceiver/3/src/test/java/org/apache/beam/sdk/io/sparkreceiver/ReadFromSparkReceiverWithOffsetDoFnTest.java @@ -35,7 +35,7 @@ /** Test class for {@link ReadFromSparkReceiverWithOffsetDoFn}. */ public class ReadFromSparkReceiverWithOffsetDoFnTest { - private static final byte[] TEST_ELEMENT = new byte[] {}; + private static final Integer TEST_ELEMENT = 0; private final ReadFromSparkReceiverWithOffsetDoFn dofnInstance = new ReadFromSparkReceiverWithOffsetDoFn<>(makeReadTransform()); diff --git a/sdks/java/io/sparkreceiver/3/src/test/java/org/apache/beam/sdk/io/sparkreceiver/SparkReceiverIOTest.java b/sdks/java/io/sparkreceiver/3/src/test/java/org/apache/beam/sdk/io/sparkreceiver/SparkReceiverIOTest.java index bb482e798387..c8fa9f0526b3 100644 --- a/sdks/java/io/sparkreceiver/3/src/test/java/org/apache/beam/sdk/io/sparkreceiver/SparkReceiverIOTest.java +++ b/sdks/java/io/sparkreceiver/3/src/test/java/org/apache/beam/sdk/io/sparkreceiver/SparkReceiverIOTest.java @@ -241,4 +241,32 @@ public void testReadFromReceiverIteratorData() { PAssert.that(actual).containsInAnyOrder(expected); pipeline.run().waitUntilFinish(Duration.standardSeconds(15)); } + + @Test + public void testReadFromCustomReceiverWithParallelism() { + CustomReceiverWithOffset.shouldFailInTheMiddle = false; + ReceiverBuilder receiverBuilder = + new ReceiverBuilder<>(CustomReceiverWithOffset.class).withConstructorArgs(); + SparkReceiverIO.Read reader = + SparkReceiverIO.read() + .withGetOffsetFn(Long::valueOf) + .withTimestampFn(Instant::parse) + .withPullFrequencySec(PULL_FREQUENCY_SEC) + .withStartPollTimeoutSec(START_POLL_TIMEOUT_SEC) + .withStartOffset(START_OFFSET) + .withSparkReceiverBuilder(receiverBuilder) + .withNumReaders(3); + + List expected = new ArrayList<>(); + // With sharding enabled in CustomReceiverWithOffset, the total records read + // across all workers + // should be exactly the set of 0..RECORDS_COUNT-1, each read exactly once. + for (int i = 0; i < CustomReceiverWithOffset.RECORDS_COUNT; i++) { + expected.add(String.valueOf(i)); + } + PCollection actual = pipeline.apply(reader).setCoder(StringUtf8Coder.of()); + + PAssert.that(actual).containsInAnyOrder(expected); + pipeline.run().waitUntilFinish(Duration.standardSeconds(15)); + } }