diff --git a/elliptic-curve/Cargo.toml b/elliptic-curve/Cargo.toml index 9e476cda1..30528f54b 100644 --- a/elliptic-curve/Cargo.toml +++ b/elliptic-curve/Cargo.toml @@ -49,7 +49,7 @@ alloc = ["der/alloc", "sec1/alloc", "zeroize/alloc"] # todo: use weak activation arithmetic = ["ff", "group"] bits = ["arithmetic", "ff/bits"] dev = ["arithmetic", "hex-literal", "pem", "pkcs8"] -hash2curve = ["digest", "ff", "group"] +hash2curve = ["arithmetic", "digest"] ecdh = ["arithmetic"] hazmat = [] jwk = ["alloc", "base64ct/alloc", "serde", "serde_json", "zeroize/alloc"] diff --git a/elliptic-curve/src/hash2curve/group_digest.rs b/elliptic-curve/src/hash2curve/group_digest.rs index 4fac9c53c..644d149a7 100644 --- a/elliptic-curve/src/hash2curve/group_digest.rs +++ b/elliptic-curve/src/hash2curve/group_digest.rs @@ -3,12 +3,12 @@ use super::MapToCurve; use crate::{ hash2field::{hash_to_field, ExpandMsg, FromOkm}, - Result, + ProjectiveArithmetic, Result, }; use group::cofactor::CofactorGroup; /// Adds hashing arbitrary byte sequences to a valid group element -pub trait GroupDigest { +pub trait GroupDigest: ProjectiveArithmetic { /// The field element representation for a group value with multiple elements type FieldElement: FromOkm + MapToCurve + Default + Copy; /// The resulting group element @@ -30,18 +30,30 @@ pub trait GroupDigest { /// ## Using a fixed size hash function /// /// ```ignore - /// let pt = ProjectivePoint::hash_from_bytes::>(b"test data", b"CURVE_XMD:SHA-256_SSWU_RO_"); + /// let pt = ProjectivePoint::hash_from_bytes::>(b"test data", b"CURVE_XMD:SHA-256_SSWU_RO_"); /// ``` /// /// ## Using an extendable output function /// /// ```ignore - /// let pt = ProjectivePoint::hash_from_bytes::>(b"test data", b"CURVE_XOF:SHAKE-256_SSWU_RO_"); + /// let pt = ProjectivePoint::hash_from_bytes::>(b"test data", b"CURVE_XOF:SHAKE-256_SSWU_RO_"); /// ``` /// - fn hash_from_bytes(msg: &[u8], dst: &'static [u8]) -> Result { + /// # Errors + /// See implementors of [`ExpandMsg`] for errors: + /// - [`ExpandMsgXmd`] + /// - [`ExpandMsgXof`] + /// + /// `len_in_bytes = T::Length * 2` + /// + /// [`ExpandMsgXmd`]: crate::hash2field::ExpandMsgXmd + /// [`ExpandMsgXof`]: crate::hash2field::ExpandMsgXof + fn hash_from_bytes<'a, X: ExpandMsg<'a>>( + msgs: &[&[u8]], + dst: &'a [u8], + ) -> Result { let mut u = [Self::FieldElement::default(), Self::FieldElement::default()]; - hash_to_field::(msg, dst, &mut u)?; + hash_to_field::(msgs, dst, &mut u)?; let q0 = u[0].map_to_curve(); let q1 = u[1].map_to_curve(); // Ideally we could add and then clear cofactor once @@ -66,10 +78,45 @@ pub trait GroupDigest { /// > uniformly random in G: the set of possible outputs of /// > encode_to_curve is only a fraction of the points in G, and some /// > points in this set are more likely to be output than others. - fn encode_from_bytes(msg: &[u8], dst: &'static [u8]) -> Result { + /// + /// # Errors + /// See implementors of [`ExpandMsg`] for errors: + /// - [`ExpandMsgXmd`] + /// - [`ExpandMsgXof`] + /// + /// `len_in_bytes = T::Length` + /// + /// [`ExpandMsgXmd`]: crate::hash2field::ExpandMsgXmd + /// [`ExpandMsgXof`]: crate::hash2field::ExpandMsgXof + fn encode_from_bytes<'a, X: ExpandMsg<'a>>( + msgs: &[&[u8]], + dst: &'a [u8], + ) -> Result { let mut u = [Self::FieldElement::default()]; - hash_to_field::(msg, dst, &mut u)?; + hash_to_field::(msgs, dst, &mut u)?; let q0 = u[0].map_to_curve(); Ok(q0.clear_cofactor()) } + + /// Computes the hash to field routine according to + /// + /// and returns a scalar. + /// + /// # Errors + /// See implementors of [`ExpandMsg`] for errors: + /// - [`ExpandMsgXmd`] + /// - [`ExpandMsgXof`] + /// + /// `len_in_bytes = T::Length` + /// + /// [`ExpandMsgXmd`]: crate::hash2field::ExpandMsgXmd + /// [`ExpandMsgXof`]: crate::hash2field::ExpandMsgXof + fn hash_to_scalar<'a, X: ExpandMsg<'a>>(msgs: &[&[u8]], dst: &'a [u8]) -> Result + where + Self::Scalar: FromOkm, + { + let mut u = [Self::Scalar::default()]; + hash_to_field::(msgs, dst, &mut u)?; + Ok(u[0]) + } } diff --git a/elliptic-curve/src/hash2field.rs b/elliptic-curve/src/hash2field.rs index 94faa958c..369e96cf5 100644 --- a/elliptic-curve/src/hash2field.rs +++ b/elliptic-curve/src/hash2field.rs @@ -21,9 +21,19 @@ pub trait FromOkm { /// Convert an arbitrary byte sequence into a field element. /// /// -pub fn hash_to_field(data: &[u8], domain: &'static [u8], out: &mut [T]) -> Result<()> +/// +/// # Errors +/// See implementors of [`ExpandMsg`] for errors: +/// - [`ExpandMsgXmd`] +/// - [`ExpandMsgXof`] +/// +/// `len_in_bytes = T::Length * out.len()` +/// +/// [`ExpandMsgXmd`]: crate::hash2field::ExpandMsgXmd +/// [`ExpandMsgXof`]: crate::hash2field::ExpandMsgXof +pub fn hash_to_field<'a, E, T>(data: &[&[u8]], domain: &'a [u8], out: &mut [T]) -> Result<()> where - E: ExpandMsg, + E: ExpandMsg<'a>, T: FromOkm + Default, { let len_in_bytes = T::Length::to_usize() * out.len(); diff --git a/elliptic-curve/src/hash2field/expand_msg.rs b/elliptic-curve/src/hash2field/expand_msg.rs index 90e81e1f6..fda9fcc83 100644 --- a/elliptic-curve/src/hash2field/expand_msg.rs +++ b/elliptic-curve/src/hash2field/expand_msg.rs @@ -14,13 +14,23 @@ const OVERSIZE_DST_SALT: &[u8] = b"H2C-OVERSIZE-DST-"; const MAX_DST_LEN: usize = 255; /// Trait for types implementing expand_message interface for `hash_to_field`. -pub trait ExpandMsg: Sized { +/// +/// # Errors +/// See implementors of [`ExpandMsg`] for errors. +pub trait ExpandMsg<'a> { + /// Type holding data for the [`Expander`]. + type Expander: Expander + Sized; + /// Expands `msg` to the required number of bytes. /// /// Returns an expander that can be used to call `read` until enough /// bytes have been consumed - fn expand_message(msg: &[u8], dst: &'static [u8], len_in_bytes: usize) -> Result; + fn expand_message(msgs: &[&[u8]], dst: &'a [u8], len_in_bytes: usize) + -> Result; +} +/// Expander that, call `read` until enough bytes have been consumed. +pub trait Expander { /// Fill the array with the expanded bytes fn fill_bytes(&mut self, okm: &mut [u8]); } @@ -30,21 +40,21 @@ pub trait ExpandMsg: Sized { /// Implements [section 5.4.3 of `draft-irtf-cfrg-hash-to-curve-13`][dst]. /// /// [dst]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-hash-to-curve-13#section-5.4.3 -pub(crate) enum Domain +pub(crate) enum Domain<'a, L> where L: ArrayLength + IsLess, { /// > 255 Hashed(GenericArray), /// <= 255 - Array(&'static [u8]), + Array(&'a [u8]), } -impl Domain +impl<'a, L> Domain<'a, L> where L: ArrayLength + IsLess, { - pub fn xof(dst: &'static [u8]) -> Self + pub fn xof(dst: &'a [u8]) -> Self where X: Default + ExtendableOutput + Update, { @@ -61,7 +71,7 @@ where } } - pub fn xmd(dst: &'static [u8]) -> Self + pub fn xmd(dst: &'a [u8]) -> Self where X: Digest, { diff --git a/elliptic-curve/src/hash2field/expand_msg/xmd.rs b/elliptic-curve/src/hash2field/expand_msg/xmd.rs index a1e96d969..8e05e715b 100644 --- a/elliptic-curve/src/hash2field/expand_msg/xmd.rs +++ b/elliptic-curve/src/hash2field/expand_msg/xmd.rs @@ -1,6 +1,8 @@ //! `expand_message_xmd` based on a hash function. -use super::{Domain, ExpandMsg}; +use core::marker::PhantomData; + +use super::{Domain, ExpandMsg, Expander}; use crate::{Error, Result}; use digest::{ generic_array::{ @@ -11,52 +13,19 @@ use digest::{ }; /// Placeholder type for implementing `expand_message_xmd` based on a hash function -pub struct ExpandMsgXmd +/// +/// # Errors +/// - `len_in_bytes == 0` +/// - `len_in_bytes > u16::MAX` +/// - `len_in_bytes > 255 * HashT::OutputSize` +pub struct ExpandMsgXmd(PhantomData) where HashT: Digest + BlockInput, HashT::OutputSize: IsLess, - HashT::OutputSize: IsLessOrEqual, -{ - b_0: GenericArray, - b_vals: GenericArray, - domain: Domain, - index: u8, - offset: usize, - ell: u8, -} - -impl ExpandMsgXmd -where - HashT: Digest + BlockInput, - HashT::OutputSize: IsLess, - HashT::OutputSize: IsLessOrEqual, -{ - fn next(&mut self) -> bool { - if self.index < self.ell { - self.index += 1; - self.offset = 0; - // b_0 XOR b_(idx - 1) - let mut tmp = GenericArray::::default(); - self.b_0 - .iter() - .zip(&self.b_vals[..]) - .enumerate() - .for_each(|(j, (b0val, bi1val))| tmp[j] = b0val ^ bi1val); - self.b_vals = HashT::new() - .chain(tmp) - .chain([self.index]) - .chain(self.domain.data()) - .chain([self.domain.len()]) - .finalize(); - true - } else { - false - } - } -} + HashT::OutputSize: IsLessOrEqual; /// ExpandMsgXmd implements expand_message_xmd for the ExpandMsg trait -impl ExpandMsg for ExpandMsgXmd +impl<'a, HashT> ExpandMsg<'a> for ExpandMsgXmd where HashT: Digest + BlockInput, // If `len_in_bytes` is bigger then 256, length of the `DST` will depend on @@ -67,7 +36,13 @@ where // https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-13.html#section-5.4.1-4 HashT::OutputSize: IsLessOrEqual, { - fn expand_message(msg: &[u8], dst: &'static [u8], len_in_bytes: usize) -> Result { + type Expander = ExpanderXmd<'a, HashT>; + + fn expand_message( + msgs: &[&[u8]], + dst: &'a [u8], + len_in_bytes: usize, + ) -> Result { if len_in_bytes == 0 { return Err(Error); } @@ -78,9 +53,13 @@ where let ell = u8::try_from((len_in_bytes + b_in_bytes - 1) / b_in_bytes).map_err(|_| Error)?; let domain = Domain::xmd::(dst); - let b_0 = HashT::new() - .chain(GenericArray::::default()) - .chain(msg) + let mut b_0 = HashT::new().chain(GenericArray::::default()); + + for msg in msgs { + b_0 = b_0.chain(msg); + } + + let b_0 = b_0 .chain(len_in_bytes_u16.to_be_bytes()) .chain([0]) .chain(domain.data()) @@ -94,7 +73,7 @@ where .chain([domain.len()]) .finalize(); - Ok(Self { + Ok(ExpanderXmd { b_0, b_vals, domain, @@ -103,7 +82,59 @@ where ell, }) } +} +/// [`Expander`] type for [`ExpandMsgXmd`]. +pub struct ExpanderXmd<'a, HashT> +where + HashT: Digest + BlockInput, + HashT::OutputSize: IsLess, + HashT::OutputSize: IsLessOrEqual, +{ + b_0: GenericArray, + b_vals: GenericArray, + domain: Domain<'a, HashT::OutputSize>, + index: u8, + offset: usize, + ell: u8, +} + +impl<'a, HashT> ExpanderXmd<'a, HashT> +where + HashT: Digest + BlockInput, + HashT::OutputSize: IsLess, + HashT::OutputSize: IsLessOrEqual, +{ + fn next(&mut self) -> bool { + if self.index < self.ell { + self.index += 1; + self.offset = 0; + // b_0 XOR b_(idx - 1) + let mut tmp = GenericArray::::default(); + self.b_0 + .iter() + .zip(&self.b_vals[..]) + .enumerate() + .for_each(|(j, (b0val, bi1val))| tmp[j] = b0val ^ bi1val); + self.b_vals = HashT::new() + .chain(tmp) + .chain([self.index]) + .chain(self.domain.data()) + .chain([self.domain.len()]) + .finalize(); + true + } else { + false + } + } +} + +impl<'a, HashT> Expander for ExpanderXmd<'a, HashT> +where + HashT: Digest + BlockInput, + HashT::OutputSize: IsLess, + HashT::OutputSize: IsLessOrEqual, +{ fn fill_bytes(&mut self, okm: &mut [u8]) { for b in okm { if self.offset == self.b_vals.len() && !self.next() { @@ -128,7 +159,7 @@ mod test { fn assert_message( msg: &[u8], - domain: &Domain, + domain: &Domain<'_, HashT::OutputSize>, len_in_bytes: u16, bytes: &[u8], ) where @@ -169,7 +200,7 @@ mod test { fn assert>( &self, dst: &'static [u8], - domain: &Domain, + domain: &Domain<'_, HashT::OutputSize>, ) -> Result<()> where HashT: Digest + BlockInput, @@ -178,7 +209,7 @@ mod test { assert_message::(self.msg, domain, L::to_u16(), self.msg_prime); let mut expander = - as ExpandMsg>::expand_message(self.msg, dst, L::to_usize())?; + ExpandMsgXmd::::expand_message(&[self.msg], dst, L::to_usize())?; let mut uniform_bytes = GenericArray::::default(); expander.fill_bytes(&mut uniform_bytes); diff --git a/elliptic-curve/src/hash2field/expand_msg/xof.rs b/elliptic-curve/src/hash2field/expand_msg/xof.rs index f30cb370e..d7ca714cf 100644 --- a/elliptic-curve/src/hash2field/expand_msg/xof.rs +++ b/elliptic-curve/src/hash2field/expand_msg/xof.rs @@ -1,11 +1,15 @@ //! `expand_message_xof` for the `ExpandMsg` trait -use super::ExpandMsg; -use crate::{hash2field::Domain, Error, Result}; +use super::{Domain, ExpandMsg, Expander}; +use crate::{Error, Result}; use digest::{ExtendableOutput, Update, XofReader}; use generic_array::typenum::U32; /// Placeholder type for implementing `expand_message_xof` based on an extendable output function +/// +/// # Errors +/// - `len_in_bytes == 0` +/// - `len_in_bytes > u16::MAX` pub struct ExpandMsgXof where HashT: Default + ExtendableOutput + Update, @@ -14,11 +18,17 @@ where } /// ExpandMsgXof implements `expand_message_xof` for the [`ExpandMsg`] trait -impl ExpandMsg for ExpandMsgXof +impl<'a, HashT> ExpandMsg<'a> for ExpandMsgXof where HashT: Default + ExtendableOutput + Update, { - fn expand_message(msg: &[u8], dst: &'static [u8], len_in_bytes: usize) -> Result { + type Expander = Self; + + fn expand_message( + msgs: &[&[u8]], + dst: &'a [u8], + len_in_bytes: usize, + ) -> Result { if len_in_bytes == 0 { return Err(Error); } @@ -26,19 +36,30 @@ where let len_in_bytes = u16::try_from(len_in_bytes).map_err(|_| Error)?; let domain = Domain::::xof::(dst); - let reader = HashT::default() - .chain(msg) + let mut reader = HashT::default(); + + for msg in msgs { + reader = reader.chain(msg); + } + + let reader = reader .chain(len_in_bytes.to_be_bytes()) .chain(domain.data()) .chain([domain.len()]) .finalize_xof(); Ok(Self { reader }) } +} +impl Expander for ExpandMsgXof +where + HashT: Default + ExtendableOutput + Update, +{ fn fill_bytes(&mut self, okm: &mut [u8]) { self.reader.read(okm); } } + #[cfg(test)] mod test { use super::*; @@ -50,7 +71,12 @@ mod test { use hex_literal::hex; use sha3::Shake128; - fn assert_message(msg: &[u8], domain: &Domain, len_in_bytes: u16, bytes: &[u8]) { + fn assert_message( + msg: &[u8], + domain: &Domain<'_, U32>, + len_in_bytes: u16, + bytes: &[u8], + ) { let msg_len = msg.len(); assert_eq!(msg, &bytes[..msg_len]); @@ -76,7 +102,7 @@ mod test { } impl TestVector { - fn assert(&self, dst: &'static [u8], domain: &Domain) -> Result<()> + fn assert(&self, dst: &'static [u8], domain: &Domain<'_, U32>) -> Result<()> where HashT: Default + ExtendableOutput + Update, L: ArrayLength, @@ -84,7 +110,7 @@ mod test { assert_message::(self.msg, domain, L::to_u16(), self.msg_prime); let mut expander = - as ExpandMsg>::expand_message(self.msg, dst, L::to_usize())?; + ExpandMsgXof::::expand_message(&[self.msg], dst, L::to_usize())?; let mut uniform_bytes = GenericArray::::default(); expander.fill_bytes(&mut uniform_bytes);