diff --git a/der/derive/src/choice/variant.rs b/der/derive/src/choice/variant.rs index 536aa423d..cb2639b0a 100644 --- a/der/derive/src/choice/variant.rs +++ b/der/derive/src/choice/variant.rs @@ -83,8 +83,9 @@ impl ChoiceVariant { #[cfg(test)] mod tests { use super::ChoiceVariant; - use crate::{Asn1Type, FieldAttrs, Tag}; + use crate::{Asn1Type, FieldAttrs, Tag, TagNumber}; use proc_macro2::Span; + use quote::quote; use syn::Ident; #[test] @@ -97,25 +98,91 @@ mod tests { tag: Tag::Universal(Asn1Type::Utf8String), }; - // TODO(tarcieri): better comparison, possibly using `quote!` assert_eq!( variant.to_decode_tokens().to_string(), - ":: der :: Tag :: Utf8String => Ok (Self :: ExampleVariant (decoder . decode () ? . try_into () ?)) ," + quote! { + ::der::Tag::Utf8String => Ok(Self::ExampleVariant( + decoder.decode()? + .try_into()? + )), + } + .to_string() ); assert_eq!( variant.to_encode_tokens().to_string(), - "Self :: ExampleVariant (variant) => encoder . encode (variant) ? ," + quote! { + Self::ExampleVariant(variant) => encoder.encode(variant)?, + } + .to_string() ); assert_eq!( variant.to_encoded_len_tokens().to_string(), - "Self :: ExampleVariant (variant) => variant . encoded_len () ," + quote! { + Self::ExampleVariant(variant) => variant.encoded_len(), + } + .to_string() ); assert_eq!( variant.to_tagged_tokens().to_string(), - "Self :: ExampleVariant (_) => :: der :: Tag :: Utf8String ," + quote! { + Self::ExampleVariant(_) => ::der::Tag::Utf8String, + } + .to_string() + ) + } + + #[test] + fn implicit() { + let span = Span::call_site(); + + let variant = ChoiceVariant { + ident: Ident::new("ImplicitVariant", span), + attrs: FieldAttrs::default(), + tag: Tag::ContextSpecific { + constructed: false, + number: TagNumber(0), + }, + }; + + assert_eq!( + variant.to_decode_tokens().to_string(), + quote! { + ::der::Tag::ContextSpecific { + constructed: false, + number: ::der::TagNumber::N0, + } => Ok(Self::ImplicitVariant(decoder.decode()?.try_into()?)), + } + .to_string() + ); + + assert_eq!( + variant.to_encode_tokens().to_string(), + quote! { + Self::ImplicitVariant(variant) => encoder.encode(variant)?, + } + .to_string() + ); + + assert_eq!( + variant.to_encoded_len_tokens().to_string(), + quote! { + Self::ImplicitVariant(variant) => variant.encoded_len(), + } + .to_string() + ); + + assert_eq!( + variant.to_tagged_tokens().to_string(), + quote! { + Self::ImplicitVariant(_) => ::der::Tag::ContextSpecific { + constructed: false, + number: ::der::TagNumber::N0, + }, + } + .to_string() ) } } diff --git a/der/derive/src/sequence.rs b/der/derive/src/sequence.rs index 33df8037d..307634917 100644 --- a/der/derive/src/sequence.rs +++ b/der/derive/src/sequence.rs @@ -277,4 +277,51 @@ mod tests { assert_eq!(public_key_field.attrs.optional, true); assert_eq!(public_key_field.attrs.tag_mode, TagMode::Explicit); } + + /// `IMPLICIT` tagged example + #[test] + fn implicit_example() { + let input = parse_quote! { + #[asn1(tag_mode = "IMPLICIT")] + pub struct ImplicitSequence<'a> { + #[asn1(context_specific = "0", type = "BIT STRING")] + bit_string: BitString<'a>, + + #[asn1(context_specific = "1", type = "GeneralizedTime")] + time: GeneralizedTime, + + #[asn1(context_specific = "2", type = "UTF8String")] + utf8_string: String, + } + }; + + let ir = DeriveSequence::new(input); + assert_eq!(ir.ident, "ImplicitSequence"); + assert_eq!(ir.lifetime.unwrap().to_string(), "'a"); + assert_eq!(ir.fields.len(), 3); + + let bit_string = &ir.fields[0]; + assert_eq!(bit_string.ident, "bit_string"); + assert_eq!(bit_string.attrs.asn1_type, Some(Asn1Type::BitString)); + assert_eq!( + bit_string.attrs.context_specific, + Some("0".parse().unwrap()) + ); + assert_eq!(bit_string.attrs.tag_mode, TagMode::Implicit); + + let time = &ir.fields[1]; + assert_eq!(time.ident, "time"); + assert_eq!(time.attrs.asn1_type, Some(Asn1Type::GeneralizedTime)); + assert_eq!(time.attrs.context_specific, Some("1".parse().unwrap())); + assert_eq!(time.attrs.tag_mode, TagMode::Implicit); + + let utf8_string = &ir.fields[2]; + assert_eq!(utf8_string.ident, "utf8_string"); + assert_eq!(utf8_string.attrs.asn1_type, Some(Asn1Type::Utf8String)); + assert_eq!( + utf8_string.attrs.context_specific, + Some("2".parse().unwrap()) + ); + assert_eq!(utf8_string.attrs.tag_mode, TagMode::Implicit); + } } diff --git a/der/derive/src/sequence/field.rs b/der/derive/src/sequence/field.rs index ac7218d58..5768998c0 100644 --- a/der/derive/src/sequence/field.rs +++ b/der/derive/src/sequence/field.rs @@ -1,6 +1,6 @@ //! Sequence field IR and lowerings -use crate::{Asn1Type, FieldAttrs, TypeAttrs}; +use crate::{Asn1Type, FieldAttrs, TagMode, TagNumber, TypeAttrs}; use proc_macro2::TokenStream; use proc_macro_error::abort; use quote::quote; @@ -37,6 +37,13 @@ impl SequenceField { ); } + if attrs.default.is_some() && attrs.optional { + abort!( + ident, + "`optional` and `default` field qualifiers are mutually exclusive" + ); + } + Self { ident, attrs, @@ -71,19 +78,27 @@ impl SequenceField { /// Derive code for encoding a field of a sequence. pub(super) fn to_encode_tokens(&self) -> TokenStream { let mut lowerer = LowerFieldEncoder::new(&self.ident); + let attrs = &self.attrs; - if let Some(ty) = &self.attrs.asn1_type { - lowerer.apply_asn1_type(ty, self.attrs.optional); - } - - if let Some(default) = &self.attrs.default { + if let Some(ty) = &attrs.asn1_type { // TODO(tarcieri): default in conjunction with ASN.1 types? debug_assert!( - self.attrs.asn1_type.is_none(), + attrs.default.is_none(), "`type` and `default` are mutually exclusive" ); + lowerer.apply_asn1_type(ty, attrs.optional); + } - lowerer.apply_default(default); + if let Some(tag_number) = &attrs.context_specific { + lowerer.apply_context_specific(tag_number, &attrs.tag_mode, attrs.optional); + } + + if let Some(default) = &attrs.default { + debug_assert!( + !attrs.optional, + "`default`, and `optional` are mutually exclusive" + ); + lowerer.apply_default(&self.ident, default, attrs.context_specific.is_none()); } lowerer.into_tokens() @@ -146,13 +161,14 @@ impl LowerFieldEncoder { /// Create a new field encoder lowerer. fn new(ident: &Ident) -> Self { Self { - encoder: quote!(&self.#ident), + encoder: quote!(self.#ident), } } /// the field encoder to tokens. fn into_tokens(self) -> TokenStream { - self.encoder + let encoder = self.encoder; + quote! { &#encoder } } /// Apply the ASN.1 type (if defined). @@ -163,38 +179,76 @@ impl LowerFieldEncoder { let map_arg = quote!(field); let encoder = asn1_type.encoder(&map_arg); - // TODO(tarcieri): refactor this to get rid of `Result` type annotation quote! { #binding.as_ref().map(|#map_arg| { - let res: der::Result<_> = Ok(#encoder); - res + der::Result::Ok(#encoder) }).transpose()? } } else { let encoder = asn1_type.encoder(binding); - quote!(&#encoder) + quote!(#encoder) }; } /// Handle default value for a type. - fn apply_default(&mut self, default: &Path) { - let encoder = &self.encoder; + fn apply_default(&mut self, ident: &Ident, default: &Path, is_bare: bool) { + let mut encoder = &self.encoder; self.encoder = quote! { - &::der::asn1::OptionalRef(if #encoder == &#default() { + if &self.#ident == &#default() { None } else { Some(#encoder) - }) + } }; + + if is_bare { + encoder = &self.encoder; + self.encoder = quote! { + ::der::asn1::OptionalRef(#encoder.as_ref()) + }; + } + } + + /// Make this field context-specific. + fn apply_context_specific( + &mut self, + tag_number: &TagNumber, + tag_mode: &TagMode, + optional: bool, + ) { + let encoder = &self.encoder; + let number_tokens = tag_number.to_tokens(); + let mode_tokens = tag_mode.to_tokens(); + + if optional { + self.encoder = quote! { + #encoder.as_ref().map(|field| { + ::der::asn1::ContextSpecificRef { + tag_number: #number_tokens, + tag_mode: #mode_tokens, + value: field, + } + }) + }; + } else { + self.encoder = quote! { + ::der::asn1::ContextSpecificRef { + tag_number: #number_tokens, + tag_mode: #mode_tokens, + value: &#encoder, + } + }; + } } } #[cfg(test)] mod tests { use super::SequenceField; - use crate::{FieldAttrs, TagMode}; + use crate::{FieldAttrs, TagMode, TagNumber}; use proc_macro2::Span; + use quote::quote; use syn::{punctuated::Punctuated, Ident, Path, PathSegment, Type, TypePath}; /// Create a [`Type::Path`]. @@ -236,15 +290,74 @@ mod tests { field_type: type_path(field_type), }; - // TODO(tarcieri): better comparison, possibly using `quote!` assert_eq!( field.to_decode_tokens().to_string(), - "let example_field = decoder . decode () ? ;" + quote! { + let example_field = decoder.decode()?; + } + .to_string() + ); + + assert_eq!( + field.to_encode_tokens().to_string(), + quote! { + &self.example_field + } + .to_string() + ); + } + + #[test] + fn implicit() { + let span = Span::call_site(); + let ident = Ident::new("implicit_field", span); + + let attrs = FieldAttrs { + asn1_type: None, + context_specific: Some(TagNumber(0)), + default: None, + extensible: false, + optional: false, + tag_mode: TagMode::Implicit, + }; + + let field_type = Ident::new("String", span); + + let field = SequenceField { + ident, + attrs, + field_type: type_path(field_type), + }; + + assert_eq!( + field.to_decode_tokens().to_string(), + quote! { + let implicit_field = ::der::asn1::ContextSpecific::<>::decode_implicit( + decoder, + ::der::TagNumber::N0 + )? + .ok_or_else(|| { + der::Tag::ContextSpecific { + number: ::der::TagNumber::N0, + constructed: false + } + .value_error() + })? + .value; + } + .to_string() ); assert_eq!( field.to_encode_tokens().to_string(), - "& self . example_field" + quote! { + &::der::asn1::ContextSpecificRef { + tag_number: ::der::TagNumber::N0, + tag_mode: ::der::TagMode::Implicit, + value: &self.implicit_field, + } + } + .to_string() ); } } diff --git a/der/derive/src/tag.rs b/der/derive/src/tag.rs index 055d36a8f..f2e39ec18 100644 --- a/der/derive/src/tag.rs +++ b/der/derive/src/tag.rs @@ -106,7 +106,7 @@ impl Display for TagMode { /// ASN.1 tag numbers (i.e. lower 5 bits of a [`Tag`]). #[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord)] -pub(crate) struct TagNumber(u8); +pub(crate) struct TagNumber(pub u8); impl TagNumber { /// Maximum tag number supported (inclusive). diff --git a/der/tests/derive.rs b/der/tests/derive.rs index 4f3a4c24f..9a019a896 100644 --- a/der/tests/derive.rs +++ b/der/tests/derive.rs @@ -353,6 +353,28 @@ mod sequence { const ALGORITHM_IDENTIFIER_DER: &[u8] = &hex!("30 13 06 07 2a 86 48 ce 3d 02 01 06 08 2a 86 48 ce 3d 03 01 07"); + #[derive(Sequence)] + #[asn1(tag_mode = "IMPLICIT")] + pub struct TypeCheckExpandedSequenceFieldAttributeCombinations<'a> { + pub simple: bool, + #[asn1(type = "BIT STRING")] + pub typed: &'a [u8], + #[asn1(context_specific = "0")] + pub context_specific: bool, + #[asn1(optional = "true")] + pub optional: Option, + #[asn1(default = "default_false_example")] + pub default: bool, + #[asn1(type = "BIT STRING", context_specific = "1")] + pub typed_context_specific: &'a [u8], + #[asn1(context_specific = "2", optional = "true")] + pub context_specific_optional: Option, + #[asn1(context_specific = "3", default = "default_false_example")] + pub context_specific_default: bool, + #[asn1(type = "BIT STRING", context_specific = "4", optional = "true")] + pub typed_context_specific_optional: Option<&'a [u8]>, + } + #[test] fn idp_test() { let idp = IssuingDistributionPointExample::from_der(&hex!("30038101FF")).unwrap();