diff --git a/compiler/noirc_evaluator/src/ssa/opt/checked_to_unchecked.rs b/compiler/noirc_evaluator/src/ssa/opt/checked_to_unchecked.rs index 30bf745635e..2a23f92f2a1 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/checked_to_unchecked.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/checked_to_unchecked.rs @@ -324,6 +324,8 @@ mod tests { b0(v0: u1, v1: i32): v2 = cast v0 as i32 v3 = mul v2, v1 + v4 = cast v3 as u64 + v6 = truncate v4 to 32 bits, max_bit_size: 64 return v2 } "; diff --git a/compiler/noirc_evaluator/src/ssa/opt/loop_invariant.rs b/compiler/noirc_evaluator/src/ssa/opt/loop_invariant.rs index b208bdb68fd..2d21173930b 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/loop_invariant.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/loop_invariant.rs @@ -1093,7 +1093,7 @@ mod test { v9 = unchecked_add v2, i32 1 jmp b1(v9) b6(): - v10 = mul v0, v1 + v10 = unchecked_mul v0, v1 constrain v10 == i32 6 v12 = unchecked_add v3, i32 1 jmp b4(v12) @@ -1110,7 +1110,7 @@ mod test { let expected = " brillig(inline) fn main f0 { b0(v0: i32, v1: i32): - v4 = mul v0, v1 + v4 = unchecked_mul v0, v1 constrain v4 == i32 6 jmp b1(i32 0) b1(v2: i32): diff --git a/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs b/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs index 857d596e044..d0e9a6eb56f 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs @@ -1490,9 +1490,9 @@ mod tests { v11 = array_get v0, index v10 -> u64 v12 = add v11, u64 1 v13 = array_set v9, index v10, value v12 - v15 = add v1, {idx_type} 1 + v15 = unchecked_add v1, {idx_type} 1 store v13 at v4 - v16 = add v1, {idx_type} 1 // duplicate + v16 = unchecked_add v1, {idx_type} 1 // duplicate jmp b1(v16) b2(): v8 = load v4 -> [u64; 6] diff --git a/compiler/noirc_evaluator/src/ssa/parser/into_ssa.rs b/compiler/noirc_evaluator/src/ssa/parser/into_ssa.rs index cff5c45ec7b..e8a78c63ad0 100644 --- a/compiler/noirc_evaluator/src/ssa/parser/into_ssa.rs +++ b/compiler/noirc_evaluator/src/ssa/parser/into_ssa.rs @@ -16,7 +16,6 @@ use crate::ssa::{ value::ValueId, }, opt::pure::FunctionPurities, - validation::validate_function, }; use super::{ @@ -558,10 +557,6 @@ impl Translator { // before each print. ssa.normalize_ids(); - for function in ssa.functions.values() { - validate_function(function); - } - ssa } diff --git a/compiler/noirc_evaluator/src/ssa/validation/mod.rs b/compiler/noirc_evaluator/src/ssa/validation/mod.rs index 55015c98898..e237c24bae1 100644 --- a/compiler/noirc_evaluator/src/ssa/validation/mod.rs +++ b/compiler/noirc_evaluator/src/ssa/validation/mod.rs @@ -29,7 +29,13 @@ struct Validator<'f> { function: &'f Function, // State for truncate-after-signed-sub validation // Stores: Option<(bit_size, result)> - signed_binary_op: Option<(u32, ValueId)>, + signed_binary_op: Option, +} + +#[derive(Debug)] +enum PendingSignedOverflowOp { + AddOrSub { bit_size: u32, result: ValueId }, + Mul { bit_size: u32, mul_result: ValueId, cast_result: Option }, } impl<'f> Validator<'f> { @@ -37,47 +43,100 @@ impl<'f> Validator<'f> { Self { function, signed_binary_op: None } } - /// Validates that any checked signed add/sub is followed by the expected truncate. - fn validate_truncate_after_signed_sub(&mut self, instruction: InstructionId) { + /// Validates that any checked signed add/sub/mul are followed by the appropriate instructions. + /// Signed overflow is many instructions but we validate up to the initial truncate. + /// + /// Expects the following SSA form for signed checked operations: + /// Add/Sub -> Truncate + /// Mul -> Cast -> Truncate + fn validate_signed_op_overflow_pattern(&mut self, instruction: InstructionId) { let dfg = &self.function.dfg; match &dfg[instruction] { Instruction::Binary(binary) => { - self.signed_binary_op = None; - - match binary.operator { - // Only validating checked addition/subtraction - BinaryOp::Add { unchecked: false } | BinaryOp::Sub { unchecked: false } => {} - // Otherwise, move onto the next instruction - _ => return, + // Only reset if we are starting a new tracked op. + // We do not reset on unrelated ops. If we already an op pending, we have an ill formed signed op. + if self.signed_binary_op.is_some() { + panic!("Signed binary operation does not follow overflow pattern"); } // Assumes rhs_type is the same as lhs_type let lhs_type = dfg.type_of_value(binary.lhs); - if let Type::Numeric(NumericType::Signed { bit_size }) = lhs_type { - let results = dfg.instruction_results(instruction); - self.signed_binary_op = Some((bit_size, results[0])); + let Type::Numeric(NumericType::Signed { bit_size }) = lhs_type else { + return; + }; + + let result = dfg.instruction_results(instruction)[0]; + match binary.operator { + BinaryOp::Mul { unchecked: false } => { + self.signed_binary_op = Some(PendingSignedOverflowOp::Mul { + bit_size, + mul_result: result, + cast_result: None, + }); + } + BinaryOp::Add { unchecked: false } | BinaryOp::Sub { unchecked: false } => { + self.signed_binary_op = + Some(PendingSignedOverflowOp::AddOrSub { bit_size, result }); + } + _ => {} } } Instruction::Truncate { value, bit_size, max_bit_size } => { - let Some((signed_op_bit_size, signed_op_res)) = self.signed_binary_op.take() else { - return; - }; - assert_eq!( - *bit_size, signed_op_bit_size, - "ICE: Correct truncate must follow the result of a checked signed add/sub" - ); - assert_eq!( - *max_bit_size, - *bit_size + 1, - "ICE: Correct truncate must follow the result of a checked signed add/sub" - ); - assert_eq!( - *value, signed_op_res, - "ICE: Correct truncate must follow the result of a checked signed add/sub" - ); + // Only a truncate can reset the signed binary op state + match self.signed_binary_op.take() { + Some(PendingSignedOverflowOp::AddOrSub { + bit_size: expected_bit_size, + result, + }) => { + assert_eq!(*bit_size, expected_bit_size); + assert_eq!(*max_bit_size, expected_bit_size + 1); + assert_eq!(*value, result); + } + Some(PendingSignedOverflowOp::Mul { + bit_size: expected_bit_size, + cast_result: Some(cast), + .. + }) => { + assert_eq!(*bit_size, expected_bit_size); + assert_eq!(*max_bit_size, 2 * expected_bit_size); + assert_eq!(*value, cast); + } + Some(PendingSignedOverflowOp::Mul { + cast_result: None, + .. + }) => { + panic!("Truncate not matched to signed overflow pattern"); + } + None => { + // Do nothing as there is no overflow op pending + } + } + } + Instruction::Cast(value, typ) => { + match &mut self.signed_binary_op { + Some(PendingSignedOverflowOp::AddOrSub { .. }) => { + panic!( + "Invalid cast inserted after signed checked Add/Sub. It must be followed immediately by truncate" + ); + } + Some(PendingSignedOverflowOp::Mul { + bit_size: expected_bit_size, + mul_result, + cast_result, + }) => { + assert_eq!(typ.bit_size(), 2 * *expected_bit_size); + assert_eq!(*value, *mul_result); + *cast_result = Some(dfg.instruction_results(instruction)[0]); + } + None => { + // Do nothing as there is no overflow op pending + } + } } _ => { - self.signed_binary_op = None; + if self.signed_binary_op.is_some() { + panic!("Signed binary operation does not follow overflow pattern"); + } } } } @@ -182,13 +241,13 @@ impl<'f> Validator<'f> { for block in self.function.reachable_blocks() { for instruction in self.function.dfg[block].instructions() { - self.validate_truncate_after_signed_sub(*instruction); + self.validate_signed_op_overflow_pattern(*instruction); self.type_check_instruction(*instruction); } } if self.signed_binary_op.is_some() { - panic!("ICE: Truncate must follow the result of a checked signed add/sub"); + panic!("Signed binary operation does not follow overflow pattern"); } } } @@ -206,7 +265,7 @@ mod tests { use crate::ssa::ssa_gen::Ssa; #[test] - #[should_panic(expected = "ICE: Truncate must follow the result of a checked signed add/sub")] + #[should_panic(expected = "Signed binary operation does not follow overflow pattern")] fn lone_signed_sub_acir() { let src = r" acir(inline) pure fn main f0 { @@ -220,7 +279,7 @@ mod tests { } #[test] - #[should_panic(expected = "ICE: Truncate must follow the result of a checked signed add/sub")] + #[should_panic(expected = "Signed binary operation does not follow overflow pattern")] fn lone_signed_sub_brillig() { // This matches the test above we just want to make sure it holds in the Brillig runtime as well as ACIR let src = r" @@ -235,7 +294,7 @@ mod tests { } #[test] - #[should_panic(expected = "ICE: Truncate must follow the result of a checked signed add/sub")] + #[should_panic(expected = "Signed binary operation does not follow overflow pattern")] fn lone_signed_add_acir() { let src = r" acir(inline) pure fn main f0 { @@ -249,7 +308,7 @@ mod tests { } #[test] - #[should_panic(expected = "ICE: Truncate must follow the result of a checked signed add/sub")] + #[should_panic(expected = "Signed binary operation does not follow overflow pattern")] fn lone_signed_add_brillig() { let src = r" brillig(inline) pure fn main f0 { @@ -263,9 +322,7 @@ mod tests { } #[test] - #[should_panic( - expected = "ICE: Correct truncate must follow the result of a checked signed add/sub" - )] + #[should_panic(expected = "assertion `left == right` failed")] fn signed_sub_bad_truncate_bit_size() { let src = r" acir(inline) pure fn main f0 { @@ -280,9 +337,7 @@ mod tests { } #[test] - #[should_panic( - expected = "ICE: Correct truncate must follow the result of a checked signed add/sub" - )] + #[should_panic(expected = "assertion `left == right` failed")] fn signed_sub_bad_truncate_max_bit_size() { let src = r" acir(inline) pure fn main f0 { @@ -352,6 +407,160 @@ mod tests { let _ = Ssa::from_str(src); } + #[test] + #[should_panic( + expected = "Invalid cast inserted after signed checked Add/Sub. It must be followed immediately by truncate" + )] + fn cast_and_truncate_follows_signed_add() { + let src = r" + brillig(inline) pure fn main f0 { + b0(v0: i16, v1: i16): + v2 = add v0, v1 + v3 = cast v2 as i32 + v4 = truncate v2 to 16 bits, max_bit_size: 17 + return v4 + } + "; + + let _ = Ssa::from_str(src); + } + + #[test] + #[should_panic(expected = "Signed binary operation does not follow overflow pattern")] + fn signed_mul_followed_by_binary() { + let src = " + acir(inline) predicate_pure fn main f0 { + b0(v0: Field): + v1 = truncate v0 to 16 bits, max_bit_size: 254 + v2 = cast v1 as i16 + v3 = mul v2, v2 + v4 = div v3, v2 + return v4 + } + "; + let _ = Ssa::from_str(src); + } + + #[test] + fn signed_mul_followed_by_cast_and_truncate() { + let src = " + acir(inline) predicate_pure fn main f0 { + b0(v0: i16): + v1 = mul v0, v0 + v2 = cast v1 as u32 + v3 = truncate v2 to 16 bits, max_bit_size: 32 + v4 = cast v3 as i16 + return v4 + } + "; + let _ = Ssa::from_str(src); + } + + #[test] + #[should_panic(expected = "assertion `left == right` failed")] + fn signed_mul_followed_by_bad_cast() { + let src = " + acir(inline) predicate_pure fn main f0 { + b0(v0: i16): + v1 = mul v0, v0 + v2 = cast v0 as u16 + v3 = truncate v2 to 16 bits, max_bit_size: 32 + v4 = cast v3 as i16 + return v4 + } + "; + let _ = Ssa::from_str(src); + } + + #[test] + #[should_panic(expected = "assertion `left == right` failed")] + fn signed_mul_followed_by_bad_cast_bit_size() { + let src = " + acir(inline) predicate_pure fn main f0 { + b0(v0: i16): + v1 = mul v0, v0 + v2 = cast v1 as u16 + v3 = truncate v2 to 16 bits, max_bit_size: 32 + v4 = cast v3 as i16 + return v4 + } + "; + let _ = Ssa::from_str(src); + } + + #[test] + #[should_panic(expected = "assertion `left == right` failed")] + fn signed_mul_followed_by_bad_truncate_bit_size() { + let src = " + acir(inline) predicate_pure fn main f0 { + b0(v0: i16): + v1 = mul v0, v0 + v2 = cast v1 as u32 + v3 = truncate v2 to 32 bits, max_bit_size: 32 + v4 = cast v3 as i16 + return v4 + } + "; + let _ = Ssa::from_str(src); + } + + #[test] + #[should_panic(expected = "assertion `left == right` failed")] + fn signed_mul_followed_by_bad_truncate_max_bit_size() { + let src = " + acir(inline) predicate_pure fn main f0 { + b0(v0: i16): + v1 = mul v0, v0 + v2 = cast v1 as u32 + v3 = truncate v2 to 16 bits, max_bit_size: 33 + v4 = cast v3 as i16 + return v4 + } + "; + let _ = Ssa::from_str(src); + } + + #[test] + #[should_panic(expected = "Signed binary operation does not follow overflow pattern")] + fn lone_signed_mul() { + let src = r" + acir(inline) pure fn main f0 { + b0(v0: i16, v1: i16): + v2 = mul v0, v1 + return v2 + } + "; + + let _ = Ssa::from_str(src); + } + + #[test] + #[should_panic(expected = "Truncate not matched to signed overflow pattern")] + fn signed_mul_followed_by_truncate_but_no_cast() { + let src = r" + acir(inline) pure fn main f0 { + b0(v0: i16, v1: i16): + v2 = mul v0, v1 + v3 = truncate v2 to 16 bits, max_bit_size: 33 + return v3 + } + "; + + let _ = Ssa::from_str(src); + } + + #[test] + fn lone_truncate() { + let src = r" + acir(inline) pure fn main f0 { + b0(v0: i16): + v1 = truncate v0 to 8 bits, max_bit_size: 8 + return v1 + } + "; + let _ = Ssa::from_str(src); + } + #[test] #[should_panic(expected = "Cannot use `lt` with field elements")] fn disallows_comparing_fields_with_lt() {