@@ -161,10 +161,21 @@ impl Context<'_, '_, '_> {
161161 ) -> ValueId {
162162 let lhs_typ = self . context . dfg . type_of_value ( lhs) . unwrap_numeric ( ) ;
163163 let base = self . field_constant ( FieldElement :: from ( 2_u128 ) ) ;
164+ let rhs_typ = self . context . dfg . type_of_value ( rhs) . unwrap_numeric ( ) ;
165+ //Check whether rhs is less than the bit_size: if it's not then it will overflow and we will return 0 instead.
166+ let bit_size_value = self . numeric_constant ( bit_size as u128 , rhs_typ) ;
167+ let rhs_is_less_than_bit_size = self . insert_binary ( rhs, BinaryOp :: Lt , bit_size_value) ;
168+ let rhs_is_less_than_bit_size_with_rhs_typ =
169+ self . insert_cast ( rhs_is_less_than_bit_size, rhs_typ) ;
170+ // Nullify rhs in case of overflow, to ensure that pow returns a value compatible with lhs
171+ let rhs = self . insert_binary (
172+ rhs_is_less_than_bit_size_with_rhs_typ,
173+ BinaryOp :: Mul { unchecked : true } ,
174+ rhs,
175+ ) ;
164176 let pow = self . pow ( base, rhs) ;
165- let pow = self . pow_or_max_for_bit_size ( pow, rhs, bit_size, lhs_typ) ;
166177 let pow = self . insert_cast ( pow, lhs_typ) ;
167- if lhs_typ. is_unsigned ( ) {
178+ let result = if lhs_typ. is_unsigned ( ) {
168179 // unsigned right bit shift is just a normal division
169180 self . insert_binary ( lhs, BinaryOp :: Div , pow)
170181 } else {
@@ -198,53 +209,14 @@ impl Context<'_, '_, '_> {
198209 lhs_sign_as_int,
199210 ) ;
200211 self . insert_truncate ( shifted, bit_size, bit_size + 1 )
201- }
202- }
203-
204- /// Returns `pow` or the maximum value allowed for `typ` if 2^rhs is guaranteed to exceed that maximum.
205- fn pow_or_max_for_bit_size (
206- & mut self ,
207- pow : ValueId ,
208- rhs : ValueId ,
209- bit_size : u32 ,
210- typ : NumericType ,
211- ) -> ValueId {
212- let max = if typ. is_unsigned ( ) {
213- if bit_size == 128 { u128:: MAX } else { ( 1_u128 << bit_size) - 1 }
214- } else {
215- 1_u128 << ( bit_size - 1 )
216212 } ;
217- let max = self . field_constant ( FieldElement :: from ( max) ) ;
218-
219- // Here we check whether rhs is less than the bit_size: if it's not then it will overflow.
220- // Then we do:
221- //
222- // rhs_is_less_than_bit_size = lt rhs, bit_size
223- // rhs_is_not_less_than_bit_size = not rhs_is_less_than_bit_size
224- // pow_when_is_less_than_bit_size = rhs_is_less_than_bit_size * pow
225- // pow_when_is_not_less_than_bit_size = rhs_is_not_less_than_bit_size * max
226- // pow = add pow_when_is_less_than_bit_size, pow_when_is_not_less_than_bit_size
227- //
228- // All operations here are unchecked because they work on field types.
229- let rhs_typ = self . context . dfg . type_of_value ( rhs) . unwrap_numeric ( ) ;
230- let bit_size = self . numeric_constant ( bit_size as u128 , rhs_typ) ;
231- let rhs_is_less_than_bit_size = self . insert_binary ( rhs, BinaryOp :: Lt , bit_size) ;
232- let rhs_is_not_less_than_bit_size = self . insert_not ( rhs_is_less_than_bit_size) ;
233- let rhs_is_less_than_bit_size =
234- self . insert_cast ( rhs_is_less_than_bit_size, NumericType :: NativeField ) ;
235- let rhs_is_not_less_than_bit_size =
236- self . insert_cast ( rhs_is_not_less_than_bit_size, NumericType :: NativeField ) ;
237- let pow_when_is_less_than_bit_size =
238- self . insert_binary ( rhs_is_less_than_bit_size, BinaryOp :: Mul { unchecked : true } , pow) ;
239- let pow_when_is_not_less_than_bit_size = self . insert_binary (
240- rhs_is_not_less_than_bit_size,
241- BinaryOp :: Mul { unchecked : true } ,
242- max,
243- ) ;
213+ // Returns 0 in case of overflow
214+ let rhs_is_less_than_bit_size_with_lhs_typ =
215+ self . insert_cast ( rhs_is_less_than_bit_size, lhs_typ) ;
244216 self . insert_binary (
245- pow_when_is_less_than_bit_size ,
246- BinaryOp :: Add { unchecked : true } ,
247- pow_when_is_not_less_than_bit_size ,
217+ rhs_is_less_than_bit_size_with_lhs_typ ,
218+ BinaryOp :: Mul { unchecked : true } ,
219+ result ,
248220 )
249221 }
250222
0 commit comments