Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 40 additions & 31 deletions compiler/noirc_evaluator/src/ssa/opt/remove_bit_shifts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::ssa::{
types::{NumericType, Type},
value::ValueId,
},
ssa_gen::Ssa,
ssa_gen::Ssa
};

use super::simple_optimization::SimpleOptimizationContext;
Expand Down Expand Up @@ -82,6 +82,9 @@ impl Function {

#[cfg(debug_assertions)]
remove_bit_shifts_post_check(self);

// println!("{}", self);
// validate_function(self);
}
}

Expand Down Expand Up @@ -130,6 +133,7 @@ impl Context<'_, '_, '_> {
let overflow = self.insert_binary(rhs, BinaryOp::Lt, bit_size_var);
let predicate = self.insert_cast(overflow, typ);
let pow = self.pow(base, rhs);
let pow = self.insert_truncate(pow, typ.bit_size(), FieldElement::max_num_bits());
let pow = self.insert_cast(pow, typ);

// Unchecked mul because `predicate` will be 1 or 0
Expand Down Expand Up @@ -179,6 +183,7 @@ impl Context<'_, '_, '_> {
rhs,
);
let pow = self.pow(base, rhs);
let pow = self.insert_truncate(pow, lhs_typ.bit_size(), FieldElement::max_num_bits());
let pow = self.insert_cast(pow, lhs_typ);

if lhs_typ.is_unsigned() {
Expand All @@ -205,6 +210,7 @@ impl Context<'_, '_, '_> {
// Performs the division on the 1-complement (or the operand if positive)
let shifted_complement = self.insert_binary(one_complement, BinaryOp::Div, pow);
// Convert back to 2-complement representation if operand is negative
let lhs_sign = self.insert_truncate(lhs_sign, lhs_typ.bit_size(), lhs_typ.bit_size() + 1);
let lhs_sign_as_int = self.insert_cast(lhs_sign, lhs_typ);

// The requirements for this to underflow are all of these:
Expand Down Expand Up @@ -262,38 +268,41 @@ impl Context<'_, '_, '_> {
/// }
fn pow(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId {
let typ = self.context.dfg.type_of_value(rhs);
if let Type::Numeric(NumericType::Unsigned { bit_size }) = typ {
let to_bits = self.context.dfg.import_intrinsic(Intrinsic::ToBits(Endian::Little));
let result_types = vec![Type::Array(Arc::new(vec![Type::bool()]), bit_size)];

// A call to ToBits can only be done with a field argument (rhs is always u8 here)
let rhs_as_field = self.insert_cast(rhs, NumericType::NativeField);
let rhs_bits = self.insert_call(to_bits, vec![rhs_as_field], result_types);

let rhs_bits = rhs_bits[0];
let one = self.field_constant(FieldElement::one());
let mut r = one;
// All operations are unchecked as we're acting on Field types (which are always unchecked)
for i in 1..bit_size + 1 {
let idx = self.numeric_constant(
FieldElement::from((bit_size - i) as i128),
NumericType::length_type(),
);
let b = self.insert_array_get(rhs_bits, idx, Type::bool());
let not_b = self.insert_not(b);
let b = self.insert_cast(b, NumericType::NativeField);
let not_b = self.insert_cast(not_b, NumericType::NativeField);

let r_squared = self.insert_binary(r, BinaryOp::Mul { unchecked: true }, r);
let r1 = self.insert_binary(r_squared, BinaryOp::Mul { unchecked: true }, not_b);
let a = self.insert_binary(r_squared, BinaryOp::Mul { unchecked: true }, lhs);
let r2 = self.insert_binary(a, BinaryOp::Mul { unchecked: true }, b);
r = self.insert_binary(r1, BinaryOp::Add { unchecked: true }, r2);
}
r
} else {
let Type::Numeric(NumericType::Unsigned { bit_size }) = typ else {
unreachable!("Value must be unsigned in power operation");
};

let to_bits = self.context.dfg.import_intrinsic(Intrinsic::ToBits(Endian::Little));
let result_types = vec![Type::Array(Arc::new(vec![Type::bool()]), bit_size)];

// A call to ToBits can only be done with a field argument (rhs is always u8 here)
let rhs_as_field = self.insert_cast(rhs, NumericType::NativeField);
let rhs_bits = self.insert_call(to_bits, vec![rhs_as_field], result_types);

let rhs_bits = rhs_bits[0];
let one = self.field_constant(FieldElement::one());
let mut r = one;
// All operations are unchecked as we're acting on Field types (which are always unchecked)
for i in 1..bit_size + 1 {
let idx = self.numeric_constant(
FieldElement::from((bit_size - i) as i128),
NumericType::length_type(),
);
let b = self.insert_array_get(rhs_bits, idx, Type::bool());
let not_b = self.insert_not(b);
let b = self.insert_cast(b, NumericType::NativeField);
let not_b = self.insert_cast(not_b, NumericType::NativeField);

let r_squared = self.insert_binary(r, BinaryOp::Mul { unchecked: true }, r);
let r1 = self.insert_binary(r_squared, BinaryOp::Mul { unchecked: true }, not_b);
let a = self.insert_binary(r_squared, BinaryOp::Mul { unchecked: true }, lhs);
let r2 = self.insert_binary(a, BinaryOp::Mul { unchecked: true }, b);
r = self.insert_binary(r1, BinaryOp::Add { unchecked: true }, r2);
}

assert!(matches!(self.context.dfg.type_of_value(r).unwrap_numeric(), NumericType::NativeField), "ICE: pow is expected to always return a NativeField");

r
}

pub(crate) fn field_constant(&mut self, constant: FieldElement) -> ValueId {
Expand Down
87 changes: 85 additions & 2 deletions compiler/noirc_evaluator/src/ssa/validation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
//! - Check that the input values of certain instructions matches that instruction's constraint
//! At the moment, only [Instruction::Binary], [Instruction::ArrayGet], and [Instruction::ArraySet]
//! are type checked.
use fxhash::FxHashSet as HashSet;
use acvm::{AcirField, FieldElement};
use fxhash::{FxHashMap as HashMap, FxHashSet as HashSet};

use crate::ssa::ir::instruction::TerminatorInstruction;

Expand All @@ -30,6 +31,12 @@ struct Validator<'f> {
// State for truncate-after-signed-sub validation
// Stores: Option<(bit_size, result)>
signed_binary_op: Option<PendingSignedOverflowOp>,

// State for valid Field to integer casts
// Range checks are laid down in isolation and can make for safe casts
// If they occurred before the value being cast to a smaller type
// Stores: A set of (value being range constrained, the value's max bit size)
range_checks: HashMap<ValueId, u32>
}

#[derive(Debug)]
Expand All @@ -40,7 +47,7 @@ enum PendingSignedOverflowOp {

impl<'f> Validator<'f> {
fn new(function: &'f Function) -> Self {
Self { function, signed_binary_op: None }
Self { function, signed_binary_op: None, range_checks: HashMap::default(), }
}

/// Validates that any checked signed add/sub/mul are followed by the appropriate instructions.
Expand Down Expand Up @@ -138,6 +145,81 @@ impl<'f> Validator<'f> {
}
}

/// Enforces that every cast from Field -> unsigned/signed integer must obey the following invariants:
/// The value being cast is either:
/// 1. A truncate instruction that ensures the cast is valid
/// 2. A constant value known to be in-range
/// 3. A division or other operation whose result is known to fit within the target bit size
///
/// Our initial SSA gen only generates preceding truncates for safe casts.
/// The cases accepted here are extended past what we perform during our initial SSA gen
/// to mirror the instruction simplifier and other logic that could be accepted as a safe cast.
fn validate_field_to_integer_cast_invariant(&mut self, instruction_id: InstructionId) {
let dfg = &self.function.dfg;

let (cast_input, typ) = match &dfg[instruction_id] {
Instruction::Cast(cast_input, typ) => {
(*cast_input, *typ)
}
Instruction::RangeCheck { value, max_bit_size, .. } => {
self.range_checks.insert(*value, *max_bit_size);
return
}
_ => return,
};

if !matches!(dfg.type_of_value(cast_input), Type::Numeric(NumericType::NativeField)) {
return;
}

let (NumericType::Signed { bit_size: target_type_size }
| NumericType::Unsigned { bit_size: target_type_size }) = typ
else {
return;
};

// If the cast input has already been range constrained to a bit size that fits
// in the destination type, we have a safe cast.
if let Some(max_bit_size) = self.range_checks.get(&cast_input) {
assert!(*max_bit_size <= target_type_size);
return;
}

match &dfg[cast_input] {
Value::Instruction { instruction, .. } => match &dfg[*instruction] {
Instruction::Truncate { value: _, bit_size, max_bit_size } => {
assert_eq!(*bit_size, target_type_size);
assert!(*max_bit_size <= FieldElement::max_num_bits());
}
Instruction::Binary(Binary { lhs, rhs, operator: BinaryOp::Div, .. }) if dfg.is_constant(*rhs) => {
let numerator_bits = dfg.type_of_value(*lhs).bit_size();
let divisor = dfg.get_numeric_constant(*rhs).unwrap();
let divisor_bits = divisor.num_bits();
let max_quotient_bits = numerator_bits - divisor_bits;

assert!(
max_quotient_bits <= target_type_size,
"Cast from field after div could exceed bit size: expected ≤ {target_type_size}, got {max_quotient_bits}"
);
}
_ => {
dbg!(&dfg[*instruction]);
panic!("Invalid cast from Field, must be truncated or provably safe");
}
}
Value::NumericConstant { constant, .. } => {
let max_val_bits = constant.num_bits();
assert!(
max_val_bits <= target_type_size,
"Constant too large for cast target: {max_val_bits} bits > {target_type_size}"
);
}
_ => {
panic!("Invalid cast from Field, not preceded by valid truncation or known safe value");
}
}
}

// Validates there is exactly one return block
fn validate_single_return_block(&self) {
let reachable_blocks = self.function.reachable_blocks();
Expand Down Expand Up @@ -239,6 +321,7 @@ impl<'f> Validator<'f> {
for block in self.function.reachable_blocks() {
for instruction in self.function.dfg[block].instructions() {
self.validate_signed_op_overflow_pattern(*instruction);
self.validate_field_to_integer_cast_invariant(*instruction);
self.type_check_instruction(*instruction);
}
}
Expand Down
Loading
Loading