diff --git a/benches/bench.rs b/benches/bench.rs index 67767c772..cc7425bee 100644 --- a/benches/bench.rs +++ b/benches/bench.rs @@ -1,7 +1,13 @@ use criterion::{ criterion_group, criterion_main, measurement::Measurement, BenchmarkGroup, Criterion, }; -use crypto_bigint::{NonZero, Random, Reciprocal, Uint}; +use crypto_bigint::{ + modular::{ + runtime_mod::{DynResidue, DynResidueParams}, + PowResidue, + }, + Encoding, NonZero, Random, Reciprocal, Uint, U256, +}; use rand_core::OsRng; fn bench_division<'a, M: Measurement>(group: &mut BenchmarkGroup<'a, M>) { @@ -72,9 +78,42 @@ fn bench_division<'a, M: Measurement>(group: &mut BenchmarkGroup<'a, M>) { }); } +fn bench_modpow<'a, M: Measurement>(group: &mut BenchmarkGroup<'a, M>) { + const TEST_SET: usize = 10; + let xs = (0..TEST_SET) + .map(|_| U256::random(&mut OsRng)) + .collect::>(); + let moduli = (0..TEST_SET) + .map(|_| U256::random(&mut OsRng) | U256::ONE) + .collect::>(); + let powers = (0..TEST_SET) + .map(|_| U256::random(&mut OsRng) | (U256::ONE << (U256::BIT_SIZE - 1))) + .collect::>(); + + let params = moduli + .iter() + .map(|modulus| DynResidueParams::new(*modulus)) + .collect::>(); + let xs_m = xs + .iter() + .zip(params.iter()) + .map(|(x, p)| DynResidue::new(*x, *p)) + .collect::>(); + + group.bench_function("modpow, 4^4", |b| { + b.iter(|| { + xs_m.iter() + .zip(powers.iter()) + .map(|(x, p)| x.pow(&p)) + .for_each(drop) + }) + }); +} + fn bench_wrapping_ops(c: &mut Criterion) { let mut group = c.benchmark_group("wrapping ops"); bench_division(&mut group); + bench_modpow(&mut group); group.finish(); } diff --git a/src/limb/cmp.rs b/src/limb/cmp.rs index 8e61bbc3e..910d6e0f7 100644 --- a/src/limb/cmp.rs +++ b/src/limb/cmp.rs @@ -42,6 +42,24 @@ impl Limb { (gt as SignedWord) - (lt as SignedWord) } + /// Returns `Word::MAX` if `lhs == rhs` and `0` otherwise. + #[inline] + pub(crate) const fn ct_eq(lhs: Self, rhs: Self) -> Word { + let x = lhs.0; + let y = rhs.0; + + // c == 0 if and only if x == y + let c = x ^ y; + + // If c == 0, then c and -c are both equal to zero; + // otherwise, one or both will have its high bit set. + let d = (c | c.wrapping_neg()) >> (Limb::BIT_SIZE - 1); + + // Result is the opposite of the high bit (now shifted to low). + // Convert 1 to Word::MAX. + (d ^ 1).wrapping_neg() + } + /// Returns `Word::MAX` if `lhs < rhs` and `0` otherwise. #[inline] pub(crate) const fn ct_lt(lhs: Self, rhs: Self) -> Word { diff --git a/src/uint/modular.rs b/src/uint/modular.rs index 9b38b32bd..a93d772f5 100644 --- a/src/uint/modular.rs +++ b/src/uint/modular.rs @@ -44,7 +44,10 @@ where self.pow_specific(exponent, LIMBS * Word::BITS as usize) } - /// Computes the (reduced) exponentiation of a residue, here `exponent_bits` represents the number of bits to take into account for the exponent. Note that this value is leaked in the time pattern. + /// Computes the (reduced) exponentiation of a residue, + /// here `exponent_bits` represents the number of bits to take into account for the exponent. + /// + /// NOTE: `exponent_bits` is leaked in the time pattern. fn pow_specific(self, exponent: &Uint, exponent_bits: usize) -> Self; } @@ -53,11 +56,13 @@ pub trait InvResidue where Self: Sized, { - /// Computes the (reduced) multiplicative inverse of the residue. Returns CtOption, which is `None` if the residue was not invertible. + /// Computes the (reduced) multiplicative inverse of the residue. Returns CtOption, + /// which is `None` if the residue was not invertible. fn inv(self) -> CtOption; } -/// The `GenericResidue` trait provides a consistent API for dealing with residues with a constant modulus. +/// The `GenericResidue` trait provides a consistent API +/// for dealing with residues with a constant modulus. pub trait GenericResidue: AddResidue + MulResidue + PowResidue + InvResidue { diff --git a/src/uint/modular/pow.rs b/src/uint/modular/pow.rs index 17dd17a85..c8e797297 100644 --- a/src/uint/modular/pow.rs +++ b/src/uint/modular/pow.rs @@ -2,7 +2,10 @@ use crate::{Limb, Uint, Word}; use super::mul::{mul_montgomery_form, square_montgomery_form}; -/// Performs modular exponentiation using Montgomery's ladder. `exponent_bits` represents the number of bits to take into account for the exponent. Note that this value is leaked in the time pattern. +/// Performs modular exponentiation using Montgomery's ladder. +/// `exponent_bits` represents the number of bits to take into account for the exponent. +/// +/// NOTE: this value is leaked in the time pattern. pub const fn pow_montgomery_form( x: Uint, exponent: &Uint, @@ -11,29 +14,66 @@ pub const fn pow_montgomery_form( r: Uint, mod_neg_inv: Limb, ) -> Uint { - let mut x1: Uint = r; - let mut x2: Uint = x; + if exponent_bits == 0 { + return r; // 1 in Montgomery form + } - // Shift the exponent all the way to the left so the leftmost bit is the MSB of the `Uint` - let mut n: Uint = exponent.shl_vartime((LIMBS * Word::BITS as usize) - exponent_bits); + const WINDOW: usize = 4; + const WINDOW_MASK: Word = (1 << WINDOW) - 1; - let mut i = 0; - while i < exponent_bits { - // Peel off one bit at a time from the left side - let (next_n, overflow) = n.shl_1(); - n = next_n; + // powers[i] contains x^i + let mut powers = [r; 1 << WINDOW]; + powers[1] = x; + let mut i = 2; + while i < powers.len() { + powers[i] = mul_montgomery_form(&powers[i - 1], &x, modulus, mod_neg_inv); + i += 1; + } - let mut product: Uint = x1; - product = mul_montgomery_form(&product, &x2, modulus, mod_neg_inv); + let starting_limb = (exponent_bits - 1) / Limb::BIT_SIZE; + let starting_bit_in_limb = (exponent_bits - 1) % Limb::BIT_SIZE; + let starting_window = starting_bit_in_limb / WINDOW; + let starting_window_mask = (1 << (starting_bit_in_limb % WINDOW + 1)) - 1; - let mut square = Uint::ct_select(x1, x2, overflow); - square = square_montgomery_form(&square, modulus, mod_neg_inv); + let mut z = r; // 1 in Montgomery form - x1 = Uint::::ct_select(square, product, overflow); - x2 = Uint::::ct_select(product, square, overflow); + let mut limb_num = starting_limb + 1; + while limb_num > 0 { + limb_num -= 1; + let w = exponent.as_limbs()[limb_num].0; - i += 1; + let mut window_num = if limb_num == starting_limb { + starting_window + 1 + } else { + Limb::BIT_SIZE / WINDOW + }; + while window_num > 0 { + window_num -= 1; + + let mut idx = (w >> (window_num * WINDOW)) & WINDOW_MASK; + + if limb_num == starting_limb && window_num == starting_window { + idx &= starting_window_mask; + } else { + let mut i = 0; + while i < WINDOW { + i += 1; + z = square_montgomery_form(&z, modulus, mod_neg_inv); + } + } + + // Constant-time lookup in the array of powers + let mut power = powers[0]; + let mut i = 1; + while i < 1 << WINDOW { + let choice = Limb::ct_eq(Limb(i as Word), Limb(idx)); + power = Uint::::ct_select(power, powers[i], choice); + i += 1; + } + + z = mul_montgomery_form(&z, &power, modulus, mod_neg_inv); + } } - x1 + z } diff --git a/src/uint/modular/runtime_mod.rs b/src/uint/modular/runtime_mod.rs index 758ecaa39..5ce7afd6c 100644 --- a/src/uint/modular/runtime_mod.rs +++ b/src/uint/modular/runtime_mod.rs @@ -84,6 +84,14 @@ impl DynResidue { residue_params, } } + + /// Instantiates a new `Residue` that represents 1. + pub const fn one(residue_params: DynResidueParams) -> Self { + Self { + montgomery_form: residue_params.r, + residue_params, + } + } } impl GenericResidue for DynResidue { diff --git a/src/uint/modular/runtime_mod/runtime_pow.rs b/src/uint/modular/runtime_mod/runtime_pow.rs index 52b215c2a..903668a40 100644 --- a/src/uint/modular/runtime_mod/runtime_pow.rs +++ b/src/uint/modular/runtime_mod/runtime_pow.rs @@ -12,7 +12,10 @@ impl PowResidue for DynResidue { } impl DynResidue { - /// Computes the (reduced) exponentiation of a residue, here `exponent_bits` represents the number of bits to take into account for the exponent. Note that this value is leaked in the time pattern. + /// Computes the (reduced) exponentiation of a residue, + /// here `exponent_bits` represents the number of bits to take into account for the exponent. + /// + /// NOTE: `exponent_bits` is leaked in the time pattern. pub const fn pow_specific(self, exponent: &Uint, exponent_bits: usize) -> Self { Self { montgomery_form: pow_montgomery_form( diff --git a/tests/proptests.rs b/tests/proptests.rs index f12903349..33bac8fe0 100644 --- a/tests/proptests.rs +++ b/tests/proptests.rs @@ -1,6 +1,12 @@ //! Equivalence tests between `num-bigint` and `crypto-bigint` -use crypto_bigint::{Encoding, Limb, NonZero, Word, U256}; +use crypto_bigint::{ + modular::{ + runtime_mod::{DynResidue, DynResidueParams}, + PowResidue, + }, + Encoding, Limb, NonZero, Word, U256, +}; use num_bigint::BigUint; use num_integer::Integer; use num_traits::identities::Zero; @@ -245,5 +251,38 @@ proptest! { let mut bytes = a.to_le_bytes(); bytes.reverse(); assert_eq!(a, U256::from_be_bytes(bytes)); -} + } + + #[test] + fn residue_pow(a in uint_mod_p(P), b in uint()) { + let a_bi = to_biguint(&a); + let b_bi = to_biguint(&b); + let p_bi = to_biguint(&P); + + let expected = to_uint(a_bi.modpow(&b_bi, &p_bi)); + + let params = DynResidueParams::new(P); + let a_m = DynResidue::new(a, params); + let actual = a_m.pow(&b).retrieve(); + + assert_eq!(expected, actual); + } + + #[test] + fn residue_pow_specific(a in uint_mod_p(P), b in uint(), exponent_bits in any::()) { + + let b_masked = b & (U256::ONE << exponent_bits.into()).wrapping_sub(&U256::ONE); + + let a_bi = to_biguint(&a); + let b_bi = to_biguint(&b_masked); + let p_bi = to_biguint(&P); + + let expected = to_uint(a_bi.modpow(&b_bi, &p_bi)); + + let params = DynResidueParams::new(P); + let a_m = DynResidue::new(a, params); + let actual = a_m.pow_specific(&b, exponent_bits.into()).retrieve(); + + assert_eq!(expected, actual); + } }