diff --git a/.gitignore b/.gitignore index 91e9c41a..139d9604 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,4 @@ /revision_logs /*.log /*.status +/bench \ No newline at end of file diff --git a/src/decoder/bench.rs b/src/decoder/bench.rs index 7c0c9446..5e4c3087 100644 --- a/src/decoder/bench.rs +++ b/src/decoder/bench.rs @@ -47,6 +47,25 @@ fn estimate_summary(estimate: CircuitBenchEstimate) -> CircuitBenchSummary { } } +fn scale_independent_summary_biguint( + summary: CircuitBenchSummary, + count: &BigUint, +) -> CircuitBenchSummary { + let scaled = CircuitBenchSummary::from_nanos( + summary.total_time.clone() * count, + summary.latency, + summary.max_parallelism.clone() * count, + ); + #[cfg(feature = "gpu")] + { + scaled.with_peak_vram(summary.peak_vram) + } + #[cfg(not(feature = "gpu"))] + { + scaled + } +} + fn sequential_summaries(parts: &[CircuitBenchSummary]) -> CircuitBenchSummary { let total_time = parts.iter().map(|part| part.total_time.clone()).sum::(); let latency = parts.iter().map(|part| part.latency).sum::(); @@ -201,12 +220,9 @@ pub(crate) fn bit_decomposed_polynomial_mask_decrypt_contribution_count( ring_dim: usize, mask_bits: usize, output_count: usize, -) -> usize { +) -> BigUint { assert!(mask_bits > 0, "mask_bits must be positive"); - ring_dim - .checked_mul(mask_bits) - .and_then(|count| count.checked_mul(output_count)) - .expect("bit-decomposed polynomial mask decrypt contribution count overflow") + BigUint::from(ring_dim) * BigUint::from(mask_bits) * BigUint::from(output_count) } /// Scale a representative one-ciphertext-bit decrypt contribution to a full @@ -216,13 +232,13 @@ pub(crate) fn scale_bit_decomposed_polynomial_mask_decrypt_contributions( ring_dim: usize, mask_bits: usize, output_count: usize, -) -> (CircuitBenchSummary, usize) { +) -> (CircuitBenchSummary, BigUint) { let contribution_count = bit_decomposed_polynomial_mask_decrypt_contribution_count( ring_dim, mask_bits, output_count, ); - (scale_independent_summary(unit, contribution_count), contribution_count) + (scale_independent_summary_biguint(unit, &contribution_count), contribution_count) } /// Number of additions needed to reduce `mask_bits` decrypted bit terms. @@ -293,3 +309,35 @@ pub(crate) fn bit_decomposed_polynomial_mask_reduction_summary( } scale_independent_summary(sequential_summaries(&per_polynomial_parts), polynomial_count) } + +#[cfg(test)] +mod tests { + use super::*; + + fn summary(total_time: u64, latency: f64, max_parallelism: u64) -> CircuitBenchSummary { + CircuitBenchSummary::from_nanos( + BigUint::from(total_time), + latency, + BigUint::from(max_parallelism), + ) + } + + #[test] + fn mask_decrypt_contribution_count_supports_large_aky24_cascade_outputs() { + let output_count = usize::MAX / 2; + let contribution_count = + bit_decomposed_polynomial_mask_decrypt_contribution_count(65_536, 1_418, output_count); + assert!(contribution_count > BigUint::from(usize::MAX)); + + let (scaled, returned_count) = scale_bit_decomposed_polynomial_mask_decrypt_contributions( + summary(7, 1.25, 3), + 65_536, + 1_418, + output_count, + ); + assert_eq!(returned_count, contribution_count); + assert_eq!(scaled.total_time, &contribution_count * BigUint::from(7u32)); + assert_eq!(scaled.latency, 1.25); + assert_eq!(scaled.max_parallelism, contribution_count * BigUint::from(3u32)); + } +} diff --git a/src/io/aky24_io.rs b/src/io/aky24_io.rs index 9ee438a7..98f06fa3 100644 --- a/src/io/aky24_io.rs +++ b/src/io/aky24_io.rs @@ -73,8 +73,8 @@ pub struct Aky24IO< pub output_size: usize, /// Number of private PRF seed bits encrypted into Ring-GSW ciphertexts. pub seed_bits: usize, - /// Number of public PRF seed bits, hence PRF seed-refresh rounds. - pub public_prf_seed_bits: usize, + /// Number of input bits processed by one seed-refresh round. + pub prf_batch_bits: usize, /// Number of bit-decomposed PRF mask output coefficients to compute. pub prf_mask_output_coeff_bits: usize, /// Number of low bits retained in the noise-refresh rounding material. @@ -113,7 +113,7 @@ where input_size: usize, output_size: usize, seed_bits: usize, - public_prf_seed_bits: usize, + prf_batch_bits: usize, prf_mask_output_coeff_bits: usize, noise_refresh_v_bits: usize, noise_refresh_cbd_n: usize, @@ -127,7 +127,16 @@ where assert!(input_size > 0, "AKY24IO input_size must be positive"); assert!(output_size > 0, "AKY24IO output_size must be positive"); assert!(seed_bits > 0, "AKY24IO seed_bits must be positive"); - assert!(public_prf_seed_bits > 0, "AKY24IO public_prf_seed_bits must be positive"); + assert!(prf_batch_bits > 0, "AKY24IO prf_batch_bits must be positive"); + assert!( + prf_batch_bits < usize::BITS as usize, + "AKY24IO prf_batch_bits must fit in a usize branch count" + ); + assert_eq!( + input_size % prf_batch_bits, + 0, + "AKY24IO input_size must be divisible by prf_batch_bits" + ); assert!( prf_mask_output_coeff_bits > 0, "AKY24IO prf_mask_output_coeff_bits must be positive" @@ -146,7 +155,7 @@ where input_size, output_size, seed_bits, - public_prf_seed_bits, + prf_batch_bits, prf_mask_output_coeff_bits, noise_refresh_v_bits, noise_refresh_cbd_n, @@ -160,7 +169,15 @@ where } } + pub(crate) fn prf_round_count(&self) -> usize { + self.input_size / self.prf_batch_bits + } + + pub(crate) fn prf_branch_count(&self) -> usize { + 1usize.checked_shl(self.prf_batch_bits as u32).expect("AKY24IO PRF branch count overflow") + } + pub(crate) fn prf_final_round_idx(&self) -> usize { - self.public_prf_seed_bits + self.prf_round_count() } } diff --git a/src/io/aky24_io/bench_estimator.rs b/src/io/aky24_io/bench_estimator.rs index b528965f..ec7ec87f 100644 --- a/src/io/aky24_io/bench_estimator.rs +++ b/src/io/aky24_io/bench_estimator.rs @@ -42,6 +42,18 @@ pub struct Aky24IOBenchEstimate { pub eval: CircuitBenchSummary, /// Compact bytes for persisted obfuscated-circuit material modeled by this estimator. pub obfuscated_circuit_bytes: BigUint, + /// Obfuscation contribution from Section 3.2 FE-to-iO cascade layers. + pub fe_to_io_obfuscate: CircuitBenchSummary, + /// Obfuscation contribution from the final FE layer. + pub final_fe_obfuscate: CircuitBenchSummary, + /// Online evaluation contribution from Section 3.2 FE-to-iO cascade layers. + pub fe_to_io_eval: CircuitBenchSummary, + /// Online evaluation contribution from the final FE layer. + pub final_fe_eval: CircuitBenchSummary, + /// Obfuscation total-time contribution from Section 3.2 FE-to-iO cascade layers. + pub fe_to_io_obfuscate_total_time: BigUint, + /// Obfuscation total-time contribution from the final FE layer. + pub final_fe_obfuscate_total_time: BigUint, /// Online evaluation total-time contribution from Section 3.2 FE-to-iO cascade layers. pub fe_to_io_eval_total_time: BigUint, /// Online evaluation total-time contribution from the final FE layer. @@ -67,16 +79,28 @@ impl Aky24IOBenchEstimate { let eval_parts = layers.iter().map(|layer| layer.eval.clone()).collect::>(); let obfuscated_circuit_bytes = layers.iter().map(|layer| layer.obfuscated_circuit_bytes.clone()).sum::(); - let fe_to_io_eval_total_time = - fe_to_io.iter().map(|layer| layer.eval.total_time.clone()).sum::(); + let fe_to_io_obfuscate_parts = + fe_to_io.iter().map(|layer| layer.obfuscate.clone()).collect::>(); + let fe_to_io_eval_parts = + fe_to_io.iter().map(|layer| layer.eval.clone()).collect::>(); + let fe_to_io_obfuscate = sequential_summaries(&fe_to_io_obfuscate_parts); + let final_fe_obfuscate = final_fe.obfuscate.clone(); + let fe_to_io_eval = sequential_summaries(&fe_to_io_eval_parts); + let final_fe_eval = final_fe.eval.clone(); let fe_to_io_obfuscated_circuit_bytes = fe_to_io.iter().map(|layer| layer.obfuscated_circuit_bytes.clone()).sum::(); Self { obfuscate: sequential_summaries(&obfuscate_parts), eval: sequential_summaries(&eval_parts), obfuscated_circuit_bytes, - fe_to_io_eval_total_time, - final_fe_eval_total_time: final_fe.eval.total_time, + fe_to_io_obfuscate_total_time: fe_to_io_obfuscate.total_time.clone(), + final_fe_obfuscate_total_time: final_fe_obfuscate.total_time.clone(), + fe_to_io_eval_total_time: fe_to_io_eval.total_time.clone(), + final_fe_eval_total_time: final_fe_eval.total_time.clone(), + fe_to_io_obfuscate, + final_fe_obfuscate, + fe_to_io_eval, + final_fe_eval, fe_to_io_obfuscated_circuit_bytes, final_fe_obfuscated_circuit_bytes: final_fe.obfuscated_circuit_bytes, } @@ -154,8 +178,9 @@ impl<'a, PKBE, EncBE, NBE> Aky24IOBenchEstimator<'a, PKBE, EncBE, NBE> { { let final_shape = Aky24IOBenchShape::from_scheme(scheme, scheme.input_size, func.output_bits()); - let final_fe = self.estimate_layer(scheme, &final_shape); - let mut fe_to_io = Vec::with_capacity(final_shape.input_size.saturating_sub(1)); + let prf_units = self.estimate_prf_bench_units(scheme); + let final_fe = self.estimate_layer(scheme, &final_shape, &prf_units); + let mut fe_to_io = Vec::with_capacity(final_shape.prf_round_count.saturating_sub(1)); for stage_input_count in final_shape.cascade_stage_input_counts() { let stage_shape = Aky24IOBenchShape::from_scheme( scheme, @@ -167,15 +192,77 @@ impl<'a, PKBE, EncBE, NBE> Aky24IOBenchEstimator<'a, PKBE, EncBE, NBE> { final_shape.modulus_digits, ), ); - fe_to_io.push(self.estimate_layer(scheme, &stage_shape)); + fe_to_io.push(self.estimate_layer(scheme, &stage_shape, &prf_units)); } Aky24IOBenchEstimate::sequential(final_fe, fe_to_io) } + fn estimate_prf_bench_units( + &self, + scheme: &Aky24IO, + ) -> Aky24IOPrfBenchUnits + where + M: PolyMatrix + Send + Sync + 'static, + M::P: 'static, + PKBE: BenchEstimator> + Sync, + PKBE: PublicKeyAuxBenchEstimator, + EncBE: BenchEstimator> + Sync, + NestedRnsPoly: DecomposeArithmeticGadget + ModularArithmeticPlanner, + { + let prg_circuit = build_representative_goldreich_prg_one_output_circuit(scheme); + let prf_mask_decrypt_circuit = prf_mask_decrypt_one_ciphertext_bit_circuit(scheme); + let prg_aux = self + .public_key_estimator + .estimate_public_lut_sample_aux_matrices_for_circuit(&scheme.params, &prg_circuit); + let scalar_one = vec![BigUint::from(1u32); scheme.params.ring_dimension() as usize]; + let scalar_target = [BigUint::from(1u32)]; + let pk_seed_lift_unit = self.public_key_estimator.estimate_large_scalar_mul(&scalar_one); + let enc_seed_lift_unit = self.encoding_estimator.estimate_large_scalar_mul(&scalar_one); + let pk_noise_refresh_matrix_mul = + self.public_key_estimator.estimate_large_scalar_mul(&scalar_target); + let enc_noise_refresh_matrix_mul = + self.encoding_estimator.estimate_large_scalar_mul(&scalar_target); + Aky24IOPrfBenchUnits { + public_key_preprocess: Aky24IOPrfBenchModeUnits { + final_mask_decrypt_unit: estimate_public_key_circuit_bench_with_aux::< + NaiveBGGPublicKeyVec, + PKBE, + >( + self.public_key_estimator, + &scheme.params, + &prf_mask_decrypt_circuit, + ), + prg_unit: estimate_public_key_circuit_bench_with_aux::, PKBE>( + self.public_key_estimator, + &scheme.params, + &prg_circuit, + ), + add: self.public_key_estimator.estimate_add(), + sub: self.public_key_estimator.estimate_sub(), + mul: self.public_key_estimator.estimate_mul(), + seed_lift_unit: pk_seed_lift_unit, + noise_refresh_matrix_mul_unit: pk_noise_refresh_matrix_mul, + }, + encoding_online: Aky24IOPrfBenchModeUnits { + final_mask_decrypt_unit: self + .encoding_estimator + .estimate_circuit_bench(&prf_mask_decrypt_circuit), + prg_unit: self.encoding_estimator.estimate_circuit_bench(&prg_circuit), + add: self.encoding_estimator.estimate_add(), + sub: self.encoding_estimator.estimate_sub(), + mul: self.encoding_estimator.estimate_mul(), + seed_lift_unit: enc_seed_lift_unit, + noise_refresh_matrix_mul_unit: enc_noise_refresh_matrix_mul, + }, + public_lut_prg_aux_compact_bytes: prg_aux.compact_bytes, + } + } + fn estimate_layer( &self, scheme: &Aky24IO, shape: &Aky24IOBenchShape, + prf_units: &Aky24IOPrfBenchUnits, ) -> Aky24IOBenchLayerEstimate where M: PolyMatrix + Send + Sync + 'static, @@ -186,9 +273,12 @@ impl<'a, PKBE, EncBE, NBE> Aky24IOBenchEstimator<'a, PKBE, EncBE, NBE> { NBE: DiamondIONativeBenchEstimator, NestedRnsPoly: DecomposeArithmeticGadget + ModularArithmeticPlanner, { - let obfuscate = self.estimate_obfuscate(scheme, shape); - let eval = self.estimate_eval(scheme, shape); - let public_lut_aux_bytes = self.estimate_public_lut_aux_storage_bytes(scheme, shape); + let obfuscate = self.estimate_obfuscate(scheme, shape, &prf_units.public_key_preprocess); + let eval = self.estimate_eval(scheme, shape, &prf_units.encoding_online); + let public_lut_aux_bytes = self.estimate_public_lut_aux_storage_bytes( + shape, + &prf_units.public_lut_prg_aux_compact_bytes, + ); Aky24IOBenchLayerEstimate { obfuscate, eval, @@ -200,6 +290,7 @@ impl<'a, PKBE, EncBE, NBE> Aky24IOBenchEstimator<'a, PKBE, EncBE, NBE> { &self, scheme: &Aky24IO, shape: &Aky24IOBenchShape, + prf_units: &Aky24IOPrfBenchModeUnits, ) -> CircuitBenchSummary where M: PolyMatrix + Send + Sync + 'static, @@ -210,7 +301,6 @@ impl<'a, PKBE, EncBE, NBE> Aky24IOBenchEstimator<'a, PKBE, EncBE, NBE> { NBE: DiamondIONativeBenchEstimator, NestedRnsPoly: DecomposeArithmeticGadget + ModularArithmeticPlanner, { - let scalar_one = vec![BigUint::from(1u32); shape.ring_dim]; let bgg_public_keys = scale_estimate( self.bgg_public_key_sample.clone(), shape.input_size.checked_add(2).expect("AKY24IO public-key count overflow"), @@ -219,7 +309,7 @@ impl<'a, PKBE, EncBE, NBE> Aky24IOBenchEstimator<'a, PKBE, EncBE, NBE> { estimate_summary(self.ring_gsw_public_key_sample.clone()); let seed_encrypt = scale_estimate(self.ring_gsw_encrypt_bit.clone(), shape.seed_bits); let seed_lift = scale_estimate( - self.public_key_estimator.estimate_large_scalar_mul(scalar_one.as_slice()), + prf_units.seed_lift_unit.clone(), shape .seed_bits .checked_mul(shape.ring_gsw_wire_count) @@ -231,6 +321,7 @@ impl<'a, PKBE, EncBE, NBE> Aky24IOBenchEstimator<'a, PKBE, EncBE, NBE> { scheme, shape, PrfBenchMode::PublicKeyPreprocess, + prf_units, ); let final_projection = self.estimate_final_projection_preprocess(shape); sequential_summaries(&[ @@ -245,6 +336,7 @@ impl<'a, PKBE, EncBE, NBE> Aky24IOBenchEstimator<'a, PKBE, EncBE, NBE> { &self, _scheme: &Aky24IO, shape: &Aky24IOBenchShape, + prf_units: &Aky24IOPrfBenchModeUnits, ) -> CircuitBenchSummary where M: PolyMatrix + Send + Sync + 'static, @@ -254,14 +346,13 @@ impl<'a, PKBE, EncBE, NBE> Aky24IOBenchEstimator<'a, PKBE, EncBE, NBE> { EncBE: BenchEstimator> + Sync, NBE: DiamondIONativeBenchEstimator, { - let scalar_one = vec![BigUint::from(1u32); shape.ring_dim]; let input_projection = scale_summary( self.native_estimator .estimate_vector_matrix_product(shape.state_col_size, shape.modulus_digits), shape.input_size.checked_add(2).expect("AKY24IO input projection count overflow"), ); let seed_lift = scale_estimate( - self.encoding_estimator.estimate_large_scalar_mul(scalar_one.as_slice()), + prf_units.seed_lift_unit.clone(), shape .seed_bits .checked_mul(shape.ring_gsw_wire_count) @@ -271,6 +362,7 @@ impl<'a, PKBE, EncBE, NBE> Aky24IOBenchEstimator<'a, PKBE, EncBE, NBE> { _scheme, shape, PrfBenchMode::EncodingOnline, + prf_units, ); let decoder_projection = scale_summary( self.native_estimator.estimate_vector_matrix_product(shape.state_col_size, 1), @@ -295,9 +387,10 @@ impl<'a, PKBE, EncBE, NBE> Aky24IOBenchEstimator<'a, PKBE, EncBE, NBE> { fn estimate_prf_path( &self, - scheme: &Aky24IO, + _scheme: &Aky24IO, shape: &Aky24IOBenchShape, mode: PrfBenchMode, + units: &Aky24IOPrfBenchModeUnits, ) -> Aky24IOPrfBenchEstimateParts where M: PolyMatrix + Send + Sync + 'static, @@ -308,35 +401,11 @@ impl<'a, PKBE, EncBE, NBE> Aky24IOBenchEstimator<'a, PKBE, EncBE, NBE> { NBE: DiamondIONativeBenchEstimator, NestedRnsPoly: DecomposeArithmeticGadget + ModularArithmeticPlanner, { - let final_mask_decrypt_unit = self - .estimate_prf_mask_decrypt_one_ciphertext_bit_unit::( - scheme, mode, - ); - let prg_circuit = build_representative_goldreich_prg_one_output_circuit(scheme); - let prg_unit = match mode { - PrfBenchMode::PublicKeyPreprocess => { - estimate_public_key_circuit_bench_with_aux::, PKBE>( - self.public_key_estimator, - &scheme.params, - &prg_circuit, - ) - } - PrfBenchMode::EncodingOnline => { - self.encoding_estimator.estimate_circuit_bench(&prg_circuit) - } - }; - let (add, sub, mul) = match mode { - PrfBenchMode::PublicKeyPreprocess => ( - self.public_key_estimator.estimate_add(), - self.public_key_estimator.estimate_sub(), - self.public_key_estimator.estimate_mul(), - ), - PrfBenchMode::EncodingOnline => ( - self.encoding_estimator.estimate_add(), - self.encoding_estimator.estimate_sub(), - self.encoding_estimator.estimate_mul(), - ), - }; + let final_mask_decrypt_unit = units.final_mask_decrypt_unit.clone(); + let prg_unit = units.prg_unit.clone(); + let add = units.add.clone(); + let sub = units.sub.clone(); + let mul = units.mul.clone(); let noise_refresh_mask_prg_unit = prg_unit.clone(); let noise_refresh_error_prg_unit = goldreich_cbd_error_prg_summary( prg_unit.clone(), @@ -379,12 +448,13 @@ impl<'a, PKBE, EncBE, NBE> Aky24IOBenchEstimator<'a, PKBE, EncBE, NBE> { shape.prf_round_count, ); let refresh_parts = self.estimate_noise_refresh_sparse::( - scheme, + _scheme, shape, mode, noise_refresh_error_prg_unit, noise_refresh_mask_prg_unit, final_mask_decrypt_unit.clone(), + units, ); let noise_refresh_branch_count = match mode { PrfBenchMode::PublicKeyPreprocess => shape.prf_branch_count, @@ -434,7 +504,7 @@ impl<'a, PKBE, EncBE, NBE> Aky24IOBenchEstimator<'a, PKBE, EncBE, NBE> { ]); let noise_refresh = repeat_sequential_summary(noise_refresh_per_round.clone(), shape.prf_round_count); - let final_prg = scale_summary(prg_unit.clone(), shape.final_prg_output_count()); + let final_prg = scale_summary_biguint(prg_unit.clone(), &shape.final_prg_output_count()); let (final_mask_decrypt_contributions, _final_mask_decrypt_contribution_count) = scale_bit_decomposed_polynomial_mask_decrypt_contributions( final_mask_decrypt_unit.clone(), @@ -511,34 +581,6 @@ impl<'a, PKBE, EncBE, NBE> Aky24IOBenchEstimator<'a, PKBE, EncBE, NBE> { } } - fn estimate_prf_mask_decrypt_one_ciphertext_bit_unit( - &self, - scheme: &Aky24IO, - mode: PrfBenchMode, - ) -> CircuitBenchSummary - where - M: PolyMatrix + Send + Sync + 'static, - M::P: 'static, - PKBE: BenchEstimator> + Sync, - PKBE: PublicKeyAuxBenchEstimator, - EncBE: BenchEstimator> + Sync, - NestedRnsPoly: DecomposeArithmeticGadget + ModularArithmeticPlanner, - { - let circuit = prf_mask_decrypt_one_ciphertext_bit_circuit(scheme); - match mode { - PrfBenchMode::PublicKeyPreprocess => { - estimate_public_key_circuit_bench_with_aux::, PKBE>( - self.public_key_estimator, - &scheme.params, - &circuit, - ) - } - PrfBenchMode::EncodingOnline => { - self.encoding_estimator.estimate_circuit_bench(&circuit) - } - } - } - fn estimate_noise_refresh_sparse( &self, _scheme: &Aky24IO, @@ -547,6 +589,7 @@ impl<'a, PKBE, EncBE, NBE> Aky24IOBenchEstimator<'a, PKBE, EncBE, NBE> { error_prg_unit: CircuitBenchSummary, mask_prg_unit: CircuitBenchSummary, decrypt_contribution_unit: CircuitBenchSummary, + units: &Aky24IOPrfBenchModeUnits, ) -> NoiseRefreshBenchEstimateParts where M: PolyMatrix + Send + Sync + 'static, @@ -562,10 +605,7 @@ impl<'a, PKBE, EncBE, NBE> Aky24IOBenchEstimator<'a, PKBE, EncBE, NBE> { shape.noise_refresh_v_bits, false, ); - let add = match mode { - PrfBenchMode::PublicKeyPreprocess => self.public_key_estimator.estimate_add(), - PrfBenchMode::EncodingOnline => self.encoding_estimator.estimate_add(), - }; + let add = units.add.clone(); let material = bit_decomposed_refresh_material_summary( error_prg_unit, mask_prg_unit, @@ -574,7 +614,6 @@ impl<'a, PKBE, EncBE, NBE> Aky24IOBenchEstimator<'a, PKBE, EncBE, NBE> { material_counts, shape.noise_refresh_v_bits, ); - let scalar_target = [BigUint::from(1u32)]; let combine_task_count = shape .ring_dim .checked_mul(shape.crt_depth) @@ -587,17 +626,23 @@ impl<'a, PKBE, EncBE, NBE> Aky24IOBenchEstimator<'a, PKBE, EncBE, NBE> { .expect("AKY24IO noise-refresh collapse add count overflow"); let per_refresh = match mode { PrfBenchMode::PublicKeyPreprocess => { - let pk_matrix_mul = - self.public_key_estimator.estimate_large_scalar_mul(&scalar_target); - let pk_add = self.public_key_estimator.estimate_add(); - let pk_sub = self.public_key_estimator.estimate_sub(); let combine_unit = sequential_summaries(&[ - estimate_summary(pk_matrix_mul.clone()), - estimate_summary(pk_matrix_mul.clone()), - scale_estimate(pk_matrix_mul, shape.modulus_digits), - scale_estimate(pk_add.clone(), collapse_add_count), - estimate_summary(pk_add), - estimate_summary(pk_sub), + // Compute the public-key one-term: `one.key(...).matrix_mul((q / q_i) * A')`. + estimate_summary(units.noise_refresh_matrix_mul_unit.clone()), + // Compute the refreshed-input public-key term: + // `refreshed_input.key(...).matrix_mul((q / q_i) * G)`. + estimate_summary(units.noise_refresh_matrix_mul_unit.clone()), + // Apply the one-column target to each decoded material column. + scale_estimate( + units.noise_refresh_matrix_mul_unit.clone(), + shape.modulus_digits, + ), + // Collapse the ring-position contributions into one accumulator. + scale_estimate(units.add.clone(), collapse_add_count), + // Add the refreshed share into the collapsed accumulator. + estimate_summary(units.add.clone()), + // Subtract the mask share to finish the refreshed ciphertext entry. + estimate_summary(units.sub.clone()), ]); sequential_summaries(&[ a_prime_sampling_stage.clone(), @@ -605,19 +650,28 @@ impl<'a, PKBE, EncBE, NBE> Aky24IOBenchEstimator<'a, PKBE, EncBE, NBE> { ]) } PrfBenchMode::EncodingOnline => { - let enc_matrix_mul = - self.encoding_estimator.estimate_large_scalar_mul(&scalar_target); - let enc_add = self.encoding_estimator.estimate_add(); - let enc_sub = self.encoding_estimator.estimate_sub(); let crt_recompose = self.native_estimator.estimate_vector_add(shape.modulus_digits); let combine_unit = sequential_summaries(&[ - estimate_summary(enc_matrix_mul.clone()), - estimate_summary(enc_matrix_mul.clone()), - scale_estimate(enc_matrix_mul, shape.modulus_digits), - scale_estimate(enc_add.clone(), collapse_add_count), - estimate_summary(enc_add), - estimate_summary(enc_sub.clone()), - estimate_summary(enc_sub), + // Compute the encoding one-term: + // `one.encoding(...).matrix_mul((q / q_i) * A')`. + estimate_summary(units.noise_refresh_matrix_mul_unit.clone()), + // Compute the refreshed-input encoding term: + // `refreshed_input.encoding(...).matrix_mul((q / q_i) * G)`. + estimate_summary(units.noise_refresh_matrix_mul_unit.clone()), + // Apply the one-column target to each decoded material column. + scale_estimate( + units.noise_refresh_matrix_mul_unit.clone(), + shape.modulus_digits, + ), + // Collapse the ring-position contributions into one accumulator. + scale_estimate(units.add.clone(), collapse_add_count), + // Add the refreshed share into the collapsed accumulator. + estimate_summary(units.add.clone()), + // Subtract the mask share before native recomposition. + estimate_summary(units.sub.clone()), + // Remove the pre-recomposition value from the encoding accumulator. + estimate_summary(units.sub.clone()), + // Recompose the CRT digits into the native refreshed ciphertext entry. estimate_summary(crt_recompose), ]); sequential_summaries(&[ @@ -661,25 +715,16 @@ impl<'a, PKBE, EncBE, NBE> Aky24IOBenchEstimator<'a, PKBE, EncBE, NBE> { sequential_summaries(&[inputs, preimages]) } - fn estimate_public_lut_aux_storage_bytes( + fn estimate_public_lut_aux_storage_bytes( &self, - scheme: &Aky24IO, shape: &Aky24IOBenchShape, - ) -> BigUint - where - M: PolyMatrix + Send + Sync + 'static, - M::P: 'static, - PKBE: PublicKeyAuxBenchEstimator, - { - let prg_circuit = build_representative_goldreich_prg_one_output_circuit(scheme); - let prg_aux = self - .public_key_estimator - .estimate_public_lut_sample_aux_matrices_for_circuit(&scheme.params, &prg_circuit); - if prg_aux.compact_bytes == BigUint::default() { + prg_aux_compact_bytes: &BigUint, + ) -> BigUint { + if prg_aux_compact_bytes == &BigUint::default() { return BigUint::default(); } - prg_aux.compact_bytes.clone() * shape.public_lut_prg_output_count() + prg_aux_compact_bytes * shape.public_lut_prg_output_count() } } @@ -689,6 +734,24 @@ enum PrfBenchMode { EncodingOnline, } +#[derive(Debug, Clone, PartialEq)] +struct Aky24IOPrfBenchUnits { + public_key_preprocess: Aky24IOPrfBenchModeUnits, + encoding_online: Aky24IOPrfBenchModeUnits, + public_lut_prg_aux_compact_bytes: BigUint, +} + +#[derive(Debug, Clone, PartialEq)] +struct Aky24IOPrfBenchModeUnits { + final_mask_decrypt_unit: CircuitBenchSummary, + prg_unit: CircuitBenchSummary, + add: CircuitBenchEstimate, + sub: CircuitBenchEstimate, + mul: CircuitBenchEstimate, + seed_lift_unit: CircuitBenchEstimate, + noise_refresh_matrix_mul_unit: CircuitBenchEstimate, +} + #[derive(Debug, Clone, PartialEq)] struct NoiseRefreshBenchEstimateParts { material: CircuitBenchSummary, @@ -767,6 +830,7 @@ struct Aky24IOBenchShape { input_size: usize, output_size: usize, seed_bits: usize, + prf_batch_bits: usize, prf_round_count: usize, prf_branch_count: usize, prf_mask_output_coeff_bits: usize, @@ -809,13 +873,19 @@ impl Aky24IOBenchShape { .and_then(|count| count.checked_mul(ring_gsw_active_levels)) .and_then(|count| count.checked_mul(scheme.ring_gsw_context.p_moduli.len())) .expect("AKY24IO Ring-GSW wire count overflow"); + assert_eq!( + input_size % scheme.prf_batch_bits, + 0, + "AKY24IO benchmark layer input_size must be divisible by prf_batch_bits" + ); Self { ring_dim, input_size, output_size, seed_bits: scheme.seed_bits, - prf_round_count: scheme.public_prf_seed_bits, - prf_branch_count: 2, + prf_batch_bits: scheme.prf_batch_bits, + prf_round_count: input_size / scheme.prf_batch_bits, + prf_branch_count: scheme.prf_branch_count(), prf_mask_output_coeff_bits: scheme.prf_mask_output_coeff_bits, noise_refresh_v_bits: scheme.noise_refresh_v_bits, cbd_n: scheme.noise_refresh_cbd_n, @@ -827,8 +897,12 @@ impl Aky24IOBenchShape { } } - fn cascade_stage_input_counts(&self) -> std::ops::Range { - 1..self.input_size + fn cascade_stage_input_counts(&self) -> impl Iterator + '_ { + (1..self.prf_round_count).map(|round_count| { + round_count + .checked_mul(self.prf_batch_bits) + .expect("AKY24IO cascade stage input size overflow") + }) } fn fe_to_io_output_size_for_stage( @@ -848,25 +922,21 @@ impl Aky24IOBenchShape { self.output_size.checked_mul(self.ring_dim).expect("AKY24IO final decoder count overflow") } - fn final_mask_prg_output_count(&self) -> usize { - self.final_decoder_count() - .checked_mul(self.prf_mask_output_coeff_bits) - .expect("AKY24IO final mask PRG output count overflow") + fn final_mask_prg_output_count(&self) -> BigUint { + BigUint::from(self.final_decoder_count()) * BigUint::from(self.prf_mask_output_coeff_bits) } - fn final_prg_output_count(&self) -> usize { - self.final_mask_prg_output_count() - .checked_add(self.output_size) - .expect("AKY24IO final PRG output count overflow") + fn final_prg_output_count(&self) -> BigUint { + self.final_mask_prg_output_count() + BigUint::from(self.output_size) } fn obfuscated_circuit_bytes(&self, public_lut_aux_bytes: BigUint) -> BigUint { - BigUint::from(self.final_projection_preimage_bytes()) + + self.final_projection_preimage_bytes() + self.prf_refresh_preimage_bytes() + public_lut_aux_bytes } - fn final_projection_preimage_bytes(&self) -> usize { + fn final_projection_preimage_bytes(&self) -> BigUint { let standard_preimages = self.input_size.checked_add(2).expect("AKY24IO projection count overflow"); let output_preimage_bytes = bench_utils::matrix_compact_bytes_for_shape( @@ -881,16 +951,9 @@ impl Aky24IOBenchShape { self.ring_dim, self.modulus_bits, ); - output_preimage_bytes - .checked_mul(standard_preimages) - .and_then(|bytes| { - bytes.checked_add( - final_decoder_preimage_bytes - .checked_mul(self.final_decoder_count()) - .expect("AKY24IO final decoder preimage byte count overflow"), - ) - }) - .expect("AKY24IO final projection byte count overflow") + BigUint::from(output_preimage_bytes) * BigUint::from(standard_preimages) + + BigUint::from(final_decoder_preimage_bytes) * + BigUint::from(self.final_decoder_count()) } fn selected_prg_output_count(&self) -> usize { @@ -954,14 +1017,70 @@ impl Aky24IOBenchShape { .checked_mul(self.prf_branch_count) .expect("AKY24IO branch-specific noise-refresh material count overflow"), ); - let final_prg_outputs = BigUint::from(self.final_prg_output_count()); - selected_prg_outputs + noise_refresh_material_prg_outputs + final_prg_outputs + selected_prg_outputs + noise_refresh_material_prg_outputs + self.final_prg_output_count() } } #[cfg(test)] mod tests { use super::*; + use crate::{ + func_enc::NoCircuitEvaluator, + gadgets::arith::{ModularArithmeticContext, NestedRnsPolyContext}, + matrix::dcrt_poly::DCRTPolyMatrix, + poly::dcrt::{params::DCRTPolyParams, poly::DCRTPoly}, + }; + + type TestAky24IO = Aky24IO< + DCRTPolyMatrix, + NoCircuitEvaluator, + NoCircuitEvaluator, + NoCircuitEvaluator, + NoCircuitEvaluator, + >; + + fn test_scheme(input_size: usize, prf_batch_bits: usize) -> TestAky24IO { + let params = DCRTPolyParams::new(2, 1, 10, 5); + let mut setup_circuit = PolyCircuit::::new(); + let ring_gsw_context = Arc::new(NestedRnsPolyContext::setup( + &mut setup_circuit, + ¶ms, + 5, + 2, + 1 << 8, + false, + Some(1), + )); + let ring_gsw_width = 2 * + >::gadget_len( + ring_gsw_context.as_ref(), + Some(1), + Some(0), + ); + TestAky24IO::new( + params.clone(), + params, + ring_gsw_context, + ring_gsw_width, + 0, + Some(1), + Some(0.0), + b"aky24_io_bench_estimator_test".to_vec(), + input_size, + 1, + 6, + prf_batch_bits, + 1, + 1, + 1, + [0x24; 32], + [0x42; 32], + None, + None, + None, + None, + ) + } fn test_shape(input_size: usize) -> Aky24IOBenchShape { Aky24IOBenchShape { @@ -969,7 +1088,8 @@ mod tests { input_size, output_size: 2, seed_bits: 5, - prf_round_count: 3, + prf_batch_bits: 1, + prf_round_count: input_size, prf_branch_count: 2, prf_mask_output_coeff_bits: 4, noise_refresh_v_bits: 3, @@ -982,6 +1102,14 @@ mod tests { } } + fn batched_test_shape(input_size: usize, prf_batch_bits: usize) -> Aky24IOBenchShape { + Aky24IOBenchShape { + prf_batch_bits, + prf_round_count: input_size / prf_batch_bits, + ..test_shape(input_size) + } + } + fn summary(total_time: u64, latency: f64, max_parallelism: u64) -> CircuitBenchSummary { CircuitBenchSummary::from_nanos( BigUint::from(total_time), @@ -999,6 +1127,29 @@ mod tests { assert_eq!(test_shape(4).cascade_stage_input_counts().collect::>(), vec![1, 2, 3]); } + #[test] + fn test_cascade_stage_input_counts_follow_batch_boundaries() { + assert_eq!( + batched_test_shape(12, 4).cascade_stage_input_counts().collect::>(), + vec![4, 8] + ); + assert_eq!( + batched_test_shape(4, 4).cascade_stage_input_counts().collect::>(), + Vec::::new() + ); + } + + #[test] + fn test_from_scheme_uses_layer_input_size_for_prf_round_count() { + let scheme = test_scheme(12, 4); + let final_shape = Aky24IOBenchShape::from_scheme(&scheme, 12, 1); + let stage_shape = Aky24IOBenchShape::from_scheme(&scheme, 4, 1); + + assert_eq!(final_shape.prf_round_count, 3); + assert_eq!(stage_shape.prf_round_count, 1); + assert_eq!(stage_shape.prf_branch_count, 16); + } + #[test] fn test_fe_to_io_stage_output_size_matches_ciphertext_bit_width() { let shape = test_shape(3); @@ -1013,6 +1164,15 @@ mod tests { ); } + #[test] + fn test_large_cascade_storage_counts_use_biguint() { + let mut shape = test_shape(1); + shape.output_size = usize::MAX / shape.ring_dim / 2; + + assert!(shape.final_projection_preimage_bytes() > BigUint::from(usize::MAX)); + assert!(shape.final_prg_output_count() > BigUint::from(usize::MAX)); + } + #[test] fn test_bench_estimate_sequential_adds_and_splits_cascade_layers() { let estimate = Aky24IOBenchEstimate::sequential( @@ -1035,7 +1195,17 @@ mod tests { assert_eq!(estimate.eval.latency, 6.0); assert_eq!(estimate.eval.max_parallelism, BigUint::from(5u64)); assert_eq!(estimate.obfuscated_circuit_bytes, BigUint::from(12u64)); + assert_eq!(estimate.fe_to_io_obfuscate.latency, 1.0); + assert_eq!(estimate.fe_to_io_obfuscate.max_parallelism, BigUint::from(2u64)); + assert_eq!(estimate.fe_to_io_obfuscate_total_time, BigUint::from(10u64)); + assert_eq!(estimate.final_fe_obfuscate.latency, 3.0); + assert_eq!(estimate.final_fe_obfuscate.max_parallelism, BigUint::from(4u64)); + assert_eq!(estimate.final_fe_obfuscate_total_time, BigUint::from(30u64)); + assert_eq!(estimate.fe_to_io_eval.latency, 2.0); + assert_eq!(estimate.fe_to_io_eval.max_parallelism, BigUint::from(3u64)); assert_eq!(estimate.fe_to_io_eval_total_time, BigUint::from(20u64)); + assert_eq!(estimate.final_fe_eval.latency, 4.0); + assert_eq!(estimate.final_fe_eval.max_parallelism, BigUint::from(5u64)); assert_eq!(estimate.final_fe_eval_total_time, BigUint::from(40u64)); assert_eq!(estimate.fe_to_io_obfuscated_circuit_bytes, BigUint::from(5u64)); assert_eq!(estimate.final_fe_obfuscated_circuit_bytes, BigUint::from(7u64)); diff --git a/src/io/aky24_io/simulation.rs b/src/io/aky24_io/simulation.rs index 1e1399f5..e0301d6e 100644 --- a/src/io/aky24_io/simulation.rs +++ b/src/io/aky24_io/simulation.rs @@ -35,7 +35,6 @@ use super::{Aky24IO, Aky24IOFuncType}; use crate::io::utils::simulation::{self as sim_utils, assert_same_matrix_shape, scale_error_norm}; const AKY24_IO_SECRET_SIZE: usize = 1; -const AKY24_IO_PRF_BRANCH_COUNT: usize = 2; const REPRESENTATIVE_GOLDREICH_SEED_BITS: usize = 5; /// Error-growth summary for the conventional AKY24 FE-to-iO online path. @@ -70,6 +69,18 @@ pub struct Aky24IOPrfMaskOutputCoeffBitsSearchResult { pub simulation: Aky24IOErrorSimulation, } +#[derive(Debug, Clone)] +struct Aky24IOFinalMaskCache { + prg_output: ErrorNorm, + base_error: ErrorNorm, +} + +#[derive(Debug, Clone)] +struct Aky24IOPrfMaskOutputCoeffBitsSearchEvaluation { + result: Aky24IOPrfMaskOutputCoeffBitsSearchResult, + final_mask_cache: Aky24IOFinalMaskCache, +} + /// Successful AKY24 iO CRT-depth search result. #[derive(Debug, Clone, PartialEq, Eq)] pub struct Aky24IOCrtDepthSearchResult { @@ -187,12 +198,14 @@ pub fn minimum_aky24_io_prf_seed_bits( params: &DCRTPolyParams, output_size: usize, function_output_bits: usize, + prf_batch_bits: usize, prf_mask_output_coeff_bits: usize, noise_refresh_v_bits: usize, cbd_n: usize, ) -> usize { let ring_dim = params.ring_dimension() as usize; - let seed_refresh_seed_bits = minimum_seed_refresh_prf_seed_bits(AKY24_IO_PRF_BRANCH_COUNT); + let seed_refresh_seed_bits = + minimum_seed_refresh_prf_seed_bits(aky24_io_prf_branch_count(prf_batch_bits)); let final_mask_seed_bits = minimum_goldreich_input_size(aky24_io_final_prg_uniform_output_bits( output_size, @@ -205,6 +218,15 @@ pub fn minimum_aky24_io_prf_seed_bits( seed_refresh_seed_bits.max(final_mask_seed_bits).max(noise_refresh_seed_bits) } +fn aky24_io_prf_branch_count(prf_batch_bits: usize) -> usize { + assert!(prf_batch_bits > 0, "AKY24IO prf_batch_bits must be positive"); + assert!( + prf_batch_bits < usize::BITS as usize, + "AKY24IO prf_batch_bits must fit in a usize branch count" + ); + 1usize.checked_shl(prf_batch_bits as u32).expect("AKY24IO PRF branch count overflow") +} + /// Returns the largest noise-refresh `v_bits` allowed before pre-rounding error is added. pub fn aky24_io_max_noise_refresh_v_bits_without_pre_rounding_error( params: &DCRTPolyParams, @@ -335,13 +357,14 @@ where low = crt_depth + 1; continue; }; - let mask_bits = mask_search.prf_mask_output_coeff_bits; + let mask_bits = mask_search.result.prf_mask_output_coeff_bits; let final_candidate = build_candidate(ring_dim, crt_depth, mask_bits, Some(global_noise_refresh_v_bits)); let expected_seed_bits = minimum_aky24_io_prf_seed_bits( &cpu_params, func_type.output_bits(), func_type.output_bits(), + final_candidate.prf_batch_bits, mask_bits, global_noise_refresh_v_bits, final_candidate.noise_refresh_cbd_n, @@ -350,25 +373,39 @@ where low = crt_depth + 1; continue; } - let final_simulation = final_candidate - .build_fixed_noise_refresh_prefix_from_base( - provisional_base.clone(), + let final_simulation = if final_candidate + .can_reuse_mask_independent_prefix_from(&mask_search_candidate, &cpu_params) + { + Some(final_candidate.finish_error_growth_from_mask_independent_prefix( + &mask_search_prefix, func_type, - global_noise_refresh_v_bits, + mask_bits, + Some(mask_search.final_mask_cache.prg_output.clone()), + Some(mask_search.final_mask_cache.base_error.clone()), plt_evaluator, slot_transfer_evaluator, - ) - .map(|prefix| { - final_candidate.finish_error_growth_from_mask_independent_prefix( - &prefix, + )) + } else { + final_candidate + .build_fixed_noise_refresh_prefix_from_base( + provisional_base.clone(), func_type, - mask_bits, - None, - None, + global_noise_refresh_v_bits, plt_evaluator, slot_transfer_evaluator, ) - }); + .map(|prefix| { + final_candidate.finish_error_growth_from_mask_independent_prefix( + &prefix, + func_type, + mask_bits, + None, + None, + plt_evaluator, + slot_transfer_evaluator, + ) + }) + }; let Some(final_simulation) = final_simulation else { low = crt_depth + 1; continue; @@ -408,14 +445,6 @@ impl Aky24IO where M: PolyMatrix, { - fn prf_round_count(&self) -> usize { - self.public_prf_seed_bits - } - - fn prf_branch_count(&self) -> usize { - AKY24_IO_PRF_BRANCH_COUNT - } - /// Simulate AKY24 iO error growth for the selected function family. pub fn simulate_error_growth( &self, @@ -467,6 +496,7 @@ where plt_evaluator, slot_transfer_evaluator, ) + .map(|evaluation| evaluation.result) } fn simulate_error_growth_with_prf_mask_output_coeff_bits( @@ -838,17 +868,25 @@ where PE: PltEvaluator, ST: SlotTransferEvaluator, { + let material_state = self.prf_refresh_material_state( + base, + round_idx, + seed_errors, + seed_ciphertext_randomizer_norm, + plt_evaluator, + slot_transfer_evaluator, + ); let mut low = 1usize; let mut high = max_candidate; let mut best = None; while low <= high { let candidate = low + (high - low) / 2; - let Some(evaluation) = self.simulate_prf_refresh_round_fixed( + let Some(evaluation) = self.evaluate_prf_refresh_round_from_material_state( base, round_idx, candidate, seed_errors, - seed_ciphertext_randomizer_norm, + &material_state, plt_evaluator, slot_transfer_evaluator, ) else { @@ -897,6 +935,31 @@ where plt_evaluator, slot_transfer_evaluator, ); + self.evaluate_prf_refresh_round_from_material_state( + base, + round_idx, + noise_refresh_v_bits, + seed_errors, + &material_state, + plt_evaluator, + slot_transfer_evaluator, + ) + } + + fn evaluate_prf_refresh_round_from_material_state( + &self, + base: &Aky24IOMaskIndependentErrorBase, + round_idx: usize, + noise_refresh_v_bits: usize, + seed_errors: &[ErrorNorm], + material_state: &Aky24IOPrfRefreshMaterialState, + plt_evaluator: &PE, + slot_transfer_evaluator: &ST, + ) -> Option + where + PE: PltEvaluator, + ST: SlotTransferEvaluator, + { let core = sim_utils::simulate_prf_refresh_round_fixed_core( &base.cpu_params, base.noise_refresh_ring_gsw_context.clone(), @@ -918,15 +981,17 @@ where Some(Aky24IOPrfRefreshRoundEvaluation { round: Aky24IOPrfRoundErrorSimulation { round_idx, - representative_selected_prg_output: material_state.selected, + representative_selected_prg_output: material_state.selected.clone(), representative_selected_prg_ciphertext_decryption_error: material_state - .representative_selected_prg_ciphertext_decryption_error, + .representative_selected_prg_ciphertext_decryption_error + .clone(), noise_refresh: core.refresh, }, refreshed_seed: core.refreshed_seed, refreshed_seed_ciphertext_randomizer_norm: material_state - .refreshed_seed_ciphertext_randomizer_norm, - material_wire_error: material_state.material_wire_error, + .refreshed_seed_ciphertext_randomizer_norm + .clone(), + material_wire_error: material_state.material_wire_error.clone(), }) } @@ -1009,7 +1074,7 @@ where security_bit: Option, plt_evaluator: &PE, slot_transfer_evaluator: &ST, - ) -> Option + ) -> Option where PE: PltEvaluator, ST: SlotTransferEvaluator, @@ -1080,10 +1145,16 @@ where final_margin > *error }; if valid { - best = Some(Aky24IOPrfMaskOutputCoeffBitsSearchResult { - prf_mask_output_coeff_bits: candidate, - noise_refresh_v_bits, - simulation, + best = Some(Aky24IOPrfMaskOutputCoeffBitsSearchEvaluation { + result: Aky24IOPrfMaskOutputCoeffBitsSearchResult { + prf_mask_output_coeff_bits: candidate, + noise_refresh_v_bits, + simulation, + }, + final_mask_cache: Aky24IOFinalMaskCache { + prg_output: cached_final_mask_prg_output.clone(), + base_error: cached_final_mask_base_error.clone(), + }, }); low = candidate + 1; } else if candidate == 1 { @@ -1189,6 +1260,36 @@ where ) } + fn can_reuse_mask_independent_prefix_from( + &self, + other: &Self, + cpu_params: &DCRTPolyParams, + ) -> bool { + let self_cpu_params = cpu_params_from_poly_params(&self.params); + let other_cpu_params = cpu_params_from_poly_params(&other.params); + self_cpu_params == *cpu_params && + other_cpu_params == *cpu_params && + self.input_size == other.input_size && + self.seed_bits == other.seed_bits && + self.prf_batch_bits == other.prf_batch_bits && + self.noise_refresh_v_bits == other.noise_refresh_v_bits && + self.noise_refresh_cbd_n == other.noise_refresh_cbd_n && + self.ring_gsw_enable_levels == other.ring_gsw_enable_levels && + self.ring_gsw_level_offset == other.ring_gsw_level_offset && + self.ring_gsw_public_key_error_sigma == other.ring_gsw_public_key_error_sigma && + self.full_active_levels(cpu_params) == other.full_active_levels(cpu_params) && + self.cpu_ring_gsw_config_matches(other) + } + + fn cpu_ring_gsw_config_matches(&self, other: &Self) -> bool { + let left = self.cpu_ring_gsw_config(); + let right = other.cpu_ring_gsw_config(); + left.p_moduli_bits == right.p_moduli_bits && + left.max_unreduced_muls == right.max_unreduced_muls && + left.scale == right.scale && + left.level_offset == right.level_offset + } + fn cpu_ring_gsw_config(&self) -> sim_utils::CpuRingGswContextConfig { sim_utils::CpuRingGswContextConfig { p_moduli_bits: self.ring_gsw_context.p_moduli_bits, @@ -1339,7 +1440,7 @@ mod tests { NoCircuitEvaluator, >; - fn test_scheme(active_levels: usize, public_prf_seed_bits: usize) -> TestAky24IO { + fn test_scheme(active_levels: usize, input_size: usize, prf_batch_bits: usize) -> TestAky24IO { let params = DCRTPolyParams::new(2, active_levels, 10, 5); let mut setup_circuit = PolyCircuit::::new(); let ring_gsw_context = Arc::new(NestedRnsPolyContext::setup( @@ -1366,10 +1467,10 @@ mod tests { Some(active_levels), Some(0.0), b"aky24_io_error_simulation_test".to_vec(), - 1, + input_size, 1, 6, - public_prf_seed_bits, + prf_batch_bits, 1, 1, 1, @@ -1386,6 +1487,7 @@ mod tests { fn test_minimum_aky24_io_prf_seed_bits_covers_seed_refresh_outputs() { let params = DCRTPolyParams::new(2, 1, 10, 5); let output_size = 3usize; + let prf_batch_bits = 3usize; let prf_mask_output_coeff_bits = 2usize; let noise_refresh_v_bits = 1usize; let cbd_n = 1usize; @@ -1393,13 +1495,15 @@ mod tests { ¶ms, output_size, output_size, + prf_batch_bits, prf_mask_output_coeff_bits, noise_refresh_v_bits, cbd_n, ); let ring_dim = params.ring_dimension() as usize; + let prf_branch_count = aky24_io_prf_branch_count(prf_batch_bits); - assert!(goldreich_output_bound_holds(seed_bits, AKY24_IO_PRF_BRANCH_COUNT * seed_bits)); + assert!(goldreich_output_bound_holds(seed_bits, prf_branch_count * seed_bits)); assert!(goldreich_output_bound_holds( seed_bits, aky24_io_final_prg_uniform_output_bits( @@ -1413,14 +1517,16 @@ mod tests { } #[test] - fn test_prf_final_round_separator_uses_public_prf_seed_bits() { - let scheme = test_scheme(1, 7); - assert_eq!(scheme.prf_final_round_idx(), 7); + fn test_prf_final_round_separator_uses_batched_public_prf_rounds() { + let scheme = test_scheme(1, 12, 4); + assert_eq!(scheme.prf_round_count(), 3); + assert_eq!(scheme.prf_branch_count(), 16); + assert_eq!(scheme.prf_final_round_idx(), 3); } #[test] fn test_fresh_error_base_does_not_require_input_injection() { - let scheme = test_scheme(1, 1); + let scheme = test_scheme(1, 1, 1); let base = scheme.simulate_mask_independent_error_base( Aky24IOFuncType::GoldreichPRF { output_bits: 1 }, 2.0, @@ -1441,7 +1547,7 @@ mod tests { #[test] fn test_finish_error_growth_composes_projection_and_decoder_error() { - let scheme = test_scheme(1, 1); + let scheme = test_scheme(1, 1, 1); let params = DCRTPolyParams::new(2, 1, 10, 5); let ctx = simulator_context(¶ms); let matrix = |norm: u32, ncol: usize| { diff --git a/src/io/diamond_io/bench_estimator.rs b/src/io/diamond_io/bench_estimator.rs index 5622f7f0..e8b858ca 100644 --- a/src/io/diamond_io/bench_estimator.rs +++ b/src/io/diamond_io/bench_estimator.rs @@ -81,12 +81,20 @@ pub struct DiamondIOBenchEstimate { pub obfuscate_input_injection: CircuitBenchSummary, /// Input-injection online evaluation work performed during `DiamondIO::eval`. pub eval_input_injection: CircuitBenchSummary, + /// Eval work after input-injection, modeled as the final FE portion for comparisons. + pub final_fe_eval: CircuitBenchSummary, /// Total compact bytes written as the persisted obfuscated circuit artifacts. pub obfuscated_circuit_bytes: BigUint, /// Compact bytes contributed by the Diamond input-injection artifacts. pub input_injection_bytes: BigUint, } +#[derive(Debug, Clone, PartialEq)] +struct DiamondIOEvalBenchEstimateParts { + total: CircuitBenchSummary, + final_fe: CircuitBenchSummary, +} + impl DiamondIOBenchEstimate { pub fn obfuscate_input_injection_latency_percent(&self) -> f64 { percent_f64(self.obfuscate_input_injection.latency, self.obfuscate.latency) @@ -549,17 +557,29 @@ where info!("completed DiamondIO eval benchmark estimation"); let estimate = DiamondIOBenchEstimate { obfuscate, - eval, + eval: eval.total, obfuscate_input_injection: input_injection.obfuscate, eval_input_injection: input_injection.eval, + final_fe_eval: eval.final_fe, obfuscated_circuit_bytes: storage.total_bytes, input_injection_bytes: storage.input_injection_bytes, }; info!( + obfuscate_input_injection_latency = estimate.obfuscate_input_injection.latency, + obfuscate_input_injection_total_time_nanos = + %estimate.obfuscate_input_injection.total_time, + obfuscate_input_injection_max_parallelism = + %estimate.obfuscate_input_injection.max_parallelism, obfuscate_input_injection_latency_percent = estimate.obfuscate_input_injection_latency_percent(), obfuscate_input_injection_total_time_percent = estimate.obfuscate_input_injection_total_time_percent(), + eval_input_injection_latency = estimate.eval_input_injection.latency, + eval_input_injection_total_time_nanos = %estimate.eval_input_injection.total_time, + eval_input_injection_max_parallelism = %estimate.eval_input_injection.max_parallelism, + final_fe_eval_latency = estimate.final_fe_eval.latency, + final_fe_eval_total_time_nanos = %estimate.final_fe_eval.total_time, + final_fe_eval_max_parallelism = %estimate.final_fe_eval.max_parallelism, eval_input_injection_latency_percent = estimate.eval_input_injection_latency_percent(), eval_input_injection_total_time_percent = estimate.eval_input_injection_total_time_percent(), @@ -733,7 +753,7 @@ where func: DiamondIOFuncType, shape: DiamondIOBenchShape, _persisted_storage: &DiamondIOStorageEstimate, - ) -> CircuitBenchSummary + ) -> DiamondIOEvalBenchEstimateParts where M: PolyMatrix + Send + Sync + 'static, M::P: 'static, @@ -807,14 +827,14 @@ where ]); let final_decode = scale_summary(final_decode_unit.clone(), shape.final_decoder_count()); - let total = sequential_summaries(&[ - input_injection.clone(), + let final_fe = sequential_summaries(&[ input_encoding_projection.clone(), seed_ciphertext_lift.clone(), prf_and_function.clone(), decoder_projection.clone(), final_decode.clone(), ]); + let total = sequential_summaries(&[input_injection.clone(), final_fe.clone()]); debug!( ?input_injection, @@ -824,11 +844,12 @@ where ?function_encoding_eval, ?decoder_projection, ?final_decode, + ?final_fe, ?total, "estimated DiamondIO eval benchmark" ); - total + DiamondIOEvalBenchEstimateParts { total, final_fe } } fn estimate_function_circuit( @@ -1171,7 +1192,7 @@ where parallel_summaries(&[final_mask_decrypt.clone(), final_function_decrypt.clone()]); info!( ?mode, - final_mask_decrypt_contribution_count, + final_mask_decrypt_contribution_count = %final_mask_decrypt_contribution_count, final_mask_reduce_add_count, ?final_mask_decrypt_unit, ?final_mask_decrypt_contributions, diff --git a/tests/test_gpu_aky24_io.rs b/tests/test_gpu_aky24_io.rs index f7e80e9e..5592bb0c 100644 --- a/tests/test_gpu_aky24_io.rs +++ b/tests/test_gpu_aky24_io.rs @@ -66,6 +66,7 @@ const DEFAULT_MIN_LOG_RING_DIM: usize = 16; const DEFAULT_MAX_LOG_RING_DIM: usize = 16; const DEFAULT_INPUT_SIZE: usize = 5; const DEFAULT_OUTPUT_SIZE: usize = 6; +const DEFAULT_PRF_BATCH_BITS: usize = 1; const DEFAULT_CRT_BITS: usize = 28; const DEFAULT_BASE_BITS: u32 = 14; const DEFAULT_P_MODULI_BITS: usize = 7; @@ -118,7 +119,7 @@ struct Aky24IOGpuBenchConfig { max_log_ring_dim: usize, input_size: usize, output_size: usize, - public_prf_seed_bits: usize, + prf_batch_bits: usize, crt_bits: usize, base_bits: u32, p_moduli_bits: usize, @@ -164,9 +165,9 @@ impl Aky24IOGpuBenchConfig { ), input_size, output_size: env_or_parse_usize("AKY24_IO_GPU_BENCH_OUTPUT_SIZE", DEFAULT_OUTPUT_SIZE), - public_prf_seed_bits: env_or_parse_usize( - "AKY24_IO_GPU_BENCH_PUBLIC_PRF_SEED_BITS", - input_size, + prf_batch_bits: env_or_parse_usize( + "AKY24_IO_GPU_BENCH_PRF_BATCH_BITS", + DEFAULT_PRF_BATCH_BITS, ), crt_bits: env_or_parse_usize("AKY24_IO_GPU_BENCH_CRT_BITS", DEFAULT_CRT_BITS), base_bits: env_or_parse_u32("AKY24_IO_GPU_BENCH_BASE_BITS", DEFAULT_BASE_BITS), @@ -222,9 +223,15 @@ impl Aky24IOGpuBenchConfig { ); assert!(cfg.input_size > 0, "AKY24_IO_GPU_BENCH_INPUT_SIZE must be positive"); assert!(cfg.output_size > 0, "AKY24_IO_GPU_BENCH_OUTPUT_SIZE must be positive"); + assert!(cfg.prf_batch_bits > 0, "AKY24_IO_GPU_BENCH_PRF_BATCH_BITS must be positive"); assert!( - cfg.public_prf_seed_bits > 0, - "AKY24_IO_GPU_BENCH_PUBLIC_PRF_SEED_BITS must be positive" + cfg.prf_batch_bits < usize::BITS as usize, + "AKY24_IO_GPU_BENCH_PRF_BATCH_BITS must fit in a usize branch count" + ); + assert_eq!( + cfg.input_size % cfg.prf_batch_bits, + 0, + "AKY24_IO_GPU_BENCH_INPUT_SIZE must be divisible by AKY24_IO_GPU_BENCH_PRF_BATCH_BITS" ); assert!(cfg.crt_bits > 0, "AKY24_IO_GPU_BENCH_CRT_BITS must be positive"); assert!(cfg.base_bits > 0, "AKY24_IO_GPU_BENCH_BASE_BITS must be positive"); @@ -257,6 +264,7 @@ impl Aky24IOGpuBenchConfig { params, self.output_size, self.output_size, + self.prf_batch_bits, prf_mask_output_coeff_bits, noise_refresh_v_bits, self.noise_refresh_cbd_n, @@ -522,7 +530,7 @@ fn build_aky24_io( cfg.input_size, cfg.output_size(), seed_bits, - cfg.public_prf_seed_bits, + cfg.prf_batch_bits, prf_mask_output_coeff_bits, noise_refresh_v_bits, cfg.noise_refresh_cbd_n, @@ -582,7 +590,7 @@ fn build_cpu_aky24_io_for_search( cfg.input_size, cfg.output_size(), seed_bits, - cfg.public_prf_seed_bits, + cfg.prf_batch_bits, prf_mask_output_coeff_bits, noise_refresh_v_bits, cfg.noise_refresh_cbd_n, @@ -1150,11 +1158,21 @@ async fn test_gpu_aky24_io_error_search_and_bench_estimate() { obfuscate_latency = estimate.obfuscate.latency, obfuscate_total_time_nanos = %estimate.obfuscate.total_time, obfuscate_max_parallelism = %estimate.obfuscate.max_parallelism, + fe_to_io_obfuscate_latency = estimate.fe_to_io_obfuscate.latency, + fe_to_io_obfuscate_total_time_nanos = %estimate.fe_to_io_obfuscate_total_time, + fe_to_io_obfuscate_max_parallelism = %estimate.fe_to_io_obfuscate.max_parallelism, + final_fe_obfuscate_latency = estimate.final_fe_obfuscate.latency, + final_fe_obfuscate_total_time_nanos = %estimate.final_fe_obfuscate_total_time, + final_fe_obfuscate_max_parallelism = %estimate.final_fe_obfuscate.max_parallelism, eval_latency = estimate.eval.latency, eval_total_time_nanos = %estimate.eval.total_time, eval_max_parallelism = %estimate.eval.max_parallelism, + fe_to_io_eval_latency = estimate.fe_to_io_eval.latency, fe_to_io_eval_total_time_nanos = %estimate.fe_to_io_eval_total_time, + fe_to_io_eval_max_parallelism = %estimate.fe_to_io_eval.max_parallelism, + final_fe_eval_latency = estimate.final_fe_eval.latency, final_fe_eval_total_time_nanos = %estimate.final_fe_eval_total_time, + final_fe_eval_max_parallelism = %estimate.final_fe_eval.max_parallelism, obfuscated_circuit_bytes = %estimate.obfuscated_circuit_bytes, fe_to_io_obfuscated_circuit_bytes = %estimate.fe_to_io_obfuscated_circuit_bytes, final_fe_obfuscated_circuit_bytes = %estimate.final_fe_obfuscated_circuit_bytes, @@ -1163,12 +1181,15 @@ async fn test_gpu_aky24_io_error_search_and_bench_estimate() { assert!(estimate.obfuscate.total_time >= BigUint::from(0u32)); assert!(estimate.eval.total_time >= BigUint::from(0u32)); assert!(estimate.obfuscated_circuit_bytes > BigUint::from(0u32)); + assert!(estimate.final_fe_obfuscate_total_time > BigUint::from(0u32)); assert!(estimate.final_fe_eval_total_time > BigUint::from(0u32)); assert!(estimate.final_fe_obfuscated_circuit_bytes > BigUint::from(0u32)); - if cfg.input_size > 1 { + if cfg.input_size / cfg.prf_batch_bits > 1 { + assert!(estimate.fe_to_io_obfuscate_total_time > BigUint::from(0u32)); assert!(estimate.fe_to_io_eval_total_time > BigUint::from(0u32)); assert!(estimate.fe_to_io_obfuscated_circuit_bytes > BigUint::from(0u32)); } else { + assert_eq!(estimate.fe_to_io_obfuscate_total_time, BigUint::from(0u32)); assert_eq!(estimate.fe_to_io_eval_total_time, BigUint::from(0u32)); assert_eq!(estimate.fe_to_io_obfuscated_circuit_bytes, BigUint::from(0u32)); } diff --git a/tests/test_gpu_diamond_io.rs b/tests/test_gpu_diamond_io.rs index 84e5f643..748e2c1c 100644 --- a/tests/test_gpu_diamond_io.rs +++ b/tests/test_gpu_diamond_io.rs @@ -972,10 +972,21 @@ async fn test_gpu_diamond_io_error_search_and_bench_estimate() { eval_latency = estimate.eval.latency, eval_total_time_nanos = %estimate.eval.total_time, eval_max_parallelism = %estimate.eval.max_parallelism, + obfuscate_input_injection_latency = estimate.obfuscate_input_injection.latency, + obfuscate_input_injection_total_time_nanos = + %estimate.obfuscate_input_injection.total_time, + obfuscate_input_injection_max_parallelism = + %estimate.obfuscate_input_injection.max_parallelism, obfuscate_input_injection_latency_percent = estimate.obfuscate_input_injection_latency_percent(), obfuscate_input_injection_total_time_percent = estimate.obfuscate_input_injection_total_time_percent(), + eval_input_injection_latency = estimate.eval_input_injection.latency, + eval_input_injection_total_time_nanos = %estimate.eval_input_injection.total_time, + eval_input_injection_max_parallelism = %estimate.eval_input_injection.max_parallelism, + final_fe_eval_latency = estimate.final_fe_eval.latency, + final_fe_eval_total_time_nanos = %estimate.final_fe_eval.total_time, + final_fe_eval_max_parallelism = %estimate.final_fe_eval.max_parallelism, eval_input_injection_latency_percent = estimate.eval_input_injection_latency_percent(), eval_input_injection_total_time_percent = estimate.eval_input_injection_total_time_percent(), @@ -985,6 +996,7 @@ async fn test_gpu_diamond_io_error_search_and_bench_estimate() { ); assert!(estimate.obfuscate.total_time >= BigUint::from(0u32)); assert!(estimate.eval.total_time >= BigUint::from(0u32)); + assert!(estimate.final_fe_eval.total_time > BigUint::from(0u32)); assert!(estimate.obfuscated_circuit_bytes > BigUint::from(0u32)); assert!(estimate.input_injection_bytes > BigUint::from(0u32)); }