Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/const_choice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,11 @@ impl ConstChoice {
Self(self.0 & other.0)
}

#[inline]
pub(crate) const fn xor(&self, other: Self) -> Self {
Self(self.0 ^ other.0)
}

/// Return `b` if `self` is truthy, otherwise return `a`.
#[inline]
pub(crate) const fn select_word(&self, a: Word, b: Word) -> Word {
Expand Down
49 changes: 46 additions & 3 deletions src/uint/boxed/mul.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
//! [`BoxedUint`] multiplication operations.

use crate::{
uint::mul::{mul_limbs, square_limbs},
uint::mul::{
karatsuba::{karatsuba_mul_limbs, karatsuba_square_limbs, KARATSUBA_MIN_STARTING_LIMBS},
mul_limbs, square_limbs,
},
BoxedUint, CheckedMul, Limb, WideningMul, Wrapping, WrappingMul, Zero,
};
use core::ops::{Mul, MulAssign};
Expand All @@ -12,7 +15,18 @@ impl BoxedUint {
///
/// Returns a widened output with a limb count equal to the sums of the input limb counts.
pub fn mul(&self, rhs: &Self) -> Self {
let mut limbs = vec![Limb::ZERO; self.nlimbs() + rhs.nlimbs()];
let size = self.nlimbs() + rhs.nlimbs();
let overlap = self.nlimbs().min(rhs.nlimbs());

if self.nlimbs().min(rhs.nlimbs()) >= KARATSUBA_MIN_STARTING_LIMBS {
let mut limbs = vec![Limb::ZERO; size + overlap * 2];
let (out, scratch) = limbs.as_mut_slice().split_at_mut(size);
karatsuba_mul_limbs(&self.limbs, &rhs.limbs, out, scratch);
limbs.truncate(size);
return limbs.into();
}

let mut limbs = vec![Limb::ZERO; size];
mul_limbs(&self.limbs, &rhs.limbs, &mut limbs);
limbs.into()
}
Expand All @@ -24,7 +38,17 @@ impl BoxedUint {

/// Multiply `self` by itself.
pub fn square(&self) -> Self {
let mut limbs = vec![Limb::ZERO; self.nlimbs() * 2];
let size = self.nlimbs() * 2;

if self.nlimbs() >= KARATSUBA_MIN_STARTING_LIMBS * 2 {
let mut limbs = vec![Limb::ZERO; size * 2];
let (out, scratch) = limbs.as_mut_slice().split_at_mut(size);
karatsuba_square_limbs(&self.limbs, out, scratch);
limbs.truncate(size);
return limbs.into();
}

let mut limbs = vec![Limb::ZERO; size];
square_limbs(&self.limbs, &mut limbs);
limbs.into()
}
Expand Down Expand Up @@ -144,4 +168,23 @@ mod tests {
}
}
}

#[cfg(feature = "rand_core")]
#[test]
fn mul_cmp() {
use crate::RandomBits;
use rand_core::SeedableRng;
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(1);

for _ in 0..50 {
let a = BoxedUint::random_bits(&mut rng, 4096);
assert_eq!(a.mul(&a), a.square(), "a = {a}");
}

for _ in 0..50 {
let a = BoxedUint::random_bits(&mut rng, 4096);
let b = BoxedUint::random_bits(&mut rng, 5000);
assert_eq!(a.mul(&b), b.mul(&a), "a={a}, b={b}");
}
}
}
117 changes: 83 additions & 34 deletions src/uint/mul.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
//! [`Uint`] multiplication operations.

// TODO(tarcieri): use Karatsuba for better performance

use self::karatsuba::UintKaratsubaMul;
use crate::{
Checked, CheckedMul, Concat, ConcatMixed, Limb, Uint, WideningMul, Wrapping, WrappingMul, Zero,
};
use core::ops::{Mul, MulAssign};
use subtle::CtOption;

pub(crate) mod karatsuba;

/// Implement the core schoolbook multiplication algorithm.
///
/// This is implemented as a macro to abstract over `const fn` and boxed use cases, since the latter
Expand All @@ -26,18 +27,15 @@ macro_rules! impl_schoolbook_multiplication {
while i < $lhs.len() {
let mut j = 0;
let mut carry = Limb::ZERO;
let xi = $lhs[i];

while j < $rhs.len() {
let k = i + j;

if k >= $lhs.len() {
let (n, c) = $hi[k - $lhs.len()].mac($lhs[i], $rhs[j], carry);
$hi[k - $lhs.len()] = n;
carry = c;
($hi[k - $lhs.len()], carry) = $hi[k - $lhs.len()].mac(xi, $rhs[j], carry);
} else {
let (n, c) = $lo[k].mac($lhs[i], $rhs[j], carry);
$lo[k] = n;
carry = c;
($lo[k], carry) = $lo[k].mac(xi, $rhs[j], carry);
}

j += 1;
Expand Down Expand Up @@ -72,18 +70,15 @@ macro_rules! impl_schoolbook_squaring {
while i < $limbs.len() {
let mut j = 0;
let mut carry = Limb::ZERO;
let xi = $limbs[i];

while j < i {
let k = i + j;

if k >= $limbs.len() {
let (n, c) = $hi[k - $limbs.len()].mac($limbs[i], $limbs[j], carry);
$hi[k - $limbs.len()] = n;
carry = c;
($hi[k - $limbs.len()], carry) = $hi[k - $limbs.len()].mac(xi, $limbs[j], carry);
} else {
let (n, c) = $lo[k].mac($limbs[i], $limbs[j], carry);
$lo[k] = n;
carry = c;
($lo[k], carry) = $lo[k].mac(xi, $limbs[j], carry);
}

j += 1;
Expand Down Expand Up @@ -117,24 +112,17 @@ macro_rules! impl_schoolbook_squaring {
let mut carry = Limb::ZERO;
let mut i = 0;
while i < $limbs.len() {
let xi = $limbs[i];
if (i * 2) < $limbs.len() {
let (n, c) = $lo[i * 2].mac($limbs[i], $limbs[i], carry);
$lo[i * 2] = n;
carry = c;
($lo[i * 2], carry) = $lo[i * 2].mac(xi, xi, carry);
} else {
let (n, c) = $hi[i * 2 - $limbs.len()].mac($limbs[i], $limbs[i], carry);
$hi[i * 2 - $limbs.len()] = n;
carry = c;
($hi[i * 2 - $limbs.len()], carry) = $hi[i * 2 - $limbs.len()].mac(xi, xi, carry);
}

if (i * 2 + 1) < $limbs.len() {
let (n, c) = $lo[i * 2 + 1].overflowing_add(carry);
$lo[i * 2 + 1] = n;
carry = c;
($lo[i * 2 + 1], carry) = $lo[i * 2 + 1].overflowing_add(carry);
} else {
let (n, c) = $hi[i * 2 + 1 - $limbs.len()].overflowing_add(carry);
$hi[i * 2 + 1 - $limbs.len()] = n;
carry = c;
($hi[i * 2 + 1 - $limbs.len()], carry) = $hi[i * 2 + 1 - $limbs.len()].overflowing_add(carry);
}

i += 1;
Expand All @@ -161,10 +149,27 @@ impl<const LIMBS: usize> Uint<LIMBS> {
&self,
rhs: &Uint<RHS_LIMBS>,
) -> (Self, Uint<RHS_LIMBS>) {
let mut lo = Self::ZERO;
let mut hi = Uint::<RHS_LIMBS>::ZERO;
impl_schoolbook_multiplication!(&self.limbs, &rhs.limbs, lo.limbs, hi.limbs);
(lo, hi)
if LIMBS == RHS_LIMBS {
if LIMBS == 128 {
let (a, b) = UintKaratsubaMul::<128>::multiply(&self.limbs, &rhs.limbs);
// resize() should be a no-op, but the compiler can't infer that Uint<LIMBS> is Uint<128>
return (a.resize(), b.resize());
}
if LIMBS == 64 {
let (a, b) = UintKaratsubaMul::<64>::multiply(&self.limbs, &rhs.limbs);
return (a.resize(), b.resize());
}
if LIMBS == 32 {
let (a, b) = UintKaratsubaMul::<32>::multiply(&self.limbs, &rhs.limbs);
return (a.resize(), b.resize());
}
if LIMBS == 16 {
let (a, b) = UintKaratsubaMul::<16>::multiply(&self.limbs, &rhs.limbs);
return (a.resize(), b.resize());
}
}

uint_mul_limbs(&self.limbs, &rhs.limbs)
}

/// Perform wrapping multiplication, discarding overflow.
Expand All @@ -180,10 +185,17 @@ impl<const LIMBS: usize> Uint<LIMBS> {

/// Square self, returning a "wide" result in two parts as (lo, hi).
pub const fn square_wide(&self) -> (Self, Self) {
let mut lo = Self::ZERO;
let mut hi = Self::ZERO;
impl_schoolbook_squaring!(&self.limbs, lo.limbs, hi.limbs);
(lo, hi)
if LIMBS == 128 {
let (a, b) = UintKaratsubaMul::<128>::square(&self.limbs);
// resize() should be a no-op, but the compiler can't infer that Uint<LIMBS> is Uint<128>
return (a.resize(), b.resize());
}
if LIMBS == 64 {
let (a, b) = UintKaratsubaMul::<64>::square(&self.limbs);
return (a.resize(), b.resize());
}

uint_square_limbs(&self.limbs)
}
}

Expand Down Expand Up @@ -295,6 +307,30 @@ impl<const LIMBS: usize> WrappingMul for Uint<LIMBS> {
}
}

/// Helper method to perform schoolbook multiplication
#[inline]
pub(crate) const fn uint_mul_limbs<const LIMBS: usize, const RHS_LIMBS: usize>(
lhs: &[Limb],
rhs: &[Limb],
) -> (Uint<LIMBS>, Uint<RHS_LIMBS>) {
debug_assert!(lhs.len() == LIMBS && rhs.len() == RHS_LIMBS);
let mut lo: Uint<LIMBS> = Uint::<LIMBS>::ZERO;
let mut hi = Uint::<RHS_LIMBS>::ZERO;
impl_schoolbook_multiplication!(lhs, rhs, lo.limbs, hi.limbs);
(lo, hi)
}

/// Helper method to perform schoolbook multiplication
#[inline]
pub(crate) const fn uint_square_limbs<const LIMBS: usize>(
limbs: &[Limb],
) -> (Uint<LIMBS>, Uint<LIMBS>) {
let mut lo = Uint::<LIMBS>::ZERO;
let mut hi = Uint::<LIMBS>::ZERO;
impl_schoolbook_squaring!(limbs, lo.limbs, hi.limbs);
(lo, hi)
}

/// Wrapper function used by `BoxedUint`
#[cfg(feature = "alloc")]
pub(crate) fn mul_limbs(lhs: &[Limb], rhs: &[Limb], out: &mut [Limb]) {
Expand Down Expand Up @@ -402,4 +438,17 @@ mod tests {
assert_eq!(lo, U256::ONE);
assert_eq!(hi, U256::MAX.wrapping_sub(&U256::ONE));
}

#[cfg(feature = "rand_core")]
#[test]
fn mul_cmp() {
use crate::{Random, U4096};
use rand_core::SeedableRng;
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(1);

for _ in 0..50 {
let a = U4096::random(&mut rng);
assert_eq!(a.split_mul(&a), a.square_wide(), "a = {a}");
}
}
}
Loading