diff --git a/aead/src/dev.rs b/aead/src/dev.rs index b128014bd..39d83db95 100644 --- a/aead/src/dev.rs +++ b/aead/src/dev.rs @@ -1,6 +1,12 @@ //! Development-related functionality pub use blobby; +#[cfg(all(feature = "alloc", feature = "inout"))] +use { + crate::Tag, alloc::vec, alloc::vec::Vec, core::fmt, crypto_common::typenum::Unsigned, + inout::InOutBuf, +}; + /// Define AEAD test #[macro_export] macro_rules! new_test { @@ -45,6 +51,8 @@ macro_rules! new_test { if res != pt { return Err("decrypted data is different from target plaintext"); } + + $crate::dev::new_test_impl_inout(cipher, nonce, aad, pt, ct, pass)?; Ok(()) } @@ -75,3 +83,127 @@ macro_rules! new_test { } }; } + +/// Helper to run tests against the inout API. +#[cfg(not(all(feature = "alloc", feature = "inout")))] +pub fn new_test_impl_inout( + _cipher: T, + _nonce: &crate::Nonce, + _aad: &[u8], + _pt: &[u8], + _ct: &[u8], + _pass: bool, +) -> Result<(), &'static str> { + Ok(()) +} + +/// Helper to run tests against the inout API. +#[cfg(all(feature = "alloc", feature = "inout"))] +pub fn new_test_impl_inout( + cipher: T, + nonce: &crate::Nonce, + aad: &[u8], + pt: &[u8], + ct: &[u8], + pass: bool, +) -> Result<(), &'static str> { + // Here we assume this is a postfix tagged AEAD + let (ciphertext, expected_tag) = ct.split_at(ct.len() - T::TagSize::to_usize()); + let expected_tag = Tag::::try_from(expected_tag).expect("invariant violation"); + + if !pass { + let mut payload = MockBuffer::from(ciphertext); + let res = cipher.decrypt_inout_detached(nonce, aad, payload.to_in_out_buf(), &expected_tag); + if res.is_ok() { + return Err("Decryption must return an error"); + } + return Ok(()); + } + + let mut payload = MockBuffer::from(pt); + let tag = cipher + .encrypt_inout_detached(nonce, aad, payload.to_in_out_buf()) + .map_err(|_| "encryption failure")?; + + // Here we assume this is a postfix tagged AEAD + let (ciphertext, _tag) = ct.split_at(ct.len() - T::TagSize::to_usize()); + if payload.as_ref() != ciphertext { + return Err("encrypted data is different from target ciphertext"); + } + + let mut payload = MockBuffer::from(ciphertext); + cipher + .decrypt_inout_detached(nonce, aad, payload.to_in_out_buf(), &tag) + .map_err(|_| "decryption failure")?; + + if payload.as_ref() != pt { + return Err("decrypted data is different from target plaintext"); + } + + Ok(()) +} + +/// [`MockBuffer`] is meant for testing InOut-backed APIs. +/// +/// It will split the initial buffer in two different backing buffers. The out buffer will be +/// zeroed. +#[cfg(all(feature = "alloc", feature = "inout"))] +pub struct MockBuffer { + in_buf: Vec, + out_buf: Vec, +} + +#[cfg(all(feature = "alloc", feature = "inout"))] +impl AsRef<[u8]> for MockBuffer { + fn as_ref(&self) -> &[u8] { + &self.out_buf + } +} + +#[cfg(all(feature = "alloc", feature = "inout"))] +impl From<&[u8]> for MockBuffer { + fn from(buf: &[u8]) -> Self { + Self { + in_buf: buf.to_vec(), + out_buf: vec![0u8; buf.len()], + } + } +} + +#[cfg(all(feature = "alloc", feature = "inout"))] +impl From> for MockBuffer { + fn from(buf: Vec) -> Self { + Self { + out_buf: vec![0u8; buf.len()], + in_buf: buf, + } + } +} + +#[cfg(all(feature = "alloc", feature = "inout"))] +impl MockBuffer { + /// Get an [`InOutBuf`] from a [`MockBuffer`] + pub fn to_in_out_buf(&mut self) -> InOutBuf<'_, '_, u8> { + InOutBuf::new(self.in_buf.as_slice(), self.out_buf.as_mut_slice()) + .expect("Invariant violation") + } + + /// Return the length of the payload + #[inline] + pub fn len(&self) -> usize { + self.in_buf.len() + } + + /// Is the payload empty? + #[inline] + pub fn is_empty(&self) -> bool { + self.in_buf.is_empty() + } +} + +#[cfg(all(feature = "alloc", feature = "inout"))] +impl fmt::Debug for MockBuffer { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "MockBuffer {{...}}") + } +}