@@ -21,10 +21,272 @@ pub const ZERO: Uint = Uint::from_u64(0);
2121pub const ONE : Uint = Uint :: from_u64 ( 1 ) ;
2222
2323impl Uint {
24+ const N_WORDS : usize = 4 ;
25+
2426 /// Convert a [`u64`] to a [`Uint`].
2527 pub const fn from_u64 ( x : u64 ) -> Uint {
2628 Uint ( [ x. to_le ( ) , 0 , 0 , 0 ] )
2729 }
30+
31+ /// Return the least number of bits needed to represent the number
32+ #[ inline]
33+ pub fn bits_512 ( arr : & [ u64 ; 2 * Self :: N_WORDS ] ) -> usize {
34+ for i in 1 ..arr. len ( ) {
35+ if arr[ arr. len ( ) - i] > 0 {
36+ return ( 0x40 * ( arr. len ( ) - i + 1 ) )
37+ - arr[ arr. len ( ) - i] . leading_zeros ( ) as usize ;
38+ }
39+ }
40+ 0x40 - arr[ 0 ] . leading_zeros ( ) as usize
41+ }
42+
43+ fn div_mod_small_512 (
44+ mut slf : [ u64 ; 2 * Self :: N_WORDS ] ,
45+ other : u64 ,
46+ ) -> ( [ u64 ; 2 * Self :: N_WORDS ] , Self ) {
47+ let mut rem = 0u64 ;
48+ slf. iter_mut ( ) . rev ( ) . for_each ( |d| {
49+ let ( q, r) = Self :: div_mod_word ( rem, * d, other) ;
50+ * d = q;
51+ rem = r;
52+ } ) ;
53+ ( slf, rem. into ( ) )
54+ }
55+
56+ fn shr_512 (
57+ original : [ u64 ; 2 * Self :: N_WORDS ] ,
58+ shift : u32 ,
59+ ) -> [ u64 ; 2 * Self :: N_WORDS ] {
60+ let shift = shift as usize ;
61+ let mut ret = [ 0u64 ; 2 * Self :: N_WORDS ] ;
62+ let word_shift = shift / 64 ;
63+ let bit_shift = shift % 64 ;
64+
65+ // shift
66+ for i in word_shift..original. len ( ) {
67+ ret[ i - word_shift] = original[ i] >> bit_shift;
68+ }
69+
70+ // Carry
71+ if bit_shift > 0 {
72+ for i in word_shift + 1 ..original. len ( ) {
73+ ret[ i - word_shift - 1 ] += original[ i] << ( 64 - bit_shift) ;
74+ }
75+ }
76+
77+ ret
78+ }
79+
80+ fn full_shl_512 (
81+ slf : [ u64 ; 2 * Self :: N_WORDS ] ,
82+ shift : u32 ,
83+ ) -> [ u64 ; 2 * Self :: N_WORDS + 1 ] {
84+ debug_assert ! ( shift < Self :: WORD_BITS as u32 ) ;
85+ let mut u = [ 0u64 ; 2 * Self :: N_WORDS + 1 ] ;
86+ let u_lo = slf[ 0 ] << shift;
87+ let u_hi = Self :: shr_512 ( slf, Self :: WORD_BITS as u32 - shift) ;
88+ u[ 0 ] = u_lo;
89+ u[ 1 ..] . copy_from_slice ( & u_hi[ ..] ) ;
90+ u
91+ }
92+
93+ fn full_shr_512 (
94+ u : [ u64 ; 2 * Self :: N_WORDS + 1 ] ,
95+ shift : u32 ,
96+ ) -> [ u64 ; 2 * Self :: N_WORDS ] {
97+ debug_assert ! ( shift < Self :: WORD_BITS as u32 ) ;
98+ let mut res = [ 0 ; 2 * Self :: N_WORDS ] ;
99+ for i in 0 ..res. len ( ) {
100+ res[ i] = u[ i] >> shift;
101+ }
102+ // carry
103+ if shift > 0 {
104+ for i in 1 ..=res. len ( ) {
105+ res[ i - 1 ] |= u[ i] << ( Self :: WORD_BITS as u32 - shift) ;
106+ }
107+ }
108+ res
109+ }
110+
111+ // See Knuth, TAOCP, Volume 2, section 4.3.1, Algorithm D.
112+ fn div_mod_knuth_512 (
113+ slf : [ u64 ; 2 * Self :: N_WORDS ] ,
114+ mut v : Self ,
115+ n : usize ,
116+ m : usize ,
117+ ) -> ( [ u64 ; 2 * Self :: N_WORDS ] , Self ) {
118+ debug_assert ! ( Self :: bits_512( & slf) >= v. bits( ) && !v. fits_word( ) ) ;
119+ debug_assert ! ( n + m <= slf. len( ) ) ;
120+ // D1.
121+ // Make sure 64th bit in v's highest word is set.
122+ // If we shift both self and v, it won't affect the quotient
123+ // and the remainder will only need to be shifted back.
124+ let shift = v. 0 [ n - 1 ] . leading_zeros ( ) ;
125+ v <<= shift;
126+ // u will store the remainder (shifted)
127+ let mut u = Self :: full_shl_512 ( slf, shift) ;
128+
129+ // quotient
130+ let mut q = [ 0 ; 2 * Self :: N_WORDS ] ;
131+ let v_n_1 = v. 0 [ n - 1 ] ;
132+ let v_n_2 = v. 0 [ n - 2 ] ;
133+
134+ // D2. D7.
135+ // iterate from m downto 0
136+ for j in ( 0 ..=m) . rev ( ) {
137+ let u_jn = u[ j + n] ;
138+
139+ // D3.
140+ // q_hat is our guess for the j-th quotient digit
141+ // q_hat = min(b - 1, (u_{j+n} * b + u_{j+n-1}) / v_{n-1})
142+ // b = 1 << WORD_BITS
143+ // Theorem B: q_hat >= q_j >= q_hat - 2
144+ let mut q_hat = if u_jn < v_n_1 {
145+ let ( mut q_hat, mut r_hat) =
146+ Self :: div_mod_word ( u_jn, u[ j + n - 1 ] , v_n_1) ;
147+ // this loop takes at most 2 iterations
148+ loop {
149+ // check if q_hat * v_{n-2} > b * r_hat + u_{j+n-2}
150+ let ( hi, lo) =
151+ Self :: split_u128 ( u128:: from ( q_hat) * u128:: from ( v_n_2) ) ;
152+ if ( hi, lo) <= ( r_hat, u[ j + n - 2 ] ) {
153+ break ;
154+ }
155+ // then iterate till it doesn't hold
156+ q_hat -= 1 ;
157+ let ( new_r_hat, overflow) = r_hat. overflowing_add ( v_n_1) ;
158+ r_hat = new_r_hat;
159+ // if r_hat overflowed, we're done
160+ if overflow {
161+ break ;
162+ }
163+ }
164+ q_hat
165+ } else {
166+ // here q_hat >= q_j >= q_hat - 1
167+ u64:: max_value ( )
168+ } ;
169+
170+ // ex. 20:
171+ // since q_hat * v_{n-2} <= b * r_hat + u_{j+n-2},
172+ // either q_hat == q_j, or q_hat == q_j + 1
173+
174+ // D4.
175+ // let's assume optimistically q_hat == q_j
176+ // subtract (q_hat * v) from u[j..]
177+ let q_hat_v = v. full_mul_u64 ( q_hat) ;
178+ // u[j..] -= q_hat_v;
179+ let c = Self :: sub_slice ( & mut u[ j..] , & q_hat_v[ ..n + 1 ] ) ;
180+
181+ // D6.
182+ // actually, q_hat == q_j + 1 and u[j..] has overflowed
183+ // highly unlikely ~ (1 / 2^63)
184+ if c {
185+ q_hat -= 1 ;
186+ // add v to u[j..]
187+ let c = Self :: add_slice ( & mut u[ j..] , & v. 0 [ ..n] ) ;
188+ u[ j + n] = u[ j + n] . wrapping_add ( u64:: from ( c) ) ;
189+ }
190+
191+ // D5.
192+ q[ j] = q_hat;
193+ }
194+
195+ // D8.
196+ let remainder = Self :: full_shr_512 ( u, shift) ;
197+ // The remainder should never exceed the capacity of Self
198+ debug_assert ! (
199+ Self :: bits_512( & remainder) <= Self :: N_WORDS * Self :: WORD_BITS
200+ ) ;
201+ ( q, Self ( remainder[ ..Self :: N_WORDS ] . try_into ( ) . unwrap ( ) ) )
202+ }
203+
204+ /// Returns a pair `(self / other, self % other)`.
205+ ///
206+ /// # Panics
207+ ///
208+ /// Panics if `other` is zero.
209+ pub fn div_mod_512 (
210+ slf : [ u64 ; 2 * Self :: N_WORDS ] ,
211+ other : Self ,
212+ ) -> ( [ u64 ; 2 * Self :: N_WORDS ] , Self ) {
213+ let my_bits = Self :: bits_512 ( & slf) ;
214+ let your_bits = other. bits ( ) ;
215+
216+ assert ! ( your_bits != 0 , "division by zero" ) ;
217+
218+ // Early return in case we are dividing by a larger number than us
219+ if my_bits < your_bits {
220+ return (
221+ [ 0 ; 2 * Self :: N_WORDS ] ,
222+ Self ( slf[ ..Self :: N_WORDS ] . try_into ( ) . unwrap ( ) ) ,
223+ ) ;
224+ }
225+
226+ if your_bits <= Self :: WORD_BITS {
227+ return Self :: div_mod_small_512 ( slf, other. low_u64 ( ) ) ;
228+ }
229+
230+ let ( n, m) = {
231+ let my_words = Self :: words ( my_bits) ;
232+ let your_words = Self :: words ( your_bits) ;
233+ ( your_words, my_words - your_words)
234+ } ;
235+
236+ Self :: div_mod_knuth_512 ( slf, other, n, m)
237+ }
238+
239+ /// Returns a pair `(Some((self * num) / denom), (self * num) % denom)` if
240+ /// the quotient fits into Self. Otherwise `(None, (self * num) % denom)` is
241+ /// returned.
242+ ///
243+ /// # Panics
244+ ///
245+ /// Panics if `denom` is zero.
246+ pub fn checked_mul_div (
247+ & self ,
248+ num : Self ,
249+ denom : Self ,
250+ ) -> ( Option < Self > , Self ) {
251+ let prod = uint:: uint_full_mul_reg!( Uint , 4 , self , num) ;
252+ let ( quotient, remainder) = Self :: div_mod_512 ( prod, denom) ;
253+ // The compiler WILL NOT inline this if you remove this annotation.
254+ #[ inline( always) ]
255+ fn any_nonzero ( arr : & [ u64 ] ) -> bool {
256+ use uint:: unroll;
257+ unroll ! {
258+ for i in 0 ..4 {
259+ if arr[ i] != 0 {
260+ return true ;
261+ }
262+ }
263+ }
264+
265+ false
266+ }
267+ (
268+ if any_nonzero ( & quotient[ Self :: N_WORDS ..] ) {
269+ None
270+ } else {
271+ Some ( Self ( quotient[ 0 ..Self :: N_WORDS ] . try_into ( ) . unwrap ( ) ) )
272+ } ,
273+ remainder,
274+ )
275+ }
276+
277+ /// Returns a pair `((self * num) / denom, (self * num) % denom)`.
278+ ///
279+ /// # Panics
280+ ///
281+ /// Panics if `denom` is zero.
282+ pub fn mul_div ( & self , num : Self , denom : Self ) -> ( Self , Self ) {
283+ let prod = uint:: uint_full_mul_reg!( Uint , 4 , self , num) ;
284+ let ( quotient, remainder) = Self :: div_mod_512 ( prod, denom) ;
285+ (
286+ Self ( quotient[ 0 ..Self :: N_WORDS ] . try_into ( ) . unwrap ( ) ) ,
287+ remainder,
288+ )
289+ }
28290}
29291
30292construct_uint ! {
@@ -171,10 +433,9 @@ impl Uint {
171433 /// * `self` * 10^(`denom`) overflows 256 bits
172434 /// * `other` is zero (`checked_div` will return `None`).
173435 pub fn fixed_precision_div ( & self , rhs : & Self , denom : u8 ) -> Option < Self > {
174- let lhs = Uint :: from ( 10 )
436+ Uint :: from ( 10 )
175437 . checked_pow ( Uint :: from ( denom) )
176- . and_then ( |res| res. checked_mul ( * self ) ) ?;
177- lhs. checked_div ( * rhs)
438+ . and_then ( |res| res. checked_mul_div ( * self , * rhs) . 0 )
178439 }
179440
180441 /// Compute the two's complement of a number.
@@ -710,4 +971,39 @@ mod test_uint {
710971 let amount: Result < Uint , _ > = serde_json:: from_str ( r#""1000000000.2""# ) ;
711972 assert ! ( amount. is_err( ) ) ;
712973 }
974+
975+ #[ test]
976+ fn test_mul_div ( ) {
977+ use std:: str:: FromStr ;
978+ let a: Uint = Uint :: from_str (
979+ "0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF" ,
980+ ) . unwrap ( ) ;
981+ let b: Uint = Uint :: from_str (
982+ "0x8000000000000000000000000000000000000000000000000000000000000000" ,
983+ ) . unwrap ( ) ;
984+ let c: Uint = Uint :: from_str (
985+ "0x4000000000000000000000000000000000000000000000000000000000000000" ,
986+ ) . unwrap ( ) ;
987+ let d: Uint = Uint :: from_str (
988+ "0x7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF" ,
989+ ) . unwrap ( ) ;
990+ let e: Uint = Uint :: from_str (
991+ "0x0000000000000000000000000000000000000000000000000000000000000001" ,
992+ ) . unwrap ( ) ;
993+ let f: Uint = Uint :: from_str (
994+ "0x0000000000000000000000000000000000000000000000000000000000000000" ,
995+ ) . unwrap ( ) ;
996+ assert_eq ! ( a. mul_div( a, a) , ( a, Uint :: zero( ) ) ) ;
997+ assert_eq ! ( b. mul_div( c, b) , ( c, Uint :: zero( ) ) ) ;
998+ assert_eq ! ( a. mul_div( c, b) , ( d, c) ) ;
999+ assert_eq ! ( a. mul_div( e, e) , ( a, Uint :: zero( ) ) ) ;
1000+ assert_eq ! ( e. mul_div( c, b) , ( Uint :: zero( ) , c) ) ;
1001+ assert_eq ! ( f. mul_div( a, e) , ( Uint :: zero( ) , Uint :: zero( ) ) ) ;
1002+ assert_eq ! ( a. checked_mul_div( a, a) , ( Some ( a) , Uint :: zero( ) ) ) ;
1003+ assert_eq ! ( b. checked_mul_div( c, b) , ( Some ( c) , Uint :: zero( ) ) ) ;
1004+ assert_eq ! ( a. checked_mul_div( c, b) , ( Some ( d) , c) ) ;
1005+ assert_eq ! ( a. checked_mul_div( e, e) , ( Some ( a) , Uint :: zero( ) ) ) ;
1006+ assert_eq ! ( e. checked_mul_div( c, b) , ( Some ( Uint :: zero( ) ) , c) ) ;
1007+ assert_eq ! ( d. checked_mul_div( a, e) , ( None , Uint :: zero( ) ) ) ;
1008+ }
7131009}
0 commit comments