diff --git a/jolt-core/benches/e2e_profiling.rs b/jolt-core/benches/e2e_profiling.rs index 8638cff9a..64199c3f2 100644 --- a/jolt-core/benches/e2e_profiling.rs +++ b/jolt-core/benches/e2e_profiling.rs @@ -203,7 +203,7 @@ fn prove_example( let mut tasks = Vec::new(); let mut program = host::Program::new(example_name); let (bytecode, init_memory_state, _) = program.decode(); - let (_, trace, _, program_io) = program.trace(&serialized_input, &[], &[]); + let (_lazy_trace, trace, _, program_io) = program.trace(&serialized_input, &[], &[]); let padded_trace_len = (trace.len() + 1).next_power_of_two(); drop(trace); diff --git a/jolt-core/src/poly/mod.rs b/jolt-core/src/poly/mod.rs index 2f55b2395..6cd948840 100644 --- a/jolt-core/src/poly/mod.rs +++ b/jolt-core/src/poly/mod.rs @@ -7,6 +7,7 @@ pub mod identity_poly; pub mod lagrange_poly; pub mod lt_poly; pub mod multilinear_polynomial; +pub mod multiquadratic_poly; pub mod one_hot_polynomial; pub mod opening_proof; pub mod prefix_suffix; diff --git a/jolt-core/src/poly/multiquadratic_poly.rs b/jolt-core/src/poly/multiquadratic_poly.rs new file mode 100644 index 000000000..faef0beed --- /dev/null +++ b/jolt-core/src/poly/multiquadratic_poly.rs @@ -0,0 +1,386 @@ +use allocative::Allocative; + +use crate::field::JoltField; +use crate::poly::multilinear_polynomial::{BindingOrder, PolynomialBinding}; + +/// Multiquadratic polynomial represented by its evaluations on the grid +/// {0, 1, ∞}^num_vars in base-3 layout (z_0 least-significant / fastest-varying). +#[derive(Allocative)] +pub struct MultiquadraticPolynomial { + num_vars: usize, + evals: Vec, +} + +impl MultiquadraticPolynomial { + /// Construct a multiquadratic polynomial from its full grid of evaluations. + /// The caller is responsible for ensuring that `evals` is laid out in base-3 + /// order with z_0 as the least-significant digit. + pub fn new(num_vars: usize, evals: Vec) -> Self { + let expected_len = 3usize.pow(num_vars as u32); + debug_assert!( + evals.len() == expected_len, + "MultiquadraticPolynomial: expected {} evals, got {}", + expected_len, + evals.len() + ); + Self { num_vars, evals } + } + + /// Number of variables in the polynomial. + pub fn num_vars(&self) -> usize { + self.num_vars + } + + /// Underlying evaluations on {0, 1, ∞}^num_vars. + pub fn evals(&self) -> &[F] { + &self.evals + } + + /// Given evaluations of a degree-1 multivariate polynomial over {0,1}^dim, + /// expand them to the corresponding multiquadratic grid over {0,1,∞}^dim. + /// + /// The input is a length-2^dim slice `input` containing evaluations on the + /// Boolean hypercube. The caller must provide two length-3^dim buffers: + /// - `output` will contain the final {0,1,∞}^dim values on return + /// - `tmp` is a scratch buffer which this routine may use internally + /// + /// Layout is product-order with the last variable as the fastest-varying + /// coordinate. For each 1D slice (f0, f1) along a new dimension we write + /// (f(0), f(1), f(∞)) = (f0, f1, f1 - f0), so ∞ stores the slope. + /// + /// TODO: special-case dim ∈ {1,2,3} with hand-unrolled code to reduce + /// loop overhead on small windows. + #[inline(always)] + pub fn expand_linear_grid_to_multiquadratic( + input: &[F], // initial buffer (size 2^dim) + output: &mut [F], // final buffer (size 3^dim) + tmp: &mut [F], // scratch buffer, also (size 3^dim) + dim: usize, + ) { + let in_size = 1usize << dim; + let out_size = 3usize.pow(dim as u32); + + assert_eq!(input.len(), in_size); + assert_eq!(output.len(), out_size); + assert_eq!(tmp.len(), out_size); + + match dim { + 0 => { + if !input.is_empty() { + output[0] = input[0]; + } + return; + } + 1 => { + Self::expand_linear_dim1(input, output); + return; + } + 2 => { + Self::expand_linear_dim2(input, output); + return; + } + 3 => { + Self::expand_linear_dim3(input, output); + return; + } + _ => {} + } + + // Fill output by expanding one dimension at a time. + // We treat slices of increasing "arity" + + // Copy the initial evaluations into the start of either + // tmp or output, depending on parity of dim. + // We'll alternate between tmp and output as we expand dimensions. + let (mut cur, mut next) = if dim % 2 == 1 { + tmp[..input.len()].copy_from_slice(input); + (tmp, output) + } else { + output[..input.len()].copy_from_slice(input); + (output, tmp) + }; + + let mut in_stride = 1usize; + let mut out_stride = 1usize; + let mut blocks = 1 << (dim - 1); + + // sanity checks + assert_eq!(cur.len(), out_size); + assert_eq!(next.len(), out_size); + assert_eq!(input.len(), in_size); + + // start from the smallest subcubes and expand dimension by dimension + for _ in 0..dim { + for b in 0..blocks { + let in_off = b * 2 * in_stride; + let out_off = b * 3 * out_stride; + + for j in 0..in_stride { + // 1d extrapolate + let f0 = cur[in_off + j]; + let f1 = cur[in_off + in_stride + j]; + next[out_off + j] = f0; + next[out_off + out_stride + j] = f1; + next[out_off + 2 * out_stride + j] = f1 - f0; + } + } + // swap buffers + std::mem::swap(&mut cur, &mut next); + in_stride *= 3; + out_stride *= 3; + blocks /= 2; + } + } + + #[inline(always)] + fn expand_linear_dim1(input: &[F], output: &mut [F]) { + debug_assert_eq!(input.len(), 2); + debug_assert_eq!(output.len(), 3); + + let f0 = input[0]; + let f1 = input[1]; + + output[0] = f0; + output[1] = f1; + output[2] = f1 - f0; + } + + #[inline(always)] + fn expand_linear_dim2(input: &[F], output: &mut [F]) { + debug_assert_eq!(input.len(), 4); + debug_assert_eq!(output.len(), 9); + + let f00 = input[0]; // f(0,0) + let f01 = input[1]; // f(0,1) + let f10 = input[2]; // f(1,0) + let f11 = input[3]; // f(1,1) + + // First extrapolate along the fastest-varying variable (second coordinate). + let a00 = f00; + let a01 = f01; + let a0_inf = f01 - f00; + + let a10 = f10; + let a11 = f11; + let a1_inf = f11 - f10; + + // Then extrapolate along the remaining variable. + let inf0 = a10 - a00; + let inf1 = a11 - a01; + let inf_inf = a1_inf - a0_inf; + + // Layout: index = 3 * enc(x0) + enc(x1), x1 fastest, enc: {0,1,∞} -> {0,1,2}. + output[0] = a00; // (0,0) + output[1] = a01; // (0,1) + output[2] = a0_inf; // (0,∞) + + output[3] = a10; // (1,0) + output[4] = a11; // (1,1) + output[5] = a1_inf; // (1,∞) + + output[6] = inf0; // (∞,0) + output[7] = inf1; // (∞,1) + output[8] = inf_inf; // (∞,∞) + } + + #[inline(always)] + fn expand_linear_dim3(input: &[F], output: &mut [F]) { + debug_assert_eq!(input.len(), 8); + debug_assert_eq!(output.len(), 27); + + // Corner values f(x0, x1, x2) with x2 fastest. + let f000 = input[0]; + let f001 = input[1]; + let f010 = input[2]; + let f011 = input[3]; + let f100 = input[4]; + let f101 = input[5]; + let f110 = input[6]; + let f111 = input[7]; + + // Stage 1: extrapolate along x2 (fastest variable) for each (x0, x1). + let g000 = f000; + let g001 = f001; + let g00_inf = f001 - f000; + + let g010 = f010; + let g011 = f011; + let g01_inf = f011 - f010; + + let g100 = f100; + let g101 = f101; + let g10_inf = f101 - f100; + + let g110 = f110; + let g111 = f111; + let g11_inf = f111 - f110; + + // Stage 2: extrapolate along x1 for each (x0, x2). + // x0 = 0 + let h0_0_0 = g000; + let h0_1_0 = g010; + let h0_inf_0 = g010 - g000; + + let h0_0_1 = g001; + let h0_1_1 = g011; + let h0_inf_1 = g011 - g001; + + let h0_0_inf = g00_inf; + let h0_1_inf = g01_inf; + let h0_inf_inf = g01_inf - g00_inf; + + // x0 = 1 + let h1_0_0 = g100; + let h1_1_0 = g110; + let h1_inf_0 = g110 - g100; + + let h1_0_1 = g101; + let h1_1_1 = g111; + let h1_inf_1 = g111 - g101; + + let h1_0_inf = g10_inf; + let h1_1_inf = g11_inf; + let h1_inf_inf = g11_inf - g10_inf; + + // Stage 3: extrapolate along x0 for each (x1, x2). + // Index: idx(x0, x1, x2) = 9 * enc(x0) + 3 * enc(x1) + enc(x2), + // enc: {0,1,∞} -> {0,1,2}, x2 fastest. + + // (x1, x2) = (0, 0) + output[0] = h0_0_0; // (0,0,0) + output[9] = h1_0_0; // (1,0,0) + output[18] = h1_0_0 - h0_0_0; // (∞,0,0) + + // (0, 1) + output[1] = h0_0_1; // (0,0,1) + output[10] = h1_0_1; // (1,0,1) + output[19] = h1_0_1 - h0_0_1; // (∞,0,1) + + // (0, ∞) + output[2] = h0_0_inf; // (0,0,∞) + output[11] = h1_0_inf; // (1,0,∞) + output[20] = h1_0_inf - h0_0_inf; // (∞,0,∞) + + // (1, 0) + output[3] = h0_1_0; // (0,1,0) + output[12] = h1_1_0; // (1,1,0) + output[21] = h1_1_0 - h0_1_0; // (∞,1,0) + + // (1, 1) + output[4] = h0_1_1; // (0,1,1) + output[13] = h1_1_1; // (1,1,1) + output[22] = h1_1_1 - h0_1_1; // (∞,1,1) + + // (1, ∞) + output[5] = h0_1_inf; // (0,1,∞) + output[14] = h1_1_inf; // (1,1,∞) + output[23] = h1_1_inf - h0_1_inf; // (∞,1,∞) + + // (∞, 0) + output[6] = h0_inf_0; // (0,∞,0) + output[15] = h1_inf_0; // (1,∞,0) + output[24] = h1_inf_0 - h0_inf_0; // (∞,∞,0) + + // (∞, 1) + output[7] = h0_inf_1; // (0,∞,1) + output[16] = h1_inf_1; // (1,∞,1) + output[25] = h1_inf_1 - h0_inf_1; // (∞,∞,1) + + // (∞, ∞) + output[8] = h0_inf_inf; // (0,∞,∞) + output[17] = h1_inf_inf; // (1,∞,∞) + output[26] = h1_inf_inf - h0_inf_inf; // (∞,∞,∞) + } + + /// Bind the first (least-significant) variable z_0 := r, reducing the + /// dimension from w to w-1 and keeping the base-3 layout invariant. + /// + /// For each assignment to (z_1, ..., z_{w-1}), we have three stored values + /// f(0, ..), f(1, ..), f(∞, ..) + /// and interpolate the unique quadratic in z_0 that matches them, then + /// evaluate it at z_0 = r. + pub fn bind_first_variable(&mut self, r: F::Challenge) { + let w = self.num_vars; + debug_assert!(w > 0); + + let new_size = 3_usize.pow((w - 1) as u32); + let one = F::one(); + + for new_idx in 0..new_size { + let old_base_idx = new_idx * 3; + let eval_at_0 = self.evals[old_base_idx]; // z_0 = 0 + let eval_at_1 = self.evals[old_base_idx + 1]; // z_0 = 1 + let eval_at_inf = self.evals[old_base_idx + 2]; // z_0 = ∞ + + self.evals[new_idx] = + eval_at_0 * (one - r) + eval_at_1 * r + eval_at_inf * r * (r - one); + } + + self.num_vars -= 1; + self.evals.truncate(new_size); + } + + /// Project t'(z_0, z_1, ..., z_{w-1}) to a univariate in z_0 by summing + /// against `E_active` over the remaining coordinates. + /// + /// The `E_active` table is interpreted identically to the existing outer + /// Spartan streaming implementation: each index encodes, in binary, which + /// of z_1..z_{w-1} take the "active" value (mapped to base-3 offset 1). + /// `first_coord_val` is the z_0 coordinate in {0, 1, 2}, where 2 encodes ∞. + pub fn project_to_first_variable(&self, E_active: &[F], first_coord_val: usize) -> F { + let w = self.num_vars; + debug_assert!(w >= 1); + + let offset = first_coord_val; // z_0 lives at the units place in base-3 + + E_active + .iter() + .enumerate() + .map(|(eq_active_idx, eq_active_val)| { + let mut index = offset; + let mut temp = eq_active_idx; + let mut power = 3; // start at 3^1 for z_1 + + for _ in 0..(w - 1) { + if temp & 1 == 1 { + index += power; + } + power *= 3; + temp >>= 1; + } + + self.evals[index] * *eq_active_val + }) + .sum() + } +} + +impl PolynomialBinding for MultiquadraticPolynomial { + fn is_bound(&self) -> bool { + self.num_vars == 0 || self.evals.len() == 1 + } + + #[tracing::instrument(skip_all, name = "MultiquadraticPolynomial::bind")] + fn bind(&mut self, r: F::Challenge, order: BindingOrder) { + match order { + BindingOrder::LowToHigh => self.bind_first_variable(r), + BindingOrder::HighToLow => { + // Not currently needed by the outer Spartan streaming code. + unimplemented!( + "HighToLow binding order is not implemented for MultiquadraticPolynomial" + ) + } + } + } + + fn bind_parallel(&mut self, r: F::Challenge, order: BindingOrder) { + // Window sizes are small; fall back to the sequential implementation. + self.bind(r, order); + } + + fn final_sumcheck_claim(&self) -> F { + debug_assert!(self.is_bound()); + debug_assert_eq!(self.evals.len(), 1); + self.evals[0] + } +} diff --git a/jolt-core/src/poly/split_eq_poly.rs b/jolt-core/src/poly/split_eq_poly.rs index c2dd0ee1c..266f4c88c 100644 --- a/jolt-core/src/poly/split_eq_poly.rs +++ b/jolt-core/src/poly/split_eq_poly.rs @@ -33,11 +33,66 @@ pub struct GruenSplitEqPolynomial { pub(crate) w: Vec, pub(crate) E_in_vec: Vec>, pub(crate) E_out_vec: Vec>, + /// Cached `[1]` table used to represent eq over zero variables when a side + /// (head, inner, or active) has no bits. + one_table: Vec, pub(crate) binding_order: BindingOrder, } impl GruenSplitEqPolynomial { #[tracing::instrument(skip_all, name = "GruenSplitEqPolynomial::new_with_scaling")] + //pub fn new_with_scaling( + // w: &[F::Challenge], + // binding_order: BindingOrder, + // scaling_factor: Option, + //) -> Self { + // match binding_order { + // BindingOrder::LowToHigh => { + // let m = w.len() / 2; + // // w = [w_out, w_in, w_last] + // // ↑ ↑ ↑ + // // | | | + // // | | last element + // // | second half of remaining elements (for E_in) + // // first half of remaining elements (for E_out) + // let (_, wprime) = w.split_last().unwrap(); + // let (w_out, w_in) = wprime.split_at(m); + // let (E_out_vec, E_in_vec) = rayon::join( + // || EqPolynomial::evals_cached(w_out), + // || EqPolynomial::evals_cached(w_in), + // ); + // Self { + // current_index: w.len(), + // current_scalar: scaling_factor.unwrap_or(F::one()), + // w: w.to_vec(), + // E_in_vec, + // E_out_vec, + // binding_order, + // } + // } + // BindingOrder::HighToLow => { + // // For high-to-low binding, we bind from MSB (index 0) to LSB (index n-1). + // // The split should be: w_in = first half, w_out = second half + // // [w_first, w_in, w_out] + // let (_, wprime) = w.split_first().unwrap(); + // let m = w.len() / 2; + // let (w_in, w_out) = wprime.split_at(m); + // let (E_in_vec, E_out_vec) = rayon::join( + // || EqPolynomial::evals_cached_rev(w_in), + // || EqPolynomial::evals_cached_rev(w_out), + // ); + // + // Self { + // current_index: 0, // Start from 0 for high-to-low up to w.len() - 1 + // current_scalar: scaling_factor.unwrap_or(F::one()), + // w: w.to_vec(), + // E_in_vec, + // E_out_vec, + // binding_order, + // } + // } + // } + //} pub fn new_with_scaling( w: &[F::Challenge], binding_order: BindingOrder, @@ -58,12 +113,14 @@ impl GruenSplitEqPolynomial { || EqPolynomial::evals_cached(w_out), || EqPolynomial::evals_cached(w_in), ); + let one_table = vec![F::one()]; Self { current_index: w.len(), current_scalar: scaling_factor.unwrap_or(F::one()), w: w.to_vec(), E_in_vec, E_out_vec, + one_table, binding_order, } } @@ -78,6 +135,7 @@ impl GruenSplitEqPolynomial { || EqPolynomial::evals_cached_rev(w_in), || EqPolynomial::evals_cached_rev(w_out), ); + let one_table = vec![F::one()]; Self { current_index: 0, // Start from 0 for high-to-low up to w.len() - 1 @@ -85,6 +143,7 @@ impl GruenSplitEqPolynomial { w: w.to_vec(), E_in_vec, E_out_vec, + one_table, binding_order, } } @@ -107,6 +166,16 @@ impl GruenSplitEqPolynomial { } } + /// Number of variables that have already been bound into `current_scalar`. + /// For LowToHigh this is `w.len() - current_index`; for HighToLow it is + /// `current_index`. + pub fn num_challenges(&self) -> usize { + match self.binding_order { + BindingOrder::LowToHigh => self.w.len() - self.current_index, + BindingOrder::HighToLow => self.current_index, + } + } + pub fn E_in_current_len(&self) -> usize { self.E_in_vec.last().map_or(0, |v| v.len()) } @@ -125,6 +194,128 @@ impl GruenSplitEqPolynomial { self.E_out_vec.last().map_or(&[], |v| v.as_slice()) } + /// Return the (E_out, E_in) tables corresponding to a streaming window of the + /// given `window_size`, using an explicit slice-based factorisation of the + /// current unbound variables. + /// + /// Semantics (LowToHigh): + /// - Let `num_unbound = current_index` and `remaining_w = w[..num_unbound]`. + /// - For a window of size `window_size >= 1`, define: + /// - `w_window` as the last `window_size` bits of `remaining_w` + /// - `w_head` as the prefix before `w_window` + /// - within `w_window`, the last bit is the current Gruen variable and the + /// preceding `window_size - 1` bits are the "active" window bits + /// - This function returns eq tables over `w_head`, split into two halves + /// `w_out` and `w_in`: + /// - `w_head = [w_out || w_in]` with `w_out` = first `⌊|w_head| / 2⌋` bits + /// - `eq(w_head, (x_out, x_in)) = E_out[x_out] * E_in[x_in]`. + /// + /// The active window bits are handled separately by [`E_active_for_window`]. + /// Together they satisfy, for `BindingOrder::LowToHigh`, + /// log2(|E_out|) + log2(|E_in|) + log2(|E_active|) + 1 = #unbound bits, + /// where the final `+ 1` accounts for the current linear Gruen bit. + /// + /// This helper returns slices and represents "no head bits" as + /// single-entry `[1]` tables, matching `eq((), ()) = 1`. + pub fn E_out_in_for_window(&self, window_size: usize) -> (&[F], &[F]) { + if window_size == 0 { + return (&self.one_table, &self.one_table); + } + + match self.binding_order { + BindingOrder::LowToHigh => { + let num_unbound = self.current_index; + if num_unbound == 0 { + return (&self.one_table, &self.one_table); + } + + // Restrict window size to the actually available unbound bits. + let window_size = core::cmp::min(window_size, num_unbound); + let head_len = num_unbound.saturating_sub(window_size); + if head_len == 0 { + // No head bits: represent as eq over zero vars. + return (&self.one_table, &self.one_table); + } + + // The head prefix consists of the earliest `head_len` bits of `w`. + // These live entirely in the original `[w_out || w_in] = w[..n-1]` + // region, so we can factor them via prefixes of `w_out` and `w_in`. + let n = self.w.len(); + let m = n / 2; + + let head_out_bits = core::cmp::min(head_len, m); + let head_in_bits = head_len.saturating_sub(head_out_bits); + + let e_out = if head_out_bits == 0 { + &self.one_table + } else { + debug_assert!( + head_out_bits < self.E_out_vec.len(), + "head_out_bits={} E_out_vec.len()={}", + head_out_bits, + self.E_out_vec.len() + ); + &self.E_out_vec[head_out_bits] + }; + let e_in = if head_in_bits == 0 { + &self.one_table + } else { + debug_assert!( + head_in_bits < self.E_in_vec.len(), + "head_in_bits={} E_in_vec.len()={}", + head_in_bits, + self.E_in_vec.len() + ); + &self.E_in_vec[head_in_bits] + }; + + (e_out, e_in) + } + BindingOrder::HighToLow => { + // Streaming windows are not defined for HighToLow in the current + // Spartan code paths; return neutral head tables. + (&self.one_table, &self.one_table) + } + } + } + + /// Return the equality table over the "active" window bits (all but the + /// last variable in the current streaming window). This is used when + /// projecting the multiquadratic t'(z_0, ..., z_{w-1}) down to a univariate + /// in the first variable by summing against eq(tau_active, ·) over the + /// remaining coordinates. + /// + /// We derive the active slice directly from the unbound portion of `w`. + /// For LowToHigh binding, the unbound variables are `w[..current_index]`; + /// the last `window_size` of these belong to the current window, and all + /// but the final one are "active". + pub fn E_active_for_window(&self, window_size: usize) -> Vec { + if window_size <= 1 { + // No active bits in a size-0/1 window; eq over zero vars is [1]. + return vec![F::one()]; + } + + match self.binding_order { + BindingOrder::LowToHigh => { + let num_unbound = self.current_index; + if window_size > num_unbound { + // Clamp to the maximum meaningful window size at this round. + return vec![F::one()]; + } + let remaining_w = &self.w[..num_unbound]; + let window_start = remaining_w.len() - window_size; + let (_w_body, w_window) = remaining_w.split_at(window_start); + let (w_active, _w_curr_slice) = w_window.split_at(window_size - 1); + // We only need the full eq table over the active window bits. + EqPolynomial::::evals(w_active) + } + BindingOrder::HighToLow => { + // Not used for the outer Spartan streaming code. + vec![F::one()] + } + } + } + #[tracing::instrument(skip_all, name = "GruenSplitEqPolynomial::bind")] pub fn bind(&mut self, r: F::Challenge) { match self.binding_order { @@ -473,4 +664,99 @@ mod tests { assert_eq!(regular_eq.Z[..regular_eq.len()], merged.Z[..merged.len()]); } } + + /// For window_size = 1, `E_out_in_for_window` should factor the eq polynomial + /// over the head bits `w[..current_index-1]` into a product of two tables. + #[test] + fn window_size_one_matches_current() { + const NUM_VARS: usize = 10; + let mut rng = test_rng(); + let w: Vec<::Challenge> = + std::iter::repeat_with(|| ::Challenge::random(&mut rng)) + .take(NUM_VARS) + .collect(); + + let mut split_eq: GruenSplitEqPolynomial = + GruenSplitEqPolynomial::new(&w, BindingOrder::LowToHigh); + + for _round in 0..NUM_VARS { + let num_unbound = split_eq.current_index; + if num_unbound <= 1 { + break; + } + + // Factor head = w[..num_unbound-1] into (E_out, E_in). + let (e_out_window, e_in_window) = split_eq.E_out_in_for_window(1); + let w_head = &split_eq.w[..num_unbound - 1]; + let head_evals = EqPolynomial::evals(w_head); + + let num_x_out = e_out_window.len(); + let num_x_in = e_in_window.len(); + assert_eq!(num_x_out * num_x_in, head_evals.len()); + + let x_in_bits = num_x_in.log_2(); + for x_out in 0..num_x_out { + for x_in in 0..num_x_in { + let idx = (x_out << x_in_bits) | x_in; + assert_eq!( + e_out_window[x_out] * e_in_window[x_in], + head_evals[idx], + "factorisation mismatch at round={_round}, x_out={x_out}, x_in={x_in}", + ); + } + } + + let r = ::Challenge::random(&mut rng); + split_eq.bind(r); + } + } + + /// Check basic bit-accounting invariants for the streaming factorisation: + /// log2(|E_out|) + log2(|E_in|) + log2(|E_active|) + 1 = number of unbound variables + /// for all window sizes and all rounds (LowToHigh). + #[test] + fn window_bit_accounting_invariants() { + const NUM_VARS: usize = 8; + let mut rng = test_rng(); + let w: Vec<::Challenge> = + std::iter::repeat_with(|| ::Challenge::random(&mut rng)) + .take(NUM_VARS) + .collect(); + + let mut split_eq: GruenSplitEqPolynomial = + GruenSplitEqPolynomial::new(&w, BindingOrder::LowToHigh); + + // Walk through all rounds, checking all window sizes that are + // meaningful at that point (at least one unbound variable). + for _round in 0..NUM_VARS { + let num_unbound = split_eq.len().log_2(); + if num_unbound == 0 { + break; + } + + for window_size in 1..=num_unbound { + let (e_out, e_in) = split_eq.E_out_in_for_window(window_size); + let e_active = split_eq.E_active_for_window(window_size); + // By construction, each side represents at least one entry. + debug_assert!(!e_out.is_empty()); + debug_assert!(!e_in.is_empty()); + debug_assert!(!e_active.is_empty()); + + let bits_out = e_out.len().log_2(); + let bits_in = e_in.len().log_2(); + let bits_active = e_active.len().log_2(); + + // One bit is reserved for the current variable in the Gruen + // cubic (the eq polynomial is linear in that bit). + assert_eq!( + bits_out + bits_in + bits_active + 1, + num_unbound, + "bit accounting failed for window_size={window_size} (bits_out={bits_out}, bits_in={bits_in}, bits_active={bits_active}, num_unbound={num_unbound})", + ); + } + + let r = ::Challenge::random(&mut rng); + split_eq.bind(r); + } + } } diff --git a/jolt-core/src/subprotocols/mod.rs b/jolt-core/src/subprotocols/mod.rs index d7cb4465f..5b10836bb 100644 --- a/jolt-core/src/subprotocols/mod.rs +++ b/jolt-core/src/subprotocols/mod.rs @@ -1,6 +1,7 @@ pub mod booleanity; pub mod hamming_weight; pub mod mles_product_sum; +pub mod streaming_schedule; pub mod sumcheck; pub mod sumcheck_prover; pub mod sumcheck_verifier; diff --git a/jolt-core/src/subprotocols/streaming_schedule.rs b/jolt-core/src/subprotocols/streaming_schedule.rs new file mode 100644 index 000000000..fcf2a2148 --- /dev/null +++ b/jolt-core/src/subprotocols/streaming_schedule.rs @@ -0,0 +1,445 @@ +use allocative::Allocative; + +// TODO: Clean up this streaming schedule docstring. +pub trait StreamingSchedule: Send + Sync { + fn is_streaming(&self, round: usize) -> bool; + /// Returns true if we are starting a new streaming window. + /// This will lead to recomputation of the streaming data structure + /// storing the multi-variable polynomial used to computer prover messages. + fn is_window_start(&self, round: usize) -> bool; + /// Returns true of round is the first round of linear proving + fn is_first_linear(&self, round: usize) -> bool; + /// Get the total number of rounds for sumcheck + fn num_rounds(&self) -> usize; + /// Get the number of unbound variables in given round + /// If still in streaming phase, this should be in terms of how many rounds left in + /// given window. + fn num_unbound_vars(&self, round: usize) -> usize; +} + +#[derive(Debug, Clone, Allocative)] +pub struct HalfSplitSchedule { + num_rounds: usize, + constant_window_width: usize, + linear_start: usize, + window_starts: Vec, +} + +impl HalfSplitSchedule { + pub fn new(num_rounds: usize, window_width: usize) -> Self { + let linear_start = num_rounds.div_ceil(2); + + let window_starts = (0..linear_start).step_by(window_width).collect(); + + Self { + num_rounds, + constant_window_width: window_width, + linear_start, + window_starts, + } + } +} + +impl StreamingSchedule for HalfSplitSchedule { + fn is_streaming(&self, round: usize) -> bool { + round < self.linear_start + } + + fn is_window_start(&self, round: usize) -> bool { + self.window_starts.contains(&round) + } + + fn is_first_linear(&self, round: usize) -> bool { + round == self.linear_start + } + + fn num_rounds(&self) -> usize { + self.num_rounds + } + fn num_unbound_vars(&self, round: usize) -> usize { + if round >= self.num_rounds { + return 0; + } + if self.is_streaming(round) { + // Find which window this round belongs to + // and how many rounds are left in that window + + // Find the next window start after this round (or linear_start) + let next_boundary = self + .window_starts + .iter() + .find(|&&start| start > round) + .copied() + .unwrap_or(self.linear_start); + + // Number of unbound vars = rounds left in current window + next_boundary - round + } else { + // In linear phase: standard sumcheck, one variable at a time + // Number of unbound = total remaining rounds + self.num_rounds - round + } + } +} + +/// Streaming schedule where window sizes increase as 1, 2, 3, ... until the +/// streaming phase (first half of the rounds) is filled. The final window is +/// truncated so that the total number of streaming rounds is still roughly half. +#[derive(Debug, Clone, Allocative)] +pub struct IncreasingWindowSchedule { + pub(crate) num_rounds: usize, + pub(crate) linear_start: usize, + pub(crate) window_starts: Vec, +} + +impl IncreasingWindowSchedule { + pub fn new(num_rounds: usize) -> Self { + let linear_start = num_rounds.div_ceil(2); + + let mut window_starts = Vec::new(); + let mut round = 0usize; + let mut width = 1usize; + while round < linear_start { + window_starts.push(round); + let remaining = linear_start - round; + let w = core::cmp::min(width, remaining); + round += w; + width += 1; + } + + Self { + num_rounds, + linear_start, + window_starts, + } + } +} + +impl StreamingSchedule for IncreasingWindowSchedule { + fn is_streaming(&self, round: usize) -> bool { + round < self.linear_start + } + + fn is_window_start(&self, round: usize) -> bool { + self.window_starts.contains(&round) + } + + fn is_first_linear(&self, round: usize) -> bool { + round == self.linear_start + } + + fn num_rounds(&self) -> usize { + self.num_rounds + } + + fn num_unbound_vars(&self, round: usize) -> usize { + if round >= self.num_rounds { + return 0; + } + if self.is_streaming(round) { + let next_boundary = self + .window_starts + .iter() + .find(|&&start| start > round) + .copied() + .unwrap_or(self.linear_start); + next_boundary - round + } else { + self.num_rounds - round + } + } +} + +/// A schedule that disables streaming and runs all sumcheck rounds in +/// the linear-time mode. +#[derive(Debug, Clone, Allocative)] +pub struct LinearOnlySchedule { + num_rounds: usize, +} + +impl LinearOnlySchedule { + pub fn new(num_rounds: usize) -> Self { + Self { num_rounds } + } +} + +impl StreamingSchedule for LinearOnlySchedule { + fn is_streaming(&self, _round: usize) -> bool { + false + } + + fn is_window_start(&self, _round: usize) -> bool { + false + } + + fn is_first_linear(&self, round: usize) -> bool { + round == 0 + } + + fn num_rounds(&self) -> usize { + self.num_rounds + } + + fn num_unbound_vars(&self, round: usize) -> usize { + self.num_rounds.saturating_sub(round) + //if round >= self.num_rounds { + // 0 + //} else { + // self.num_rounds - round + //} + } +} + +#[cfg(test)] +mod tests { + use super::*; + + /// Test basic schedule properties with 10 rounds + #[test] + fn test_half_split_schedule_basic() { + let schedule = HalfSplitSchedule::new(10, 2); + + // First 5 rounds are streaming + assert!(schedule.is_streaming(0)); + assert!(schedule.is_streaming(4)); + assert!(!schedule.is_streaming(5)); + assert!(!schedule.is_streaming(9)); + + // Window starts at 0, 2, 4 + assert!(schedule.is_window_start(0)); + assert!(schedule.is_window_start(2)); + assert!(schedule.is_window_start(4)); + assert!(!schedule.is_window_start(1)); + assert!(!schedule.is_window_start(3)); + + // First linear round is 5 + assert!(schedule.is_first_linear(5)); + assert!(!schedule.is_first_linear(4)); + assert!(!schedule.is_first_linear(6)); + } + + /// Test schedule with 8 rounds and window width 2 + /// + /// # Schedule Layout (8 rounds, window_width=2): + /// ```text + /// Rounds: 0 1 2 3 4 5 6 7 + /// Phase: [--Streaming--] [---Linear---] + /// Windows: [W0---] [W1---] + /// Window#: 0 1 + /// ``` + /// + /// Streaming phase: rounds 0-3 (first half) + /// Linear phase: rounds 4-7 (second half) + #[test] + fn test_8_rounds_window_2() { + let schedule = HalfSplitSchedule::new(8, 2); + + assert_eq!(schedule.num_rounds(), 8); + assert_eq!(schedule.linear_start, 4); + assert_eq!(schedule.window_starts, vec![0, 2]); + + // Phase checks + for round in 0..=3 { + assert!( + schedule.is_streaming(round), + "Round {round} should be streaming", + ); + } + for round in 4..=7 { + assert!( + !schedule.is_streaming(round), + "Round {round} should be linear", + ); + } + + // Window start checks + assert!(schedule.is_window_start(0), "Round 0 starts window 0"); + assert!(!schedule.is_window_start(1), "Round 1 is mid-window"); + assert!(schedule.is_window_start(2), "Round 2 starts window 1"); + assert!(!schedule.is_window_start(3), "Round 3 is mid-window"); + + // First linear check + assert!(schedule.is_first_linear(4), "Round 4 is first linear"); + for round in [0, 1, 2, 3, 5, 6, 7] { + assert!( + !schedule.is_first_linear(round), + "Round {round} is not first linear", + ); + } + } + + /// Test num_unbound_vars with 8 rounds and window width 2 + /// + /// # Unbound Variables Per Round: + /// ```text + /// Round 0 (streaming, window start): 2 unbound (rounds 0,1 left in window) + /// Round 1 (streaming): 1 unbound (round 1 left in window) + /// Round 2 (streaming, window start): 2 unbound (rounds 2,3 left in window) + /// Round 3 (streaming): 1 unbound (round 3 left in window) + /// Round 4 (linear): 4 unbound (4 rounds left total) + /// Round 5 (linear): 3 unbound (3 rounds left total) + /// Round 6 (linear): 2 unbound (2 rounds left total) + /// Round 7 (linear): 1 unbound (1 round left total) + /// ``` + #[test] + fn test_num_unbound_vars_8_rounds_window_3() { + let schedule = HalfSplitSchedule::new(8, 2); + + // Streaming phase unbound vars + assert_eq!( + schedule.num_unbound_vars(0), + 2, + "Round 0: start of window 0" + ); + assert_eq!(schedule.num_unbound_vars(1), 1, "Round 1: end of window 0"); + assert_eq!( + schedule.num_unbound_vars(2), + 2, + "Round 2: start of window 1" + ); + assert_eq!(schedule.num_unbound_vars(3), 1, "Round 3: end of window 1"); + + // Linear phase unbound vars + assert_eq!( + schedule.num_unbound_vars(4), + 4, + "Round 4: 4 rounds remaining" + ); + assert_eq!( + schedule.num_unbound_vars(5), + 3, + "Round 5: 3 rounds remaining" + ); + assert_eq!( + schedule.num_unbound_vars(6), + 2, + "Round 6: 2 rounds remaining" + ); + assert_eq!( + schedule.num_unbound_vars(7), + 1, + "Round 7: 1 round remaining" + ); + } + + /// Test schedule with 8 rounds and window width 3 + /// + /// # Schedule Layout: + /// ```text + /// Rounds: 0 1 2 3 4 5 6 7 + /// Phase: [--Streaming--] [---Linear---] + /// Windows: [W0-------] [W1] + /// ``` + /// + /// Window 0: rounds 0-2 (3 rounds) + /// Window 1: round 3 only (truncated at linear boundary) + #[test] + fn test_8_rounds_window_3() { + let schedule = HalfSplitSchedule::new(8, 3); + + assert_eq!(schedule.window_starts, vec![0, 3]); + + // Unbound vars in streaming phase + assert_eq!( + schedule.num_unbound_vars(0), + 3, + "Round 0: 3 rounds in window" + ); + assert_eq!(schedule.num_unbound_vars(1), 2, "Round 1: 2 rounds left"); + assert_eq!(schedule.num_unbound_vars(2), 1, "Round 2: 1 round left"); + assert_eq!(schedule.num_unbound_vars(3), 1, "Round 3: truncated window"); + + // Unbound vars in linear phase + assert_eq!(schedule.num_unbound_vars(4), 4); + assert_eq!(schedule.num_unbound_vars(7), 1); + } + + /// Test edge case: window width equals streaming phase length + #[test] + fn test_single_window() { + let schedule = HalfSplitSchedule::new(8, 4); + + // Only one window in streaming phase + assert_eq!(schedule.window_starts, vec![0]); + + assert!(schedule.is_window_start(0)); + assert!(!schedule.is_window_start(1)); + assert!(!schedule.is_window_start(2)); + assert!(!schedule.is_window_start(3)); + + // All streaming rounds in single window + assert_eq!(schedule.num_unbound_vars(0), 4); + assert_eq!(schedule.num_unbound_vars(1), 3); + assert_eq!(schedule.num_unbound_vars(2), 2); + assert_eq!(schedule.num_unbound_vars(3), 1); + } + + /// Test edge case: window width larger than streaming phase + #[test] + fn test_oversized_window() { + let schedule = HalfSplitSchedule::new(8, 10); + + // Still only one window, even though width > streaming phase + assert_eq!(schedule.window_starts, vec![0]); + + // Unbound vars limited by linear_start boundary + assert_eq!(schedule.num_unbound_vars(0), 4, "Limited by linear_start"); + assert_eq!(schedule.num_unbound_vars(1), 3); + assert_eq!(schedule.num_unbound_vars(2), 2); + assert_eq!(schedule.num_unbound_vars(3), 1); + } + + /// Test odd number of rounds + #[test] + fn test_odd_rounds() { + let schedule = HalfSplitSchedule::new(7, 2); + + // (7 + 1) / 2 = 4, so streaming is 0-3, linear is 4-6 + assert_eq!(schedule.linear_start, 4); + + assert!(schedule.is_streaming(3)); + assert!(!schedule.is_streaming(4)); + assert_eq!(schedule.num_rounds(), 7); + } + + /// Test schedule properties are consistent + #[test] + fn test_schedule_invariants() { + for num_rounds in [4, 8, 16, 32] { + for window_width in [1, 2, 3, 4, 8] { + let schedule = HalfSplitSchedule::new(num_rounds, window_width); + + // First window always starts at 0 + assert!(schedule.is_window_start(0)); + + // Linear start is roughly half + assert_eq!(schedule.linear_start, num_rounds.div_ceil(2)); + //assert_eq!(schedule.linear_start, (num_rounds + 1) / 2); + + // All window starts are in streaming phase + for &start in &schedule.window_starts { + assert!(schedule.is_streaming(start)); + } + + // Unbound vars decrease monotonically within windows + for round in 0..schedule.linear_start - 1 { + if !schedule.is_window_start(round + 1) { + assert!( + schedule.num_unbound_vars(round) > schedule.num_unbound_vars(round + 1), + "Unbound vars should decrease within window" + ); + } + } + + // Linear phase unbound vars decrease by 1 each round + for round in schedule.linear_start..num_rounds - 1 { + assert_eq!( + schedule.num_unbound_vars(round) - schedule.num_unbound_vars(round + 1), + 1, + "Linear phase should decrease by 1 each round" + ); + } + } + } + } +} diff --git a/jolt-core/src/utils/expanding_table.rs b/jolt-core/src/utils/expanding_table.rs index bec7ac712..8ab24abb0 100644 --- a/jolt-core/src/utils/expanding_table.rs +++ b/jolt-core/src/utils/expanding_table.rs @@ -22,6 +22,10 @@ impl ExpandingTable { self.len } + pub fn order(&self) -> BindingOrder { + self.binding_order + } + /// Initializes an `ExpandingTable` with the given `capacity`. #[tracing::instrument(skip_all, name = "ExpandingTable::new")] pub fn new(capacity: usize, binding_order: BindingOrder) -> Self { @@ -52,6 +56,7 @@ impl ExpandingTable { /// Updates this table (expanding it by a factor of 2) to incorporate /// the new random challenge `r_j`. + /// TODO: this is bad parallelisation. #[tracing::instrument(skip_all, name = "ExpandingTable::update")] pub fn update(&mut self, r_j: F::Challenge) { match self.binding_order { diff --git a/jolt-core/src/zkvm/prover.rs b/jolt-core/src/zkvm/prover.rs index e0f4c76b5..ae1735b16 100644 --- a/jolt-core/src/zkvm/prover.rs +++ b/jolt-core/src/zkvm/prover.rs @@ -1,3 +1,4 @@ +use crate::subprotocols::streaming_schedule::LinearOnlySchedule; use std::{ collections::HashMap, fs::File, @@ -97,6 +98,8 @@ pub struct JoltCpuProver< pub program_io: JoltDevice, pub lazy_trace: LazyTraceIterator, pub trace: Arc>, + pub checkpoints: Vec>, + pub checkpoint_interval: usize, pub advice: JoltAdvice, pub twist_sumcheck_switch_index: usize, pub unpadded_trace_len: usize, @@ -108,7 +111,6 @@ pub struct JoltCpuProver< pub final_ram_state: Vec, pub one_hot_params: OneHotParams, } - impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscript: Transcript> JoltCpuProver<'a, F, PCS, ProofTranscript> { @@ -130,6 +132,21 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip program_size: Some(preprocessing.memory_layout.program_size), }; + // TODO: Currently we're manifesting the entire trace + // There is some debate on how to stream this efficiently + // We can move forward with streaming implementations assuming this will be + // fixed in the coming days + + //let checkpoint_interval = 256; + //let (checkpoints, _jolt_device) = trace_checkpoints( + // elf_contents, + // inputs, + // untrusted_advice, + // trusted_advice, + // &memory_config, + // checkpoint_interval, + //); + let (lazy_trace, trace, final_memory_state, program_io) = { let _pprof_trace = pprof_scope!("trace"); guest::program::trace( @@ -141,6 +158,37 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip &memory_config, ) }; + + //#[cfg(debug_assertions)] + //{ + // for (time_step_idx, expected_cycle) in trace.iter().enumerate() { + // // Calculate which checkpoint and offset + // let checkpoint_idx = time_step_idx / checkpoint_interval; + // let offset = time_step_idx % checkpoint_interval; + // + // // Clone the checkpoint and advance to target + // let mut iter = checkpoints[checkpoint_idx].clone(); + // + // // Skip offset cycles + // for _ in 0..offset { + // iter.next(); + // } + // + // // Get the cycle from checkpoint + // let checkpoint_cycle = iter.next().expect("checkpoint should have cycle"); + // + // // Assert they match + // assert_eq!( + // &checkpoint_cycle, expected_cycle, + // "Mismatch at cycle {time_step_idx}: checkpoint != trace", + // ); + // } + // println!( + // "✓ All {} cycles match between checkpoints and full trace", + // trace.len() + // ); + //} + // let num_riscv_cycles: usize = trace .par_iter() .map(|cycle| { @@ -163,14 +211,20 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip trace.len(), ); - Self::gen_from_trace( + let mut prover = Self::gen_from_trace( preprocessing, lazy_trace, trace, program_io, trusted_advice_commitment, final_memory_state, - ) + ); + + // Set checkpoints after construction + // Vec> + prover.checkpoints = Vec::new(); + prover.checkpoint_interval = 0; + prover } pub fn gen_from_trace( @@ -240,6 +294,8 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip program_io, lazy_trace, trace: trace.into(), + checkpoints: Vec::new(), // Empty by default + checkpoint_interval: 0, // Default value advice: JoltAdvice { untrusted_advice_polynomial: None, trusted_advice_commitment, @@ -500,10 +556,17 @@ impl<'a, F: JoltField, PCS: StreamingCommitmentScheme, ProofTranscrip &mut self.transcript, ); + // Every sum-check with num_rounds > 1 requires a schedule + // which dictates the compute_message and bind methods + let schedule = LinearOnlySchedule::new(uni_skip_state.tau.len() - 1); + //let schedule = HalfSplitSchedule::new(uni_skip_state.tau.len() - 1, 3); let mut spartan_outer_remaining = OuterRemainingSumcheckProver::gen( Arc::clone(&self.trace), + &self.checkpoints, + self.checkpoint_interval, &self.preprocessing.bytecode, &uni_skip_state, + schedule, ); let (sumcheck_proof, _r_stage1) = BatchedSumcheck::prove( @@ -1426,7 +1489,6 @@ mod tests { init_memory_state, 1 << 16, ); - let prover = RV64IMACProver::gen_from_trace( &preprocessing, lazy_trace, diff --git a/jolt-core/src/zkvm/r1cs/evaluation.rs b/jolt-core/src/zkvm/r1cs/evaluation.rs index 9b6b2767b..f1af9c456 100644 --- a/jolt-core/src/zkvm/r1cs/evaluation.rs +++ b/jolt-core/src/zkvm/r1cs/evaluation.rs @@ -138,6 +138,29 @@ pub struct AzFirstGroup { pub must_start_sequence: bool, // NextIsVirtual && !NextIsFirstInSequence } +impl AzFirstGroup { + /// Fused multiply-add into an unreduced accumulator using Lagrange weights `w` + /// over the univariate-skip base window. This mirrors `az_at_r_first_group` + /// but keeps the result in an `Acc5U` accumulator without reducing. + #[inline(always)] + pub fn fmadd_at_r( + &self, + w: &[F; OUTER_UNIVARIATE_SKIP_DOMAIN_SIZE], + acc: &mut Acc5U, + ) { + acc.fmadd(&w[0], &self.not_load_store); + acc.fmadd(&w[1], &self.load_a); + acc.fmadd(&w[2], &self.load_b); + acc.fmadd(&w[3], &self.store); + acc.fmadd(&w[4], &self.add_sub_mul); + acc.fmadd(&w[5], &self.not_add_sub_mul); + acc.fmadd(&w[6], &self.assert_flag); + acc.fmadd(&w[7], &self.should_jump); + acc.fmadd(&w[8], &self.virtual_instruction); + acc.fmadd(&w[9], &self.must_start_sequence); + } +} + /// Magnitudes for the first group (kept small: bool/u64/S64) #[derive(Clone, Copy, Debug)] pub struct BzFirstGroup { @@ -153,6 +176,29 @@ pub struct BzFirstGroup { pub one_minus_do_not_update_unexpanded_pc: bool, // 1 - DoNotUpdateUnexpandedPC } +impl BzFirstGroup { + /// Fused multiply-add into an unreduced accumulator using Lagrange weights `w` + /// over the univariate-skip base window. This mirrors `bz_at_r_first_group` + /// but keeps the result in an `Acc6S` accumulator without reducing. + #[inline(always)] + pub fn fmadd_at_r( + &self, + w: &[F; OUTER_UNIVARIATE_SKIP_DOMAIN_SIZE], + acc: &mut Acc6S, + ) { + acc.fmadd(&w[0], &self.ram_addr); + acc.fmadd(&w[1], &self.ram_read_minus_ram_write); + acc.fmadd(&w[2], &self.ram_read_minus_rd_write); + acc.fmadd(&w[3], &self.rs2_minus_ram_write); + acc.fmadd(&w[4], &self.left_lookup); + acc.fmadd(&w[5], &self.left_lookup_minus_left_input); + acc.fmadd(&w[6], &self.lookup_output_minus_one); + acc.fmadd(&w[7], &self.next_unexp_pc_minus_lookup_output); + acc.fmadd(&w[8], &self.next_pc_minus_pc_plus_one); + acc.fmadd(&w[9], &self.one_minus_do_not_update_unexpanded_pc); + } +} + /// Guards for the second group (all booleans except two u8 flags) #[derive(Clone, Copy, Debug)] pub struct AzSecondGroup { @@ -167,6 +213,28 @@ pub struct AzSecondGroup { pub not_jump_or_branch: bool, // !(Jump || ShouldBranch) } +impl AzSecondGroup { + /// Fused multiply-add into an unreduced accumulator using Lagrange weights `w` + /// over the univariate-skip base window. This mirrors `az_at_r_second_group` + /// but keeps the result in an `Acc5U` accumulator without reducing. + #[inline(always)] + pub fn fmadd_at_r( + &self, + w: &[F; OUTER_UNIVARIATE_SKIP_DOMAIN_SIZE], + acc: &mut Acc5U, + ) { + acc.fmadd(&w[0], &self.load_or_store); + acc.fmadd(&w[1], &self.add); + acc.fmadd(&w[2], &self.sub); + acc.fmadd(&w[3], &self.mul); + acc.fmadd(&w[4], &self.not_add_sub_mul_advice); + acc.fmadd(&w[5], &self.write_lookup_to_rd); + acc.fmadd(&w[6], &self.write_pc_to_rd); + acc.fmadd(&w[7], &self.should_branch); + acc.fmadd(&w[8], &self.not_jump_or_branch); + } +} + /// Magnitudes for the second group (mixed precision up to S160) #[derive(Clone, Copy, Debug)] pub struct BzSecondGroup { @@ -181,6 +249,28 @@ pub struct BzSecondGroup { pub next_unexp_pc_minus_expected: S64, // NextUnexpandedPC - (UnexpandedPC + const) } +impl BzSecondGroup { + /// Fused multiply-add into an unreduced accumulator using Lagrange weights `w` + /// over the univariate-skip base window. This mirrors `bz_at_r_second_group` + /// but keeps the result in an `Acc7S` accumulator without reducing. + #[inline(always)] + pub fn fmadd_at_r( + &self, + w: &[F; OUTER_UNIVARIATE_SKIP_DOMAIN_SIZE], + acc: &mut Acc7S, + ) { + acc.fmadd(&w[0], &self.ram_addr_minus_rs1_plus_imm); + acc.fmadd(&w[1], &self.right_lookup_minus_add_result); + acc.fmadd(&w[2], &self.right_lookup_minus_sub_result); + acc.fmadd(&w[3], &self.right_lookup_minus_product); + acc.fmadd(&w[4], &self.right_lookup_minus_right_input); + acc.fmadd(&w[5], &self.rd_write_minus_lookup_output); + acc.fmadd(&w[6], &self.rd_write_minus_pc_plus_const); + acc.fmadd(&w[7], &self.next_unexp_pc_minus_pc_plus_imm); + acc.fmadd(&w[8], &self.next_unexp_pc_minus_expected); + } +} + /// Unified evaluator wrapper with typed accessors for both groups #[derive(Clone, Copy, Debug)] pub struct R1CSEval<'a, F: JoltField> { @@ -293,6 +383,22 @@ impl<'a, F: JoltField> R1CSEval<'a, F> { acc.barrett_reduce() } + /// Fused accumulate of first-group Az and Bz into unreduced accumulators using + /// Lagrange weights `w`. This keeps everything in unreduced form; callers are + /// responsible for reducing at the end. + #[inline(always)] + pub fn fmadd_first_group_at_r( + &self, + w: &[F; OUTER_UNIVARIATE_SKIP_DOMAIN_SIZE], + acc_az: &mut Acc5U, + acc_bz: &mut Acc6S, + ) { + let az = self.eval_az_first_group(); + let bz = self.eval_bz_first_group(); + az.fmadd_at_r(w, acc_az); + bz.fmadd_at_r(w, acc_bz); + } + /// Product Az·Bz at the j-th extended uniskip target for the first group (uses precomputed weights). pub fn extended_azbz_product_first_group(&self, j: usize) -> S192 { let coeffs_i32: &[i32; OUTER_UNIVARIATE_SKIP_DOMAIN_SIZE] = &COEFFS_PER_J[j]; @@ -519,6 +625,21 @@ impl<'a, F: JoltField> R1CSEval<'a, F> { acc.barrett_reduce() } + /// Fused accumulate of second-group Az and Bz into unreduced accumulators + /// using Lagrange weights `w`. This keeps everything in unreduced form; callers + /// are responsible for reducing at the end. + #[inline(always)] + pub fn fmadd_second_group_at_r( + &self, + w: &[F; OUTER_UNIVARIATE_SKIP_DOMAIN_SIZE], + acc_az: &mut Acc5U, + acc_bz: &mut Acc7S, + ) { + let az = self.eval_az_second_group(); + let bz = self.eval_bz_second_group(); + az.fmadd_at_r(w, acc_az); + bz.fmadd_at_r(w, acc_bz); + } /// Product Az·Bz at the j-th extended uniskip target for the second group (uses precomputed weights). pub fn extended_azbz_product_second_group(&self, j: usize) -> S192 { #[cfg(test)] diff --git a/jolt-core/src/zkvm/r1cs/inputs.rs b/jolt-core/src/zkvm/r1cs/inputs.rs index 7369f2782..b61332166 100644 --- a/jolt-core/src/zkvm/r1cs/inputs.rs +++ b/jolt-core/src/zkvm/r1cs/inputs.rs @@ -25,6 +25,7 @@ use ark_ff::biginteger::{S128, S64}; use common::constants::XLEN; use std::fmt::Debug; use tracer::instruction::Cycle; +use tracer::LazyTraceIterator; use strum::IntoEnumIterator; @@ -264,6 +265,149 @@ pub struct R1CSCycleInputs { } impl R1CSCycleInputs { + /// Build directly from checkpoints and preprocessing, + /// mirroring the optimized semantics used in `compute_claimed_r1cs_input_evals`. + pub fn from_checkpoints( + bytecode_preprocessing: &BytecodePreprocessing, + checkpoints: &[std::iter::Take], + checkpoint_interval: usize, + t: usize, + ) -> Self + where + F: JoltField, + { + // Calculate checkpoint and offset + let checkpoint_idx = t / checkpoint_interval; + let offset = t % checkpoint_interval; + + // Clone and advance to target cycle + let mut iter = checkpoints[checkpoint_idx].clone(); + for _ in 0..offset { + iter.next(); + } + let cycle = iter.next().expect("cycle should exist"); + + // Get next cycle if needed + let next_cycle = { + let mut next_iter = iter.clone(); + next_iter.next() + }; + + let instr = cycle.instruction(); + let flags_view = instr.circuit_flags(); + let instruction_flags = instr.instruction_flags(); + let norm = instr.normalize(); + + // Instruction inputs and product + let (left_input, right_i128) = LookupQuery::::to_instruction_inputs(&cycle); + let left_s64: S64 = S64::from_u64(left_input); + let right_mag = right_i128.unsigned_abs(); + debug_assert!( + right_mag <= u64::MAX as u128, + "RightInstructionInput overflow at row {t}: |{right_i128}| > 2^64-1" + ); + let right_input = S64::from_u64_with_sign(right_mag as u64, right_i128 >= 0); + let right_s128: S128 = S128::from_i128(right_i128); + let product: S128 = left_s64.mul_trunc::<2, 2>(&right_s128); + + // Lookup operands and output + let (left_lookup, right_lookup) = LookupQuery::::to_lookup_operands(&cycle); + let lookup_output = LookupQuery::::to_lookup_output(&cycle); + + // Registers + let rs1_read_value = cycle.rs1_read().1; + let rs2_read_value = cycle.rs2_read().1; + let rd_write_value = cycle.rd_write().2; + + // RAM + let ram_addr = cycle.ram_access().address() as u64; + let (ram_read_value, ram_write_value) = match cycle.ram_access() { + tracer::instruction::RAMAccess::Read(r) => (r.value, r.value), + tracer::instruction::RAMAccess::Write(w) => (w.pre_value, w.post_value), + tracer::instruction::RAMAccess::NoOp => (0u64, 0u64), + }; + + // PCs + let pc = bytecode_preprocessing.get_pc(&cycle) as u64; + let next_pc = if let Some(nc) = &next_cycle { + bytecode_preprocessing.get_pc(nc) as u64 + } else { + 0u64 + }; + let unexpanded_pc = norm.address as u64; + let next_unexpanded_pc = if let Some(nc) = &next_cycle { + nc.instruction().normalize().address as u64 + } else { + 0u64 + }; + + // Immediate + let imm_i128 = norm.operands.imm; + let imm_mag = imm_i128.unsigned_abs(); + debug_assert!( + imm_mag <= u64::MAX as u128, + "Imm overflow at row {t}: |{imm_i128}| > 2^64-1" + ); + let imm = S64::from_u64_with_sign(imm_mag as u64, imm_i128 >= 0); + + // Flags and derived booleans + let mut flags = [false; NUM_CIRCUIT_FLAGS]; + for flag in CircuitFlags::iter() { + flags[flag] = flags_view[flag]; + } + let next_is_noop = if let Some(nc) = &next_cycle { + nc.instruction().instruction_flags()[InstructionFlags::IsNoop] + } else { + false // There is no next cycle, so cannot be a noop + }; + let should_jump = flags_view[CircuitFlags::Jump] && !next_is_noop; + let should_branch = instruction_flags[InstructionFlags::Branch] && (lookup_output == 1); + + // Write-to-Rd selectors (masked by flags) + let write_lookup_output_to_rd_addr = flags_view[CircuitFlags::WriteLookupOutputToRD] + && instruction_flags[InstructionFlags::IsRdNotZero]; + let write_pc_to_rd_addr = + flags_view[CircuitFlags::Jump] && instruction_flags[InstructionFlags::IsRdNotZero]; + + let (next_is_virtual, next_is_first_in_sequence) = if let Some(nc) = &next_cycle { + let flags = nc.instruction().circuit_flags(); + ( + flags[CircuitFlags::VirtualInstruction], + flags[CircuitFlags::IsFirstInSequence], + ) + } else { + (false, false) + }; + + Self { + left_input, + right_input, + product, + left_lookup, + right_lookup, + lookup_output, + rs1_read_value, + rs2_read_value, + rd_write_value, + ram_addr, + ram_read_value, + ram_write_value, + pc, + next_pc, + unexpanded_pc, + next_unexpanded_pc, + imm, + flags, + next_is_noop, + should_jump, + should_branch, + write_lookup_output_to_rd_addr, + write_pc_to_rd_addr, + next_is_virtual, + next_is_first_in_sequence, + } + } + /// Build directly from the execution trace and preprocessing, /// mirroring the optimized semantics used in `compute_claimed_r1cs_input_evals`. pub fn from_trace( diff --git a/jolt-core/src/zkvm/spartan/outer.rs b/jolt-core/src/zkvm/spartan/outer.rs index 6b08d6692..17b8fca40 100644 --- a/jolt-core/src/zkvm/spartan/outer.rs +++ b/jolt-core/src/zkvm/spartan/outer.rs @@ -4,25 +4,30 @@ use allocative::Allocative; use ark_std::Zero; use rayon::prelude::*; use tracer::instruction::Cycle; +use tracer::LazyTraceIterator; +use crate::field::BarrettReduce; use crate::field::{FMAdd, JoltField, MontgomeryReduce}; use crate::poly::dense_mlpoly::DensePolynomial; use crate::poly::eq_poly::EqPolynomial; use crate::poly::lagrange_poly::LagrangePolynomial; -use crate::poly::multilinear_polynomial::BindingOrder; +use crate::poly::multilinear_polynomial::{BindingOrder, PolynomialBinding}; +use crate::poly::multiquadratic_poly::MultiquadraticPolynomial; use crate::poly::opening_proof::{ OpeningAccumulator, OpeningPoint, ProverOpeningAccumulator, SumcheckId, VerifierOpeningAccumulator, BIG_ENDIAN, LITTLE_ENDIAN, }; use crate::poly::split_eq_poly::GruenSplitEqPolynomial; use crate::poly::unipoly::UniPoly; +use crate::subprotocols::streaming_schedule::StreamingSchedule; use crate::subprotocols::sumcheck_prover::{ SumcheckInstanceProver, UniSkipFirstRoundInstanceProver, }; use crate::subprotocols::sumcheck_verifier::SumcheckInstanceVerifier; use crate::subprotocols::univariate_skip::{build_uniskip_first_round_poly, UniSkipState}; use crate::transcripts::Transcript; -use crate::utils::accumulation::Acc8S; +use crate::utils::accumulation::{Acc5U, Acc6S, Acc7S, Acc8S}; +use crate::utils::expanding_table::ExpandingTable; use crate::utils::math::Math; #[cfg(feature = "allocative")] use crate::utils::profiling::print_data_structure_heap_usage; @@ -43,6 +48,7 @@ use allocative::FlameGraphBuilder; /// Degree bound of the sumcheck round polynomials for [`OuterRemainingSumcheckVerifier`]. const OUTER_REMAINING_DEGREE_BOUND: usize = 3; +const INFINITY: usize = 2; // 2 represents ∞ in base-3 // Spartan Outer sumcheck // (with univariate-skip first round on Z, and no Cz term given all eq conditional constraints) @@ -204,29 +210,60 @@ impl UniSkipFirstRoundInstanceProver /// SumcheckInstance for Spartan outer rounds after the univariate-skip first round. /// Round 0 in this instance corresponds to the "streaming" round; subsequent rounds /// use the remaining linear-time algorithm over cycle variables. +//#[derive(Allocative)] +//pub struct OuterRemainingSumcheckProver { +// #[allocative(skip)] +// bytecode_preprocessing: BytecodePreprocessing, +// #[allocative(skip)] +// trace: Arc>, +// split_eq_poly: GruenSplitEqPolynomial, +// az: Option>, +// bz: Option>, +// t_prime_poly: Option>, // multiquadratic polynomial used to answer queries in a streaming window +// /// The first round evals (t0, t_inf) computed from a streaming pass over the trace +// first_round_evals: (F, F), +// #[allocative(skip)] +// params: OuterRemainingSumcheckParams, +//} +// #[derive(Allocative)] -pub struct OuterRemainingSumcheckProver { +pub struct OuterRemainingSumcheckProver<'a, F: JoltField, S: StreamingSchedule + Allocative> { #[allocative(skip)] bytecode_preprocessing: BytecodePreprocessing, #[allocative(skip)] trace: Arc>, + /// Split-eq instance used for both streaming and linear phases of the + /// outer Spartan sumcheck over cycle variables. split_eq_poly: GruenSplitEqPolynomial, - az: DensePolynomial, - bz: DensePolynomial, + az: Option>, + bz: Option>, + t_prime_poly: Option>, // multiquadratic polynomial used to answer queries in a streaming window + r_grid: ExpandingTable, // hadamard product of (1 - r_j, r_j) for bound variables so far to help with streaming /// The first round evals (t0, t_inf) computed from a streaming pass over the trace - first_round_evals: (F, F), #[allocative(skip)] params: OuterRemainingSumcheckParams, + lagrange_evals_r0: [F; 10], + schedule: S, + t_0: Option, + t_inf: Option, + #[allocative(skip)] + checkpoints: &'a [std::iter::Take], + checkpoint_interval: usize, } -impl OuterRemainingSumcheckProver { +impl<'a, F: JoltField, S: StreamingSchedule + Allocative> OuterRemainingSumcheckProver<'a, F, S> { #[tracing::instrument(skip_all, name = "OuterRemainingSumcheckProver::gen")] pub fn gen( trace: Arc>, + checkpoints: &'a [std::iter::Take], // Add lifetime 'a, use slice + checkpoint_interval: usize, bytecode_preprocessing: &BytecodePreprocessing, uni: &UniSkipState, + schedule: S, ) -> Self { let bytecode_preprocessing = bytecode_preprocessing.clone(); + let n_cycle_vars = trace.len().log_2(); + let outer_params = OuterRemainingSumcheckParams::new(n_cycle_vars, uni); let lagrange_evals_r = LagrangePolynomial::::evals::< F::Challenge, @@ -248,26 +285,273 @@ impl OuterRemainingSumcheckProver { Some(lagrange_tau_r0), ); - let (t0, t_inf, az_bound, bz_bound) = Self::compute_first_quadratic_evals_and_bound_polys( - &bytecode_preprocessing, - &trace, - &lagrange_evals_r, - &split_eq_poly, - ); - - let n_cycle_vars = trace.len().ilog2() as usize; + // NOTE: The API changed recently: Both binding orders will technically pass + // based on current implementation. + let mut r_grid = ExpandingTable::new(1 << n_cycle_vars, BindingOrder::LowToHigh); + r_grid.reset(F::one()); Self { split_eq_poly, bytecode_preprocessing, trace, - az: az_bound, - bz: bz_bound, - first_round_evals: (t0, t_inf), - params: OuterRemainingSumcheckParams::new(n_cycle_vars, uni), + checkpoints, + checkpoint_interval, + az: None, + bz: None, + t_prime_poly: None, + r_grid, + params: outer_params, + lagrange_evals_r0: lagrange_evals_r, + schedule, + t_0: None, + t_inf: None, } } + // gets the evaluations of az(x, {0,1}^log(jlen), r) and bz(x, {0,1}^log(jlen), r) + // where x is determined by the bit decomposition of offset + // and r is log(klen) variables + // this is used both in window computation (jlen is window size) + // and in converting to linear time (offset is 0, log(jlen) is the number of unbound variables) + // The caller must pass in `scaled_w`, the tensor product of the Lagrange weights + // at r0 with the current `r_grid` weights: + // scaled_w[k][t] = lagrange_evals_r0[t] * r_grid[k] (for klen > 1) + // and scaled_w[0][t] = lagrange_evals_r0[t] when klen == 1 (no r_grid factor). + #[allow(clippy::too_many_arguments)] + fn build_grids( + &self, + grid_az: &mut [F], + grid_bz: &mut [F], + jlen: usize, + klen: usize, + offset: usize, + parallel: bool, + scaled_w: &[[F; OUTER_UNIVARIATE_SKIP_DOMAIN_SIZE]], + ) { + let preprocess = &self.bytecode_preprocessing; + let checkpoints = &self.checkpoints; + let checkpoint_interval = self.checkpoint_interval; + let trace = &self.trace; + debug_assert_eq!(scaled_w.len(), klen); + debug_assert_eq!(grid_az.len(), jlen); + debug_assert_eq!(grid_bz.len(), jlen); + + // Unreduced accumulators per j for Az and the two Bz groups. + let mut acc_az = vec![Acc5U::::zero(); jlen]; + let mut acc_bz_first = vec![Acc6S::::zero(); jlen]; + let mut acc_bz_second = vec![Acc7S::::zero(); jlen]; + + if !parallel { + // Sequential traversal: iterate over j first and then k so that we + // walk consecutive cycles in memory (full_idx increases by 1 inside + // the inner loop). + for j in 0..jlen { + for k in 0..klen { + let full_idx = offset + j * klen + k; + let current_step_idx = full_idx >> 1; + let selector = (full_idx & 1) == 1; + + // TODO: use the lazy trace iterator here instead of indexing directly into the + // trace that is all that needs to change for now + let row_inputs_prime = R1CSCycleInputs::from_checkpoints::( + preprocess, + checkpoints, + checkpoint_interval, + current_step_idx, + ); + let _row_inputs = + R1CSCycleInputs::from_trace::(preprocess, trace, current_step_idx); + let eval = R1CSEval::::from_cycle_inputs(&row_inputs_prime); + let w_k = &scaled_w[k]; + + if !selector { + eval.fmadd_first_group_at_r(w_k, &mut acc_az[j], &mut acc_bz_first[j]); + } else { + eval.fmadd_second_group_at_r(w_k, &mut acc_az[j], &mut acc_bz_second[j]); + } + } + } + } else { + // Parallel traversal over j for the linear-time prover. + // Each worker owns disjoint accumulators for a fixed j, so there + // are no data races. We reuse the precomputed scaled Lagrange weights + // per k from `scaled_w`, avoiding redundant tensor products. + acc_az + .par_iter_mut() + .with_min_len(4096) + .zip(acc_bz_first.par_iter_mut()) + .zip(acc_bz_second.par_iter_mut()) + .enumerate() + .for_each(|(j, ((acc_az_j, acc_bz_first_j), acc_bz_second_j))| { + for k in 0..klen { + let full_idx = offset + j * klen + k; + let current_step_idx = full_idx >> 1; + let selector = (full_idx & 1) == 1; + + let row_inputs = + R1CSCycleInputs::from_trace::(preprocess, trace, current_step_idx); + let eval = R1CSEval::::from_cycle_inputs(&row_inputs); + let w_k = &scaled_w[k]; + + if !selector { + eval.fmadd_first_group_at_r(w_k, acc_az_j, acc_bz_first_j); + } else { + eval.fmadd_second_group_at_r(w_k, acc_az_j, acc_bz_second_j); + } + } + }); + } + + // Final reductions once per j. + //for j in 0..jlen { + // let az_j = acc_az[j].barrett_reduce(); + // let bz_first_j = acc_bz_first[j].barrett_reduce(); + // let bz_second_j = acc_bz_second[j].barrett_reduce(); + // grid_az[j] = az_j; + // grid_bz[j] = bz_first_j + bz_second_j; + //} + + let grid_az_ptr = grid_az.as_mut_ptr() as usize; + let grid_bz_ptr = grid_bz.as_mut_ptr() as usize; + let chunk_size = 4096; + // jlen + chunk_size - 1) / chunk_size + let num_chunks = jlen.div_ceil(chunk_size); + (0..num_chunks).into_par_iter().for_each(move |chunk_idx| { + let start = chunk_idx * chunk_size; + let end = (start + chunk_size).min(jlen); + + let az_ptr = grid_az_ptr as *mut F; + let bz_ptr = grid_bz_ptr as *mut F; + + for j in start..end { + let az_j = acc_az[j].barrett_reduce(); + let bz_first_j = acc_bz_first[j].barrett_reduce(); + let bz_second_j = acc_bz_second[j].barrett_reduce(); + + unsafe { + *az_ptr.add(j) = az_j; + *bz_ptr.add(j) = bz_first_j + bz_second_j; + } + } + }); + } + + // returns the grid of evaluations on {0,1,inf}^window_size + // touches each cycle of the trace exactly once and in order! + fn get_grid_gen(&mut self, window_size: usize) { + // Use the split-eq instance to derive the current window + // factorisation of Eq over the unbound cycle bits. This keeps the + // semantics in one place (see `split_eq_poly::E_out_in_for_window`). + let split_eq = &self.split_eq_poly; + + // helper constants + let three_pow_dim = 3_usize.pow(window_size as u32); + let jlen = 1 << window_size; + let klen = 1 << split_eq.num_challenges(); + + // Precompute the tensor product of the Lagrange weights at r0 with the + // current r_grid weights so that all calls into `build_grids` can reuse + // these scaled tables. + let lagrange_evals_r = &self.lagrange_evals_r0; + let r_grid = &self.r_grid; + let mut scaled_w = vec![[F::zero(); OUTER_UNIVARIATE_SKIP_DOMAIN_SIZE]; klen]; + if klen > 1 { + debug_assert_eq!(klen, r_grid.len()); + for k in 0..klen { + let weight = r_grid[k]; + let row = &mut scaled_w[k]; + for t in 0..OUTER_UNIVARIATE_SKIP_DOMAIN_SIZE { + row[t] = lagrange_evals_r[t] * weight; + } + } + } else { + debug_assert_eq!(klen, 1); + scaled_w[0].copy_from_slice(lagrange_evals_r); + } + + // Head-factor eq tables for this window. + let (e_out, e_in) = split_eq.E_out_in_for_window(window_size); + let e_in_len = e_in.len(); + + // main logic: parallelize outer sum over E_out_current; for each x_out, + // perform an inner unreduced accumulation over E_in_current and only + // reduce once per grid cell, then multiply by E_out unreduced. + let res_unr = e_out + .par_iter() + .enumerate() + .map(|(out_idx, out_val)| { + // Local unreduced accumulators and scratch buffers for this out_idx. + let mut local_res_unr = vec![F::Unreduced::<9>::zero(); three_pow_dim]; + let mut buff_a: Vec = vec![F::zero(); three_pow_dim]; + let mut buff_b = vec![F::zero(); three_pow_dim]; + let mut tmp = vec![F::zero(); three_pow_dim]; + let mut grid_a = vec![F::zero(); jlen]; + let mut grid_b = vec![F::zero(); jlen]; + + for (in_idx, in_val) in e_in.iter().enumerate() { + let i = out_idx * e_in_len + in_idx; + + // Reuse the same grid buffers across all x_in for this x_out. + grid_a.fill(F::zero()); + grid_b.fill(F::zero()); + // Keep this call sequential to avoid nested rayon parallelism. + self.build_grids( + &mut grid_a, + &mut grid_b, + jlen, + klen, + i * jlen * klen, + false, + &scaled_w, + ); + + // Extrapolate grid_a and grid_b from {0,1}^window_size to {0,1,∞}^window_size. + MultiquadraticPolynomial::::expand_linear_grid_to_multiquadratic( + &grid_a, + &mut buff_a, + &mut tmp, + window_size, + ); + MultiquadraticPolynomial::::expand_linear_grid_to_multiquadratic( + &grid_b, + &mut buff_b, + &mut tmp, + window_size, + ); + + let e_in_val = *in_val; + for idx in 0..three_pow_dim { + let val = buff_a[idx] * buff_b[idx]; + local_res_unr[idx] += e_in_val.mul_unreduced::<9>(val); + } + } + + // Fold in E_out for this x_out. + let e_out_val = *out_val; + for idx in 0..three_pow_dim { + let inner_red = F::from_montgomery_reduce::<9>(local_res_unr[idx]); + local_res_unr[idx] = e_out_val.mul_unreduced::<9>(inner_red); + } + local_res_unr + }) + .reduce( + || vec![F::Unreduced::<9>::zero(); three_pow_dim], + |mut acc, local| { + for idx in 0..three_pow_dim { + acc[idx] += local[idx]; + } + acc + }, + ); + + // Final reduction over all (x_out, x_in) + let res: Vec = res_unr + .into_iter() + .map(|unr| F::from_montgomery_reduce::<9>(unr)) + .collect(); + self.t_prime_poly = Some(MultiquadraticPolynomial::new(window_size, res)); + } + /// Compute the quadratic evaluations for the streaming round (right after univariate skip). /// /// This uses the streaming algorithm to compute the sum-check polynomial for the round @@ -294,7 +578,7 @@ impl OuterRemainingSumcheckProver { /// /// (and the eval at ∞ is computed as (eval at 1) - (eval at 0)) #[inline] - fn compute_first_quadratic_evals_and_bound_polys( + fn _compute_first_quadratic_evals_and_bound_polys( bytecode_preprocessing: &BytecodePreprocessing, trace: &[Cycle], lagrange_evals_r: &[F; OUTER_UNIVARIATE_SKIP_DOMAIN_SIZE], @@ -366,8 +650,486 @@ impl OuterRemainingSumcheckProver { ) } - // No special binding path needed; az/bz hold interleaved [lo,hi] ready for binding + // TODO: No small value optimisation in this function currently. + // TODO: Put meaningful doc strings + fn stream_to_linear_time_parallel(&mut self) { + let num_x_out_vals = self.split_eq_poly.E_out_current_len(); + let num_x_in_vals = self.split_eq_poly.E_in_current_len(); + let r_grid = &self.r_grid; + let num_r_vals = r_grid.len(); + + // Output arrays are sized by (x_out, x_in) pairs + let output_size = num_x_out_vals * num_x_in_vals; + let mut az_bound: Vec = unsafe_allocate_zero_vec(2 * output_size); + let mut bz_bound: Vec = unsafe_allocate_zero_vec(2 * output_size); + + let num_r_bits = num_r_vals.log_2(); + let num_x_in_bits = num_x_in_vals.log_2(); + + // Dynamic chunking for parallelization + let num_threads = rayon::current_num_threads(); + let target_chunks = num_threads * 4; + let min_chunk_pairs = 16; + //let pairs_per_chunk = + //((output_size + target_chunks - 1) / target_chunks).max(min_chunk_pairs); + let pairs_per_chunk = output_size.div_ceil(target_chunks).max(min_chunk_pairs); + let chunk_size = pairs_per_chunk * 2; + + // Parallel computation with reduction + let (t0_acc, t_inf_acc) = az_bound + .par_chunks_mut(chunk_size) + .zip(bz_bound.par_chunks_mut(chunk_size)) + .enumerate() + .fold( + || (F::zero(), F::zero()), + |(mut t0_local, mut t_inf_local), (chunk_idx, (az_chunk, bz_chunk))| { + let start_pair = chunk_idx * pairs_per_chunk; + let end_pair = (start_pair + pairs_per_chunk).min(output_size); + + for pair_idx in start_pair..end_pair { + let x_in_val = pair_idx % num_x_in_vals; + let x_out_val = pair_idx / num_x_in_vals; + + let mut az0_sum = F::zero(); // For X=0 + let mut az1_sum = F::zero(); // For X=1 + let mut bz0_sum = F::zero(); // For X=0 + let mut bz1_sum = F::zero(); // For X=1 + + // Single loop over r values, computing both X=0 and X=1 + for r_idx in 0..num_r_vals { + let r_eval = r_grid[r_idx]; + + // Build indices for both X=0 and X=1 + let base_idx = (x_out_val << (num_x_in_bits + 1 + num_r_bits)) + | (x_in_val << (1 + num_r_bits)); + + let full_idx_x0 = base_idx | (0 << num_r_bits) | r_idx; + let full_idx_x1 = base_idx | (1 << num_r_bits) | r_idx; + + // Process X=0 + let step_idx_x0 = full_idx_x0 >> 1; + let selector_x0 = (full_idx_x0 & 1) == 1; + + let row_inputs_x0 = R1CSCycleInputs::from_trace::( + &self.bytecode_preprocessing, + &self.trace, + step_idx_x0, + ); + let eval_x0 = R1CSEval::::from_cycle_inputs(&row_inputs_x0); + + let (az_x0, bz_x0) = if !selector_x0 { + ( + eval_x0.az_at_r_first_group(&self.lagrange_evals_r0), + eval_x0.bz_at_r_first_group(&self.lagrange_evals_r0), + ) + } else { + ( + eval_x0.az_at_r_second_group(&self.lagrange_evals_r0), + eval_x0.bz_at_r_second_group(&self.lagrange_evals_r0), + ) + }; + + // Process X=1 + let step_idx_x1 = full_idx_x1 >> 1; + let selector_x1 = (full_idx_x1 & 1) == 1; + + let row_inputs_x1 = R1CSCycleInputs::from_trace::( + &self.bytecode_preprocessing, + &self.trace, + step_idx_x1, + ); + let eval_x1 = R1CSEval::::from_cycle_inputs(&row_inputs_x1); + + let (az_x1, bz_x1) = if !selector_x1 { + ( + eval_x1.az_at_r_first_group(&self.lagrange_evals_r0), + eval_x1.bz_at_r_first_group(&self.lagrange_evals_r0), + ) + } else { + ( + eval_x1.az_at_r_second_group(&self.lagrange_evals_r0), + eval_x1.bz_at_r_second_group(&self.lagrange_evals_r0), + ) + }; + + // Accumulate both with the same r_eval + az0_sum += az_x0 * r_eval; + bz0_sum += bz_x0 * r_eval; + az1_sum += az_x1 * r_eval; + bz1_sum += bz_x1 * r_eval; + } + + // Store in chunk-relative position + let buffer_offset = 2 * (pair_idx - start_pair); + az_chunk[buffer_offset] = az0_sum; + az_chunk[buffer_offset + 1] = az1_sum; + bz_chunk[buffer_offset] = bz0_sum; + bz_chunk[buffer_offset + 1] = bz1_sum; + + // Local accumulation for t_0 and t_inf + let e_in = self.split_eq_poly.E_in_current()[x_in_val]; + let e_out = self.split_eq_poly.E_out_current()[x_out_val]; + let p0 = az0_sum * bz0_sum; + let slope = (az1_sum - az0_sum) * (bz1_sum - bz0_sum); + + t0_local += e_out * e_in * p0; + t_inf_local += e_out * e_in * slope; + } + + (t0_local, t_inf_local) + }, + ) + .reduce( + || (F::zero(), F::zero()), + |(t0_a, t_inf_a), (t0_b, t_inf_b)| (t0_a + t0_b, t_inf_a + t_inf_b), + ); + + self.az = Some(DensePolynomial::new(az_bound)); + self.bz = Some(DensePolynomial::new(bz_bound)); + self.t_0 = Some(t0_acc); + self.t_inf = Some(t_inf_acc); + } + //fn stream_to_linear_time_parallel(&mut self) { + // let num_x_out_vals = (&self.split_eq_poly).E_out_current_len(); + // let num_x_in_vals = (&self.split_eq_poly).E_in_current_len(); + // let r_grid = &self.r_grid; + // let num_r_vals = r_grid.len(); + // + // // Output arrays are sized by (x_out, x_in) pairs + // let output_size = num_x_out_vals * num_x_in_vals; + // let mut az_bound: Vec = unsafe_allocate_zero_vec(2 * output_size); + // let mut bz_bound: Vec = unsafe_allocate_zero_vec(2 * output_size); + // + // let num_r_bits = num_r_vals.log_2(); + // let num_x_in_bits = num_x_in_vals.log_2(); + // + // // Dynamic chunking for parallelization + // let num_threads = rayon::current_num_threads(); + // let target_chunks = num_threads * 4; // 4x oversubscription + // let min_chunk_pairs = 16; // Minimum pairs per chunk to avoid overhead + // let pairs_per_chunk = + // ((output_size + target_chunks - 1) / target_chunks).max(min_chunk_pairs); + // let chunk_size = pairs_per_chunk * 2; // *2 for [X=0, X=1] storage + // + // // Parallel computation with reduction + // let (t0_acc, t_inf_acc) = az_bound + // .par_chunks_mut(chunk_size) + // .zip(bz_bound.par_chunks_mut(chunk_size)) + // .enumerate() + // .fold( + // || (F::zero(), F::zero()), + // |(mut t0_local, mut t_inf_local), (chunk_idx, (az_chunk, bz_chunk))| { + // let start_pair = chunk_idx * pairs_per_chunk; + // let end_pair = (start_pair + pairs_per_chunk).min(output_size); + // + // for pair_idx in start_pair..end_pair { + // // Decompose pair index into x_out, x_in + // let x_in_val = pair_idx % num_x_in_vals; + // let x_out_val = pair_idx / num_x_in_vals; + // + // // Initialize accumulators for this (x_out, x_in) pair + // let mut az0_sum = F::zero(); // For X=0 + // let mut az1_sum = F::zero(); // For X=1 + // let mut bz0_sum = F::zero(); // For X=0 + // let mut bz1_sum = F::zero(); // For X=1 + // + // // Iterate over X bit and r values + // for x_bit in 0..2 { + // for r_idx in 0..num_r_vals { + // // Build the full index: x_out || x_in || X || x_r + // let full_idx = (x_out_val << (num_x_in_bits + 1 + num_r_bits)) + // | (x_in_val << (1 + num_r_bits)) + // | (x_bit << num_r_bits) + // | r_idx; + // + // // Extract step index and selector + // let current_step_idx = full_idx >> 1; + // let selector = (full_idx & 1) == 1; + // + // let row_inputs = R1CSCycleInputs::from_trace::( + // &self.bytecode_preprocessing, + // &self.trace, + // current_step_idx, + // ); + // + // let eval = R1CSEval::::from_cycle_inputs(&row_inputs); + // + // let (az, bz) = if !selector { + // // First group (selector = 0) + // ( + // eval.az_at_r_first_group(&self.lagrange_evals_r0), + // eval.bz_at_r_first_group(&self.lagrange_evals_r0), + // ) + // } else { + // // Second group (selector = 1) + // ( + // eval.az_at_r_second_group(&self.lagrange_evals_r0), + // eval.bz_at_r_second_group(&self.lagrange_evals_r0), + // ) + // }; + // + // let r_eval = r_grid[r_idx]; + // + // if x_bit == 0 { + // az0_sum += az * r_eval; + // bz0_sum += bz * r_eval; + // } else { + // az1_sum += az * r_eval; + // bz1_sum += bz * r_eval; + // } + // } + // } + // + // // Store in chunk-relative position + // let buffer_offset = 2 * (pair_idx - start_pair); + // az_chunk[buffer_offset] = az0_sum; + // az_chunk[buffer_offset + 1] = az1_sum; + // bz_chunk[buffer_offset] = bz0_sum; + // bz_chunk[buffer_offset + 1] = bz1_sum; + // + // // Local accumulation for t_0 and t_inf + // let e_in = (&self.split_eq_poly).E_in_current()[x_in_val]; + // let e_out = (&self.split_eq_poly).E_out_current()[x_out_val]; + // let p0 = az0_sum * bz0_sum; + // let slope = (az1_sum - az0_sum) * (bz1_sum - bz0_sum); + // + // t0_local += e_out * e_in * p0; + // t_inf_local += e_out * e_in * slope; + // } + // + // (t0_local, t_inf_local) + // }, + // ) + // .reduce( + // || (F::zero(), F::zero()), + // |(t0_a, t_inf_a), (t0_b, t_inf_b)| (t0_a + t0_b, t_inf_a + t_inf_b), + // ); + // + // self.az = Some(DensePolynomial::new(az_bound)); + // self.bz = Some(DensePolynomial::new(bz_bound)); + // self.t_0 = Some(t0_acc); + // self.t_inf = Some(t_inf_acc); + //} + // + fn _stream_to_linear_time_serial(&mut self) { + let num_x_out_vals = self.split_eq_poly.E_out_current_len(); + let num_x_in_vals = self.split_eq_poly.E_in_current_len(); + let r_grid = &self.r_grid; + let num_r_vals = r_grid.len(); + + // Output arrays are sized by (x_out, x_in) pairs + let output_size = num_x_out_vals * num_x_in_vals; + let mut az_bound: Vec = unsafe_allocate_zero_vec(2 * output_size); + let mut bz_bound: Vec = unsafe_allocate_zero_vec(2 * output_size); + + let num_r_bits = num_r_vals.log_2(); + let num_x_in_bits = num_x_in_vals.log_2(); + let mut t0_acc = F::zero(); + let mut t_inf_acc = F::zero(); + + // Serial iteration over all (x_out, x_in) pairs + for x_out_val in 0..num_x_out_vals { + for x_in_val in 0..num_x_in_vals { + // Initialize accumulators for this (x_out, x_in) pair + let mut az0_sum = F::zero(); // For X=0 + let mut az1_sum = F::zero(); // For X=1 + let mut bz0_sum = F::zero(); // For X=0 + let mut bz1_sum = F::zero(); // For X=1 + + // Iterate over X bit and r values + for x_bit in 0..2 { + for r_idx in 0..num_r_vals { + // Build the full index: x_out || x_in || X || x_r + let full_idx = (x_out_val << (num_x_in_bits + 1 + num_r_bits)) + | (x_in_val << (1 + num_r_bits)) + | (x_bit << num_r_bits) + | r_idx; + + // Extract step index and selector + let current_step_idx = full_idx >> 1; + let selector = (full_idx & 1) == 1; + + let row_inputs = R1CSCycleInputs::from_trace::( + &self.bytecode_preprocessing, + &self.trace, + current_step_idx, + ); + + let eval = R1CSEval::::from_cycle_inputs(&row_inputs); + + let (az, bz) = if !selector { + // First group (selector = 0) + ( + eval.az_at_r_first_group(&self.lagrange_evals_r0), + eval.bz_at_r_first_group(&self.lagrange_evals_r0), + ) + } else { + // Second group (selector = 1) + ( + eval.az_at_r_second_group(&self.lagrange_evals_r0), + eval.bz_at_r_second_group(&self.lagrange_evals_r0), + ) + }; + + let r_eval = r_grid[r_idx]; + + if x_bit == 0 { + az0_sum += az * r_eval; + bz0_sum += bz * r_eval; + } else { + az1_sum += az * r_eval; + bz1_sum += bz * r_eval; + } + } + } + + // Store the summed values in Az and Bz arrays + let pair_idx = x_out_val * num_x_in_vals + x_in_val; + let buffer_offset = 2 * pair_idx; + az_bound[buffer_offset] = az0_sum; // A(x_out, x_in, 0, r2, r1) + az_bound[buffer_offset + 1] = az1_sum; // A(x_out, x_in, 1, r2, r1) + bz_bound[buffer_offset] = bz0_sum; // B(x_out, x_in, 0, r2, r1) + bz_bound[buffer_offset + 1] = bz1_sum; // B(x_out, x_in, 1, r2, r1) + + // For t_0 and t_inf, apply eq polynomials + let e_in = self.split_eq_poly.E_in_current()[x_in_val]; + let e_out = self.split_eq_poly.E_out_current()[x_out_val]; + let p0 = az0_sum * bz0_sum; + let slope = (az1_sum - az0_sum) * (bz1_sum - bz0_sum); + + t0_acc += e_out * e_in * p0; + t_inf_acc += e_out * e_in * slope; + } + } + + self.az = Some(DensePolynomial::new(az_bound)); + self.bz = Some(DensePolynomial::new(bz_bound)); + } + + #[tracing::instrument( + skip_all, + name = "OuterRemainingSumcheckProver::stream_to_linear_time_helper" + )] + + // If the first round of the sumcheck is linear -- then manifesting Az and Bz + // is significantly simpler. + fn stream_to_linear_time_round_zero(&mut self) { + let num_x_out_vals = self.split_eq_poly.E_out_current_len(); + let num_x_in_vals = self.split_eq_poly.E_in_current_len(); + let iter_num_x_in_vars = num_x_in_vals.log_2(); + + let groups_exact = num_x_out_vals + .checked_mul(num_x_in_vals) + .expect("overflow computing groups_exact"); + + // Preallocate interleaved buffers once ([lo, hi] per entry) + let mut az_bound: Vec = unsafe_allocate_zero_vec(2 * groups_exact); + let mut bz_bound: Vec = unsafe_allocate_zero_vec(2 * groups_exact); + // Parallel over x_out groups using exact-sized mutable chunks, with per-worker fold + let (t0_acc_unr, t_inf_acc_unr) = az_bound + .par_chunks_exact_mut(2 * num_x_in_vals) + .zip(bz_bound.par_chunks_exact_mut(2 * num_x_in_vals)) + .enumerate() + .fold( + || (F::Unreduced::<9>::zero(), F::Unreduced::<9>::zero()), + |(mut acc0, mut acci), (x_out_val, (az_chunk, bz_chunk))| { + let mut inner_sum0 = F::Unreduced::<9>::zero(); + let mut inner_sum_inf = F::Unreduced::<9>::zero(); + for x_in_val in 0..num_x_in_vals { + let current_step_idx = (x_out_val << iter_num_x_in_vars) | x_in_val; + let row_inputs = R1CSCycleInputs::from_trace::( + &self.bytecode_preprocessing, + &self.trace, + current_step_idx, + ); + let eval = R1CSEval::::from_cycle_inputs(&row_inputs); + let az0 = eval.az_at_r_first_group(&self.lagrange_evals_r0); + let bz0 = eval.bz_at_r_first_group(&self.lagrange_evals_r0); + let az1 = eval.az_at_r_second_group(&self.lagrange_evals_r0); + let bz1 = eval.bz_at_r_second_group(&self.lagrange_evals_r0); + let p0 = az0 * bz0; + let slope = (az1 - az0) * (bz1 - bz0); + let e_in = self.split_eq_poly.E_in_current()[x_in_val]; + inner_sum0 += e_in.mul_unreduced::<9>(p0); + inner_sum_inf += e_in.mul_unreduced::<9>(slope); + let off = 2 * x_in_val; + az_chunk[off] = az0; + az_chunk[off + 1] = az1; + bz_chunk[off] = bz0; + bz_chunk[off + 1] = bz1; + } + let e_out = self.split_eq_poly.E_out_current()[x_out_val]; + let reduced0 = F::from_montgomery_reduce::<9>(inner_sum0); + let reduced_inf = F::from_montgomery_reduce::<9>(inner_sum_inf); + acc0 += e_out.mul_unreduced::<9>(reduced0); + acci += e_out.mul_unreduced::<9>(reduced_inf); + (acc0, acci) + }, + ) + .reduce( + || (F::Unreduced::<9>::zero(), F::Unreduced::<9>::zero()), + |a, b| (a.0 + b.0, a.1 + b.1), + ); + + self.az = Some(DensePolynomial::new(az_bound)); + self.bz = Some(DensePolynomial::new(bz_bound)); + self.t_0 = Some(F::from_montgomery_reduce::<9>(t0_acc_unr)); + self.t_inf = Some(F::from_montgomery_reduce::<9>(t_inf_acc_unr)) + } + + // TODO:(ari) This is 2.5x slower than it needs to be right now. + // Currently this is binding Az and Bz -- but I can fuse this. + #[tracing::instrument(skip_all, name = "OuterRemainingSumcheckProver::stream_to_linear_time")] + fn stream_to_linear_time(&mut self) { + let split_eq_poly = &self.split_eq_poly; + // helper constants + let klen = 1 << split_eq_poly.num_challenges(); + // Precompute scaled Lagrange weights for all k so the parallel + // conversion reuses them instead of recomputing per (j, k). + + if klen > 1 { + // Uncomment the following lines for the more structured bind + // with small value optimisations + + //let mut scaled_w = vec![[F::zero(); OUTER_UNIVARIATE_SKIP_DOMAIN_SIZE]; klen]; + //let lagrange_evals_r = &self.lagrange_evals_r0; + //let r_grid = &self.r_grid; + //debug_assert_eq!(klen, r_grid.len()); + //let jlen = 1 << (split_eq_poly.get_num_vars() - split_eq_poly.num_challenges()); + //for k in 0..klen { + // let weight = r_grid[k]; + // let row = &mut scaled_w[k]; + // for t in 0..OUTER_UNIVARIATE_SKIP_DOMAIN_SIZE { + // row[t] = lagrange_evals_r[t] * weight; + // } + //} + //let mut ret_az = unsafe_allocate_zero_vec(jlen); + //let mut ret_bz = unsafe_allocate_zero_vec(jlen); + //// Parallelize over j for the linear-time conversion. + //self.build_grids(&mut ret_az, &mut ret_bz, jlen, klen, 0, true, &scaled_w); + //self.az = Some(DensePolynomial::new(ret_az)); + //self.bz = Some(DensePolynomial::new(ret_bz)); + + // A simpler parallel version without small value optimisations + // And no unsafe code + self.stream_to_linear_time_parallel(); + } else { + //debug_assert_eq!(klen, 1); + + //let jlen = 1 << (split_eq_poly.get_num_vars() - split_eq_poly.num_challenges()); + //let mut scaled_w = vec![[F::zero(); OUTER_UNIVARIATE_SKIP_DOMAIN_SIZE]; klen]; + //let lagrange_evals_r = &self.lagrange_evals_r0; + //scaled_w[0].copy_from_slice(lagrange_evals_r); + //let mut ret_az = unsafe_allocate_zero_vec(jlen); + //let mut ret_bz = unsafe_allocate_zero_vec(jlen); + //self.build_grids(&mut ret_az, &mut ret_bz, jlen, klen, 0, true, &scaled_w); + //self.az = Some(DensePolynomial::new(ret_az)); + //self.bz = Some(DensePolynomial::new(ret_bz)); + // + self.stream_to_linear_time_round_zero(); + } + } /// Compute the polynomial for each of the remaining rounds, using the /// linear-time algorithm with split-eq optimizations. /// @@ -384,23 +1146,98 @@ impl OuterRemainingSumcheckProver { /// /// (ordering of indices is MSB to LSB, so x_out is the MSB and x_in is the LSB) #[inline] - fn remaining_quadratic_evals(&self) -> (F, F) { - let n = self.az.len(); - debug_assert_eq!(n, self.bz.len()); - let [t0, tinf] = self.split_eq_poly.par_fold_out_in_unreduced::<9, 2>(&|g| { - let az0 = self.az[2 * g]; - let az1 = self.az[2 * g + 1]; - let bz0 = self.bz[2 * g]; - let bz1 = self.bz[2 * g + 1]; - let p0 = az0 * bz0; - let slope = (az1 - az0) * (bz1 - bz0); - [p0, slope] - }); - (t0, tinf) + fn remaining_quadratic_evals(&mut self) -> (F, F) { + if self.t_0.is_some() { + let t_0 = self.t_0.unwrap(); + let t_inf = self.t_inf.unwrap(); + self.t_0 = None; + self.t_inf = None; + return (t_0, t_inf); + } + let eq_poly = &self.split_eq_poly; + + let n = self.az.as_ref().expect("az should be initialized").len(); + let az = self.az.as_ref().expect("az should be initialized"); + let bz = self.bz.as_ref().expect("bz should be initialized"); + + debug_assert_eq!(n, bz.len()); + if eq_poly.E_in_current_len() == 1 { + // groups are pairs (0,1) + let groups = n / 2; + let (t0_unr, tinf_unr) = (0..groups) + .into_par_iter() + .map(|g| { + let az0 = az[2 * g]; + let az1 = az[2 * g + 1]; + let bz0 = bz[2 * g]; + let bz1 = bz[2 * g + 1]; + let eq = eq_poly.E_out_current()[g]; + let p0 = az0 * bz0; + let slope = (az1 - az0) * (bz1 - bz0); + let t0_unr = eq.mul_unreduced::<9>(p0); + let tinf_unr = eq.mul_unreduced::<9>(slope); + (t0_unr, tinf_unr) + }) + .reduce( + || (F::Unreduced::<9>::zero(), F::Unreduced::<9>::zero()), + |a, b| (a.0 + b.0, a.1 + b.1), + ); + ( + F::from_montgomery_reduce::<9>(t0_unr), + F::from_montgomery_reduce::<9>(tinf_unr), + ) + } else { + let num_x1_bits = eq_poly.E_in_current_len().log_2(); + let x1_len = eq_poly.E_in_current_len(); + let x2_len = eq_poly.E_out_current_len(); + let (sum0_unr, suminf_unr) = (0..x2_len) + .into_par_iter() + .map(|x2| { + let mut inner0_unr = F::Unreduced::<9>::zero(); + let mut inner_inf_unr = F::Unreduced::<9>::zero(); + for x1 in 0..x1_len { + let g = (x2 << num_x1_bits) | x1; + let az0 = az[2 * g]; + let az1 = az[2 * g + 1]; + let bz0 = bz[2 * g]; + let bz1 = bz[2 * g + 1]; + let e_in = eq_poly.E_in_current()[x1]; + let p0 = az0 * bz0; + let slope = (az1 - az0) * (bz1 - bz0); + inner0_unr += e_in.mul_unreduced::<9>(p0); + inner_inf_unr += e_in.mul_unreduced::<9>(slope); + } + let e_out = eq_poly.E_out_current()[x2]; + let inner0_red = F::from_montgomery_reduce::<9>(inner0_unr); + let inner_inf_red = F::from_montgomery_reduce::<9>(inner_inf_unr); + let t0_unr = e_out.mul_unreduced::<9>(inner0_red); + let tinf_unr = e_out.mul_unreduced::<9>(inner_inf_red); + (t0_unr, tinf_unr) + }) + .reduce( + || (F::Unreduced::<9>::zero(), F::Unreduced::<9>::zero()), + |a, b| (a.0 + b.0, a.1 + b.1), + ); + ( + F::from_montgomery_reduce::<9>(sum0_unr), + F::from_montgomery_reduce::<9>(suminf_unr), + ) + } + } + + pub fn final_sumcheck_evals(&self) -> [F; 2] { + let az = self.az.as_ref().expect("az should be initialized"); + let bz = self.bz.as_ref().expect("bz should be initialized"); + + let az0 = if !az.is_empty() { az[0] } else { F::zero() }; + let bz0 = if !bz.is_empty() { bz[0] } else { F::zero() }; + [az0, bz0] } } -impl SumcheckInstanceProver for OuterRemainingSumcheckProver { +impl SumcheckInstanceProver + for OuterRemainingSumcheckProver<'_, F, S> +{ fn degree(&self) -> usize { OUTER_REMAINING_DEGREE_BOUND } @@ -413,26 +1250,85 @@ impl SumcheckInstanceProver for OuterRemainin self.params.input_claim } + //#[tracing::instrument(skip_all, name = "OuterRemainingSumcheckProver::compute_message")] + //fn compute_message(&mut self, round: usize, previous_claim: F) -> UniPoly { + // let (t0, t_inf) = if round == 0 { + // self.first_round_evals + // } else { + // self.remaining_quadratic_evals() + // }; + // self.split_eq_poly + // .gruen_poly_deg_3(t0, t_inf, previous_claim) + //} #[tracing::instrument(skip_all, name = "OuterRemainingSumcheckProver::compute_message")] fn compute_message(&mut self, round: usize, previous_claim: F) -> UniPoly { - let (t0, t_inf) = if round == 0 { - self.first_round_evals + let (t0, t_inf) = if self.schedule.is_streaming(round) { + let num_unbound_vars = self.schedule.num_unbound_vars(round); + + if self.schedule.is_window_start(round) { + // Build the multiquadratic t'(z) for this window using the + // slice-based Eq factorisation provided by the simple + // split-eq instance (head vs window bits). + self.get_grid_gen(num_unbound_vars); + } + // Use the multiquadratic polynomial to compute the message + let t_prime_poly = self + .t_prime_poly + .as_ref() + .expect("t_prime_poly should be initialized"); + // Equality weights over the active window bits (all but the first). + let e_active = self.split_eq_poly.E_active_for_window(num_unbound_vars); + let t_prime_0 = t_prime_poly.project_to_first_variable(&e_active, 0); + let t_prime_inf = t_prime_poly.project_to_first_variable(&e_active, INFINITY); + + (t_prime_0, t_prime_inf) } else { - self.remaining_quadratic_evals() + // LINEAR PHASE + //println!("In Linear phase| Round: {:?}", round); + if self.schedule.is_first_linear(round) { + self.stream_to_linear_time(); + } + // For now, just use quadratic evals + let (t0, t_inf) = self.remaining_quadratic_evals(); + (t0, t_inf) }; + // Compute the Gruen cubic using the split-eq implementation. self.split_eq_poly .gruen_poly_deg_3(t0, t_inf, previous_claim) + //vec![evals[0], evals[1], evals[2]] } #[tracing::instrument(skip_all, name = "OuterRemainingSumcheckProver::ingest_challenge")] - fn ingest_challenge(&mut self, r_j: F::Challenge, _round: usize) { - rayon::join( - || self.az.bind_parallel(r_j, BindingOrder::LowToHigh), - || self.bz.bind_parallel(r_j, BindingOrder::LowToHigh), - ); - - // Bind eq_poly for next round + fn ingest_challenge(&mut self, r_j: F::Challenge, round: usize) { self.split_eq_poly.bind(r_j); + + if self.schedule.is_streaming(round) { + let t_prime_poly = self + .t_prime_poly + .as_mut() + .expect("t_prime_poly should be initialized"); + t_prime_poly.bind(r_j, BindingOrder::LowToHigh); + self.r_grid.update(r_j); + } else { + // TODO: Unless this is the last round I should also + // manifest evals for next round : Fused bind + eval; + // Bind the split-eq instance in lock-step with the outer sumcheck. + // TODO: so we need a new bind_parallel algorithm + rayon::join( + || { + self.az + .as_mut() + .expect("az should be initialised") + .bind_parallel(r_j, BindingOrder::LowToHigh) + }, + || { + self.bz + .as_mut() + .expect("bz should be initialised") + .bind_parallel(r_j, BindingOrder::LowToHigh) + }, + ); + } } fn cache_openings( diff --git a/tracer/src/lib.rs b/tracer/src/lib.rs index 36f845db2..1a7d70628 100644 --- a/tracer/src/lib.rs +++ b/tracer/src/lib.rs @@ -86,7 +86,9 @@ pub fn trace( memory_config, )); let lazy_trace_iter_ = lazy_trace_iter.clone(); + // NOTE: this will materialise the trace in full let trace: Vec = lazy_trace_iter.by_ref().collect(); + //let trace = lazy_trace_iter.by_ref(); let final_memory_state = std::mem::take(lazy_trace_iter.final_memory_state.as_mut().unwrap()); ( lazy_trace_iter_,