Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@
/revision_logs
/*.log
/*.status
/bench
62 changes: 55 additions & 7 deletions src/decoder/bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,25 @@ fn estimate_summary(estimate: CircuitBenchEstimate) -> CircuitBenchSummary {
}
}

fn scale_independent_summary_biguint(
summary: CircuitBenchSummary,
count: &BigUint,
) -> CircuitBenchSummary {
let scaled = CircuitBenchSummary::from_nanos(
summary.total_time.clone() * count,
summary.latency,
summary.max_parallelism.clone() * count,
);
#[cfg(feature = "gpu")]
{
scaled.with_peak_vram(summary.peak_vram)
}
#[cfg(not(feature = "gpu"))]
{
scaled
}
}

fn sequential_summaries(parts: &[CircuitBenchSummary]) -> CircuitBenchSummary {
let total_time = parts.iter().map(|part| part.total_time.clone()).sum::<BigUint>();
let latency = parts.iter().map(|part| part.latency).sum::<f64>();
Expand Down Expand Up @@ -201,12 +220,9 @@ pub(crate) fn bit_decomposed_polynomial_mask_decrypt_contribution_count(
ring_dim: usize,
mask_bits: usize,
output_count: usize,
) -> usize {
) -> BigUint {
assert!(mask_bits > 0, "mask_bits must be positive");
ring_dim
.checked_mul(mask_bits)
.and_then(|count| count.checked_mul(output_count))
.expect("bit-decomposed polynomial mask decrypt contribution count overflow")
BigUint::from(ring_dim) * BigUint::from(mask_bits) * BigUint::from(output_count)
}

/// Scale a representative one-ciphertext-bit decrypt contribution to a full
Expand All @@ -216,13 +232,13 @@ pub(crate) fn scale_bit_decomposed_polynomial_mask_decrypt_contributions(
ring_dim: usize,
mask_bits: usize,
output_count: usize,
) -> (CircuitBenchSummary, usize) {
) -> (CircuitBenchSummary, BigUint) {
let contribution_count = bit_decomposed_polynomial_mask_decrypt_contribution_count(
ring_dim,
mask_bits,
output_count,
);
(scale_independent_summary(unit, contribution_count), contribution_count)
(scale_independent_summary_biguint(unit, &contribution_count), contribution_count)
}

/// Number of additions needed to reduce `mask_bits` decrypted bit terms.
Expand Down Expand Up @@ -293,3 +309,35 @@ pub(crate) fn bit_decomposed_polynomial_mask_reduction_summary(
}
scale_independent_summary(sequential_summaries(&per_polynomial_parts), polynomial_count)
}

#[cfg(test)]
mod tests {
use super::*;

fn summary(total_time: u64, latency: f64, max_parallelism: u64) -> CircuitBenchSummary {
CircuitBenchSummary::from_nanos(
BigUint::from(total_time),
latency,
BigUint::from(max_parallelism),
)
}

#[test]
fn mask_decrypt_contribution_count_supports_large_aky24_cascade_outputs() {
let output_count = usize::MAX / 2;
let contribution_count =
bit_decomposed_polynomial_mask_decrypt_contribution_count(65_536, 1_418, output_count);
assert!(contribution_count > BigUint::from(usize::MAX));

let (scaled, returned_count) = scale_bit_decomposed_polynomial_mask_decrypt_contributions(
summary(7, 1.25, 3),
65_536,
1_418,
output_count,
);
assert_eq!(returned_count, contribution_count);
assert_eq!(scaled.total_time, &contribution_count * BigUint::from(7u32));
assert_eq!(scaled.latency, 1.25);
assert_eq!(scaled.max_parallelism, contribution_count * BigUint::from(3u32));
}
}
29 changes: 23 additions & 6 deletions src/io/aky24_io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ pub struct Aky24IO<
pub output_size: usize,
/// Number of private PRF seed bits encrypted into Ring-GSW ciphertexts.
pub seed_bits: usize,
/// Number of public PRF seed bits, hence PRF seed-refresh rounds.
pub public_prf_seed_bits: usize,
/// Number of input bits processed by one seed-refresh round.
pub prf_batch_bits: usize,
/// Number of bit-decomposed PRF mask output coefficients to compute.
pub prf_mask_output_coeff_bits: usize,
/// Number of low bits retained in the noise-refresh rounding material.
Expand Down Expand Up @@ -113,7 +113,7 @@ where
input_size: usize,
output_size: usize,
seed_bits: usize,
public_prf_seed_bits: usize,
prf_batch_bits: usize,
prf_mask_output_coeff_bits: usize,
noise_refresh_v_bits: usize,
noise_refresh_cbd_n: usize,
Expand All @@ -127,7 +127,16 @@ where
assert!(input_size > 0, "AKY24IO input_size must be positive");
assert!(output_size > 0, "AKY24IO output_size must be positive");
assert!(seed_bits > 0, "AKY24IO seed_bits must be positive");
assert!(public_prf_seed_bits > 0, "AKY24IO public_prf_seed_bits must be positive");
assert!(prf_batch_bits > 0, "AKY24IO prf_batch_bits must be positive");
assert!(
prf_batch_bits < usize::BITS as usize,
"AKY24IO prf_batch_bits must fit in a usize branch count"
);
assert_eq!(
input_size % prf_batch_bits,
0,
"AKY24IO input_size must be divisible by prf_batch_bits"
);
assert!(
prf_mask_output_coeff_bits > 0,
"AKY24IO prf_mask_output_coeff_bits must be positive"
Expand All @@ -146,7 +155,7 @@ where
input_size,
output_size,
seed_bits,
public_prf_seed_bits,
prf_batch_bits,
prf_mask_output_coeff_bits,
noise_refresh_v_bits,
noise_refresh_cbd_n,
Expand All @@ -160,7 +169,15 @@ where
}
}

pub(crate) fn prf_round_count(&self) -> usize {
self.input_size / self.prf_batch_bits
}

pub(crate) fn prf_branch_count(&self) -> usize {
1usize.checked_shl(self.prf_batch_bits as u32).expect("AKY24IO PRF branch count overflow")
}

pub(crate) fn prf_final_round_idx(&self) -> usize {
self.public_prf_seed_bits
self.prf_round_count()
}
}
Loading
Loading