diff --git a/README.md b/README.md index 551058b1..9027659c 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ A portable RSA implementation in pure Rust. ## Example ```rust -use rsa::{Pkcs1v15Encrypt, PublicKey, RsaPrivateKey, RsaPublicKey}; +use rsa::{Pkcs1v15Encrypt, RsaPrivateKey, RsaPublicKey}; let mut rng = rand::thread_rng(); let bits = 2048; diff --git a/src/algorithms.rs b/src/algorithms.rs index 8bdc1d9a..9d803b11 100644 --- a/src/algorithms.rs +++ b/src/algorithms.rs @@ -1,6 +1,10 @@ //! Useful algorithms related to RSA. -use digest::{Digest, DynDigest, FixedOutputReset}; +mod mgf; +pub(crate) mod oaep; +pub(crate) mod pkcs1v15; +pub(crate) mod pss; + use num_bigint::traits::ModInverse; use num_bigint::{BigUint, RandPrime}; #[allow(unused_imports)] @@ -134,75 +138,3 @@ pub fn generate_multi_prime_key_with_exp( RsaPrivateKey::from_components(n_final, exp.clone(), d_final, primes) } - -/// Mask generation function. -/// -/// Panics if out is larger than 2**32. This is in accordance with RFC 8017 - PKCS #1 B.2.1 -pub fn mgf1_xor(out: &mut [u8], digest: &mut dyn DynDigest, seed: &[u8]) { - let mut counter = [0u8; 4]; - let mut i = 0; - - const MAX_LEN: u64 = core::u32::MAX as u64 + 1; - assert!(out.len() as u64 <= MAX_LEN); - - while i < out.len() { - let mut digest_input = vec![0u8; seed.len() + 4]; - digest_input[0..seed.len()].copy_from_slice(seed); - digest_input[seed.len()..].copy_from_slice(&counter); - - digest.update(digest_input.as_slice()); - let digest_output = &*digest.finalize_reset(); - let mut j = 0; - loop { - if j >= digest_output.len() || i >= out.len() { - break; - } - - out[i] ^= digest_output[j]; - j += 1; - i += 1; - } - inc_counter(&mut counter); - } -} - -/// Mask generation function. -/// -/// Panics if out is larger than 2**32. This is in accordance with RFC 8017 - PKCS #1 B.2.1 -pub fn mgf1_xor_digest(out: &mut [u8], digest: &mut D, seed: &[u8]) -where - D: Digest + FixedOutputReset, -{ - let mut counter = [0u8; 4]; - let mut i = 0; - - const MAX_LEN: u64 = core::u32::MAX as u64 + 1; - assert!(out.len() as u64 <= MAX_LEN); - - while i < out.len() { - Digest::update(digest, seed); - Digest::update(digest, counter); - - let digest_output = digest.finalize_reset(); - let mut j = 0; - loop { - if j >= digest_output.len() || i >= out.len() { - break; - } - - out[i] ^= digest_output[j]; - j += 1; - i += 1; - } - inc_counter(&mut counter); - } -} -fn inc_counter(counter: &mut [u8; 4]) { - for i in (0..4).rev() { - counter[i] = counter[i].wrapping_add(1); - if counter[i] != 0 { - // No overflow - return; - } - } -} diff --git a/src/algorithms/mgf.rs b/src/algorithms/mgf.rs new file mode 100644 index 00000000..aa8fb2a3 --- /dev/null +++ b/src/algorithms/mgf.rs @@ -0,0 +1,75 @@ +//! Mask generation function common to both PSS and OAEP padding + +use digest::{Digest, DynDigest, FixedOutputReset}; + +/// Mask generation function. +/// +/// Panics if out is larger than 2**32. This is in accordance with RFC 8017 - PKCS #1 B.2.1 +pub fn mgf1_xor(out: &mut [u8], digest: &mut dyn DynDigest, seed: &[u8]) { + let mut counter = [0u8; 4]; + let mut i = 0; + + const MAX_LEN: u64 = core::u32::MAX as u64 + 1; + assert!(out.len() as u64 <= MAX_LEN); + + while i < out.len() { + let mut digest_input = vec![0u8; seed.len() + 4]; + digest_input[0..seed.len()].copy_from_slice(seed); + digest_input[seed.len()..].copy_from_slice(&counter); + + digest.update(digest_input.as_slice()); + let digest_output = &*digest.finalize_reset(); + let mut j = 0; + loop { + if j >= digest_output.len() || i >= out.len() { + break; + } + + out[i] ^= digest_output[j]; + j += 1; + i += 1; + } + inc_counter(&mut counter); + } +} + +/// Mask generation function. +/// +/// Panics if out is larger than 2**32. This is in accordance with RFC 8017 - PKCS #1 B.2.1 +pub fn mgf1_xor_digest(out: &mut [u8], digest: &mut D, seed: &[u8]) +where + D: Digest + FixedOutputReset, +{ + let mut counter = [0u8; 4]; + let mut i = 0; + + const MAX_LEN: u64 = core::u32::MAX as u64 + 1; + assert!(out.len() as u64 <= MAX_LEN); + + while i < out.len() { + Digest::update(digest, seed); + Digest::update(digest, counter); + + let digest_output = digest.finalize_reset(); + let mut j = 0; + loop { + if j >= digest_output.len() || i >= out.len() { + break; + } + + out[i] ^= digest_output[j]; + j += 1; + i += 1; + } + inc_counter(&mut counter); + } +} +fn inc_counter(counter: &mut [u8; 4]) { + for i in (0..4).rev() { + counter[i] = counter[i].wrapping_add(1); + if counter[i] != 0 { + // No overflow + return; + } + } +} diff --git a/src/algorithms/oaep.rs b/src/algorithms/oaep.rs new file mode 100644 index 00000000..0ba2de9d --- /dev/null +++ b/src/algorithms/oaep.rs @@ -0,0 +1,246 @@ +//! Encryption and Decryption using [OAEP padding](https://datatracker.ietf.org/doc/html/rfc8017#section-7.1). +//! +use alloc::string::String; +use alloc::vec::Vec; + +use digest::{Digest, DynDigest, FixedOutputReset}; +use rand_core::CryptoRngCore; +use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption}; +use zeroize::Zeroizing; + +use super::mgf::{mgf1_xor, mgf1_xor_digest}; +use crate::errors::{Error, Result}; + +// 2**61 -1 (pow is not const yet) +// TODO: This is the maximum for SHA-1, unclear from the RFC what the values are for other hashing functions. +const MAX_LABEL_LEN: u64 = 2_305_843_009_213_693_951; + +#[inline] +fn encrypt_internal( + rng: &mut R, + msg: &[u8], + p_hash: &[u8], + h_size: usize, + k: usize, + mut mgf: MGF, +) -> Result>> { + if msg.len() + 2 * h_size + 2 > k { + return Err(Error::MessageTooLong); + } + + let mut em = Zeroizing::new(vec![0u8; k]); + + let (_, payload) = em.split_at_mut(1); + let (seed, db) = payload.split_at_mut(h_size); + rng.fill_bytes(seed); + + // Data block DB = pHash || PS || 01 || M + let db_len = k - h_size - 1; + + db[0..h_size].copy_from_slice(p_hash); + db[db_len - msg.len() - 1] = 1; + db[db_len - msg.len()..].copy_from_slice(msg); + + mgf(seed, db); + + Ok(em) +} + +/// Encrypts the given message with RSA and the padding scheme from +/// [PKCS#1 OAEP]. +/// +/// The message must be no longer than the length of the public modulus minus +/// `2 + (2 * hash.size())`. +/// +/// [PKCS#1 OAEP]: https://datatracker.ietf.org/doc/html/rfc8017#section-7.1 +#[inline] +pub(crate) fn oaep_encrypt( + rng: &mut R, + msg: &[u8], + digest: &mut dyn DynDigest, + mgf_digest: &mut dyn DynDigest, + label: Option, + k: usize, +) -> Result>> { + let h_size = digest.output_size(); + + let label = label.unwrap_or_default(); + if label.len() as u64 > MAX_LABEL_LEN { + return Err(Error::LabelTooLong); + } + + digest.update(label.as_bytes()); + let p_hash = digest.finalize_reset(); + + encrypt_internal(rng, msg, &p_hash, h_size, k, |seed, db| { + mgf1_xor(db, mgf_digest, seed); + mgf1_xor(seed, mgf_digest, db); + }) +} + +/// Encrypts the given message with RSA and the padding scheme from +/// [PKCS#1 OAEP]. +/// +/// The message must be no longer than the length of the public modulus minus +/// `2 + (2 * hash.size())`. +/// +/// [PKCS#1 OAEP]: https://datatracker.ietf.org/doc/html/rfc8017#section-7.1 +#[inline] +pub(crate) fn oaep_encrypt_digest< + R: CryptoRngCore + ?Sized, + D: Digest, + MGD: Digest + FixedOutputReset, +>( + rng: &mut R, + msg: &[u8], + label: Option, + k: usize, +) -> Result>> { + let h_size = ::output_size(); + + let label = label.unwrap_or_default(); + if label.len() as u64 > MAX_LABEL_LEN { + return Err(Error::LabelTooLong); + } + + let p_hash = D::digest(label.as_bytes()); + + encrypt_internal(rng, msg, &p_hash, h_size, k, |seed, db| { + let mut mgf_digest = MGD::new(); + mgf1_xor_digest(db, &mut mgf_digest, seed); + mgf1_xor_digest(seed, &mut mgf_digest, db); + }) +} + +///Decrypts OAEP padding. +/// +/// Note that whether this function returns an error or not discloses secret +/// information. If an attacker can cause this function to run repeatedly and +/// learn whether each instance returned an error then they can decrypt and +/// forge signatures as if they had the private key. +/// +/// See `decrypt_session_key` for a way of solving this problem. +/// +/// [PKCS#1 OAEP]: https://datatracker.ietf.org/doc/html/rfc8017#section-7.1 +#[inline] +pub(crate) fn oaep_decrypt( + em: &mut [u8], + digest: &mut dyn DynDigest, + mgf_digest: &mut dyn DynDigest, + label: Option, + k: usize, +) -> Result> { + let h_size = digest.output_size(); + + let label = label.unwrap_or_default(); + if label.len() as u64 > MAX_LABEL_LEN { + return Err(Error::Decryption); + } + + digest.update(label.as_bytes()); + + let expected_p_hash = digest.finalize_reset(); + + let res = decrypt_inner(em, h_size, &expected_p_hash, k, |seed, db| { + mgf1_xor(seed, mgf_digest, db); + mgf1_xor(db, mgf_digest, seed); + })?; + if res.is_none().into() { + return Err(Error::Decryption); + } + + let (out, index) = res.unwrap(); + + Ok(out[index as usize..].to_vec()) +} + +///Decrypts OAEP padding. +/// +/// Note that whether this function returns an error or not discloses secret +/// information. If an attacker can cause this function to run repeatedly and +/// learn whether each instance returned an error then they can decrypt and +/// forge signatures as if they had the private key. +/// +/// See `decrypt_session_key` for a way of solving this problem. +/// +/// [PKCS#1 OAEP]: https://datatracker.ietf.org/doc/html/rfc8017#section-7.1 +#[inline] +pub(crate) fn oaep_decrypt_digest( + em: &mut [u8], + label: Option, + k: usize, +) -> Result> { + let h_size = ::output_size(); + + let label = label.unwrap_or_default(); + if label.len() as u64 > MAX_LABEL_LEN { + return Err(Error::LabelTooLong); + } + + let expected_p_hash = D::digest(label.as_bytes()); + + let res = decrypt_inner(em, h_size, &expected_p_hash, k, |seed, db| { + let mut mgf_digest = MGD::new(); + mgf1_xor_digest(seed, &mut mgf_digest, db); + mgf1_xor_digest(db, &mut mgf_digest, seed); + })?; + if res.is_none().into() { + return Err(Error::Decryption); + } + + let (out, index) = res.unwrap(); + + Ok(out[index as usize..].to_vec()) +} + +/// Decrypts OAEP padding. It returns one or zero in valid that indicates whether the +/// plaintext was correctly structured. +#[inline] +fn decrypt_inner( + em: &mut [u8], + h_size: usize, + expected_p_hash: &[u8], + k: usize, + mut mgf: MGF, +) -> Result, u32)>> { + if k < 11 { + return Err(Error::Decryption); + } + + if k < h_size * 2 + 2 { + return Err(Error::Decryption); + } + + let first_byte_is_zero = em[0].ct_eq(&0u8); + + let (_, payload) = em.split_at_mut(1); + let (seed, db) = payload.split_at_mut(h_size); + + mgf(seed, db); + + let hash_are_equal = db[0..h_size].ct_eq(expected_p_hash); + + // The remainder of the plaintext must be zero or more 0x00, followed + // by 0x01, followed by the message. + // looking_for_index: 1 if we are still looking for the 0x01 + // index: the offset of the first 0x01 byte + // zero_before_one: 1 if we saw a non-zero byte before the 1 + let mut looking_for_index = Choice::from(1u8); + let mut index = 0u32; + let mut nonzero_before_one = Choice::from(0u8); + + for (i, el) in db.iter().skip(h_size).enumerate() { + let equals0 = el.ct_eq(&0u8); + let equals1 = el.ct_eq(&1u8); + index.conditional_assign(&(i as u32), looking_for_index & equals1); + looking_for_index &= !equals1; + nonzero_before_one |= looking_for_index & !equals0; + } + + let valid = first_byte_is_zero & hash_are_equal & !nonzero_before_one & !looking_for_index; + + Ok(CtOption::new( + (em.to_vec(), index + 2 + (h_size * 2) as u32), + valid, + )) +} diff --git a/src/algorithms/pkcs1v15.rs b/src/algorithms/pkcs1v15.rs new file mode 100644 index 00000000..c1f0779a --- /dev/null +++ b/src/algorithms/pkcs1v15.rs @@ -0,0 +1,198 @@ +//! PKCS#1 v1.5 support as described in [RFC8017 § 8.2]. +//! +//! # Usage +//! +//! See [code example in the toplevel rustdoc](../index.html#pkcs1-v15-signatures). +//! +//! [RFC8017 § 8.2]: https://datatracker.ietf.org/doc/html/rfc8017#section-8.2 + +use alloc::vec::Vec; +use digest::Digest; +use pkcs8::AssociatedOid; +use rand_core::CryptoRngCore; +use subtle::{Choice, ConditionallySelectable, ConstantTimeEq}; +use zeroize::Zeroizing; + +use crate::errors::{Error, Result}; + +/// Fills the provided slice with random values, which are guaranteed +/// to not be zero. +#[inline] +fn non_zero_random_bytes(rng: &mut R, data: &mut [u8]) { + rng.fill_bytes(data); + + for el in data { + if *el == 0u8 { + // TODO: break after a certain amount of time + while *el == 0u8 { + rng.fill_bytes(core::slice::from_mut(el)); + } + } + } +} + +/// Applied the padding scheme from PKCS#1 v1.5 for encryption. The message must be no longer than +/// the length of the public modulus minus 11 bytes. +pub(crate) fn pkcs1v15_encrypt_pad( + rng: &mut R, + msg: &[u8], + k: usize, +) -> Result>> +where + R: CryptoRngCore + ?Sized, +{ + if msg.len() > k - 11 { + return Err(Error::MessageTooLong); + } + + // EM = 0x00 || 0x02 || PS || 0x00 || M + let mut em = Zeroizing::new(vec![0u8; k]); + em[1] = 2; + non_zero_random_bytes(rng, &mut em[2..k - msg.len() - 1]); + em[k - msg.len() - 1] = 0; + em[k - msg.len()..].copy_from_slice(msg); + Ok(em) +} + +/// Removes the encryption padding scheme from PKCS#1 v1.5. +/// +/// Note that whether this function returns an error or not discloses secret +/// information. If an attacker can cause this function to run repeatedly and +/// learn whether each instance returned an error then they can decrypt and +/// forge signatures as if they had the private key. See +/// `decrypt_session_key` for a way of solving this problem. +#[inline] +pub(crate) fn pkcs1v15_encrypt_unpad(em: Vec, k: usize) -> Result> { + let (valid, out, index) = decrypt_inner(em, k)?; + if valid == 0 { + return Err(Error::Decryption); + } + + Ok(out[index as usize..].to_vec()) +} + +/// Removes the PKCS1v15 padding It returns one or zero in valid that indicates whether the +/// plaintext was correctly structured. In either case, the plaintext is +/// returned in em so that it may be read independently of whether it was valid +/// in order to maintain constant memory access patterns. If the plaintext was +/// valid then index contains the index of the original message in em. +#[inline] +fn decrypt_inner(em: Vec, k: usize) -> Result<(u8, Vec, u32)> { + if k < 11 { + return Err(Error::Decryption); + } + + let first_byte_is_zero = em[0].ct_eq(&0u8); + let second_byte_is_two = em[1].ct_eq(&2u8); + + // The remainder of the plaintext must be a string of non-zero random + // octets, followed by a 0, followed by the message. + // looking_for_index: 1 iff we are still looking for the zero. + // index: the offset of the first zero byte. + let mut looking_for_index = 1u8; + let mut index = 0u32; + + for (i, el) in em.iter().enumerate().skip(2) { + let equals0 = el.ct_eq(&0u8); + index.conditional_assign(&(i as u32), Choice::from(looking_for_index) & equals0); + looking_for_index.conditional_assign(&0u8, equals0); + } + + // The PS padding must be at least 8 bytes long, and it starts two + // bytes into em. + // TODO: WARNING: THIS MUST BE CONSTANT TIME CHECK: + // Ref: https://github.com/dalek-cryptography/subtle/issues/20 + // This is currently copy & paste from the constant time impl in + // go, but very likely not sufficient. + let valid_ps = Choice::from((((2i32 + 8i32 - index as i32 - 1i32) >> 31) & 1) as u8); + let valid = + first_byte_is_zero & second_byte_is_two & Choice::from(!looking_for_index & 1) & valid_ps; + index = u32::conditional_select(&0, &(index + 1), valid); + + Ok((valid.unwrap_u8(), em, index)) +} + +#[inline] +pub(crate) fn pkcs1v15_sign_pad(prefix: &[u8], hashed: &[u8], k: usize) -> Result> { + let hash_len = hashed.len(); + let t_len = prefix.len() + hashed.len(); + if k < t_len + 11 { + return Err(Error::MessageTooLong); + } + + // EM = 0x00 || 0x01 || PS || 0x00 || T + let mut em = vec![0xff; k]; + em[0] = 0; + em[1] = 1; + em[k - t_len - 1] = 0; + em[k - t_len..k - hash_len].copy_from_slice(prefix); + em[k - hash_len..k].copy_from_slice(hashed); + + Ok(em) +} + +#[inline] +pub(crate) fn pkcs1v15_sign_unpad(prefix: &[u8], hashed: &[u8], em: &[u8], k: usize) -> Result<()> { + let hash_len = hashed.len(); + let t_len = prefix.len() + hashed.len(); + if k < t_len + 11 { + return Err(Error::Verification); + } + + // EM = 0x00 || 0x01 || PS || 0x00 || T + let mut ok = em[0].ct_eq(&0u8); + ok &= em[1].ct_eq(&1u8); + ok &= em[k - hash_len..k].ct_eq(hashed); + ok &= em[k - t_len..k - hash_len].ct_eq(prefix); + ok &= em[k - t_len - 1].ct_eq(&0u8); + + for el in em.iter().skip(2).take(k - t_len - 3) { + ok &= el.ct_eq(&0xff) + } + + if ok.unwrap_u8() != 1 { + return Err(Error::Verification); + } + + Ok(()) +} + +/// prefix = 0x30 0x30 0x06 oid 0x05 0x00 0x04 +#[inline] +pub(crate) fn pkcs1v15_generate_prefix() -> Vec +where + D: Digest + AssociatedOid, +{ + let oid = D::OID.as_bytes(); + let oid_len = oid.len() as u8; + let digest_len = ::output_size() as u8; + let mut v = vec![ + 0x30, + oid_len + 8 + digest_len, + 0x30, + oid_len + 4, + 0x6, + oid_len, + ]; + v.extend_from_slice(oid); + v.extend_from_slice(&[0x05, 0x00, 0x04, digest_len]); + v +} + +#[cfg(test)] +mod tests { + use super::*; + use rand_chacha::{rand_core::SeedableRng, ChaCha8Rng}; + + #[test] + fn test_non_zero_bytes() { + for _ in 0..10 { + let mut rng = ChaCha8Rng::from_seed([42; 32]); + let mut b = vec![0u8; 512]; + non_zero_random_bytes(&mut rng, &mut b); + for el in &b { + assert_ne!(*el, 0u8); + } + } + } +} diff --git a/src/algorithms/pss.rs b/src/algorithms/pss.rs new file mode 100644 index 00000000..db58584d --- /dev/null +++ b/src/algorithms/pss.rs @@ -0,0 +1,334 @@ +//! Support for the [Probabilistic Signature Scheme] (PSS) a.k.a. RSASSA-PSS. +//! +//! Designed by Mihir Bellare and Phillip Rogaway. Specified in [RFC8017 § 8.1]. +//! +//! # Usage +//! +//! See [code example in the toplevel rustdoc](../index.html#pss-signatures). +//! +//! [Probabilistic Signature Scheme]: https://en.wikipedia.org/wiki/Probabilistic_signature_scheme +//! [RFC8017 § 8.1]: https://datatracker.ietf.org/doc/html/rfc8017#section-8.1 + +use alloc::vec::Vec; +use digest::{Digest, DynDigest, FixedOutputReset}; +use subtle::{Choice, ConstantTimeEq}; + +use super::mgf::{mgf1_xor, mgf1_xor_digest}; +use crate::errors::{Error, Result}; + +pub(crate) fn emsa_pss_encode( + m_hash: &[u8], + em_bits: usize, + salt: &[u8], + hash: &mut dyn DynDigest, +) -> Result> { + // See [1], section 9.1.1 + let h_len = hash.output_size(); + let s_len = salt.len(); + let em_len = (em_bits + 7) / 8; + + // 1. If the length of M is greater than the input limitation for the + // hash function (2^61 - 1 octets for SHA-1), output "message too + // long" and stop. + // + // 2. Let mHash = Hash(M), an octet string of length hLen. + if m_hash.len() != h_len { + return Err(Error::InputNotHashed); + } + + // 3. If em_len < h_len + s_len + 2, output "encoding error" and stop. + if em_len < h_len + s_len + 2 { + // TODO: Key size too small + return Err(Error::Internal); + } + + let mut em = vec![0; em_len]; + + let (db, h) = em.split_at_mut(em_len - h_len - 1); + let h = &mut h[..(em_len - 1) - db.len()]; + + // 4. Generate a random octet string salt of length s_len; if s_len = 0, + // then salt is the empty string. + // + // 5. Let + // M' = (0x)00 00 00 00 00 00 00 00 || m_hash || salt; + // + // M' is an octet string of length 8 + h_len + s_len with eight + // initial zero octets. + // + // 6. Let H = Hash(M'), an octet string of length h_len. + let prefix = [0u8; 8]; + + hash.update(&prefix); + hash.update(m_hash); + hash.update(salt); + + let hashed = hash.finalize_reset(); + h.copy_from_slice(&hashed); + + // 7. Generate an octet string PS consisting of em_len - s_len - h_len - 2 + // zero octets. The length of PS may be 0. + // + // 8. Let DB = PS || 0x01 || salt; DB is an octet string of length + // emLen - hLen - 1. + db[em_len - s_len - h_len - 2] = 0x01; + db[em_len - s_len - h_len - 1..].copy_from_slice(salt); + + // 9. Let dbMask = MGF(H, emLen - hLen - 1). + // + // 10. Let maskedDB = DB \xor dbMask. + mgf1_xor(db, hash, h); + + // 11. Set the leftmost 8 * em_len - em_bits bits of the leftmost octet in + // maskedDB to zero. + db[0] &= 0xFF >> (8 * em_len - em_bits); + + // 12. Let EM = maskedDB || H || 0xbc. + em[em_len - 1] = 0xBC; + + Ok(em) +} + +pub(crate) fn emsa_pss_encode_digest( + m_hash: &[u8], + em_bits: usize, + salt: &[u8], +) -> Result> +where + D: Digest + FixedOutputReset, +{ + // See [1], section 9.1.1 + let h_len = ::output_size(); + let s_len = salt.len(); + let em_len = (em_bits + 7) / 8; + + // 1. If the length of M is greater than the input limitation for the + // hash function (2^61 - 1 octets for SHA-1), output "message too + // long" and stop. + // + // 2. Let mHash = Hash(M), an octet string of length hLen. + if m_hash.len() != h_len { + return Err(Error::InputNotHashed); + } + + // 3. If em_len < h_len + s_len + 2, output "encoding error" and stop. + if em_len < h_len + s_len + 2 { + // TODO: Key size too small + return Err(Error::Internal); + } + + let mut em = vec![0; em_len]; + + let (db, h) = em.split_at_mut(em_len - h_len - 1); + let h = &mut h[..(em_len - 1) - db.len()]; + + // 4. Generate a random octet string salt of length s_len; if s_len = 0, + // then salt is the empty string. + // + // 5. Let + // M' = (0x)00 00 00 00 00 00 00 00 || m_hash || salt; + // + // M' is an octet string of length 8 + h_len + s_len with eight + // initial zero octets. + // + // 6. Let H = Hash(M'), an octet string of length h_len. + let prefix = [0u8; 8]; + + let mut hash = D::new(); + + Digest::update(&mut hash, prefix); + Digest::update(&mut hash, m_hash); + Digest::update(&mut hash, salt); + + let hashed = hash.finalize_reset(); + h.copy_from_slice(&hashed); + + // 7. Generate an octet string PS consisting of em_len - s_len - h_len - 2 + // zero octets. The length of PS may be 0. + // + // 8. Let DB = PS || 0x01 || salt; DB is an octet string of length + // emLen - hLen - 1. + db[em_len - s_len - h_len - 2] = 0x01; + db[em_len - s_len - h_len - 1..].copy_from_slice(salt); + + // 9. Let dbMask = MGF(H, emLen - hLen - 1). + // + // 10. Let maskedDB = DB \xor dbMask. + mgf1_xor_digest(db, &mut hash, h); + + // 11. Set the leftmost 8 * em_len - em_bits bits of the leftmost octet in + // maskedDB to zero. + db[0] &= 0xFF >> (8 * em_len - em_bits); + + // 12. Let EM = maskedDB || H || 0xbc. + em[em_len - 1] = 0xBC; + + Ok(em) +} + +fn emsa_pss_verify_pre<'a>( + m_hash: &[u8], + em: &'a mut [u8], + em_bits: usize, + s_len: usize, + h_len: usize, +) -> Result<(&'a mut [u8], &'a mut [u8])> { + // 1. If the length of M is greater than the input limitation for the + // hash function (2^61 - 1 octets for SHA-1), output "inconsistent" + // and stop. + // + // 2. Let mHash = Hash(M), an octet string of length hLen + if m_hash.len() != h_len { + return Err(Error::Verification); + } + + // 3. If emLen < hLen + sLen + 2, output "inconsistent" and stop. + let em_len = em.len(); //(em_bits + 7) / 8; + if em_len < h_len + s_len + 2 { + return Err(Error::Verification); + } + + // 4. If the rightmost octet of EM does not have hexadecimal value + // 0xbc, output "inconsistent" and stop. + if em[em.len() - 1] != 0xBC { + return Err(Error::Verification); + } + + // 5. Let maskedDB be the leftmost emLen - hLen - 1 octets of EM, and + // let H be the next hLen octets. + let (db, h) = em.split_at_mut(em_len - h_len - 1); + let h = &mut h[..h_len]; + + // 6. If the leftmost 8 * em_len - em_bits bits of the leftmost octet in + // maskedDB are not all equal to zero, output "inconsistent" and + // stop. + if db[0] + & (0xFF_u8 + .checked_shl(8 - (8 * em_len - em_bits) as u32) + .unwrap_or(0)) + != 0 + { + return Err(Error::Verification); + } + + Ok((db, h)) +} + +fn emsa_pss_verify_salt(db: &[u8], em_len: usize, s_len: usize, h_len: usize) -> Choice { + // 10. If the emLen - hLen - sLen - 2 leftmost octets of DB are not zero + // or if the octet at position emLen - hLen - sLen - 1 (the leftmost + // position is "position 1") does not have hexadecimal value 0x01, + // output "inconsistent" and stop. + let (zeroes, rest) = db.split_at(em_len - h_len - s_len - 2); + let valid: Choice = zeroes + .iter() + .fold(Choice::from(1u8), |a, e| a & e.ct_eq(&0x00)); + + valid & rest[0].ct_eq(&0x01) +} + +pub(crate) fn emsa_pss_verify( + m_hash: &[u8], + em: &mut [u8], + s_len: usize, + hash: &mut dyn DynDigest, + key_bits: usize, +) -> Result<()> { + let em_bits = key_bits - 1; + let em_len = (em_bits + 7) / 8; + let key_len = (key_bits + 7) / 8; + let h_len = hash.output_size(); + + let em = &mut em[key_len - em_len..]; + + let (db, h) = emsa_pss_verify_pre(m_hash, em, em_bits, s_len, h_len)?; + + // 7. Let dbMask = MGF(H, em_len - h_len - 1) + // + // 8. Let DB = maskedDB \xor dbMask + mgf1_xor(db, hash, &*h); + + // 9. Set the leftmost 8 * emLen - emBits bits of the leftmost octet in DB + // to zero. + db[0] &= 0xFF >> /*uint*/(8 * em_len - em_bits); + + let salt_valid = emsa_pss_verify_salt(db, em_len, s_len, h_len); + + // 11. Let salt be the last s_len octets of DB. + let salt = &db[db.len() - s_len..]; + + // 12. Let + // M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt ; + // M' is an octet string of length 8 + hLen + sLen with eight + // initial zero octets. + // + // 13. Let H' = Hash(M'), an octet string of length hLen. + let prefix = [0u8; 8]; + + hash.update(&prefix[..]); + hash.update(m_hash); + hash.update(salt); + let h0 = hash.finalize_reset(); + + // 14. If H = H', output "consistent." Otherwise, output "inconsistent." + if (salt_valid & h0.ct_eq(h)).into() { + Ok(()) + } else { + Err(Error::Verification) + } +} + +pub(crate) fn emsa_pss_verify_digest( + m_hash: &[u8], + em: &mut [u8], + s_len: usize, + key_bits: usize, +) -> Result<()> +where + D: Digest + FixedOutputReset, +{ + let em_bits = key_bits - 1; + let em_len = (em_bits + 7) / 8; + let key_len = (key_bits + 7) / 8; + let h_len = ::output_size(); + + let em = &mut em[key_len - em_len..]; + + let (db, h) = emsa_pss_verify_pre(m_hash, em, em_bits, s_len, h_len)?; + + let mut hash = D::new(); + + // 7. Let dbMask = MGF(H, em_len - h_len - 1) + // + // 8. Let DB = maskedDB \xor dbMask + mgf1_xor_digest::(db, &mut hash, &*h); + + // 9. Set the leftmost 8 * emLen - emBits bits of the leftmost octet in DB + // to zero. + db[0] &= 0xFF >> /*uint*/(8 * em_len - em_bits); + + let salt_valid = emsa_pss_verify_salt(db, em_len, s_len, h_len); + + // 11. Let salt be the last s_len octets of DB. + let salt = &db[db.len() - s_len..]; + + // 12. Let + // M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt ; + // M' is an octet string of length 8 + hLen + sLen with eight + // initial zero octets. + // + // 13. Let H' = Hash(M'), an octet string of length hLen. + let prefix = [0u8; 8]; + + Digest::update(&mut hash, &prefix[..]); + Digest::update(&mut hash, m_hash); + Digest::update(&mut hash, salt); + let h0 = hash.finalize_reset(); + + // 14. If H = H', output "consistent." Otherwise, output "inconsistent." + if (salt_valid & h0.ct_eq(h)).into() { + Ok(()) + } else { + Err(Error::Verification) + } +} diff --git a/src/internals.rs b/src/internals.rs index 90ae2181..1c3ba6be 100644 --- a/src/internals.rs +++ b/src/internals.rs @@ -4,7 +4,7 @@ use alloc::vec::Vec; use num_bigint::{BigInt, BigUint, IntoBigInt, IntoBigUint, ModInverse, RandBigInt, ToBigInt}; use num_traits::{One, Signed, Zero}; use rand_core::CryptoRngCore; -use zeroize::Zeroize; +use zeroize::{Zeroize, Zeroizing}; use crate::errors::{Error, Result}; use crate::key::{PublicKeyParts, RsaPrivateKey}; @@ -18,7 +18,7 @@ pub fn encrypt(key: &K, m: &BigUint) -> BigUint { /// Performs raw RSA decryption with no padding, resulting in a plaintext `BigUint`. /// Peforms RSA blinding if an `Rng` is passed. #[inline] -pub fn decrypt( +fn decrypt( mut rng: Option<&mut R>, priv_key: &RsaPrivateKey, c: &BigUint, @@ -127,7 +127,7 @@ pub fn decrypt_and_check( } /// Returns the blinded c, along with the unblinding factor. -pub fn blind( +fn blind( rng: &mut R, key: &K, c: &BigUint, @@ -168,13 +168,13 @@ pub fn blind( } /// Given an m and and unblinding factor, unblind the m. -pub fn unblind(key: &impl PublicKeyParts, m: &BigUint, unblinder: &BigUint) -> BigUint { +fn unblind(key: &impl PublicKeyParts, m: &BigUint, unblinder: &BigUint) -> BigUint { (m * unblinder) % key.n() } /// Returns a new vector of the given length, with 0s left padded. #[inline] -pub fn left_pad(input: &[u8], padded_len: usize) -> Result> { +fn left_pad(input: &[u8], padded_len: usize) -> Result> { if input.len() > padded_len { return Err(Error::InvalidPadLen); } @@ -184,6 +184,20 @@ pub fn left_pad(input: &[u8], padded_len: usize) -> Result> { Ok(out) } +/// Converts input to the new vector of the given length, using BE and with 0s left padded. +#[inline] +pub fn uint_to_be_pad(input: BigUint, padded_len: usize) -> Result> { + left_pad(&input.to_bytes_be(), padded_len) +} + +/// Converts input to the new vector of the given length, using BE and with 0s left padded. +#[inline] +pub fn uint_to_zeroizing_be_pad(input: BigUint, padded_len: usize) -> Result> { + let m = Zeroizing::new(input); + let m = Zeroizing::new(m.to_bytes_be()); + left_pad(&m, padded_len) +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/key.rs b/src/key.rs index 87784395..0f24d497 100644 --- a/src/key.rs +++ b/src/key.rs @@ -12,9 +12,9 @@ use zeroize::Zeroize; use crate::algorithms::{generate_multi_prime_key, generate_multi_prime_key_with_exp}; use crate::dummy_rng::DummyRng; use crate::errors::{Error, Result}; +use crate::internals; use crate::padding::{PaddingScheme, SignatureScheme}; -use crate::raw::{DecryptionPrimitive, EncryptionPrimitive}; /// Components of an RSA public key. pub trait PublicKeyParts { @@ -31,8 +31,6 @@ pub trait PublicKeyParts { } } -pub trait PrivateKey: DecryptionPrimitive + PublicKeyParts {} - /// Represents the public part of an RSA key. #[derive(Debug, Clone, Hash, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] @@ -161,25 +159,6 @@ impl From<&RsaPrivateKey> for RsaPublicKey { } } -/// Generic trait for operations on a public key. -pub trait PublicKey: EncryptionPrimitive + PublicKeyParts { - /// Encrypt the given message. - fn encrypt( - &self, - rng: &mut R, - padding: P, - msg: &[u8], - ) -> Result>; - - /// Verify a signed message. - /// - /// `hashed` must be the result of hashing the input using the hashing function - /// passed in through `hash`. - /// - /// If the message is valid `Ok(())` is returned, otherwise an `Err` indicating failure. - fn verify(&self, scheme: S, hashed: &[u8], sig: &[u8]) -> Result<()>; -} - impl PublicKeyParts for RsaPublicKey { fn n(&self) -> &BigUint { &self.n @@ -190,8 +169,9 @@ impl PublicKeyParts for RsaPublicKey { } } -impl PublicKey for RsaPublicKey { - fn encrypt( +impl RsaPublicKey { + /// Encrypt the given message. + pub fn encrypt( &self, rng: &mut R, padding: P, @@ -200,7 +180,13 @@ impl PublicKey for RsaPublicKey { padding.encrypt(rng, self, msg) } - fn verify(&self, scheme: S, hashed: &[u8], sig: &[u8]) -> Result<()> { + /// Verify a signed message. + /// + /// `hashed` must be the result of hashing the input using the hashing function + /// passed in through `hash`. + /// + /// If the message is valid `Ok(())` is returned, otherwise an `Err` indicating failure. + pub fn verify(&self, scheme: S, hashed: &[u8], sig: &[u8]) -> Result<()> { scheme.verify(self, hashed, sig) } } @@ -239,6 +225,10 @@ impl RsaPublicKey { pub fn new_unchecked(n: BigUint, e: BigUint) -> Self { Self { n, e } } + + pub(crate) fn raw_int_encryption_primitive(&self, plaintext: &BigUint) -> Result { + Ok(internals::encrypt(self, plaintext)) + } } impl PublicKeyParts for RsaPrivateKey { @@ -251,8 +241,6 @@ impl PublicKeyParts for RsaPrivateKey { } } -impl PrivateKey for RsaPrivateKey {} - impl RsaPrivateKey { /// Generate a new Rsa key pair of the given bit size using the passed in `rng`. pub fn new(rng: &mut R, bit_size: usize) -> Result { @@ -461,6 +449,15 @@ impl RsaPrivateKey { ) -> Result> { padding.sign(Some(rng), self, digest_in) } + + /// Do NOT use directly! Only for implementors. + pub(crate) fn raw_int_decryption_primitive( + &self, + rng: Option<&mut R>, + ciphertext: &BigUint, + ) -> Result { + internals::decrypt_and_check(rng, self, ciphertext) + } } /// Check that the public key is well formed and has an exponent within acceptable bounds. @@ -495,7 +492,6 @@ fn check_public_with_max_size(public_key: &impl PublicKeyParts, max_size: usize) #[cfg(test)] mod tests { use super::*; - use crate::internals; use hex_literal::hex; use num_traits::{FromPrimitive, ToPrimitive}; @@ -528,12 +524,16 @@ mod tests { let pub_key: RsaPublicKey = private_key.clone().into(); let m = BigUint::from_u64(42).expect("invalid 42"); - let c = internals::encrypt(&pub_key, &m); - let m2 = internals::decrypt::(None, private_key, &c) + let c = pub_key + .raw_int_encryption_primitive(&m) + .expect("encryption successfull"); + let m2 = private_key + .raw_int_decryption_primitive::(None, &c) .expect("unable to decrypt without blinding"); assert_eq!(m, m2); let mut rng = ChaCha8Rng::from_seed([42; 32]); - let m3 = internals::decrypt(Some(&mut rng), private_key, &c) + let m3 = private_key + .raw_int_decryption_primitive(Some(&mut rng), &c) .expect("unable to decrypt with blinding"); assert_eq!(m, m3); } diff --git a/src/lib.rs b/src/lib.rs index 76956fe6..5efcb86b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,7 +14,7 @@ //! //! ## PKCS#1 v1.5 encryption //! ``` -//! use rsa::{PublicKey, RsaPrivateKey, RsaPublicKey, Pkcs1v15Encrypt}; +//! use rsa::{RsaPrivateKey, RsaPublicKey, Pkcs1v15Encrypt}; //! //! let mut rng = rand::thread_rng(); //! @@ -34,7 +34,7 @@ //! //! ## OAEP encryption //! ``` -//! use rsa::{PublicKey, RsaPrivateKey, RsaPublicKey, Oaep}; +//! use rsa::{RsaPrivateKey, RsaPublicKey, Oaep}; //! //! let mut rng = rand::thread_rng(); //! @@ -233,7 +233,6 @@ mod dummy_rng; mod encoding; mod key; mod padding; -mod raw; pub use pkcs1; pub use pkcs8; @@ -241,7 +240,7 @@ pub use pkcs8; pub use sha2; pub use crate::{ - key::{PublicKey, PublicKeyParts, RsaPrivateKey, RsaPublicKey}, + key::{PublicKeyParts, RsaPrivateKey, RsaPublicKey}, oaep::Oaep, padding::{PaddingScheme, SignatureScheme}, pkcs1v15::{Pkcs1v15Encrypt, Pkcs1v15Sign}, diff --git a/src/oaep.rs b/src/oaep.rs index 309a0cd1..07755cd9 100644 --- a/src/oaep.rs +++ b/src/oaep.rs @@ -5,27 +5,23 @@ //! See [code example in the toplevel rustdoc](../index.html#oaep-encryption). use alloc::boxed::Box; use alloc::string::{String, ToString}; -use alloc::vec; use alloc::vec::Vec; use core::fmt; use core::marker::PhantomData; use rand_core::CryptoRngCore; use digest::{Digest, DynDigest, FixedOutputReset}; -use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption}; +use num_bigint::BigUint; use zeroize::Zeroizing; -use crate::algorithms::{mgf1_xor, mgf1_xor_digest}; +use crate::algorithms::oaep::*; use crate::dummy_rng::DummyRng; use crate::errors::{Error, Result}; -use crate::key::{self, PrivateKey, PublicKey, RsaPrivateKey, RsaPublicKey}; +use crate::internals::{uint_to_be_pad, uint_to_zeroizing_be_pad}; +use crate::key::{self, PublicKeyParts, RsaPrivateKey, RsaPublicKey}; use crate::padding::PaddingScheme; use crate::traits::{Decryptor, RandomizedDecryptor, RandomizedEncryptor}; -// 2**61 -1 (pow is not const yet) -// TODO: This is the maximum for SHA-1, unclear from the RFC what the values are for other hashing functions. -const MAX_LABEL_LEN: u64 = 2_305_843_009_213_693_951; - /// Encryption and Decryption using [OAEP padding](https://datatracker.ietf.org/doc/html/rfc8017#section-7.1). /// /// - `digest` is used to hash the label. The maximum possible plaintext length is `m = k - 2 * h_len - 2`, @@ -55,7 +51,7 @@ impl Oaep { /// ``` /// use sha1::Sha1; /// use sha2::Sha256; - /// use rsa::{BigUint, RsaPublicKey, Oaep, PublicKey}; + /// use rsa::{BigUint, RsaPublicKey, Oaep, }; /// use base64ct::{Base64, Encoding}; /// /// let n = Base64::decode_vec("ALHgDoZmBQIx+jTmgeeHW6KsPOrj11f6CvWsiRleJlQpW77AwSZhd21ZDmlTKfaIHBSUxRUsuYNh7E2SHx8rkFVCQA2/gXkZ5GK2IUbzSTio9qXA25MWHvVxjMfKSL8ZAxZyKbrG94FLLszFAFOaiLLY8ECs7g+dXOriYtBwLUJK+lppbd+El+8ZA/zH0bk7vbqph5pIoiWggxwdq3mEz4LnrUln7r6dagSQzYErKewY8GADVpXcq5mfHC1xF2DFBub7bFjMVM5fHq7RK+pG5xjNDiYITbhLYrbVv3X0z75OvN0dY49ITWjM7xyvMWJXVJS7sJlgmCCL6RwWgP8PhcE=").unwrap(); @@ -92,7 +88,7 @@ impl Oaep { /// ``` /// use sha1::Sha1; /// use sha2::Sha256; - /// use rsa::{BigUint, RsaPublicKey, Oaep, PublicKey}; + /// use rsa::{BigUint, RsaPublicKey, Oaep, }; /// use base64ct::{Base64, Encoding}; /// /// let n = Base64::decode_vec("ALHgDoZmBQIx+jTmgeeHW6KsPOrj11f6CvWsiRleJlQpW77AwSZhd21ZDmlTKfaIHBSUxRUsuYNh7E2SHx8rkFVCQA2/gXkZ5GK2IUbzSTio9qXA25MWHvVxjMfKSL8ZAxZyKbrG94FLLszFAFOaiLLY8ECs7g+dXOriYtBwLUJK+lppbd+El+8ZA/zH0bk7vbqph5pIoiWggxwdq3mEz4LnrUln7r6dagSQzYErKewY8GADVpXcq5mfHC1xF2DFBub7bFjMVM5fHq7RK+pG5xjNDiYITbhLYrbVv3X0z75OvN0dY49ITWjM7xyvMWJXVJS7sJlgmCCL6RwWgP8PhcE=").unwrap(); @@ -131,10 +127,10 @@ impl Oaep { } impl PaddingScheme for Oaep { - fn decrypt( + fn decrypt( mut self, rng: Option<&mut Rng>, - priv_key: &Priv, + priv_key: &RsaPrivateKey, ciphertext: &[u8], ) -> Result> { decrypt( @@ -147,10 +143,10 @@ impl PaddingScheme for Oaep { ) } - fn encrypt( + fn encrypt( mut self, rng: &mut Rng, - pub_key: &Pub, + pub_key: &RsaPublicKey, msg: &[u8], ) -> Result> { encrypt( @@ -174,41 +170,6 @@ impl fmt::Debug for Oaep { } } -#[inline] -fn encrypt_internal( - rng: &mut R, - pub_key: &K, - msg: &[u8], - p_hash: &[u8], - h_size: usize, - mut mgf: MGF, -) -> Result> { - key::check_public(pub_key)?; - - let k = pub_key.size(); - - if msg.len() + 2 * h_size + 2 > k { - return Err(Error::MessageTooLong); - } - - let mut em = Zeroizing::new(vec![0u8; k]); - - let (_, payload) = em.split_at_mut(1); - let (seed, db) = payload.split_at_mut(h_size); - rng.fill_bytes(seed); - - // Data block DB = pHash || PS || 01 || M - let db_len = k - h_size - 1; - - db[0..h_size].copy_from_slice(p_hash); - db[db_len - msg.len() - 1] = 1; - db[db_len - msg.len()..].copy_from_slice(msg); - - mgf(seed, db); - - pub_key.raw_encryption_primitive(&em, pub_key.size()) -} - /// Encrypts the given message with RSA and the padding scheme from /// [PKCS#1 OAEP]. /// @@ -217,28 +178,20 @@ fn encrypt_internal( +fn encrypt( rng: &mut R, - pub_key: &K, + pub_key: &RsaPublicKey, msg: &[u8], digest: &mut dyn DynDigest, mgf_digest: &mut dyn DynDigest, label: Option, ) -> Result> { - let h_size = digest.output_size(); - - let label = label.unwrap_or_default(); - if label.len() as u64 > MAX_LABEL_LEN { - return Err(Error::LabelTooLong); - } + key::check_public(pub_key)?; - digest.update(label.as_bytes()); - let p_hash = digest.finalize_reset(); + let em = oaep_encrypt(rng, msg, digest, mgf_digest, label, pub_key.size())?; - encrypt_internal(rng, pub_key, msg, &p_hash, h_size, |seed, db| { - mgf1_xor(db, mgf_digest, seed); - mgf1_xor(seed, mgf_digest, db); - }) + let int = Zeroizing::new(BigUint::from_bytes_be(&em)); + uint_to_be_pad(pub_key.raw_int_encryption_primitive(&int)?, pub_key.size()) } /// Encrypts the given message with RSA and the padding scheme from @@ -248,32 +201,18 @@ fn encrypt( /// `2 + (2 * hash.size())`. /// /// [PKCS#1 OAEP]: https://datatracker.ietf.org/doc/html/rfc8017#section-7.1 -#[inline] -fn encrypt_digest< - R: CryptoRngCore + ?Sized, - K: PublicKey, - D: Digest, - MGD: Digest + FixedOutputReset, ->( +fn encrypt_digest( rng: &mut R, - pub_key: &K, + pub_key: &RsaPublicKey, msg: &[u8], label: Option, ) -> Result> { - let h_size = ::output_size(); - - let label = label.unwrap_or_default(); - if label.len() as u64 > MAX_LABEL_LEN { - return Err(Error::LabelTooLong); - } + key::check_public(pub_key)?; - let p_hash = D::digest(label.as_bytes()); + let em = oaep_encrypt_digest::<_, D, MGD>(rng, msg, label, pub_key.size())?; - encrypt_internal(rng, pub_key, msg, &p_hash, h_size, |seed, db| { - let mut mgf_digest = MGD::new(); - mgf1_xor_digest(db, &mut mgf_digest, seed); - mgf1_xor_digest(seed, &mut mgf_digest, db); - }) + let int = Zeroizing::new(BigUint::from_bytes_be(&em)); + uint_to_be_pad(pub_key.raw_int_encryption_primitive(&int)?, pub_key.size()) } /// Decrypts a plaintext using RSA and the padding scheme from [PKCS#1 OAEP]. @@ -289,9 +228,9 @@ fn encrypt_digest< /// /// [PKCS#1 OAEP]: https://datatracker.ietf.org/doc/html/rfc8017#section-7.1 #[inline] -fn decrypt( +fn decrypt( rng: Option<&mut R>, - priv_key: &SK, + priv_key: &RsaPrivateKey, ciphertext: &[u8], digest: &mut dyn DynDigest, mgf_digest: &mut dyn DynDigest, @@ -299,35 +238,14 @@ fn decrypt( ) -> Result> { key::check_public(priv_key)?; - let h_size = digest.output_size(); - - let label = label.unwrap_or_default(); - if label.len() as u64 > MAX_LABEL_LEN { + if ciphertext.len() != priv_key.size() { return Err(Error::Decryption); } - digest.update(label.as_bytes()); - - let expected_p_hash = digest.finalize_reset(); - - let res = decrypt_inner( - rng, - priv_key, - ciphertext, - h_size, - &expected_p_hash, - |seed, db| { - mgf1_xor(seed, mgf_digest, db); - mgf1_xor(db, mgf_digest, seed); - }, - )?; - if res.is_none().into() { - return Err(Error::Decryption); - } - - let (out, index) = res.unwrap(); + let em = priv_key.raw_int_decryption_primitive(rng, &BigUint::from_bytes_be(ciphertext))?; + let mut em = uint_to_zeroizing_be_pad(em, priv_key.size())?; - Ok(out[index as usize..].to_vec()) + oaep_decrypt(&mut em, digest, mgf_digest, label, priv_key.size()) } /// Decrypts a plaintext using RSA and the padding scheme from [PKCS#1 OAEP]. @@ -343,101 +261,22 @@ fn decrypt( /// /// [PKCS#1 OAEP]: https://datatracker.ietf.org/doc/html/rfc8017#section-7.1 #[inline] -fn decrypt_digest< - R: CryptoRngCore + ?Sized, - SK: PrivateKey, - D: Digest, - MGD: Digest + FixedOutputReset, ->( +fn decrypt_digest( rng: Option<&mut R>, - priv_key: &SK, + priv_key: &RsaPrivateKey, ciphertext: &[u8], label: Option, ) -> Result> { key::check_public(priv_key)?; - let h_size = ::output_size(); - - let label = label.unwrap_or_default(); - if label.len() as u64 > MAX_LABEL_LEN { - return Err(Error::LabelTooLong); - } - - let expected_p_hash = D::digest(label.as_bytes()); - - let res = decrypt_inner( - rng, - priv_key, - ciphertext, - h_size, - &expected_p_hash, - |seed, db| { - let mut mgf_digest = MGD::new(); - mgf1_xor_digest(seed, &mut mgf_digest, db); - mgf1_xor_digest(db, &mut mgf_digest, seed); - }, - )?; - if res.is_none().into() { + if ciphertext.len() != priv_key.size() { return Err(Error::Decryption); } - let (out, index) = res.unwrap(); - - Ok(out[index as usize..].to_vec()) -} - -/// Decrypts ciphertext using `priv_key` and blinds the operation if -/// `rng` is given. It returns one or zero in valid that indicates whether the -/// plaintext was correctly structured. -#[inline] -fn decrypt_inner( - rng: Option<&mut R>, - priv_key: &SK, - ciphertext: &[u8], - h_size: usize, - expected_p_hash: &[u8], - mut mgf: MGF, -) -> Result, u32)>> { - let k = priv_key.size(); - if k < 11 { - return Err(Error::Decryption); - } - - if ciphertext.len() != k || k < h_size * 2 + 2 { - return Err(Error::Decryption); - } - - let mut em = priv_key.raw_decryption_primitive(rng, ciphertext, priv_key.size())?; - - let first_byte_is_zero = em[0].ct_eq(&0u8); - - let (_, payload) = em.split_at_mut(1); - let (seed, db) = payload.split_at_mut(h_size); - - mgf(seed, db); - - let hash_are_equal = db[0..h_size].ct_eq(expected_p_hash); - - // The remainder of the plaintext must be zero or more 0x00, followed - // by 0x01, followed by the message. - // looking_for_index: 1 if we are still looking for the 0x01 - // index: the offset of the first 0x01 byte - // zero_before_one: 1 if we saw a non-zero byte before the 1 - let mut looking_for_index = Choice::from(1u8); - let mut index = 0u32; - let mut nonzero_before_one = Choice::from(0u8); - - for (i, el) in db.iter().skip(h_size).enumerate() { - let equals0 = el.ct_eq(&0u8); - let equals1 = el.ct_eq(&1u8); - index.conditional_assign(&(i as u32), looking_for_index & equals1); - looking_for_index &= !equals1; - nonzero_before_one |= looking_for_index & !equals0; - } - - let valid = first_byte_is_zero & hash_are_equal & !nonzero_before_one & !looking_for_index; + let em = priv_key.raw_int_decryption_primitive(rng, &BigUint::from_bytes_be(ciphertext))?; + let mut em = uint_to_zeroizing_be_pad(em, priv_key.size())?; - Ok(CtOption::new((em, index + 2 + (h_size * 2) as u32), valid)) + oaep_decrypt_digest::(&mut em, label, priv_key.size()) } /// Encryption key for PKCS#1 v1.5 encryption as described in [RFC8017 § 7.1]. @@ -491,7 +330,7 @@ where rng: &mut R, msg: &[u8], ) -> Result> { - encrypt_digest::<_, _, D, MGD>(rng, &self.inner, msg, self.label.as_ref().cloned()) + encrypt_digest::<_, D, MGD>(rng, &self.inner, msg, self.label.as_ref().cloned()) } } @@ -542,7 +381,7 @@ where MGD: Digest + FixedOutputReset, { fn decrypt(&self, ciphertext: &[u8]) -> Result> { - decrypt_digest::( + decrypt_digest::( None, &self.inner, ciphertext, @@ -561,7 +400,7 @@ where rng: &mut R, ciphertext: &[u8], ) -> Result> { - decrypt_digest::<_, _, D, MGD>( + decrypt_digest::<_, D, MGD>( Some(rng), &self.inner, ciphertext, @@ -572,7 +411,7 @@ where #[cfg(test)] mod tests { - use crate::key::{PublicKey, PublicKeyParts, RsaPrivateKey, RsaPublicKey}; + use crate::key::{PublicKeyParts, RsaPrivateKey, RsaPublicKey}; use crate::oaep::{DecryptingKey, EncryptingKey, Oaep}; use crate::traits::{Decryptor, RandomizedDecryptor, RandomizedEncryptor}; diff --git a/src/padding.rs b/src/padding.rs index 391779f6..ce198fc3 100644 --- a/src/padding.rs +++ b/src/padding.rs @@ -5,7 +5,7 @@ use alloc::vec::Vec; use rand_core::CryptoRngCore; use crate::errors::Result; -use crate::key::{PrivateKey, PublicKey}; +use crate::key::{RsaPrivateKey, RsaPublicKey}; /// Padding scheme used for encryption. pub trait PaddingScheme { @@ -13,18 +13,18 @@ pub trait PaddingScheme { /// /// If an `rng` is passed, it uses RSA blinding to help mitigate timing /// side-channel attacks. - fn decrypt( + fn decrypt( self, rng: Option<&mut Rng>, - priv_key: &Priv, + priv_key: &RsaPrivateKey, ciphertext: &[u8], ) -> Result>; /// Encrypt the given message using the given public key. - fn encrypt( + fn encrypt( self, rng: &mut Rng, - pub_key: &Pub, + pub_key: &RsaPublicKey, msg: &[u8], ) -> Result>; } @@ -32,10 +32,10 @@ pub trait PaddingScheme { /// Digital signature scheme. pub trait SignatureScheme { /// Sign the given digest. - fn sign( + fn sign( self, rng: Option<&mut Rng>, - priv_key: &Priv, + priv_key: &RsaPrivateKey, hashed: &[u8], ) -> Result>; @@ -45,5 +45,5 @@ pub trait SignatureScheme { /// passed in through `hash`. /// /// If the message is valid `Ok(())` is returned, otherwise an `Err` indicating failure. - fn verify(self, pub_key: &Pub, hashed: &[u8], sig: &[u8]) -> Result<()>; + fn verify(self, pub_key: &RsaPublicKey, hashed: &[u8], sig: &[u8]) -> Result<()>; } diff --git a/src/pkcs1v15.rs b/src/pkcs1v15.rs index a6cb13a3..d82a7303 100644 --- a/src/pkcs1v15.rs +++ b/src/pkcs1v15.rs @@ -24,12 +24,13 @@ use signature::{ DigestSigner, DigestVerifier, Keypair, RandomizedDigestSigner, RandomizedSigner, SignatureEncoding, Signer, Verifier, }; -use subtle::{Choice, ConditionallySelectable, ConstantTimeEq}; use zeroize::Zeroizing; +use crate::algorithms::pkcs1v15::*; use crate::dummy_rng::DummyRng; use crate::errors::{Error, Result}; -use crate::key::{self, PrivateKey, PublicKey}; +use crate::internals::{uint_to_be_pad, uint_to_zeroizing_be_pad}; +use crate::key::{self, PublicKeyParts}; use crate::padding::{PaddingScheme, SignatureScheme}; use crate::traits::{Decryptor, EncryptingKeypair, RandomizedDecryptor, RandomizedEncryptor}; use crate::{RsaPrivateKey, RsaPublicKey}; @@ -39,19 +40,19 @@ use crate::{RsaPrivateKey, RsaPublicKey}; pub struct Pkcs1v15Encrypt; impl PaddingScheme for Pkcs1v15Encrypt { - fn decrypt( + fn decrypt( self, rng: Option<&mut Rng>, - priv_key: &Priv, + priv_key: &RsaPrivateKey, ciphertext: &[u8], ) -> Result> { decrypt(rng, priv_key, ciphertext) } - fn encrypt( + fn encrypt( self, rng: &mut Rng, - pub_key: &Pub, + pub_key: &RsaPublicKey, msg: &[u8], ) -> Result> { encrypt(rng, pub_key, msg) @@ -79,7 +80,7 @@ impl Pkcs1v15Sign { { Self { hash_len: Some(::output_size()), - prefix: generate_prefix::().into_boxed_slice(), + prefix: pkcs1v15_generate_prefix::().into_boxed_slice(), } } @@ -103,10 +104,10 @@ impl Pkcs1v15Sign { } impl SignatureScheme for Pkcs1v15Sign { - fn sign( + fn sign( self, rng: Option<&mut Rng>, - priv_key: &Priv, + priv_key: &RsaPrivateKey, hashed: &[u8], ) -> Result> { if let Some(hash_len) = self.hash_len { @@ -118,7 +119,7 @@ impl SignatureScheme for Pkcs1v15Sign { sign(rng, priv_key, &self.prefix, hashed) } - fn verify(self, pub_key: &Pub, hashed: &[u8], sig: &[u8]) -> Result<()> { + fn verify(self, pub_key: &RsaPublicKey, hashed: &[u8], sig: &[u8]) -> Result<()> { if let Some(hash_len) = self.hash_len { if hashed.len() != hash_len { return Err(Error::InputNotHashed); @@ -192,26 +193,16 @@ impl Display for Signature { /// scheme from PKCS#1 v1.5. The message must be no longer than the /// length of the public modulus minus 11 bytes. #[inline] -pub(crate) fn encrypt( +fn encrypt( rng: &mut R, - pub_key: &PK, + pub_key: &RsaPublicKey, msg: &[u8], ) -> Result> { key::check_public(pub_key)?; - let k = pub_key.size(); - if msg.len() > k - 11 { - return Err(Error::MessageTooLong); - } - - // EM = 0x00 || 0x02 || PS || 0x00 || M - let mut em = Zeroizing::new(vec![0u8; k]); - em[1] = 2; - non_zero_random_bytes(rng, &mut em[2..k - msg.len() - 1]); - em[k - msg.len() - 1] = 0; - em[k - msg.len()..].copy_from_slice(msg); - - pub_key.raw_encryption_primitive(&em, pub_key.size()) + let em = pkcs1v15_encrypt_pad(rng, msg, pub_key.size())?; + let int = Zeroizing::new(BigUint::from_bytes_be(&em)); + uint_to_be_pad(pub_key.raw_int_encryption_primitive(&int)?, pub_key.size()) } /// Decrypts a plaintext using RSA and the padding scheme from PKCS#1 v1.5. @@ -224,19 +215,17 @@ pub(crate) fn encrypt( /// forge signatures as if they had the private key. See /// `decrypt_session_key` for a way of solving this problem. #[inline] -pub(crate) fn decrypt( +fn decrypt( rng: Option<&mut R>, - priv_key: &SK, + priv_key: &RsaPrivateKey, ciphertext: &[u8], ) -> Result> { key::check_public(priv_key)?; - let (valid, out, index) = decrypt_inner(rng, priv_key, ciphertext)?; - if valid == 0 { - return Err(Error::Decryption); - } + let em = priv_key.raw_int_decryption_primitive(rng, &BigUint::from_bytes_be(ciphertext))?; + let em = uint_to_zeroizing_be_pad(em, priv_key.size())?; - Ok(out[index as usize..].to_vec()) + pkcs1v15_encrypt_unpad(em, priv_key.size()) } /// Calculates the signature of hashed using @@ -253,150 +242,26 @@ pub(crate) fn decrypt( /// messages to signatures and identify the signed messages. As ever, /// signatures provide authenticity, not confidentiality. #[inline] -pub(crate) fn sign( +fn sign( rng: Option<&mut R>, - priv_key: &SK, + priv_key: &RsaPrivateKey, prefix: &[u8], hashed: &[u8], ) -> Result> { - let hash_len = hashed.len(); - let t_len = prefix.len() + hashed.len(); - let k = priv_key.size(); - if k < t_len + 11 { - return Err(Error::MessageTooLong); - } - - // EM = 0x00 || 0x01 || PS || 0x00 || T - let mut em = vec![0xff; k]; - em[0] = 0; - em[1] = 1; - em[k - t_len - 1] = 0; - em[k - t_len..k - hash_len].copy_from_slice(prefix); - em[k - hash_len..k].copy_from_slice(hashed); - - priv_key.raw_decryption_primitive(rng, &em, priv_key.size()) -} + let em = pkcs1v15_sign_pad(prefix, hashed, priv_key.size())?; -/// Verifies an RSA PKCS#1 v1.5 signature. -#[inline] -pub(crate) fn verify( - pub_key: &PK, - prefix: &[u8], - hashed: &[u8], - sig: &BigUint, -) -> Result<()> { - let hash_len = hashed.len(); - let t_len = prefix.len() + hashed.len(); - let k = pub_key.size(); - if k < t_len + 11 { - return Err(Error::Verification); - } - - let em = pub_key.raw_int_encryption_primitive(sig, pub_key.size())?; - - // EM = 0x00 || 0x01 || PS || 0x00 || T - let mut ok = em[0].ct_eq(&0u8); - ok &= em[1].ct_eq(&1u8); - ok &= em[k - hash_len..k].ct_eq(hashed); - ok &= em[k - t_len..k - hash_len].ct_eq(prefix); - ok &= em[k - t_len - 1].ct_eq(&0u8); - - for el in em.iter().skip(2).take(k - t_len - 3) { - ok &= el.ct_eq(&0xff) - } - - if ok.unwrap_u8() != 1 { - return Err(Error::Verification); - } - - Ok(()) + uint_to_zeroizing_be_pad( + priv_key.raw_int_decryption_primitive(rng, &BigUint::from_bytes_be(&em))?, + priv_key.size(), + ) } -/// prefix = 0x30 0x30 0x06 oid 0x05 0x00 0x04 -#[inline] -pub(crate) fn generate_prefix() -> Vec -where - D: Digest + AssociatedOid, -{ - let oid = D::OID.as_bytes(); - let oid_len = oid.len() as u8; - let digest_len = ::output_size() as u8; - let mut v = vec![ - 0x30, - oid_len + 8 + digest_len, - 0x30, - oid_len + 4, - 0x6, - oid_len, - ]; - v.extend_from_slice(oid); - v.extend_from_slice(&[0x05, 0x00, 0x04, digest_len]); - v -} - -/// Decrypts ciphertext using `priv_key` and blinds the operation if -/// `rng` is given. It returns one or zero in valid that indicates whether the -/// plaintext was correctly structured. In either case, the plaintext is -/// returned in em so that it may be read independently of whether it was valid -/// in order to maintain constant memory access patterns. If the plaintext was -/// valid then index contains the index of the original message in em. +/// Verifies an RSA PKCS#1 v1.5 signature. #[inline] -fn decrypt_inner( - rng: Option<&mut R>, - priv_key: &SK, - ciphertext: &[u8], -) -> Result<(u8, Vec, u32)> { - let k = priv_key.size(); - if k < 11 { - return Err(Error::Decryption); - } - - let em = priv_key.raw_decryption_primitive(rng, ciphertext, priv_key.size())?; - - let first_byte_is_zero = em[0].ct_eq(&0u8); - let second_byte_is_two = em[1].ct_eq(&2u8); - - // The remainder of the plaintext must be a string of non-zero random - // octets, followed by a 0, followed by the message. - // looking_for_index: 1 iff we are still looking for the zero. - // index: the offset of the first zero byte. - let mut looking_for_index = 1u8; - let mut index = 0u32; - - for (i, el) in em.iter().enumerate().skip(2) { - let equals0 = el.ct_eq(&0u8); - index.conditional_assign(&(i as u32), Choice::from(looking_for_index) & equals0); - looking_for_index.conditional_assign(&0u8, equals0); - } - - // The PS padding must be at least 8 bytes long, and it starts two - // bytes into em. - // TODO: WARNING: THIS MUST BE CONSTANT TIME CHECK: - // Ref: https://github.com/dalek-cryptography/subtle/issues/20 - // This is currently copy & paste from the constant time impl in - // go, but very likely not sufficient. - let valid_ps = Choice::from((((2i32 + 8i32 - index as i32 - 1i32) >> 31) & 1) as u8); - let valid = - first_byte_is_zero & second_byte_is_two & Choice::from(!looking_for_index & 1) & valid_ps; - index = u32::conditional_select(&0, &(index + 1), valid); - - Ok((valid.unwrap_u8(), em, index)) -} +fn verify(pub_key: &RsaPublicKey, prefix: &[u8], hashed: &[u8], sig: &BigUint) -> Result<()> { + let em = uint_to_be_pad(pub_key.raw_int_encryption_primitive(sig)?, pub_key.size())?; -/// Fills the provided slice with random values, which are guaranteed -/// to not be zero. -#[inline] -fn non_zero_random_bytes(rng: &mut R, data: &mut [u8]) { - rng.fill_bytes(data); - - for el in data { - if *el == 0u8 { - // TODO: break after a certain amount of time - while *el == 0u8 { - rng.fill_bytes(core::slice::from_mut(el)); - } - } - } + pkcs1v15_sign_unpad(prefix, hashed, &em, pub_key.size()) } /// Signing key for PKCS#1 v1.5 signatures as described in [RFC8017 § 8.2]. @@ -490,7 +355,7 @@ where pub fn new(key: RsaPrivateKey) -> Self { Self { inner: key, - prefix: generate_prefix::(), + prefix: pkcs1v15_generate_prefix::(), phantom: Default::default(), } } @@ -499,7 +364,7 @@ where pub fn random(rng: &mut R, bit_size: usize) -> Result { Ok(Self { inner: RsaPrivateKey::new(rng, bit_size)?, - prefix: generate_prefix::(), + prefix: pkcs1v15_generate_prefix::(), phantom: Default::default(), }) } @@ -543,7 +408,7 @@ where D: Digest, { fn try_sign(&self, msg: &[u8]) -> signature::Result { - sign::(None, &self.inner, &self.prefix, &D::digest(msg))? + sign::(None, &self.inner, &self.prefix, &D::digest(msg))? .as_slice() .try_into() } @@ -569,7 +434,7 @@ where D: Digest, { fn try_sign_digest(&self, digest: D) -> signature::Result { - sign::(None, &self.inner, &self.prefix, &digest.finalize())? + sign::(None, &self.inner, &self.prefix, &digest.finalize())? .as_slice() .try_into() } @@ -595,7 +460,7 @@ where D: Digest, { fn sign_prehash(&self, prehash: &[u8]) -> signature::Result { - sign::(None, &self.inner, &self.prefix, prehash)? + sign::(None, &self.inner, &self.prefix, prehash)? .as_slice() .try_into() } @@ -694,7 +559,7 @@ where pub fn new(key: RsaPublicKey) -> Self { Self { inner: key, - prefix: generate_prefix::(), + prefix: pkcs1v15_generate_prefix::(), phantom: Default::default(), } } @@ -819,7 +684,7 @@ impl DecryptingKey { impl Decryptor for DecryptingKey { fn decrypt(&self, ciphertext: &[u8]) -> Result> { - decrypt::(None, &self.inner, ciphertext) + decrypt::(None, &self.inner, ciphertext) } } @@ -899,19 +764,7 @@ mod tests { use sha3::Sha3_256; use signature::{RandomizedSigner, Signer, Verifier}; - use crate::{PublicKey, PublicKeyParts, RsaPrivateKey, RsaPublicKey}; - - #[test] - fn test_non_zero_bytes() { - for _ in 0..10 { - let mut rng = ChaCha8Rng::from_seed([42; 32]); - let mut b = vec![0u8; 512]; - non_zero_random_bytes(&mut rng, &mut b); - for el in &b { - assert_ne!(*el, 0u8); - } - } - } + use crate::{PublicKeyParts, RsaPrivateKey, RsaPublicKey}; fn get_private_key() -> RsaPrivateKey { // In order to generate new test vectors you'll need the PEM form of this key: diff --git a/src/pss.rs b/src/pss.rs index 1ef9a259..8713a8dd 100644 --- a/src/pss.rs +++ b/src/pss.rs @@ -9,13 +9,13 @@ //! [Probabilistic Signature Scheme]: https://en.wikipedia.org/wiki/Probabilistic_signature_scheme //! [RFC8017 § 8.1]: https://datatracker.ietf.org/doc/html/rfc8017#section-8.1 -use alloc::boxed::Box; -use alloc::vec::Vec; +use alloc::{boxed::Box, string::ToString, vec::Vec}; use core::fmt::{self, Debug, Display, Formatter, LowerHex, UpperHex}; use core::marker::PhantomData; use const_oid::{AssociatedOid, ObjectIdentifier}; use digest::{Digest, DynDigest, FixedOutputReset}; +use num_bigint::BigUint; use pkcs1::RsaPssParams; use pkcs8::{ spki::{ @@ -30,11 +30,11 @@ use signature::{ hazmat::{PrehashVerifier, RandomizedPrehashSigner}, DigestVerifier, Keypair, RandomizedDigestSigner, RandomizedSigner, SignatureEncoding, Verifier, }; -use subtle::{Choice, ConstantTimeEq}; -use crate::algorithms::{mgf1_xor, mgf1_xor_digest}; +use crate::algorithms::pss::*; use crate::errors::{Error, Result}; -use crate::key::{PrivateKey, PublicKey}; +use crate::internals::{uint_to_be_pad, uint_to_zeroizing_be_pad}; +use crate::key::PublicKeyParts; use crate::padding::SignatureScheme; use crate::{RsaPrivateKey, RsaPublicKey}; @@ -89,10 +89,10 @@ impl Pss { } impl SignatureScheme for Pss { - fn sign( + fn sign( mut self, rng: Option<&mut Rng>, - priv_key: &Priv, + priv_key: &RsaPrivateKey, hashed: &[u8], ) -> Result> { sign( @@ -105,8 +105,15 @@ impl SignatureScheme for Pss { ) } - fn verify(mut self, pub_key: &Pub, hashed: &[u8], sig: &[u8]) -> Result<()> { - verify(pub_key, hashed, sig, &mut *self.digest, self.salt_len) + fn verify(mut self, pub_key: &RsaPublicKey, hashed: &[u8], sig: &[u8]) -> Result<()> { + verify( + pub_key, + hashed, + &BigUint::from_bytes_be(sig), + sig.len(), + &mut *self.digest, + self.salt_len, + ) } } @@ -125,62 +132,48 @@ impl Debug for Pss { /// [RFC8017 § 8.1]: https://datatracker.ietf.org/doc/html/rfc8017#section-8.1 #[derive(Clone, PartialEq, Eq)] pub struct Signature { - bytes: Box<[u8]>, + inner: BigUint, + len: usize, } impl SignatureEncoding for Signature { type Repr = Box<[u8]>; } -impl From> for Signature { - fn from(bytes: Box<[u8]>) -> Self { - Self { bytes } - } -} - impl TryFrom<&[u8]> for Signature { type Error = signature::Error; fn try_from(bytes: &[u8]) -> signature::Result { Ok(Self { - bytes: bytes.into(), + len: bytes.len(), + inner: BigUint::from_bytes_be(bytes), }) } } impl From for Box<[u8]> { fn from(signature: Signature) -> Box<[u8]> { - signature.bytes + signature.inner.to_bytes_be().into_boxed_slice() } } impl Debug for Signature { fn fmt(&self, fmt: &mut Formatter<'_>) -> core::result::Result<(), core::fmt::Error> { - fmt.debug_list().entries(self.bytes.iter()).finish() - } -} - -impl AsRef<[u8]> for Signature { - fn as_ref(&self) -> &[u8] { - self.bytes.as_ref() + fmt.debug_tuple("Signature") + .field(&self.to_string()) + .finish() } } impl LowerHex for Signature { fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { - for byte in self.bytes.iter() { - write!(f, "{:02x}", byte)?; - } - Ok(()) + write!(f, "{:x}", &self.inner) } } impl UpperHex for Signature { fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { - for byte in self.bytes.iter() { - write!(f, "{:02X}", byte)?; - } - Ok(()) + write!(f, "{:X}", &self.inner) } } @@ -190,51 +183,40 @@ impl Display for Signature { } } -pub(crate) fn verify( - pub_key: &PK, +pub(crate) fn verify( + pub_key: &RsaPublicKey, hashed: &[u8], - sig: &[u8], + sig: &BigUint, + sig_len: usize, digest: &mut dyn DynDigest, salt_len: usize, ) -> Result<()> { - if sig.len() != pub_key.size() { + if sig_len != pub_key.size() { return Err(Error::Verification); } - let em_bits = pub_key.n().bits() - 1; - let em_len = (em_bits + 7) / 8; - let key_len = pub_key.size(); - let mut em = pub_key.raw_encryption_primitive(sig, key_len)?; + let mut em = uint_to_be_pad(pub_key.raw_int_encryption_primitive(sig)?, pub_key.size())?; - emsa_pss_verify( - hashed, - &mut em[key_len - em_len..], - em_bits, - salt_len, - digest, - ) + emsa_pss_verify(hashed, &mut em, salt_len, digest, pub_key.n().bits()) } -pub(crate) fn verify_digest( - pub_key: &PK, +pub(crate) fn verify_digest( + pub_key: &RsaPublicKey, hashed: &[u8], - sig: &[u8], + sig: &BigUint, + sig_len: usize, salt_len: usize, ) -> Result<()> where - PK: PublicKey, D: Digest + FixedOutputReset, { - if sig.len() != pub_key.size() { + if sig_len != pub_key.size() { return Err(Error::Verification); } - let em_bits = pub_key.n().bits() - 1; - let em_len = (em_bits + 7) / 8; - let key_len = pub_key.size(); - let mut em = pub_key.raw_encryption_primitive(sig, key_len)?; + let mut em = uint_to_be_pad(pub_key.raw_int_encryption_primitive(sig)?, pub_key.size())?; - emsa_pss_verify_digest::(hashed, &mut em[key_len - em_len..], em_bits, salt_len) + emsa_pss_verify_digest::(hashed, &mut em, salt_len, pub_key.n().bits()) } /// SignPSS calculates the signature of hashed using RSASSA-PSS. @@ -242,10 +224,10 @@ where /// Note that hashed must be the result of hashing the input message using the /// given hash function. The opts argument may be nil, in which case sensible /// defaults are used. -pub(crate) fn sign( +pub(crate) fn sign( rng: &mut T, blind: bool, - priv_key: &SK, + priv_key: &RsaPrivateKey, hashed: &[u8], salt_len: usize, digest: &mut dyn DynDigest, @@ -256,21 +238,17 @@ pub(crate) fn sign( sign_pss_with_salt(blind.then_some(rng), priv_key, hashed, &salt, digest) } -pub(crate) fn sign_digest< - T: CryptoRngCore + ?Sized, - SK: PrivateKey, - D: Digest + FixedOutputReset, ->( +pub(crate) fn sign_digest( rng: &mut T, blind: bool, - priv_key: &SK, + priv_key: &RsaPrivateKey, hashed: &[u8], salt_len: usize, ) -> Result> { let mut salt = vec![0; salt_len]; rng.fill_bytes(&mut salt[..]); - sign_pss_with_salt_digest::<_, _, D>(blind.then_some(rng), priv_key, hashed, &salt) + sign_pss_with_salt_digest::<_, D>(blind.then_some(rng), priv_key, hashed, &salt) } /// signPSSWithSalt calculates the signature of hashed using PSS with specified salt. @@ -278,9 +256,9 @@ pub(crate) fn sign_digest< /// Note that hashed must be the result of hashing the input message using the /// given hash function. salt is a random sequence of bytes whose length will be /// later used to verify the signature. -fn sign_pss_with_salt( +fn sign_pss_with_salt( blind_rng: Option<&mut T>, - priv_key: &SK, + priv_key: &RsaPrivateKey, hashed: &[u8], salt: &[u8], digest: &mut dyn DynDigest, @@ -288,328 +266,25 @@ fn sign_pss_with_salt( let em_bits = priv_key.n().bits() - 1; let em = emsa_pss_encode(hashed, em_bits, salt, digest)?; - priv_key.raw_decryption_primitive(blind_rng, &em, priv_key.size()) + uint_to_zeroizing_be_pad( + priv_key.raw_int_decryption_primitive(blind_rng, &BigUint::from_bytes_be(&em))?, + priv_key.size(), + ) } -fn sign_pss_with_salt_digest< - T: CryptoRngCore + ?Sized, - SK: PrivateKey, - D: Digest + FixedOutputReset, ->( +fn sign_pss_with_salt_digest( blind_rng: Option<&mut T>, - priv_key: &SK, + priv_key: &RsaPrivateKey, hashed: &[u8], salt: &[u8], ) -> Result> { let em_bits = priv_key.n().bits() - 1; let em = emsa_pss_encode_digest::(hashed, em_bits, salt)?; - priv_key.raw_decryption_primitive(blind_rng, &em, priv_key.size()) -} - -fn emsa_pss_encode( - m_hash: &[u8], - em_bits: usize, - salt: &[u8], - hash: &mut dyn DynDigest, -) -> Result> { - // See [1], section 9.1.1 - let h_len = hash.output_size(); - let s_len = salt.len(); - let em_len = (em_bits + 7) / 8; - - // 1. If the length of M is greater than the input limitation for the - // hash function (2^61 - 1 octets for SHA-1), output "message too - // long" and stop. - // - // 2. Let mHash = Hash(M), an octet string of length hLen. - if m_hash.len() != h_len { - return Err(Error::InputNotHashed); - } - - // 3. If em_len < h_len + s_len + 2, output "encoding error" and stop. - if em_len < h_len + s_len + 2 { - // TODO: Key size too small - return Err(Error::Internal); - } - - let mut em = vec![0; em_len]; - - let (db, h) = em.split_at_mut(em_len - h_len - 1); - let h = &mut h[..(em_len - 1) - db.len()]; - - // 4. Generate a random octet string salt of length s_len; if s_len = 0, - // then salt is the empty string. - // - // 5. Let - // M' = (0x)00 00 00 00 00 00 00 00 || m_hash || salt; - // - // M' is an octet string of length 8 + h_len + s_len with eight - // initial zero octets. - // - // 6. Let H = Hash(M'), an octet string of length h_len. - let prefix = [0u8; 8]; - - hash.update(&prefix); - hash.update(m_hash); - hash.update(salt); - - let hashed = hash.finalize_reset(); - h.copy_from_slice(&hashed); - - // 7. Generate an octet string PS consisting of em_len - s_len - h_len - 2 - // zero octets. The length of PS may be 0. - // - // 8. Let DB = PS || 0x01 || salt; DB is an octet string of length - // emLen - hLen - 1. - db[em_len - s_len - h_len - 2] = 0x01; - db[em_len - s_len - h_len - 1..].copy_from_slice(salt); - - // 9. Let dbMask = MGF(H, emLen - hLen - 1). - // - // 10. Let maskedDB = DB \xor dbMask. - mgf1_xor(db, hash, h); - - // 11. Set the leftmost 8 * em_len - em_bits bits of the leftmost octet in - // maskedDB to zero. - db[0] &= 0xFF >> (8 * em_len - em_bits); - - // 12. Let EM = maskedDB || H || 0xbc. - em[em_len - 1] = 0xBC; - - Ok(em) -} - -fn emsa_pss_encode_digest(m_hash: &[u8], em_bits: usize, salt: &[u8]) -> Result> -where - D: Digest + FixedOutputReset, -{ - // See [1], section 9.1.1 - let h_len = ::output_size(); - let s_len = salt.len(); - let em_len = (em_bits + 7) / 8; - - // 1. If the length of M is greater than the input limitation for the - // hash function (2^61 - 1 octets for SHA-1), output "message too - // long" and stop. - // - // 2. Let mHash = Hash(M), an octet string of length hLen. - if m_hash.len() != h_len { - return Err(Error::InputNotHashed); - } - - // 3. If em_len < h_len + s_len + 2, output "encoding error" and stop. - if em_len < h_len + s_len + 2 { - // TODO: Key size too small - return Err(Error::Internal); - } - - let mut em = vec![0; em_len]; - - let (db, h) = em.split_at_mut(em_len - h_len - 1); - let h = &mut h[..(em_len - 1) - db.len()]; - - // 4. Generate a random octet string salt of length s_len; if s_len = 0, - // then salt is the empty string. - // - // 5. Let - // M' = (0x)00 00 00 00 00 00 00 00 || m_hash || salt; - // - // M' is an octet string of length 8 + h_len + s_len with eight - // initial zero octets. - // - // 6. Let H = Hash(M'), an octet string of length h_len. - let prefix = [0u8; 8]; - - let mut hash = D::new(); - - Digest::update(&mut hash, prefix); - Digest::update(&mut hash, m_hash); - Digest::update(&mut hash, salt); - - let hashed = hash.finalize_reset(); - h.copy_from_slice(&hashed); - - // 7. Generate an octet string PS consisting of em_len - s_len - h_len - 2 - // zero octets. The length of PS may be 0. - // - // 8. Let DB = PS || 0x01 || salt; DB is an octet string of length - // emLen - hLen - 1. - db[em_len - s_len - h_len - 2] = 0x01; - db[em_len - s_len - h_len - 1..].copy_from_slice(salt); - - // 9. Let dbMask = MGF(H, emLen - hLen - 1). - // - // 10. Let maskedDB = DB \xor dbMask. - mgf1_xor_digest(db, &mut hash, h); - - // 11. Set the leftmost 8 * em_len - em_bits bits of the leftmost octet in - // maskedDB to zero. - db[0] &= 0xFF >> (8 * em_len - em_bits); - - // 12. Let EM = maskedDB || H || 0xbc. - em[em_len - 1] = 0xBC; - - Ok(em) -} - -fn emsa_pss_verify_pre<'a>( - m_hash: &[u8], - em: &'a mut [u8], - em_bits: usize, - s_len: usize, - h_len: usize, -) -> Result<(&'a mut [u8], &'a mut [u8])> { - // 1. If the length of M is greater than the input limitation for the - // hash function (2^61 - 1 octets for SHA-1), output "inconsistent" - // and stop. - // - // 2. Let mHash = Hash(M), an octet string of length hLen - if m_hash.len() != h_len { - return Err(Error::Verification); - } - - // 3. If emLen < hLen + sLen + 2, output "inconsistent" and stop. - let em_len = em.len(); //(em_bits + 7) / 8; - if em_len < h_len + s_len + 2 { - return Err(Error::Verification); - } - - // 4. If the rightmost octet of EM does not have hexadecimal value - // 0xbc, output "inconsistent" and stop. - if em[em.len() - 1] != 0xBC { - return Err(Error::Verification); - } - - // 5. Let maskedDB be the leftmost emLen - hLen - 1 octets of EM, and - // let H be the next hLen octets. - let (db, h) = em.split_at_mut(em_len - h_len - 1); - let h = &mut h[..h_len]; - - // 6. If the leftmost 8 * em_len - em_bits bits of the leftmost octet in - // maskedDB are not all equal to zero, output "inconsistent" and - // stop. - if db[0] - & (0xFF_u8 - .checked_shl(8 - (8 * em_len - em_bits) as u32) - .unwrap_or(0)) - != 0 - { - return Err(Error::Verification); - } - - Ok((db, h)) -} - -fn emsa_pss_verify_salt(db: &[u8], em_len: usize, s_len: usize, h_len: usize) -> Choice { - // 10. If the emLen - hLen - sLen - 2 leftmost octets of DB are not zero - // or if the octet at position emLen - hLen - sLen - 1 (the leftmost - // position is "position 1") does not have hexadecimal value 0x01, - // output "inconsistent" and stop. - let (zeroes, rest) = db.split_at(em_len - h_len - s_len - 2); - let valid: Choice = zeroes - .iter() - .fold(Choice::from(1u8), |a, e| a & e.ct_eq(&0x00)); - - valid & rest[0].ct_eq(&0x01) -} - -fn emsa_pss_verify( - m_hash: &[u8], - em: &mut [u8], - em_bits: usize, - s_len: usize, - hash: &mut dyn DynDigest, -) -> Result<()> { - let em_len = em.len(); //(em_bits + 7) / 8; - let h_len = hash.output_size(); - - let (db, h) = emsa_pss_verify_pre(m_hash, em, em_bits, s_len, h_len)?; - - // 7. Let dbMask = MGF(H, em_len - h_len - 1) - // - // 8. Let DB = maskedDB \xor dbMask - mgf1_xor(db, hash, &*h); - - // 9. Set the leftmost 8 * emLen - emBits bits of the leftmost octet in DB - // to zero. - db[0] &= 0xFF >> /*uint*/(8 * em_len - em_bits); - - let salt_valid = emsa_pss_verify_salt(db, em_len, s_len, h_len); - - // 11. Let salt be the last s_len octets of DB. - let salt = &db[db.len() - s_len..]; - - // 12. Let - // M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt ; - // M' is an octet string of length 8 + hLen + sLen with eight - // initial zero octets. - // - // 13. Let H' = Hash(M'), an octet string of length hLen. - let prefix = [0u8; 8]; - - hash.update(&prefix[..]); - hash.update(m_hash); - hash.update(salt); - let h0 = hash.finalize_reset(); - - // 14. If H = H', output "consistent." Otherwise, output "inconsistent." - if (salt_valid & h0.ct_eq(h)).into() { - Ok(()) - } else { - Err(Error::Verification) - } -} - -fn emsa_pss_verify_digest( - m_hash: &[u8], - em: &mut [u8], - em_bits: usize, - s_len: usize, -) -> Result<()> -where - D: Digest + FixedOutputReset, -{ - let em_len = em.len(); //(em_bits + 7) / 8; - let h_len = ::output_size(); - - let (db, h) = emsa_pss_verify_pre(m_hash, em, em_bits, s_len, h_len)?; - - let mut hash = D::new(); - - // 7. Let dbMask = MGF(H, em_len - h_len - 1) - // - // 8. Let DB = maskedDB \xor dbMask - mgf1_xor_digest::(db, &mut hash, &*h); - - // 9. Set the leftmost 8 * emLen - emBits bits of the leftmost octet in DB - // to zero. - db[0] &= 0xFF >> /*uint*/(8 * em_len - em_bits); - - let salt_valid = emsa_pss_verify_salt(db, em_len, s_len, h_len); - - // 11. Let salt be the last s_len octets of DB. - let salt = &db[db.len() - s_len..]; - - // 12. Let - // M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt ; - // M' is an octet string of length 8 + hLen + sLen with eight - // initial zero octets. - // - // 13. Let H' = Hash(M'), an octet string of length hLen. - let prefix = [0u8; 8]; - - Digest::update(&mut hash, &prefix[..]); - Digest::update(&mut hash, m_hash); - Digest::update(&mut hash, salt); - let h0 = hash.finalize_reset(); - - // 14. If H = H', output "consistent." Otherwise, output "inconsistent." - if (salt_valid & h0.ct_eq(h)).into() { - Ok(()) - } else { - Err(Error::Verification) - } + uint_to_zeroizing_be_pad( + priv_key.raw_int_decryption_primitive(blind_rng, &BigUint::from_bytes_be(&em))?, + priv_key.size(), + ) } /// Signing key for producing RSASSA-PSS signatures as described in @@ -762,9 +437,9 @@ where rng: &mut impl CryptoRngCore, msg: &[u8], ) -> signature::Result { - sign_digest::<_, _, D>(rng, false, &self.inner, &D::digest(msg), self.salt_len) - .map(|v| v.into_boxed_slice().into()) - .map_err(|e| e.into()) + sign_digest::<_, D>(rng, false, &self.inner, &D::digest(msg), self.salt_len)? + .as_slice() + .try_into() } } @@ -777,9 +452,9 @@ where rng: &mut impl CryptoRngCore, digest: D, ) -> signature::Result { - sign_digest::<_, _, D>(rng, false, &self.inner, &digest.finalize(), self.salt_len) - .map(|v| v.into_boxed_slice().into()) - .map_err(|e| e.into()) + sign_digest::<_, D>(rng, false, &self.inner, &digest.finalize(), self.salt_len)? + .as_slice() + .try_into() } } @@ -792,9 +467,9 @@ where rng: &mut impl CryptoRngCore, prehash: &[u8], ) -> signature::Result { - sign_digest::<_, _, D>(rng, false, &self.inner, prehash, self.salt_len) - .map(|v| v.into_boxed_slice().into()) - .map_err(|e| e.into()) + sign_digest::<_, D>(rng, false, &self.inner, prehash, self.salt_len)? + .as_slice() + .try_into() } } @@ -935,9 +610,9 @@ where rng: &mut impl CryptoRngCore, msg: &[u8], ) -> signature::Result { - sign_digest::<_, _, D>(rng, true, &self.inner, &D::digest(msg), self.salt_len) - .map(|v| v.into_boxed_slice().into()) - .map_err(|e| e.into()) + sign_digest::<_, D>(rng, true, &self.inner, &D::digest(msg), self.salt_len)? + .as_slice() + .try_into() } } @@ -950,9 +625,9 @@ where rng: &mut impl CryptoRngCore, digest: D, ) -> signature::Result { - sign_digest::<_, _, D>(rng, true, &self.inner, &digest.finalize(), self.salt_len) - .map(|v| v.into_boxed_slice().into()) - .map_err(|e| e.into()) + sign_digest::<_, D>(rng, true, &self.inner, &digest.finalize(), self.salt_len)? + .as_slice() + .try_into() } } @@ -965,9 +640,9 @@ where rng: &mut impl CryptoRngCore, prehash: &[u8], ) -> signature::Result { - sign_digest::<_, _, D>(rng, true, &self.inner, prehash, self.salt_len) - .map(|v| v.into_boxed_slice().into()) - .map_err(|e| e.into()) + sign_digest::<_, D>(rng, true, &self.inner, prehash, self.salt_len)? + .as_slice() + .try_into() } } @@ -1060,10 +735,11 @@ where D: Digest + FixedOutputReset, { fn verify(&self, msg: &[u8], signature: &Signature) -> signature::Result<()> { - verify_digest::<_, D>( + verify_digest::( &self.inner, &D::digest(msg), - signature.as_ref(), + &signature.inner, + signature.len, self.salt_len, ) .map_err(|e| e.into()) @@ -1075,10 +751,11 @@ where D: Digest + FixedOutputReset, { fn verify_digest(&self, digest: D, signature: &Signature) -> signature::Result<()> { - verify_digest::<_, D>( + verify_digest::( &self.inner, &digest.finalize(), - signature.as_ref(), + &signature.inner, + signature.len, self.salt_len, ) .map_err(|e| e.into()) @@ -1090,8 +767,14 @@ where D: Digest + FixedOutputReset, { fn verify_prehash(&self, prehash: &[u8], signature: &Signature) -> signature::Result<()> { - verify_digest::<_, D>(&self.inner, prehash, signature.as_ref(), self.salt_len) - .map_err(|e| e.into()) + verify_digest::( + &self.inner, + prehash, + &signature.inner, + signature.len, + self.salt_len, + ) + .map_err(|e| e.into()) } } @@ -1116,7 +799,7 @@ where #[cfg(test)] mod test { use crate::pss::{BlindedSigningKey, Pss, Signature, SigningKey, VerifyingKey}; - use crate::{PublicKey, RsaPrivateKey, RsaPublicKey}; + use crate::{RsaPrivateKey, RsaPublicKey}; use hex_literal::hex; use num_bigint::BigUint; diff --git a/src/raw.rs b/src/raw.rs deleted file mode 100644 index 793e68c2..00000000 --- a/src/raw.rs +++ /dev/null @@ -1,64 +0,0 @@ -use alloc::vec::Vec; -use num_bigint::BigUint; -use rand_core::CryptoRngCore; -use zeroize::Zeroizing; - -use crate::errors::Result; -use crate::internals; -use crate::key::{RsaPrivateKey, RsaPublicKey}; - -pub trait EncryptionPrimitive { - /// Do NOT use directly! Only for implementors. - fn raw_encryption_primitive(&self, plaintext: &[u8], pad_size: usize) -> Result> { - let int = Zeroizing::new(BigUint::from_bytes_be(plaintext)); - self.raw_int_encryption_primitive(&int, pad_size) - } - - fn raw_int_encryption_primitive(&self, plaintext: &BigUint, pad_size: usize) - -> Result>; -} - -pub trait DecryptionPrimitive { - /// Do NOT use directly! Only for implementors. - fn raw_decryption_primitive( - &self, - rng: Option<&mut R>, - ciphertext: &[u8], - pad_size: usize, - ) -> Result> { - let int = Zeroizing::new(BigUint::from_bytes_be(ciphertext)); - self.raw_int_decryption_primitive(rng, &int, pad_size) - } - - fn raw_int_decryption_primitive( - &self, - rng: Option<&mut R>, - ciphertext: &BigUint, - pad_size: usize, - ) -> Result>; -} - -impl EncryptionPrimitive for RsaPublicKey { - fn raw_int_encryption_primitive( - &self, - plaintext: &BigUint, - pad_size: usize, - ) -> Result> { - let c = Zeroizing::new(internals::encrypt(self, &plaintext)); - let c_bytes = Zeroizing::new(c.to_bytes_be()); - internals::left_pad(&c_bytes, pad_size) - } -} - -impl DecryptionPrimitive for RsaPrivateKey { - fn raw_int_decryption_primitive( - &self, - rng: Option<&mut R>, - ciphertext: &BigUint, - pad_size: usize, - ) -> Result> { - let m = Zeroizing::new(internals::decrypt_and_check(rng, self, &ciphertext)?); - let m_bytes = Zeroizing::new(m.to_bytes_be()); - internals::left_pad(&m_bytes, pad_size) - } -}