Skip to content

Commit 14b8909

Browse files
committed
Add BoxedBernsteinYangInverter
Adds a boxed equivalent of `BernsteinYangInverter` which operates on values whose number of limbs is determined at compile time. It's largely copy-pasted from the stack-allocated equivalent, but the long-term goal will be to rewrite it to use in-place operations rather than allocating.
1 parent fcc0933 commit 14b8909

4 files changed

Lines changed: 453 additions & 90 deletions

File tree

src/modular.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ mod sub;
3232
pub(crate) mod boxed_residue;
3333

3434
pub use self::{
35-
bernstein_yang::BernsteinYangInverter,
35+
bernstein_yang::{boxed::BoxedBernsteinYangInverter, BernsteinYangInverter},
3636
dyn_residue::{inv::DynResidueInverter, DynResidue, DynResidueParams},
3737
reduction::montgomery_reduction,
3838
residue::{inv::ResidueInverter, Residue, ResidueParams},

src/modular/bernstein_yang.rs

Lines changed: 76 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,14 @@
1313
#[macro_use]
1414
mod macros;
1515

16+
#[cfg(feature = "alloc")]
17+
pub(super) mod boxed;
18+
1619
use crate::{ConstChoice, ConstCtOption, Inverter, Limb, Uint, Word};
1720
use subtle::CtOption;
1821

19-
/// Type of the modular multiplicative inverter based on the Bernstein-Yang method.
22+
/// Modular multiplicative inverter based on the Bernstein-Yang method.
23+
///
2024
/// The inverter can be created for a specified modulus M and adjusting parameter A
2125
/// to compute the adjusted multiplicative inverses of positive integers, i.e. for
2226
/// computing (1 / x) * A (mod M) for a positive integer x.
@@ -59,8 +63,7 @@ impl<const SAT_LIMBS: usize, const UNSAT_LIMBS: usize>
5963
{
6064
/// Creates the inverter for specified modulus and adjusting parameter.
6165
///
62-
/// Modulus must be odd. Returns `ConstChoice::FALSE` if it is not.
63-
#[allow(trivial_numeric_casts)]
66+
/// Modulus must be odd. Returns `None` if it is not.
6467
pub const fn new(modulus: &Uint<SAT_LIMBS>, adjuster: &Uint<SAT_LIMBS>) -> ConstCtOption<Self> {
6568
let ret = Self {
6669
modulus: Uint62L::from_uint(modulus),
@@ -71,7 +74,7 @@ impl<const SAT_LIMBS: usize, const UNSAT_LIMBS: usize>
7174
ConstCtOption::new(ret, modulus.is_odd())
7275
}
7376

74-
/// Returns either the adjusted modular multiplicative inverse for the argument or None
77+
/// Returns either the adjusted modular multiplicative inverse for the argument or `None`
7578
/// depending on invertibility of the argument, i.e. its coprimality with the modulus
7679
pub const fn inv(&self, value: &Uint<SAT_LIMBS>) -> ConstCtOption<Uint<SAT_LIMBS>> {
7780
let (mut d, mut e) = (Uint62L::ZERO, self.adjuster);
@@ -147,8 +150,8 @@ impl<const SAT_LIMBS: usize, const UNSAT_LIMBS: usize>
147150
t: Matrix,
148151
) -> (Uint62L<UNSAT_LIMBS>, Uint62L<UNSAT_LIMBS>) {
149152
(
150-
f.mul(t[0][0]).add(&g.mul(t[0][1])).shift(),
151-
f.mul(t[1][0]).add(&g.mul(t[1][1])).shift(),
153+
f.mul(t[0][0]).add(&g.mul(t[0][1])).shr(),
154+
f.mul(t[1][0]).add(&g.mul(t[1][1])).shr(),
152155
)
153156
}
154157

@@ -187,7 +190,7 @@ impl<const SAT_LIMBS: usize, const UNSAT_LIMBS: usize>
187190
.add(&e.mul(t[1][1]))
188191
.add(&self.modulus.mul(me));
189192

190-
(cd.shift(), ce.shift())
193+
(cd.shr(), ce.shr())
191194
}
192195

193196
/// Returns either "value (mod M)" or "-value (mod M)", where M is the modulus the
@@ -260,7 +263,7 @@ impl<const LIMBS: usize> Uint62L<LIMBS> {
260263
/// Number of bits in each limb.
261264
pub const LIMB_BITS: usize = 62;
262265

263-
/// Mask, in which the B lowest bits are 1 and only they.
266+
/// Mask, in which the 62 lowest bits are 1.
264267
pub const MASK: u64 = u64::MAX >> (64 - Self::LIMB_BITS);
265268

266269
/// Representation of -1.
@@ -277,7 +280,7 @@ impl<const LIMBS: usize> Uint62L<LIMBS> {
277280
};
278281

279282
/// Convert from 64-bit saturated representation used by `Uint` to the 62-bit unsaturated representation used by
280-
/// `Uint62`.
283+
/// `Uint62L`.
281284
///
282285
/// Returns a big unsigned integer as an array of 62-bit chunks, which is equal modulo 2 ^ (62 * S) to the input big
283286
/// unsigned integer stored as an array of 64-bit chunks.
@@ -289,17 +292,13 @@ impl<const LIMBS: usize> Uint62L<LIMBS> {
289292
panic!("incorrect number of limbs");
290293
}
291294

292-
Self(impl_limb_convert!(
293-
Word,
294-
Word::BITS as usize,
295-
u64,
296-
62,
297-
LIMBS,
298-
input.as_words()
299-
))
295+
let mut output = [0; LIMBS];
296+
impl_limb_convert!(Word, Word::BITS as usize, input.as_words(), u64, 62, output);
297+
298+
Self(output)
300299
}
301300

302-
/// Convert from 62-bit unsaturated representation used by `Uint62` to the 64-bit saturated representation used by
301+
/// Convert from 62-bit unsaturated representation used by `Uint62L` to the 64-bit saturated representation used by
303302
/// `Uint`.
304303
///
305304
/// Returns a big unsigned integer as an array of 64-bit chunks, which is equal modulo 2 ^ (64 * S) to the input big
@@ -312,53 +311,9 @@ impl<const LIMBS: usize> Uint62L<LIMBS> {
312311
panic!("incorrect number of limbs");
313312
}
314313

315-
Uint::from_words(impl_limb_convert!(
316-
u64,
317-
62,
318-
Word,
319-
Word::BITS as usize,
320-
SAT_LIMBS,
321-
&self.0
322-
))
323-
}
324-
325-
/// Returns the result of applying 62-bit right arithmetical shift to the current number.
326-
pub const fn shift(&self) -> Self {
327-
let mut ret = Self::ZERO;
328-
if self.is_negative() {
329-
ret.0[LIMBS - 1] = Self::MASK;
330-
}
331-
332-
let mut i = 0;
333-
while i < LIMBS - 1 {
334-
ret.0[i] = self.0[i + 1];
335-
i += 1;
336-
}
337-
338-
ret
339-
}
340-
341-
/// Returns the lowest 62 bits of the current number.
342-
pub const fn lowest(&self) -> u64 {
343-
self.0[0]
344-
}
345-
346-
/// Returns "true" iff the current number is negative.
347-
pub const fn is_negative(&self) -> bool {
348-
self.0[LIMBS - 1] > (Self::MASK >> 1)
349-
}
350-
351-
/// Const fn equivalent for `PartialEq::eq`.
352-
pub const fn eq(&self, other: &Self) -> bool {
353-
let mut ret = true;
354-
let mut i = 0;
355-
356-
while i < LIMBS {
357-
ret &= self.0[i] == other.0[i];
358-
i += 1;
359-
}
360-
361-
ret
314+
let mut ret = [0 as Word; SAT_LIMBS];
315+
impl_limb_convert!(u64, 62, &self.0, Word, Word::BITS as usize, ret);
316+
Uint::from_words(ret)
362317
}
363318

364319
/// Const fn equivalent for `Add::add`.
@@ -376,23 +331,6 @@ impl<const LIMBS: usize> Uint62L<LIMBS> {
376331
ret
377332
}
378333

379-
/// Const fn equivalent for `Neg::neg`.
380-
pub const fn neg(&self) -> Self {
381-
// For the two's complement code the additive negation is the result
382-
// of adding 1 to the bitwise inverted argument's representation
383-
let (mut ret, mut carry) = (Self::ZERO, 1);
384-
let mut i = 0;
385-
386-
while i < LIMBS {
387-
let sum = (self.0[i] ^ Self::MASK) + carry;
388-
ret.0[i] = sum & Self::MASK;
389-
carry = sum >> Self::LIMB_BITS;
390-
i += 1;
391-
}
392-
393-
ret
394-
}
395-
396334
/// Const fn equivalent for `Mul::<i64>::mul`.
397335
pub const fn mul(&self, other: i64) -> Self {
398336
let mut ret = Self::ZERO;
@@ -424,4 +362,60 @@ impl<const LIMBS: usize> Uint62L<LIMBS> {
424362

425363
ret
426364
}
365+
366+
/// Const fn equivalent for `Neg::neg`.
367+
pub const fn neg(&self) -> Self {
368+
// For the two's complement code the additive negation is the result
369+
// of adding 1 to the bitwise inverted argument's representation
370+
let (mut ret, mut carry) = (Self::ZERO, 1);
371+
let mut i = 0;
372+
373+
while i < LIMBS {
374+
let sum = (self.0[i] ^ Self::MASK) + carry;
375+
ret.0[i] = sum & Self::MASK;
376+
carry = sum >> Self::LIMB_BITS;
377+
i += 1;
378+
}
379+
380+
ret
381+
}
382+
383+
/// Returns the result of applying 62-bit right arithmetical shift to the current number.
384+
pub const fn shr(&self) -> Self {
385+
let mut ret = Self::ZERO;
386+
if self.is_negative() {
387+
ret.0[LIMBS - 1] = Self::MASK;
388+
}
389+
390+
let mut i = 0;
391+
while i < LIMBS - 1 {
392+
ret.0[i] = self.0[i + 1];
393+
i += 1;
394+
}
395+
396+
ret
397+
}
398+
399+
/// Const fn equivalent for `PartialEq::eq`.
400+
pub const fn eq(&self, other: &Self) -> bool {
401+
let mut ret = true;
402+
let mut i = 0;
403+
404+
while i < LIMBS {
405+
ret &= self.0[i] == other.0[i];
406+
i += 1;
407+
}
408+
409+
ret
410+
}
411+
412+
/// Returns "true" iff the current number is negative.
413+
pub const fn is_negative(&self) -> bool {
414+
self.0[LIMBS - 1] > (Self::MASK >> 1)
415+
}
416+
417+
/// Returns the lowest 62 bits of the current number.
418+
pub const fn lowest(&self) -> u64 {
419+
self.0[0]
420+
}
427421
}

0 commit comments

Comments
 (0)