diff --git a/src/lair/bytecode.rs b/src/lair/bytecode.rs index b2940681..50235de4 100644 --- a/src/lair/bytecode.rs +++ b/src/lair/bytecode.rs @@ -138,6 +138,7 @@ pub struct Func { pub(crate) input_size: usize, pub(crate) output_size: usize, pub(crate) body: Block, + pub(crate) rc: usize, } impl Func { diff --git a/src/lair/execute.rs b/src/lair/execute.rs index fa112190..69dfb3b8 100644 --- a/src/lair/execute.rs +++ b/src/lair/execute.rs @@ -107,8 +107,25 @@ impl<'a, F: PrimeField32> Shard<'a, F> { self.queries.expect("Missing query record reference") } - pub fn get_func_range(&self, func_index: usize) -> Range { - let num_func_queries = self.queries().func_queries[func_index].len(); + pub fn get_func_range_rc(&self, func: &Func, rc_index: usize) -> Range { + let num_func_queries = self.queries().func_queries[func.index].len(); + let shard_idx = self.index as usize; + let shard_chunk = self.shard_config.max_shard_size as usize * func.rc; + let start = shard_idx * shard_chunk; + let end = ((shard_idx + 1) * shard_chunk).min(num_func_queries); + let len = (start..end).len(); + if len % func.rc == 0 { + let chunk = len / func.rc; + start + chunk * rc_index..start + chunk * (rc_index + 1) + } else { + let chunk = (len / func.rc) + 1; + let end = start + chunk * (rc_index + 1); + start + chunk * rc_index..end.min(num_func_queries) + } + } + + pub fn get_func_range(&self, func: &Func) -> Range { + let num_func_queries = self.queries().func_queries[func.index].len(); let shard_idx = self.index as usize; let max_shard_size = self.shard_config.max_shard_size as usize; shard_idx * max_shard_size..((shard_idx + 1) * max_shard_size).min(num_func_queries) diff --git a/src/lair/expr.rs b/src/lair/expr.rs index d0ea5056..3880ef10 100644 --- a/src/lair/expr.rs +++ b/src/lair/expr.rs @@ -133,9 +133,10 @@ pub struct CasesE { #[derive(Clone, Debug, Eq, PartialEq)] pub struct FuncE { pub name: Name, - pub invertible: bool, pub partial: bool, + pub invertible: bool, pub input_params: VarList, pub output_size: usize, pub body: BlockE, + pub rc: usize, } diff --git a/src/lair/lair_chip.rs b/src/lair/lair_chip.rs index a9eff79b..8efe5220 100644 --- a/src/lair/lair_chip.rs +++ b/src/lair/lair_chip.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use p3_air::{Air, AirBuilder, AirBuilderWithPublicValues, BaseAir, PairBuilder}; use p3_field::{AbstractField, Field, PrimeField32}; use p3_matrix::{dense::RowMajorMatrix, Matrix}; @@ -20,7 +22,10 @@ use super::{ }; pub enum LairChip<'a, F, H: Chipset> { - Func(FuncChip<'a, F, H>), + Func { + func_chip: Arc>, + rc_index: usize, + }, Mem(MemChip), Bytes(BytesChip), Entrypoint { @@ -54,7 +59,7 @@ impl<'a, F: PrimeField32, H: Chipset> EventLens> for Shard impl<'a, F: Field + Sync, H: Chipset> BaseAir for LairChip<'a, F, H> { fn width(&self) -> usize { match self { - Self::Func(func_chip) => func_chip.width(), + Self::Func { func_chip, .. } => func_chip.width(), Self::Mem(mem_chip) => mem_chip.width(), Self::Bytes(bytes_chip) => bytes_chip.width(), Self::Entrypoint { @@ -78,7 +83,7 @@ impl<'a, F: PrimeField32, H: Chipset> MachineAir for LairChip<'a, F, H> { fn name(&self) -> String { match self { - Self::Func(func_chip) => format!("Func[{}]", func_chip.func.name), + Self::Func { func_chip, .. } => format!("Func[{}]", func_chip.func.name), Self::Mem(mem_chip) => format!("Mem[{}-wide]", mem_chip.len), Self::Entrypoint { func_idx, .. } => format!("Entrypoint[{func_idx}]"), // the following is required by sphinx @@ -93,7 +98,10 @@ impl<'a, F: PrimeField32, H: Chipset> MachineAir for LairChip<'a, F, H> { _: &mut Self::Record, ) -> RowMajorMatrix { match self { - Self::Func(func_chip) => func_chip.generate_trace(shard.events()), + Self::Func { + func_chip, + rc_index, + } => func_chip.generate_trace_rc(shard.events(), *rc_index), Self::Mem(mem_chip) => mem_chip.generate_trace(shard.events()), Self::Bytes(bytes_chip) => { // TODO: Shard the byte events differently? @@ -117,8 +125,8 @@ impl<'a, F: PrimeField32, H: Chipset> MachineAir for LairChip<'a, F, H> { fn included(&self, shard: &Self::Record) -> bool { match self { - Self::Func(func_chip) => { - let range = shard.get_func_range(func_chip.func.index); + Self::Func { func_chip, .. } => { + let range = shard.get_func_range(func_chip.func); !range.is_empty() } Self::Mem(_mem_chip) => { @@ -154,7 +162,7 @@ where { fn eval(&self, builder: &mut AB) { match self { - Self::Func(func_chip) => func_chip.eval(builder), + Self::Func { func_chip, .. } => func_chip.eval(builder), Self::Mem(mem_chip) => mem_chip.eval(builder), Self::Bytes(bytes_chip) => bytes_chip.eval(builder), Self::Entrypoint { @@ -195,7 +203,14 @@ pub fn build_lair_chip_vector<'a, F: PrimeField32, H: Chipset>( let mut chip_vector = Vec::with_capacity(2 + toplevel.map.size() + MEM_TABLE_SIZES.len()); chip_vector.push(LairChip::entrypoint(func)); for func_chip in FuncChip::from_toplevel(toplevel) { - chip_vector.push(LairChip::Func(func_chip)); + let func_chip = Arc::new(func_chip); + for rc_index in 0..func_chip.func.rc { + let func_chip = func_chip.clone(); + chip_vector.push(LairChip::Func { + func_chip, + rc_index, + }); + } } for mem_len in MEM_TABLE_SIZES { chip_vector.push(LairChip::Mem(MemChip::new(mem_len))); diff --git a/src/lair/macros.rs b/src/lair/macros.rs index 852fbcb5..b20d6109 100644 --- a/src/lair/macros.rs +++ b/src/lair/macros.rs @@ -1,6 +1,12 @@ #[macro_export] macro_rules! func { - (fn $name:ident($( $in:ident $(: [$in_size:expr])? ),*): [$size:expr] $lair:tt) => {{ + (#[RC = $rc:expr] $($x:tt)*) => { $crate::func_body!($rc, $($x)*) }; + ($($x:tt)*) => { $crate::func_body!(1, $($x)*) }; +} + +#[macro_export] +macro_rules! func_body { + ($rc:expr, fn $name:ident($( $in:ident $(: [$in_size:expr])? ),*): [$size:expr] $lair:tt) => {{ $(let $in = $crate::var!($in $(, $in_size)?);)* $crate::lair::expr::FuncE { name: $crate::lair::Name(stringify!($name)), @@ -9,9 +15,10 @@ macro_rules! func { input_params: [$($crate::var!($in $(, $in_size)?)),*].into(), output_size: $size, body: $crate::block_init!($lair), + rc: $rc, } }}; - (partial fn $name:ident($( $in:ident $(: [$in_size:expr])? ),*): [$size:expr] $lair:tt) => {{ + ($rc:expr, partial fn $name:ident($( $in:ident $(: [$in_size:expr])? ),*): [$size:expr] $lair:tt) => {{ $(let $in = $crate::var!($in $(, $in_size)?);)* $crate::lair::expr::FuncE { name: $crate::lair::Name(stringify!($name)), @@ -20,9 +27,10 @@ macro_rules! func { input_params: [$($crate::var!($in $(, $in_size)?)),*].into(), output_size: $size, body: $crate::block_init!($lair), + rc: $rc, } }}; - (invertible fn $name:ident($( $in:ident $(: [$in_size:expr])? ),*): [$size:expr] $lair:tt) => {{ + ($rc:expr, invertible fn $name:ident($( $in:ident $(: [$in_size:expr])? ),*): [$size:expr] $lair:tt) => {{ $(let $in = $crate::var!($in $(, $in_size)?);)* $crate::lair::expr::FuncE { name: $crate::lair::Name(stringify!($name)), @@ -31,9 +39,10 @@ macro_rules! func { input_params: [$($crate::var!($in $(, $in_size)?)),*].into(), output_size: $size, body: $crate::block_init!($lair), + rc: $rc, } }}; - (invertible partial fn $name:ident($( $in:ident $(: [$in_size:expr])? ),*): [$size:expr] $lair:tt) => {{ + ($rc:expr, invertible partial fn $name:ident($( $in:ident $(: [$in_size:expr])? ),*): [$size:expr] $lair:tt) => {{ $(let $in = $crate::var!($in $(, $in_size)?);)* $crate::lair::expr::FuncE { name: $crate::lair::Name(stringify!($name)), @@ -42,6 +51,7 @@ macro_rules! func { input_params: [$($crate::var!($in $(, $in_size)?)),*].into(), output_size: $size, body: $crate::block_init!($lair), + rc: $rc, } }}; } diff --git a/src/lair/toplevel.rs b/src/lair/toplevel.rs index 6f72bfce..75f37461 100644 --- a/src/lair/toplevel.rs +++ b/src/lair/toplevel.rs @@ -182,6 +182,7 @@ impl FuncE { body, input_size: self.input_params.total_size(), output_size: self.output_size, + rc: self.rc, } } } diff --git a/src/lair/trace.rs b/src/lair/trace.rs index df89c328..053bbac0 100644 --- a/src/lair/trace.rs +++ b/src/lair/trace.rs @@ -70,10 +70,74 @@ impl<'a, T> ColumnMutSlice<'a, T> { } impl<'a, F: PrimeField32, H: Chipset> FuncChip<'a, F, H> { + /// Per-row parallel trace generation + pub fn generate_trace_rc(&self, shard: &Shard<'_, F>, rc_index: usize) -> RowMajorMatrix { + let func_queries = &shard.queries().func_queries()[self.func.index]; + let range = shard.get_func_range_rc(self.func, rc_index); + let offset = range.start; + let width = self.width(); + let non_dummy_height = range.len(); + let height = non_dummy_height.next_power_of_two(); + let mut rows = vec![F::zero(); height * width]; + // initializing nonces + rows.chunks_mut(width) + .enumerate() + .for_each(|(i, row)| row[0] = F::from_canonical_usize(i + offset)); + let non_dummies = &mut rows[0..non_dummy_height * width]; + non_dummies + .par_chunks_mut(width) + .enumerate() + .for_each(|(i, row)| { + let (args, result) = func_queries.get_index(i + offset).unwrap(); + let index = &mut ColumnIndex::default(); + let slice = &mut ColumnMutSlice::from_slice(row, self.layout_sizes); + let requires = result.requires.iter(); + let mut depth_requires = result.depth_requires.iter(); + let queries = shard.queries(); + let query_map = &queries.func_queries()[self.func.index]; + let lookup = query_map + .get(args) + .expect("Cannot find query result") + .provide; + let provide = lookup.into_provide(); + result + .output + .as_ref() + .unwrap() + .iter() + .for_each(|&o| slice.push_output(index, o)); + slice.push_aux(index, provide.last_nonce); + slice.push_aux(index, provide.last_count); + // provenance and range check + if self.func.partial { + let num_requires = (DEPTH_W / 2) + (DEPTH_W % 2); + let depth: [u8; DEPTH_W] = result.depth.to_le_bytes(); + for b in depth { + slice.push_aux(index, F::from_canonical_u8(b)); + } + for _ in 0..num_requires { + let lookup = depth_requires.next().expect("Not enough require hints"); + slice.push_require(index, lookup.into_require()); + } + } + self.func.populate_row( + args, + index, + slice, + queries, + requires, + self.toplevel, + result.depth, + depth_requires, + ); + }); + RowMajorMatrix::new(rows, width) + } + /// Per-row parallel trace generation pub fn generate_trace(&self, shard: &Shard<'_, F>) -> RowMajorMatrix { let func_queries = &shard.queries().func_queries()[self.func.index]; - let range = shard.get_func_range(self.func.index); + let range = shard.get_func_range(self.func); let width = self.width(); let non_dummy_height = range.len(); let height = non_dummy_height.next_power_of_two(); diff --git a/src/lurk/eval.rs b/src/lurk/eval.rs index e266742e..2b834b71 100644 --- a/src/lurk/eval.rs +++ b/src/lurk/eval.rs @@ -143,6 +143,11 @@ impl EvalErr { } } +const EVAL_RC: usize = 8; +const APPLY_RC: usize = 4; +const BINOP_RC: usize = 4; +const LOOKUP_RC: usize = 4; + pub fn lurk_main() -> FuncE { func!( partial fn lurk_main(full_expr_tag: [8], expr_digest: [8], env_digest: [8]): [16] { @@ -301,6 +306,7 @@ fn ingress_builtin(builtins: &BuiltinMemo<'_, F>) -> Fun input_params: [input_var].into(), output_size: 1, body: BlockE { ops, ctrl }, + rc: 1, } } @@ -435,6 +441,7 @@ fn egress_builtin(builtins: &BuiltinMemo<'_, F>) -> Func input_params: [input_var].into(), output_size: 8, body: BlockE { ops, ctrl }, + rc: 1, } } @@ -565,6 +572,7 @@ pub fn big_num_lessthan() -> FuncE { pub fn eval(builtins: &BuiltinMemo<'_, F>) -> FuncE { func!( + #[RC = EVAL_RC] partial fn eval(expr_tag, expr, env): [2] { // Constants, tags, etc let t = builtins.index("t"); @@ -1131,6 +1139,7 @@ pub fn eval_begin(builtins: &BuiltinMemo<'_, F>) -> Func pub fn eval_binop_num(builtins: &BuiltinMemo<'_, F>) -> FuncE { func!( + #[RC = BINOP_RC] partial fn eval_binop_num(head, exp1_tag, exp1, exp2_tag, exp2, env): [2] { let err_tag = Tag::Err; let num_tag = Tag::Num; @@ -1685,6 +1694,7 @@ pub fn eval_letrec() -> FuncE { pub fn apply() -> FuncE { func!( + #[RC = APPLY_RC] partial fn apply(head_tag, head, args_tag, args, args_env): [2] { // Constants, tags, etc let err_tag = Tag::Err; @@ -1765,6 +1775,7 @@ pub fn apply() -> FuncE { pub fn env_lookup() -> FuncE { func!( + #[RC = LOOKUP_RC] fn env_lookup(x_digest: [8], env): [2] { if !env { let err_tag = Tag::Err;