diff --git a/whir/examples/pareto_frontier.rs b/whir/examples/pareto_frontier.rs index 452fa0ad35..bace2c0462 100644 --- a/whir/examples/pareto_frontier.rs +++ b/whir/examples/pareto_frontier.rs @@ -13,12 +13,15 @@ //! 256 polynomials of size 2^(m-8), opened at one common point), and plots two //! log-log views with the same y-axis (argument size = postcard proof bytes): //! -//! - x = total committed oracle length (Σ codeword sizes) — an analytic, -//! noise-free proxy for prover work (LDE FFTs + Merkle hashing). +//! - x = a **modelled prover cost** (base-field multiply-equivalents): per +//! committed codeword, the encode (FFT) and Merkle-hash terms, plus WHIR's +//! open-phase sumcheck and the claim-batching term. This extends the older +//! "total committed oracle length" proxy, which priced only the commit +//! phase and so under-counted WHIR (see the cost-model constants below). //! - x = measured prover wall-clock (commit + open). //! -//! Each panel draws dotted iso-knob curves, shaded by value: WHIR grouped by k -//! (blues), FRI grouped by log_blowup (reds). Lower-left is better. +//! Each panel draws the full sample cloud plus one per-protocol Pareto frontier +//! (lower-left is better). The raw oracle length is still emitted to the CSV. //! //! # Knobs swept //! @@ -104,6 +107,22 @@ const FRI_LOG_FINAL_POLY_LEN: usize = 0; /// lengths to base-field elements. const EXT_DEGREE: u128 = 4; +// --- Prover-cost model constants (Steps 1-2 of the cost model). ------------- +// Everything is expressed in base-field multiply-equivalents, so the FFT, +// Merkle-hash, sumcheck and batching terms can be summed into one number. The +// constants are rough order-of-magnitude values: only the *relative* shape of +// the modelled cost matters here. Calibrating them against a handful of +// measured points (model "Step 3") would make the absolute scale meaningful. + +/// Cost of one extension-field multiply, in base-field multiplies (~d² with +/// schoolbook for a degree-d extension; ~9 with Karatsuba for d = 4). +const EXT_MUL_COST: f64 = 16.0; +/// Cost of one Poseidon permutation, in base-field multiplies (rough). +const PERM_COST: f64 = 200.0; +/// Degree of WHIR's eq-weighted sumcheck constraint (a constant multiplier on +/// the per-cell sumcheck work). +const SUMCHECK_DEG: f64 = 2.0; + /// Minimum FRI query count reaching `SECURITY_LEVEL` under the capacity formula /// `log_blowup * queries + query_pow >= security_level`. fn fri_min_queries(log_blowup: usize, query_pow_bits: usize) -> usize { @@ -163,6 +182,10 @@ struct WhirBuilt, Ch> { /// overlooked term in argument size: high folding factors make every queried /// leaf much wider, even if they reduce the number of rounds. query_widths: Vec, + /// Modelled prover cost (base-field multiply-equivalents): encode (FFT) + + /// Merkle hash, per committed codeword, plus the open-phase sumcheck and + /// claim-batching terms the raw oracle length omits. + model_cost: f64, } /// Build a WHIR rig for `(m, log_width)` with an arbitrary folding strategy and @@ -231,6 +254,37 @@ where .at_round(config.round_parameters.len()), ); + // ---- Prover-cost model (Steps 1-2): encode + hash + sumcheck + batch. ---- + let mut model_cost = 0.0f64; + // Commit, initial base-field codeword: one FFT (encode) + Merkle hash. + { + let n = (1u128 << (num_variables + starting)) as f64; + model_cost += n * n.log2(); // T_encode (base field, w_mul = 1) + // Leaves = codeword / coset; clamp so an oversized coset gives one leaf. + let leaves = + (1u128 << (num_variables + starting).saturating_sub(config.folding_factor.at_round(0))) as f64; + model_cost += (n + leaves) * PERM_COST; // T_hash + } + // Commit, each per-round extension-field codeword. Unlike FRI, WHIR + // re-encodes (a fresh FFT) every round. + for r in &config.round_parameters { + let n = (1u128 << (r.num_variables + r.log_inv_rate)) as f64; + model_cost += n * n.log2() * EXT_MUL_COST; // T_encode (extension field) + let n_base = n * EXT_DEGREE as f64; + let leaves = + (1u128 << (r.num_variables + r.log_inv_rate).saturating_sub(r.folding_factor)) as f64; + model_cost += (n_base + leaves) * PERM_COST; // T_hash + } + // Open: multilinear sumcheck over the 2^m hypercube. The cube halves each + // round, so total prover work ≈ 2·2^m extension-field multiply-adds. FRI + // has no analogue of this term. + model_cost += SUMCHECK_DEG * 2.0 * (1u128 << num_variables) as f64 * EXT_MUL_COST; + // Open: batching the `2^log_width` evaluation claims into one constraint. + // The opening points share the trailing coordinates, so the eq-combination + // factorises to ~one pass over the cube; generic (non-shared) points would + // scale this with the claim count. + model_cost += (1u128 << num_variables) as f64 * EXT_MUL_COST; + let seed = BENCH_SEED ^ ((num_variables as u64) << 16) ^ ((log_width as u64) << 8) @@ -268,6 +322,7 @@ where oracle_len, queries, query_widths, + model_cost, }) } @@ -391,6 +446,36 @@ fn fri_oracle_len( total } +/// Modelled FRI prover cost (Steps 1-2), in base-field multiply-equivalents. +/// +/// Unlike WHIR, FRI does a single FFT (the input-matrix LDE) and then folds for +/// free, so only the input oracle carries an encode term; the commit-phase +/// codewords are hashed but not re-encoded, and there is no sumcheck. +fn fri_cost(num_variables: usize, log_width: usize, log_blowup: usize, max_log_arity: usize) -> f64 { + let per_col_log = (num_variables - log_width + log_blowup) as f64; + let n_input = (1u128 << (num_variables + log_blowup)) as f64; + + // Commit, input matrix: one (per-column) FFT over the base field + hash. + let mut model = n_input * per_col_log; // T_encode (base field) + let input_leaves = (1u128 << (num_variables - log_width + log_blowup)) as f64; + model += (n_input + input_leaves) * PERM_COST; // T_hash + + // Commit phase: hash only (FFT-free folding), extension field. + let floor = log_blowup + FRI_LOG_FINAL_POLY_LEN; + let mut h = num_variables - log_width + log_blowup; + while h > floor { + let arity = max_log_arity.min(h - floor); + let n_base = EXT_DEGREE as f64 * (1u128 << h) as f64; + let leaves = (1u128 << (h - arity)) as f64; + model += (n_base + leaves) * PERM_COST; // T_hash, no T_encode + h -= arity; + } + + // Batching: forming g = Σ αⁱ·fᵢ — one extension-field pass over the matrix. + model += n_input * EXT_MUL_COST; // T_batch + model +} + fn fri_build( num_variables: usize, log_width: usize, @@ -864,8 +949,13 @@ struct Record { /// This config's knob values; currently emitted only for possible future CSV/plot labels. #[allow(dead_code)] knobs: Vec<(&'static str, usize)>, - /// Total committed oracle length in field elements (prover-cost proxy). + /// Total committed oracle length in base-field elements (the original, + /// commit-only proxy). Kept in the CSV for reference. oracle_len: f64, + /// Modelled prover cost (base-field multiply-equivalents): commit (encode + + /// hash) plus the open-phase sumcheck + batching terms. This is the x-axis + /// of the left panel. + model_cost: f64, /// Argument size: postcard-serialised proof bytes. proof_bytes: f64, /// Measured prover wall-clock (commit + open), milliseconds. @@ -962,6 +1052,7 @@ fn main() { label: cand.name, knobs: vec![("k", cand.group), ("start", cand.starting)], oracle_len: built.oracle_len as f64, + model_cost: built.model_cost, proof_bytes: proof_bytes as f64, prove_ms: prove_ms as f64, verify_us: verify_us as f64, @@ -1005,6 +1096,7 @@ fn main() { .expect("postcard FRI proof + openings") .len(); let oracle_len = fri_oracle_len(m, log_width, log_blowup, max_log_arity); + let model_cost = fri_cost(m, log_width, log_blowup, max_log_arity); let arities: Vec = proof .query_proofs @@ -1040,6 +1132,7 @@ fn main() { label: format!("ρ⁻¹=2^{log_blowup} a=2^{max_log_arity} q={num_queries}"), knobs: vec![("blowup", log_blowup), ("arity", max_log_arity)], oracle_len: oracle_len as f64, + model_cost, proof_bytes: proof_bytes as f64, prove_ms: prove_ms as f64, verify_us: verify_us as f64, @@ -1067,14 +1160,17 @@ fn csv_escape(s: &str) -> String { fn write_csv(records: &[Record]) { let mut s = String::new(); - s.push_str("protocol,label,oracle_len_elems,proof_bytes,prove_ms,verify_us,is_default\n"); + s.push_str( + "protocol,label,oracle_len_elems,model_cost,proof_bytes,prove_ms,verify_us,is_default\n", + ); for r in records { let _ = writeln!( s, - "{},{},{:.0},{:.0},{:.0},{:.0},{}", + "{},{},{:.0},{:.0},{:.0},{:.0},{:.0},{}", csv_escape(r.protocol), csv_escape(&r.label), r.oracle_len, + r.model_cost, r.proof_bytes, r.prove_ms, r.verify_us, @@ -1306,7 +1402,7 @@ fn write_svg(m: usize, log_width: usize, records: &[Record]) { ); let _ = writeln!( out, - r##"m=2^{m}, 256 polys of 2^{}, {SECURITY_LEVEL}-bit capacity-regime, Poseidon1 — solid lines are per-protocol Pareto frontiers · ◯ = PR default · lower-left is better"##, + r##"m=2^{m}, 256 polys of 2^{}, {SECURITY_LEVEL}-bit capacity-regime, Poseidon1 — left x = modelled prover cost (encode+hash+sumcheck+batch), right x = measured wall-clock · solid lines = per-protocol Pareto frontiers · ◯ = PR default · lower-left is better"##, w / 2.0, m - log_width, ); @@ -1344,9 +1440,9 @@ fn write_svg(m: usize, log_width: usize, records: &[Record]) { panel_top, panel_w, panel_h, - "Theoretical proxy", - "total committed oracle length, elements (log scale)", - &|r: &Record| r.oracle_len, + "Modelled prover cost", + "encode + hash + sumcheck + batch (base-mult-equiv, log scale)", + &|r: &Record| r.model_cost, &fmt_count, records, );