From 71e51ff11d6fa9c72c2f61ded444a27360ce27ef Mon Sep 17 00:00:00 2001 From: Dmitry Baryshkov Date: Sat, 15 Apr 2023 03:09:33 +0300 Subject: [PATCH 1/3] pss: fix unsalted RSA PSS support Current new() and random() functions cause confusion. There is the default from ASN.1 encoding of RSAPSS parameters (20). There is also another default of (mod_size - 2 - hash_size). And there is a recommendation to use salt_len of hash_size. Drop old defaults and always use digest output size as the salt_len. Clearly document new default. Signed-off-by: Dmitry Baryshkov --- src/pss.rs | 89 +++++++++++++++++------------------------------------- 1 file changed, 28 insertions(+), 61 deletions(-) diff --git a/src/pss.rs b/src/pss.rs index 5042eda5..fa8a5a96 100644 --- a/src/pss.rs +++ b/src/pss.rs @@ -47,17 +47,14 @@ pub struct Pss { pub digest: Box, /// Salt length. - pub salt_len: Option, + pub salt_len: usize, } impl Pss { /// New PSS padding for the given digest. + /// Digest output size is used as a salt length. pub fn new() -> Self { - Self { - blinded: false, - digest: Box::new(T::new()), - salt_len: None, - } + Self::new_with_salt::(::output_size()) } /// New PSS padding for the given digest with a salt value of the given length. @@ -65,17 +62,14 @@ impl Pss { Self { blinded: false, digest: Box::new(T::new()), - salt_len: Some(len), + salt_len: len, } } /// New PSS padding for blinded signatures (RSA-BSSA) for the given digest. + /// Digest output size is used as a salt length. pub fn new_blinded() -> Self { - Self { - blinded: true, - digest: Box::new(T::new()), - salt_len: None, - } + Self::new_blinded_with_salt::(::output_size()) } /// New PSS padding for blinded signatures (RSA-BSSA) for the given digest @@ -86,7 +80,7 @@ impl Pss { Self { blinded: true, digest: Box::new(T::new()), - salt_len: Some(len), + salt_len: len, } } } @@ -238,10 +232,11 @@ pub(crate) fn sign( blind: bool, priv_key: &SK, hashed: &[u8], - salt_len: Option, + salt_len: usize, digest: &mut dyn DynDigest, ) -> Result> { - let salt = generate_salt(rng, priv_key, salt_len, digest.output_size()); + let mut salt = vec![0; salt_len]; + rng.fill_bytes(&mut salt[..]); sign_pss_with_salt(blind.then_some(rng), priv_key, hashed, &salt, digest) } @@ -255,27 +250,12 @@ pub(crate) fn sign_digest< blind: bool, priv_key: &SK, hashed: &[u8], - salt_len: Option, + salt_len: usize, ) -> Result> { - let salt = generate_salt(rng, priv_key, salt_len, ::output_size()); - - sign_pss_with_salt_digest::<_, _, D>(blind.then_some(rng), priv_key, hashed, &salt) -} - -fn generate_salt( - rng: &mut T, - priv_key: &SK, - salt_len: Option, - digest_size: usize, -) -> Vec { - let em_bits = priv_key.n().bits() - 1; - let em_len = (em_bits + 7) / 8; - let salt_len = salt_len.unwrap_or(em_len - 2 - digest_size); - let mut salt = vec![0; salt_len]; rng.fill_bytes(&mut salt[..]); - salt + sign_pss_with_salt_digest::<_, _, D>(blind.then_some(rng), priv_key, hashed, &salt) } /// signPSSWithSalt calculates the signature of hashed using PSS with specified salt. @@ -644,7 +624,7 @@ where D: Digest, { inner: RsaPrivateKey, - salt_len: Option, + salt_len: usize, phantom: PhantomData, } @@ -653,30 +633,24 @@ where D: Digest, { /// Create a new RSASSA-PSS signing key. + /// Digest output size is used as a salt length. pub fn new(key: RsaPrivateKey) -> Self { - Self { - inner: key, - salt_len: None, - phantom: Default::default(), - } + Self::new_with_salt_len(key, ::output_size()) } /// Create a new RSASSA-PSS signing key with a salt of the given length. pub fn new_with_salt_len(key: RsaPrivateKey, salt_len: usize) -> Self { Self { inner: key, - salt_len: Some(salt_len), + salt_len, phantom: Default::default(), } } /// Generate a new random RSASSA-PSS signing key. + /// Digest output size is used as a salt length. pub fn random(rng: &mut R, bit_size: usize) -> Result { - Ok(Self { - inner: RsaPrivateKey::new(rng, bit_size)?, - salt_len: None, - phantom: Default::default(), - }) + Self::random_with_salt_len(rng, bit_size, ::output_size()) } /// Generate a new random RSASSA-PSS signing key with a salt of the given length. @@ -687,28 +661,24 @@ where ) -> Result { Ok(Self { inner: RsaPrivateKey::new(rng, bit_size)?, - salt_len: Some(salt_len), + salt_len, phantom: Default::default(), }) } /// Return specified salt length for this key - pub fn salt_len(&self) -> Option { + pub fn salt_len(&self) -> usize { self.salt_len } } -fn get_pss_signature_algo_id( - salt_len: Option, -) -> pkcs8::spki::Result +fn get_pss_signature_algo_id(salt_len: u8) -> pkcs8::spki::Result where D: Digest + AssociatedOid, { const ID_MGF_1: ObjectIdentifier = ObjectIdentifier::new_unwrap("1.2.840.113549.1.1.8"); const ID_RSASSA_PSS: ObjectIdentifier = ObjectIdentifier::new_unwrap("1.2.840.113549.1.1.10"); - let salt_len = salt_len.map_or(RsaPssParams::SALT_LEN_DEFAULT, |l| l as u8); - let pss_params = RsaPssParams { hash: AlgorithmIdentifierRef { oid: D::OID, @@ -745,7 +715,7 @@ where D: Digest + AssociatedOid, { fn signature_algorithm_identifier(&self) -> pkcs8::spki::Result { - get_pss_signature_algo_id::(self.salt_len) + get_pss_signature_algo_id::(self.salt_len as u8) } } @@ -851,7 +821,7 @@ where D: Digest, { inner: RsaPrivateKey, - salt_len: Option, + salt_len: usize, phantom: PhantomData, } @@ -861,12 +831,9 @@ where { /// Create a new RSASSA-PSS signing key which produces "blinded" /// signatures. + /// Digest output size is used as a salt length. pub fn new(key: RsaPrivateKey) -> Self { - Self { - inner: key, - salt_len: None, - phantom: Default::default(), - } + Self::new_with_salt_len(key, ::output_size()) } /// Create a new RSASSA-PSS signing key which produces "blinded" @@ -874,13 +841,13 @@ where pub fn new_with_salt_len(key: RsaPrivateKey, salt_len: usize) -> Self { Self { inner: key, - salt_len: Some(salt_len), + salt_len, phantom: Default::default(), } } /// Return specified salt length for this key - pub fn salt_len(&self) -> Option { + pub fn salt_len(&self) -> usize { self.salt_len } } @@ -899,7 +866,7 @@ where D: Digest + AssociatedOid, { fn signature_algorithm_identifier(&self) -> pkcs8::spki::Result { - get_pss_signature_algo_id::(self.salt_len) + get_pss_signature_algo_id::(self.salt_len as u8) } } From b6c0332ee850d537262c0521b5d8ac90e147d73f Mon Sep 17 00:00:00 2001 From: Dmitry Baryshkov Date: Sat, 15 Apr 2023 04:07:14 +0300 Subject: [PATCH 2/3] pss: specify salt_len when verifying the message All RSA PSS standards (e.g. RFC 8017) clearly specify that RSA PSS verification has an explicit salt length parameter (rather than determining it from the message). Drop our 'automagic' code and pass salt length when verifying the message. Old functions now default to digest output size as a hash length. Signed-off-by: Dmitry Baryshkov --- src/pss.rs | 96 +++++++++++++++++++++++++++++++----------------------- 1 file changed, 55 insertions(+), 41 deletions(-) diff --git a/src/pss.rs b/src/pss.rs index fa8a5a96..b418aecf 100644 --- a/src/pss.rs +++ b/src/pss.rs @@ -103,7 +103,7 @@ impl SignatureScheme for Pss { } fn verify(mut self, pub_key: &Pub, hashed: &[u8], sig: &[u8]) -> Result<()> { - verify(pub_key, hashed, sig, &mut *self.digest) + verify(pub_key, hashed, sig, &mut *self.digest, self.salt_len) } } @@ -192,6 +192,7 @@ pub(crate) fn verify( hashed: &[u8], sig: &[u8], digest: &mut dyn DynDigest, + salt_len: usize, ) -> Result<()> { if sig.len() != pub_key.size() { return Err(Error::Verification); @@ -202,10 +203,21 @@ pub(crate) fn verify( let key_len = pub_key.size(); let mut em = pub_key.raw_encryption_primitive(sig, key_len)?; - emsa_pss_verify(hashed, &mut em[key_len - em_len..], em_bits, None, digest) + emsa_pss_verify( + hashed, + &mut em[key_len - em_len..], + em_bits, + salt_len, + digest, + ) } -pub(crate) fn verify_digest(pub_key: &PK, hashed: &[u8], sig: &[u8]) -> Result<()> +pub(crate) fn verify_digest( + pub_key: &PK, + hashed: &[u8], + sig: &[u8], + salt_len: usize, +) -> Result<()> where PK: PublicKey, D: Digest + FixedOutputReset, @@ -219,7 +231,7 @@ where let key_len = pub_key.size(); let mut em = pub_key.raw_encryption_primitive(sig, key_len)?; - emsa_pss_verify_digest::(hashed, &mut em[key_len - em_len..], em_bits, None) + emsa_pss_verify_digest::(hashed, &mut em[key_len - em_len..], em_bits, salt_len) } /// SignPSS calculates the signature of hashed using RSASSA-PSS. @@ -442,7 +454,7 @@ fn emsa_pss_verify_pre<'a>( m_hash: &[u8], em: &'a mut [u8], em_bits: usize, - s_len: Option, + s_len: usize, h_len: usize, ) -> Result<(&'a mut [u8], &'a mut [u8])> { // 1. If the length of M is greater than the input limitation for the @@ -456,7 +468,7 @@ fn emsa_pss_verify_pre<'a>( // 3. If emLen < hLen + sLen + 2, output "inconsistent" and stop. let em_len = em.len(); //(em_bits + 7) / 8; - if em_len < h_len + s_len.unwrap_or_default() + 2 { + if em_len < h_len + s_len + 2 { return Err(Error::Verification); } @@ -486,34 +498,14 @@ fn emsa_pss_verify_pre<'a>( Ok((db, h)) } -fn emsa_pss_get_salt( - db: &[u8], - em_len: usize, - s_len: Option, - h_len: usize, -) -> Result<&[u8]> { - let s_len = match s_len { - None => (0..=em_len - (h_len + 2)) - .rev() - .try_fold(None, |state, i| match (state, db[em_len - h_len - i - 2]) { - (Some(i), _) => Ok(Some(i)), - (_, 1) => Ok(Some(i)), - (_, 0) => Ok(None), - _ => Err(Error::Verification), - })? - .ok_or(Error::Verification)?, - Some(s_len) => { - // 10. If the emLen - hLen - sLen - 2 leftmost octets of DB are not zero - // or if the octet at position emLen - hLen - sLen - 1 (the leftmost - // position is "position 1") does not have hexadecimal value 0x01, - // output "inconsistent" and stop. - let (zeroes, rest) = db.split_at(em_len - h_len - s_len - 2); - if zeroes.iter().any(|e| *e != 0x00) || rest[0] != 0x01 { - return Err(Error::Verification); - } - - s_len - } +fn emsa_pss_get_salt(db: &[u8], em_len: usize, s_len: usize, h_len: usize) -> Result<&[u8]> { + // 10. If the emLen - hLen - sLen - 2 leftmost octets of DB are not zero + // or if the octet at position emLen - hLen - sLen - 1 (the leftmost + // position is "position 1") does not have hexadecimal value 0x01, + // output "inconsistent" and stop. + let (zeroes, rest) = db.split_at(em_len - h_len - s_len - 2); + if zeroes.iter().any(|e| *e != 0x00) || rest[0] != 0x01 { + return Err(Error::Verification); }; // 11. Let salt be the last s_len octets of DB. @@ -526,7 +518,7 @@ fn emsa_pss_verify( m_hash: &[u8], em: &mut [u8], em_bits: usize, - s_len: Option, + s_len: usize, hash: &mut dyn DynDigest, ) -> Result<()> { let em_len = em.len(); //(em_bits + 7) / 8; @@ -570,7 +562,7 @@ fn emsa_pss_verify_digest( m_hash: &[u8], em: &mut [u8], em_bits: usize, - s_len: Option, + s_len: usize, ) -> Result<()> where D: Digest + FixedOutputReset, @@ -754,6 +746,7 @@ where fn verifying_key(&self) -> Self::VerifyingKey { VerifyingKey { inner: self.inner.to_public_key(), + salt_len: self.salt_len, phantom: Default::default(), } } @@ -905,6 +898,7 @@ where fn verifying_key(&self) -> Self::VerifyingKey { VerifyingKey { inner: self.inner.to_public_key(), + salt_len: self.salt_len, phantom: Default::default(), } } @@ -974,6 +968,7 @@ where D: Digest, { inner: RsaPublicKey, + salt_len: usize, phantom: PhantomData, } @@ -985,6 +980,7 @@ where fn clone(&self) -> Self { Self { inner: self.inner.clone(), + salt_len: self.salt_len, phantom: Default::default(), } } @@ -995,9 +991,16 @@ where D: Digest, { /// Create a new RSASSA-PSS verifying key. + /// Digest output size is used as a salt length. pub fn new(key: RsaPublicKey) -> Self { + Self::new_with_salt_len(key, ::output_size()) + } + + /// Create a new RSASSA-PSS verifying key. + pub fn new_with_salt_len(key: RsaPublicKey, salt_len: usize) -> Self { Self { inner: key, + salt_len, phantom: Default::default(), } } @@ -1035,8 +1038,13 @@ where D: Digest + FixedOutputReset, { fn verify(&self, msg: &[u8], signature: &Signature) -> signature::Result<()> { - verify_digest::<_, D>(&self.inner, &D::digest(msg), signature.as_ref()) - .map_err(|e| e.into()) + verify_digest::<_, D>( + &self.inner, + &D::digest(msg), + signature.as_ref(), + self.salt_len, + ) + .map_err(|e| e.into()) } } @@ -1045,8 +1053,13 @@ where D: Digest + FixedOutputReset, { fn verify_digest(&self, digest: D, signature: &Signature) -> signature::Result<()> { - verify_digest::<_, D>(&self.inner, &digest.finalize(), signature.as_ref()) - .map_err(|e| e.into()) + verify_digest::<_, D>( + &self.inner, + &digest.finalize(), + signature.as_ref(), + self.salt_len, + ) + .map_err(|e| e.into()) } } @@ -1055,7 +1068,8 @@ where D: Digest + FixedOutputReset, { fn verify_prehash(&self, prehash: &[u8], signature: &Signature) -> signature::Result<()> { - verify_digest::<_, D>(&self.inner, prehash, signature.as_ref()).map_err(|e| e.into()) + verify_digest::<_, D>(&self.inner, prehash, signature.as_ref(), self.salt_len) + .map_err(|e| e.into()) } } From d6a4fb09cfbecfa7dc586567b5aa90e9a50b2842 Mon Sep 17 00:00:00 2001 From: Dmitry Baryshkov Date: Sat, 15 Apr 2023 04:35:10 +0300 Subject: [PATCH 3/3] pss: remove possible non-constatnt time operation in PSS salt handling The emsa_pss_get_salt() is possibly non-constant-time op. Change it to be a contant-time operation. Signed-off-by: Dmitry Baryshkov --- src/pss.rs | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/src/pss.rs b/src/pss.rs index b418aecf..bc34bbe0 100644 --- a/src/pss.rs +++ b/src/pss.rs @@ -30,7 +30,7 @@ use signature::{ hazmat::{PrehashVerifier, RandomizedPrehashSigner}, DigestVerifier, Keypair, RandomizedDigestSigner, RandomizedSigner, SignatureEncoding, Verifier, }; -use subtle::ConstantTimeEq; +use subtle::{Choice, ConstantTimeEq}; use crate::algorithms::{mgf1_xor, mgf1_xor_digest}; use crate::errors::{Error, Result}; @@ -498,20 +498,17 @@ fn emsa_pss_verify_pre<'a>( Ok((db, h)) } -fn emsa_pss_get_salt(db: &[u8], em_len: usize, s_len: usize, h_len: usize) -> Result<&[u8]> { +fn emsa_pss_verify_salt(db: &[u8], em_len: usize, s_len: usize, h_len: usize) -> Choice { // 10. If the emLen - hLen - sLen - 2 leftmost octets of DB are not zero // or if the octet at position emLen - hLen - sLen - 1 (the leftmost // position is "position 1") does not have hexadecimal value 0x01, // output "inconsistent" and stop. let (zeroes, rest) = db.split_at(em_len - h_len - s_len - 2); - if zeroes.iter().any(|e| *e != 0x00) || rest[0] != 0x01 { - return Err(Error::Verification); - }; - - // 11. Let salt be the last s_len octets of DB. - let salt = &db[db.len() - s_len..]; + let valid: Choice = zeroes + .iter() + .fold(Choice::from(1u8), |a, e| a & e.ct_eq(&0x00)); - Ok(salt) + valid & rest[0].ct_eq(&0x01) } fn emsa_pss_verify( @@ -535,7 +532,10 @@ fn emsa_pss_verify( // to zero. db[0] &= 0xFF >> /*uint*/(8 * em_len - em_bits); - let salt = emsa_pss_get_salt(db, em_len, s_len, h_len)?; + let salt_valid = emsa_pss_verify_salt(db, em_len, s_len, h_len); + + // 11. Let salt be the last s_len octets of DB. + let salt = &db[db.len() - s_len..]; // 12. Let // M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt ; @@ -551,7 +551,7 @@ fn emsa_pss_verify( let h0 = hash.finalize_reset(); // 14. If H = H', output "consistent." Otherwise, output "inconsistent." - if h0.ct_eq(h).into() { + if (salt_valid & h0.ct_eq(h)).into() { Ok(()) } else { Err(Error::Verification) @@ -583,7 +583,10 @@ where // to zero. db[0] &= 0xFF >> /*uint*/(8 * em_len - em_bits); - let salt = emsa_pss_get_salt(db, em_len, s_len, h_len)?; + let salt_valid = emsa_pss_verify_salt(db, em_len, s_len, h_len); + + // 11. Let salt be the last s_len octets of DB. + let salt = &db[db.len() - s_len..]; // 12. Let // M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt ; @@ -599,7 +602,7 @@ where let h0 = hash.finalize_reset(); // 14. If H = H', output "consistent." Otherwise, output "inconsistent." - if h0.ct_eq(h).into() { + if (salt_valid & h0.ct_eq(h)).into() { Ok(()) } else { Err(Error::Verification)