Skip to content

Commit 4d8fe28

Browse files
author
jfecher
authored
fix: Try to move constant terms to one side for arithmetic generics (#6008)
# Description ## Problem\* Resolves #6006 ## Summary\* Previously we were failing for constraints like `a + 3 = b + 1` when we could instead move the constant terms to one side: `a + 2 = b` then solve with `b := a + 2`. ## Additional Context ## Documentation\* Check one: - [x] No documentation needed. - [ ] Documentation included in this PR. - [ ] **[For Experimental Features]** Documentation to be submitted in a separate PR. # PR Checklist\* - [x] I have tested the changes locally. - [x] I have formatted the changes with [Prettier](https://prettier.io/) and/or `cargo fmt` on default settings.
1 parent 21425de commit 4d8fe28

4 files changed

Lines changed: 86 additions & 3 deletions

File tree

compiler/noirc_frontend/src/hir_def/types.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1630,8 +1630,18 @@ impl Type {
16301630

16311631
(InfixExpr(lhs_a, op_a, rhs_a), InfixExpr(lhs_b, op_b, rhs_b)) => {
16321632
if op_a == op_b {
1633-
lhs_a.try_unify(lhs_b, bindings)?;
1634-
rhs_a.try_unify(rhs_b, bindings)
1633+
// We need to preserve the original bindings since if syntactic equality
1634+
// fails we fall back to other equality strategies.
1635+
let mut new_bindings = bindings.clone();
1636+
let lhs_result = lhs_a.try_unify(lhs_b, &mut new_bindings);
1637+
let rhs_result = rhs_a.try_unify(rhs_b, &mut new_bindings);
1638+
1639+
if lhs_result.is_ok() && rhs_result.is_ok() {
1640+
*bindings = new_bindings;
1641+
Ok(())
1642+
} else {
1643+
lhs.try_unify_by_moving_constant_terms(&rhs, bindings)
1644+
}
16351645
} else {
16361646
Err(UnificationError)
16371647
}

compiler/noirc_frontend/src/hir_def/types/arithmetic.rs

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::collections::BTreeSet;
22

3-
use crate::{BinaryTypeOperator, Type};
3+
use crate::{BinaryTypeOperator, Type, TypeBindings, UnificationError};
44

55
impl Type {
66
/// Try to canonicalize the representation of this type.
@@ -212,4 +212,44 @@ impl Type {
212212
_ => None,
213213
}
214214
}
215+
216+
/// Try to unify equations like `(..) + 3 = (..) + 1`
217+
/// by transforming them to `(..) + 2 = (..)`
218+
pub(super) fn try_unify_by_moving_constant_terms(
219+
&self,
220+
other: &Type,
221+
bindings: &mut TypeBindings,
222+
) -> Result<(), UnificationError> {
223+
if let Type::InfixExpr(lhs_a, op_a, rhs_a) = self {
224+
if let Some(inverse) = op_a.inverse() {
225+
if let Some(rhs_a) = rhs_a.evaluate_to_u32() {
226+
let rhs_a = Box::new(Type::Constant(rhs_a));
227+
let new_other = Type::InfixExpr(Box::new(other.clone()), inverse, rhs_a);
228+
229+
let mut tmp_bindings = bindings.clone();
230+
if lhs_a.try_unify(&new_other, &mut tmp_bindings).is_ok() {
231+
*bindings = tmp_bindings;
232+
return Ok(());
233+
}
234+
}
235+
}
236+
}
237+
238+
if let Type::InfixExpr(lhs_b, op_b, rhs_b) = other {
239+
if let Some(inverse) = op_b.inverse() {
240+
if let Some(rhs_b) = rhs_b.evaluate_to_u32() {
241+
let rhs_b = Box::new(Type::Constant(rhs_b));
242+
let new_self = Type::InfixExpr(Box::new(self.clone()), inverse, rhs_b);
243+
244+
let mut tmp_bindings = bindings.clone();
245+
if new_self.try_unify(lhs_b, &mut tmp_bindings).is_ok() {
246+
*bindings = tmp_bindings;
247+
return Ok(());
248+
}
249+
}
250+
}
251+
}
252+
253+
Err(UnificationError)
254+
}
215255
}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
[package]
2+
name = "arithmetic_generics_move_constant_terms"
3+
type = "bin"
4+
authors = [""]
5+
compiler_version = ">=0.33.0"
6+
7+
[dependencies]
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
trait FromCallData<let N: u32, let M: u32> {
2+
fn from_calldata(calldata: [Field; N]) -> (Self, [Field; M]);
3+
}
4+
5+
struct Point { x: Field, y: Field }
6+
7+
impl <let N: u32> FromCallData<N, N - 1> for Field {
8+
fn from_calldata(calldata: [Field; N]) -> (Self, [Field; (N - 1)]) {
9+
let slice = calldata.as_slice();
10+
let (value, slice) = slice.pop_front();
11+
(value, slice.as_array())
12+
}
13+
}
14+
15+
impl <let N: u32> FromCallData<N, N - 2> for Point {
16+
fn from_calldata(calldata: [Field; N]) -> (Self, [Field; (N - 2)]) {
17+
let (x, calldata) = FromCallData::from_calldata(calldata);
18+
let (y, calldata) = FromCallData::from_calldata(calldata);
19+
(Self { x, y }, calldata)
20+
}
21+
}
22+
23+
fn main() {
24+
let calldata = [1, 2];
25+
let _: (Point, _) = FromCallData::from_calldata(calldata);
26+
}

0 commit comments

Comments
 (0)