From d94f3a8a25b560cf899bd9eb87caf442e2464a89 Mon Sep 17 00:00:00 2001 From: SoraSuegami Date: Wed, 3 Jun 2026 09:18:29 +0900 Subject: [PATCH] Add lattice estimator checks to DiamondWE tests --- src/io/utils/simulation.rs | 152 ++++++++++++++++++++++++++++++++ src/we/diamond_we/simulation.rs | 84 ++++++++++++++++-- tests/test_gpu_diamond_we.rs | 95 ++++++++++++++++++-- 3 files changed, 319 insertions(+), 12 deletions(-) diff --git a/src/io/utils/simulation.rs b/src/io/utils/simulation.rs index 2444b065..32678a2c 100644 --- a/src/io/utils/simulation.rs +++ b/src/io/utils/simulation.rs @@ -287,6 +287,158 @@ where found } +pub(crate) fn select_min_secure_ring_dim_gaussian_only( + protocol_name: &str, + crt_depth: usize, + min_log_ring_dim: usize, + max_log_ring_dim: usize, + security_bits: usize, + error_sigma: f64, + skip_lattice_check: bool, + lattice_cache: &mut SecureRingDimLatticeCache, + mut build_params: BuildParams, +) -> Option +where + P: PolyParams, + BuildParams: FnMut(u32) -> P, +{ + assert!( + min_log_ring_dim <= max_log_ring_dim, + "{protocol_name} log-ring-dimension search range must be non-empty" + ); + assert!( + max_log_ring_dim < u32::BITS as usize, + "{protocol_name} max_log_ring_dim must be less than 32" + ); + assert!( + error_sigma >= 0.0, + "{protocol_name} lattice-estimator Gaussian stddev must be nonnegative" + ); + if skip_lattice_check { + assert_eq!( + min_log_ring_dim, max_log_ring_dim, + "{protocol_name} explicit lattice-check skip requires a single log_ring_dim" + ); + let ring_dim = 1u32 + .checked_shl(min_log_ring_dim.try_into().expect("log_ring_dim must fit in u32")) + .expect("ring_dim shift overflow"); + info!( + protocol_name, + crt_depth, + log_ring_dim = min_log_ring_dim, + ring_dim, + "skipping Gaussian lattice-estimator security check because log_ring_dim was explicit" + ); + return Some(SecureRingDimSearchResult { + log_ring_dim: min_log_ring_dim, + ring_dim, + achieved_secpar_for_gauss: None, + achieved_secpar_for_cbd: None, + }); + } + let s_dist = Distribution::Ternary; + let e_dist_gauss = + Distribution::DiscreteGaussian { stddev: error_sigma.to_string(), mean: None, n: None }; + let required_security: u64 = security_bits.try_into().expect("security_bits must fit in u64"); + let mut low = min_log_ring_dim; + let mut high = max_log_ring_dim; + let mut found = None; + while low <= high { + let log_ring_dim = low + (high - low) / 2; + let ring_dim = 1u32 + .checked_shl(log_ring_dim.try_into().expect("log_ring_dim must fit in u32")) + .expect("ring_dim shift overflow"); + if let Some(cached) = lattice_cache.secure_for(crt_depth, log_ring_dim) { + info!( + protocol_name, + crt_depth, + log_ring_dim, + ring_dim, + achieved_secpar_for_gauss = cached.achieved_secpar_for_gauss, + "skipping Gaussian lattice-estimator security check using larger CRT-depth cache" + ); + found = Some(cached); + if log_ring_dim == 0 { + break; + } + high = log_ring_dim - 1; + continue; + } + let params = build_params(ring_dim); + let q: Arc = params.modulus().into(); + let ring_dim_big = BigUint::from(ring_dim); + info!( + protocol_name, + crt_depth, + log_ring_dim, + ring_dim, + modulus_bits = q.bits(), + required_security, + error_sigma, + "running Gaussian lattice-estimator security check for CRT-depth ring-dimension candidate" + ); + match run_lattice_estimator_cli_with_timeout( + &ring_dim_big, + q.as_ref(), + &s_dist, + &e_dist_gauss, + None, + false, + LATTICE_ESTIMATOR_TIMEOUT, + ) { + Ok(achieved_secpar_for_gauss) => { + info!( + protocol_name, + crt_depth, + log_ring_dim, + ring_dim, + achieved_secpar_for_gauss, + required_security, + "evaluated CRT-depth ring-dimension Gaussian security candidate" + ); + if achieved_secpar_for_gauss >= required_security { + let result = SecureRingDimSearchResult { + log_ring_dim, + ring_dim, + achieved_secpar_for_gauss: Some(achieved_secpar_for_gauss), + achieved_secpar_for_cbd: None, + }; + lattice_cache.record(crt_depth, result); + found = Some(result); + if log_ring_dim == 0 { + break; + } + high = log_ring_dim - 1; + } else { + low = log_ring_dim + 1; + } + } + Err(err) => { + info!( + protocol_name, + crt_depth, + log_ring_dim, + ring_dim, + gauss_error = ?err, + "Gaussian lattice-estimator failed for CRT-depth ring-dimension candidate" + ); + low = log_ring_dim + 1; + } + } + } + if found.is_none() { + info!( + protocol_name, + crt_depth, + min_log_ring_dim, + max_log_ring_dim, + required_security, + "no Gaussian-secure ring dimension found for CRT-depth candidate" + ); + } + found +} + #[derive(Debug, Clone, Copy)] pub(crate) struct CpuRingGswContextConfig { pub p_moduli_bits: usize, diff --git a/src/we/diamond_we/simulation.rs b/src/we/diamond_we/simulation.rs index cd2bcc74..8aa93d85 100644 --- a/src/we/diamond_we/simulation.rs +++ b/src/we/diamond_we/simulation.rs @@ -1,6 +1,7 @@ use crate::{ circuit::{Evaluable, PolyCircuit}, input_injector::DiamondInputErrorSimulation, + io::utils::simulation as sim_utils, lookup::PltEvaluator, matrix::PolyMatrix, poly::{Poly, PolyParams, dcrt::poly::DCRTPoly}, @@ -36,6 +37,10 @@ pub struct DiamondWEErrorSimulation { #[derive(Debug, Clone, PartialEq, Eq)] pub struct DiamondWECrtDepthSearchResult { pub crt_depth: usize, + pub log_ring_dim: usize, + pub ring_dim: u32, + pub achieved_secpar_for_gauss: Option, + pub achieved_secpar_for_cbd: Option, pub simulation: DiamondWEErrorSimulation, } @@ -54,6 +59,9 @@ fn diamond_we_correctness_margin_holds( pub fn diamond_we_find_crt_depth( min_crt_depth: usize, max_crt_depth: usize, + min_log_ring_dim: usize, + max_log_ring_dim: usize, + security_bits: usize, circuit: &PolyCircuit, mut build_candidate: BuildCandidate, plt_evaluator: Option<&PE>, @@ -66,18 +74,48 @@ where TS: PolyTrapdoorSampler + Send + Sync, PE: PltEvaluator, ST: SlotTransferEvaluator, - BuildCandidate: FnMut(usize) -> DiamondWE, + BuildCandidate: FnMut(u32, usize) -> DiamondWE, { assert!(min_crt_depth > 0, "minimum CRT depth must be positive"); assert!(min_crt_depth <= max_crt_depth, "CRT-depth search range must be non-empty"); info!( min_crt_depth, - max_crt_depth, "starting DiamondWE CRT-depth search with q/4 correctness margin" + max_crt_depth, + min_log_ring_dim, + max_log_ring_dim, + security_bits, + "starting DiamondWE CRT-depth search with q/4 correctness margin" ); + let force_lattice_check = std::env::var_os("MXX_IO_FORCE_LATTICE_CHECK").is_some(); + let explicit_log_ring_dim = min_log_ring_dim == max_log_ring_dim && !force_lattice_check; + if !explicit_log_ring_dim { + sim_utils::assert_lattice_estimator_available("DiamondWE"); + } + let mut lattice_cache = sim_utils::SecureRingDimLatticeCache::default(); let mut high = max_crt_depth; let upper_valid = loop { info!(crt_depth = high, "evaluating DiamondWE CRT-depth upper-bound candidate"); - let candidate = build_candidate(high); + let min_ring_dim = 1u32 + .checked_shl(min_log_ring_dim.try_into().expect("min_log_ring_dim must fit in u32")) + .expect("minimum ring_dim shift overflow"); + let probe_candidate = build_candidate(min_ring_dim, high); + let Some(ring_dim_search) = sim_utils::select_min_secure_ring_dim_gaussian_only( + "DiamondWE", + high, + min_log_ring_dim, + max_log_ring_dim, + security_bits, + probe_candidate.injector.error_sigma, + explicit_log_ring_dim, + &mut lattice_cache, + |ring_dim| { + let candidate = build_candidate(ring_dim, high); + candidate.injector.params.clone() + }, + ) else { + return None; + }; + let candidate = build_candidate(ring_dim_search.ring_dim, high); let slot_transfer_evaluator = slot_transfer_evaluator .map(|evaluator| evaluator as &dyn SlotTransferEvaluator); let simulation = @@ -92,7 +130,14 @@ where "DiamondWE CRT-depth upper-bound candidate evaluated" ); if valid { - break DiamondWECrtDepthSearchResult { crt_depth: high, simulation }; + break DiamondWECrtDepthSearchResult { + crt_depth: high, + log_ring_dim: ring_dim_search.log_ring_dim, + ring_dim: ring_dim_search.ring_dim, + achieved_secpar_for_gauss: ring_dim_search.achieved_secpar_for_gauss, + achieved_secpar_for_cbd: ring_dim_search.achieved_secpar_for_cbd, + simulation, + }; } let next_high = high.checked_mul(2).expect("DiamondWE CRT-depth search upper bound overflowed usize"); @@ -110,7 +155,27 @@ where while low <= high { let crt_depth = low + (high - low) / 2; info!(crt_depth, low, high, "evaluating DiamondWE CRT-depth candidate"); - let candidate = build_candidate(crt_depth); + let min_ring_dim = 1u32 + .checked_shl(min_log_ring_dim.try_into().expect("min_log_ring_dim must fit in u32")) + .expect("minimum ring_dim shift overflow"); + let probe_candidate = build_candidate(min_ring_dim, crt_depth); + let Some(ring_dim_search) = sim_utils::select_min_secure_ring_dim_gaussian_only( + "DiamondWE", + crt_depth, + min_log_ring_dim, + max_log_ring_dim, + security_bits, + probe_candidate.injector.error_sigma, + explicit_log_ring_dim, + &mut lattice_cache, + |ring_dim| { + let candidate = build_candidate(ring_dim, crt_depth); + candidate.injector.params.clone() + }, + ) else { + return None; + }; + let candidate = build_candidate(ring_dim_search.ring_dim, crt_depth); let slot_transfer_evaluator = slot_transfer_evaluator .map(|evaluator| evaluator as &dyn SlotTransferEvaluator); let simulation = @@ -125,7 +190,14 @@ where "DiamondWE CRT-depth candidate evaluated" ); if valid { - result = Some(DiamondWECrtDepthSearchResult { crt_depth, simulation }); + result = Some(DiamondWECrtDepthSearchResult { + crt_depth, + log_ring_dim: ring_dim_search.log_ring_dim, + ring_dim: ring_dim_search.ring_dim, + achieved_secpar_for_gauss: ring_dim_search.achieved_secpar_for_gauss, + achieved_secpar_for_cbd: ring_dim_search.achieved_secpar_for_cbd, + simulation, + }); if crt_depth == min_crt_depth { break; } diff --git a/tests/test_gpu_diamond_we.rs b/tests/test_gpu_diamond_we.rs index efb901d0..8fec4abc 100644 --- a/tests/test_gpu_diamond_we.rs +++ b/tests/test_gpu_diamond_we.rs @@ -54,6 +54,7 @@ const DEFAULT_CRT_BITS: usize = 28; const DEFAULT_BASE_BITS: u32 = 14; const DEFAULT_MIN_CRT_DEPTH: usize = 1; const DEFAULT_MAX_CRT_DEPTH: usize = 64; +const DEFAULT_SECURITY_BITS: usize = 100; const DEFAULT_BENCH_ITERATIONS: usize = 1; const DEFAULT_ERROR_SIGMA: f64 = 4.0; const DEFAULT_TRAPDOOR_SIGMA: f64 = 4.578; @@ -104,6 +105,8 @@ type GpuPubKeySlotEvaluator = BggPublicKeySTEvaluator< #[derive(Debug, Clone)] struct DiamondWEGpuBenchConfig { ring_dim: u32, + min_log_ring_dim: usize, + max_log_ring_dim: usize, circuit_height: usize, witness_size: usize, injector_batch_bits: usize, @@ -111,6 +114,7 @@ struct DiamondWEGpuBenchConfig { base_bits: u32, min_crt_depth: usize, max_crt_depth: usize, + security_bits: usize, bench_iterations: usize, error_sigma: f64, trapdoor_sigma: f64, @@ -120,14 +124,26 @@ struct DiamondWEGpuBenchConfig { #[derive(Debug, Clone, Copy)] struct DiamondWEGpuBenchSelectedSimulation { crt_depth: usize, + ring_dim: u32, + log_ring_dim: usize, + achieved_secpar_for_gauss: Option, + achieved_secpar_for_cbd: Option, noisy_plaintext_error_bits: usize, input_injection_error_bits: usize, } impl DiamondWEGpuBenchConfig { fn from_env() -> Self { + let ring_dim = env_or_parse_u32("DIAMOND_WE_GPU_BENCH_RING_DIM", DEFAULT_RING_DIM); + assert!(ring_dim > 0, "DIAMOND_WE_GPU_BENCH_RING_DIM must be positive"); + assert!(ring_dim.is_power_of_two(), "DIAMOND_WE_GPU_BENCH_RING_DIM must be a power of two"); + let default_log_ring_dim = ring_dim.trailing_zeros() as usize; let cfg = Self { - ring_dim: env_or_parse_u32("DIAMOND_WE_GPU_BENCH_RING_DIM", DEFAULT_RING_DIM), + ring_dim, + min_log_ring_dim: env_or_parse_optional_usize("DIAMOND_WE_GPU_BENCH_MIN_LOG_RING_DIM") + .unwrap_or(default_log_ring_dim), + max_log_ring_dim: env_or_parse_optional_usize("DIAMOND_WE_GPU_BENCH_MAX_LOG_RING_DIM") + .unwrap_or(default_log_ring_dim), circuit_height: env_or_parse_usize( "DIAMOND_WE_GPU_BENCH_CIRCUIT_HEIGHT", DEFAULT_CIRCUIT_HEIGHT, @@ -150,6 +166,10 @@ impl DiamondWEGpuBenchConfig { "DIAMOND_WE_GPU_BENCH_MAX_CRT_DEPTH", DEFAULT_MAX_CRT_DEPTH, ), + security_bits: env_or_parse_usize( + "DIAMOND_WE_GPU_BENCH_SECURITY_BITS", + DEFAULT_SECURITY_BITS, + ), bench_iterations: env_or_parse_usize( "DIAMOND_WE_GPU_BENCH_ITERATIONS", DEFAULT_BENCH_ITERATIONS, @@ -161,7 +181,14 @@ impl DiamondWEGpuBenchConfig { ), d_secret: env_or_parse_usize("DIAMOND_WE_GPU_BENCH_D_SECRET", DEFAULT_D_SECRET), }; - assert!(cfg.ring_dim > 0, "DIAMOND_WE_GPU_BENCH_RING_DIM must be positive"); + assert!( + cfg.min_log_ring_dim <= cfg.max_log_ring_dim, + "DIAMOND_WE_GPU_BENCH_MIN_LOG_RING_DIM must be <= DIAMOND_WE_GPU_BENCH_MAX_LOG_RING_DIM" + ); + assert!( + cfg.max_log_ring_dim < u32::BITS as usize, + "DIAMOND_WE_GPU_BENCH_MAX_LOG_RING_DIM must be < 32" + ); assert!(cfg.circuit_height > 0, "DIAMOND_WE_GPU_BENCH_CIRCUIT_HEIGHT must be positive"); assert!(cfg.witness_size > 0, "DIAMOND_WE_GPU_BENCH_WITNESS_SIZE must be positive"); assert!( @@ -187,6 +214,7 @@ impl DiamondWEGpuBenchConfig { cfg.min_crt_depth <= cfg.max_crt_depth, "DIAMOND_WE_GPU_BENCH_MIN_CRT_DEPTH must be <= DIAMOND_WE_GPU_BENCH_MAX_CRT_DEPTH" ); + assert!(cfg.security_bits > 0, "DIAMOND_WE_GPU_BENCH_SECURITY_BITS must be positive"); assert!(cfg.bench_iterations > 0, "DIAMOND_WE_GPU_BENCH_ITERATIONS must be positive"); assert!(cfg.error_sigma >= 0.0, "DIAMOND_WE_GPU_BENCH_ERROR_SIGMA must be nonnegative"); assert!(cfg.trapdoor_sigma > 0.0, "DIAMOND_WE_GPU_BENCH_TRAPDOOR_SIGMA must be positive"); @@ -224,6 +252,21 @@ impl DiamondWEGpuBenchConfig { fn selected_simulation_from_env(&self) -> Option { let crt_depth = env_or_parse_optional_usize("DIAMOND_WE_GPU_BENCH_SELECTED_CRT_DEPTH")?; + let ring_dim = env_or_parse_optional_u32("DIAMOND_WE_GPU_BENCH_SELECTED_RING_DIM") + .unwrap_or(self.ring_dim); + let log_ring_dim = + env_or_parse_optional_usize("DIAMOND_WE_GPU_BENCH_SELECTED_LOG_RING_DIM") + .unwrap_or_else(|| { + assert!( + ring_dim.is_power_of_two(), + "DIAMOND_WE_GPU_BENCH_SELECTED_RING_DIM must be a power of two" + ); + ring_dim.trailing_zeros() as usize + }); + let achieved_secpar_for_gauss = + env_or_parse_optional_u64("DIAMOND_WE_GPU_BENCH_SELECTED_ACHIEVED_SECPAR_FOR_GAUSS"); + let achieved_secpar_for_cbd = + env_or_parse_optional_u64("DIAMOND_WE_GPU_BENCH_SELECTED_ACHIEVED_SECPAR_FOR_CBD"); let noisy_plaintext_error_bits = env_or_parse_optional_usize("DIAMOND_WE_GPU_BENCH_SELECTED_NOISY_PLAINTEXT_ERROR_BITS") .unwrap_or(0); @@ -233,6 +276,10 @@ impl DiamondWEGpuBenchConfig { assert!(crt_depth > 0, "DIAMOND_WE_GPU_BENCH_SELECTED_CRT_DEPTH must be positive"); Some(DiamondWEGpuBenchSelectedSimulation { crt_depth, + ring_dim, + log_ring_dim, + achieved_secpar_for_gauss, + achieved_secpar_for_cbd, noisy_plaintext_error_bits, input_injection_error_bits, }) @@ -259,6 +306,18 @@ fn env_or_parse_optional_usize(key: &str) -> Option { .map(|raw| raw.parse::().unwrap_or_else(|err| panic!("{key} must be usize: {err}"))) } +fn env_or_parse_optional_u32(key: &str) -> Option { + env::var(key) + .ok() + .map(|raw| raw.parse::().unwrap_or_else(|err| panic!("{key} must be u32: {err}"))) +} + +fn env_or_parse_optional_u64(key: &str) -> Option { + env::var(key) + .ok() + .map(|raw| raw.parse::().unwrap_or_else(|err| panic!("{key} must be u64: {err}"))) +} + fn env_or_parse_f64(key: &str, default: f64) -> f64 { env::var(key) .ok() @@ -312,10 +371,11 @@ fn build_circuit(height: usize) -> PolyCircuit

{ fn build_cpu_diamond_we_for_search( cfg: &DiamondWEGpuBenchConfig, + ring_dim: u32, crt_depth: usize, dir_path: PathBuf, ) -> CpuDiamondWE { - let params = DCRTPolyParams::new(cfg.ring_dim, crt_depth, cfg.crt_bits, cfg.base_bits); + let params = DCRTPolyParams::new(ring_dim, crt_depth, cfg.crt_bits, cfg.base_bits); let injector = CpuInjector::new( params, cfg.injector_input_count(), @@ -487,6 +547,7 @@ async fn test_gpu_diamond_we_error_search_bench_estimate_and_round_trip() { let log_filter = tracing_subscriber::filter::Targets::new() .with_target("test_gpu_diamond_we", tracing_subscriber::filter::LevelFilter::INFO) .with_target("mxx::we::diamond_we", tracing_subscriber::filter::LevelFilter::INFO) + .with_target("mxx::io::utils::simulation", tracing_subscriber::filter::LevelFilter::INFO) .with_target( "mxx::we::diamond_we::bench_estimator", tracing_subscriber::filter::LevelFilter::DEBUG, @@ -511,6 +572,10 @@ async fn test_gpu_diamond_we_error_search_bench_estimate_and_round_trip() { let selected = if let Some(selected) = cfg.selected_simulation_from_env() { info!( crt_depth = selected.crt_depth, + log_ring_dim = selected.log_ring_dim, + ring_dim = selected.ring_dim, + achieved_secpar_for_gauss = selected.achieved_secpar_for_gauss, + achieved_secpar_for_cbd = selected.achieved_secpar_for_cbd, noisy_plaintext_error_bits = selected.noisy_plaintext_error_bits, input_injection_error_bits = selected.input_injection_error_bits, "DiamondWE selected simulation parameters provided; skipping error simulation" @@ -521,9 +586,17 @@ async fn test_gpu_diamond_we_error_search_bench_estimate_and_round_trip() { let search = diamond_we_find_crt_depth( cfg.min_crt_depth, cfg.max_crt_depth, + cfg.min_log_ring_dim, + cfg.max_log_ring_dim, + cfg.security_bits, &cpu_circuit, - |crt_depth| { - build_cpu_diamond_we_for_search(&cfg, crt_depth, search_dir.join("candidate")) + |ring_dim, crt_depth| { + build_cpu_diamond_we_for_search( + &cfg, + ring_dim, + crt_depth, + search_dir.join("candidate"), + ) }, None::<&NoCircuitEvaluator>, None::<&NoCircuitEvaluator>, @@ -531,6 +604,10 @@ async fn test_gpu_diamond_we_error_search_bench_estimate_and_round_trip() { .expect("DiamondWE CRT-depth search must find a valid benchmark candidate"); let selected = DiamondWEGpuBenchSelectedSimulation { crt_depth: search.crt_depth, + ring_dim: search.ring_dim, + log_ring_dim: search.log_ring_dim, + achieved_secpar_for_gauss: search.achieved_secpar_for_gauss, + achieved_secpar_for_cbd: search.achieved_secpar_for_cbd, noisy_plaintext_error_bits: bigdecimal_bits_ceil( &search.simulation.noisy_plaintext_error.poly_norm.norm, ) as usize, @@ -540,6 +617,10 @@ async fn test_gpu_diamond_we_error_search_bench_estimate_and_round_trip() { }; info!( crt_depth = selected.crt_depth, + log_ring_dim = selected.log_ring_dim, + ring_dim = selected.ring_dim, + achieved_secpar_for_gauss = selected.achieved_secpar_for_gauss, + achieved_secpar_for_cbd = selected.achieved_secpar_for_cbd, noisy_plaintext_error_bits = selected.noisy_plaintext_error_bits, input_injection_error_bits = selected.input_injection_error_bits, "DiamondWE CRT-depth search selected parameters" @@ -547,7 +628,9 @@ async fn test_gpu_diamond_we_error_search_bench_estimate_and_round_trip() { selected }; - let (_cpu_params, gpu_params) = gpu_params_for_crt_depth(&cfg, selected.crt_depth, gpu_id); + let selected_cfg = DiamondWEGpuBenchConfig { ring_dim: selected.ring_dim, ..cfg.clone() }; + let (_cpu_params, gpu_params) = + gpu_params_for_crt_depth(&selected_cfg, selected.crt_depth, gpu_id); let final_dir = artifact_dir_from_env(temp_dir.path().join("final_estimate")); ensure_dir(&final_dir); info!(