diff --git a/src/pss.rs b/src/pss.rs index 5042eda5..bc34bbe0 100644 --- a/src/pss.rs +++ b/src/pss.rs @@ -30,7 +30,7 @@ use signature::{ hazmat::{PrehashVerifier, RandomizedPrehashSigner}, DigestVerifier, Keypair, RandomizedDigestSigner, RandomizedSigner, SignatureEncoding, Verifier, }; -use subtle::ConstantTimeEq; +use subtle::{Choice, ConstantTimeEq}; use crate::algorithms::{mgf1_xor, mgf1_xor_digest}; use crate::errors::{Error, Result}; @@ -47,17 +47,14 @@ pub struct Pss { pub digest: Box, /// Salt length. - pub salt_len: Option, + pub salt_len: usize, } impl Pss { /// New PSS padding for the given digest. + /// Digest output size is used as a salt length. pub fn new() -> Self { - Self { - blinded: false, - digest: Box::new(T::new()), - salt_len: None, - } + Self::new_with_salt::(::output_size()) } /// New PSS padding for the given digest with a salt value of the given length. @@ -65,17 +62,14 @@ impl Pss { Self { blinded: false, digest: Box::new(T::new()), - salt_len: Some(len), + salt_len: len, } } /// New PSS padding for blinded signatures (RSA-BSSA) for the given digest. + /// Digest output size is used as a salt length. pub fn new_blinded() -> Self { - Self { - blinded: true, - digest: Box::new(T::new()), - salt_len: None, - } + Self::new_blinded_with_salt::(::output_size()) } /// New PSS padding for blinded signatures (RSA-BSSA) for the given digest @@ -86,7 +80,7 @@ impl Pss { Self { blinded: true, digest: Box::new(T::new()), - salt_len: Some(len), + salt_len: len, } } } @@ -109,7 +103,7 @@ impl SignatureScheme for Pss { } fn verify(mut self, pub_key: &Pub, hashed: &[u8], sig: &[u8]) -> Result<()> { - verify(pub_key, hashed, sig, &mut *self.digest) + verify(pub_key, hashed, sig, &mut *self.digest, self.salt_len) } } @@ -198,6 +192,7 @@ pub(crate) fn verify( hashed: &[u8], sig: &[u8], digest: &mut dyn DynDigest, + salt_len: usize, ) -> Result<()> { if sig.len() != pub_key.size() { return Err(Error::Verification); @@ -208,10 +203,21 @@ pub(crate) fn verify( let key_len = pub_key.size(); let mut em = pub_key.raw_encryption_primitive(sig, key_len)?; - emsa_pss_verify(hashed, &mut em[key_len - em_len..], em_bits, None, digest) + emsa_pss_verify( + hashed, + &mut em[key_len - em_len..], + em_bits, + salt_len, + digest, + ) } -pub(crate) fn verify_digest(pub_key: &PK, hashed: &[u8], sig: &[u8]) -> Result<()> +pub(crate) fn verify_digest( + pub_key: &PK, + hashed: &[u8], + sig: &[u8], + salt_len: usize, +) -> Result<()> where PK: PublicKey, D: Digest + FixedOutputReset, @@ -225,7 +231,7 @@ where let key_len = pub_key.size(); let mut em = pub_key.raw_encryption_primitive(sig, key_len)?; - emsa_pss_verify_digest::(hashed, &mut em[key_len - em_len..], em_bits, None) + emsa_pss_verify_digest::(hashed, &mut em[key_len - em_len..], em_bits, salt_len) } /// SignPSS calculates the signature of hashed using RSASSA-PSS. @@ -238,10 +244,11 @@ pub(crate) fn sign( blind: bool, priv_key: &SK, hashed: &[u8], - salt_len: Option, + salt_len: usize, digest: &mut dyn DynDigest, ) -> Result> { - let salt = generate_salt(rng, priv_key, salt_len, digest.output_size()); + let mut salt = vec![0; salt_len]; + rng.fill_bytes(&mut salt[..]); sign_pss_with_salt(blind.then_some(rng), priv_key, hashed, &salt, digest) } @@ -255,27 +262,12 @@ pub(crate) fn sign_digest< blind: bool, priv_key: &SK, hashed: &[u8], - salt_len: Option, + salt_len: usize, ) -> Result> { - let salt = generate_salt(rng, priv_key, salt_len, ::output_size()); - - sign_pss_with_salt_digest::<_, _, D>(blind.then_some(rng), priv_key, hashed, &salt) -} - -fn generate_salt( - rng: &mut T, - priv_key: &SK, - salt_len: Option, - digest_size: usize, -) -> Vec { - let em_bits = priv_key.n().bits() - 1; - let em_len = (em_bits + 7) / 8; - let salt_len = salt_len.unwrap_or(em_len - 2 - digest_size); - let mut salt = vec![0; salt_len]; rng.fill_bytes(&mut salt[..]); - salt + sign_pss_with_salt_digest::<_, _, D>(blind.then_some(rng), priv_key, hashed, &salt) } /// signPSSWithSalt calculates the signature of hashed using PSS with specified salt. @@ -462,7 +454,7 @@ fn emsa_pss_verify_pre<'a>( m_hash: &[u8], em: &'a mut [u8], em_bits: usize, - s_len: Option, + s_len: usize, h_len: usize, ) -> Result<(&'a mut [u8], &'a mut [u8])> { // 1. If the length of M is greater than the input limitation for the @@ -476,7 +468,7 @@ fn emsa_pss_verify_pre<'a>( // 3. If emLen < hLen + sLen + 2, output "inconsistent" and stop. let em_len = em.len(); //(em_bits + 7) / 8; - if em_len < h_len + s_len.unwrap_or_default() + 2 { + if em_len < h_len + s_len + 2 { return Err(Error::Verification); } @@ -506,47 +498,24 @@ fn emsa_pss_verify_pre<'a>( Ok((db, h)) } -fn emsa_pss_get_salt( - db: &[u8], - em_len: usize, - s_len: Option, - h_len: usize, -) -> Result<&[u8]> { - let s_len = match s_len { - None => (0..=em_len - (h_len + 2)) - .rev() - .try_fold(None, |state, i| match (state, db[em_len - h_len - i - 2]) { - (Some(i), _) => Ok(Some(i)), - (_, 1) => Ok(Some(i)), - (_, 0) => Ok(None), - _ => Err(Error::Verification), - })? - .ok_or(Error::Verification)?, - Some(s_len) => { - // 10. If the emLen - hLen - sLen - 2 leftmost octets of DB are not zero - // or if the octet at position emLen - hLen - sLen - 1 (the leftmost - // position is "position 1") does not have hexadecimal value 0x01, - // output "inconsistent" and stop. - let (zeroes, rest) = db.split_at(em_len - h_len - s_len - 2); - if zeroes.iter().any(|e| *e != 0x00) || rest[0] != 0x01 { - return Err(Error::Verification); - } - - s_len - } - }; +fn emsa_pss_verify_salt(db: &[u8], em_len: usize, s_len: usize, h_len: usize) -> Choice { + // 10. If the emLen - hLen - sLen - 2 leftmost octets of DB are not zero + // or if the octet at position emLen - hLen - sLen - 1 (the leftmost + // position is "position 1") does not have hexadecimal value 0x01, + // output "inconsistent" and stop. + let (zeroes, rest) = db.split_at(em_len - h_len - s_len - 2); + let valid: Choice = zeroes + .iter() + .fold(Choice::from(1u8), |a, e| a & e.ct_eq(&0x00)); - // 11. Let salt be the last s_len octets of DB. - let salt = &db[db.len() - s_len..]; - - Ok(salt) + valid & rest[0].ct_eq(&0x01) } fn emsa_pss_verify( m_hash: &[u8], em: &mut [u8], em_bits: usize, - s_len: Option, + s_len: usize, hash: &mut dyn DynDigest, ) -> Result<()> { let em_len = em.len(); //(em_bits + 7) / 8; @@ -563,7 +532,10 @@ fn emsa_pss_verify( // to zero. db[0] &= 0xFF >> /*uint*/(8 * em_len - em_bits); - let salt = emsa_pss_get_salt(db, em_len, s_len, h_len)?; + let salt_valid = emsa_pss_verify_salt(db, em_len, s_len, h_len); + + // 11. Let salt be the last s_len octets of DB. + let salt = &db[db.len() - s_len..]; // 12. Let // M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt ; @@ -579,7 +551,7 @@ fn emsa_pss_verify( let h0 = hash.finalize_reset(); // 14. If H = H', output "consistent." Otherwise, output "inconsistent." - if h0.ct_eq(h).into() { + if (salt_valid & h0.ct_eq(h)).into() { Ok(()) } else { Err(Error::Verification) @@ -590,7 +562,7 @@ fn emsa_pss_verify_digest( m_hash: &[u8], em: &mut [u8], em_bits: usize, - s_len: Option, + s_len: usize, ) -> Result<()> where D: Digest + FixedOutputReset, @@ -611,7 +583,10 @@ where // to zero. db[0] &= 0xFF >> /*uint*/(8 * em_len - em_bits); - let salt = emsa_pss_get_salt(db, em_len, s_len, h_len)?; + let salt_valid = emsa_pss_verify_salt(db, em_len, s_len, h_len); + + // 11. Let salt be the last s_len octets of DB. + let salt = &db[db.len() - s_len..]; // 12. Let // M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt ; @@ -627,7 +602,7 @@ where let h0 = hash.finalize_reset(); // 14. If H = H', output "consistent." Otherwise, output "inconsistent." - if h0.ct_eq(h).into() { + if (salt_valid & h0.ct_eq(h)).into() { Ok(()) } else { Err(Error::Verification) @@ -644,7 +619,7 @@ where D: Digest, { inner: RsaPrivateKey, - salt_len: Option, + salt_len: usize, phantom: PhantomData, } @@ -653,30 +628,24 @@ where D: Digest, { /// Create a new RSASSA-PSS signing key. + /// Digest output size is used as a salt length. pub fn new(key: RsaPrivateKey) -> Self { - Self { - inner: key, - salt_len: None, - phantom: Default::default(), - } + Self::new_with_salt_len(key, ::output_size()) } /// Create a new RSASSA-PSS signing key with a salt of the given length. pub fn new_with_salt_len(key: RsaPrivateKey, salt_len: usize) -> Self { Self { inner: key, - salt_len: Some(salt_len), + salt_len, phantom: Default::default(), } } /// Generate a new random RSASSA-PSS signing key. + /// Digest output size is used as a salt length. pub fn random(rng: &mut R, bit_size: usize) -> Result { - Ok(Self { - inner: RsaPrivateKey::new(rng, bit_size)?, - salt_len: None, - phantom: Default::default(), - }) + Self::random_with_salt_len(rng, bit_size, ::output_size()) } /// Generate a new random RSASSA-PSS signing key with a salt of the given length. @@ -687,28 +656,24 @@ where ) -> Result { Ok(Self { inner: RsaPrivateKey::new(rng, bit_size)?, - salt_len: Some(salt_len), + salt_len, phantom: Default::default(), }) } /// Return specified salt length for this key - pub fn salt_len(&self) -> Option { + pub fn salt_len(&self) -> usize { self.salt_len } } -fn get_pss_signature_algo_id( - salt_len: Option, -) -> pkcs8::spki::Result +fn get_pss_signature_algo_id(salt_len: u8) -> pkcs8::spki::Result where D: Digest + AssociatedOid, { const ID_MGF_1: ObjectIdentifier = ObjectIdentifier::new_unwrap("1.2.840.113549.1.1.8"); const ID_RSASSA_PSS: ObjectIdentifier = ObjectIdentifier::new_unwrap("1.2.840.113549.1.1.10"); - let salt_len = salt_len.map_or(RsaPssParams::SALT_LEN_DEFAULT, |l| l as u8); - let pss_params = RsaPssParams { hash: AlgorithmIdentifierRef { oid: D::OID, @@ -745,7 +710,7 @@ where D: Digest + AssociatedOid, { fn signature_algorithm_identifier(&self) -> pkcs8::spki::Result { - get_pss_signature_algo_id::(self.salt_len) + get_pss_signature_algo_id::(self.salt_len as u8) } } @@ -784,6 +749,7 @@ where fn verifying_key(&self) -> Self::VerifyingKey { VerifyingKey { inner: self.inner.to_public_key(), + salt_len: self.salt_len, phantom: Default::default(), } } @@ -851,7 +817,7 @@ where D: Digest, { inner: RsaPrivateKey, - salt_len: Option, + salt_len: usize, phantom: PhantomData, } @@ -861,12 +827,9 @@ where { /// Create a new RSASSA-PSS signing key which produces "blinded" /// signatures. + /// Digest output size is used as a salt length. pub fn new(key: RsaPrivateKey) -> Self { - Self { - inner: key, - salt_len: None, - phantom: Default::default(), - } + Self::new_with_salt_len(key, ::output_size()) } /// Create a new RSASSA-PSS signing key which produces "blinded" @@ -874,13 +837,13 @@ where pub fn new_with_salt_len(key: RsaPrivateKey, salt_len: usize) -> Self { Self { inner: key, - salt_len: Some(salt_len), + salt_len, phantom: Default::default(), } } /// Return specified salt length for this key - pub fn salt_len(&self) -> Option { + pub fn salt_len(&self) -> usize { self.salt_len } } @@ -899,7 +862,7 @@ where D: Digest + AssociatedOid, { fn signature_algorithm_identifier(&self) -> pkcs8::spki::Result { - get_pss_signature_algo_id::(self.salt_len) + get_pss_signature_algo_id::(self.salt_len as u8) } } @@ -938,6 +901,7 @@ where fn verifying_key(&self) -> Self::VerifyingKey { VerifyingKey { inner: self.inner.to_public_key(), + salt_len: self.salt_len, phantom: Default::default(), } } @@ -1007,6 +971,7 @@ where D: Digest, { inner: RsaPublicKey, + salt_len: usize, phantom: PhantomData, } @@ -1018,6 +983,7 @@ where fn clone(&self) -> Self { Self { inner: self.inner.clone(), + salt_len: self.salt_len, phantom: Default::default(), } } @@ -1028,9 +994,16 @@ where D: Digest, { /// Create a new RSASSA-PSS verifying key. + /// Digest output size is used as a salt length. pub fn new(key: RsaPublicKey) -> Self { + Self::new_with_salt_len(key, ::output_size()) + } + + /// Create a new RSASSA-PSS verifying key. + pub fn new_with_salt_len(key: RsaPublicKey, salt_len: usize) -> Self { Self { inner: key, + salt_len, phantom: Default::default(), } } @@ -1068,8 +1041,13 @@ where D: Digest + FixedOutputReset, { fn verify(&self, msg: &[u8], signature: &Signature) -> signature::Result<()> { - verify_digest::<_, D>(&self.inner, &D::digest(msg), signature.as_ref()) - .map_err(|e| e.into()) + verify_digest::<_, D>( + &self.inner, + &D::digest(msg), + signature.as_ref(), + self.salt_len, + ) + .map_err(|e| e.into()) } } @@ -1078,8 +1056,13 @@ where D: Digest + FixedOutputReset, { fn verify_digest(&self, digest: D, signature: &Signature) -> signature::Result<()> { - verify_digest::<_, D>(&self.inner, &digest.finalize(), signature.as_ref()) - .map_err(|e| e.into()) + verify_digest::<_, D>( + &self.inner, + &digest.finalize(), + signature.as_ref(), + self.salt_len, + ) + .map_err(|e| e.into()) } } @@ -1088,7 +1071,8 @@ where D: Digest + FixedOutputReset, { fn verify_prehash(&self, prehash: &[u8], signature: &Signature) -> signature::Result<()> { - verify_digest::<_, D>(&self.inner, prehash, signature.as_ref()).map_err(|e| e.into()) + verify_digest::<_, D>(&self.inner, prehash, signature.as_ref(), self.salt_len) + .map_err(|e| e.into()) } }