diff --git a/src/limb/rand.rs b/src/limb/rand.rs index 0bc8af31a..0c89cf93f 100644 --- a/src/limb/rand.rs +++ b/src/limb/rand.rs @@ -23,20 +23,15 @@ impl RandomMod for Limb { fn random_mod(mut rng: impl CryptoRng + RngCore, modulus: &NonZero) -> Self { let mut bytes = ::Repr::default(); - // TODO(tarcieri): use `div_ceil` when available - // See: https://github.com/rust-lang/rust/issues/88581 - let mut n_bytes = modulus.bits() / 8; - - // Ensure the randomly generated value can always be larger than - // the modulus in order to ensure a uniform distribution - if n_bytes < Self::BYTE_SIZE { - n_bytes += 1; - } + let n_bits = modulus.bits(); + let n_bytes = (n_bits + 7) / 8; + let mask = 0xff >> (8 * n_bytes - n_bits); loop { rng.fill_bytes(&mut bytes[..n_bytes]); - let n = Limb::from_le_bytes(bytes); + bytes[0] = bytes[0] & mask; + let n = Limb::from_be_bytes(bytes); if n.ct_lt(modulus).into() { return n; } diff --git a/src/uint/rand.rs b/src/uint/rand.rs index e88280afa..14fb98572 100644 --- a/src/uint/rand.rs +++ b/src/uint/rand.rs @@ -35,28 +35,15 @@ impl RandomMod for UInt { fn random_mod(mut rng: impl CryptoRng + RngCore, modulus: &NonZero) -> Self { let mut n = Self::ZERO; - // TODO(tarcieri): use `div_ceil` when available - // See: https://github.com/rust-lang/rust/issues/88581 - let mut n_limbs = modulus.bits() / Limb::BIT_SIZE; - if n_limbs < LIMBS { - n_limbs += 1; - } - - // Compute the highest limb of `modulus` as a `NonZero`. - // Add one to ensure `Limb::random_mod` returns values inclusive of this limb. - let modulus_hi = - NonZero::new(modulus.limbs[n_limbs.saturating_sub(1)].saturating_add(Limb::ONE)) - .unwrap(); // Always at least one due to `saturating_add` + let n_bits = modulus.bits(); + let n_limbs = (n_bits + Limb::BIT_SIZE - 1) / Limb::BIT_SIZE; + let mask = Limb(Limb::MAX.0 >> (Limb::BIT_SIZE * n_limbs - n_bits)); loop { for i in 0..n_limbs { - n.limbs[i] = if (i + 1 == n_limbs) && (*modulus_hi != Limb::MAX) { - // Highest limb - Limb::random_mod(&mut rng, &modulus_hi) - } else { - Limb::random(&mut rng) - } + n.limbs[i] = Limb::random(&mut rng); } + n.limbs[n_limbs - 1] = n.limbs[n_limbs - 1].bitand(mask); if n.ct_lt(modulus).into() { return n;