Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 41 additions & 22 deletions compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1616,28 +1616,47 @@ impl<'block, Registers: RegisterAllocator> BrilligBlock<'block, Registers> {

self.brillig_context.codegen_branch(left_is_negative.address, |ctx, is_negative| {
if is_negative {
let one = ctx.make_constant_instruction(1_u128.into(), left.bit_size);

// computes 2^right
let two = ctx.make_constant_instruction(2_u128.into(), left.bit_size);
let two_pow = ctx.make_constant_instruction(1_u128.into(), left.bit_size);
let right_u32 = SingleAddrVariable::new(ctx.allocate_register(), 32);
ctx.cast(right_u32, right);
let pow_body = |ctx: &mut BrilligContext<_, _>, _: SingleAddrVariable| {
ctx.binary_instruction(two_pow, two, two_pow, BrilligBinaryOp::Mul);
};
ctx.codegen_for_loop(None, right_u32.address, None, pow_body);

// Right shift using division on 1-complement
ctx.binary_instruction(left, one, result, BrilligBinaryOp::Add);
ctx.convert_signed_division(result, two_pow, result);
ctx.binary_instruction(result, one, result, BrilligBinaryOp::Sub);

// Clean-up
ctx.deallocate_single_addr(one);
ctx.deallocate_single_addr(two);
ctx.deallocate_single_addr(two_pow);
ctx.deallocate_single_addr(right_u32);
// If right value is greater than the left bit size, return 0
let rhs_does_not_overflow = SingleAddrVariable::new(ctx.allocate_register(), 1);
let lhs_bit_size =
ctx.make_constant_instruction(left.bit_size.into(), right.bit_size);
ctx.binary_instruction(
right,
lhs_bit_size,
rhs_does_not_overflow,
BrilligBinaryOp::LessThan,
);

ctx.codegen_branch(rhs_does_not_overflow.address, |ctx, no_overflow| {
if no_overflow {
let one = ctx.make_constant_instruction(1_u128.into(), left.bit_size);

// computes 2^right
let two = ctx.make_constant_instruction(2_u128.into(), left.bit_size);
let two_pow = ctx.make_constant_instruction(1_u128.into(), left.bit_size);
let right_u32 = SingleAddrVariable::new(ctx.allocate_register(), 32);
ctx.cast(right_u32, right);
let pow_body = |ctx: &mut BrilligContext<_, _>, _: SingleAddrVariable| {
ctx.binary_instruction(two_pow, two, two_pow, BrilligBinaryOp::Mul);
};
ctx.codegen_for_loop(None, right_u32.address, None, pow_body);

// Right shift using division on 1-complement
ctx.binary_instruction(left, one, result, BrilligBinaryOp::Add);
ctx.convert_signed_division(result, two_pow, result);
ctx.binary_instruction(result, one, result, BrilligBinaryOp::Sub);

// Clean-up
ctx.deallocate_single_addr(one);
ctx.deallocate_single_addr(two);
ctx.deallocate_single_addr(two_pow);
ctx.deallocate_single_addr(right_u32);
} else {
ctx.const_instruction(result, 0_u128.into());
}
});

ctx.deallocate_single_addr(rhs_does_not_overflow);
} else {
ctx.binary_instruction(left, right, result, BrilligBinaryOp::Shr);
}
Expand Down
66 changes: 19 additions & 47 deletions compiler/noirc_evaluator/src/ssa/opt/remove_bit_shifts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,10 +161,21 @@ impl Context<'_, '_, '_> {
) -> ValueId {
let lhs_typ = self.context.dfg.type_of_value(lhs).unwrap_numeric();
let base = self.field_constant(FieldElement::from(2_u128));
let rhs_typ = self.context.dfg.type_of_value(rhs).unwrap_numeric();
//Check whether rhs is less than the bit_size: if it's not then it will overflow and we will return 0 instead.
let bit_size_value = self.numeric_constant(bit_size as u128, rhs_typ);
let rhs_is_less_than_bit_size = self.insert_binary(rhs, BinaryOp::Lt, bit_size_value);
let rhs_is_less_than_bit_size_with_rhs_typ =
self.insert_cast(rhs_is_less_than_bit_size, rhs_typ);
// Nullify rhs in case of overflow, to ensure that pow returns a value compatible with lhs
let rhs = self.insert_binary(
rhs_is_less_than_bit_size_with_rhs_typ,
BinaryOp::Mul { unchecked: true },
rhs,
);
let pow = self.pow(base, rhs);
let pow = self.pow_or_max_for_bit_size(pow, rhs, bit_size, lhs_typ);
let pow = self.insert_cast(pow, lhs_typ);
if lhs_typ.is_unsigned() {
let result = if lhs_typ.is_unsigned() {
// unsigned right bit shift is just a normal division
self.insert_binary(lhs, BinaryOp::Div, pow)
} else {
Expand Down Expand Up @@ -198,53 +209,14 @@ impl Context<'_, '_, '_> {
lhs_sign_as_int,
);
self.insert_truncate(shifted, bit_size, bit_size + 1)
}
}

/// Returns `pow` or the maximum value allowed for `typ` if 2^rhs is guaranteed to exceed that maximum.
fn pow_or_max_for_bit_size(
&mut self,
pow: ValueId,
rhs: ValueId,
bit_size: u32,
typ: NumericType,
) -> ValueId {
let max = if typ.is_unsigned() {
if bit_size == 128 { u128::MAX } else { (1_u128 << bit_size) - 1 }
} else {
1_u128 << (bit_size - 1)
};
let max = self.field_constant(FieldElement::from(max));

// Here we check whether rhs is less than the bit_size: if it's not then it will overflow.
// Then we do:
//
// rhs_is_less_than_bit_size = lt rhs, bit_size
// rhs_is_not_less_than_bit_size = not rhs_is_less_than_bit_size
// pow_when_is_less_than_bit_size = rhs_is_less_than_bit_size * pow
// pow_when_is_not_less_than_bit_size = rhs_is_not_less_than_bit_size * max
// pow = add pow_when_is_less_than_bit_size, pow_when_is_not_less_than_bit_size
//
// All operations here are unchecked because they work on field types.
let rhs_typ = self.context.dfg.type_of_value(rhs).unwrap_numeric();
let bit_size = self.numeric_constant(bit_size as u128, rhs_typ);
let rhs_is_less_than_bit_size = self.insert_binary(rhs, BinaryOp::Lt, bit_size);
let rhs_is_not_less_than_bit_size = self.insert_not(rhs_is_less_than_bit_size);
let rhs_is_less_than_bit_size =
self.insert_cast(rhs_is_less_than_bit_size, NumericType::NativeField);
let rhs_is_not_less_than_bit_size =
self.insert_cast(rhs_is_not_less_than_bit_size, NumericType::NativeField);
let pow_when_is_less_than_bit_size =
self.insert_binary(rhs_is_less_than_bit_size, BinaryOp::Mul { unchecked: true }, pow);
let pow_when_is_not_less_than_bit_size = self.insert_binary(
rhs_is_not_less_than_bit_size,
BinaryOp::Mul { unchecked: true },
max,
);
// Returns 0 in case of overflow
let rhs_is_less_than_bit_size_with_lhs_typ =
self.insert_cast(rhs_is_less_than_bit_size, lhs_typ);
self.insert_binary(
pow_when_is_less_than_bit_size,
BinaryOp::Add { unchecked: true },
pow_when_is_not_less_than_bit_size,
rhs_is_less_than_bit_size_with_lhs_typ,
BinaryOp::Mul { unchecked: true },
result,
)
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
x = 64
y = 1
z = "-769"
u = -1
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
fn main(x: u64, y: u8, z: i16) {
fn main(x: u64, y: u8, z: i16, u: i64) {
// runtime shifts on compile-time known values
assert(64 << y == 128);
assert(64 >> y == 32);
// runtime shifts on runtime values
assert(x << y == 128);
assert(x >> y == 32);
// regression tests for issue #8176
assert(u >> (x as u8) == 0);
assert(z >> (x as u8) == 0);

// Bit-shift with signed integers
let mut a: i8 = y as i8;
Expand Down
Loading
Loading