diff --git a/src/main/java/algebra/curves/barreto_lynn_scott/BLSBinaryReader.java b/src/main/java/algebra/curves/barreto_lynn_scott/BLSBinaryReader.java index 15d22af..e0bf1a0 100644 --- a/src/main/java/algebra/curves/barreto_lynn_scott/BLSBinaryReader.java +++ b/src/main/java/algebra/curves/barreto_lynn_scott/BLSBinaryReader.java @@ -78,14 +78,24 @@ public BLSFrT readFr() throws IOException { @Override public BLSG1T readG1() throws IOException { - return G1One.construct(readFq(), readFq(), FqOne); + final var x = readFq(); + final var y = readFq(); + if (x.isZero() && y.isOne()) { + return G1One.zero(); + } + + return G1One.construct(x, y, FqOne); } @Override public BLSG2T readG2() throws IOException { - final BLSFq2T X = readFq2(); - final BLSFq2T Y = readFq2(); - return G2One.construct(X, Y, Y.one()); + final BLSFq2T x = readFq2(); + final BLSFq2T y = readFq2(); + if (x.isZero() && y.isOne()) { + return G2One.zero(); + } + + return G2One.construct(x, y, y.one()); } protected BLSFqT readFq() throws IOException { diff --git a/src/main/java/algebra/curves/barreto_naehrig/BNBinaryReader.java b/src/main/java/algebra/curves/barreto_naehrig/BNBinaryReader.java index 9793202..0a60fa6 100644 --- a/src/main/java/algebra/curves/barreto_naehrig/BNBinaryReader.java +++ b/src/main/java/algebra/curves/barreto_naehrig/BNBinaryReader.java @@ -76,14 +76,24 @@ public BNFrT readFr() throws IOException { @Override public BNG1T readG1() throws IOException { - return G1One.construct(readFq(), readFq(), FqOne); + final var x = readFq(); + final var y = readFq(); + if (x.isZero() && y.isOne()) { + return G1One.zero(); + } + + return G1One.construct(x, y, FqOne); } @Override public BNG2T readG2() throws IOException { - final BNFq2T X = readFq2(); - final BNFq2T Y = readFq2(); - return G2One.construct(X, Y, Y.one()); + final var x = readFq2(); + final var y = readFq2(); + if (x.isZero() && y.isOne()) { + return G2One.zero(); + } + + return G2One.construct(x, y, y.one()); } protected BNFqT readFq() throws IOException { diff --git a/src/main/java/common/PairRDDAggregator.java b/src/main/java/common/PairRDDAggregator.java index 4790586..4027013 100644 --- a/src/main/java/common/PairRDDAggregator.java +++ b/src/main/java/common/PairRDDAggregator.java @@ -3,6 +3,8 @@ import java.util.ArrayList; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.storage.StorageLevel; + import scala.Tuple2; import scala.collection.JavaConverters; @@ -53,7 +55,18 @@ public JavaPairRDD aggregate() { } void processBatch() { - batches.add(sc.parallelizePairs(currentBatch, numPartitions)); + System.out.println("processBatch: " + String.valueOf(batches.size())); + final var newBatchRDD = sc.parallelizePairs(currentBatch, numPartitions); + + // To avoid running out of memory, 'checkpoint' the RDD. (The goal is to + // force it to be fully evaluated (and potentially evicted to disk), + // removing any need to recompute it, since receomputing requires that the + // original array of batch data must be present in memory somewhere). + newBatchRDD.cache(); + newBatchRDD.checkpoint(); + // newBatchRDD.persist(StorageLevel.MEMORY_AND_DISK()); + + batches.add(newBatchRDD); currentBatch = null; } diff --git a/src/main/java/prover/Prover.java b/src/main/java/prover/Prover.java index 3b17e57..9c6b68b 100644 --- a/src/main/java/prover/Prover.java +++ b/src/main/java/prover/Prover.java @@ -125,6 +125,12 @@ static JavaSparkContext createSparkContext(boolean local) { final SparkSession spark = sessionBuilder.getOrCreate(); spark.sparkContext().conf().set("spark.files.overwrite", "true"); + + // checkpoint directory + spark.sparkContext().setCheckpointDir("hdfs://ip-172-31-42-216:9000/checkpoints/"); + // clean checkpoint files if the reference is out of scope + // spark.sparkContext().conf().set("spark.cleaner.referenceTracking.cleanCheckpoints", "true"); + // TODO: reinstate this when it can be made to work // spark.sparkContext().conf().set( // "spark.serializer", diff --git a/src/test/java/algebra/curves/GenericBinaryWriterTest.java b/src/test/java/algebra/curves/GenericBinaryWriterTest.java index a3854c2..edd6422 100644 --- a/src/test/java/algebra/curves/GenericBinaryWriterTest.java +++ b/src/test/java/algebra/curves/GenericBinaryWriterTest.java @@ -30,14 +30,17 @@ public void testBinaryWriter( { var writer = mkWriter.apply(os); + writer.writeFr(frOne.zero()); writer.writeFr(frOne); writer.writeFr(frOne.construct(-1)); writer.writeFr(frOne.construct(2)); writer.writeFr(frOne.construct(-2)); + writer.writeG1(g1One.zero()); writer.writeG1(g1One); writer.writeG1(g1One.mul(frOne.construct(-1))); writer.writeG1(g1One.mul(frOne.construct(2))); writer.writeG1(g1One.mul(frOne.construct(-2))); + writer.writeG2(g2One.zero()); writer.writeG2(g2One); writer.writeG2(g2One.mul(frOne.construct(-1))); writer.writeG2(g2One.mul(frOne.construct(2))); @@ -51,14 +54,17 @@ public void testBinaryWriter( final var is = new ByteArrayInputStream(buffer); final var reader = mkReader.apply(is); + assertEquals(frOne.zero(), reader.readFr()); assertEquals(frOne, reader.readFr()); assertEquals(frOne.construct(-1), reader.readFr()); assertEquals(frOne.construct(2), reader.readFr()); assertEquals(frOne.construct(-2), reader.readFr()); + assertEquals(g1One.zero(), reader.readG1()); assertEquals(g1One, reader.readG1()); assertEquals(g1One.mul(frOne.construct(-1)), reader.readG1()); assertEquals(g1One.mul(frOne.construct(2)), reader.readG1()); assertEquals(g1One.mul(frOne.construct(-2)), reader.readG1()); + assertEquals(g2One.zero(), reader.readG2()); assertEquals(g2One, reader.readG2()); assertEquals(g2One.mul(frOne.construct(-1)), reader.readG2()); assertEquals(g2One.mul(frOne.construct(2)), reader.readG2());