Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 40 additions & 1 deletion benches/bench.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
use criterion::{
criterion_group, criterion_main, measurement::Measurement, BenchmarkGroup, Criterion,
};
use crypto_bigint::{NonZero, Random, Reciprocal, Uint};
use crypto_bigint::{
modular::{
runtime_mod::{DynResidue, DynResidueParams},
PowResidue,
},
Encoding, NonZero, Random, Reciprocal, Uint, U256,
};
use rand_core::OsRng;

fn bench_division<'a, M: Measurement>(group: &mut BenchmarkGroup<'a, M>) {
Expand Down Expand Up @@ -72,9 +78,42 @@ fn bench_division<'a, M: Measurement>(group: &mut BenchmarkGroup<'a, M>) {
});
}

fn bench_modpow<'a, M: Measurement>(group: &mut BenchmarkGroup<'a, M>) {
const TEST_SET: usize = 10;
let xs = (0..TEST_SET)
.map(|_| U256::random(&mut OsRng))
.collect::<Vec<_>>();
let moduli = (0..TEST_SET)
.map(|_| U256::random(&mut OsRng) | U256::ONE)
.collect::<Vec<_>>();
let powers = (0..TEST_SET)
.map(|_| U256::random(&mut OsRng) | (U256::ONE << (U256::BIT_SIZE - 1)))
.collect::<Vec<_>>();

let params = moduli
.iter()
.map(|modulus| DynResidueParams::new(*modulus))
.collect::<Vec<_>>();
let xs_m = xs
.iter()
.zip(params.iter())
.map(|(x, p)| DynResidue::new(*x, *p))
.collect::<Vec<_>>();

group.bench_function("modpow, 4^4", |b| {
b.iter(|| {
xs_m.iter()
.zip(powers.iter())
.map(|(x, p)| x.pow(&p))
.for_each(drop)
})
});
}

fn bench_wrapping_ops(c: &mut Criterion) {
let mut group = c.benchmark_group("wrapping ops");
bench_division(&mut group);
bench_modpow(&mut group);
group.finish();
}

Expand Down
18 changes: 18 additions & 0 deletions src/limb/cmp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,24 @@ impl Limb {
(gt as SignedWord) - (lt as SignedWord)
}

/// Returns `Word::MAX` if `lhs == rhs` and `0` otherwise.
#[inline]
pub(crate) const fn ct_eq(lhs: Self, rhs: Self) -> Word {
let x = lhs.0;
let y = rhs.0;

// c == 0 if and only if x == y
let c = x ^ y;

// If c == 0, then c and -c are both equal to zero;
// otherwise, one or both will have its high bit set.
let d = (c | c.wrapping_neg()) >> (Limb::BIT_SIZE - 1);

// Result is the opposite of the high bit (now shifted to low).
// Convert 1 to Word::MAX.
(d ^ 1).wrapping_neg()
}

/// Returns `Word::MAX` if `lhs < rhs` and `0` otherwise.
#[inline]
pub(crate) const fn ct_lt(lhs: Self, rhs: Self) -> Word {
Expand Down
11 changes: 8 additions & 3 deletions src/uint/modular.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ where
self.pow_specific(exponent, LIMBS * Word::BITS as usize)
}

/// Computes the (reduced) exponentiation of a residue, here `exponent_bits` represents the number of bits to take into account for the exponent. Note that this value is leaked in the time pattern.
/// Computes the (reduced) exponentiation of a residue,
/// here `exponent_bits` represents the number of bits to take into account for the exponent.
///
/// NOTE: `exponent_bits` is leaked in the time pattern.
fn pow_specific(self, exponent: &Uint<LIMBS>, exponent_bits: usize) -> Self;
}

Expand All @@ -53,11 +56,13 @@ pub trait InvResidue
where
Self: Sized,
{
/// Computes the (reduced) multiplicative inverse of the residue. Returns CtOption, which is `None` if the residue was not invertible.
/// Computes the (reduced) multiplicative inverse of the residue. Returns CtOption,
/// which is `None` if the residue was not invertible.
fn inv(self) -> CtOption<Self>;
}

/// The `GenericResidue` trait provides a consistent API for dealing with residues with a constant modulus.
/// The `GenericResidue` trait provides a consistent API
/// for dealing with residues with a constant modulus.
pub trait GenericResidue<const LIMBS: usize>:
AddResidue + MulResidue + PowResidue<LIMBS> + InvResidue
{
Expand Down
76 changes: 58 additions & 18 deletions src/uint/modular/pow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ use crate::{Limb, Uint, Word};

use super::mul::{mul_montgomery_form, square_montgomery_form};

/// Performs modular exponentiation using Montgomery's ladder. `exponent_bits` represents the number of bits to take into account for the exponent. Note that this value is leaked in the time pattern.
/// Performs modular exponentiation using Montgomery's ladder.
/// `exponent_bits` represents the number of bits to take into account for the exponent.
///
/// NOTE: this value is leaked in the time pattern.
pub const fn pow_montgomery_form<const LIMBS: usize>(
x: Uint<LIMBS>,
exponent: &Uint<LIMBS>,
Expand All @@ -11,29 +14,66 @@ pub const fn pow_montgomery_form<const LIMBS: usize>(
r: Uint<LIMBS>,
mod_neg_inv: Limb,
) -> Uint<LIMBS> {
let mut x1: Uint<LIMBS> = r;
let mut x2: Uint<LIMBS> = x;
if exponent_bits == 0 {
return r; // 1 in Montgomery form
}

// Shift the exponent all the way to the left so the leftmost bit is the MSB of the `Uint`
let mut n: Uint<LIMBS> = exponent.shl_vartime((LIMBS * Word::BITS as usize) - exponent_bits);
const WINDOW: usize = 4;
const WINDOW_MASK: Word = (1 << WINDOW) - 1;

let mut i = 0;
while i < exponent_bits {
// Peel off one bit at a time from the left side
let (next_n, overflow) = n.shl_1();
n = next_n;
// powers[i] contains x^i
let mut powers = [r; 1 << WINDOW];
powers[1] = x;
let mut i = 2;
while i < powers.len() {
powers[i] = mul_montgomery_form(&powers[i - 1], &x, modulus, mod_neg_inv);
i += 1;
}

let mut product: Uint<LIMBS> = x1;
product = mul_montgomery_form(&product, &x2, modulus, mod_neg_inv);
let starting_limb = (exponent_bits - 1) / Limb::BIT_SIZE;
let starting_bit_in_limb = (exponent_bits - 1) % Limb::BIT_SIZE;
let starting_window = starting_bit_in_limb / WINDOW;
let starting_window_mask = (1 << (starting_bit_in_limb % WINDOW + 1)) - 1;

let mut square = Uint::ct_select(x1, x2, overflow);
square = square_montgomery_form(&square, modulus, mod_neg_inv);
let mut z = r; // 1 in Montgomery form

x1 = Uint::<LIMBS>::ct_select(square, product, overflow);
x2 = Uint::<LIMBS>::ct_select(product, square, overflow);
let mut limb_num = starting_limb + 1;
while limb_num > 0 {
limb_num -= 1;
let w = exponent.as_limbs()[limb_num].0;

i += 1;
let mut window_num = if limb_num == starting_limb {
starting_window + 1
} else {
Limb::BIT_SIZE / WINDOW
};
while window_num > 0 {
window_num -= 1;

let mut idx = (w >> (window_num * WINDOW)) & WINDOW_MASK;

if limb_num == starting_limb && window_num == starting_window {
idx &= starting_window_mask;
} else {
let mut i = 0;
while i < WINDOW {
i += 1;
z = square_montgomery_form(&z, modulus, mod_neg_inv);
}
}

// Constant-time lookup in the array of powers
let mut power = powers[0];
let mut i = 1;
while i < 1 << WINDOW {
let choice = Limb::ct_eq(Limb(i as Word), Limb(idx));
power = Uint::<LIMBS>::ct_select(power, powers[i], choice);
i += 1;
}

z = mul_montgomery_form(&z, &power, modulus, mod_neg_inv);
}
}

x1
z
}
8 changes: 8 additions & 0 deletions src/uint/modular/runtime_mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,14 @@ impl<const LIMBS: usize> DynResidue<LIMBS> {
residue_params,
}
}

/// Instantiates a new `Residue` that represents 1.
pub const fn one(residue_params: DynResidueParams<LIMBS>) -> Self {
Self {
montgomery_form: residue_params.r,
residue_params,
}
}
}

impl<const LIMBS: usize> GenericResidue<LIMBS> for DynResidue<LIMBS> {
Expand Down
5 changes: 4 additions & 1 deletion src/uint/modular/runtime_mod/runtime_pow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ impl<const LIMBS: usize> PowResidue<LIMBS> for DynResidue<LIMBS> {
}

impl<const LIMBS: usize> DynResidue<LIMBS> {
/// Computes the (reduced) exponentiation of a residue, here `exponent_bits` represents the number of bits to take into account for the exponent. Note that this value is leaked in the time pattern.
/// Computes the (reduced) exponentiation of a residue,
/// here `exponent_bits` represents the number of bits to take into account for the exponent.
///
/// NOTE: `exponent_bits` is leaked in the time pattern.
pub const fn pow_specific(self, exponent: &Uint<LIMBS>, exponent_bits: usize) -> Self {
Self {
montgomery_form: pow_montgomery_form(
Expand Down
43 changes: 41 additions & 2 deletions tests/proptests.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
//! Equivalence tests between `num-bigint` and `crypto-bigint`

use crypto_bigint::{Encoding, Limb, NonZero, Word, U256};
use crypto_bigint::{
modular::{
runtime_mod::{DynResidue, DynResidueParams},
PowResidue,
},
Encoding, Limb, NonZero, Word, U256,
};
use num_bigint::BigUint;
use num_integer::Integer;
use num_traits::identities::Zero;
Expand Down Expand Up @@ -245,5 +251,38 @@ proptest! {
let mut bytes = a.to_le_bytes();
bytes.reverse();
assert_eq!(a, U256::from_be_bytes(bytes));
}
}

#[test]
fn residue_pow(a in uint_mod_p(P), b in uint()) {
let a_bi = to_biguint(&a);
let b_bi = to_biguint(&b);
let p_bi = to_biguint(&P);

let expected = to_uint(a_bi.modpow(&b_bi, &p_bi));

let params = DynResidueParams::new(P);
let a_m = DynResidue::new(a, params);
let actual = a_m.pow(&b).retrieve();

assert_eq!(expected, actual);
}

#[test]
fn residue_pow_specific(a in uint_mod_p(P), b in uint(), exponent_bits in any::<u8>()) {

let b_masked = b & (U256::ONE << exponent_bits.into()).wrapping_sub(&U256::ONE);

let a_bi = to_biguint(&a);
let b_bi = to_biguint(&b_masked);
let p_bi = to_biguint(&P);

let expected = to_uint(a_bi.modpow(&b_bi, &p_bi));

let params = DynResidueParams::new(P);
let a_m = DynResidue::new(a, params);
let actual = a_m.pow_specific(&b, exponent_bits.into()).retrieve();

assert_eq!(expected, actual);
}
}