@@ -9,125 +9,118 @@ use subtle::CtOption;
99
1010pub ( crate ) mod karatsuba;
1111
12- /// Implement the core schoolbook multiplication algorithm.
12+ /// Schoolbook multiplication a.k.a. long multiplication, i.e. the traditional method taught in
13+ /// schools.
1314///
14- /// This is implemented as a macro to abstract over `const fn` and boxed use cases, since the latter
15- /// needs mutable references and thus the unstable `const_mut_refs` feature (rust-lang/rust#57349).
16- ///
17- /// It allows us to have a single place (this module) to improve the multiplication implementation
18- /// which will also be reused for `BoxedUint`.
19- // TODO(tarcieri): change this into a `const fn` when `const_mut_refs` is stable
20- macro_rules! impl_schoolbook_multiplication {
21- ( $lhs: expr, $rhs: expr, $lo: expr, $hi: expr) => { {
22- if $lhs. len( ) != $lo. len( ) || $rhs. len( ) != $hi. len( ) {
23- panic!( "schoolbook multiplication length mismatch" ) ;
24- }
25-
26- let mut i = 0 ;
27- while i < $lhs. len( ) {
28- let mut j = 0 ;
29- let mut carry = Limb :: ZERO ;
30- let xi = $lhs[ i] ;
31-
32- while j < $rhs. len( ) {
33- let k = i + j;
15+ /// The most efficient method for small numbers.
16+ const fn schoolbook_multiplication ( lhs : & [ Limb ] , rhs : & [ Limb ] , lo : & mut [ Limb ] , hi : & mut [ Limb ] ) {
17+ if lhs. len ( ) != lo. len ( ) || rhs. len ( ) != hi. len ( ) {
18+ panic ! ( "schoolbook multiplication length mismatch" ) ;
19+ }
3420
35- if k >= $lhs . len ( ) {
36- ( $hi [ k - $ lhs. len( ) ] , carry ) = $hi [ k - $lhs . len ( ) ] . mac ( xi , $rhs [ j ] , carry ) ;
37- } else {
38- ( $lo [ k ] , carry) = $lo [ k ] . mac ( xi , $rhs [ j ] , carry ) ;
39- }
21+ let mut i = 0 ;
22+ while i < lhs. len ( ) {
23+ let mut j = 0 ;
24+ let mut carry = Limb :: ZERO ;
25+ let xi = lhs [ i ] ;
4026
41- j += 1 ;
42- }
27+ while j < rhs . len ( ) {
28+ let k = i + j ;
4329
44- if i + j >= $ lhs. len( ) {
45- $ hi[ i + j - $ lhs. len( ) ] = carry;
30+ if k >= lhs. len ( ) {
31+ ( hi[ k - lhs . len ( ) ] , carry ) = hi [ k - lhs. len ( ) ] . mac ( xi , rhs [ j ] , carry) ;
4632 } else {
47- $ lo[ i + j ] = carry;
33+ ( lo[ k ] , carry ) = lo [ k ] . mac ( xi , rhs [ j ] , carry) ;
4834 }
49- i += 1 ;
35+
36+ j += 1 ;
5037 }
51- } } ;
38+
39+ if i + j >= lhs. len ( ) {
40+ hi[ i + j - lhs. len ( ) ] = carry;
41+ } else {
42+ lo[ i + j] = carry;
43+ }
44+ i += 1 ;
45+ }
5246}
5347
54- /// Implement the schoolbook method for squaring.
48+ /// Schoolbook method of squaring.
5549///
5650/// Like schoolbook multiplication, but only considering half of the multiplication grid.
57- // TODO: change this into a `const fn` when `const_mut_refs` is stable.
58- macro_rules! impl_schoolbook_squaring {
59- ( $limbs: expr, $lo: expr, $hi: expr) => { {
60- // Translated from https://github.com/ucbrise/jedi-pairing/blob/c4bf151/include/core/bigint.hpp#L410
61- //
62- // Permission to relicense the resulting translation as Apache 2.0 + MIT was given
63- // by the original author Sam Kumar: https://github.com/RustCrypto/crypto-bigint/pull/133#discussion_r1056870411
64-
65- if $limbs. len( ) != $lo. len( ) || $lo. len( ) != $hi. len( ) {
66- panic!( "schoolbook squaring length mismatch" ) ;
67- }
68-
69- let mut i = 1 ;
70- while i < $limbs. len( ) {
71- let mut j = 0 ;
72- let mut carry = Limb :: ZERO ;
73- let xi = $limbs[ i] ;
51+ pub ( crate ) const fn schoolbook_squaring ( limbs : & [ Limb ] , lo : & mut [ Limb ] , hi : & mut [ Limb ] ) {
52+ // Translated from https://github.com/ucbrise/jedi-pairing/blob/c4bf151/include/core/bigint.hpp#L410
53+ //
54+ // Permission to relicense the resulting translation as Apache 2.0 + MIT was given
55+ // by the original author Sam Kumar: https://github.com/RustCrypto/crypto-bigint/pull/133#discussion_r1056870411
56+
57+ if limbs. len ( ) != lo. len ( ) || lo. len ( ) != hi. len ( ) {
58+ panic ! ( "schoolbook squaring length mismatch" ) ;
59+ }
7460
75- while j < i {
76- let k = i + j;
61+ let mut i = 1 ;
62+ while i < limbs. len ( ) {
63+ let mut j = 0 ;
64+ let mut carry = Limb :: ZERO ;
65+ let xi = limbs[ i] ;
7766
78- if k >= $limbs. len( ) {
79- ( $hi[ k - $limbs. len( ) ] , carry) = $hi[ k - $limbs. len( ) ] . mac( xi, $limbs[ j] , carry) ;
80- } else {
81- ( $lo[ k] , carry) = $lo[ k] . mac( xi, $limbs[ j] , carry) ;
82- }
67+ while j < i {
68+ let k = i + j;
8369
84- j += 1 ;
85- }
86-
87- if ( 2 * i) < $limbs. len( ) {
88- $lo[ 2 * i] = carry;
70+ if k >= limbs. len ( ) {
71+ ( hi[ k - limbs. len ( ) ] , carry) = hi[ k - limbs. len ( ) ] . mac ( xi, limbs[ j] , carry) ;
8972 } else {
90- $hi [ 2 * i - $limbs . len ( ) ] = carry;
73+ ( lo [ k ] , carry ) = lo [ k ] . mac ( xi , limbs [ j ] , carry) ;
9174 }
9275
93- i += 1 ;
76+ j += 1 ;
9477 }
9578
96- // Double the current result, this accounts for the other half of the multiplication grid.
97- // The top word is empty, so we use a special purpose shl.
98- let mut carry = Limb :: ZERO ;
99- let mut i = 0 ;
100- while i < $limbs. len( ) {
101- ( $lo[ i] . 0 , carry) = ( $lo[ i] . 0 << 1 | carry. 0 , $lo[ i] . shr( Limb :: BITS - 1 ) ) ;
102- i += 1 ;
79+ if ( 2 * i) < limbs. len ( ) {
80+ lo[ 2 * i] = carry;
81+ } else {
82+ hi[ 2 * i - limbs. len ( ) ] = carry;
10383 }
104- i = 0 ;
105- while i < $limbs. len( ) - 1 {
106- ( $hi[ i] . 0 , carry) = ( $hi[ i] . 0 << 1 | carry. 0 , $hi[ i] . shr( Limb :: BITS - 1 ) ) ;
107- i += 1 ;
108- }
109- $hi[ $limbs. len( ) - 1 ] = carry;
11084
111- // Handle the diagonal of the multiplication grid, which finishes the multiplication grid.
112- let mut carry = Limb :: ZERO ;
113- let mut i = 0 ;
114- while i < $limbs. len( ) {
115- let xi = $limbs[ i] ;
116- if ( i * 2 ) < $limbs. len( ) {
117- ( $lo[ i * 2 ] , carry) = $lo[ i * 2 ] . mac( xi, xi, carry) ;
118- } else {
119- ( $hi[ i * 2 - $limbs. len( ) ] , carry) = $hi[ i * 2 - $limbs. len( ) ] . mac( xi, xi, carry) ;
120- }
85+ i += 1 ;
86+ }
12187
122- if ( i * 2 + 1 ) < $limbs. len( ) {
123- ( $lo[ i * 2 + 1 ] , carry) = $lo[ i * 2 + 1 ] . overflowing_add( carry) ;
124- } else {
125- ( $hi[ i * 2 + 1 - $limbs. len( ) ] , carry) = $hi[ i * 2 + 1 - $limbs. len( ) ] . overflowing_add( carry) ;
126- }
88+ // Double the current result, this accounts for the other half of the multiplication grid.
89+ // The top word is empty, so we use a special purpose shl.
90+ let mut carry = Limb :: ZERO ;
91+ let mut i = 0 ;
92+ while i < limbs. len ( ) {
93+ ( lo[ i] . 0 , carry) = ( lo[ i] . 0 << 1 | carry. 0 , lo[ i] . shr ( Limb :: BITS - 1 ) ) ;
94+ i += 1 ;
95+ }
12796
128- i += 1 ;
97+ let mut i = 0 ;
98+ while i < limbs. len ( ) - 1 {
99+ ( hi[ i] . 0 , carry) = ( hi[ i] . 0 << 1 | carry. 0 , hi[ i] . shr ( Limb :: BITS - 1 ) ) ;
100+ i += 1 ;
101+ }
102+ hi[ limbs. len ( ) - 1 ] = carry;
103+
104+ // Handle the diagonal of the multiplication grid, which finishes the multiplication grid.
105+ let mut carry = Limb :: ZERO ;
106+ let mut i = 0 ;
107+ while i < limbs. len ( ) {
108+ let xi = limbs[ i] ;
109+ if ( i * 2 ) < limbs. len ( ) {
110+ ( lo[ i * 2 ] , carry) = lo[ i * 2 ] . mac ( xi, xi, carry) ;
111+ } else {
112+ ( hi[ i * 2 - limbs. len ( ) ] , carry) = hi[ i * 2 - limbs. len ( ) ] . mac ( xi, xi, carry) ;
129113 }
130- } } ;
114+
115+ if ( i * 2 + 1 ) < limbs. len ( ) {
116+ ( lo[ i * 2 + 1 ] , carry) = lo[ i * 2 + 1 ] . overflowing_add ( carry) ;
117+ } else {
118+ ( hi[ i * 2 + 1 - limbs. len ( ) ] , carry) =
119+ hi[ i * 2 + 1 - limbs. len ( ) ] . overflowing_add ( carry) ;
120+ }
121+
122+ i += 1 ;
123+ }
131124}
132125
133126impl < const LIMBS : usize > Uint < LIMBS > {
@@ -316,7 +309,7 @@ pub(crate) const fn uint_mul_limbs<const LIMBS: usize, const RHS_LIMBS: usize>(
316309 debug_assert ! ( lhs. len( ) == LIMBS && rhs. len( ) == RHS_LIMBS ) ;
317310 let mut lo: Uint < LIMBS > = Uint :: < LIMBS > :: ZERO ;
318311 let mut hi = Uint :: < RHS_LIMBS > :: ZERO ;
319- impl_schoolbook_multiplication ! ( lhs, rhs, lo. limbs, hi. limbs) ;
312+ schoolbook_multiplication ( lhs, rhs, & mut lo. limbs , & mut hi. limbs ) ;
320313 ( lo, hi)
321314}
322315
@@ -327,7 +320,7 @@ pub(crate) const fn uint_square_limbs<const LIMBS: usize>(
327320) -> ( Uint < LIMBS > , Uint < LIMBS > ) {
328321 let mut lo = Uint :: < LIMBS > :: ZERO ;
329322 let mut hi = Uint :: < LIMBS > :: ZERO ;
330- impl_schoolbook_squaring ! ( limbs, lo. limbs, hi. limbs) ;
323+ schoolbook_squaring ( limbs, & mut lo. limbs , & mut hi. limbs ) ;
331324 ( lo, hi)
332325}
333326
@@ -336,15 +329,15 @@ pub(crate) const fn uint_square_limbs<const LIMBS: usize>(
336329pub ( crate ) fn mul_limbs ( lhs : & [ Limb ] , rhs : & [ Limb ] , out : & mut [ Limb ] ) {
337330 debug_assert_eq ! ( lhs. len( ) + rhs. len( ) , out. len( ) ) ;
338331 let ( lo, hi) = out. split_at_mut ( lhs. len ( ) ) ;
339- impl_schoolbook_multiplication ! ( lhs, rhs, lo, hi) ;
332+ schoolbook_multiplication ( lhs, rhs, lo, hi) ;
340333}
341334
342335/// Wrapper function used by `BoxedUint`
343336#[ cfg( feature = "alloc" ) ]
344337pub ( crate ) fn square_limbs ( limbs : & [ Limb ] , out : & mut [ Limb ] ) {
345338 debug_assert_eq ! ( limbs. len( ) * 2 , out. len( ) ) ;
346339 let ( lo, hi) = out. split_at_mut ( limbs. len ( ) ) ;
347- impl_schoolbook_squaring ! ( limbs, lo, hi) ;
340+ schoolbook_squaring ( limbs, lo, hi) ;
348341}
349342
350343#[ cfg( test) ]
0 commit comments