diff --git a/der/derive/src/sequence/field.rs b/der/derive/src/sequence/field.rs index 23c85835a..057496e5e 100644 --- a/der/derive/src/sequence/field.rs +++ b/der/derive/src/sequence/field.rs @@ -98,7 +98,7 @@ impl SequenceField { !attrs.optional, "`default`, and `optional` are mutually exclusive" ); - lowerer.apply_default(&self.ident, default, attrs.context_specific.is_none()); + lowerer.apply_default(&self.ident, default); } lowerer.into_tokens() @@ -191,8 +191,8 @@ impl LowerFieldEncoder { } /// Handle default value for a type. - fn apply_default(&mut self, ident: &Ident, default: &Path, is_bare: bool) { - let mut encoder = &self.encoder; + fn apply_default(&mut self, ident: &Ident, default: &Path) { + let encoder = &self.encoder; self.encoder = quote! { if &self.#ident == &#default() { @@ -201,13 +201,6 @@ impl LowerFieldEncoder { Some(#encoder) } }; - - if is_bare { - encoder = &self.encoder; - self.encoder = quote! { - ::der::asn1::OptionalRef(#encoder.as_ref()) - }; - } } /// Make this field context-specific. diff --git a/der/src/asn1.rs b/der/src/asn1.rs index 0dbed1bb3..7d7764e14 100644 --- a/der/src/asn1.rs +++ b/der/src/asn1.rs @@ -32,7 +32,6 @@ pub use self::{ integer::bigint::UIntBytes, null::Null, octet_string::OctetString, - optional::OptionalRef, printable_string::PrintableString, sequence::{Sequence, SequenceRef}, sequence_of::{SequenceOf, SequenceOfIter}, diff --git a/der/src/asn1/context_specific.rs b/der/src/asn1/context_specific.rs index 376f5a800..29a169ecf 100644 --- a/der/src/asn1/context_specific.rs +++ b/der/src/asn1/context_specific.rs @@ -1,8 +1,8 @@ //! Context-specific field. use crate::{ - asn1::Any, Choice, Decode, DecodeValue, Decoder, DerOrd, Encode, EncodeValue, Encoder, Error, - Header, Length, Result, Tag, TagMode, TagNumber, Tagged, ValueOrd, + asn1::Any, Choice, Decode, DecodeValue, Decoder, DerOrd, Encode, EncodeValue, EncodeValueRef, + Encoder, Error, Header, Length, Result, Tag, TagMode, TagNumber, Tagged, ValueOrd, }; use core::cmp::Ordering; @@ -110,15 +110,6 @@ impl ContextSpecific { Ok(None) } - - /// Get a [`ContextSpecificRef`] for this field. - pub fn to_ref(&self) -> ContextSpecificRef<'_, T> { - ContextSpecificRef { - tag_number: self.tag_number, - tag_mode: self.tag_mode, - value: &self.value, - } - } } impl<'a, T> Choice<'a> for ContextSpecific @@ -144,11 +135,17 @@ where T: EncodeValue + Tagged, { fn value_len(&self) -> Result { - self.to_ref().value_len() + match self.tag_mode { + TagMode::Explicit => self.value.encoded_len(), + TagMode::Implicit => self.value.value_len(), + } } fn encode_value(&self, encoder: &mut Encoder<'_>) -> Result<()> { - self.to_ref().encode_value(encoder) + match self.tag_mode { + TagMode::Explicit => self.value.encode(encoder), + TagMode::Implicit => self.value.encode_value(encoder), + } } } @@ -157,16 +154,15 @@ where T: Tagged, { fn tag(&self) -> Tag { - self.to_ref().tag() - } -} + let constructed = match self.tag_mode { + TagMode::Explicit => true, + TagMode::Implicit => self.value.tag().is_constructed(), + }; -impl ValueOrd for ContextSpecific -where - T: EncodeValue + ValueOrd + Tagged, -{ - fn value_cmp(&self, other: &Self) -> Result { - self.to_ref().value_cmp(&other.to_ref()) + Tag::ContextSpecific { + number: self.tag_number, + constructed, + } } } @@ -191,6 +187,18 @@ where } } +impl ValueOrd for ContextSpecific +where + T: EncodeValue + ValueOrd + Tagged, +{ + fn value_cmp(&self, other: &Self) -> Result { + match self.tag_mode { + TagMode::Explicit => self.der_cmp(other), + TagMode::Implicit => self.value_cmp(other), + } + } +} + /// Context-specific field reference. /// /// This type encodes a field which is specific to a particular context @@ -208,51 +216,36 @@ pub struct ContextSpecificRef<'a, T> { pub value: &'a T, } -impl EncodeValue for ContextSpecificRef<'_, T> +impl<'a, T> ContextSpecificRef<'a, T> { + /// Convert to a [`ContextSpecific`]. + fn encoder(&self) -> ContextSpecific> { + ContextSpecific { + tag_number: self.tag_number, + tag_mode: self.tag_mode, + value: EncodeValueRef(self.value), + } + } +} + +impl<'a, T> EncodeValue for ContextSpecificRef<'a, T> where T: EncodeValue + Tagged, { fn value_len(&self) -> Result { - match self.tag_mode { - TagMode::Explicit => self.value.encoded_len(), - TagMode::Implicit => self.value.value_len(), - } + self.encoder().value_len() } fn encode_value(&self, encoder: &mut Encoder<'_>) -> Result<()> { - match self.tag_mode { - TagMode::Explicit => self.value.encode(encoder), - TagMode::Implicit => self.value.encode_value(encoder), - } + self.encoder().encode_value(encoder) } } -impl Tagged for ContextSpecificRef<'_, T> +impl<'a, T> Tagged for ContextSpecificRef<'a, T> where T: Tagged, { fn tag(&self) -> Tag { - let constructed = match self.tag_mode { - TagMode::Explicit => true, - TagMode::Implicit => self.value.tag().is_constructed(), - }; - - Tag::ContextSpecific { - number: self.tag_number, - constructed, - } - } -} - -impl ValueOrd for ContextSpecificRef<'_, T> -where - T: EncodeValue + ValueOrd + Tagged, -{ - fn value_cmp(&self, other: &Self) -> Result { - match self.tag_mode { - TagMode::Explicit => self.der_cmp(other), - TagMode::Implicit => self.value_cmp(other), - } + self.encoder().tag() } } diff --git a/der/src/asn1/optional.rs b/der/src/asn1/optional.rs index 0645f0db5..594806728 100644 --- a/der/src/asn1/optional.rs +++ b/der/src/asn1/optional.rs @@ -18,27 +18,6 @@ where } } -impl Encode for Option -where - T: Encode, -{ - fn encoded_len(&self) -> Result { - if let Some(encodable) = self { - encodable.encoded_len() - } else { - Ok(0u8.into()) - } - } - - fn encode(&self, encoder: &mut Encoder<'_>) -> Result<()> { - if let Some(encodable) = self { - encodable.encode(encoder) - } else { - Ok(()) - } - } -} - impl DerOrd for Option where T: DerOrd, @@ -56,15 +35,25 @@ where } } -/// A reference to an ASN.1 `OPTIONAL` type, used for encoding only. -pub struct OptionalRef<'a, T>(pub Option<&'a T>); +impl Encode for Option +where + T: Encode, +{ + fn encoded_len(&self) -> Result { + (&self).encoded_len() + } + + fn encode(&self, encoder: &mut Encoder<'_>) -> Result<()> { + (&self).encode(encoder) + } +} -impl<'a, T> Encode for OptionalRef<'a, T> +impl Encode for &Option where T: Encode, { fn encoded_len(&self) -> Result { - if let Some(encodable) = self.0 { + if let Some(encodable) = self { encodable.encoded_len() } else { Ok(0u8.into()) @@ -72,7 +61,7 @@ where } fn encode(&self, encoder: &mut Encoder<'_>) -> Result<()> { - if let Some(encodable) = self.0 { + if let Some(encodable) = self { encodable.encode(encoder) } else { Ok(()) diff --git a/der/src/encode_ref.rs b/der/src/encode_ref.rs new file mode 100644 index 000000000..8855b4743 --- /dev/null +++ b/der/src/encode_ref.rs @@ -0,0 +1,71 @@ +//! Wrapper object for encoding reference types. +// TODO(tarcieri): replace with blanket impls of `Encode(Value)` for reference types? + +use crate::{Encode, EncodeValue, Encoder, Length, Result, Tag, Tagged, ValueOrd}; +use core::cmp::Ordering; + +/// Reference encoder: wrapper type which impls `Encode` for any reference to a +/// type which impls the same. +pub struct EncodeRef<'a, T>(pub &'a T); + +impl<'a, T> AsRef for EncodeRef<'a, T> { + fn as_ref(&self) -> &T { + self.0 + } +} + +impl<'a, T> Encode for EncodeRef<'a, T> +where + T: Encode, +{ + fn encoded_len(&self) -> Result { + self.0.encoded_len() + } + + fn encode(&self, encoder: &mut Encoder<'_>) -> Result<()> { + self.0.encode(encoder) + } +} + +/// Reference value encoder: wrapper type which impls `EncodeValue` and `Tagged` +/// for any reference type which impls the same. +/// +/// By virtue of the blanket impl, this type also impls `Encode`. +pub struct EncodeValueRef<'a, T>(pub &'a T); + +impl<'a, T> AsRef for EncodeValueRef<'a, T> { + fn as_ref(&self) -> &T { + self.0 + } +} + +impl<'a, T> EncodeValue for EncodeValueRef<'a, T> +where + T: EncodeValue, +{ + fn value_len(&self) -> Result { + self.0.value_len() + } + + fn encode_value(&self, encoder: &mut Encoder<'_>) -> Result<()> { + self.0.encode_value(encoder) + } +} + +impl<'a, T> Tagged for EncodeValueRef<'a, T> +where + T: Tagged, +{ + fn tag(&self) -> Tag { + self.0.tag() + } +} + +impl<'a, T> ValueOrd for EncodeValueRef<'a, T> +where + T: ValueOrd, +{ + fn value_cmp(&self, other: &Self) -> Result { + self.0.value_cmp(other.0) + } +} diff --git a/der/src/encoder.rs b/der/src/encoder.rs index 52c5258f9..b3a47388b 100644 --- a/der/src/encoder.rs +++ b/der/src/encoder.rs @@ -1,8 +1,8 @@ //! DER encoder. use crate::{ - asn1::*, Encode, EncodeValue, Error, ErrorKind, Header, Length, Result, Tag, TagMode, - TagNumber, Tagged, + asn1::*, Encode, EncodeRef, EncodeValue, Error, ErrorKind, Header, Length, Result, Tag, + TagMode, TagNumber, Tagged, }; /// DER encoder. @@ -78,7 +78,7 @@ impl<'a> Encoder<'a> { .and_then(|value| self.encode(&value)) } - /// Encode a `CONTEXT-SPECIFIC` field with `EXPLICIT` tagging. + /// Encode a `CONTEXT-SPECIFIC` field with the provided tag number and mode. pub fn context_specific( &mut self, tag_number: TagNumber, @@ -96,7 +96,7 @@ impl<'a> Encoder<'a> { .encode(self) } - /// Encode the provided value as an ASN.1 `GeneralizedTime` + /// Encode the provided value as an ASN.1 `GeneralizedTime`. pub fn generalized_time(&mut self, value: impl TryInto) -> Result<()> { value .try_into() @@ -135,6 +135,11 @@ impl<'a> Encoder<'a> { .and_then(|value| self.encode(&value)) } + /// Encode an ASN.1 `OPTIONAL` for the given option reference. + pub fn optional(&mut self, value: Option<&T>) -> Result<()> { + value.map(EncodeRef).encode(self) + } + /// Encode the provided value as an ASN.1 `PrintableString` pub fn printable_string(&mut self, value: impl TryInto>) -> Result<()> { value diff --git a/der/src/lib.rs b/der/src/lib.rs index 825d307d4..3a727568f 100644 --- a/der/src/lib.rs +++ b/der/src/lib.rs @@ -351,6 +351,7 @@ mod datetime; mod decode; mod decoder; mod encode; +mod encode_ref; mod encoder; mod error; mod header; @@ -368,6 +369,7 @@ pub use crate::{ decode::{Decode, DecodeOwned, DecodeValue}, decoder::Decoder, encode::{Encode, EncodeValue}, + encode_ref::{EncodeRef, EncodeValueRef}, encoder::Encoder, error::{Error, ErrorKind, Result}, header::Header,