diff --git a/der/src/asn1/any.rs b/der/src/asn1/any.rs index 5a629ab78..3856191bd 100644 --- a/der/src/asn1/any.rs +++ b/der/src/asn1/any.rs @@ -1,18 +1,16 @@ //! ASN.1 `ANY` type. + #![cfg_attr(feature = "arbitrary", allow(clippy::integer_arithmetic))] use crate::{ - asn1::*, BytesRef, Choice, Decode, DecodeValue, DerOrd, EncodeValue, Error, ErrorKind, - FixedTag, Header, Length, Reader, Result, SliceReader, Tag, Tagged, ValueOrd, Writer, + BytesRef, Choice, Decode, DecodeValue, DerOrd, EncodeValue, Error, ErrorKind, Header, Length, + Reader, Result, SliceReader, Tag, Tagged, ValueOrd, Writer, }; use core::cmp::Ordering; #[cfg(feature = "alloc")] use {crate::BytesOwned, alloc::boxed::Box}; -#[cfg(feature = "oid")] -use crate::asn1::ObjectIdentifier; - /// ASN.1 `ANY`: represents any explicitly tagged ASN.1 value. /// /// This is a zero-copy reference type which borrows from the input data. @@ -58,11 +56,14 @@ impl<'a> AnyRef<'a> { } /// Attempt to decode this [`AnyRef`] type into the inner value. - pub fn decode_into(self) -> Result + pub fn decode_as(self) -> Result where - T: DecodeValue<'a> + FixedTag, + T: Choice<'a> + DecodeValue<'a>, { - self.tag.assert_eq(T::TAG)?; + if !T::can_decode(self.tag) { + return Err(self.tag.unexpected_error(None)); + } + let header = Header { tag: self.tag, length: self.value.len(), @@ -78,48 +79,6 @@ impl<'a> AnyRef<'a> { self == Self::NULL } - /// Attempt to decode an ASN.1 `BIT STRING`. - pub fn bit_string(self) -> Result> { - self.try_into() - } - - /// Attempt to decode an ASN.1 `CONTEXT-SPECIFIC` field. - pub fn context_specific(self) -> Result> - where - T: Decode<'a>, - { - self.try_into() - } - - /// Attempt to decode an ASN.1 `GeneralizedTime`. - pub fn generalized_time(self) -> Result { - self.try_into() - } - - /// Attempt to decode an ASN.1 `OCTET STRING`. - pub fn octet_string(self) -> Result> { - self.try_into() - } - - /// Attempt to decode an ASN.1 `OBJECT IDENTIFIER`. - #[cfg(feature = "oid")] - #[cfg_attr(docsrs, doc(cfg(feature = "oid")))] - pub fn oid(self) -> Result { - self.try_into() - } - - /// Attempt to decode an ASN.1 `OPTIONAL` value. - pub fn optional(self) -> Result> - where - T: Choice<'a> + TryFrom, - { - if T::can_decode(self.tag) { - T::try_from(self).map(Some) - } else { - Ok(None) - } - } - /// Attempt to decode this value an ASN.1 `SEQUENCE`, creating a new /// nested reader and calling the provided argument with it. pub fn sequence(self, f: F) -> Result @@ -131,11 +90,6 @@ impl<'a> AnyRef<'a> { let result = f(&mut reader)?; reader.finish(result) } - - /// Attempt to decode an ASN.1 `UTCTime`. - pub fn utc_time(self) -> Result { - self.try_into() - } } impl<'a> Choice<'a> for AnyRef<'a> { @@ -231,19 +185,20 @@ impl Any { } /// Attempt to decode this [`Any`] type into the inner value. - pub fn decode_into<'a, T>(&'a self) -> Result + pub fn decode_as<'a, T>(&'a self) -> Result where - T: DecodeValue<'a> + FixedTag, + T: Choice<'a> + DecodeValue<'a>, { - self.tag.assert_eq(T::TAG)?; - let header = Header { - tag: self.tag, - length: self.value.len(), - }; + AnyRef::from(self).decode_as() + } - let mut decoder = SliceReader::new(self.value.as_slice())?; - let result = T::decode_value(&mut decoder, header)?; - decoder.finish(result) + /// Attempt to decode this value an ASN.1 `SEQUENCE`, creating a new + /// nested reader and calling the provided argument with it. + pub fn sequence<'a, F, T>(&'a self, f: F) -> Result + where + F: FnOnce(&mut SliceReader<'a>) -> Result, + { + AnyRef::from(self).sequence(f) } } diff --git a/der/src/asn1/bit_string.rs b/der/src/asn1/bit_string.rs index b41c60345..b049d4d6f 100644 --- a/der/src/asn1/bit_string.rs +++ b/der/src/asn1/bit_string.rs @@ -163,7 +163,7 @@ impl<'a> TryFrom> for BitStringRef<'a> { type Error = Error; fn try_from(any: AnyRef<'a>) -> Result> { - any.decode_into() + any.decode_as() } } diff --git a/der/src/asn1/generalized_time.rs b/der/src/asn1/generalized_time.rs index ccb5218c5..7e7f0f5e9 100644 --- a/der/src/asn1/generalized_time.rs +++ b/der/src/asn1/generalized_time.rs @@ -168,7 +168,7 @@ impl TryFrom> for GeneralizedTime { type Error = Error; fn try_from(any: AnyRef<'_>) -> Result { - any.decode_into() + any.decode_as() } } diff --git a/der/src/asn1/integer.rs b/der/src/asn1/integer.rs index f7d963416..52dbd8468 100644 --- a/der/src/asn1/integer.rs +++ b/der/src/asn1/integer.rs @@ -65,7 +65,7 @@ macro_rules! impl_int_encoding { type Error = Error; fn try_from(any: AnyRef<'_>) -> Result { - any.decode_into() + any.decode_as() } } )+ @@ -113,7 +113,7 @@ macro_rules! impl_uint_encoding { type Error = Error; fn try_from(any: AnyRef<'_>) -> Result { - any.decode_into() + any.decode_as() } } )+ diff --git a/der/src/asn1/integer/bigint.rs b/der/src/asn1/integer/bigint.rs index 2039cf608..3178c4ed4 100644 --- a/der/src/asn1/integer/bigint.rs +++ b/der/src/asn1/integer/bigint.rs @@ -79,7 +79,7 @@ impl<'a> TryFrom> for IntRef<'a> { type Error = Error; fn try_from(any: AnyRef<'a>) -> Result> { - any.decode_into() + any.decode_as() } } @@ -167,7 +167,7 @@ impl<'a> TryFrom> for UintRef<'a> { type Error = Error; fn try_from(any: AnyRef<'a>) -> Result> { - any.decode_into() + any.decode_as() } } @@ -265,7 +265,7 @@ mod allocating { type Error = Error; fn try_from(any: AnyRef<'a>) -> Result { - any.decode_into() + any.decode_as() } } @@ -372,7 +372,7 @@ mod allocating { type Error = Error; fn try_from(any: AnyRef<'a>) -> Result { - any.decode_into() + any.decode_as() } } diff --git a/der/src/asn1/internal_macros.rs b/der/src/asn1/internal_macros.rs index 2627e368a..119af1d16 100644 --- a/der/src/asn1/internal_macros.rs +++ b/der/src/asn1/internal_macros.rs @@ -53,7 +53,7 @@ macro_rules! impl_string_type { type Error = Error; fn try_from(any: &'__der Any) -> Result<$type> { - any.decode_into() + any.decode_as() } } } diff --git a/der/src/asn1/null.rs b/der/src/asn1/null.rs index f9a46a8ad..6bea65781 100644 --- a/der/src/asn1/null.rs +++ b/der/src/asn1/null.rs @@ -45,7 +45,7 @@ impl TryFrom> for Null { type Error = Error; fn try_from(any: AnyRef<'_>) -> Result { - any.decode_into() + any.decode_as() } } diff --git a/der/src/asn1/octet_string.rs b/der/src/asn1/octet_string.rs index fef4f09d5..3bbfd047b 100644 --- a/der/src/asn1/octet_string.rs +++ b/der/src/asn1/octet_string.rs @@ -87,7 +87,7 @@ impl<'a> TryFrom> for OctetStringRef<'a> { type Error = Error; fn try_from(any: AnyRef<'a>) -> Result> { - any.decode_into() + any.decode_as() } } diff --git a/der/src/asn1/printable_string.rs b/der/src/asn1/printable_string.rs index 69fff9fd0..f1929307f 100644 --- a/der/src/asn1/printable_string.rs +++ b/der/src/asn1/printable_string.rs @@ -112,7 +112,7 @@ impl<'a> TryFrom> for PrintableStringRef<'a> { type Error = Error; fn try_from(any: AnyRef<'a>) -> Result> { - any.decode_into() + any.decode_as() } } @@ -206,7 +206,7 @@ mod allocation { type Error = Error; fn try_from(any: &AnyRef<'a>) -> Result { - (*any).decode_into() + (*any).decode_as() } } diff --git a/der/src/asn1/teletex_string.rs b/der/src/asn1/teletex_string.rs index 1aa8a4495..fdaac856a 100644 --- a/der/src/asn1/teletex_string.rs +++ b/der/src/asn1/teletex_string.rs @@ -83,7 +83,7 @@ impl<'a> TryFrom> for TeletexStringRef<'a> { type Error = Error; fn try_from(any: AnyRef<'a>) -> Result> { - any.decode_into() + any.decode_as() } } impl<'a> From> for AnyRef<'a> { @@ -164,7 +164,7 @@ mod allocation { type Error = Error; fn try_from(any: &AnyRef<'a>) -> Result { - (*any).decode_into() + (*any).decode_as() } } diff --git a/der/src/asn1/utc_time.rs b/der/src/asn1/utc_time.rs index 1765f2e9e..6d4b3b5bc 100644 --- a/der/src/asn1/utc_time.rs +++ b/der/src/asn1/utc_time.rs @@ -191,7 +191,7 @@ impl TryFrom> for UtcTime { type Error = Error; fn try_from(any: AnyRef<'_>) -> Result { - any.decode_into() + any.decode_as() } } diff --git a/der/src/asn1/utf8_string.rs b/der/src/asn1/utf8_string.rs index 3ca2d9076..7ec5f6c7f 100644 --- a/der/src/asn1/utf8_string.rs +++ b/der/src/asn1/utf8_string.rs @@ -94,7 +94,7 @@ impl<'a> TryFrom> for Utf8StringRef<'a> { type Error = Error; fn try_from(any: AnyRef<'a>) -> Result> { - any.decode_into() + any.decode_as() } } @@ -103,7 +103,7 @@ impl<'a> TryFrom<&'a Any> for Utf8StringRef<'a> { type Error = Error; fn try_from(any: &'a Any) -> Result> { - any.decode_into() + any.decode_as() } } diff --git a/der/src/asn1/videotex_string.rs b/der/src/asn1/videotex_string.rs index 21577c3a0..672f657df 100644 --- a/der/src/asn1/videotex_string.rs +++ b/der/src/asn1/videotex_string.rs @@ -101,7 +101,7 @@ impl<'a> TryFrom> for VideotexStringRef<'a> { type Error = Error; fn try_from(any: AnyRef<'a>) -> Result> { - any.decode_into() + any.decode_as() } } @@ -110,7 +110,7 @@ impl<'a> TryFrom<&'a Any> for VideotexStringRef<'a> { type Error = Error; fn try_from(any: &'a Any) -> Result> { - any.decode_into() + any.decode_as() } } diff --git a/pkcs1/tests/params.rs b/pkcs1/tests/params.rs index 597e263f0..7a653b0ce 100644 --- a/pkcs1/tests/params.rs +++ b/pkcs1/tests/params.rs @@ -105,7 +105,11 @@ fn decode_oaep_param() { .assert_algorithm_oid(db::rfc5912::ID_P_SPECIFIED) .is_ok()); assert_eq!( - param.p_source.parameters_any().unwrap().octet_string(), + param + .p_source + .parameters_any() + .unwrap() + .decode_as::>(), OctetStringRef::new(&[0xab, 0xcd, 0xef]) ); } @@ -147,7 +151,7 @@ fn decode_oaep_param_default() { .p_source .parameters_any() .unwrap() - .octet_string() + .decode_as::>() .unwrap() .is_empty(),); assert_eq!(param, Default::default()) diff --git a/pkcs5/src/pbes2.rs b/pkcs5/src/pbes2.rs index e8c60abb5..3af39d0bd 100644 --- a/pkcs5/src/pbes2.rs +++ b/pkcs5/src/pbes2.rs @@ -320,7 +320,7 @@ impl<'a> TryFrom> for EncryptionScheme<'a> { fn try_from(alg: AlgorithmIdentifierRef<'a>) -> der::Result { // TODO(tarcieri): support for non-AES algorithms? let iv = match alg.parameters { - Some(params) => params.octet_string()?.as_bytes(), + Some(params) => params.decode_as::>()?.as_bytes(), None => return Err(Tag::OctetString.value_error()), }; diff --git a/pkcs7/tests/content_tests.rs b/pkcs7/tests/content_tests.rs index 00fe33b1c..406e4cf73 100644 --- a/pkcs7/tests/content_tests.rs +++ b/pkcs7/tests/content_tests.rs @@ -104,7 +104,7 @@ fn decode_signed_mdm_example() { signer_infos: _, })) => { let _content = content - .decode_into::() + .decode_as::() .expect("Content should be in the correct format: SequenceRef"); } _ => panic!("expected ContentInfo::SignedData(Some(_))"), @@ -132,7 +132,7 @@ fn decode_signed_scep_example() { signer_infos: _, })) => { let _content = content - .decode_into::() + .decode_as::() .expect("Content should be in the correct format: OctetStringRef"); assert_eq!(ver, CmsVersion::V1) diff --git a/pkcs8/tests/private_key.rs b/pkcs8/tests/private_key.rs index 15d669495..36f119f81 100644 --- a/pkcs8/tests/private_key.rs +++ b/pkcs8/tests/private_key.rs @@ -1,5 +1,6 @@ //! PKCS#8 private key tests +use der::asn1::ObjectIdentifier; use hex_literal::hex; use pkcs8::{PrivateKeyInfo, Version}; @@ -48,7 +49,11 @@ fn decode_ec_p256_der() { assert_eq!(pk.algorithm.oid, "1.2.840.10045.2.1".parse().unwrap()); assert_eq!( - pk.algorithm.parameters.unwrap().oid().unwrap(), + pk.algorithm + .parameters + .unwrap() + .decode_as::() + .unwrap(), "1.2.840.10045.3.1.7".parse().unwrap() ); diff --git a/spki/src/algorithm.rs b/spki/src/algorithm.rs index aeeac3206..739815f4f 100644 --- a/spki/src/algorithm.rs +++ b/spki/src/algorithm.rs @@ -157,7 +157,7 @@ impl<'a> AlgorithmIdentifierRef<'a> { None => None, Some(p) => match p { AnyRef::NULL => None, - _ => Some(p.oid()?), + _ => Some(p.decode_as::()?), }, }, )) diff --git a/spki/tests/spki.rs b/spki/tests/spki.rs index e655c0b29..0a3baa466 100644 --- a/spki/tests/spki.rs +++ b/spki/tests/spki.rs @@ -1,5 +1,6 @@ //! `SubjectPublicKeyInfo` tests. +use der::asn1::ObjectIdentifier; use hex_literal::hex; use spki::SubjectPublicKeyInfoRef; @@ -51,7 +52,11 @@ fn decode_ec_p256_der() { assert_eq!(spki.algorithm.oid, "1.2.840.10045.2.1".parse().unwrap()); assert_eq!( - spki.algorithm.parameters.unwrap().oid().unwrap(), + spki.algorithm + .parameters + .unwrap() + .decode_as::() + .unwrap(), "1.2.840.10045.3.1.7".parse().unwrap() ); diff --git a/x509-cert/tests/certreq.rs b/x509-cert/tests/certreq.rs index 08f6d6eb3..e1e67fa76 100644 --- a/x509-cert/tests/certreq.rs +++ b/x509-cert/tests/certreq.rs @@ -65,7 +65,7 @@ fn decode_rsa_2048_der() { // Check the extensions. let extensions: x509_cert::ext::Extensions = - attribute.values.get(0).unwrap().decode_into().unwrap(); + attribute.values.get(0).unwrap().decode_as().unwrap(); for (ext, (oid, val)) in extensions.iter().zip(EXTENSIONS) { assert_eq!(ext.extn_id, oid.parse().unwrap()); assert_eq!(ext.extn_value.as_bytes(), *val);