@@ -123,6 +123,39 @@ pub(super) fn decompose_constrain(
123123 }
124124 }
125125
126+ ( Value :: NumericConstant { constant, .. } , Value :: Instruction { instruction, .. } )
127+ | ( Value :: Instruction { instruction, .. } , Value :: NumericConstant { constant, .. } ) => {
128+ match dfg[ * instruction] {
129+ Instruction :: Binary ( Binary { lhs, rhs, operator : BinaryOp :: Mul { .. } } )
130+ if constant. is_zero ( ) && lhs == rhs =>
131+ {
132+ // Replace an assertion that a squared value is zero
133+ //
134+ // v1 = mul v0, v0
135+ // constrain v1 == u1 0
136+ //
137+ // with a direct assertion that value being squared is equal to 0
138+ //
139+ // v1 = mul v0, v0
140+ // constrain v0 == u1 0
141+ //
142+ // This is due to the fact that for `v1` to be 0 then `v0` is 0.
143+ //
144+ // Note that this doesn't remove the value `v1` as it may be used in other instructions, but it
145+ // will likely be removed through dead instruction elimination.
146+ //
147+ // This is safe for all numeric types as the underlying field has a prime modulus so squaring
148+ // a non-zero value should never result in zero.
149+
150+ let zero = FieldElement :: zero ( ) ;
151+ let zero = dfg. make_constant ( zero, dfg. type_of_value ( lhs) . unwrap_numeric ( ) ) ;
152+ decompose_constrain ( lhs, zero, msg, dfg)
153+ }
154+
155+ _ => vec ! [ Instruction :: Constrain ( lhs, rhs, msg. clone( ) ) ] ,
156+ }
157+ }
158+
126159 (
127160 Value :: Instruction { instruction : instruction_lhs, .. } ,
128161 Value :: Instruction { instruction : instruction_rhs, .. } ,
@@ -144,3 +177,31 @@ pub(super) fn decompose_constrain(
144177 }
145178 }
146179}
180+
181+ #[ cfg( test) ]
182+ mod tests {
183+ use crate :: ssa:: { opt:: assert_normalized_ssa_equals, ssa_gen:: Ssa } ;
184+
185+ #[ test]
186+ fn simplifies_assertions_that_squared_values_are_equal_to_zero ( ) {
187+ let src = "
188+ acir(inline) fn main f0 {
189+ b0(v0: Field):
190+ v1 = mul v0, v0
191+ constrain v1 == Field 0
192+ return
193+ }
194+ " ;
195+ let ssa = Ssa :: from_str_simplifying ( src) . unwrap ( ) ;
196+
197+ let expected = "
198+ acir(inline) fn main f0 {
199+ b0(v0: Field):
200+ v1 = mul v0, v0
201+ constrain v0 == Field 0
202+ return
203+ }
204+ " ;
205+ assert_normalized_ssa_equals ( ssa, expected) ;
206+ }
207+ }
0 commit comments