diff --git a/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_slice_ops.rs b/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_slice_ops.rs index 4ebe9bc46cc..6114fd35a38 100644 --- a/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_slice_ops.rs +++ b/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_slice_ops.rs @@ -185,7 +185,7 @@ mod tests { fn create_test_environment() -> (Ssa, FunctionContext, BrilligContext) { let mut builder = FunctionBuilder::new("main".to_string(), Id::test_new(0)); builder.set_runtime(RuntimeType::Brillig(InlineType::default())); - + builder.terminate_with_return(vec![]); let ssa = builder.finish(); let mut brillig_context = create_context(ssa.main_id); brillig_context.enter_context(Label::block(ssa.main_id, Id::test_new(0))); diff --git a/compiler/noirc_evaluator/src/ssa/function_builder/mod.rs b/compiler/noirc_evaluator/src/ssa/function_builder/mod.rs index cfcb80c212a..c6a07ee5905 100644 --- a/compiler/noirc_evaluator/src/ssa/function_builder/mod.rs +++ b/compiler/noirc_evaluator/src/ssa/function_builder/mod.rs @@ -155,6 +155,9 @@ impl FunctionBuilder { /// Consume the FunctionBuilder returning all the functions it has generated. pub fn finish(mut self) -> Ssa { self.finished_functions.push(self.current_function); + + Self::validate_ssa(&self.finished_functions); + Ssa::new(self.finished_functions, self.error_types) } @@ -518,6 +521,12 @@ impl FunctionBuilder { pub fn record_error_type(&mut self, selector: ErrorSelector, typ: HirType) { self.error_types.insert(selector, typ); } + + fn validate_ssa(functions: &[Function]) { + for function in functions { + function.assert_valid(); + } + } } impl std::ops::Index for FunctionBuilder { diff --git a/compiler/noirc_evaluator/src/ssa/interpreter/tests/instructions.rs b/compiler/noirc_evaluator/src/ssa/interpreter/tests/instructions.rs index eb4b05200b4..4272a3cb787 100644 --- a/compiler/noirc_evaluator/src/ssa/interpreter/tests/instructions.rs +++ b/compiler/noirc_evaluator/src/ssa/interpreter/tests/instructions.rs @@ -23,12 +23,12 @@ fn add() { " acir(inline) fn main f0 { b0(): - v0 = add i32 2, i32 100 + v0 = add u32 2, u32 100 return v0 } ", ); - assert_eq!(value, Value::Numeric(NumericValue::I32(102))); + assert_eq!(value, Value::Numeric(NumericValue::U32(102))); } #[test] @@ -64,12 +64,12 @@ fn sub() { " acir(inline) fn main f0 { b0(): - v0 = sub i32 10101, i32 101 + v0 = sub u32 10101, u32 101 return v0 } ", ); - assert_eq!(value, Value::Numeric(NumericValue::I32(10000))); + assert_eq!(value, Value::Numeric(NumericValue::U32(10000))); } #[test] @@ -79,7 +79,8 @@ fn sub_underflow() { acir(inline) fn main f0 { b0(): v0 = sub i8 136, i8 10 // -120 - 10 - return v0 + v1 = truncate v0 to 8 bits, max_bit_size: 9 + return v1 } ", ); diff --git a/compiler/noirc_evaluator/src/ssa/ir/function.rs b/compiler/noirc_evaluator/src/ssa/ir/function.rs index a6cd515bc15..dd85cfa936f 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/function.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/function.rs @@ -8,7 +8,7 @@ use serde::{Deserialize, Serialize}; use super::basic_block::BasicBlockId; use super::dfg::{DataFlowGraph, GlobalsGraph}; -use super::instruction::TerminatorInstruction; +use super::instruction::{BinaryOp, Instruction, TerminatorInstruction}; use super::map::Id; use super::types::{NumericType, Type}; use super::value::{Value, ValueId}; @@ -236,6 +236,12 @@ impl Function { /// /// Panics on malformed functions. pub(crate) fn assert_valid(&self) { + self.assert_single_return_block(); + self.validate_signed_arithmetic_invariants(); + } + + /// Checks that the function has only one return block. + fn assert_single_return_block(&self) { let reachable_blocks = self.reachable_blocks(); // We assume that all functions have a single block which terminates with a `return` instruction. @@ -253,6 +259,61 @@ impl Function { panic!("Function {} has multiple return blocks {return_blocks:?}", self.id()) } } + + /// Validates that any checked signed add/sub is followed by the expected truncate. + fn validate_signed_arithmetic_invariants(&self) { + // State for tracking the last signed binary addition/subtraction + let mut signed_binary_op = None; + for block in self.reachable_blocks() { + for instruction in self.dfg[block].instructions() { + match &self.dfg[*instruction] { + Instruction::Binary(binary) => { + signed_binary_op = None; + + match binary.operator { + // We are only validating addition/subtraction + BinaryOp::Add { unchecked: false } + | BinaryOp::Sub { unchecked: false } => {} + // Otherwise, move onto the next instruction + _ => continue, + } + + // Assume rhs_type is the same as lhs_type + let lhs_type = self.dfg.type_of_value(binary.lhs); + if let Type::Numeric(NumericType::Signed { bit_size }) = lhs_type { + let results = self.dfg.instruction_results(*instruction); + signed_binary_op = Some((bit_size, results[0])); + } + } + Instruction::Truncate { value, bit_size, max_bit_size } => { + let Some((signed_op_bit_size, signed_op_res)) = signed_binary_op.take() + else { + continue; + }; + assert_eq!( + *bit_size, signed_op_bit_size, + "ICE: Correct truncate must follow the result of a checked signed add/sub" + ); + assert_eq!( + *max_bit_size, + *bit_size + 1, + "ICE: Correct truncate must follow the result of a checked signed add/sub" + ); + assert_eq!( + *value, signed_op_res, + "ICE: Correct truncate must follow the result of a checked signed add/sub" + ); + } + _ => { + signed_binary_op = None; + } + } + } + } + if signed_binary_op.is_some() { + panic!("ICE: Truncate must follow the result of a checked signed add/sub"); + } + } } impl Clone for Function { @@ -295,3 +356,155 @@ fn sign_smoke() { signature.params.push(Type::Numeric(NumericType::NativeField)); signature.returns.push(Type::Numeric(NumericType::Unsigned { bit_size: 32 })); } + +#[cfg(test)] +mod validation { + use crate::ssa::ssa_gen::Ssa; + + #[test] + #[should_panic(expected = "ICE: Truncate must follow the result of a checked signed add/sub")] + fn lone_signed_sub_acir() { + let src = r" + acir(inline) pure fn main f0 { + b0(v0: i16, v1: i16): + v2 = sub v0, v1 + return v2 + } + "; + + let _ = Ssa::from_str(src); + } + + #[test] + #[should_panic(expected = "ICE: Truncate must follow the result of a checked signed add/sub")] + fn lone_signed_sub_brillig() { + // This matches the test above we just want to make sure it holds in the Brillig runtime as well as ACIR + let src = r" + brillig(inline) pure fn main f0 { + b0(v0: i16, v1: i16): + v2 = sub v0, v1 + return v2 + } + "; + + let _ = Ssa::from_str(src); + } + + #[test] + #[should_panic(expected = "ICE: Truncate must follow the result of a checked signed add/sub")] + fn lone_signed_add_acir() { + let src = r" + acir(inline) pure fn main f0 { + b0(v0: i16, v1: i16): + v2 = add v0, v1 + return v2 + } + "; + + let _ = Ssa::from_str(src); + } + + #[test] + #[should_panic(expected = "ICE: Truncate must follow the result of a checked signed add/sub")] + fn lone_signed_add_brillig() { + let src = r" + brillig(inline) pure fn main f0 { + b0(v0: i16, v1: i16): + v2 = add v0, v1 + return v2 + } + "; + + let _ = Ssa::from_str(src); + } + + #[test] + #[should_panic( + expected = "ICE: Correct truncate must follow the result of a checked signed add/sub" + )] + fn signed_sub_bad_truncate_bit_size() { + let src = r" + acir(inline) pure fn main f0 { + b0(v0: i16, v1: i16): + v2 = sub v0, v1 + v3 = truncate v2 to 32 bits, max_bit_size: 33 + return v3 + } + "; + + let _ = Ssa::from_str(src); + } + + #[test] + #[should_panic( + expected = "ICE: Correct truncate must follow the result of a checked signed add/sub" + )] + fn signed_sub_bad_truncate_max_bit_size() { + let src = r" + acir(inline) pure fn main f0 { + b0(v0: i16, v1: i16): + v2 = sub v0, v1 + v3 = truncate v2 to 16 bits, max_bit_size: 18 + return v3 + } + "; + + let _ = Ssa::from_str(src); + } + + #[test] + fn truncate_follows_signed_sub_acir() { + let src = r" + acir(inline) pure fn main f0 { + b0(v0: i16, v1: i16): + v2 = sub v0, v1 + v3 = truncate v2 to 16 bits, max_bit_size: 17 + return v3 + } + "; + + let _ = Ssa::from_str(src); + } + + #[test] + fn truncate_follows_signed_sub_brillig() { + let src = r" + brillig(inline) pure fn main f0 { + b0(v0: i16, v1: i16): + v2 = sub v0, v1 + v3 = truncate v2 to 16 bits, max_bit_size: 17 + return v3 + } + "; + + let _ = Ssa::from_str(src); + } + + #[test] + fn truncate_follows_signed_add_acir() { + let src = r" + acir(inline) pure fn main f0 { + b0(v0: i16, v1: i16): + v2 = add v0, v1 + v3 = truncate v2 to 16 bits, max_bit_size: 17 + return v3 + } + "; + + let _ = Ssa::from_str(src); + } + + #[test] + fn truncate_follows_signed_add_brillig() { + let src = r" + brillig(inline) pure fn main f0 { + b0(v0: i16, v1: i16): + v2 = add v0, v1 + v3 = truncate v2 to 16 bits, max_bit_size: 17 + return v3 + } + "; + + let _ = Ssa::from_str(src); + } +} diff --git a/compiler/noirc_evaluator/src/ssa/ir/post_order.rs b/compiler/noirc_evaluator/src/ssa/ir/post_order.rs index 183d2ce4b49..89adb42ec96 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/post_order.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/post_order.rs @@ -153,6 +153,9 @@ mod tests { builder.switch_to_block(block_c_id); builder.terminate_with_jmp(block_f_id, vec![]); + builder.switch_to_block(block_f_id); + builder.terminate_with_return(vec![]); + let ssa = builder.finish(); let func = ssa.main(); let post_order = PostOrder::with_function(func); diff --git a/compiler/noirc_evaluator/src/ssa/opt/checked_to_unchecked.rs b/compiler/noirc_evaluator/src/ssa/opt/checked_to_unchecked.rs index a2804703cc7..30bf745635e 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/checked_to_unchecked.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/checked_to_unchecked.rs @@ -292,7 +292,8 @@ mod tests { v2 = cast v0 as i32 v3 = cast v1 as i32 v4 = add v2, v3 - return v4 + v5 = truncate v4 to 32 bits, max_bit_size: 33 + return v5 } "; let ssa = Ssa::from_str(src).unwrap(); @@ -307,7 +308,8 @@ mod tests { b0(v0: i16): v1 = cast v0 as i32 v2 = sub i32 65536, v1 - return v2 + v3 = truncate v2 to 32 bits, max_bit_size: 33 + return v3 } "; let ssa = Ssa::from_str(src).unwrap(); diff --git a/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs b/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs index dd6d07c2037..fb69a01f9cd 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs @@ -1482,7 +1482,8 @@ mod test { brillig(inline) fn one f1 { b0(v0: i32, v1: i32): v2 = add v0, v1 - return v2 + v3 = truncate v2 to 32 bits, max_bit_size: 33 + return v3 } "; let ssa = Ssa::from_str(src).unwrap(); diff --git a/compiler/noirc_evaluator/src/ssa/opt/loop_invariant.rs b/compiler/noirc_evaluator/src/ssa/opt/loop_invariant.rs index 66c5bfeabe8..e5e5f88b494 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/loop_invariant.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/loop_invariant.rs @@ -1403,37 +1403,37 @@ mod test { // uses a checked add in `b3`. let src = " brillig(inline) fn main f0 { - b0(v0: i32, v1: i32): - jmp b1(i32 0) - b1(v2: i32): - v5 = lt v2, i32 4 + b0(v0: u32, v1: u32): + jmp b1(u32 0) + b1(v2: u32): + v5 = lt v2, u32 4 jmpif v5 then: b3, else: b2 b2(): return b3(): v6 = mul v0, v1 - constrain v6 == i32 6 - v8 = add v2, i32 1 + constrain v6 == u32 6 + v8 = add v2, u32 1 jmp b1(v8) } "; let ssa = Ssa::from_str(src).unwrap(); - // `v8 = add v2, i32 1` in b3 should now be `v9 = unchecked_add v2, i32 1` in b3 + // `v8 = add v2, u32 1` in b3 should now be `v9 = unchecked_add v2, u32 1` in b3 let expected = " brillig(inline) fn main f0 { - b0(v0: i32, v1: i32): + b0(v0: u32, v1: u32): v3 = mul v0, v1 - constrain v3 == i32 6 - jmp b1(i32 0) - b1(v2: i32): - v7 = lt v2, i32 4 + constrain v3 == u32 6 + jmp b1(u32 0) + b1(v2: u32): + v7 = lt v2, u32 4 jmpif v7 then: b3, else: b2 b2(): return b3(): - v9 = unchecked_add v2, i32 1 + v9 = unchecked_add v2, u32 1 jmp b1(v9) } "; diff --git a/compiler/noirc_evaluator/src/ssa/opt/simplify_cfg.rs b/compiler/noirc_evaluator/src/ssa/opt/simplify_cfg.rs index e61c20f2e24..cf3697c683a 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/simplify_cfg.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/simplify_cfg.rs @@ -724,7 +724,7 @@ mod test { v2 = lt i16 3, v0 jmpif v2 then: b1, else: b2 b1(): - v4 = add i16 1, v0 + v4 = unchecked_add i16 1, v0 jmp b3() b2(): jmp b3()