Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 63 additions & 7 deletions pkcs1/src/params.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
//! PKCS#1 RSA parameters.

use crate::{Error, Result};
use der::asn1::{AnyRef, ObjectIdentifier};
use der::{
asn1::ContextSpecificRef, Decode, DecodeValue, Encode, EncodeValue, FixedTag, Length, Reader,
Sequence, Tag, TagMode, TagNumber, Writer,
asn1::{AnyRef, ContextSpecificRef, ObjectIdentifier},
oid::AssociatedOid,
Decode, DecodeValue, Encode, EncodeValue, FixedTag, Length, Reader, Sequence, Tag, TagMode,
TagNumber, Writer,
};
use spki::{AlgorithmIdentifier, AlgorithmIdentifierRef};

Expand All @@ -14,7 +15,7 @@ const OID_PSPECIFIED: ObjectIdentifier = ObjectIdentifier::new_unwrap("1.2.840.1

const SHA_1_AI: AlgorithmIdentifierRef<'_> = AlgorithmIdentifierRef {
oid: OID_SHA_1,
parameters: None,
parameters: Some(AnyRef::NULL),
};

/// `TrailerField` as defined in [RFC 8017 Appendix 2.3].
Expand Down Expand Up @@ -92,6 +93,28 @@ impl<'a> RsaPssParams<'a> {
/// Default RSA PSS Salt length in RsaPssParams
pub const SALT_LEN_DEFAULT: u8 = 20;

/// Create new RsaPssParams for the provided digest and salt len
pub fn new<D>(salt_len: u8) -> Self
where
D: AssociatedOid,
{
Self {
hash: AlgorithmIdentifierRef {
oid: D::OID,
parameters: Some(AnyRef::NULL),
},
mask_gen: AlgorithmIdentifier {
oid: OID_MGF_1,
parameters: Some(AlgorithmIdentifierRef {
oid: D::OID,
parameters: Some(AnyRef::NULL),
}),
},
salt_len,
trailer_field: Default::default(),
}
}

fn context_specific_hash(&self) -> Option<ContextSpecificRef<'_, AlgorithmIdentifierRef<'a>>> {
if self.hash == SHA_1_AI {
None
Expand Down Expand Up @@ -238,6 +261,35 @@ pub struct RsaOaepParams<'a> {
}

impl<'a> RsaOaepParams<'a> {
/// Create new RsaPssParams for the provided digest and default (empty) label
pub fn new<D>() -> Self
where
D: AssociatedOid,
{
Self::new_with_label::<D>(&[])
}

/// Create new RsaPssParams for the provided digest and specified label
pub fn new_with_label<D>(label: &'a impl AsRef<[u8]>) -> Self
where
D: AssociatedOid,
{
Self {
hash: AlgorithmIdentifierRef {
oid: D::OID,
parameters: Some(AnyRef::NULL),
},
mask_gen: AlgorithmIdentifier {
oid: OID_MGF_1,
parameters: Some(AlgorithmIdentifierRef {
oid: D::OID,
parameters: Some(AnyRef::NULL),
}),
},
p_source: pspecicied_algorithm_identifier(label),
}
}

fn context_specific_hash(&self) -> Option<ContextSpecificRef<'_, AlgorithmIdentifierRef<'a>>> {
if self.hash == SHA_1_AI {
None
Expand Down Expand Up @@ -332,12 +384,16 @@ impl<'a> TryFrom<&'a [u8]> for RsaOaepParams<'a> {
}
}

/// Default Source Algorithm, empty string
fn default_pempty_string<'a>() -> AlgorithmIdentifierRef<'a> {
fn pspecicied_algorithm_identifier(label: &impl AsRef<[u8]>) -> AlgorithmIdentifierRef<'_> {
AlgorithmIdentifierRef {
oid: OID_PSPECIFIED,
parameters: Some(
AnyRef::new(Tag::OctetString, &[]).expect("error creating default OAEP params"),
AnyRef::new(Tag::OctetString, label.as_ref()).expect("error creating OAEP params"),
),
}
}

/// Default Source Algorithm, empty string
fn default_pempty_string<'a>() -> AlgorithmIdentifierRef<'a> {
pspecicied_algorithm_identifier(&[])
}
88 changes: 72 additions & 16 deletions pkcs1/tests/params.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,33 @@
//! PKCS#1 algorithm params tests

use const_oid::db;
use der::{asn1::OctetStringRef, Encode};
use der::{
asn1::{AnyRef, ObjectIdentifier, OctetStringRef},
oid::AssociatedOid,
Encode,
};
use hex_literal::hex;
use pkcs1::{RsaOaepParams, RsaPssParams, TrailerField};

/// Default PSS parameters using all default values (SHA1, MGF1)
const RSA_PSS_PARAMETERS_DEFAULTS: &[u8] = &hex!("3000");
/// Example PSS parameters using SHA256 instead of SHA1
const RSA_PSS_PARAMETERS_SHA2_256: &[u8] = &hex!("3030a00d300b0609608648016503040201a11a301806092a864886f70d010108300b0609608648016503040201a203020120");
const RSA_PSS_PARAMETERS_SHA2_256: &[u8] = &hex!("3034a00f300d06096086480165030402010500a11c301a06092a864886f70d010108300d06096086480165030402010500a203020120");

/// Default OAEP parameters using all default values (SHA1, MGF1, Empty)
const RSA_OAEP_PARAMETERS_DEFAULTS: &[u8] = &hex!("3000");
/// Example OAEP parameters using SHA256 instead of SHA1 and 'abc' as label
const RSA_OAEP_PARAMETERS_SHA2_256: &[u8] = &hex!("303fa00d300b0609608648016503040201a11a301806092a864886f70d010108300b0609608648016503040201a212301006092a864886f70d0101090403abcdef");
/// Example OAEP parameters using SHA256 instead of SHA1
const RSA_OAEP_PARAMETERS_SHA2_256: &[u8] = &hex!("302fa00f300d06096086480165030402010500a11c301a06092a864886f70d010108300d06096086480165030402010500");

struct Sha1Mock {}
impl AssociatedOid for Sha1Mock {
const OID: ObjectIdentifier = ObjectIdentifier::new_unwrap("1.3.14.3.2.26");
}

struct Sha256Mock {}
impl AssociatedOid for Sha256Mock {
const OID: ObjectIdentifier = ObjectIdentifier::new_unwrap("2.16.840.1.101.3.4.2.1");
}

#[test]
fn decode_pss_param() {
Expand All @@ -23,7 +37,7 @@ fn decode_pss_param() {
.hash
.assert_algorithm_oid(db::rfc5912::ID_SHA_256)
.is_ok());
assert_eq!(param.hash.parameters, None);
assert_eq!(param.hash.parameters, Some(AnyRef::NULL));
assert!(param
.mask_gen
.assert_algorithm_oid(db::rfc5912::ID_MGF_1)
Expand Down Expand Up @@ -56,7 +70,7 @@ fn decode_pss_param_default() {
.hash
.assert_algorithm_oid(db::rfc5912::ID_SHA_1)
.is_ok());
assert_eq!(param.hash.parameters, None);
assert_eq!(param.hash.parameters, Some(AnyRef::NULL));
assert!(param
.mask_gen
.assert_algorithm_oid(db::rfc5912::ID_MGF_1)
Expand All @@ -67,6 +81,10 @@ fn decode_pss_param_default() {
.unwrap()
.assert_algorithm_oid(db::rfc5912::ID_SHA_1)
.is_ok());
assert_eq!(
param.mask_gen.parameters.unwrap().parameters,
Some(AnyRef::NULL)
);
assert_eq!(param.salt_len, 20);
assert_eq!(param.trailer_field, TrailerField::BC);
assert_eq!(param, Default::default())
Expand All @@ -81,6 +99,23 @@ fn encode_pss_param_default() {
);
}

#[test]
fn new_pss_param() {
let mut buf = [0_u8; 256];

let param = RsaPssParams::new::<Sha1Mock>(20);
assert_eq!(
param.encode_to_slice(&mut buf).unwrap(),
RSA_PSS_PARAMETERS_DEFAULTS
);

let param = RsaPssParams::new::<Sha256Mock>(32);
assert_eq!(
param.encode_to_slice(&mut buf).unwrap(),
RSA_PSS_PARAMETERS_SHA2_256
);
}

#[test]
fn decode_oaep_param() {
let param = RsaOaepParams::try_from(RSA_OAEP_PARAMETERS_SHA2_256).unwrap();
Expand All @@ -89,7 +124,7 @@ fn decode_oaep_param() {
.hash
.assert_algorithm_oid(db::rfc5912::ID_SHA_256)
.is_ok());
assert_eq!(param.hash.parameters, None);
assert_eq!(param.hash.parameters, Some(AnyRef::NULL));
assert!(param
.mask_gen
.assert_algorithm_oid(db::rfc5912::ID_MGF_1)
Expand All @@ -104,14 +139,13 @@ fn decode_oaep_param() {
.p_source
.assert_algorithm_oid(db::rfc5912::ID_P_SPECIFIED)
.is_ok());
assert_eq!(
param
.p_source
.parameters_any()
.unwrap()
.decode_as::<OctetStringRef<'_>>(),
OctetStringRef::new(&[0xab, 0xcd, 0xef])
);
assert!(param
.p_source
.parameters_any()
.unwrap()
.decode_as::<OctetStringRef<'_>>()
.unwrap()
.is_empty(),);
}

#[test]
Expand All @@ -132,7 +166,7 @@ fn decode_oaep_param_default() {
.hash
.assert_algorithm_oid(db::rfc5912::ID_SHA_1)
.is_ok());
assert_eq!(param.hash.parameters, None);
assert_eq!(param.hash.parameters, Some(AnyRef::NULL));
assert!(param
.mask_gen
.assert_algorithm_oid(db::rfc5912::ID_MGF_1)
Expand All @@ -143,6 +177,10 @@ fn decode_oaep_param_default() {
.unwrap()
.assert_algorithm_oid(db::rfc5912::ID_SHA_1)
.is_ok());
assert_eq!(
param.mask_gen.parameters.unwrap().parameters,
Some(AnyRef::NULL)
);
assert!(param
.p_source
.assert_algorithm_oid(db::rfc5912::ID_P_SPECIFIED)
Expand All @@ -165,3 +203,21 @@ fn encode_oaep_param_default() {
RSA_OAEP_PARAMETERS_DEFAULTS
);
}

#[test]
fn new_oaep_param() {
let mut buf = [0_u8; 256];

let param = RsaOaepParams::new::<Sha1Mock>();
assert_eq!(
param.encode_to_slice(&mut buf).unwrap(),
RSA_OAEP_PARAMETERS_DEFAULTS
);

let param = RsaOaepParams::new::<Sha256Mock>();
println!("{:02x?}", param.encode_to_slice(&mut buf).unwrap());
assert_eq!(
param.encode_to_slice(&mut buf).unwrap(),
RSA_OAEP_PARAMETERS_SHA2_256
);
}