diff --git a/src/const_choice.rs b/src/const_choice.rs index d45f6f245..60fbae5e3 100644 --- a/src/const_choice.rs +++ b/src/const_choice.rs @@ -166,6 +166,11 @@ impl ConstChoice { Self(self.0 & other.0) } + #[inline] + pub(crate) const fn xor(&self, other: Self) -> Self { + Self(self.0 ^ other.0) + } + /// Return `b` if `self` is truthy, otherwise return `a`. #[inline] pub(crate) const fn select_word(&self, a: Word, b: Word) -> Word { diff --git a/src/uint/boxed/mul.rs b/src/uint/boxed/mul.rs index 11778f3b8..089d28365 100644 --- a/src/uint/boxed/mul.rs +++ b/src/uint/boxed/mul.rs @@ -1,7 +1,10 @@ //! [`BoxedUint`] multiplication operations. use crate::{ - uint::mul::{mul_limbs, square_limbs}, + uint::mul::{ + karatsuba::{karatsuba_mul_limbs, karatsuba_square_limbs, KARATSUBA_MIN_STARTING_LIMBS}, + mul_limbs, square_limbs, + }, BoxedUint, CheckedMul, Limb, WideningMul, Wrapping, WrappingMul, Zero, }; use core::ops::{Mul, MulAssign}; @@ -12,7 +15,18 @@ impl BoxedUint { /// /// Returns a widened output with a limb count equal to the sums of the input limb counts. pub fn mul(&self, rhs: &Self) -> Self { - let mut limbs = vec![Limb::ZERO; self.nlimbs() + rhs.nlimbs()]; + let size = self.nlimbs() + rhs.nlimbs(); + let overlap = self.nlimbs().min(rhs.nlimbs()); + + if self.nlimbs().min(rhs.nlimbs()) >= KARATSUBA_MIN_STARTING_LIMBS { + let mut limbs = vec![Limb::ZERO; size + overlap * 2]; + let (out, scratch) = limbs.as_mut_slice().split_at_mut(size); + karatsuba_mul_limbs(&self.limbs, &rhs.limbs, out, scratch); + limbs.truncate(size); + return limbs.into(); + } + + let mut limbs = vec![Limb::ZERO; size]; mul_limbs(&self.limbs, &rhs.limbs, &mut limbs); limbs.into() } @@ -24,7 +38,17 @@ impl BoxedUint { /// Multiply `self` by itself. pub fn square(&self) -> Self { - let mut limbs = vec![Limb::ZERO; self.nlimbs() * 2]; + let size = self.nlimbs() * 2; + + if self.nlimbs() >= KARATSUBA_MIN_STARTING_LIMBS * 2 { + let mut limbs = vec![Limb::ZERO; size * 2]; + let (out, scratch) = limbs.as_mut_slice().split_at_mut(size); + karatsuba_square_limbs(&self.limbs, out, scratch); + limbs.truncate(size); + return limbs.into(); + } + + let mut limbs = vec![Limb::ZERO; size]; square_limbs(&self.limbs, &mut limbs); limbs.into() } @@ -144,4 +168,23 @@ mod tests { } } } + + #[cfg(feature = "rand_core")] + #[test] + fn mul_cmp() { + use crate::RandomBits; + use rand_core::SeedableRng; + let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(1); + + for _ in 0..50 { + let a = BoxedUint::random_bits(&mut rng, 4096); + assert_eq!(a.mul(&a), a.square(), "a = {a}"); + } + + for _ in 0..50 { + let a = BoxedUint::random_bits(&mut rng, 4096); + let b = BoxedUint::random_bits(&mut rng, 5000); + assert_eq!(a.mul(&b), b.mul(&a), "a={a}, b={b}"); + } + } } diff --git a/src/uint/mul.rs b/src/uint/mul.rs index d04c5b353..4a122e963 100644 --- a/src/uint/mul.rs +++ b/src/uint/mul.rs @@ -1,13 +1,14 @@ //! [`Uint`] multiplication operations. -// TODO(tarcieri): use Karatsuba for better performance - +use self::karatsuba::UintKaratsubaMul; use crate::{ Checked, CheckedMul, Concat, ConcatMixed, Limb, Uint, WideningMul, Wrapping, WrappingMul, Zero, }; use core::ops::{Mul, MulAssign}; use subtle::CtOption; +pub(crate) mod karatsuba; + /// Implement the core schoolbook multiplication algorithm. /// /// This is implemented as a macro to abstract over `const fn` and boxed use cases, since the latter @@ -26,18 +27,15 @@ macro_rules! impl_schoolbook_multiplication { while i < $lhs.len() { let mut j = 0; let mut carry = Limb::ZERO; + let xi = $lhs[i]; while j < $rhs.len() { let k = i + j; if k >= $lhs.len() { - let (n, c) = $hi[k - $lhs.len()].mac($lhs[i], $rhs[j], carry); - $hi[k - $lhs.len()] = n; - carry = c; + ($hi[k - $lhs.len()], carry) = $hi[k - $lhs.len()].mac(xi, $rhs[j], carry); } else { - let (n, c) = $lo[k].mac($lhs[i], $rhs[j], carry); - $lo[k] = n; - carry = c; + ($lo[k], carry) = $lo[k].mac(xi, $rhs[j], carry); } j += 1; @@ -72,18 +70,15 @@ macro_rules! impl_schoolbook_squaring { while i < $limbs.len() { let mut j = 0; let mut carry = Limb::ZERO; + let xi = $limbs[i]; while j < i { let k = i + j; if k >= $limbs.len() { - let (n, c) = $hi[k - $limbs.len()].mac($limbs[i], $limbs[j], carry); - $hi[k - $limbs.len()] = n; - carry = c; + ($hi[k - $limbs.len()], carry) = $hi[k - $limbs.len()].mac(xi, $limbs[j], carry); } else { - let (n, c) = $lo[k].mac($limbs[i], $limbs[j], carry); - $lo[k] = n; - carry = c; + ($lo[k], carry) = $lo[k].mac(xi, $limbs[j], carry); } j += 1; @@ -117,24 +112,17 @@ macro_rules! impl_schoolbook_squaring { let mut carry = Limb::ZERO; let mut i = 0; while i < $limbs.len() { + let xi = $limbs[i]; if (i * 2) < $limbs.len() { - let (n, c) = $lo[i * 2].mac($limbs[i], $limbs[i], carry); - $lo[i * 2] = n; - carry = c; + ($lo[i * 2], carry) = $lo[i * 2].mac(xi, xi, carry); } else { - let (n, c) = $hi[i * 2 - $limbs.len()].mac($limbs[i], $limbs[i], carry); - $hi[i * 2 - $limbs.len()] = n; - carry = c; + ($hi[i * 2 - $limbs.len()], carry) = $hi[i * 2 - $limbs.len()].mac(xi, xi, carry); } if (i * 2 + 1) < $limbs.len() { - let (n, c) = $lo[i * 2 + 1].overflowing_add(carry); - $lo[i * 2 + 1] = n; - carry = c; + ($lo[i * 2 + 1], carry) = $lo[i * 2 + 1].overflowing_add(carry); } else { - let (n, c) = $hi[i * 2 + 1 - $limbs.len()].overflowing_add(carry); - $hi[i * 2 + 1 - $limbs.len()] = n; - carry = c; + ($hi[i * 2 + 1 - $limbs.len()], carry) = $hi[i * 2 + 1 - $limbs.len()].overflowing_add(carry); } i += 1; @@ -161,10 +149,27 @@ impl Uint { &self, rhs: &Uint, ) -> (Self, Uint) { - let mut lo = Self::ZERO; - let mut hi = Uint::::ZERO; - impl_schoolbook_multiplication!(&self.limbs, &rhs.limbs, lo.limbs, hi.limbs); - (lo, hi) + if LIMBS == RHS_LIMBS { + if LIMBS == 128 { + let (a, b) = UintKaratsubaMul::<128>::multiply(&self.limbs, &rhs.limbs); + // resize() should be a no-op, but the compiler can't infer that Uint is Uint<128> + return (a.resize(), b.resize()); + } + if LIMBS == 64 { + let (a, b) = UintKaratsubaMul::<64>::multiply(&self.limbs, &rhs.limbs); + return (a.resize(), b.resize()); + } + if LIMBS == 32 { + let (a, b) = UintKaratsubaMul::<32>::multiply(&self.limbs, &rhs.limbs); + return (a.resize(), b.resize()); + } + if LIMBS == 16 { + let (a, b) = UintKaratsubaMul::<16>::multiply(&self.limbs, &rhs.limbs); + return (a.resize(), b.resize()); + } + } + + uint_mul_limbs(&self.limbs, &rhs.limbs) } /// Perform wrapping multiplication, discarding overflow. @@ -180,10 +185,17 @@ impl Uint { /// Square self, returning a "wide" result in two parts as (lo, hi). pub const fn square_wide(&self) -> (Self, Self) { - let mut lo = Self::ZERO; - let mut hi = Self::ZERO; - impl_schoolbook_squaring!(&self.limbs, lo.limbs, hi.limbs); - (lo, hi) + if LIMBS == 128 { + let (a, b) = UintKaratsubaMul::<128>::square(&self.limbs); + // resize() should be a no-op, but the compiler can't infer that Uint is Uint<128> + return (a.resize(), b.resize()); + } + if LIMBS == 64 { + let (a, b) = UintKaratsubaMul::<64>::square(&self.limbs); + return (a.resize(), b.resize()); + } + + uint_square_limbs(&self.limbs) } } @@ -295,6 +307,30 @@ impl WrappingMul for Uint { } } +/// Helper method to perform schoolbook multiplication +#[inline] +pub(crate) const fn uint_mul_limbs( + lhs: &[Limb], + rhs: &[Limb], +) -> (Uint, Uint) { + debug_assert!(lhs.len() == LIMBS && rhs.len() == RHS_LIMBS); + let mut lo: Uint = Uint::::ZERO; + let mut hi = Uint::::ZERO; + impl_schoolbook_multiplication!(lhs, rhs, lo.limbs, hi.limbs); + (lo, hi) +} + +/// Helper method to perform schoolbook multiplication +#[inline] +pub(crate) const fn uint_square_limbs( + limbs: &[Limb], +) -> (Uint, Uint) { + let mut lo = Uint::::ZERO; + let mut hi = Uint::::ZERO; + impl_schoolbook_squaring!(limbs, lo.limbs, hi.limbs); + (lo, hi) +} + /// Wrapper function used by `BoxedUint` #[cfg(feature = "alloc")] pub(crate) fn mul_limbs(lhs: &[Limb], rhs: &[Limb], out: &mut [Limb]) { @@ -402,4 +438,17 @@ mod tests { assert_eq!(lo, U256::ONE); assert_eq!(hi, U256::MAX.wrapping_sub(&U256::ONE)); } + + #[cfg(feature = "rand_core")] + #[test] + fn mul_cmp() { + use crate::{Random, U4096}; + use rand_core::SeedableRng; + let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(1); + + for _ in 0..50 { + let a = U4096::random(&mut rng); + assert_eq!(a.split_mul(&a), a.square_wide(), "a = {a}"); + } + } } diff --git a/src/uint/mul/karatsuba.rs b/src/uint/mul/karatsuba.rs new file mode 100644 index 000000000..32b2f9310 --- /dev/null +++ b/src/uint/mul/karatsuba.rs @@ -0,0 +1,419 @@ +//! Karatsuba multiplication +//! +//! This is a method which reduces the complexity of multiplication from O(n^2) to O(n^1.585). +//! For smaller numbers, it is best to stick to schoolbook multiplication, taking advantage +//! of better cache locality and avoiding recursion. +//! +//! In general, we consider the multiplication of two numbers of an equal size, `n` bits. +//! Setting b = 2^(n/2), then we can decompose the values: +//! x•y = (x0 + x1•b)(y0 + y1•b) +//! +//! This equation is equivalent to a linear combination of three products of size `n/2`, which +//! may each be reduced by applying the same optimization. +//! Setting z0 = x0•y0, z1 = (x0-x1)(y1-y0), z2 = x1•y1: +//! x•y = z0 + (z0 - z1 + z2)•b + z2•b^2 +//! +//! Considering each sub-product as a tuple of integers `(lo, hi)`, the product is calculated as +//! follows (with appropriate carries): +//! [z0.0, z0.0 + z0.1 - z1.0 + z2.0, z0.1 - z1.1 + z2.0 + z2.1, z2.1] +//! + +use super::{uint_mul_limbs, uint_square_limbs}; +use crate::{ConstChoice, Limb, Uint}; + +#[cfg(feature = "alloc")] +use super::square_limbs; +#[cfg(feature = "alloc")] +use crate::{WideWord, Word}; + +#[cfg(feature = "alloc")] +pub const KARATSUBA_MIN_STARTING_LIMBS: usize = 32; +#[cfg(feature = "alloc")] +pub const KARATSUBA_MAX_REDUCE_LIMBS: usize = 24; + +/// A helper struct for performing Karatsuba multiplication on Uints. +pub(crate) struct UintKaratsubaMul; + +macro_rules! impl_uint_karatsuba_multiplication { + // TODO: revisit when `const_mut_refs` is stable + (reduce $full_size:expr, $half_size:expr) => { + impl UintKaratsubaMul<$full_size> { + pub(crate) const fn multiply( + lhs: &[Limb], + rhs: &[Limb], + ) -> (Uint<$full_size>, Uint<$full_size>) { + let (x0, x1) = lhs.split_at($half_size); + let (y0, y1) = rhs.split_at($half_size); + + // Calculate z1 = (x0 - x1)(y1 - y0) + let mut l0 = Uint::<$half_size>::ZERO; + let mut l1 = Uint::<$half_size>::ZERO; + let mut l0b = Limb::ZERO; + let mut l1b = Limb::ZERO; + let mut i = 0; + while i < $half_size { + (l0.limbs[i], l0b) = x0[i].sbb(x1[i], l0b); + (l1.limbs[i], l1b) = y1[i].sbb(y0[i], l1b); + i += 1; + } + l0 = Uint::select( + &l0, + &l0.wrapping_neg(), + ConstChoice::from_word_mask(l0b.0), + ); + l1 = Uint::select( + &l1, + &l1.wrapping_neg(), + ConstChoice::from_word_mask(l1b.0), + ); + let z1 = UintKaratsubaMul::<$half_size>::multiply(&l0.limbs, &l1.limbs); + let z1_neg = ConstChoice::from_word_mask(l0b.0) + .xor(ConstChoice::from_word_mask(l1b.0)); + + // Conditionally add or subtract z1•b depending on its sign + let mut res = (Uint::ZERO, z1.0, z1.1, Uint::ZERO); + res.0 = Uint::select(&res.0, &res.0.not(), z1_neg); + res.1 = Uint::select(&res.1, &res.1.not(), z1_neg); + res.2 = Uint::select(&res.2, &res.2.not(), z1_neg); + res.3 = Uint::select(&res.3, &res.3.not(), z1_neg); + + // Calculate z0 = x0•y0 + let z0 = UintKaratsubaMul::<$half_size>::multiply(&x0, &y0); + // Calculate z2 = x1•y1 + let z2 = UintKaratsubaMul::<$half_size>::multiply(&x1, &y1); + + // Add z0 + (z0 + z2)•b + z2•b^2 + let mut carry = Limb::select(Limb::ZERO, Limb::ONE, z1_neg); + (res.0, carry) = res.0.adc(&z0.0, carry); + (res.1, carry) = res.1.adc(&z0.1, carry); + let mut carry2; + (res.1, carry2) = res.1.adc(&z0.0, Limb::ZERO); + (res.2, carry) = res.2.adc(&z0.1, carry.wrapping_add(carry2)); + (res.1, carry2) = res.1.adc(&z2.0, Limb::ZERO); + (res.2, carry2) = res.2.adc(&z2.1, carry2); + carry = carry.wrapping_add(carry2); + (res.2, carry2) = res.2.adc(&z2.0, Limb::ZERO); + (res.3, _) = res.3.adc(&z2.1, carry.wrapping_add(carry2)); + + (res.0.concat(&res.1), res.2.concat(&res.3)) + } + } + }; + ($small_size:expr) => { + impl UintKaratsubaMul<$small_size> { + #[inline] + pub(crate) const fn multiply(lhs: &[Limb], rhs: &[Limb]) -> (Uint<$small_size>, Uint<$small_size>) { + uint_mul_limbs(lhs, rhs) + } + } + }; + ($full_size:tt, $half_size:tt $(,$rest:tt)*) => { + impl_uint_karatsuba_multiplication!{reduce $full_size, $half_size} + impl_uint_karatsuba_multiplication!{$half_size $(,$rest)*} + } +} + +macro_rules! impl_uint_karatsuba_squaring { + (reduce $full_size:expr, $half_size:expr) => { + impl UintKaratsubaMul<$full_size> { + pub(crate) const fn square(limbs: &[Limb]) -> (Uint<$full_size>, Uint<$full_size>) { + let (x0, x1) = limbs.split_at($half_size); + let z0 = UintKaratsubaMul::<$half_size>::square(&x0); + let z2 = UintKaratsubaMul::<$half_size>::square(&x1); + + // Calculate z0 + (z0 + z2)•b + z2•b^2 + let mut res = (z0.0, z0.1, Uint::<$half_size>::ZERO, Uint::<$half_size>::ZERO); + let mut carry; + (res.1, carry) = res.1.adc(&z0.0, Limb::ZERO); + (res.2, carry) = z0.1.adc(&z2.0, carry); + let mut carry2; + (res.1, carry2) = res.1.adc(&z2.0, Limb::ZERO); + (res.2, carry2) = res.2.adc(&z2.1, carry2); + (res.3, _) = z2.1.adc(&Uint::ZERO, carry.wrapping_add(carry2)); + + // Calculate z1 = (x0 - x1)^2 + let mut l0 = Uint::<$half_size>::ZERO; + let mut l0b = Limb::ZERO; + let mut i = 0; + while i < $half_size { + (l0.limbs[i], l0b) = x0[i].sbb(x1[i], l0b); + i += 1; + } + l0 = Uint::select( + &l0, + &l0.wrapping_neg(), + ConstChoice::from_word_mask(l0b.0), + ); + + let z1 = UintKaratsubaMul::<$half_size>::square(&l0.limbs); + + // Subtract z1•b + carry = Limb::ZERO; + (res.1, carry) = res.1.sbb(&z1.0, carry); + (res.2, carry) = res.2.sbb(&z1.1, carry); + (res.3, _) = res.3.sbb(&Uint::ZERO, carry); + + (res.0.concat(&res.1), res.2.concat(&res.3)) + } + } + }; + ($small_size:expr) => { + impl UintKaratsubaMul<$small_size> { + #[inline] + pub(crate) const fn square(limbs: &[Limb]) -> (Uint<$small_size>, Uint<$small_size>) { + uint_square_limbs(limbs) + } + } + }; + ($full_size:tt, $half_size:tt $(,$rest:tt)*) => { + impl_uint_karatsuba_squaring!{reduce $full_size, $half_size} + impl_uint_karatsuba_squaring!{$half_size $(,$rest)*} + } +} + +#[cfg(feature = "alloc")] +#[inline(never)] +pub(crate) fn karatsuba_mul_limbs( + lhs: &[Limb], + rhs: &[Limb], + out: &mut [Limb], + scratch: &mut [Limb], +) { + let size = { + let overlap = lhs.len().min(rhs.len()); + if (overlap & 1) == 1 { + overlap.saturating_sub(1) + } else { + overlap + } + }; + if size <= KARATSUBA_MAX_REDUCE_LIMBS { + out.fill(Limb::ZERO); + adc_mul_limbs(lhs, rhs, out); + return; + } + if lhs.len() + rhs.len() != out.len() || scratch.len() < 2 * size { + panic!("invalid arguments to karatsuba_mul_limbs"); + } + let half = size / 2; + let (scratch, ext_scratch) = scratch.split_at_mut(size); + + let (x, xt) = lhs.split_at(size); + let (y, yt) = rhs.split_at(size); + let (x0, x1) = x.split_at(half); + let (y0, y1) = y.split_at(half); + + // Initialize output buffer + out.fill(Limb::ZERO); + + // Calculate abs(x0 - x1) and abs(y1 - y0) + let mut i = 0; + let mut borrow0 = Limb::ZERO; + let mut borrow1 = Limb::ZERO; + while i < half { + (scratch[i], borrow0) = x0[i].sbb(x1[i], borrow0); + (scratch[i + half], borrow1) = y1[i].sbb(y0[i], borrow1); + i += 1; + } + // Conditionally negate terms depending whether they borrowed + conditional_wrapping_neg_assign(&mut scratch[..half], ConstChoice::from_word_mask(borrow0.0)); + conditional_wrapping_neg_assign( + &mut scratch[half..size], + ConstChoice::from_word_mask(borrow1.0), + ); + + // Calculate abs(z1) = abs(x0 - x1)•abs(y1 - y0) + karatsuba_mul_limbs( + &scratch[..half], + &scratch[half..size], + &mut out[half..size + half], + ext_scratch, + ); + let z1_neg = ConstChoice::from_word_mask(borrow0.0).xor(ConstChoice::from_word_mask(borrow1.0)); + // Conditionally negate the output + conditional_wrapping_neg_assign(&mut out[..2 * size], z1_neg); + + // Calculate z0 = x0•y0 into scratch + karatsuba_mul_limbs(x0, y0, scratch, ext_scratch); + // Add z0•(1 + b) to output + let mut carry = Limb::ZERO; + let mut carry2 = Limb::ZERO; + i = 0; + while i < size { + (out[i], carry) = out[i].adc(scratch[i], carry); // add z0 + i += 1; + } + i = 0; + while i < half { + (out[i + half], carry2) = out[i + half].adc(scratch[i], carry2); // add z0.0 + i += 1; + } + carry = carry.wrapping_add(carry2); + while i < size { + (out[i + half], carry) = out[i + half].adc(scratch[i], carry); // add z0.1 + i += 1; + } + + // Calculate z2 = x1•y1 into scratch + karatsuba_mul_limbs(x1, y1, scratch, ext_scratch); + // Add z2•(b + b^2) to output + carry2 = Limb::ZERO; + i = 0; + while i < size { + (out[i + half], carry2) = out[i + half].adc(scratch[i], carry2); // add z2 + i += 1; + } + carry = carry.wrapping_add(carry2); + carry2 = Limb::ZERO; + i = 0; + while i < half { + (out[i + size], carry2) = out[i + size].adc(scratch[i], carry2); // add z2.0 + i += 1; + } + carry = carry.wrapping_add(carry2); + while i < size { + (out[i + size], carry) = out[i + size].adc(scratch[i], carry); // add z2.1 + i += 1; + } + + // Handle trailing limbs + if !xt.is_empty() { + adc_mul_limbs(xt, rhs, &mut out[size..]); + } + if !yt.is_empty() { + let end_pos = 2 * size + yt.len(); + carry = adc_mul_limbs(yt, x, &mut out[size..end_pos]); + i = end_pos; + while i < out.len() { + (out[i], carry) = out[i].adc(Limb::ZERO, carry); + i += 1; + } + } +} + +#[cfg(feature = "alloc")] +#[inline(never)] +pub(crate) fn karatsuba_square_limbs(limbs: &[Limb], out: &mut [Limb], scratch: &mut [Limb]) { + let size = limbs.len(); + if size <= KARATSUBA_MAX_REDUCE_LIMBS * 2 || (size & 1) == 1 { + out.fill(Limb::ZERO); + square_limbs(limbs, out); + return; + } + if 2 * size != out.len() || scratch.len() < out.len() { + panic!("invalid arguments to karatsuba_square_limbs"); + } + let half = size / 2; + let (scratch, ext_scratch) = scratch.split_at_mut(size); + let (x0, x1) = limbs.split_at(half); + + // Initialize output buffer + out[..2 * size].fill(Limb::ZERO); + + // Calculate x0 - x1 + let mut i = 0; + let mut borrow = Limb::ZERO; + while i < half { + (scratch[i], borrow) = x0[i].sbb(x1[i], borrow); + i += 1; + } + // Conditionally negate depending whether subtraction borrowed + conditional_wrapping_neg_assign(&mut scratch[..half], ConstChoice::from_word_mask(borrow.0)); + // Calculate z1 = (x0 - x1)^2 into output + karatsuba_square_limbs(&scratch[..half], &mut out[half..3 * half], ext_scratch); + // Negate the output (will add 1 to produce the wrapping negative) + i = 0; + while i < 2 * size { + out[i] = !out[i]; + i += 1; + } + + // Calculate z0 = x0^2 into scratch + karatsuba_square_limbs(x0, scratch, ext_scratch); + // Add z0•(1 + b) to output + let mut carry = Limb::ONE; // add 1 to complete wrapping negative + let mut carry2 = Limb::ZERO; + i = 0; + while i < size { + (out[i], carry) = out[i].adc(scratch[i], carry); // add z0 + i += 1; + } + i = 0; + while i < half { + (out[i + half], carry2) = out[i + half].adc(scratch[i], carry2); // add z0.0 + i += 1; + } + carry = carry.wrapping_add(carry2); + while i < size { + (out[i + half], carry) = out[i + half].adc(scratch[i], carry); // add z0.1 + i += 1; + } + + // Calculate z2 = x1^2 into scratch + karatsuba_square_limbs(x1, scratch, ext_scratch); + // Add z2•(b + b^2) to output + carry2 = Limb::ZERO; + i = 0; + while i < size { + (out[i + half], carry2) = out[i + half].adc(scratch[i], carry2); // add z2 + i += 1; + } + carry = carry.wrapping_add(carry2); + carry2 = Limb::ZERO; + i = 0; + while i < half { + (out[i + size], carry2) = out[i + size].adc(scratch[i], carry2); // add z2.0 + i += 1; + } + carry = carry.wrapping_add(carry2); + while i < size { + (out[i + size], carry) = out[i + size].adc(scratch[i], carry); // add z2.1 + i += 1; + } +} + +#[cfg(feature = "alloc")] +/// Conditionally replace the contents of a mutable limb slice with its wrapping negation. +#[inline] +fn conditional_wrapping_neg_assign(limbs: &mut [Limb], choice: ConstChoice) { + let mut carry = choice.select_word(0, 1) as WideWord; + let mut r; + let mut i = 0; + while i < limbs.len() { + r = (choice.select_word(limbs[i].0, !limbs[i].0) as WideWord) + carry; + limbs[i].0 = r as Word; + carry = r >> Word::BITS; + i += 1; + } +} + +/// Add the schoolbook product of two limb slices to a limb slice, returning the carry. +#[cfg(feature = "alloc")] +fn adc_mul_limbs(lhs: &[Limb], rhs: &[Limb], out: &mut [Limb]) -> Limb { + if lhs.len() + rhs.len() != out.len() { + panic!("adc_mul_limbs length mismatch"); + } + + let mut carry = Limb::ZERO; + let mut i = 0; + while i < lhs.len() { + let mut j = 0; + let mut carry2 = Limb::ZERO; + let xi = lhs[i]; + + while j < rhs.len() { + let k = i + j; + (out[k], carry2) = out[k].mac(xi, rhs[j], carry2); + j += 1; + } + + carry = carry.wrapping_add(carry2); + (out[i + j], carry) = out[i + j].adc(Limb::ZERO, carry); + i += 1; + } + + carry +} + +impl_uint_karatsuba_multiplication!(128, 64, 32, 16, 8); +impl_uint_karatsuba_squaring!(128, 64, 32); diff --git a/tests/uint.rs b/tests/uint.rs index f798db815..a74b53a45 100644 --- a/tests/uint.rs +++ b/tests/uint.rs @@ -5,7 +5,7 @@ mod common; use common::to_biguint; use crypto_bigint::{ modular::{MontyForm, MontyParams}, - Encoding, Integer, Limb, NonZero, Odd, Uint, Word, U256, + Encoding, Integer, Limb, NonZero, Odd, Uint, Word, U256, U4096, U8192, }; use num_bigint::BigUint; use num_integer::Integer as _; @@ -26,11 +26,25 @@ fn to_uint(big_uint: BigUint) -> U256 { U256::from_le_slice(&input) } +fn to_uint_xlarge(big_uint: BigUint) -> U8192 { + let mut input = [0u8; U8192::BYTES]; + let encoded = big_uint.to_bytes_le(); + let l = encoded.len().min(U8192::BYTES); + input[..l].copy_from_slice(&encoded[..l]); + + U8192::from_le_slice(&input) +} + prop_compose! { fn uint()(bytes in any::<[u8; 32]>()) -> U256 { U256::from_le_slice(&bytes) } } +prop_compose! { + fn uint_large()(bytes in any::<[u8; 512]>()) -> U4096 { + U4096::from_le_slice(&bytes) + } +} prop_compose! { fn uint_mod_p(p: Odd)(a in uint()) -> U256 { a.wrapping_rem_vartime(&p) @@ -250,6 +264,28 @@ proptest! { } } + #[test] + fn widening_mul_large(a in uint_large(), b in uint_large()) { + let a_bi = to_biguint(&a); + let b_bi = to_biguint(&b); + + let expected = to_uint_xlarge(a_bi * b_bi); + let actual = a.widening_mul(&b); + + assert_eq!(expected, actual); + } + + + #[test] + fn square_large(a in uint_large()) { + let a_bi = to_biguint(&a); + + let expected = to_uint_xlarge(&a_bi * &a_bi); + let actual = a.square(); + + assert_eq!(expected, actual); + } + #[test] fn div_rem(a in uint(), b in uint()) { let a_bi = to_biguint(&a);