diff --git a/compiler/noirc_evaluator/src/ssa/function_builder/mod.rs b/compiler/noirc_evaluator/src/ssa/function_builder/mod.rs index 9e70acf6622..7a6aa468ae3 100644 --- a/compiler/noirc_evaluator/src/ssa/function_builder/mod.rs +++ b/compiler/noirc_evaluator/src/ssa/function_builder/mod.rs @@ -28,7 +28,6 @@ use super::{ }, opt::pure::FunctionPurities, ssa_gen::Ssa, - validation::validate_function, }; /// The per-function context for each ssa function being generated. @@ -158,8 +157,6 @@ impl FunctionBuilder { pub fn finish(mut self) -> Ssa { self.finished_functions.push(self.current_function); - Self::validate_ssa(&self.finished_functions); - Ssa::new(self.finished_functions, self.error_types) } @@ -524,12 +521,6 @@ impl FunctionBuilder { pub fn record_error_type(&mut self, selector: ErrorSelector, typ: HirType) { self.error_types.insert(selector, typ); } - - fn validate_ssa(functions: &[Function]) { - for function in functions { - validate_function(function); - } - } } impl std::ops::Index for FunctionBuilder { diff --git a/compiler/noirc_evaluator/src/ssa/interpreter/tests/instructions.rs b/compiler/noirc_evaluator/src/ssa/interpreter/tests/instructions.rs index 6b52388fbdc..6227ca744f1 100644 --- a/compiler/noirc_evaluator/src/ssa/interpreter/tests/instructions.rs +++ b/compiler/noirc_evaluator/src/ssa/interpreter/tests/instructions.rs @@ -1121,10 +1121,12 @@ fn test_range_and_xor_bb() { acir(inline) fn test_and_xor f2 { b0(v0: Field, v1: Field): - v2 = cast v0 as u8 - v4 = cast v1 as u8 - v8 = and v2, v4 - v9 = xor v2, v4 + v2 = truncate v0 to 8 bits, max_bit_size: 254 + v3 = cast v2 as u8 + v4 = truncate v1 to 8 bits, max_bit_size: 254 + v5 = cast v4 as u8 + v8 = and v3, v5 + v9 = xor v3, v5 return v9 } diff --git a/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs b/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs index 7cddafa824e..ba7b5309e9e 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs @@ -1472,8 +1472,9 @@ mod test { b2(): v7 = cast v2 as Field v9 = add v7, Field 1 - v10 = cast v9 as u8 - store v10 at v6 + v10 = truncate v9 to 8 bits, max_bit_size: 254 + v11 = cast v10 as u8 + store v11 at v6 jmp b3() b3(): constrain v5 == u1 1 @@ -1485,7 +1486,6 @@ mod test { "; let ssa = Ssa::from_str(src).unwrap(); - let flattened_ssa = ssa.flatten_cfg(); let main = flattened_ssa.main(); @@ -1515,21 +1515,22 @@ mod test { enable_side_effects v5 v8 = cast v2 as Field v10 = add v8, Field 1 - v11 = cast v10 as u8 - v12 = load v6 -> u8 - v13 = not v5 - v14 = cast v4 as u8 - v15 = cast v13 as u8 - v16 = unchecked_mul v14, v11 + v11 = truncate v10 to 8 bits, max_bit_size: 254 + v12 = cast v11 as u8 + v13 = load v6 -> u8 + v14 = not v5 + v15 = cast v4 as u8 + v16 = cast v14 as u8 v17 = unchecked_mul v15, v12 - v18 = unchecked_add v16, v17 - store v18 at v6 - enable_side_effects v13 - v19 = load v6 -> u8 - v20 = cast v13 as u8 - v21 = cast v4 as u8 - v22 = unchecked_mul v21, v19 - store v22 at v6 + v18 = unchecked_mul v16, v13 + v19 = unchecked_add v17, v18 + store v19 at v6 + enable_side_effects v14 + v20 = load v6 -> u8 + v21 = cast v14 as u8 + v22 = cast v4 as u8 + v23 = unchecked_mul v22, v20 + store v23 at v6 enable_side_effects u1 1 constrain v5 == u1 1 return diff --git a/compiler/noirc_evaluator/src/ssa/opt/remove_bit_shifts.rs b/compiler/noirc_evaluator/src/ssa/opt/remove_bit_shifts.rs index 0169b0e1eff..005a944acd3 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/remove_bit_shifts.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/remove_bit_shifts.rs @@ -115,7 +115,7 @@ impl Context<'_, '_, '_> { return InsertInstructionResult::SimplifiedTo(zero).first(); } } - let pow = self.numeric_constant(FieldElement::from(rhs_bit_size_pow_2), typ); + let pow = self.field_constant(FieldElement::from(rhs_bit_size_pow_2)); let max_lhs_bits = self.context.dfg.get_value_max_num_bits(lhs); let max_bit_size = max_lhs_bits + bit_shift_size; @@ -128,9 +128,8 @@ impl Context<'_, '_, '_> { let u8_type = NumericType::unsigned(8); let bit_size_var = self.numeric_constant(FieldElement::from(bit_size as u128), u8_type); let overflow = self.insert_binary(rhs, BinaryOp::Lt, bit_size_var); - let predicate = self.insert_cast(overflow, typ); + let predicate = self.insert_cast(overflow, NumericType::NativeField); let pow = self.pow(base, rhs); - let pow = self.insert_cast(pow, typ); // Unchecked mul because `predicate` will be 1 or 0 ( @@ -140,14 +139,13 @@ impl Context<'_, '_, '_> { }; if max_bit <= bit_size { + let pow = self.insert_cast(pow, typ); // Unchecked mul as it can't overflow self.insert_binary(lhs, BinaryOp::Mul { unchecked: true }, pow) } else { let lhs_field = self.insert_cast(lhs, NumericType::NativeField); - let pow_field = self.insert_cast(pow, NumericType::NativeField); // Unchecked mul as this is a wrapping operation that we later truncate - let result = - self.insert_binary(lhs_field, BinaryOp::Mul { unchecked: true }, pow_field); + let result = self.insert_binary(lhs_field, BinaryOp::Mul { unchecked: true }, pow); let result = self.insert_truncate(result, bit_size, max_bit); self.insert_cast(result, typ) } @@ -184,7 +182,7 @@ impl Context<'_, '_, '_> { if lhs_typ.is_unsigned() { // unsigned right bit shift is just a normal division let result = self.insert_binary(lhs, BinaryOp::Div, pow); - // In case of overflow, pow is 1, because rhs was nullified, so we return explicitly 0. + // In case of overflow, pow is 1, because rhs was nullified, so we return explicitly 0. return self.insert_binary( rhs_is_less_than_bit_size_with_lhs_typ, BinaryOp::Mul { unchecked: true }, @@ -262,38 +260,44 @@ impl Context<'_, '_, '_> { /// } fn pow(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId { let typ = self.context.dfg.type_of_value(rhs); - if let Type::Numeric(NumericType::Unsigned { bit_size }) = typ { - let to_bits = self.context.dfg.import_intrinsic(Intrinsic::ToBits(Endian::Little)); - let result_types = vec![Type::Array(Arc::new(vec![Type::bool()]), bit_size)]; - - // A call to ToBits can only be done with a field argument (rhs is always u8 here) - let rhs_as_field = self.insert_cast(rhs, NumericType::NativeField); - let rhs_bits = self.insert_call(to_bits, vec![rhs_as_field], result_types); - - let rhs_bits = rhs_bits[0]; - let one = self.field_constant(FieldElement::one()); - let mut r = one; - // All operations are unchecked as we're acting on Field types (which are always unchecked) - for i in 1..bit_size + 1 { - let idx = self.numeric_constant( - FieldElement::from((bit_size - i) as i128), - NumericType::length_type(), - ); - let b = self.insert_array_get(rhs_bits, idx, Type::bool()); - let not_b = self.insert_not(b); - let b = self.insert_cast(b, NumericType::NativeField); - let not_b = self.insert_cast(not_b, NumericType::NativeField); - - let r_squared = self.insert_binary(r, BinaryOp::Mul { unchecked: true }, r); - let r1 = self.insert_binary(r_squared, BinaryOp::Mul { unchecked: true }, not_b); - let a = self.insert_binary(r_squared, BinaryOp::Mul { unchecked: true }, lhs); - let r2 = self.insert_binary(a, BinaryOp::Mul { unchecked: true }, b); - r = self.insert_binary(r1, BinaryOp::Add { unchecked: true }, r2); - } - r - } else { + let Type::Numeric(NumericType::Unsigned { bit_size }) = typ else { unreachable!("Value must be unsigned in power operation"); + }; + + let to_bits = self.context.dfg.import_intrinsic(Intrinsic::ToBits(Endian::Little)); + let result_types = vec![Type::Array(Arc::new(vec![Type::bool()]), bit_size)]; + + // A call to ToBits can only be done with a field argument (rhs is always u8 here) + let rhs_as_field = self.insert_cast(rhs, NumericType::NativeField); + let rhs_bits = self.insert_call(to_bits, vec![rhs_as_field], result_types); + + let rhs_bits = rhs_bits[0]; + let one = self.field_constant(FieldElement::one()); + let mut r = one; + // All operations are unchecked as we're acting on Field types (which are always unchecked) + for i in 1..bit_size + 1 { + let idx = self.numeric_constant( + FieldElement::from((bit_size - i) as i128), + NumericType::length_type(), + ); + let b = self.insert_array_get(rhs_bits, idx, Type::bool()); + let not_b = self.insert_not(b); + let b = self.insert_cast(b, NumericType::NativeField); + let not_b = self.insert_cast(not_b, NumericType::NativeField); + + let r_squared = self.insert_binary(r, BinaryOp::Mul { unchecked: true }, r); + let r1 = self.insert_binary(r_squared, BinaryOp::Mul { unchecked: true }, not_b); + let a = self.insert_binary(r_squared, BinaryOp::Mul { unchecked: true }, lhs); + let r2 = self.insert_binary(a, BinaryOp::Mul { unchecked: true }, b); + r = self.insert_binary(r1, BinaryOp::Add { unchecked: true }, r2); } + + assert!( + matches!(self.context.dfg.type_of_value(r).unwrap_numeric(), NumericType::NativeField), + "ICE: pow is expected to always return a NativeField" + ); + + r } pub(crate) fn field_constant(&mut self, constant: FieldElement) -> ValueId { @@ -452,11 +456,12 @@ mod tests { "; let ssa = Ssa::from_str(src).unwrap(); let ssa = ssa.remove_bit_shifts(); + assert_ssa_snapshot!(ssa, @r" acir(inline) fn main f0 { b0(v0: u32, v1: u8): v3 = lt v1, u8 32 - v4 = cast v3 as u32 + v4 = cast v3 as Field v5 = cast v1 as Field v7 = call to_le_bits(v5) -> [u1; 8] v9 = array_get v7, index u32 7 -> u1 @@ -528,15 +533,13 @@ mod tests { v83 = mul v81, Field 2 v84 = mul v83, v79 v85 = add v82, v84 - v86 = cast v85 as u32 - v87 = unchecked_mul v4, v86 - v88 = cast v0 as Field - v89 = cast v87 as Field - v90 = mul v88, v89 - v91 = truncate v90 to 32 bits, max_bit_size: 254 - v92 = cast v91 as u32 - v93 = truncate v92 to 32 bits, max_bit_size: 33 - return v92 + v86 = mul v4, v85 + v87 = cast v0 as Field + v88 = mul v87, v86 + v89 = truncate v88 to 32 bits, max_bit_size: 254 + v90 = cast v89 as u32 + v91 = truncate v90 to 32 bits, max_bit_size: 33 + return v90 } "); } diff --git a/compiler/noirc_evaluator/src/ssa/parser/into_ssa.rs b/compiler/noirc_evaluator/src/ssa/parser/into_ssa.rs index 6efd0e423e9..db4f6f24720 100644 --- a/compiler/noirc_evaluator/src/ssa/parser/into_ssa.rs +++ b/compiler/noirc_evaluator/src/ssa/parser/into_ssa.rs @@ -16,6 +16,7 @@ use crate::ssa::{ value::ValueId, }, opt::pure::FunctionPurities, + ssa_gen::validate_ssa, }; use super::{ @@ -24,8 +25,8 @@ use super::{ }; impl ParsedSsa { - pub(crate) fn into_ssa(self, simplify: bool) -> Result { - Translator::translate(self, simplify) + pub(crate) fn into_ssa(self, simplify: bool, validate: bool) -> Result { + Translator::translate(self, simplify, validate) } } @@ -61,7 +62,11 @@ struct Translator { } impl Translator { - fn translate(mut parsed_ssa: ParsedSsa, simplify: bool) -> Result { + fn translate( + mut parsed_ssa: ParsedSsa, + simplify: bool, + validate: bool, + ) -> Result { let mut translator = Self::new(&mut parsed_ssa, simplify)?; // Note that the `new` call above removed the main function, @@ -70,7 +75,13 @@ impl Translator { translator.translate_non_main_function(function)?; } - Ok(translator.finish()) + let ssa = translator.finish(); + + if validate { + validate_ssa(&ssa); + } + + Ok(ssa) } fn new(parsed_ssa: &mut ParsedSsa, simplify: bool) -> Result { diff --git a/compiler/noirc_evaluator/src/ssa/parser/mod.rs b/compiler/noirc_evaluator/src/ssa/parser/mod.rs index 3bd0563e5e6..d940bd22a4b 100644 --- a/compiler/noirc_evaluator/src/ssa/parser/mod.rs +++ b/compiler/noirc_evaluator/src/ssa/parser/mod.rs @@ -37,7 +37,7 @@ impl FromStr for Ssa { type Err = SsaErrorWithSource; fn from_str(s: &str) -> Result { - Self::from_str_impl(s, false) + Self::from_str_impl(s, false, true) } } @@ -48,19 +48,24 @@ impl Ssa { FromStr::from_str(src) } + /// Creates an Ssa object from the given string without running SSA validation + pub fn from_str_no_validation(src: &str) -> Result { + Self::from_str_impl(src, false, false) + } + /// Creates an Ssa object from the given string but trying to simplify /// each parsed instruction as it's inserted into the final SSA. pub fn from_str_simplifying(src: &str) -> Result { - Self::from_str_impl(src, true) + Self::from_str_impl(src, true, true) } - fn from_str_impl(src: &str, simplify: bool) -> Result { + fn from_str_impl(src: &str, simplify: bool, validate: bool) -> Result { let mut parser = Parser::new(src).map_err(|err| SsaErrorWithSource::parse_error(err, src))?; let parsed_ssa = parser.parse_ssa().map_err(|err| SsaErrorWithSource::parse_error(err, src))?; parsed_ssa - .into_ssa(simplify) + .into_ssa(simplify, validate) .map_err(|error| SsaErrorWithSource { src: src.to_string(), error }) } } diff --git a/compiler/noirc_evaluator/src/ssa/parser/tests.rs b/compiler/noirc_evaluator/src/ssa/parser/tests.rs index d40d1e3f277..0ac080991b2 100644 --- a/compiler/noirc_evaluator/src/ssa/parser/tests.rs +++ b/compiler/noirc_evaluator/src/ssa/parser/tests.rs @@ -288,8 +288,9 @@ fn test_cast() { let src = " acir(inline) fn main f0 { b0(v0: Field): - v1 = cast v0 as i32 - return v1 + v1 = truncate v0 to 32 bits, max_bit_size: 254 + v2 = cast v1 as i32 + return v2 } "; assert_ssa_roundtrip(src); diff --git a/compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs b/compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs index c361803ba98..97430e649c5 100644 --- a/compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs +++ b/compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs @@ -31,6 +31,7 @@ use super::ir::basic_block::BasicBlockId; use super::ir::dfg::GlobalsGraph; use super::ir::instruction::{ArrayOffset, ErrorType}; use super::ir::types::NumericType; +use super::validation::validate_function; use super::{ function_builder::data_bus::DataBus, ir::{ @@ -131,9 +132,17 @@ pub fn generate_ssa(program: Program) -> Result { } let ssa = function_context.builder.finish(); + validate_ssa(&ssa); + Ok(ssa) } +pub(crate) fn validate_ssa(ssa: &Ssa) { + for function in ssa.functions.values() { + validate_function(function); + } +} + impl FunctionContext<'_> { /// Codegen a function's body and set its return value to that of its last parameter. /// For functions returning nothing, this will be an empty list. diff --git a/compiler/noirc_evaluator/src/ssa/validation/mod.rs b/compiler/noirc_evaluator/src/ssa/validation/mod.rs index 6f0d8774812..eac124bd03a 100644 --- a/compiler/noirc_evaluator/src/ssa/validation/mod.rs +++ b/compiler/noirc_evaluator/src/ssa/validation/mod.rs @@ -12,7 +12,8 @@ //! - Check that the input values of certain instructions matches that instruction's constraint //! At the moment, only [Instruction::Binary], [Instruction::ArrayGet], and [Instruction::ArraySet] //! are type checked. -use fxhash::FxHashSet as HashSet; +use acvm::{AcirField, FieldElement}; +use fxhash::{FxHashMap as HashMap, FxHashSet as HashSet}; use crate::ssa::ir::instruction::TerminatorInstruction; @@ -30,6 +31,12 @@ struct Validator<'f> { // State for truncate-after-signed-sub validation // Stores: Option<(bit_size, result)> signed_binary_op: Option, + + // State for valid Field to integer casts + // Range checks are laid down in isolation and can make for safe casts + // If they occurred before the value being cast to a smaller type + // Stores: A set of (value being range constrained, the value's max bit size) + range_checks: HashMap, } #[derive(Debug)] @@ -40,7 +47,7 @@ enum PendingSignedOverflowOp { impl<'f> Validator<'f> { fn new(function: &'f Function) -> Self { - Self { function, signed_binary_op: None } + Self { function, signed_binary_op: None, range_checks: HashMap::default() } } /// Validates that any checked signed add/sub/mul are followed by the appropriate instructions. @@ -138,6 +145,82 @@ impl<'f> Validator<'f> { } } + /// Enforces that every cast from Field -> unsigned/signed integer must obey the following invariants: + /// The value being cast is either: + /// 1. A truncate instruction that ensures the cast is valid + /// 2. A constant value known to be in-range + /// 3. A division or other operation whose result is known to fit within the target bit size + /// + /// Our initial SSA gen only generates preceding truncates for safe casts. + /// The cases accepted here are extended past what we perform during our initial SSA gen + /// to mirror the instruction simplifier and other logic that could be accepted as a safe cast. + fn validate_field_to_integer_cast_invariant(&mut self, instruction_id: InstructionId) { + let dfg = &self.function.dfg; + + let (cast_input, typ) = match &dfg[instruction_id] { + Instruction::Cast(cast_input, typ) => (*cast_input, *typ), + Instruction::RangeCheck { value, max_bit_size, .. } => { + self.range_checks.insert(*value, *max_bit_size); + return; + } + _ => return, + }; + + if !matches!(dfg.type_of_value(cast_input), Type::Numeric(NumericType::NativeField)) { + return; + } + + let (NumericType::Signed { bit_size: target_type_size } + | NumericType::Unsigned { bit_size: target_type_size }) = typ + else { + return; + }; + + // If the cast input has already been range constrained to a bit size that fits + // in the destination type, we have a safe cast. + if let Some(max_bit_size) = self.range_checks.get(&cast_input) { + assert!(*max_bit_size <= target_type_size); + return; + } + + match &dfg[cast_input] { + Value::Instruction { instruction, .. } => match &dfg[*instruction] { + Instruction::Truncate { value: _, bit_size, max_bit_size } => { + assert!(*bit_size <= target_type_size); + assert!(*max_bit_size <= FieldElement::max_num_bits()); + } + Instruction::Binary(Binary { lhs, rhs, operator: BinaryOp::Div, .. }) + if dfg.is_constant(*rhs) => + { + let numerator_bits = dfg.type_of_value(*lhs).bit_size(); + let divisor = dfg.get_numeric_constant(*rhs).unwrap(); + let divisor_bits = divisor.num_bits(); + let max_quotient_bits = numerator_bits - divisor_bits; + + assert!( + max_quotient_bits <= target_type_size, + "Cast from field after div could exceed bit size: expected ≤ {target_type_size}, got {max_quotient_bits}" + ); + } + _ => { + panic!("Invalid cast from Field, must be truncated or provably safe"); + } + }, + Value::NumericConstant { constant, .. } => { + let max_val_bits = constant.num_bits(); + assert!( + max_val_bits <= target_type_size, + "Constant too large for cast target: {max_val_bits} bits > {target_type_size}" + ); + } + _ => { + panic!( + "Invalid cast from Field, not preceded by valid truncation or known safe value" + ); + } + } + } + // Validates there is exactly one return block fn validate_single_return_block(&self) { let reachable_blocks = self.function.reachable_blocks(); @@ -248,6 +331,7 @@ impl<'f> Validator<'f> { for block in self.function.reachable_blocks() { for instruction in self.function.dfg[block].instructions() { self.validate_signed_op_overflow_pattern(*instruction); + self.validate_field_to_integer_cast_invariant(*instruction); self.type_check_instruction(*instruction); } } @@ -714,4 +798,83 @@ mod tests { "; let _ = Ssa::from_str(src); } + + #[test] + fn cast_from_field_constant_in_range() { + let src = " + acir(inline) predicate_pure fn main f0 { + b0(): + v0 = cast Field 42 as u8 + return v0 + } + "; + let _ = Ssa::from_str(src); + } + + #[test] + fn cast_from_field_constant_out_of_range_with_truncate() { + let src = " + acir(inline) predicate_pure fn main f0 { + b0(): + v0 = truncate Field 123456 to 8 bits, max_bit_size: 16 + v1 = cast v0 as u8 + return v1 + } + "; + let _ = Ssa::from_str(src); + } + + #[test] + fn cast_from_field_division_safe() { + let src = " + acir(inline) predicate_pure fn main f0 { + b0(): + v0 = div u16 256, u16 256 + v1 = cast v0 as u8 + return v1 + } + "; + let _ = Ssa::from_str(src); + } + + #[test] + #[should_panic(expected = "Constant too large")] + fn cast_from_field_constant_too_large() { + let src = " + acir(inline) predicate_pure fn main f0 { + b0(): + v0 = cast Field 300 as u8 + return v0 + } + "; + let _ = Ssa::from_str(src); + } + + #[test] + #[should_panic(expected = "Invalid cast from Field")] + fn cast_from_raw_field() { + let src = " + acir(inline) predicate_pure fn main f0 { + b0(): + v0 = add Field 255, Field 1 + v1 = cast v0 as u8 + return v1 + } + "; + let _ = Ssa::from_str(src); + } + + #[test] + #[should_panic(expected = "assertion")] + fn cast_after_unsafe_truncate() { + let src = " + acir(inline) predicate_pure fn main f0 { + b0(): + v0 = truncate Field 1000 to 16 bits, max_bit_size: 16 + v1 = cast v0 as u8 + return v1 + } + "; + let _ = Ssa::from_str(src); + } } diff --git a/tooling/ast_fuzzer/tests/parser.rs b/tooling/ast_fuzzer/tests/parser.rs index 20c114f0ab4..76b4d15891d 100644 --- a/tooling/ast_fuzzer/tests/parser.rs +++ b/tooling/ast_fuzzer/tests/parser.rs @@ -76,7 +76,7 @@ fn arb_ssa_roundtrip() { ssa1.normalize_ids(); // Print to str and parse back. - let ssa2 = Ssa::from_str(&ssa1.to_string()).unwrap_or_else(|e| { + let ssa2 = Ssa::from_str_no_validation(&ssa1.to_string()).unwrap_or_else(|e| { let msg = passes.last().map(|p| p.msg()).unwrap_or("Initial SSA"); print_ast_and_panic(&format!( "Could not parse SSA after step {last_pass} ({msg}): \n{e:?}"