From f2076322afd579c13ab9d175cc1b9a96daf5acc2 Mon Sep 17 00:00:00 2001 From: Bogdan Opanchuk Date: Sun, 10 Dec 2023 16:15:48 -0800 Subject: [PATCH 1/4] Make `inv_mod2k(_vartime)` return a CtChoice indicating if the inverse exists --- src/ct_choice.rs | 6 +++ src/modular/dyn_residue.rs | 12 +++--- src/modular/residue/macros.rs | 1 + src/uint/inv_mod.rs | 74 ++++++++++++++++++++++++++++------- tests/uint_proptests.rs | 6 ++- 5 files changed, 76 insertions(+), 23 deletions(-) diff --git a/src/ct_choice.rs b/src/ct_choice.rs index 10c2b2230..205564e63 100644 --- a/src/ct_choice.rs +++ b/src/ct_choice.rs @@ -158,6 +158,12 @@ impl From for bool { } } +impl PartialEq for CtChoice { + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } +} + #[cfg(test)] mod tests { use super::CtChoice; diff --git a/src/modular/dyn_residue.rs b/src/modular/dyn_residue.rs index f00901b37..f3fde59cd 100644 --- a/src/modular/dyn_residue.rs +++ b/src/modular/dyn_residue.rs @@ -13,7 +13,7 @@ use super::{ residue::{Residue, ResidueParams}, Retrieve, }; -use crate::{Integer, Limb, Uint, Word}; +use crate::{Limb, Uint, Word}; use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption}; /// Parameters to efficiently go to/from the Montgomery form for an odd modulus provided at runtime. @@ -40,11 +40,9 @@ impl DynResidueParams { let r = Uint::MAX.const_rem(modulus).0.wrapping_add(&Uint::ONE); let r2 = Uint::const_rem_wide(r.square_wide(), modulus).0; - // Since we are calculating the inverse modulo (Word::MAX+1), - // we can take the modulo right away and calculate the inverse of the first limb only. - let modulus_lo = Uint::<1>::from_words([modulus.limbs[0].0]); - let mod_neg_inv = - Limb(Word::MIN.wrapping_sub(modulus_lo.inv_mod2k_vartime(Word::BITS).limbs[0].0)); + // If the inverse does not exist, it means the modulus is odd. + let (inv_mod_limb, modulus_is_odd) = modulus.inv_mod2k_vartime(Word::BITS); + let mod_neg_inv = Limb(Word::MIN.wrapping_sub(inv_mod_limb.limbs[0].0)); let r3 = montgomery_reduction(&r2.square_wide(), modulus, mod_neg_inv); @@ -56,7 +54,7 @@ impl DynResidueParams { mod_neg_inv, }; - CtOption::new(params, modulus.is_odd()) + CtOption::new(params, modulus_is_odd.into()) } /// Returns the modulus which was used to initialize these parameters. diff --git a/src/modular/residue/macros.rs b/src/modular/residue/macros.rs index de3bb2afd..0f92f9b4e 100644 --- a/src/modular/residue/macros.rs +++ b/src/modular/residue/macros.rs @@ -40,6 +40,7 @@ macro_rules! impl_modulus { $crate::Word::MIN.wrapping_sub( Self::MODULUS .inv_mod2k_vartime($crate::Word::BITS) + .0 .as_limbs()[0] .0, ), diff --git a/src/uint/inv_mod.rs b/src/uint/inv_mod.rs index 19f9ade57..8d33e95b0 100644 --- a/src/uint/inv_mod.rs +++ b/src/uint/inv_mod.rs @@ -5,8 +5,10 @@ impl Uint { /// Computes 1/`self` mod `2^k`. /// This method is constant-time w.r.t. `self` but not `k`. /// - /// Conditions: `self` < 2^k and `self` must be odd - pub const fn inv_mod2k_vartime(&self, k: u32) -> Self { + /// If the inverse does not exist (`k > 0` and `self` is even), + /// returns `CtChoice::FALSE` as the second element of the tuple, + /// otherwise returns `CtChoice::TRUE`. + pub const fn inv_mod2k_vartime(&self, k: u32) -> (Self, CtChoice) { // Using the Algorithm 3 from "A Secure Algorithm for Inversion Modulo 2k" // by Sadiel de la Fe and Carles Ferrer. // See . @@ -18,6 +20,9 @@ impl Uint { let mut b = Self::ONE; // keeps `b_i` during iterations let mut i = 0; + // The inverse exists either if `k` is 0 or if `self` is odd. + let is_some = CtChoice::from_u32_nonzero(k).not().or(self.ct_is_odd()); + while i < k { // X_i = b_i mod 2 let x_i = b.limbs[0].0 & 1; @@ -30,13 +35,15 @@ impl Uint { i += 1; } - x + (x, is_some) } /// Computes 1/`self` mod `2^k`. /// - /// Conditions: `self` < 2^k and `self` must be odd - pub const fn inv_mod2k(&self, k: u32) -> Self { + /// If the inverse does not exist (`k > 0` and `self` is even), + /// returns `CtChoice::FALSE` as the second element of the tuple, + /// otherwise returns `CtChoice::TRUE`. + pub const fn inv_mod2k(&self, k: u32) -> (Self, CtChoice) { // This is the same algorithm as in `inv_mod2k_vartime()`, // but made constant-time w.r.t `k` as well. @@ -44,6 +51,9 @@ impl Uint { let mut b = Self::ONE; // keeps `b_i` during iterations let mut i = 0; + // The inverse exists either if `k` is 0 or if `self` is odd. + let is_some = CtChoice::from_u32_nonzero(k).not().or(self.ct_is_odd()); + while i < Self::BITS { // Only iterations for i = 0..k need to change `x`, // the rest are dummy ones performed for the sake of constant-timeness. @@ -52,7 +62,7 @@ impl Uint { // X_i = b_i mod 2 let x_i = b.limbs[0].0 & 1; let x_i_choice = CtChoice::from_word_lsb(x_i); - // b_{i+1} = (b_i - a * X_i) / 2 + // b_{i+1} = (b_i - self * X_i) / 2 b = Self::ct_select(&b, &b.wrapping_sub(self), x_i_choice).shr1(); // Store the X_i bit in the result (x = x | (1 << X_i)) @@ -63,7 +73,7 @@ impl Uint { i += 1; } - x + (x, is_some) } /// Computes the multiplicative inverse of `self` mod `modulus`, where `modulus` is odd. @@ -154,16 +164,14 @@ impl Uint { // Decompose `self` into RNS with moduli `2^k` and `s` and calculate the inverses. // Using the fact that `(z^{-1} mod (m1 * m2)) mod m1 == z^{-1} mod m1` let (a, a_is_some) = self.inv_odd_mod(&s); - let b = self.inv_mod2k(k); - // inverse modulo 2^k exists either if `k` is 0 or if `self` is odd. - let b_is_some = CtChoice::from_u32_nonzero(k).not().or(self.ct_is_odd()); + let (b, b_is_some) = self.inv_mod2k(k); // Restore from RNS: // self^{-1} = a mod s = b mod 2^k // => self^{-1} = a + s * ((b - a) * s^(-1) mod 2^k) // (essentially one step of the Garner's algorithm for recovery from RNS). - let m_odd_inv = s.inv_mod2k(k); // `s` is odd, so this always exists + let (m_odd_inv, _is_some) = s.inv_mod2k(k); // `s` is odd, so this always exists // This part is mod 2^k let mask = Uint::ONE.shl(k).wrapping_sub(&Uint::ONE); @@ -178,7 +186,7 @@ impl Uint { #[cfg(test)] mod tests { - use crate::{U1024, U256, U64}; + use crate::{CtChoice, U1024, U256, U64}; #[test] fn inv_mod2k() { @@ -186,15 +194,53 @@ mod tests { U256::from_be_hex("fffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f"); let e = U256::from_be_hex("3642e6faeaac7c6663b93d3d6a0d489e434ddc0123db5fa627c7f6e22ddacacf"); - let a = v.inv_mod2k(256); + let (a, is_some) = v.inv_mod2k(256); + assert_eq!(e, a); + assert_eq!(is_some, CtChoice::TRUE); + + let (a, is_some) = v.inv_mod2k_vartime(256); assert_eq!(e, a); + assert_eq!(is_some, CtChoice::TRUE); let v = U256::from_be_hex("fffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141"); let e = U256::from_be_hex("261776f29b6b106c7680cf3ed83054a1af5ae537cb4613dbb4f20099aa774ec1"); - let a = v.inv_mod2k(256); + let (a, is_some) = v.inv_mod2k(256); + assert_eq!(e, a); + assert_eq!(is_some, CtChoice::TRUE); + + let (a, is_some) = v.inv_mod2k_vartime(256); + assert_eq!(e, a); + assert_eq!(is_some, CtChoice::TRUE); + + // Check that even if the number is >= 2^k, the inverse is still correct. + + let v = + U256::from_be_hex("fffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141"); + let e = + U256::from_be_hex("0000000000000000000000000000000000000000034613dbb4f20099aa774ec1"); + let (a, is_some) = v.inv_mod2k(90); + assert_eq!(e, a); + assert_eq!(is_some, CtChoice::TRUE); + + let (a, is_some) = v.inv_mod2k_vartime(90); assert_eq!(e, a); + assert_eq!(is_some, CtChoice::TRUE); + + // An inverse of an even number does not exist. + + let (_a, is_some) = U256::from(10u64).inv_mod2k(4); + assert_eq!(is_some, CtChoice::FALSE); + + let (_a, is_some) = U256::from(10u64).inv_mod2k_vartime(4); + assert_eq!(is_some, CtChoice::FALSE); + + // A degenerate case. An inverse mod 2^0 == 1 always exists even for even numbers. + + let (a, is_some) = U256::from(10u64).inv_mod2k_vartime(0); + assert_eq!(a, U256::ZERO); + assert_eq!(is_some, CtChoice::TRUE); } #[test] diff --git a/tests/uint_proptests.rs b/tests/uint_proptests.rs index a09f34fee..9c884e25d 100644 --- a/tests/uint_proptests.rs +++ b/tests/uint_proptests.rs @@ -245,9 +245,11 @@ proptest! { let a_bi = to_biguint(&a); let m_bi = BigUint::one() << k as usize; - let actual = a.inv_mod2k(k); - let actual_vartime = a.inv_mod2k_vartime(k); + let (actual, is_some) = a.inv_mod2k(k); + let (actual_vartime, is_some_vartime) = a.inv_mod2k_vartime(k); assert_eq!(actual, actual_vartime); + assert_eq!(is_some, CtChoice::TRUE); + assert_eq!(is_some_vartime, CtChoice::TRUE); if k == 0 { assert_eq!(actual, U256::ZERO); From c5fa880c3ed0f53edec386327b50d53637b70ae1 Mon Sep 17 00:00:00 2001 From: Bogdan Opanchuk Date: Mon, 11 Dec 2023 14:31:34 -0800 Subject: [PATCH 2/4] Make `inv_odd_mod()` return a falsy CtChoice if the given modulus is even --- src/uint/inv_mod.rs | 48 +++++++++++++++++++++++++++++++++------------ 1 file changed, 35 insertions(+), 13 deletions(-) diff --git a/src/uint/inv_mod.rs b/src/uint/inv_mod.rs index 8d33e95b0..286947208 100644 --- a/src/uint/inv_mod.rs +++ b/src/uint/inv_mod.rs @@ -81,8 +81,8 @@ impl Uint { /// `bits` and `modulus_bits` are the bounds on the bit size /// of `self` and `modulus`, respectively /// (the inversion speed will be proportional to `bits + modulus_bits`). - /// The second element of the tuple is the truthy value if an inverse exists, - /// otherwise it is a falsy value. + /// The second element of the tuple is the truthy value + /// if `modulus` is odd and an inverse exists, otherwise it is a falsy value. /// /// **Note:** variable time in `bits` and `modulus_bits`. /// @@ -93,8 +93,6 @@ impl Uint { bits: u32, modulus_bits: u32, ) -> (Self, CtChoice) { - debug_assert!(modulus.ct_is_odd().is_true_vartime()); - let mut a = *self; let mut u = Uint::ONE; @@ -105,14 +103,15 @@ impl Uint { // `bit_size` can be anything >= `self.bits()` + `modulus.bits()`, setting to the minimum. let bit_size = bits + modulus_bits; - let mut m1hp = *modulus; - let (m1hp_new, carry) = m1hp.shr1_with_overflow(); - debug_assert!(carry.is_true_vartime()); - m1hp = m1hp_new.wrapping_add(&Uint::ONE); + let m1hp = modulus.shr1().wrapping_add(&Uint::ONE); + + let modulus_is_odd = modulus.ct_is_odd(); let mut i = 0; while i < bit_size { - debug_assert!(b.ct_is_odd().is_true_vartime()); + // A sanity check that `b` stays odd. Only matters if `modulus` was odd to begin with, + // otherwise this whole thing produces nonsense anyway. + debug_assert!(modulus_is_odd.not().or(b.ct_is_odd()).is_true_vartime()); let self_odd = a.ct_is_odd(); @@ -129,10 +128,10 @@ impl Uint { debug_assert!(cy.is_true_vartime() == cyy.is_true_vartime()); let (new_a, overflow) = a.shr1_with_overflow(); - debug_assert!(!overflow.is_true_vartime()); + debug_assert!(modulus_is_odd.not().or(overflow.not()).is_true_vartime()); let (new_u, cy) = new_u.shr1_with_overflow(); let (new_u, cy) = new_u.conditional_wrapping_add(&m1hp, cy); - debug_assert!(!cy.is_true_vartime()); + debug_assert!(modulus_is_odd.not().or(cy.not()).is_true_vartime()); a = new_a; u = new_u; @@ -141,9 +140,12 @@ impl Uint { i += 1; } - debug_assert!(!a.ct_is_nonzero().is_true_vartime()); + debug_assert!(modulus_is_odd + .not() + .or(a.ct_is_nonzero().not()) + .is_true_vartime()); - (v, Uint::ct_eq(&b, &Uint::ONE)) + (v, Uint::ct_eq(&b, &Uint::ONE).and(modulus_is_odd)) } /// Computes the multiplicative inverse of `self` mod `modulus`, where `modulus` is odd. @@ -268,12 +270,32 @@ mod tests { assert!(is_some.is_true_vartime()); assert_eq!(res, expected); + // Check that trying to pass an even modulus causes `is_some` to be falsy + let (_res, is_some) = a.inv_odd_mod(&(m.wrapping_add(&U1024::ONE))); + assert!(!is_some.is_true_vartime()); + // Even though it is less efficient, it still works let (res, is_some) = a.inv_mod(&m); assert!(is_some.is_true_vartime()); assert_eq!(res, expected); } + #[test] + fn test_invert_odd_no_inverse() { + // 2^128 - 159, a prime + let p1 = + U256::from_be_hex("00000000000000000000000000000000ffffffffffffffffffffffffffffff61"); + // 2^128 - 173, a prime + let p2 = + U256::from_be_hex("00000000000000000000000000000000ffffffffffffffffffffffffffffff53"); + + let m = p1.wrapping_mul(&p2); + + // `m` is a multiple of `p1`, so no inverse exists + let (_res, is_some) = p1.inv_odd_mod(&m); + assert!(!is_some.is_true_vartime()); + } + #[test] fn test_invert_even() { let a = U1024::from_be_hex(concat![ From 3b5be1769ce7a4844aeba315321b580d29987647 Mon Sep 17 00:00:00 2001 From: Bogdan Opanchuk Date: Mon, 11 Dec 2023 15:08:20 -0800 Subject: [PATCH 3/4] Make `BoxedUint::inv_mod2k(_vartime)` return a Choice indicating if the inverse exists --- src/modular/boxed_residue.rs | 11 ++++------- src/uint/boxed/inv_mod.rs | 26 ++++++++++++++------------ 2 files changed, 18 insertions(+), 19 deletions(-) diff --git a/src/modular/boxed_residue.rs b/src/modular/boxed_residue.rs index f37edd4ea..7bc2fd867 100644 --- a/src/modular/boxed_residue.rs +++ b/src/modular/boxed_residue.rs @@ -103,12 +103,9 @@ impl BoxedResidueParams { /// Common functionality of `new` and `new_vartime`. fn new_inner(modulus: BoxedUint, r: BoxedUint, r2: BoxedUint) -> CtOption { - let is_odd = modulus.is_odd(); - - // Since we are calculating the inverse modulo (Word::MAX+1), - // we can take the modulo right away and calculate the inverse of the first limb only. - let modulus_lo = BoxedUint::from(modulus.limbs.get(0).copied().unwrap_or_default()); - let mod_neg_inv = Limb(Word::MIN.wrapping_sub(modulus_lo.inv_mod2k(Word::BITS).limbs[0].0)); + // If the inverse exists, it means the modulus is odd. + let (inv_mod_limb, modulus_is_odd) = modulus.inv_mod2k(Word::BITS); + let mod_neg_inv = Limb(Word::MIN.wrapping_sub(inv_mod_limb.limbs[0].0)); let r3 = montgomery_reduction_boxed(&mut r2.square(), &modulus, mod_neg_inv); let params = Self { @@ -119,7 +116,7 @@ impl BoxedResidueParams { mod_neg_inv, }; - CtOption::new(params, is_odd) + CtOption::new(params, modulus_is_odd) } /// Modulus value. diff --git a/src/uint/boxed/inv_mod.rs b/src/uint/boxed/inv_mod.rs index 6f77d1771..008b7a499 100644 --- a/src/uint/boxed/inv_mod.rs +++ b/src/uint/boxed/inv_mod.rs @@ -16,16 +16,14 @@ impl BoxedUint { // Decompose `self` into RNS with moduli `2^k` and `s` and calculate the inverses. // Using the fact that `(z^{-1} mod (m1 * m2)) mod m1 == z^{-1} mod m1` let (a, a_is_some) = self.inv_odd_mod(&s); - let b = self.inv_mod2k(k); - // inverse modulo 2^k exists either if `k` is 0 or if `self` is odd. - let b_is_some = k.ct_eq(&0) | self.is_odd(); + let (b, b_is_some) = self.inv_mod2k(k); // Restore from RNS: // self^{-1} = a mod s = b mod 2^k // => self^{-1} = a + s * ((b - a) * s^(-1) mod 2^k) // (essentially one step of the Garner's algorithm for recovery from RNS). - let m_odd_inv = s.inv_mod2k(k); // `s` is odd, so this always exists + let (m_odd_inv, _is_some) = s.inv_mod2k(k); // `s` is odd, so this always exists // This part is mod 2^k let mask = Self::one().shl(k).wrapping_sub(&Self::one()); @@ -39,14 +37,16 @@ impl BoxedUint { /// Computes 1/`self` mod `2^k`. /// - /// Conditions: `self` < 2^k and `self` must be odd - pub(crate) fn inv_mod2k(&self, k: u32) -> Self { - // This is the same algorithm as in `inv_mod2k_vartime()`, - // but made constant-time w.r.t `k` as well. - + /// If the inverse does not exist (`k > 0` and `self` is even), + /// returns `CtChoice::FALSE` as the second element of the tuple, + /// otherwise returns `CtChoice::TRUE`. + pub(crate) fn inv_mod2k(&self, k: u32) -> (Self, Choice) { let mut x = Self::zero_with_precision(self.bits_precision()); // keeps `x` during iterations let mut b = Self::one_with_precision(self.bits_precision()); // keeps `b_i` during iterations + // The inverse exists either if `k` is 0 or if `self` is odd. + let is_some = k.ct_eq(&0) | self.is_odd(); + for i in 0..self.bits_precision() { // Only iterations for i = 0..k need to change `x`, // the rest are dummy ones performed for the sake of constant-timeness. @@ -64,7 +64,7 @@ impl BoxedUint { x.set_bit(i, x_i_choice); } - x + (x, is_some) } /// Computes the multiplicative inverse of `self` mod `modulus`, where `modulus` is odd. @@ -157,8 +157,9 @@ mod tests { 256, ) .unwrap(); - let a = v.inv_mod2k(256); + let (a, is_some) = v.inv_mod2k(256); assert_eq!(e, a); + assert!(bool::from(is_some)); let v = BoxedUint::from_be_slice( &hex!("fffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141"), @@ -170,7 +171,8 @@ mod tests { 256, ) .unwrap(); - let a = v.inv_mod2k(256); + let (a, is_some) = v.inv_mod2k(256); assert_eq!(e, a); + assert!(bool::from(is_some)); } } From 9db32ed337b76da894a995609c49805f6ab8065e Mon Sep 17 00:00:00 2001 From: Bogdan Opanchuk Date: Mon, 11 Dec 2023 15:54:49 -0800 Subject: [PATCH 4/4] Make `BoxedUint::inv_odd_mod()` return a falsy CtChoice if the given modulus is even --- src/uint/boxed/inv_mod.rs | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/src/uint/boxed/inv_mod.rs b/src/uint/boxed/inv_mod.rs index 008b7a499..d4c063c5a 100644 --- a/src/uint/boxed/inv_mod.rs +++ b/src/uint/boxed/inv_mod.rs @@ -80,8 +80,8 @@ impl BoxedUint { /// of `self` and `modulus`, respectively. /// /// (the inversion speed will be proportional to `bits + modulus_bits`). - /// The second element of the tuple is the truthy value if an inverse exists, - /// otherwise it is a falsy value. + /// The second element of the tuple is the truthy value + /// if `modulus` is odd and an inverse exists, otherwise it is a falsy value. /// /// **Note:** variable time in `bits` and `modulus_bits`. /// @@ -90,7 +90,6 @@ impl BoxedUint { debug_assert_eq!(self.bits_precision(), modulus.bits_precision()); let bits_precision = self.bits_precision(); - debug_assert!(bool::from(modulus.is_odd())); let mut a = self.clone(); let mut u = Self::one_with_precision(bits_precision); @@ -100,13 +99,16 @@ impl BoxedUint { // `bit_size` can be anything >= `self.bits()` + `modulus.bits()`, setting to the minimum. let bit_size = bits + modulus_bits; - let mut m1hp = modulus.clone(); - let (m1hp_new, carry) = m1hp.shr1_with_overflow(); - debug_assert!(bool::from(carry)); - m1hp = m1hp_new.wrapping_add(&Self::one_with_precision(bits_precision)); + let m1hp = modulus + .shr1() + .wrapping_add(&Self::one_with_precision(bits_precision)); + + let modulus_is_odd = modulus.is_odd(); for _ in 0..bit_size { - debug_assert!(bool::from(b.is_odd())); + // A sanity check that `b` stays odd. Only matters if `modulus` was odd to begin with, + // otherwise this whole thing produces nonsense anyway. + debug_assert!(bool::from(!modulus_is_odd | b.is_odd())); let self_odd = a.is_odd(); @@ -125,18 +127,18 @@ impl BoxedUint { debug_assert!(bool::from(cy.ct_eq(&cyy))); let (new_a, overflow) = a.shr1_with_overflow(); - debug_assert!(!bool::from(overflow)); + debug_assert!(bool::from(!modulus_is_odd | !overflow)); let (mut new_u, cy) = new_u.shr1_with_overflow(); let cy = new_u.conditional_adc_assign(&m1hp, cy); - debug_assert!(!bool::from(cy)); + debug_assert!(bool::from(!modulus_is_odd | !cy)); a = new_a; u = new_u; v = new_v; } - debug_assert!(bool::from(a.is_zero())); - (v, b.is_one()) + debug_assert!(bool::from(!modulus_is_odd | a.is_zero())); + (v, b.is_one() & modulus_is_odd) } }