diff --git a/elliptic-curve/src/hash2curve/group_digest.rs b/elliptic-curve/src/hash2curve/group_digest.rs index 4de06140c..ea7f0471f 100644 --- a/elliptic-curve/src/hash2curve/group_digest.rs +++ b/elliptic-curve/src/hash2curve/group_digest.rs @@ -48,10 +48,10 @@ where /// [`ExpandMsgXof`]: crate::hash2curve::ExpandMsgXof fn hash_from_bytes<'a, X: ExpandMsg<'a>>( msgs: &[&[u8]], - dst: &'a [u8], + dsts: &'a [&'a [u8]], ) -> Result> { let mut u = [Self::FieldElement::default(), Self::FieldElement::default()]; - hash_to_field::(msgs, dst, &mut u)?; + hash_to_field::(msgs, dsts, &mut u)?; let q0 = u[0].map_to_curve(); let q1 = u[1].map_to_curve(); // Ideally we could add and then clear cofactor once @@ -88,10 +88,10 @@ where /// [`ExpandMsgXof`]: crate::hash2curve::ExpandMsgXof fn encode_from_bytes<'a, X: ExpandMsg<'a>>( msgs: &[&[u8]], - dst: &'a [u8], + dsts: &'a [&'a [u8]], ) -> Result> { let mut u = [Self::FieldElement::default()]; - hash_to_field::(msgs, dst, &mut u)?; + hash_to_field::(msgs, dsts, &mut u)?; let q0 = u[0].map_to_curve(); Ok(q0.clear_cofactor().into()) } @@ -109,12 +109,15 @@ where /// /// [`ExpandMsgXmd`]: crate::hash2curve::ExpandMsgXmd /// [`ExpandMsgXof`]: crate::hash2curve::ExpandMsgXof - fn hash_to_scalar<'a, X: ExpandMsg<'a>>(msgs: &[&[u8]], dst: &'a [u8]) -> Result + fn hash_to_scalar<'a, X: ExpandMsg<'a>>( + msgs: &[&[u8]], + dsts: &'a [&'a [u8]], + ) -> Result where Self::Scalar: FromOkm, { let mut u = [Self::Scalar::default()]; - hash_to_field::(msgs, dst, &mut u)?; + hash_to_field::(msgs, dsts, &mut u)?; Ok(u[0]) } } diff --git a/elliptic-curve/src/hash2curve/hash2field.rs b/elliptic-curve/src/hash2curve/hash2field.rs index 6cd0723aa..b0ee5b2d6 100644 --- a/elliptic-curve/src/hash2curve/hash2field.rs +++ b/elliptic-curve/src/hash2curve/hash2field.rs @@ -32,7 +32,7 @@ pub trait FromOkm { /// [`ExpandMsgXmd`]: crate::hash2field::ExpandMsgXmd /// [`ExpandMsgXof`]: crate::hash2field::ExpandMsgXof #[doc(hidden)] -pub fn hash_to_field<'a, E, T>(data: &[&[u8]], domain: &'a [u8], out: &mut [T]) -> Result<()> +pub fn hash_to_field<'a, E, T>(data: &[&[u8]], domain: &'a [&'a [u8]], out: &mut [T]) -> Result<()> where E: ExpandMsg<'a>, T: FromOkm + Default, diff --git a/elliptic-curve/src/hash2curve/hash2field/expand_msg.rs b/elliptic-curve/src/hash2curve/hash2field/expand_msg.rs index 4a4db7119..96a659b9a 100644 --- a/elliptic-curve/src/hash2curve/hash2field/expand_msg.rs +++ b/elliptic-curve/src/hash2curve/hash2field/expand_msg.rs @@ -25,8 +25,11 @@ pub trait ExpandMsg<'a> { /// /// Returns an expander that can be used to call `read` until enough /// bytes have been consumed - fn expand_message(msgs: &[&[u8]], dst: &'a [u8], len_in_bytes: usize) - -> Result; + fn expand_message( + msgs: &[&[u8]], + dsts: &'a [&'a [u8]], + len_in_bytes: usize, + ) -> Result; } /// Expander that, call `read` until enough bytes have been consumed. @@ -47,54 +50,66 @@ where /// > 255 Hashed(GenericArray), /// <= 255 - Array(&'a [u8]), + Array(&'a [&'a [u8]]), } impl<'a, L> Domain<'a, L> where L: ArrayLength + IsLess, { - pub fn xof(dst: &'a [u8]) -> Result + pub fn xof(dsts: &'a [&'a [u8]]) -> Result where X: Default + ExtendableOutput + Update, { - if dst.is_empty() { + if dsts.is_empty() { Err(Error) - } else if dst.len() > MAX_DST_LEN { + } else if dsts.iter().map(|dst| dst.len()).sum::() > MAX_DST_LEN { let mut data = GenericArray::::default(); - X::default() - .chain(OVERSIZE_DST_SALT) - .chain(dst) - .finalize_xof() - .read(&mut data); + let mut hash = X::default(); + hash.update(OVERSIZE_DST_SALT); + + for dst in dsts { + hash.update(dst); + } + + hash.finalize_xof().read(&mut data); + Ok(Self::Hashed(data)) } else { - Ok(Self::Array(dst)) + Ok(Self::Array(dsts)) } } - pub fn xmd(dst: &'a [u8]) -> Result + pub fn xmd(dsts: &'a [&'a [u8]]) -> Result where X: Digest, { - if dst.is_empty() { + if dsts.is_empty() { Err(Error) - } else if dst.len() > MAX_DST_LEN { + } else if dsts.iter().map(|dst| dst.len()).sum::() > MAX_DST_LEN { Ok(Self::Hashed({ let mut hash = X::new(); hash.update(OVERSIZE_DST_SALT); - hash.update(dst); + + for dst in dsts { + hash.update(dst); + } + hash.finalize() })) } else { - Ok(Self::Array(dst)) + Ok(Self::Array(dsts)) } } - pub fn data(&self) -> &[u8] { + pub fn update_hash(&self, hash: &mut HashT) { match self { - Self::Hashed(d) => &d[..], - Self::Array(d) => d, + Self::Hashed(d) => hash.update(d), + Self::Array(d) => { + for d in d.iter() { + hash.update(d) + } + } } } @@ -103,13 +118,28 @@ where // Can't overflow because it's enforced on a type level. Self::Hashed(_) => L::to_u8(), // Can't overflow because it's checked on creation. - Self::Array(d) => u8::try_from(d.len()).expect("length overflow"), + Self::Array(d) => { + u8::try_from(d.iter().map(|d| d.len()).sum::()).expect("length overflow") + } } } #[cfg(test)] pub fn assert(&self, bytes: &[u8]) { - assert_eq!(self.data(), &bytes[..bytes.len() - 1]); + let data = match self { + Domain::Hashed(d) => d.to_vec(), + Domain::Array(d) => d.iter().copied().flatten().copied().collect(), + }; + assert_eq!(data, bytes); + } + + #[cfg(test)] + pub fn assert_dst(&self, bytes: &[u8]) { + let data = match self { + Domain::Hashed(d) => d.to_vec(), + Domain::Array(d) => d.iter().copied().flatten().copied().collect(), + }; + assert_eq!(data, &bytes[..bytes.len() - 1]); assert_eq!(self.len(), bytes[bytes.len() - 1]); } } diff --git a/elliptic-curve/src/hash2curve/hash2field/expand_msg/xmd.rs b/elliptic-curve/src/hash2curve/hash2field/expand_msg/xmd.rs index 876b012f5..baf6f31b2 100644 --- a/elliptic-curve/src/hash2curve/hash2field/expand_msg/xmd.rs +++ b/elliptic-curve/src/hash2curve/hash2field/expand_msg/xmd.rs @@ -10,7 +10,7 @@ use digest::{ typenum::{IsLess, IsLessOrEqual, Unsigned, U256}, GenericArray, }, - Digest, + FixedOutput, HashMarker, }; /// Placeholder type for implementing `expand_message_xmd` based on a hash function @@ -22,14 +22,14 @@ use digest::{ /// - `len_in_bytes > 255 * HashT::OutputSize` pub struct ExpandMsgXmd(PhantomData) where - HashT: Digest + BlockSizeUser, + HashT: BlockSizeUser + Default + FixedOutput + HashMarker, HashT::OutputSize: IsLess, HashT::OutputSize: IsLessOrEqual; /// ExpandMsgXmd implements expand_message_xmd for the ExpandMsg trait impl<'a, HashT> ExpandMsg<'a> for ExpandMsgXmd where - HashT: Digest + BlockSizeUser, + HashT: BlockSizeUser + Default + FixedOutput + HashMarker, // If `len_in_bytes` is bigger then 256, length of the `DST` will depend on // the output size of the hash, which is still not allowed to be bigger then 256: // https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-13.html#section-5.4.1-6 @@ -42,7 +42,7 @@ where fn expand_message( msgs: &[&[u8]], - dst: &'a [u8], + dsts: &'a [&'a [u8]], len_in_bytes: usize, ) -> Result { if len_in_bytes == 0 { @@ -54,26 +54,26 @@ where let b_in_bytes = HashT::OutputSize::to_usize(); let ell = u8::try_from((len_in_bytes + b_in_bytes - 1) / b_in_bytes).map_err(|_| Error)?; - let domain = Domain::xmd::(dst)?; - let mut b_0 = HashT::new(); - b_0.update(GenericArray::::default()); + let domain = Domain::xmd::(dsts)?; + let mut b_0 = HashT::default(); + b_0.update(&GenericArray::::default()); for msg in msgs { b_0.update(msg); } - b_0.update(len_in_bytes_u16.to_be_bytes()); - b_0.update([0]); - b_0.update(domain.data()); - b_0.update([domain.len()]); - let b_0 = b_0.finalize(); + b_0.update(&len_in_bytes_u16.to_be_bytes()); + b_0.update(&[0]); + domain.update_hash(&mut b_0); + b_0.update(&[domain.len()]); + let b_0 = b_0.finalize_fixed(); - let mut b_vals = HashT::new(); + let mut b_vals = HashT::default(); b_vals.update(&b_0[..]); - b_vals.update([1u8]); - b_vals.update(domain.data()); - b_vals.update([domain.len()]); - let b_vals = b_vals.finalize(); + b_vals.update(&[1u8]); + domain.update_hash(&mut b_vals); + b_vals.update(&[domain.len()]); + let b_vals = b_vals.finalize_fixed(); Ok(ExpanderXmd { b_0, @@ -89,7 +89,7 @@ where /// [`Expander`] type for [`ExpandMsgXmd`]. pub struct ExpanderXmd<'a, HashT> where - HashT: Digest + BlockSizeUser, + HashT: BlockSizeUser + Default + FixedOutput + HashMarker, HashT::OutputSize: IsLess, HashT::OutputSize: IsLessOrEqual, { @@ -103,7 +103,7 @@ where impl<'a, HashT> ExpanderXmd<'a, HashT> where - HashT: Digest + BlockSizeUser, + HashT: BlockSizeUser + Default + FixedOutput + HashMarker, HashT::OutputSize: IsLess, HashT::OutputSize: IsLessOrEqual, { @@ -118,12 +118,12 @@ where .zip(&self.b_vals[..]) .enumerate() .for_each(|(j, (b0val, bi1val))| tmp[j] = b0val ^ bi1val); - let mut b_vals = HashT::new(); - b_vals.update(tmp); - b_vals.update([self.index]); - b_vals.update(self.domain.data()); - b_vals.update([self.domain.len()]); - self.b_vals = b_vals.finalize(); + let mut b_vals = HashT::default(); + b_vals.update(&tmp); + b_vals.update(&[self.index]); + self.domain.update_hash(&mut b_vals); + b_vals.update(&[self.domain.len()]); + self.b_vals = b_vals.finalize_fixed(); true } else { false @@ -133,7 +133,7 @@ where impl<'a, HashT> Expander for ExpanderXmd<'a, HashT> where - HashT: Digest + BlockSizeUser, + HashT: BlockSizeUser + Default + FixedOutput + HashMarker, HashT::OutputSize: IsLess, HashT::OutputSize: IsLessOrEqual, { @@ -165,7 +165,7 @@ mod test { len_in_bytes: u16, bytes: &[u8], ) where - HashT: Digest + BlockSizeUser, + HashT: BlockSizeUser + Default + FixedOutput + HashMarker, HashT::OutputSize: IsLess, { let block = HashT::BlockSize::to_usize(); @@ -183,8 +183,8 @@ mod test { let pad = l + mem::size_of::(); assert_eq!([0], &bytes[l..pad]); - let dst = pad + domain.data().len(); - assert_eq!(domain.data(), &bytes[pad..dst]); + let dst = pad + usize::from(domain.len()); + domain.assert(&bytes[pad..dst]); let dst_len = dst + mem::size_of::(); assert_eq!([domain.len()], &bytes[dst..dst_len]); @@ -205,13 +205,14 @@ mod test { domain: &Domain<'_, HashT::OutputSize>, ) -> Result<()> where - HashT: Digest + BlockSizeUser, + HashT: BlockSizeUser + Default + FixedOutput + HashMarker, HashT::OutputSize: IsLess + IsLessOrEqual, { assert_message::(self.msg, domain, L::to_u16(), self.msg_prime); + let dst = [dst]; let mut expander = - ExpandMsgXmd::::expand_message(&[self.msg], dst, L::to_usize())?; + ExpandMsgXmd::::expand_message(&[self.msg], &dst, L::to_usize())?; let mut uniform_bytes = GenericArray::::default(); expander.fill_bytes(&mut uniform_bytes); @@ -227,8 +228,8 @@ mod test { const DST_PRIME: &[u8] = &hex!("515555582d5630312d435330322d776974682d657870616e6465722d5348413235362d31323826"); - let dst_prime = Domain::xmd::(DST)?; - dst_prime.assert(DST_PRIME); + let dst_prime = Domain::xmd::(&[DST])?; + dst_prime.assert_dst(DST_PRIME); const TEST_VECTORS_32: &[TestVector] = &[ TestVector { @@ -299,8 +300,8 @@ mod test { const DST_PRIME: &[u8] = &hex!("412717974da474d0f8c420f320ff81e8432adb7c927d9bd082b4fb4d16c0a23620"); - let dst_prime = Domain::xmd::(DST)?; - dst_prime.assert(DST_PRIME); + let dst_prime = Domain::xmd::(&[DST])?; + dst_prime.assert_dst(DST_PRIME); const TEST_VECTORS_32: &[TestVector] = &[ TestVector { @@ -377,8 +378,8 @@ mod test { const DST_PRIME: &[u8] = &hex!("515555582d5630312d435330322d776974682d657870616e6465722d5348413531322d32353626"); - let dst_prime = Domain::xmd::(DST)?; - dst_prime.assert(DST_PRIME); + let dst_prime = Domain::xmd::(&[DST])?; + dst_prime.assert_dst(DST_PRIME); const TEST_VECTORS_32: &[TestVector] = &[ TestVector { diff --git a/elliptic-curve/src/hash2curve/hash2field/expand_msg/xof.rs b/elliptic-curve/src/hash2curve/hash2field/expand_msg/xof.rs index 107ac5e06..9a5ff19e9 100644 --- a/elliptic-curve/src/hash2curve/hash2field/expand_msg/xof.rs +++ b/elliptic-curve/src/hash2curve/hash2field/expand_msg/xof.rs @@ -27,7 +27,7 @@ where fn expand_message( msgs: &[&[u8]], - dst: &'a [u8], + dsts: &'a [&'a [u8]], len_in_bytes: usize, ) -> Result { if len_in_bytes == 0 { @@ -36,18 +36,17 @@ where let len_in_bytes = u16::try_from(len_in_bytes).map_err(|_| Error)?; - let domain = Domain::::xof::(dst)?; + let domain = Domain::::xof::(dsts)?; let mut reader = HashT::default(); for msg in msgs { reader = reader.chain(msg); } - let reader = reader - .chain(len_in_bytes.to_be_bytes()) - .chain(domain.data()) - .chain([domain.len()]) - .finalize_xof(); + reader.update(&len_in_bytes.to_be_bytes()); + domain.update_hash(&mut reader); + reader.update(&[domain.len()]); + let reader = reader.finalize_xof(); Ok(Self { reader }) } } @@ -87,8 +86,8 @@ mod test { &bytes[msg_len..len_in_bytes_len] ); - let dst = len_in_bytes_len + domain.data().len(); - assert_eq!(domain.data(), &bytes[len_in_bytes_len..dst]); + let dst = len_in_bytes_len + usize::from(domain.len()); + domain.assert(&bytes[len_in_bytes_len..dst]); let dst_len = dst + mem::size_of::(); assert_eq!([domain.len()], &bytes[dst..dst_len]); @@ -111,7 +110,7 @@ mod test { assert_message::(self.msg, domain, L::to_u16(), self.msg_prime); let mut expander = - ExpandMsgXof::::expand_message(&[self.msg], dst, L::to_usize())?; + ExpandMsgXof::::expand_message(&[self.msg], &[dst], L::to_usize())?; let mut uniform_bytes = GenericArray::::default(); expander.fill_bytes(&mut uniform_bytes); @@ -127,8 +126,8 @@ mod test { const DST_PRIME: &[u8] = &hex!("515555582d5630312d435330322d776974682d657870616e6465722d5348414b4531323824"); - let dst_prime = Domain::::xof::(DST)?; - dst_prime.assert(DST_PRIME); + let dst_prime = Domain::::xof::(&[DST])?; + dst_prime.assert_dst(DST_PRIME); const TEST_VECTORS_32: &[TestVector] = &[ TestVector { @@ -203,8 +202,8 @@ mod test { const DST_PRIME: &[u8] = &hex!("acb9736c0867fdfbd6385519b90fc8c034b5af04a958973212950132d035792f20"); - let dst_prime = Domain::::xof::(DST)?; - dst_prime.assert(DST_PRIME); + let dst_prime = Domain::::xof::(&[DST])?; + dst_prime.assert_dst(DST_PRIME); const TEST_VECTORS_32: &[TestVector] = &[ TestVector { @@ -281,8 +280,8 @@ mod test { const DST_PRIME: &[u8] = &hex!("515555582d5630312d435330322d776974682d657870616e6465722d5348414b4532353624"); - let dst_prime = Domain::::xof::(DST)?; - dst_prime.assert(DST_PRIME); + let dst_prime = Domain::::xof::(&[DST])?; + dst_prime.assert_dst(DST_PRIME); const TEST_VECTORS_32: &[TestVector] = &[ TestVector {