1010//! - unsigned division: <https://github.com/llvm/llvm-project/blob/main/compiler-rt/lib/builtins/udivmodti4.c>
1111
1212use crate :: { int:: I256 , uint:: U256 } ;
13- use core:: mem:: MaybeUninit ;
13+ use core:: { mem:: MaybeUninit , num :: NonZeroU128 } ;
1414
1515#[ inline( always) ]
16- fn udiv256_by_128_to_128 ( u1 : u128 , u0 : u128 , mut v : u128 , r : & mut u128 ) -> u128 {
16+ fn udiv256_by_128_to_128 ( u1 : u128 , u0 : u128 , mut v : NonZeroU128 , r : & mut u128 ) -> u128 {
1717 const N_UDWORD_BITS : u32 = 128 ;
18+
19+ #[ inline]
20+ unsafe fn shl_nz ( x : NonZeroU128 , n : u32 ) -> NonZeroU128 {
21+ debug_assert ! ( n < N_UDWORD_BITS ) ;
22+ let res: u128 = x. get ( ) << n;
23+ debug_assert_ne ! ( res, 0 ) ;
24+ NonZeroU128 :: new_unchecked ( res)
25+ }
26+
27+ #[ inline]
28+ unsafe fn shr_nz ( x : NonZeroU128 , n : u32 ) -> NonZeroU128 {
29+ debug_assert ! ( n < N_UDWORD_BITS ) ;
30+ let res: u128 = x. get ( ) >> n;
31+ debug_assert_ne ! ( res, 0 ) ;
32+ NonZeroU128 :: new_unchecked ( res)
33+ }
34+
1835 const B : u128 = 1 << ( N_UDWORD_BITS / 2 ) ; // Number base (128 bits)
1936 let ( un1, un0) : ( u128 , u128 ) ; // Norm. dividend LSD's
20- let ( vn1, vn0) : ( u128 , u128 ) ; // Norm. divisor digits
37+ let ( vn1, vn0) : ( NonZeroU128 , u128 ) ; // Norm. divisor digits
2138 let ( mut q1, mut q0) : ( u128 , u128 ) ; // Quotient digits
2239 let ( un128, un21, un10) : ( u128 , u128 , u128 ) ; // Dividend digit pairs
2340
41+ debug_assert ! ( v. get( ) > u1) ;
42+
2443 let s = v. leading_zeros ( ) ;
44+ debug_assert_ne ! ( s, N_UDWORD_BITS ) ;
2545 if s > 0 {
2646 // Normalize the divisor.
27- v <<= s ;
47+ v = unsafe { shl_nz ( v , s ) } ;
2848 un128 = ( u1 << s) | ( u0 >> ( N_UDWORD_BITS - s) ) ;
2949 un10 = u0 << s; // Shift dividend left
3050 } else {
31- // Avoid undefined behavior of (u0 >> 64 ).
51+ // Avoid undefined behavior of (u0 >> 128 ).
3252 un128 = u1;
3353 un10 = u0;
3454 }
3555
3656 // Break divisor up into two 64-bit digits.
37- vn1 = v >> ( N_UDWORD_BITS / 2 ) ;
38- vn0 = v & 0xFFFF_FFFF_FFFF_FFFF ;
57+ vn1 = unsafe { shr_nz ( v , N_UDWORD_BITS / 2 ) } ;
58+ vn0 = v. get ( ) & 0xFFFF_FFFF_FFFF_FFFF ;
3959
4060 // Break right half of dividend into two digits.
4161 un1 = un10 >> ( N_UDWORD_BITS / 2 ) ;
4262 un0 = un10 & 0xFFFF_FFFF_FFFF_FFFF ;
4363
4464 // Compute the first quotient digit, q1.
4565 q1 = un128 / vn1;
46- let mut rhat = un128 - q1 * vn1;
66+ let mut rhat = un128 - q1 * vn1. get ( ) ;
4767
4868 // q1 has at most error 2. No more than 2 iterations.
4969 while q1 >= B || q1 * vn0 > B * rhat + un1 {
5070 q1 -= 1 ;
51- rhat += vn1;
71+ rhat += vn1. get ( ) ;
5272 if rhat >= B {
5373 break ;
5474 }
@@ -57,16 +77,16 @@ fn udiv256_by_128_to_128(u1: u128, u0: u128, mut v: u128, r: &mut u128) -> u128
5777 un21 = un128
5878 . wrapping_mul ( B )
5979 . wrapping_add ( un1)
60- . wrapping_sub ( q1. wrapping_mul ( v) ) ;
80+ . wrapping_sub ( q1. wrapping_mul ( v. get ( ) ) ) ;
6181
6282 // Compute the second quotient digit.
6383 q0 = un21 / vn1;
64- rhat = un21 - q0 * vn1;
84+ rhat = un21 - q0 * vn1. get ( ) ;
6585
6686 // q0 has at most error 2. No more than 2 iterations.
6787 while q0 >= B || q0 * vn0 > B * rhat + un0 {
6888 q0 -= 1 ;
69- rhat += vn1;
89+ rhat += vn1. get ( ) ;
7090 if rhat >= B {
7191 break ;
7292 }
@@ -75,7 +95,7 @@ fn udiv256_by_128_to_128(u1: u128, u0: u128, mut v: u128, r: &mut u128) -> u128
7595 * r = ( un21
7696 . wrapping_mul ( B )
7797 . wrapping_add ( un0)
78- . wrapping_sub ( q0. wrapping_mul ( v) ) )
98+ . wrapping_sub ( q0. wrapping_mul ( v. get ( ) ) ) )
7999 >> s;
80100 q1 * B + q0
81101}
@@ -101,10 +121,10 @@ pub fn udivmod4(
101121 // Unfortunately, there is no 256-bit equivalent on x86_64, but we can still
102122 // shortcut if the high and low values of the operands are 0:
103123 if a. high ( ) | b. high ( ) == 0 {
124+ res. write ( U256 :: from_words ( 0 , a. low ( ) / b. low ( ) ) ) ;
104125 if let Some ( rem) = rem {
105126 rem. write ( U256 :: from_words ( 0 , a. low ( ) % b. low ( ) ) ) ;
106127 }
107- res. write ( U256 :: from_words ( 0 , a. low ( ) / b. low ( ) ) ) ;
108128 return ;
109129 }
110130
@@ -130,7 +150,8 @@ pub fn udivmod4(
130150 udiv256_by_128_to_128 (
131151 * dividend. high ( ) ,
132152 * dividend. low ( ) ,
133- * divisor. low ( ) ,
153+ // SAFETY: dividend.high() < divisor.low()
154+ unsafe { NonZeroU128 :: new_unchecked ( * divisor. low ( ) ) } ,
134155 remainder. low_mut ( ) ,
135156 ) ,
136157 ) ;
@@ -142,7 +163,8 @@ pub fn udivmod4(
142163 udiv256_by_128_to_128 (
143164 dividend. high ( ) % divisor. low ( ) ,
144165 * dividend. low ( ) ,
145- * divisor. low ( ) ,
166+ // SAFETY: dividend.high() / divisor.low()
167+ unsafe { NonZeroU128 :: new_unchecked ( * divisor. low ( ) ) } ,
146168 remainder. low_mut ( ) ,
147169 ) ,
148170 ) ;
@@ -154,7 +176,8 @@ pub fn udivmod4(
154176 return ;
155177 }
156178
157- ( quotient, remainder) = div_mod_knuth ( & dividend, & divisor) ;
179+ // SAFETY: `*divisor.high() != 0`
180+ ( quotient, remainder) = unsafe { div_mod_knuth ( & dividend, & divisor) } ;
158181
159182 if let Some ( rem) = rem {
160183 rem. write ( remainder) ;
@@ -164,9 +187,18 @@ pub fn udivmod4(
164187
165188// See Knuth, TAOCP, Volume 2, section 4.3.1, Algorithm D.
166189// https://skanthak.homepage.t-online.de/division.html
190+ // SAFETY: The high word of v (the divisor) must be non-zero.
167191#[ inline]
168- pub fn div_mod_knuth ( u : & U256 , v : & U256 ) -> ( U256 , U256 ) {
192+ unsafe fn div_mod_knuth ( u : & U256 , v : & U256 ) -> ( U256 , U256 ) {
169193 const N_UDWORD_BITS : u32 = 128 ;
194+ debug_assert_ne ! (
195+ * u. high( ) ,
196+ 0 ,
197+ "The second operand must be greater than u128::MAX"
198+ ) ;
199+ if * u. high ( ) == 0 {
200+ unsafe { core:: hint:: unreachable_unchecked ( ) }
201+ }
170202
171203 #[ inline]
172204 fn full_shl ( a : & U256 , shift : u32 ) -> [ u128 ; 3 ] {
@@ -266,7 +298,6 @@ pub fn div_mod_knuth(u: &U256, v: &U256) -> (U256, U256) {
266298 let shift = v. high ( ) . leading_zeros ( ) ;
267299 debug_assert ! ( shift < N_UDWORD_BITS ) ;
268300 let v = v << shift;
269- debug_assert ! ( v. high( ) >> ( N_UDWORD_BITS - 1 ) == 1 ) ;
270301 // u will store the remainder (shifted)
271302 let mut u = full_shl ( u, shift) ;
272303
@@ -275,6 +306,14 @@ pub fn div_mod_knuth(u: &U256, v: &U256) -> (U256, U256) {
275306 let v_n_1 = * v. high ( ) ;
276307 let v_n_2 = * v. low ( ) ;
277308
309+ if v_n_1 >> ( N_UDWORD_BITS - 1 ) != 1 {
310+ debug_assert ! ( false ) ;
311+
312+ // SAFETY: `v_n_1` must be normalized because input `v` has
313+ // been checked to be non-zero.
314+ unsafe { core:: hint:: unreachable_unchecked ( ) }
315+ }
316+
278317 // D2. D7. - unrolled loop j == 0, n == 2, m == 0 (only one possible iteration)
279318 let mut r_hat: u128 = 0 ;
280319 let u_jn = u[ 2 ] ;
@@ -286,7 +325,12 @@ pub fn div_mod_knuth(u: &U256, v: &U256) -> (U256, U256) {
286325 // Theorem B: q_hat >= q_j >= q_hat - 2
287326 let mut q_hat = if u_jn < v_n_1 {
288327 //let (mut q_hat, mut r_hat) = _div_mod_u128(u_jn, u[j + n - 1], v_n_1);
289- let mut q_hat = udiv256_by_128_to_128 ( u_jn, u[ 1 ] , v_n_1, & mut r_hat) ;
328+ let mut q_hat = udiv256_by_128_to_128 (
329+ u_jn,
330+ u[ 1 ] ,
331+ unsafe { NonZeroU128 :: new_unchecked ( v_n_1) } ,
332+ & mut r_hat,
333+ ) ;
290334 let mut overflow: bool ;
291335 // this loop takes at most 2 iterations
292336 loop {
0 commit comments