diff --git a/src/lib.rs b/src/lib.rs index b4681cc86..7afa680eb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -189,7 +189,7 @@ pub use crate::{ pub use subtle; #[cfg(feature = "alloc")] -pub use crate::uint::boxed::{encoding::DecodeError, BoxedUint}; +pub use crate::uint::boxed::BoxedUint; #[cfg(feature = "hybrid-array")] pub use { diff --git a/src/traits.rs b/src/traits.rs index 1dde92611..ff53cf48a 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -9,7 +9,7 @@ pub use num_traits::{ pub(crate) use sealed::PrecomputeInverterWithAdjuster; use crate::{Limb, NonZero, Odd, Reciprocal}; -use core::fmt::Debug; +use core::fmt::{self, Debug}; use core::ops::{ Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Div, DivAssign, Mul, MulAssign, Neg, Not, Rem, Shl, ShlAssign, Shr, ShrAssign, Sub, SubAssign, @@ -22,9 +22,6 @@ use subtle::{ #[cfg(feature = "rand_core")] use rand_core::CryptoRngCore; -#[cfg(feature = "rand_core")] -use core::fmt; - /// Integers whose representation takes a bounded amount of space. pub trait Bounded { /// Size of this integer in bits. @@ -541,6 +538,41 @@ pub trait Encoding: Sized { fn to_le_bytes(&self) -> Self::Repr; } +/// Possible errors in variable-time integer decoding methods. +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum DecodeError { + /// The input value was empty. + Empty, + + /// The input was not consistent with the format restrictions. + InvalidDigit, + + /// Input size is too small to fit in the given precision. + InputSize, + + /// The deserialized number is larger than the given precision. + Precision, +} + +impl fmt::Display for DecodeError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Empty => write!(f, "empty value provided"), + Self::InvalidDigit => { + write!(f, "invalid digit character") + } + Self::InputSize => write!(f, "input size is too small to fit in the given precision"), + Self::Precision => write!( + f, + "the deserialized number is larger than the given precision" + ), + } + } +} + +#[cfg(feature = "std")] +impl std::error::Error for DecodeError {} + /// Support for optimized squaring pub trait Square { /// Computes the same as `self * self`, but may be more efficient. diff --git a/src/uint/boxed/encoding.rs b/src/uint/boxed/encoding.rs index 1baae6adf..650b6272d 100644 --- a/src/uint/boxed/encoding.rs +++ b/src/uint/boxed/encoding.rs @@ -1,36 +1,10 @@ //! Const-friendly decoding operations for [`BoxedUint`]. use super::BoxedUint; -use crate::{uint::encoding, Limb, Word}; -use alloc::boxed::Box; -use core::fmt; +use crate::{uint::encoding, DecodeError, Limb, Word}; +use alloc::{boxed::Box, vec::Vec}; use subtle::{Choice, CtOption}; -/// Decoding errors for [`BoxedUint`]. -#[derive(Clone, Copy, Debug, Eq, PartialEq)] -pub enum DecodeError { - /// Input size is too small to fit in the given precision. - InputSize, - - /// The deserialized number is larger than the given precision. - Precision, -} - -impl fmt::Display for DecodeError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::InputSize => write!(f, "input size is too small to fit in the given precision"), - Self::Precision => write!( - f, - "the deserialized number is larger than the given precision" - ), - } - } -} - -#[cfg(feature = "std")] -impl std::error::Error for DecodeError {} - impl BoxedUint { /// Create a new [`BoxedUint`] from the provided big endian bytes. /// @@ -142,7 +116,6 @@ impl BoxedUint { bytes.len() == Limb::BYTES * nlimbs * 2, "hex string is not the expected size" ); - let mut res = vec![Limb::ZERO; nlimbs]; let mut buf = [0u8; Limb::BYTES]; let mut i = 0; @@ -163,6 +136,76 @@ impl BoxedUint { } CtOption::new(Self { limbs: res.into() }, Choice::from((err == 0) as u8)) } + + /// Create a new [`BoxedUint`] from a big-endian string in a given base. + /// + /// The string may begin with a `+` character, and may use underscore + /// characters to separate digits. + /// + /// If the input value contains non-digit characters or digits outside of the range `0..radix` + /// this function will return [`DecodeError::InvalidDigit`]. + /// Panics if `radix` is not in the range from 2 to 36. + pub fn from_str_radix_vartime(src: &str, radix: u32) -> Result { + let mut dec = VecDecodeByLimb::default(); + encoding::decode_str_radix(src, radix, &mut dec)?; + Ok(Self { + limbs: dec.limbs.into(), + }) + } + + /// Create a new [`BoxedUint`] from a big-endian string in a given base, + /// with a given precision. + /// + /// The string may begin with a `+` character, and may use underscore + /// characters to separate digits. + /// + /// The `bits_precision` argument represents the precision of the resulting integer, which is + /// fixed as this type is not arbitrary-precision. + /// The new [`BoxedUint`] will be created with `bits_precision` rounded up to a multiple + /// of [`Limb::BITS`]. + /// + /// If the input value contains non-digit characters or digits outside of the range `0..radix` + /// this function will return [`DecodeError::InvalidDigit`]. + /// If the length of `bytes` is larger than `bits_precision` (rounded up to a multiple of 8) + /// this function will return [`DecodeError::InputSize`]. + /// If the size of the decoded integer is larger than `bits_precision`, + /// this function will return [`DecodeError::Precision`]. + /// Panics if `radix` is not in the range from 2 to 36. + pub fn from_str_radix_with_precision_vartime( + src: &str, + radix: u32, + bits_precision: u32, + ) -> Result { + let mut ret = Self::zero_with_precision(bits_precision); + encoding::decode_str_radix( + src, + radix, + &mut encoding::SliceDecodeByLimb::new(&mut ret.limbs), + )?; + if bits_precision < ret.bits() { + return Err(DecodeError::Precision); + } + Ok(ret) + } +} + +/// Decoder target producing a Vec +#[derive(Default)] +struct VecDecodeByLimb { + limbs: Vec, +} + +impl encoding::DecodeByLimb for VecDecodeByLimb { + #[inline] + fn limbs_mut(&mut self) -> &mut [Limb] { + self.limbs.as_mut_slice() + } + + #[inline] + fn push_limb(&mut self, limb: Limb) -> bool { + self.limbs.push(limb); + true + } } #[cfg(test)] @@ -381,4 +424,38 @@ mod tests { let n = BoxedUint::from_be_slice(&bytes, 128).unwrap(); assert_eq!(bytes.as_slice(), &*n.to_be_bytes()); } + + #[test] + fn from_str_radix_invalid() { + assert_eq!( + BoxedUint::from_str_radix_vartime("?", 10,), + Err(DecodeError::InvalidDigit) + ); + assert_eq!( + BoxedUint::from_str_radix_with_precision_vartime( + "ffffffffffffffff_ffffffffffffffff_f", + 16, + 128 + ), + Err(DecodeError::InputSize) + ); + assert_eq!( + BoxedUint::from_str_radix_with_precision_vartime("1111111111111111", 2, 10), + Err(DecodeError::Precision) + ); + } + + #[test] + fn from_str_radix_10() { + let dec = "+340_282_366_920_938_463_463_374_607_431_768_211_455"; + let res = BoxedUint::from_str_radix_vartime(dec, 10).expect("error decoding"); + assert_eq!(res, BoxedUint::max(128)); + } + + #[test] + fn from_str_radix_16() { + let hex = "fedcba9876543210fedcba9876543210"; + let res = BoxedUint::from_str_radix_vartime(hex, 16).expect("error decoding"); + assert_eq!(hex, format!("{res:x}")); + } } diff --git a/src/uint/encoding.rs b/src/uint/encoding.rs index 9da8e4992..97da2bc68 100644 --- a/src/uint/encoding.rs +++ b/src/uint/encoding.rs @@ -7,7 +7,7 @@ mod der; mod rlp; use super::Uint; -use crate::{Limb, Word}; +use crate::{DecodeError, Limb, Word}; #[cfg(feature = "hybrid-array")] use crate::Encoding; @@ -163,6 +163,22 @@ impl Uint { dst.copy_from_slice(&src.to_le_bytes()); } } + + /// Create a new [`Uint`] from a string slice in a given base. + /// + /// The string may begin with a `+` character, and may use + /// underscore characters to separate digits. + /// + /// If the input value contains non-digit characters or digits outside of the range `0..radix` + /// this function will return [`DecodeError::InvalidDigit`]. + /// If the size of the decoded integer is larger than this type can represent, + /// this function will return [`DecodeError::InputSize`]. + /// Panics if `radix` is not in the range from 2 to 36. + pub fn from_str_radix_vartime(src: &str, radix: u32) -> Result { + let mut slf = Self::ZERO; + decode_str_radix(src, radix, &mut SliceDecodeByLimb::new(&mut slf.limbs))?; + Ok(slf) + } } /// Encode a [`Uint`] to a big endian byte array of the given size. @@ -249,13 +265,224 @@ pub(crate) const fn decode_hex_byte(bytes: [u8; 2]) -> (u8, u16) { (result, err) } +/// Allow decoding of integers into fixed and variable-length types +pub(crate) trait DecodeByLimb { + /// Access the limbs as a mutable slice + fn limbs_mut(&mut self) -> &mut [Limb]; + + /// Append a new most-significant limb + fn push_limb(&mut self, limb: Limb) -> bool; +} + +/// Wrap a `Limb`` slice as a target for decoding +pub(crate) struct SliceDecodeByLimb<'de> { + limbs: &'de mut [Limb], + len: usize, +} + +impl<'de> SliceDecodeByLimb<'de> { + #[inline] + pub fn new(limbs: &'de mut [Limb]) -> Self { + Self { limbs, len: 0 } + } +} + +impl DecodeByLimb for SliceDecodeByLimb<'_> { + #[inline] + fn push_limb(&mut self, limb: Limb) -> bool { + if self.len < self.limbs.len() { + self.limbs[self.len] = limb; + self.len += 1; + true + } else { + false + } + } + + #[inline] + fn limbs_mut(&mut self) -> &mut [Limb] { + &mut self.limbs[..self.len] + } +} + +/// Decode an ascii string in base `radix`, writing the result +/// to the `DecodeByLimb` instance `out`. +/// The input must be a non-empty ascii string, may begin with a `+` +/// character, and may use `_` as a separator between digits. +pub(crate) fn decode_str_radix( + src: &str, + radix: u32, + out: &mut D, +) -> Result<(), DecodeError> { + if !(2u32..=36).contains(&radix) { + panic!("unsupported radix"); + } + if radix == 2 || radix == 4 || radix == 16 { + decode_str_radix_aligned_digits(src, radix as u8, out) + } else { + decode_str_radix_digits(src, radix as u8, out) + } +} + +#[inline(always)] +/// Perform basic validation and pre-processing on a digit string +fn process_radix_str(src: &str) -> Result<&[u8], DecodeError> { + // Treat the input as ascii bytes + let src_b = src.as_bytes(); + let mut digits = src_b.strip_prefix(b"+").unwrap_or(src_b); + + if digits.is_empty() { + // Blank string or plain "+" not allowed + Err(DecodeError::Empty) + } else if digits.starts_with(b"_") || digits.ends_with(b"_") { + // Leading or trailing underscore not allowed + Err(DecodeError::InvalidDigit) + } else { + // Strip leading zeroes to simplify parsing + while digits[0] == b'0' || digits[0] == b'_' { + digits = &digits[1..]; + if digits.is_empty() { + break; + } + } + Ok(digits) + } +} + +// Decode a string of digits in base `radix` +fn decode_str_radix_digits( + src: &str, + radix: u8, + out: &mut D, +) -> Result<(), DecodeError> { + let digits = process_radix_str(src)?; + let mut buf = [0u8; Limb::BITS as _]; + let mut limb_digits = Word::MAX.ilog(radix as _) as usize; + let mut limb_max = Limb(Word::pow(radix as _, limb_digits as _)); + let mut digits_pos = 0; + let mut buf_pos = 0; + + while digits_pos < digits.len() { + // Parse digits from most significant, to fill buffer limb + loop { + let digit = match digits[digits_pos] { + b @ b'0'..=b'9' => b - b'0', + b @ b'a'..=b'z' => b + 10 - b'a', + b @ b'A'..=b'Z' => b + 10 - b'A', + b'_' => { + digits_pos += 1; + continue; + } + _ => radix, + }; + if digit >= radix { + return Err(DecodeError::InvalidDigit); + } + buf[buf_pos] = digit; + buf_pos += 1; + digits_pos += 1; + + if digits_pos == digits.len() || buf_pos == limb_digits { + break; + } + } + + // On the final loop, there may be fewer digits to process + if buf_pos < limb_digits { + limb_digits = buf_pos; + limb_max = Limb(Word::pow(radix as _, limb_digits as _)); + } + + // Combine the digit bytes into a limb + let mut carry = Limb::ZERO; + for c in buf[..limb_digits].iter().copied() { + carry = Limb(carry.0 * (radix as Word) + (c as Word)); + } + // Multiply the existing limbs by `radix` ^ `limb_digits`, + // and add the new least-significant limb + for limb in out.limbs_mut().iter_mut() { + (*limb, carry) = Limb::ZERO.mac(*limb, limb_max, carry); + } + // Append the new carried limb, if any + if carry.0 != 0 && !out.push_limb(carry) { + return Err(DecodeError::InputSize); + } + + buf_pos = 0; + buf[..limb_digits].fill(0); + } + + Ok(()) +} + +// Decode digits for bases where an integer number of characters +// can represent a saturated Limb (specifically 2, 4, and 16). +fn decode_str_radix_aligned_digits( + src: &str, + radix: u8, + out: &mut D, +) -> Result<(), DecodeError> { + debug_assert!(radix == 2 || radix == 4 || radix == 16); + + let digits = process_radix_str(src)?; + let shift = radix.trailing_zeros(); + let limb_digits = (Limb::BITS / shift) as usize; + let mut buf = [0u8; Limb::BITS as _]; + let mut buf_pos = 0; + let mut digits_pos = digits.len(); + + while digits_pos > 0 { + // Parse digits from the least significant, to fill the buffer limb + loop { + digits_pos -= 1; + + let digit = match digits[digits_pos] { + b @ b'0'..=b'9' => b - b'0', + b @ b'a'..=b'z' => b + 10 - b'a', + b @ b'A'..=b'Z' => b + 10 - b'A', + b'_' => { + // cannot occur when c == 0 + continue; + } + _ => radix, + }; + if digit >= radix { + return Err(DecodeError::InvalidDigit); + } + buf[buf_pos] = digit; + buf_pos += 1; + + if digits_pos == 0 || buf_pos == limb_digits { + break; + } + } + + if buf_pos > 0 { + // Combine the digit bytes into a limb + let mut w: Word = 0; + for c in buf[..buf_pos].iter().rev().copied() { + w = (w << shift) | (c as Word); + } + // Append the new most-significant limb + if !out.push_limb(Limb(w)) { + return Err(DecodeError::InputSize); + } + + buf_pos = 0; + buf[..limb_digits].fill(0); + } + } + + Ok(()) +} + #[cfg(test)] mod tests { - use crate::Limb; + use crate::{DecodeError, Limb, Zero, U128, U64}; use hex_literal::hex; #[cfg(feature = "alloc")] - use {crate::U128, alloc::format}; + use alloc::format; #[cfg(target_pointer_width = "32")] use crate::U64 as UintEx; @@ -361,4 +588,82 @@ mod tests { 1100110011001100110011001100110011011101110111011101110111011101"; assert_eq!(expect, format!("{:b}", n)); } + + #[test] + fn from_str_radix_disallowed() { + let tests = [ + ("", 10, DecodeError::Empty), + ("+", 10, DecodeError::Empty), + ("_", 10, DecodeError::InvalidDigit), + ("0_", 10, DecodeError::InvalidDigit), + ("0_", 10, DecodeError::InvalidDigit), + ("a", 10, DecodeError::InvalidDigit), + (".", 10, DecodeError::InvalidDigit), + ( + "99999999999999999999999999999999", + 10, + DecodeError::InputSize, + ), + ]; + for (input, radix, expect) in tests { + assert_eq!(U64::from_str_radix_vartime(input, radix), Err(expect)); + } + } + + #[test] + fn from_str_radix_2() { + let buf = &[b'1'; 128]; + let radix = U128::from_u64(2); + let radix_max = U128::from_u64(1); + let mut last: Option = None; + for idx in (1..buf.len()).rev() { + let res = U128::from_str_radix_vartime( + core::str::from_utf8(&buf[..idx]).expect("utf-8 error"), + 2, + ) + .expect("error decoding"); + assert!(!bool::from(res.is_zero())); + if let Some(prev) = last { + assert_eq!(res.saturating_mul(&radix).saturating_add(&radix_max), prev); + } + last = Some(res); + } + assert_eq!(last, Some(radix_max)); + } + + #[test] + fn from_str_radix_5() { + let buf = &[b'4'; 55]; + let radix = U128::from_u64(5); + let radix_max = U128::from_u64(4); + let mut last: Option = None; + for idx in (1..buf.len()).rev() { + let res = U128::from_str_radix_vartime( + core::str::from_utf8(&buf[..idx]).expect("utf-8 error"), + 5, + ) + .expect("error decoding"); + assert!(!bool::from(res.is_zero())); + if let Some(prev) = last { + assert_eq!(res.saturating_mul(&radix).saturating_add(&radix_max), prev); + } + last = Some(res); + } + assert_eq!(last, Some(radix_max)); + } + + #[test] + fn from_str_radix_10() { + let dec = "+340_282_366_920_938_463_463_374_607_431_768_211_455"; + let res = U128::from_str_radix_vartime(dec, 10).expect("error decoding"); + assert_eq!(res, U128::MAX); + } + + #[cfg(feature = "alloc")] + #[test] + fn from_str_radix_16() { + let hex = "fedcba9876543210fedcba9876543210"; + let res = U128::from_str_radix_vartime(hex, 16).expect("error decoding"); + assert_eq!(hex, format!("{res:x}")); + } }