diff --git a/tls_codec/src/tls_vec.rs b/tls_codec/src/tls_vec.rs index da0c9f0a0..6540eafbf 100644 --- a/tls_codec/src/tls_vec.rs +++ b/tls_codec/src/tls_vec.rs @@ -13,7 +13,7 @@ use serde::ser::SerializeStruct; use std::io::{Read, Write}; use zeroize::Zeroize; -use crate::{Deserialize, DeserializeBytes, Error, Serialize, Size}; +use crate::{Deserialize, DeserializeBytes, Error, Serialize, SerializeBytes, Size}; macro_rules! impl_size { ($self:ident, $size:ty, $name:ident, $len_len:literal) => { @@ -127,38 +127,16 @@ macro_rules! impl_serialize { fn serialize(&$self, writer: &mut W) -> Result { // Get the byte length of the content, make sure it's not too // large and write it out. - let tls_serialized_len = $self.tls_serialized_len(); - let byte_length = tls_serialized_len - $len_len; - - let max_len = <$size>::MAX as usize; - debug_assert!( - byte_length <= max_len, - "Vector length can't be encoded in the vector length a {} >= {}", - byte_length, - max_len - ); - if byte_length > max_len { - return Err(Error::InvalidVectorLength); - } + let (tls_serialized_len, byte_length) = $self.get_content_lengths()?; - let mut written = (byte_length as $size).tls_serialize(writer)?; + let mut written = <$size as Serialize>::tls_serialize(&(byte_length as $size), writer)?; // Now serialize the elements for e in $self.as_slice().iter() { written += e.tls_serialize(writer)?; } - debug_assert_eq!( - written, tls_serialized_len, - "{} bytes should have been serialized but {} were written", - tls_serialized_len, written - ); - if written != tls_serialized_len { - return Err(Error::EncodingError(format!( - "{} bytes should have been serialized but {} were written", - tls_serialized_len, written - ))); - } + $self.assert_written_bytes(tls_serialized_len, written)?; Ok(written) } }; @@ -171,6 +149,23 @@ macro_rules! impl_byte_serialize { fn serialize_bytes(&$self, writer: &mut W) -> Result { // Get the byte length of the content, make sure it's not too // large and write it out. + let (tls_serialized_len, byte_length) = $self.get_content_lengths()?; + + let mut written = <$size as Serialize>::tls_serialize(&(byte_length as $size), writer)?; + + // Now serialize the elements + written += writer.write($self.as_slice())?; + + $self.assert_written_bytes(tls_serialized_len, written)?; + Ok(written) + } + }; +} + +macro_rules! impl_serialize_common { + ($self:ident, $size:ty, $name:ident, $len_len:literal $(,#[$std_enabled:meta])?) => { + $(#[$std_enabled])? + fn get_content_lengths(&$self) -> Result<(usize, usize), Error> { let tls_serialized_len = $self.tls_serialized_len(); let byte_length = tls_serialized_len - $len_len; @@ -184,12 +179,11 @@ macro_rules! impl_byte_serialize { if byte_length > max_len { return Err(Error::InvalidVectorLength); } + Ok((tls_serialized_len, byte_length)) + } - let mut written = (byte_length as $size).tls_serialize(writer)?; - - // Now serialize the elements - written += writer.write($self.as_slice())?; - + $(#[$std_enabled])? + fn assert_written_bytes(&$self, tls_serialized_len: usize, written: usize) -> Result<(), Error> { debug_assert_eq!( written, tls_serialized_len, "{} bytes should have been serialized but {} were written", @@ -201,7 +195,28 @@ macro_rules! impl_byte_serialize { tls_serialized_len, written ))); } - Ok(written) + Ok(()) + } + }; +} + +macro_rules! impl_serialize_bytes_bytes { + ($self:ident, $size:ty, $name:ident, $len_len:literal) => { + fn serialize_bytes_bytes(&$self) -> Result, Error> { + let (tls_serialized_len, byte_length) = $self.get_content_lengths()?; + + let mut vec = Vec::::with_capacity(tls_serialized_len); + let length_vec = <$size as SerializeBytes>::tls_serialize(&(byte_length as $size))?; + let mut written = length_vec.len(); + vec.extend_from_slice(&length_vec); + + let bytes = $self.as_slice(); + vec.extend_from_slice(bytes); + written += bytes.len(); + + $self.assert_written_bytes(tls_serialized_len, written)?; + + Ok(vec) } }; } @@ -293,6 +308,12 @@ macro_rules! impl_tls_vec_codec_bytes { Self::deserialize_bytes_bytes(bytes) } } + + impl SerializeBytes for $name { + fn tls_serialize(&self) -> Result, Error> { + self.serialize_bytes_bytes() + } + } }; } @@ -789,6 +810,7 @@ macro_rules! impl_secret_tls_vec { impl_tls_vec_codec_generic!($size, $name, $len_len, Zeroize); impl $name { + impl_serialize_common!(self, $size, $name, $len_len, #[cfg(feature = "std")]); impl_serialize!(self, $size, $name, $len_len); } @@ -825,6 +847,7 @@ macro_rules! impl_public_tls_vec { impl_tls_vec_codec_generic!($size, $name, $len_len); impl $name { + impl_serialize_common!(self, $size, $name, $len_len, #[cfg(feature = "std")]); impl_serialize!(self, $size, $name, $len_len); } @@ -848,7 +871,9 @@ macro_rules! impl_tls_byte_vec { impl $name { // This implements serialize and size for all versions + impl_serialize_common!(self, $size, $name, $len_len); impl_byte_serialize!(self, $size, $name, $len_len); + impl_serialize_bytes_bytes!(self, $size, $name, $len_len); impl_byte_size!(self, $size, $name, $len_len); impl_byte_deserialize!(self, $size, $name, $len_len); } @@ -885,6 +910,7 @@ macro_rules! impl_tls_byte_slice { } impl<'a> $name<'a> { + impl_serialize_common!(self, $size, $name, $len_len, #[cfg(feature = "std")]); impl_byte_serialize!(self, $size, $name, $len_len); impl_byte_size!(self, $size, $name, $len_len); } @@ -940,6 +966,7 @@ macro_rules! impl_tls_slice { } impl<'a, T: Serialize> $name<'a, T> { + impl_serialize_common!(self, $size, $name, $len_len, #[cfg(feature = "std")]); impl_serialize!(self, $size, $name, $len_len); } diff --git a/tls_codec/tests/encode_bytes.rs b/tls_codec/tests/encode_bytes.rs index 20830b5e0..f9b30224b 100644 --- a/tls_codec/tests/encode_bytes.rs +++ b/tls_codec/tests/encode_bytes.rs @@ -1,4 +1,4 @@ -use tls_codec::SerializeBytes; +use tls_codec::{SerializeBytes, TlsByteVecU16, TlsByteVecU32, TlsByteVecU8}; #[test] fn serialize_primitives() { @@ -40,3 +40,30 @@ fn serialize_var_len_boundaries() { let serialized = v.tls_serialize().expect("Error encoding vector"); assert_eq!(&serialized[0..5], &[0x80, 0, 0x40, 0, 99]); } + +#[test] +fn serialize_tls_byte_vec_u8() { + let byte_vec = TlsByteVecU8::from_slice(&[1, 2, 3]); + let actual_result = byte_vec + .tls_serialize() + .expect("Error encoding byte vector"); + assert_eq!(actual_result, vec![3, 1, 2, 3]); +} + +#[test] +fn serialize_tls_byte_vec_u16() { + let byte_vec = TlsByteVecU16::from_slice(&[1, 2, 3]); + let actual_result = byte_vec + .tls_serialize() + .expect("Error encoding byte vector"); + assert_eq!(actual_result, vec![0, 3, 1, 2, 3]); +} + +#[test] +fn serialize_tls_byte_vec_u32() { + let byte_vec = TlsByteVecU32::from_slice(&[1, 2, 3]); + let actual_result = byte_vec + .tls_serialize() + .expect("Error encoding byte vector"); + assert_eq!(actual_result, vec![0, 0, 0, 3, 1, 2, 3]); +}