Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1102,10 +1102,12 @@ fn test_range_and_xor_bb() {

acir(inline) fn test_and_xor f2 {
b0(v0: Field, v1: Field):
v2 = cast v0 as u8
v4 = cast v1 as u8
v8 = and v2, v4
v9 = xor v2, v4
v2 = truncate v0 to 8 bits, max_bit_size: 254
v3 = cast v2 as u8
v4 = truncate v1 to 8 bits, max_bit_size: 254
v5 = cast v4 as u8
v8 = and v3, v5
v9 = xor v3, v5
return v9
}

Expand Down
35 changes: 18 additions & 17 deletions compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1472,8 +1472,9 @@ mod test {
b2():
v7 = cast v2 as Field
v9 = add v7, Field 1
v10 = cast v9 as u8
store v10 at v6
v10 = truncate v9 to 8 bits, max_bit_size: 254
v11 = cast v10 as u8
store v11 at v6
jmp b3()
b3():
constrain v5 == u1 1
Expand All @@ -1485,7 +1486,6 @@ mod test {
";

let ssa = Ssa::from_str(src).unwrap();

let flattened_ssa = ssa.flatten_cfg();
let main = flattened_ssa.main();

Expand Down Expand Up @@ -1515,21 +1515,22 @@ mod test {
enable_side_effects v5
v8 = cast v2 as Field
v10 = add v8, Field 1
v11 = cast v10 as u8
v12 = load v6 -> u8
v13 = not v5
v14 = cast v4 as u8
v15 = cast v13 as u8
v16 = unchecked_mul v14, v11
v11 = truncate v10 to 8 bits, max_bit_size: 254
v12 = cast v11 as u8
v13 = load v6 -> u8
v14 = not v5
v15 = cast v4 as u8
v16 = cast v14 as u8
v17 = unchecked_mul v15, v12
v18 = unchecked_add v16, v17
store v18 at v6
enable_side_effects v13
v19 = load v6 -> u8
v20 = cast v13 as u8
v21 = cast v4 as u8
v22 = unchecked_mul v21, v19
store v22 at v6
v18 = unchecked_mul v16, v13
v19 = unchecked_add v17, v18
store v19 at v6
enable_side_effects v14
v20 = load v6 -> u8
v21 = cast v14 as u8
v22 = cast v4 as u8
v23 = unchecked_mul v22, v20
store v23 at v6
enable_side_effects u1 1
constrain v5 == u1 1
return
Expand Down
92 changes: 53 additions & 39 deletions compiler/noirc_evaluator/src/ssa/opt/remove_bit_shifts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ impl Function {

#[cfg(debug_assertions)]
remove_bit_shifts_post_check(self);

// println!("{}", self);
// validate_function(self);
}
}

Expand Down Expand Up @@ -130,6 +133,7 @@ impl Context<'_, '_, '_> {
let overflow = self.insert_binary(rhs, BinaryOp::Lt, bit_size_var);
let predicate = self.insert_cast(overflow, typ);
let pow = self.pow(base, rhs);
let pow = self.insert_truncate(pow, typ.bit_size(), FieldElement::max_num_bits());
let pow = self.insert_cast(pow, typ);

// Unchecked mul because `predicate` will be 1 or 0
Expand Down Expand Up @@ -179,6 +183,7 @@ impl Context<'_, '_, '_> {
rhs,
);
let pow = self.pow(base, rhs);
let pow = self.insert_truncate(pow, lhs_typ.bit_size(), FieldElement::max_num_bits());
let pow = self.insert_cast(pow, lhs_typ);

if lhs_typ.is_unsigned() {
Expand All @@ -205,6 +210,7 @@ impl Context<'_, '_, '_> {
// Performs the division on the 1-complement (or the operand if positive)
let shifted_complement = self.insert_binary(one_complement, BinaryOp::Div, pow);
// Convert back to 2-complement representation if operand is negative
let lhs_sign = self.insert_truncate(lhs_sign, lhs_typ.bit_size(), lhs_typ.bit_size() + 1);
let lhs_sign_as_int = self.insert_cast(lhs_sign, lhs_typ);

// The requirements for this to underflow are all of these:
Expand Down Expand Up @@ -262,38 +268,44 @@ impl Context<'_, '_, '_> {
/// }
fn pow(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId {
let typ = self.context.dfg.type_of_value(rhs);
if let Type::Numeric(NumericType::Unsigned { bit_size }) = typ {
let to_bits = self.context.dfg.import_intrinsic(Intrinsic::ToBits(Endian::Little));
let result_types = vec![Type::Array(Arc::new(vec![Type::bool()]), bit_size)];

// A call to ToBits can only be done with a field argument (rhs is always u8 here)
let rhs_as_field = self.insert_cast(rhs, NumericType::NativeField);
let rhs_bits = self.insert_call(to_bits, vec![rhs_as_field], result_types);

let rhs_bits = rhs_bits[0];
let one = self.field_constant(FieldElement::one());
let mut r = one;
// All operations are unchecked as we're acting on Field types (which are always unchecked)
for i in 1..bit_size + 1 {
let idx = self.numeric_constant(
FieldElement::from((bit_size - i) as i128),
NumericType::length_type(),
);
let b = self.insert_array_get(rhs_bits, idx, Type::bool());
let not_b = self.insert_not(b);
let b = self.insert_cast(b, NumericType::NativeField);
let not_b = self.insert_cast(not_b, NumericType::NativeField);

let r_squared = self.insert_binary(r, BinaryOp::Mul { unchecked: true }, r);
let r1 = self.insert_binary(r_squared, BinaryOp::Mul { unchecked: true }, not_b);
let a = self.insert_binary(r_squared, BinaryOp::Mul { unchecked: true }, lhs);
let r2 = self.insert_binary(a, BinaryOp::Mul { unchecked: true }, b);
r = self.insert_binary(r1, BinaryOp::Add { unchecked: true }, r2);
}
r
} else {
let Type::Numeric(NumericType::Unsigned { bit_size }) = typ else {
unreachable!("Value must be unsigned in power operation");
};

let to_bits = self.context.dfg.import_intrinsic(Intrinsic::ToBits(Endian::Little));
let result_types = vec![Type::Array(Arc::new(vec![Type::bool()]), bit_size)];

// A call to ToBits can only be done with a field argument (rhs is always u8 here)
let rhs_as_field = self.insert_cast(rhs, NumericType::NativeField);
let rhs_bits = self.insert_call(to_bits, vec![rhs_as_field], result_types);

let rhs_bits = rhs_bits[0];
let one = self.field_constant(FieldElement::one());
let mut r = one;
// All operations are unchecked as we're acting on Field types (which are always unchecked)
for i in 1..bit_size + 1 {
let idx = self.numeric_constant(
FieldElement::from((bit_size - i) as i128),
NumericType::length_type(),
);
let b = self.insert_array_get(rhs_bits, idx, Type::bool());
let not_b = self.insert_not(b);
let b = self.insert_cast(b, NumericType::NativeField);
let not_b = self.insert_cast(not_b, NumericType::NativeField);

let r_squared = self.insert_binary(r, BinaryOp::Mul { unchecked: true }, r);
let r1 = self.insert_binary(r_squared, BinaryOp::Mul { unchecked: true }, not_b);
let a = self.insert_binary(r_squared, BinaryOp::Mul { unchecked: true }, lhs);
let r2 = self.insert_binary(a, BinaryOp::Mul { unchecked: true }, b);
r = self.insert_binary(r1, BinaryOp::Add { unchecked: true }, r2);
}

assert!(
matches!(self.context.dfg.type_of_value(r).unwrap_numeric(), NumericType::NativeField),
"ICE: pow is expected to always return a NativeField"
);

r
}

pub(crate) fn field_constant(&mut self, constant: FieldElement) -> ValueId {
Expand Down Expand Up @@ -452,6 +464,7 @@ mod tests {
";
let ssa = Ssa::from_str(src).unwrap();
let ssa = ssa.remove_bit_shifts();

assert_ssa_snapshot!(ssa, @r"
acir(inline) fn main f0 {
b0(v0: u32, v1: u8):
Expand Down Expand Up @@ -528,15 +541,16 @@ mod tests {
v83 = mul v81, Field 2
v84 = mul v83, v79
v85 = add v82, v84
v86 = cast v85 as u32
v87 = unchecked_mul v4, v86
v88 = cast v0 as Field
v89 = cast v87 as Field
v90 = mul v88, v89
v91 = truncate v90 to 32 bits, max_bit_size: 254
v92 = cast v91 as u32
v93 = truncate v92 to 32 bits, max_bit_size: 33
return v92
v86 = truncate v85 to 32 bits, max_bit_size: 254
v87 = cast v86 as u32
v88 = unchecked_mul v4, v87
v89 = cast v0 as Field
v90 = cast v88 as Field
v91 = mul v89, v90
v92 = truncate v91 to 32 bits, max_bit_size: 254
v93 = cast v92 as u32
v94 = truncate v93 to 32 bits, max_bit_size: 33
return v93
}
");
}
Expand Down
5 changes: 3 additions & 2 deletions compiler/noirc_evaluator/src/ssa/parser/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,9 @@
let src = "
acir(inline) fn main f0 {
b0(v0: Field):
v1 = cast v0 as i32
return v1
v1 = truncate v0 to 32 bits, max_bit_size: 254
v2 = cast v1 as i32
return v2
}
";
assert_ssa_roundtrip(src);
Expand Down Expand Up @@ -786,7 +787,7 @@
v2 = add Field 1, Field 2
v4 = add v2, Field 3
jmp b1()
}

Check warning on line 790 in compiler/noirc_evaluator/src/ssa/parser/tests.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (syntantically)
";
assert_ssa_roundtrip(src);
}
Expand All @@ -804,7 +805,7 @@
jmp b3()
b3():
v3 = add Field 1, Field 2
v5 = add v3, Field 3

Check warning on line 808 in compiler/noirc_evaluator/src/ssa/parser/tests.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (syntantically)
jmp b1()
}
";
Expand Down
89 changes: 87 additions & 2 deletions compiler/noirc_evaluator/src/ssa/validation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
//! - Check that the input values of certain instructions matches that instruction's constraint
//! At the moment, only [Instruction::Binary], [Instruction::ArrayGet], and [Instruction::ArraySet]
//! are type checked.
use fxhash::FxHashSet as HashSet;
use acvm::{AcirField, FieldElement};
use fxhash::{FxHashMap as HashMap, FxHashSet as HashSet};

use crate::ssa::ir::instruction::TerminatorInstruction;

Expand All @@ -30,6 +31,12 @@ struct Validator<'f> {
// State for truncate-after-signed-sub validation
// Stores: Option<(bit_size, result)>
signed_binary_op: Option<PendingSignedOverflowOp>,

// State for valid Field to integer casts
// Range checks are laid down in isolation and can make for safe casts
// If they occurred before the value being cast to a smaller type
// Stores: A set of (value being range constrained, the value's max bit size)
range_checks: HashMap<ValueId, u32>,
}

#[derive(Debug)]
Expand All @@ -40,7 +47,7 @@ enum PendingSignedOverflowOp {

impl<'f> Validator<'f> {
fn new(function: &'f Function) -> Self {
Self { function, signed_binary_op: None }
Self { function, signed_binary_op: None, range_checks: HashMap::default() }
}

/// Validates that any checked signed add/sub/mul are followed by the appropriate instructions.
Expand Down Expand Up @@ -138,6 +145,83 @@ impl<'f> Validator<'f> {
}
}

/// Enforces that every cast from Field -> unsigned/signed integer must obey the following invariants:
/// The value being cast is either:
/// 1. A truncate instruction that ensures the cast is valid
/// 2. A constant value known to be in-range
/// 3. A division or other operation whose result is known to fit within the target bit size
///
/// Our initial SSA gen only generates preceding truncates for safe casts.
/// The cases accepted here are extended past what we perform during our initial SSA gen
/// to mirror the instruction simplifier and other logic that could be accepted as a safe cast.
fn validate_field_to_integer_cast_invariant(&mut self, instruction_id: InstructionId) {
let dfg = &self.function.dfg;

let (cast_input, typ) = match &dfg[instruction_id] {
Instruction::Cast(cast_input, typ) => (*cast_input, *typ),
Instruction::RangeCheck { value, max_bit_size, .. } => {
self.range_checks.insert(*value, *max_bit_size);
return;
}
_ => return,
};

if !matches!(dfg.type_of_value(cast_input), Type::Numeric(NumericType::NativeField)) {
return;
}

let (NumericType::Signed { bit_size: target_type_size }
| NumericType::Unsigned { bit_size: target_type_size }) = typ
else {
return;
};

// If the cast input has already been range constrained to a bit size that fits
// in the destination type, we have a safe cast.
if let Some(max_bit_size) = self.range_checks.get(&cast_input) {
assert!(*max_bit_size <= target_type_size);
return;
}

match &dfg[cast_input] {
Value::Instruction { instruction, .. } => match &dfg[*instruction] {
Instruction::Truncate { value: _, bit_size, max_bit_size } => {
assert_eq!(*bit_size, target_type_size);
assert!(*max_bit_size <= FieldElement::max_num_bits());
}
Instruction::Binary(Binary { lhs, rhs, operator: BinaryOp::Div, .. })
if dfg.is_constant(*rhs) =>
{
let numerator_bits = dfg.type_of_value(*lhs).bit_size();
let divisor = dfg.get_numeric_constant(*rhs).unwrap();
let divisor_bits = divisor.num_bits();
let max_quotient_bits = numerator_bits - divisor_bits;

assert!(
max_quotient_bits <= target_type_size,
"Cast from field after div could exceed bit size: expected ≤ {target_type_size}, got {max_quotient_bits}"
);
}
_ => {
dbg!(&dfg[*instruction]);
panic!("Invalid cast from Field, must be truncated or provably safe");
}
},
Value::NumericConstant { constant, .. } => {
let max_val_bits = constant.num_bits();
assert!(
max_val_bits <= target_type_size,
"Constant too large for cast target: {max_val_bits} bits > {target_type_size}"
);
}
_ => {
panic!(
"Invalid cast from Field, not preceded by valid truncation or known safe value"
);
}
}
}

// Validates there is exactly one return block
fn validate_single_return_block(&self) {
let reachable_blocks = self.function.reachable_blocks();
Expand Down Expand Up @@ -239,6 +323,7 @@ impl<'f> Validator<'f> {
for block in self.function.reachable_blocks() {
for instruction in self.function.dfg[block].instructions() {
self.validate_signed_op_overflow_pattern(*instruction);
self.validate_field_to_integer_cast_invariant(*instruction);
self.type_check_instruction(*instruction);
}
}
Expand Down
Loading
Loading