Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions compiler/noirc_evaluator/src/ssa/function_builder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ use super::{
},
opt::pure::FunctionPurities,
ssa_gen::Ssa,
validation::validate_function,
};

/// The per-function context for each ssa function being generated.
Expand Down Expand Up @@ -158,8 +157,6 @@ impl FunctionBuilder {
pub fn finish(mut self) -> Ssa {
self.finished_functions.push(self.current_function);

Self::validate_ssa(&self.finished_functions);

Ssa::new(self.finished_functions, self.error_types)
}

Expand Down Expand Up @@ -524,12 +521,6 @@ impl FunctionBuilder {
pub fn record_error_type(&mut self, selector: ErrorSelector, typ: HirType) {
self.error_types.insert(selector, typ);
}

fn validate_ssa(functions: &[Function]) {
for function in functions {
validate_function(function);
}
}
}

impl std::ops::Index<ValueId> for FunctionBuilder {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1121,10 +1121,12 @@ fn test_range_and_xor_bb() {

acir(inline) fn test_and_xor f2 {
b0(v0: Field, v1: Field):
v2 = cast v0 as u8
v4 = cast v1 as u8
v8 = and v2, v4
v9 = xor v2, v4
v2 = truncate v0 to 8 bits, max_bit_size: 254
v3 = cast v2 as u8
v4 = truncate v1 to 8 bits, max_bit_size: 254
v5 = cast v4 as u8
v8 = and v3, v5
v9 = xor v3, v5
return v9
}

Expand Down
35 changes: 18 additions & 17 deletions compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1472,8 +1472,9 @@ mod test {
b2():
v7 = cast v2 as Field
v9 = add v7, Field 1
v10 = cast v9 as u8
store v10 at v6
v10 = truncate v9 to 8 bits, max_bit_size: 254
v11 = cast v10 as u8
store v11 at v6
jmp b3()
b3():
constrain v5 == u1 1
Expand All @@ -1485,7 +1486,6 @@ mod test {
";

let ssa = Ssa::from_str(src).unwrap();

let flattened_ssa = ssa.flatten_cfg();
let main = flattened_ssa.main();

Expand Down Expand Up @@ -1515,21 +1515,22 @@ mod test {
enable_side_effects v5
v8 = cast v2 as Field
v10 = add v8, Field 1
v11 = cast v10 as u8
v12 = load v6 -> u8
v13 = not v5
v14 = cast v4 as u8
v15 = cast v13 as u8
v16 = unchecked_mul v14, v11
v11 = truncate v10 to 8 bits, max_bit_size: 254
v12 = cast v11 as u8
v13 = load v6 -> u8
v14 = not v5
v15 = cast v4 as u8
v16 = cast v14 as u8
v17 = unchecked_mul v15, v12
v18 = unchecked_add v16, v17
store v18 at v6
enable_side_effects v13
v19 = load v6 -> u8
v20 = cast v13 as u8
v21 = cast v4 as u8
v22 = unchecked_mul v21, v19
store v22 at v6
v18 = unchecked_mul v16, v13
v19 = unchecked_add v17, v18
store v19 at v6
enable_side_effects v14
v20 = load v6 -> u8
v21 = cast v14 as u8
v22 = cast v4 as u8
v23 = unchecked_mul v22, v20
store v23 at v6
enable_side_effects u1 1
constrain v5 == u1 1
return
Expand Down
97 changes: 50 additions & 47 deletions compiler/noirc_evaluator/src/ssa/opt/remove_bit_shifts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ impl Context<'_, '_, '_> {
return InsertInstructionResult::SimplifiedTo(zero).first();
}
}
let pow = self.numeric_constant(FieldElement::from(rhs_bit_size_pow_2), typ);
let pow = self.field_constant(FieldElement::from(rhs_bit_size_pow_2));

let max_lhs_bits = self.context.dfg.get_value_max_num_bits(lhs);
let max_bit_size = max_lhs_bits + bit_shift_size;
Expand All @@ -128,9 +128,8 @@ impl Context<'_, '_, '_> {
let u8_type = NumericType::unsigned(8);
let bit_size_var = self.numeric_constant(FieldElement::from(bit_size as u128), u8_type);
let overflow = self.insert_binary(rhs, BinaryOp::Lt, bit_size_var);
let predicate = self.insert_cast(overflow, typ);
let predicate = self.insert_cast(overflow, NumericType::NativeField);
let pow = self.pow(base, rhs);
let pow = self.insert_cast(pow, typ);

// Unchecked mul because `predicate` will be 1 or 0
(
Expand All @@ -140,14 +139,13 @@ impl Context<'_, '_, '_> {
};

if max_bit <= bit_size {
let pow = self.insert_cast(pow, typ);
// Unchecked mul as it can't overflow
self.insert_binary(lhs, BinaryOp::Mul { unchecked: true }, pow)
} else {
let lhs_field = self.insert_cast(lhs, NumericType::NativeField);
let pow_field = self.insert_cast(pow, NumericType::NativeField);
// Unchecked mul as this is a wrapping operation that we later truncate
let result =
self.insert_binary(lhs_field, BinaryOp::Mul { unchecked: true }, pow_field);
let result = self.insert_binary(lhs_field, BinaryOp::Mul { unchecked: true }, pow);
let result = self.insert_truncate(result, bit_size, max_bit);
self.insert_cast(result, typ)
}
Expand Down Expand Up @@ -184,7 +182,7 @@ impl Context<'_, '_, '_> {
if lhs_typ.is_unsigned() {
// unsigned right bit shift is just a normal division
let result = self.insert_binary(lhs, BinaryOp::Div, pow);
// In case of overflow, pow is 1, because rhs was nullified, so we return explicitly 0.
// In case of overflow, pow is 1, because rhs was nullified, so we return explicitly 0.
return self.insert_binary(
rhs_is_less_than_bit_size_with_lhs_typ,
BinaryOp::Mul { unchecked: true },
Expand Down Expand Up @@ -262,38 +260,44 @@ impl Context<'_, '_, '_> {
/// }
fn pow(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId {
let typ = self.context.dfg.type_of_value(rhs);
if let Type::Numeric(NumericType::Unsigned { bit_size }) = typ {
let to_bits = self.context.dfg.import_intrinsic(Intrinsic::ToBits(Endian::Little));
let result_types = vec![Type::Array(Arc::new(vec![Type::bool()]), bit_size)];

// A call to ToBits can only be done with a field argument (rhs is always u8 here)
let rhs_as_field = self.insert_cast(rhs, NumericType::NativeField);
let rhs_bits = self.insert_call(to_bits, vec![rhs_as_field], result_types);

let rhs_bits = rhs_bits[0];
let one = self.field_constant(FieldElement::one());
let mut r = one;
// All operations are unchecked as we're acting on Field types (which are always unchecked)
for i in 1..bit_size + 1 {
let idx = self.numeric_constant(
FieldElement::from((bit_size - i) as i128),
NumericType::length_type(),
);
let b = self.insert_array_get(rhs_bits, idx, Type::bool());
let not_b = self.insert_not(b);
let b = self.insert_cast(b, NumericType::NativeField);
let not_b = self.insert_cast(not_b, NumericType::NativeField);

let r_squared = self.insert_binary(r, BinaryOp::Mul { unchecked: true }, r);
let r1 = self.insert_binary(r_squared, BinaryOp::Mul { unchecked: true }, not_b);
let a = self.insert_binary(r_squared, BinaryOp::Mul { unchecked: true }, lhs);
let r2 = self.insert_binary(a, BinaryOp::Mul { unchecked: true }, b);
r = self.insert_binary(r1, BinaryOp::Add { unchecked: true }, r2);
}
r
} else {
let Type::Numeric(NumericType::Unsigned { bit_size }) = typ else {
unreachable!("Value must be unsigned in power operation");
};

let to_bits = self.context.dfg.import_intrinsic(Intrinsic::ToBits(Endian::Little));
let result_types = vec![Type::Array(Arc::new(vec![Type::bool()]), bit_size)];

// A call to ToBits can only be done with a field argument (rhs is always u8 here)
let rhs_as_field = self.insert_cast(rhs, NumericType::NativeField);
let rhs_bits = self.insert_call(to_bits, vec![rhs_as_field], result_types);

let rhs_bits = rhs_bits[0];
let one = self.field_constant(FieldElement::one());
let mut r = one;
// All operations are unchecked as we're acting on Field types (which are always unchecked)
for i in 1..bit_size + 1 {
let idx = self.numeric_constant(
FieldElement::from((bit_size - i) as i128),
NumericType::length_type(),
);
let b = self.insert_array_get(rhs_bits, idx, Type::bool());
let not_b = self.insert_not(b);
let b = self.insert_cast(b, NumericType::NativeField);
let not_b = self.insert_cast(not_b, NumericType::NativeField);

let r_squared = self.insert_binary(r, BinaryOp::Mul { unchecked: true }, r);
let r1 = self.insert_binary(r_squared, BinaryOp::Mul { unchecked: true }, not_b);
let a = self.insert_binary(r_squared, BinaryOp::Mul { unchecked: true }, lhs);
let r2 = self.insert_binary(a, BinaryOp::Mul { unchecked: true }, b);
r = self.insert_binary(r1, BinaryOp::Add { unchecked: true }, r2);
}

assert!(
matches!(self.context.dfg.type_of_value(r).unwrap_numeric(), NumericType::NativeField),
"ICE: pow is expected to always return a NativeField"
);

r
}

pub(crate) fn field_constant(&mut self, constant: FieldElement) -> ValueId {
Expand Down Expand Up @@ -452,11 +456,12 @@ mod tests {
";
let ssa = Ssa::from_str(src).unwrap();
let ssa = ssa.remove_bit_shifts();

assert_ssa_snapshot!(ssa, @r"
acir(inline) fn main f0 {
b0(v0: u32, v1: u8):
v3 = lt v1, u8 32
v4 = cast v3 as u32
v4 = cast v3 as Field
v5 = cast v1 as Field
v7 = call to_le_bits(v5) -> [u1; 8]
v9 = array_get v7, index u32 7 -> u1
Expand Down Expand Up @@ -528,15 +533,13 @@ mod tests {
v83 = mul v81, Field 2
v84 = mul v83, v79
v85 = add v82, v84
v86 = cast v85 as u32
v87 = unchecked_mul v4, v86
v88 = cast v0 as Field
v89 = cast v87 as Field
v90 = mul v88, v89
v91 = truncate v90 to 32 bits, max_bit_size: 254
v92 = cast v91 as u32
v93 = truncate v92 to 32 bits, max_bit_size: 33
return v92
v86 = mul v4, v85
v87 = cast v0 as Field
v88 = mul v87, v86
v89 = truncate v88 to 32 bits, max_bit_size: 254
v90 = cast v89 as u32
v91 = truncate v90 to 32 bits, max_bit_size: 33
return v90
}
");
}
Expand Down
19 changes: 15 additions & 4 deletions compiler/noirc_evaluator/src/ssa/parser/into_ssa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use crate::ssa::{
value::ValueId,
},
opt::pure::FunctionPurities,
ssa_gen::validate_ssa,
};

use super::{
Expand All @@ -24,8 +25,8 @@ use super::{
};

impl ParsedSsa {
pub(crate) fn into_ssa(self, simplify: bool) -> Result<Ssa, SsaError> {
Translator::translate(self, simplify)
pub(crate) fn into_ssa(self, simplify: bool, validate: bool) -> Result<Ssa, SsaError> {
Translator::translate(self, simplify, validate)
}
}

Expand Down Expand Up @@ -61,7 +62,11 @@ struct Translator {
}

impl Translator {
fn translate(mut parsed_ssa: ParsedSsa, simplify: bool) -> Result<Ssa, SsaError> {
fn translate(
mut parsed_ssa: ParsedSsa,
simplify: bool,
validate: bool,
) -> Result<Ssa, SsaError> {
let mut translator = Self::new(&mut parsed_ssa, simplify)?;

// Note that the `new` call above removed the main function,
Expand All @@ -70,7 +75,13 @@ impl Translator {
translator.translate_non_main_function(function)?;
}

Ok(translator.finish())
let ssa = translator.finish();

if validate {
validate_ssa(&ssa);
}

Ok(ssa)
}

fn new(parsed_ssa: &mut ParsedSsa, simplify: bool) -> Result<Self, SsaError> {
Expand Down
13 changes: 9 additions & 4 deletions compiler/noirc_evaluator/src/ssa/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ impl FromStr for Ssa {
type Err = SsaErrorWithSource;

fn from_str(s: &str) -> Result<Self, Self::Err> {
Self::from_str_impl(s, false)
Self::from_str_impl(s, false, true)
}
}

Expand All @@ -48,19 +48,24 @@ impl Ssa {
FromStr::from_str(src)
}

/// Creates an Ssa object from the given string without running SSA validation
pub fn from_str_no_validation(src: &str) -> Result<Ssa, SsaErrorWithSource> {
Self::from_str_impl(src, false, false)
}

/// Creates an Ssa object from the given string but trying to simplify
/// each parsed instruction as it's inserted into the final SSA.
pub fn from_str_simplifying(src: &str) -> Result<Ssa, SsaErrorWithSource> {
Self::from_str_impl(src, true)
Self::from_str_impl(src, true, true)
}

fn from_str_impl(src: &str, simplify: bool) -> Result<Ssa, SsaErrorWithSource> {
fn from_str_impl(src: &str, simplify: bool, validate: bool) -> Result<Ssa, SsaErrorWithSource> {
let mut parser =
Parser::new(src).map_err(|err| SsaErrorWithSource::parse_error(err, src))?;
let parsed_ssa =
parser.parse_ssa().map_err(|err| SsaErrorWithSource::parse_error(err, src))?;
parsed_ssa
.into_ssa(simplify)
.into_ssa(simplify, validate)
.map_err(|error| SsaErrorWithSource { src: src.to_string(), error })
}
}
Expand Down
5 changes: 3 additions & 2 deletions compiler/noirc_evaluator/src/ssa/parser/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -288,8 +288,9 @@ fn test_cast() {
let src = "
acir(inline) fn main f0 {
b0(v0: Field):
v1 = cast v0 as i32
return v1
v1 = truncate v0 to 32 bits, max_bit_size: 254
v2 = cast v1 as i32
return v2
}
";
assert_ssa_roundtrip(src);
Expand Down
9 changes: 9 additions & 0 deletions compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ use super::ir::basic_block::BasicBlockId;
use super::ir::dfg::GlobalsGraph;
use super::ir::instruction::{ArrayOffset, ErrorType};
use super::ir::types::NumericType;
use super::validation::validate_function;
use super::{
function_builder::data_bus::DataBus,
ir::{
Expand Down Expand Up @@ -131,9 +132,17 @@ pub fn generate_ssa(program: Program) -> Result<Ssa, RuntimeError> {
}

let ssa = function_context.builder.finish();
validate_ssa(&ssa);

Ok(ssa)
}

pub(crate) fn validate_ssa(ssa: &Ssa) {
for function in ssa.functions.values() {
validate_function(function);
}
}

impl FunctionContext<'_> {
/// Codegen a function's body and set its return value to that of its last parameter.
/// For functions returning nothing, this will be an empty list.
Expand Down
Loading
Loading