diff --git a/Cargo.lock b/Cargo.lock index 16d49859..be9cc1c6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -370,8 +370,7 @@ dependencies = [ [[package]] name = "inout" version = "0.2.0-rc.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac5e145e8ade9f74c0a5efc60ccb4e714b0144f7e2220b7ca64254feee71c57f" +source = "git+https://github.com/RustCrypto/utils.git#7560ec86787c1342147a8d8ac61a80f5363589a3" dependencies = [ "hybrid-array", ] @@ -393,6 +392,8 @@ dependencies = [ "ctr", "dbl", "hex-literal 0.4.1", + "hex-literal", + "inout", "subtle", "zeroize", ] diff --git a/Cargo.toml b/Cargo.toml index 24efad89..8307a1d4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,4 +30,7 @@ ghash = { git = "https://github.com/RustCrypto/universal-hashes.git" } pmac = { git = "https://github.com/RustCrypto/MACs.git" } -belt-ctr = { git = "https://github.com/RustCrypto/block-modes.git" } \ No newline at end of file +belt-ctr = { git = "https://github.com/RustCrypto/block-modes.git" } + +# https://github.com/RustCrypto/utils/pull/1170 +inout = { git = "https://github.com/RustCrypto/utils.git" } diff --git a/ocb3/Cargo.toml b/ocb3/Cargo.toml index 4e7fd1c2..74991c49 100644 --- a/ocb3/Cargo.toml +++ b/ocb3/Cargo.toml @@ -23,6 +23,7 @@ dbl = "0.4.0-rc.2" subtle = { version = "2", default-features = false } aead-stream = { version = "=0.6.0-pre", optional = true, default-features = false } zeroize = { version = "1", optional = true, default-features = false } +inout = { version = "0.2.0-rc.4", default-features = false } [dev-dependencies] aead = { version = "0.6.0-rc.0", features = ["dev"], default-features = false } diff --git a/ocb3/src/lib.rs b/ocb3/src/lib.rs index a6808ff9..afcde3a5 100644 --- a/ocb3/src/lib.rs +++ b/ocb3/src/lib.rs @@ -21,11 +21,12 @@ pub use aead::{ use aead::{PostfixTagged, array::ArraySize}; use cipher::{ BlockCipherDecrypt, BlockCipherEncrypt, BlockSizeUser, - consts::{U12, U16}, - typenum::Unsigned, + consts::{U2, U12, U16}, + typenum::Prod, }; use core::marker::PhantomData; use dbl::Dbl; +use inout::{InOut, InOutBuf}; use subtle::ConstantTimeEq; /// Number of L values to be precomputed. Precomputing m values, allows @@ -55,7 +56,9 @@ pub type Nonce = Array; /// OCB3 tag pub type Tag = Array; -pub(crate) type Block = Array; +type BlockSize = U16; +pub(crate) type Block = Array; +type DoubleBlock = Array>; mod sealed { use aead::array::{ @@ -210,34 +213,36 @@ where associated_data: &[u8], buffer: &mut [u8], ) -> aead::Result> { + let buffer = InOutBuf::from(buffer); if (buffer.len() > P_MAX) || (associated_data.len() > A_MAX) { unimplemented!() } // First, try to process many blocks at once. - let (processed_bytes, mut offset_i, mut checksum_i) = self.wide_encrypt(nonce, buffer); + let (tail, index, mut offset_i, mut checksum_i) = self.wide_encrypt(nonce, buffer); - let mut i = (processed_bytes / 16) + 1; + let mut i = index; // Then, process the remaining blocks. - for p_i in Block::slice_as_chunks_mut(&mut buffer[processed_bytes..]).0 { + let (blocks, mut tail): (InOutBuf<'_, '_, Block>, _) = tail.into_chunks(); + + for p_i in blocks { // offset_i = offset_{i-1} xor L_{ntz(i)} inplace_xor(&mut offset_i, &self.ll[ntz(i)]); // checksum_i = checksum_{i-1} xor p_i - inplace_xor(&mut checksum_i, p_i); + inplace_xor(&mut checksum_i, p_i.get_in()); // c_i = offset_i xor ENCIPHER(K, p_i xor offset_i) - let c_i = p_i; - inplace_xor(c_i, &offset_i); - self.cipher.encrypt_block(c_i); - inplace_xor(c_i, &offset_i); + let mut c_i = p_i; + c_i.xor_in2out(&offset_i); + self.cipher.encrypt_block(c_i.get_out()); + inplace_xor(c_i.get_out(), &offset_i); i += 1; } // Process any partial blocks. - if (buffer.len() % 16) != 0 { - let processed_bytes = (i - 1) * 16; - let remaining_bytes = buffer.len() - processed_bytes; + if !tail.is_empty() { + let remaining_bytes = tail.len(); // offset_* = offset_m xor L_* inplace_xor(&mut offset_i, &self.ll_star); @@ -247,15 +252,13 @@ where self.cipher.encrypt_block(&mut pad); // checksum_* = checksum_m xor (P_* || 1 || zeros(127-bitlen(P_*))) let checksum_rhs = &mut [0u8; 16]; - checksum_rhs[..remaining_bytes].copy_from_slice(&buffer[processed_bytes..]); + checksum_rhs[..remaining_bytes].copy_from_slice(tail.get_in()); checksum_rhs[remaining_bytes] = 0b1000_0000; inplace_xor(&mut checksum_i, checksum_rhs.as_ref()); // C_* = P_* xor Pad[1..bitlen(P_*)] - let p_star = &mut buffer[processed_bytes..]; + let p_star = tail.get_out(); let pad = &mut pad[..p_star.len()]; - for (aa, bb) in p_star.iter_mut().zip(pad) { - *aa ^= *bb; - } + tail.xor_in2out(pad); } let tag = self.compute_tag(associated_data, &mut checksum_i, &offset_i); @@ -295,32 +298,32 @@ where if (buffer.len() > C_MAX) || (associated_data.len() > A_MAX) { unimplemented!() } + let buffer = InOutBuf::from(buffer); // First, try to process many blocks at once. - let (processed_bytes, mut offset_i, mut checksum_i) = self.wide_decrypt(nonce, buffer); + let (tail, index, mut offset_i, mut checksum_i) = self.wide_decrypt(nonce, buffer); - let mut i = (processed_bytes / 16) + 1; + let mut i = index; // Then, process the remaining blocks. - let (blocks, _remaining) = Block::slice_as_chunks_mut(&mut buffer[processed_bytes..]); + let (blocks, mut tail): (InOutBuf<'_, '_, Block>, _) = tail.into_chunks(); for c_i in blocks { // offset_i = offset_{i-1} xor L_{ntz(i)} inplace_xor(&mut offset_i, &self.ll[ntz(i)]); // p_i = offset_i xor DECIPHER(K, c_i xor offset_i) - let p_i = c_i; - inplace_xor(p_i, &offset_i); - self.cipher.decrypt_block(p_i); - inplace_xor(p_i, &offset_i); + let mut p_i = c_i; + p_i.xor_in2out(&offset_i); + self.cipher.decrypt_block(p_i.get_out()); + inplace_xor(p_i.get_out(), &offset_i); // checksum_i = checksum_{i-1} xor p_i - inplace_xor(&mut checksum_i, p_i); + inplace_xor(&mut checksum_i, p_i.get_out()); i += 1; } // Process any partial blocks. - if (buffer.len() % 16) != 0 { - let processed_bytes = (i - 1) * 16; - let remaining_bytes = buffer.len() - processed_bytes; + if !tail.is_empty() { + let remaining_bytes = tail.len(); // offset_* = offset_m xor L_* inplace_xor(&mut offset_i, &self.ll_star); @@ -329,14 +332,12 @@ where inplace_xor(&mut pad, &offset_i); self.cipher.encrypt_block(&mut pad); // P_* = C_* xor Pad[1..bitlen(C_*)] - let c_star = &mut buffer[processed_bytes..]; + let c_star = tail.get_in(); let pad = &mut pad[..c_star.len()]; - for (aa, bb) in c_star.iter_mut().zip(pad) { - *aa ^= *bb; - } + tail.xor_in2out(pad); // checksum_* = checksum_m xor (P_* || 1 || zeros(127-bitlen(P_*))) let checksum_rhs = &mut [0u8; 16]; - checksum_rhs[..remaining_bytes].copy_from_slice(&buffer[processed_bytes..]); + checksum_rhs[..remaining_bytes].copy_from_slice(tail.get_out()); checksum_rhs[remaining_bytes] = 0b1000_0000; inplace_xor(&mut checksum_i, checksum_rhs.as_ref()); } @@ -347,81 +348,85 @@ where /// Encrypts plaintext in groups of two. /// /// Adapted from https://www.cs.ucdavis.edu/~rogaway/ocb/news/code/ocb.c - fn wide_encrypt(&self, nonce: &Nonce, buffer: &mut [u8]) -> (usize, Block, Block) { + fn wide_encrypt<'i, 'o>( + &self, + nonce: &Nonce, + buffer: InOutBuf<'i, 'o, u8>, + ) -> (InOutBuf<'i, 'o, u8>, usize, Block, Block) { const WIDTH: usize = 2; let mut i = 1; let mut offset_i = [Block::default(); WIDTH]; - offset_i[offset_i.len() - 1] = initial_offset(&self.cipher, nonce, TagSize::to_u32()); + offset_i[1] = initial_offset(&self.cipher, nonce, TagSize::to_u32()); let mut checksum_i = Block::default(); - for wide_blocks in buffer.chunks_exact_mut(::Size::USIZE * WIDTH) { - let p_i = split_into_two_blocks(wide_blocks); + let (wide_blocks, tail): (InOutBuf<'_, '_, DoubleBlock>, _) = buffer.into_chunks(); + for wide_block in wide_blocks.into_iter() { + let mut p_i = split_into_two_blocks(wide_block); // checksum_i = checksum_{i-1} xor p_i for p_ij in &p_i { - inplace_xor(&mut checksum_i, p_ij); + inplace_xor(&mut checksum_i, p_ij.get_in()); } // offset_i = offset_{i-1} xor L_{ntz(i)} - offset_i[0] = offset_i[offset_i.len() - 1]; + offset_i[0] = offset_i[1]; inplace_xor(&mut offset_i[0], &self.ll[ntz(i)]); - for j in 1..p_i.len() { - offset_i[j] = offset_i[j - 1]; - inplace_xor(&mut offset_i[j], &self.ll[ntz(i + j)]); - } + offset_i[1] = offset_i[0]; + inplace_xor(&mut offset_i[1], &self.ll[ntz(i + 1)]); // c_i = offset_i xor ENCIPHER(K, p_i xor offset_i) for j in 0..p_i.len() { - inplace_xor(p_i[j], &offset_i[j]); - self.cipher.encrypt_block(p_i[j]); - inplace_xor(p_i[j], &offset_i[j]) + p_i[j].xor_in2out(&offset_i[j]); + self.cipher.encrypt_block(p_i[j].get_out()); + inplace_xor(p_i[j].get_out(), &offset_i[j]); } i += WIDTH; } - let processed_bytes = (buffer.len() / (WIDTH * 16)) * (WIDTH * 16); - - (processed_bytes, offset_i[offset_i.len() - 1], checksum_i) + (tail, i, offset_i[offset_i.len() - 1], checksum_i) } /// Decrypts plaintext in groups of two. /// /// Adapted from https://www.cs.ucdavis.edu/~rogaway/ocb/news/code/ocb.c - fn wide_decrypt(&self, nonce: &Nonce, buffer: &mut [u8]) -> (usize, Block, Block) { + fn wide_decrypt<'i, 'o>( + &self, + nonce: &Nonce, + buffer: InOutBuf<'i, 'o, u8>, + ) -> (InOutBuf<'i, 'o, u8>, usize, Block, Block) { const WIDTH: usize = 2; let mut i = 1; let mut offset_i = [Block::default(); WIDTH]; - offset_i[offset_i.len() - 1] = initial_offset(&self.cipher, nonce, TagSize::to_u32()); + offset_i[1] = initial_offset(&self.cipher, nonce, TagSize::to_u32()); let mut checksum_i = Block::default(); - for wide_blocks in buffer.chunks_exact_mut(16 * WIDTH) { - let c_i = split_into_two_blocks(wide_blocks); + + let (wide_blocks, tail): (InOutBuf<'_, '_, DoubleBlock>, _) = buffer.into_chunks(); + for wide_block in wide_blocks.into_iter() { + let mut c_i = split_into_two_blocks(wide_block); // offset_i = offset_{i-1} xor L_{ntz(i)} - offset_i[0] = offset_i[offset_i.len() - 1]; + offset_i[0] = offset_i[1]; inplace_xor(&mut offset_i[0], &self.ll[ntz(i)]); - for j in 1..c_i.len() { - offset_i[j] = offset_i[j - 1]; - inplace_xor(&mut offset_i[j], &self.ll[ntz(i + j)]); - } + offset_i[1] = offset_i[0]; + inplace_xor(&mut offset_i[1], &self.ll[ntz(i + 1)]); // p_i = offset_i xor DECIPHER(K, c_i xor offset_i) // checksum_i = checksum_{i-1} xor p_i for j in 0..c_i.len() { - inplace_xor(c_i[j], &offset_i[j]); - self.cipher.decrypt_block(c_i[j]); - inplace_xor(c_i[j], &offset_i[j]); - inplace_xor(&mut checksum_i, c_i[j]); + c_i[j].xor_in2out(&offset_i[j]); + self.cipher.decrypt_block(c_i[j].get_out()); + inplace_xor(c_i[j].get_out(), &offset_i[j]); + inplace_xor(&mut checksum_i, c_i[j].get_out()); } i += WIDTH; } - let processed_bytes = (buffer.len() / (WIDTH * 16)) * (WIDTH * 16); - (processed_bytes, offset_i[offset_i.len() - 1], checksum_i) + (tail, i, offset_i[offset_i.len() - 1], checksum_i) } /// Computes HASH function defined in https://www.rfc-editor.org/rfc/rfc7253.html#section-4.1 @@ -580,11 +585,10 @@ pub(crate) fn ntz(n: usize) -> usize { } #[inline] -pub(crate) fn split_into_two_blocks(two_blocks: &mut [u8]) -> [&mut Block; 2] { - const BLOCK_SIZE: usize = 16; - debug_assert_eq!(two_blocks.len(), BLOCK_SIZE * 2); - let (b0, b1) = two_blocks.split_at_mut(BLOCK_SIZE); - [b0.try_into().unwrap(), b1.try_into().unwrap()] +pub(crate) fn split_into_two_blocks<'i, 'o>( + two_blocks: InOut<'i, 'o, DoubleBlock>, +) -> [InOut<'i, 'o, Block>; 2] { + Array::, U2>::from(two_blocks).into() } #[cfg(test)]