Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions compiler/noirc_evaluator/src/ssa/function_builder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -552,10 +552,10 @@ impl std::ops::Index<BasicBlockId> for FunctionBuilder {
fn validate_numeric_type(typ: &NumericType) {
match &typ {
NumericType::Signed { bit_size } => match bit_size {
8 | 16 | 32 | 64 | 128 => (),
8 | 16 | 32 | 64 => (),
_ => {
panic!(
"Invalid bit size for signed numeric type: {bit_size}. Expected one of 8, 16, 32, 64 or 128."
"Invalid bit size for signed numeric type: {bit_size}. Expected one of 8, 16, 32, or 64."
);
}
},
Expand Down
33 changes: 32 additions & 1 deletion compiler/noirc_evaluator/src/ssa/ir/dfg/simplify/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@ pub(super) fn simplify_cast(

if let Value::Instruction { instruction, .. } = &dfg[value] {
if let Instruction::Cast(original_value, _) = &dfg[*instruction] {
return SimplifiedToInstruction(Instruction::Cast(*original_value, dst_typ));
let original_value = *original_value;
return match simplify_cast(original_value, dst_typ, dfg) {
None => SimplifiedToInstruction(Instruction::Cast(original_value, dst_typ)),
simpler => simpler,
};
}
}

Expand Down Expand Up @@ -151,4 +155,31 @@ mod tests {
}
");
}

#[test]
fn simplifies_out_casting_there_and_back() {
// Casting from e.g. i8 to u64 used to go through sign extending to i64,
// which itself first cast to u8, then u64 to do some arithmetic, then
// the result was cast to i64 and back to u64.
let src = "
acir(inline) fn main f0 {
b0(v0: u64, v1: u64):
v2 = unchecked_add v0, v1
v3 = cast v2 as i64
v4 = cast v3 as u64
return v4
}
";

let ssa = Ssa::from_str_simplifying(src).unwrap();

assert_ssa_snapshot!(ssa, @r"
acir(inline) fn main f0 {
b0(v0: u64, v1: u64):
v2 = unchecked_add v0, v1
v3 = cast v2 as i64
return v2
}
");
}
}
6 changes: 2 additions & 4 deletions compiler/noirc_evaluator/src/ssa/opt/expand_signed_checks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -505,8 +505,7 @@ mod tests {
v12 = eq v11, v8
v13 = unchecked_mul v12, v10
constrain v13 == v10, "attempt to add with overflow"
v14 = cast v3 as i32
return v14
return v3
}
"#);
}
Expand Down Expand Up @@ -538,8 +537,7 @@ mod tests {
v13 = eq v12, v8
v14 = unchecked_mul v13, v11
constrain v14 == v11, "attempt to subtract with overflow"
v15 = cast v3 as i32
return v15
return v3
}
"#);
}
Expand Down
2 changes: 1 addition & 1 deletion compiler/noirc_evaluator/src/ssa/opt/remove_bit_shifts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,7 @@ mod tests {
v4 = cast v3 as u64
v6 = lt v4, u64 64
constrain v6 == u1 1, "attempt to bit-shift with overflow"
v8 = cast v3 as Field
v8 = cast v1 as Field
v10 = call to_le_bits(v8) -> [u1; 1]
v12 = array_get v10, index u32 0 -> u1
v13 = not v12
Expand Down
177 changes: 95 additions & 82 deletions compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ impl<'a> FunctionContext<'a> {
/// Compared to `self.builder.insert_cast`, this version will automatically truncate `value` to be a valid `typ`.
pub(super) fn insert_safe_cast(
&mut self,
mut value: ValueId,
value: ValueId,
typ: NumericType,
location: Location,
) -> ValueId {
Expand All @@ -392,68 +392,7 @@ impl<'a> FunctionContext<'a> {
}
std::cmp::Ordering::Equal => value,
std::cmp::Ordering::Greater => {
// If target size is bigger, we do a sign extension:
// When the value is negative, it is represented in 2-complement form; `2^s-v`, where `s` is the incoming bit size and `v` is the absolute value
// Sign extension in this case will give `2^t-v`, where `t` is the target bit size
// So we simply convert `2^s-v` into `2^t-v` by adding `2^t-2^s` to the value when the value is negative.
// Casting s-bits signed v0 to t-bits will add the following instructions:
// v1 = cast v0 to 's-bits unsigned'
// v2 = lt v1, 2**(s-1)
// v3 = not(v1)
// v4 = cast v3 to 't-bits unsigned'
// v5 = v3 * (2**t - 2**s)
// v6 = cast v1 to 't-bits unsigned'
// return v6 + v5
let value_as_unsigned = self.insert_safe_cast(
value,
NumericType::unsigned(*incoming_type_size),
location,
);
let half_width = self.builder.numeric_constant(
FieldElement::from(2_u128.pow(incoming_type_size - 1)),
NumericType::unsigned(*incoming_type_size),
);
// value_sign is 1 if the value is positive, 0 otherwise
let value_sign =
self.builder.insert_binary(value_as_unsigned, BinaryOp::Lt, half_width);
let max_for_incoming_type_size = if *incoming_type_size == 128 {
u128::MAX
} else {
2_u128.pow(*incoming_type_size) - 1
};
let max_for_target_type_size = if target_type_size == 128 {
u128::MAX
} else {
2_u128.pow(target_type_size) - 1
};
let patch = self.builder.numeric_constant(
FieldElement::from(
max_for_target_type_size - max_for_incoming_type_size,
),
NumericType::unsigned(target_type_size),
);
let mut is_negative_predicate = self.builder.insert_not(value_sign);
is_negative_predicate = self.insert_safe_cast(
is_negative_predicate,
NumericType::unsigned(target_type_size),
location,
);
// multiplication by a boolean cannot overflow
let patch_with_sign_predicate = self.builder.insert_binary(
patch,
BinaryOp::Mul { unchecked: true },
is_negative_predicate,
);
let value_as_unsigned = self.builder.insert_cast(
value_as_unsigned,
NumericType::unsigned(target_type_size),
);
// Patch the bit sign, which gives a `target_type_size` bit size value, so it does not overflow.
self.builder.insert_binary(
patch_with_sign_predicate,
BinaryOp::Add { unchecked: true },
value_as_unsigned,
)
self.sign_extend(value, *incoming_type_size, target_type_size, location)
}
}
}
Expand All @@ -463,45 +402,55 @@ impl<'a> FunctionContext<'a> {
) => {
// If target size is smaller, we do a truncation
if target_type_size < *incoming_type_size {
value =
self.builder.insert_truncate(value, target_type_size, *incoming_type_size);
self.builder.insert_truncate(value, target_type_size, *incoming_type_size)
} else {
value
}
value
}
// When casting a signed value to u1 we can truncate then cast
(
Type::Numeric(NumericType::Signed { bit_size: incoming_type_size }),
NumericType::Unsigned { bit_size: 1 },
) => self.builder.insert_truncate(value, 1, *incoming_type_size),
// For mixed sign to unsigned or unsigned to sign;
// 1. we cast to the required type using the same signedness
// 2. then we switch the signedness

// For mixed singed to unsigned:
(
Type::Numeric(NumericType::Signed { bit_size: incoming_type_size }),
NumericType::Unsigned { bit_size: target_type_size },
) => {
if *incoming_type_size != target_type_size {
value = self.insert_safe_cast(
value,
NumericType::signed(target_type_size),
location,
);
// when going from lower to higher bit size:
// 1. we sign-extend to the target bits
// 2. we are already in the target signedness
if *incoming_type_size < target_type_size {
// By not the casting to a signed type with the target bit size, we avoid potentially going
// through i128, which is not a type we support in the frontend, and would be strange in SSA.
self.sign_extend(value, *incoming_type_size, target_type_size, location)
}
// when the target bit size is not higher than the source:
// 1. we cast to the required type using the same signedness
// 2. then we switch the signedness
else if *incoming_type_size != target_type_size {
self.insert_safe_cast(value, NumericType::signed(target_type_size), location)
} else {
value
}
value
}

// For mixed unsigned to signed:
// 1. we cast to the required type using the same signedness
// 2. then we switch the signedness
(
Type::Numeric(NumericType::Unsigned { bit_size: incoming_type_size }),
NumericType::Signed { bit_size: target_type_size },
) => {
if *incoming_type_size != target_type_size {
value = self.insert_safe_cast(
value,
NumericType::unsigned(target_type_size),
location,
);
self.insert_safe_cast(value, NumericType::unsigned(target_type_size), location)
} else {
value
}
value
}

// Field to signed/unsigned:
(
Type::Numeric(NumericType::NativeField),
NumericType::Unsigned { bit_size: target_type_size },
Expand All @@ -517,6 +466,70 @@ impl<'a> FunctionContext<'a> {
self.builder.insert_cast(result, typ)
}

/// During casting signed values, if target size is bigger, we do a sign extension:
///
/// When the value is negative, it is represented in 2-complement form; `2^s-v`, where `s` is the incoming bit size and `v` is the absolute value.
/// Sign extension in this case will give `2^t-v`, where `t` is the target bit size.
/// So we simply convert `2^s-v` into `2^t-v` by adding `2^t-2^s` to the value when the value is negative.
///
/// Casting s-bits signed v0 to t-bits will add the following instructions:
/// ```ssa
/// v1 = cast v0 to 's-bits unsigned'
/// v2 = lt v1, 2**(s-1)
/// v3 = not(v1)
/// v4 = cast v3 to 't-bits unsigned'
/// v5 = v3 * (2**t - 2**s)
/// v6 = cast v1 to 't-bits unsigned'
/// return v6 + v5
/// ```
///
/// Return an unsigned value that we can cast back to the signed type if we want,
/// or keep it as it is, if we did the sign extension as part of casting e.g. `i8` to `u64`.
fn sign_extend(
&mut self,
value: ValueId,
incoming_type_size: u32,
target_type_size: u32,
location: Location,
) -> ValueId {
let value_as_unsigned =
self.insert_safe_cast(value, NumericType::unsigned(incoming_type_size), location);
let half_width = self.builder.numeric_constant(
FieldElement::from(2_u128.pow(incoming_type_size - 1)),
NumericType::unsigned(incoming_type_size),
);
// value_sign is 1 if the value is positive, 0 otherwise
let value_sign = self.builder.insert_binary(value_as_unsigned, BinaryOp::Lt, half_width);
let max_for_incoming_type_size =
if incoming_type_size == 128 { u128::MAX } else { 2_u128.pow(incoming_type_size) - 1 };
let max_for_target_type_size =
if target_type_size == 128 { u128::MAX } else { 2_u128.pow(target_type_size) - 1 };
let patch = self.builder.numeric_constant(
FieldElement::from(max_for_target_type_size - max_for_incoming_type_size),
NumericType::unsigned(target_type_size),
);
let mut is_negative_predicate = self.builder.insert_not(value_sign);
is_negative_predicate = self.insert_safe_cast(
is_negative_predicate,
NumericType::unsigned(target_type_size),
location,
);
// multiplication by a boolean cannot overflow
let patch_with_sign_predicate = self.builder.insert_binary(
patch,
BinaryOp::Mul { unchecked: true },
is_negative_predicate,
);
let value_as_unsigned =
self.builder.insert_cast(value_as_unsigned, NumericType::unsigned(target_type_size));
// Patch the bit sign, which gives a `target_type_size` bit size value, so it does not overflow.
self.builder.insert_binary(
patch_with_sign_predicate,
BinaryOp::Add { unchecked: true },
value_as_unsigned,
)
}

/// Create a const offset of an address for an array load or store
pub(super) fn make_offset(
&mut self,
Expand Down
Loading