diff --git a/base64ct/src/decoder.rs b/base64ct/src/decoder.rs index 364f10ad4..9a02ae184 100644 --- a/base64ct/src/decoder.rs +++ b/base64ct/src/decoder.rs @@ -1,6 +1,7 @@ //! Buffered Base64 decoder. use crate::{ + encoding, variant::Variant, Encoding, Error::{self, InvalidLength}, @@ -33,6 +34,9 @@ pub struct Decoder<'i, E: Variant> { /// Base64 input data reader. line_reader: LineReader<'i>, + /// Length of the remaining data after Base64 decoding. + decoded_len: usize, + /// Block buffer used for non-block-aligned data. block_buffer: BlockBuffer, @@ -48,13 +52,13 @@ impl<'i, E: Variant> Decoder<'i, E> { /// - `Ok(decoder)` on success. /// - `Err(Error::InvalidLength)` if the input buffer is empty. pub fn new(input: &'i [u8]) -> Result { - if input.is_empty() { - return Err(InvalidLength); - } + let line_reader = LineReader::new_unwrapped(input)?; + let decoded_len = line_reader.decoded_len::()?; Ok(Self { - line: Line::new(input), - line_reader: LineReader::default(), + line: Line::default(), + line_reader, + decoded_len, block_buffer: BlockBuffer::default(), encoding: PhantomData, }) @@ -85,13 +89,13 @@ impl<'i, E: Variant> Decoder<'i, E> { /// /// [RFC7468]: https://datatracker.ietf.org/doc/html/rfc7468 pub fn new_wrapped(input: &'i [u8], line_width: usize) -> Result { - if input.is_empty() { - return Err(InvalidLength); - } + let line_reader = LineReader::new_wrapped(input, line_width)?; + let decoded_len = line_reader.decoded_len::()?; Ok(Self { line: Line::default(), - line_reader: LineReader::new(input, line_width)?, + line_reader, + decoded_len, block_buffer: BlockBuffer::default(), encoding: PhantomData, }) @@ -153,9 +157,21 @@ impl<'i, E: Variant> Decoder<'i, E> { } } + self.decoded_len = self + .decoded_len + .checked_sub(out.len()) + .ok_or(InvalidLength)?; + Ok(out) } + /// Get the length of the remaining data after Base64 decoding. + /// + /// Decreases every time data is decoded. + pub fn decoded_len(&self) -> usize { + self.decoded_len + } + /// Has all of the input data been decoded? pub fn is_finished(&self) -> bool { self.line.is_empty() && self.line_reader.is_empty() && self.block_buffer.is_empty() @@ -285,8 +301,8 @@ impl<'i> Default for Line<'i> { } impl<'i> Line<'i> { - /// Create a new line which wraps the given input data - pub fn new(bytes: &'i [u8]) -> Self { + /// Create a new line which wraps the given input data. + fn new(bytes: &'i [u8]) -> Self { Self { remaining: bytes } } @@ -302,6 +318,12 @@ impl<'i> Line<'i> { bytes } + /// Slice off a tail of a given length. + fn slice_tail(&self, nbytes: usize) -> Result<&'i [u8], Error> { + let offset = self.len().checked_sub(nbytes).ok_or(InvalidLength)?; + self.remaining.get(offset..).ok_or(InvalidLength) + } + /// Get the number of bytes remaining in this line. fn len(&self) -> usize { self.remaining.len() @@ -311,10 +333,20 @@ impl<'i> Line<'i> { fn is_empty(&self) -> bool { self.len() == 0 } + + /// Trim the newline off the end of this line. + fn trim_end(&self) -> Self { + Line::new(match self.remaining { + [line @ .., CHAR_CR, CHAR_LF] => line, + [line @ .., CHAR_CR] => line, + [line @ .., CHAR_LF] => line, + line => line, + }) + } } /// Iterator over multi-line Base64 input. -#[derive(Clone, Default)] +#[derive(Clone)] struct LineReader<'i> { /// Remaining linewrapped data to be processed. remaining: &'i [u8], @@ -324,22 +356,103 @@ struct LineReader<'i> { } impl<'i> LineReader<'i> { + /// Create a new reader which operates over continugous unwrapped data. + fn new_unwrapped(bytes: &'i [u8]) -> Result { + if bytes.is_empty() { + Err(InvalidLength) + } else { + Ok(Self { + remaining: bytes, + line_width: None, + }) + } + } + /// Create a new reader which operates over linewrapped data. - fn new(bytes: &'i [u8], line_width: usize) -> Result { - if line_width == 0 { + fn new_wrapped(bytes: &'i [u8], line_width: usize) -> Result { + if line_width < 4 { return Err(InvalidLength); } - Ok(Self { - remaining: bytes, - line_width: Some(line_width), - }) + let mut reader = Self::new_unwrapped(bytes)?; + reader.line_width = Some(line_width); + Ok(reader) } /// Is this line reader empty? fn is_empty(&self) -> bool { self.remaining.is_empty() } + + /// Get the total length of the data decoded from this line reader. + fn decoded_len(&self) -> Result { + let mut buffer = [0u8; 4]; + let mut lines = self.clone(); + let mut line = match lines.next().transpose()? { + Some(l) => l, + None => return Ok(0), + }; + let mut base64_len = 0usize; + + loop { + base64_len = base64_len.checked_add(line.len()).ok_or(InvalidLength)?; + + match lines.next().transpose()? { + Some(l) => { + // Store the end of the line in the buffer so we can + // reassemble the last block to determine the real length + buffer.copy_from_slice(line.slice_tail(4)?); + + line = l + } + + // To compute an exact decoded length we need to decode the + // last Base64 block and get the decoded length. + // + // This is what the somewhat complex code below is doing. + None => { + // Compute number of bytes in the last block (may be unpadded) + let base64_last_block_len = match base64_len % 4 { + 0 => 4, + n => n, + }; + + // Compute decoded length without the last block + let decoded_len = encoding::decoded_len( + base64_len + .checked_sub(base64_last_block_len) + .ok_or(InvalidLength)?, + ); + + // Compute the decoded length of the last block + let mut out = [0u8; 3]; + let last_block_len = if line.len() < base64_last_block_len { + let buffered_part_len = base64_last_block_len + .checked_sub(line.len()) + .ok_or(InvalidLength)?; + + let offset = 4usize.checked_sub(buffered_part_len).ok_or(InvalidLength)?; + + for i in 0..buffered_part_len { + buffer[i] = buffer[offset.checked_add(i).ok_or(InvalidLength)?]; + } + + buffer[buffered_part_len..][..line.len()].copy_from_slice(line.remaining); + let buffer_len = buffered_part_len + .checked_add(line.len()) + .ok_or(InvalidLength)?; + + E::decode(&buffer[..buffer_len], &mut out)?.len() + } else { + let last_block = line.slice_tail(base64_last_block_len)?; + E::decode(last_block, &mut out)?.len() + }; + + return decoded_len.checked_add(last_block_len).ok_or(InvalidLength); + } + } + } + } } impl<'i> Iterator for LineReader<'i> { @@ -352,7 +465,7 @@ impl<'i> Iterator for LineReader<'i> { if self.remaining.is_empty() { return None; } else { - let line = Line::new(self.remaining); + let line = Line::new(self.remaining).trim_end(); self.remaining = &[]; return Some(Ok(line)); } @@ -369,6 +482,15 @@ impl<'i> Iterator for LineReader<'i> { let line = Line::new(&self.remaining[..line_width]); self.remaining = rest; Some(Ok(line)) + } else if !self.remaining.is_empty() { + let line = Line::new(self.remaining).trim_end(); + self.remaining = b""; + + if line.is_empty() { + None + } else { + Some(Ok(line)) + } } else { None } @@ -416,15 +538,20 @@ mod tests { { for chunk_size in 1..expected.len() { let mut decoder = f(); + let mut remaining_len = decoder.decoded_len(); let mut buffer = [0u8; 1024]; for chunk in expected.chunks(chunk_size) { assert!(!decoder.is_finished()); let decoded = decoder.decode(&mut buffer[..chunk.len()]).unwrap(); assert_eq!(chunk, decoded); + + remaining_len -= decoded.len(); + assert_eq!(remaining_len, decoder.decoded_len()); } assert!(decoder.is_finished()); + assert_eq!(decoder.decoded_len(), 0); } } } diff --git a/base64ct/src/encoding.rs b/base64ct/src/encoding.rs index 3173880d4..77b6a90f9 100644 --- a/base64ct/src/encoding.rs +++ b/base64ct/src/encoding.rs @@ -321,7 +321,7 @@ fn validate_padding(encoded: &[u8], decoded: &[u8]) -> Result<(), Er /// Note that this function does not fully validate the Base64 is well-formed /// and may return incorrect results for malformed Base64. #[inline(always)] -fn decoded_len(input_len: usize) -> usize { +pub(crate) fn decoded_len(input_len: usize) -> usize { // overflow-proof computation of `(3*n)/4` let k = input_len / 4; let l = input_len - 4 * k; diff --git a/base64ct/tests/proptests.proptest-regressions b/base64ct/tests/proptests.proptest-regressions index 9f31a8ac0..473fcecd2 100644 --- a/base64ct/tests/proptests.proptest-regressions +++ b/base64ct/tests/proptests.proptest-regressions @@ -7,3 +7,4 @@ cc ea4af6a6a3c5feddd17be51d3bb3d863881547acf50b553e76da3f34f8b755d4 # shrinks to base64ish = "" cc 348d4acf2c3d1e8db3772f5645179e24b50178747469da9709e60800175eef80 # shrinks to bytes = [240, 144, 128, 128, 240, 144, 128, 128, 32, 32, 32, 194, 161, 48, 97, 97, 65, 194, 161, 32, 97, 194, 161, 32, 240, 144, 128, 128, 194, 161, 48, 32, 97, 194, 161, 240, 144, 128, 128, 32, 224, 160, 128, 97, 224, 160, 128, 48, 48, 194, 161, 32, 240, 144, 128, 128, 11, 65, 97, 48, 65, 65, 97, 11, 240, 144, 128, 128, 240, 144, 128, 128, 48, 224, 160, 128, 194, 161, 32, 32, 194, 161, 32, 48, 97, 240, 144, 128, 128, 224, 160, 128, 240, 144, 128, 128, 0, 224, 160, 128, 32, 240, 144, 128, 128, 0, 32, 32, 97, 240, 144, 128, 128, 240, 144, 128, 128, 240, 144, 128, 128, 240, 144, 128, 128, 0, 0, 240, 144, 128, 128, 32, 240, 144, 128, 128, 32, 48, 65, 11, 32, 65, 48, 48, 65, 65, 194, 161, 32, 224, 160, 128, 240, 144, 128, 128, 224, 160, 128, 0, 65, 0, 65, 32, 194, 161, 240, 144, 128, 128, 32, 65, 32, 0, 97, 32, 97, 11, 11, 48, 97, 97, 240, 144, 128, 128, 65, 240, 144, 128, 128, 194, 161], line_width = 10, chunk_size = 163 cc 0c0ee7f6a60d24431333f5c39c506b818a6c21022e39288619c8f78f29d30b1c # shrinks to bytes = [240, 144, 128, 128, 194, 161, 194, 161, 240, 144, 128, 128, 194, 161, 240, 144, 128, 128, 65, 224, 160, 128, 97, 224, 160, 128, 32, 97, 32, 65, 224, 160, 128, 0, 97, 0, 240, 144, 128, 128, 97, 194, 161, 32, 240, 144, 128, 128, 11, 48, 32, 65, 32, 240, 144, 128, 128, 97, 194, 161, 48, 48, 240, 144, 128, 128, 194, 161, 194, 161, 32, 194, 161, 48, 0, 32, 48, 224, 160, 128, 65, 240, 144, 128, 128, 11, 65, 11, 240, 144, 128, 128, 32, 32, 194, 161, 240, 144, 128, 128, 224, 160, 128, 240, 144, 128, 128, 194, 161, 224, 160, 128, 65, 32, 240, 144, 128, 128, 32, 240, 144, 128, 128, 48, 240, 144, 128, 128, 0, 48, 240, 144, 128, 128, 48, 65, 65, 11, 0, 65, 240, 144, 128, 128, 240, 144, 128, 128, 32, 65, 240, 144, 128, 128, 112, 75, 46, 232, 143, 132, 240, 159, 149, 180, 101, 92, 11, 42, 98, 244, 142, 150, 136, 83, 13, 243, 189, 168, 131, 194, 154, 9, 243, 129, 165, 130, 241, 138, 188, 150, 39, 241, 170, 133, 154, 39, 61, 244, 136, 146, 157, 46, 91, 108, 34, 66, 0, 239, 187, 191, 34, 240, 158, 187, 152, 241, 187, 172, 188, 46, 239, 191, 189, 244, 143, 139, 131, 13, 13, 226, 128, 174, 60, 200, 186, 194, 151, 27, 105, 43, 226, 128, 174, 70, 0, 38, 127, 194, 133, 195, 177, 123, 127, 121, 241, 128, 141, 141, 244, 137, 146, 189, 55, 54, 9, 240, 159, 149, 180, 2, 209, 168, 239, 187, 191, 11, 34, 123, 32, 42, 242, 171, 149, 149, 102, 241, 174, 190, 188, 242, 144, 186, 145, 1, 84, 34, 56, 7, 0, 194, 188, 43, 117, 48, 96, 11, 60, 242, 190, 170, 187, 47, 99, 37, 241, 175, 142, 186, 240, 178, 162, 136, 46, 2, 241, 176, 162, 162, 37, 242, 148, 135, 179, 11, 36, 104, 244, 130, 136, 177], line_width = 24, chunk_size = 240 +cc b6d81102accbff17f00786b06c6040fc59fee8aa087033c9b5604d2a3f246afd # shrinks to bytes = [32, 65, 11, 97, 97, 32, 240, 144, 128, 128, 97, 32, 65, 0, 0, 32, 240, 144, 128, 128, 97, 65, 97, 97, 240, 144, 128, 128, 240, 144, 128, 128, 65, 48, 240, 144, 128, 128, 240, 144, 128, 128, 32, 0, 97, 97, 240, 144, 128, 128, 65, 32, 194, 161, 65, 0, 32, 11, 97, 32, 32, 11, 32, 240, 144, 128, 128, 240, 144, 128, 128, 194, 128, 32, 48, 65, 32, 240, 144, 128, 128, 240, 144, 128, 128, 240, 144, 128, 128, 194, 161, 32, 194, 161, 48, 224, 160, 128, 240, 144, 128, 128, 97, 32, 0, 48, 240, 144, 128, 128, 0, 11, 240, 144, 128, 128, 97, 240, 144, 128, 128, 11, 32, 0, 32, 0, 194, 161, 194, 161, 56, 242, 150, 180, 168, 243, 187, 153, 181, 46, 36, 121, 70, 8, 226, 128, 174, 242, 135, 172, 189, 0, 194, 169, 244, 130, 145, 146, 240, 159, 149, 180, 63, 240, 184, 155, 139, 27, 243, 185, 138, 139, 194, 162, 46, 242, 148, 129, 171, 195, 143, 56, 241, 147, 151, 173, 240, 159, 149, 180, 33, 89, 36, 37, 240, 159, 149, 180, 200, 186, 117, 194, 165, 77, 241, 171, 180, 143, 60, 96, 242, 175, 134, 177, 27, 1, 42, 242, 145, 189, 151, 92, 39, 96, 38, 243, 181, 148, 171, 243, 164, 185, 188, 47, 195, 181, 0, 226, 128, 174, 13, 233, 136, 141, 57, 200, 186, 243, 129, 145, 159, 242, 137, 177, 176, 122, 61, 243, 140, 180, 151, 239, 191, 189, 80, 194, 144, 121, 42, 239, 191, 189, 231, 173, 145, 75, 91, 0, 123, 238, 154, 139, 58, 240, 179, 187, 172, 107, 13, 13, 123, 241, 152, 132, 160, 242, 130, 149, 190, 92, 239, 187, 191, 117, 241, 182, 130, 165, 241, 165, 155, 168, 39, 60, 0, 0, 13, 200, 186, 83, 37, 243, 174, 183, 166, 11, 0, 237, 134, 157, 39, 58, 113, 44, 243, 135, 142, 174, 9, 9, 195, 184, 74, 241, 146, 132, 133, 34, 58, 92, 123, 239, 187, 191, 37, 58, 239, 187, 191, 77, 9, 243, 183, 143, 189, 243, 159, 143, 171, 243, 162, 128, 179, 241, 137, 158, 163, 127, 60, 195, 159, 106, 47, 242, 135, 154, 161, 51, 243, 160, 136, 149, 91, 241, 175, 181, 149, 96, 58, 46, 11, 37, 107, 32, 52, 237, 136, 144, 77, 194, 156, 42, 13, 39, 61, 2, 59, 48, 58, 240, 159, 149, 180, 4, 96, 127, 230, 166, 145, 58, 239, 187, 191, 242, 135, 132, 146, 241, 178, 129, 185, 36], line_width = 118, chunk_size = 147 diff --git a/base64ct/tests/proptests.rs b/base64ct/tests/proptests.rs index 7ce36f608..f0980940a 100644 --- a/base64ct/tests/proptests.rs +++ b/base64ct/tests/proptests.rs @@ -33,15 +33,20 @@ proptest! { let mut buffer = [0u8; 384]; let mut decoder = Decoder::new(encoded.as_bytes()).unwrap(); + let mut remaining_len = decoder.decoded_len(); for chunk in bytes.chunks(chunk_size) { prop_assert!(!decoder.is_finished()); let decoded = decoder.decode(&mut buffer[..chunk.len()]); prop_assert_eq!(Ok(chunk), decoded); + + remaining_len -= decoded.unwrap().len(); + prop_assert_eq!(remaining_len, decoder.decoded_len()); } prop_assert!(decoder.is_finished()); + prop_assert_eq!(decoder.decoded_len(), 0); } #[test] @@ -76,15 +81,20 @@ proptest! { let mut buffer = [0u8; 384]; let mut decoder = Decoder::new_wrapped(&encoded_wrapped, line_width).unwrap(); + let mut remaining_len = decoder.decoded_len(); for chunk in bytes.chunks(chunk_size) { prop_assert!(!decoder.is_finished()); let decoded = decoder.decode(&mut buffer[..chunk.len()]); prop_assert_eq!(Ok(chunk), decoded); + + remaining_len -= decoded.unwrap().len(); + prop_assert_eq!(remaining_len, decoder.decoded_len()); } prop_assert!(decoder.is_finished()); + prop_assert_eq!(decoder.decoded_len(), 0); } }