Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion elliptic-curve/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ alloc = ["der/alloc", "sec1/alloc", "zeroize/alloc"] # todo: use weak activation
arithmetic = ["ff", "group"]
bits = ["arithmetic", "ff/bits"]
dev = ["arithmetic", "hex-literal", "pem", "pkcs8"]
hash2curve = ["digest", "ff", "group"]
hash2curve = ["arithmetic", "digest"]
ecdh = ["arithmetic"]
hazmat = []
jwk = ["alloc", "base64ct/alloc", "serde", "serde_json", "zeroize/alloc"]
Expand Down
51 changes: 43 additions & 8 deletions elliptic-curve/src/hash2curve/group_digest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
use super::MapToCurve;
use crate::{
hash2field::{hash_to_field, ExpandMsg, FromOkm},
Result,
ProjectiveArithmetic, Result,
};
use group::cofactor::CofactorGroup;

/// Adds hashing arbitrary byte sequences to a valid group element
pub trait GroupDigest {
pub trait GroupDigest: ProjectiveArithmetic<ProjectivePoint = Self::Output> {
/// The field element representation for a group value with multiple elements
type FieldElement: FromOkm + MapToCurve<Output = Self::Output> + Default + Copy;
/// The resulting group element
Expand All @@ -30,18 +30,26 @@ pub trait GroupDigest {
/// ## Using a fixed size hash function
///
/// ```ignore
/// let pt = ProjectivePoint::hash_from_bytes::<hash2field::ExpandMsgXmd<sha2::Sha256>>(b"test data", b"CURVE_XMD:SHA-256_SSWU_RO_");
/// let pt = ProjectivePoint::hash_from_bytes::<ExpandMsgXmd<sha2::Sha256>>(b"test data", b"CURVE_XMD:SHA-256_SSWU_RO_");
/// ```
///
/// ## Using an extendable output function
///
/// ```ignore
/// let pt = ProjectivePoint::hash_from_bytes::<hash2field::ExpandMsgXof<sha3::Shake256>>(b"test data", b"CURVE_XOF:SHAKE-256_SSWU_RO_");
/// let pt = ProjectivePoint::hash_from_bytes::<ExpandMsgXof<sha3::Shake256>>(b"test data", b"CURVE_XOF:SHAKE-256_SSWU_RO_");
/// ```
///
fn hash_from_bytes<X: ExpandMsg>(msg: &[u8], dst: &'static [u8]) -> Result<Self::Output> {
/// # Errors
/// Can't fail with [`ExpandMsgXmd`] or [`ExpandMsgXof`].
///
/// [`ExpandMsgXmd`]: crate::hash2field::ExpandMsgXmd
/// [`ExpandMsgXof`]: crate::hash2field::ExpandMsgXof
fn hash_from_bytes<'a, X: ExpandMsg<'a>>(
msgs: &[&[u8]],
dst: &'a [u8],
) -> Result<Self::Output> {
let mut u = [Self::FieldElement::default(), Self::FieldElement::default()];
hash_to_field::<X, _>(msg, dst, &mut u)?;
hash_to_field::<X, _>(msgs, dst, &mut u)?;
let q0 = u[0].map_to_curve();
let q1 = u[1].map_to_curve();
// Ideally we could add and then clear cofactor once
Expand All @@ -66,10 +74,37 @@ pub trait GroupDigest {
/// > uniformly random in G: the set of possible outputs of
/// > encode_to_curve is only a fraction of the points in G, and some
/// > points in this set are more likely to be output than others.
fn encode_from_bytes<X: ExpandMsg>(msg: &[u8], dst: &'static [u8]) -> Result<Self::Output> {
///
/// # Errors
/// Can't fail with [`ExpandMsgXmd`] or [`ExpandMsgXof`].
///
/// [`ExpandMsgXmd`]: crate::hash2field::ExpandMsgXmd
/// [`ExpandMsgXof`]: crate::hash2field::ExpandMsgXof
fn encode_from_bytes<'a, X: ExpandMsg<'a>>(
msgs: &[&[u8]],
dst: &'a [u8],
) -> Result<Self::Output> {
let mut u = [Self::FieldElement::default()];
hash_to_field::<X, _>(msg, dst, &mut u)?;
hash_to_field::<X, _>(msgs, dst, &mut u)?;
let q0 = u[0].map_to_curve();
Ok(q0.clear_cofactor())
}

/// Computes the hash to field routine according to
/// <https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-13.html#section-5>
/// and returns a scalar.
///
/// # Errors
/// Can't fail with [`ExpandMsgXmd`] or [`ExpandMsgXof`].
///
/// [`ExpandMsgXmd`]: crate::hash2field::ExpandMsgXmd
/// [`ExpandMsgXof`]: crate::hash2field::ExpandMsgXof
fn hash_to_scalar<'a, X: ExpandMsg<'a>>(msgs: &[&[u8]], dst: &'a [u8]) -> Result<Self::Scalar>
where
Self::Scalar: FromOkm,
{
let mut u = [Self::Scalar::default()];
hash_to_field::<X, _>(msgs, dst, &mut u)?;
Ok(u[0])
}
}
4 changes: 2 additions & 2 deletions elliptic-curve/src/hash2field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ pub trait FromOkm {
/// Convert an arbitrary byte sequence into a field element.
///
/// <https://tools.ietf.org/html/draft-irtf-cfrg-hash-to-curve-11#section-5.3>
pub fn hash_to_field<E, T>(data: &[u8], domain: &'static [u8], out: &mut [T]) -> Result<()>
pub fn hash_to_field<'a, E, T>(data: &[&[u8]], domain: &'a [u8], out: &mut [T]) -> Result<()>
where
E: ExpandMsg,
E: ExpandMsg<'a>,
T: FromOkm + Default,
{
let len_in_bytes = T::Length::to_usize() * out.len();
Expand Down
23 changes: 16 additions & 7 deletions elliptic-curve/src/hash2field/expand_msg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,22 @@ const OVERSIZE_DST_SALT: &[u8] = b"H2C-OVERSIZE-DST-";
const MAX_DST_LEN: usize = 255;

/// Trait for types implementing expand_message interface for `hash_to_field`.
pub trait ExpandMsg: Sized {
/// # Errors
/// See implementors of `ExpandMsg` for errors.
pub trait ExpandMsg<'a> {
/// Type holding data for the [`Expander`].
type Expander: Expander + Sized;

/// Expands `msg` to the required number of bytes.
///
/// Returns an expander that can be used to call `read` until enough
/// bytes have been consumed
fn expand_message(msg: &[u8], dst: &'static [u8], len_in_bytes: usize) -> Result<Self>;
fn expand_message(msgs: &[&[u8]], dst: &'a [u8], len_in_bytes: usize)
-> Result<Self::Expander>;
}

/// Expander that, call `read` until enough bytes have been consumed.
pub trait Expander {
/// Fill the array with the expanded bytes
fn fill_bytes(&mut self, okm: &mut [u8]);
}
Expand All @@ -30,21 +39,21 @@ pub trait ExpandMsg: Sized {
/// Implements [section 5.4.3 of `draft-irtf-cfrg-hash-to-curve-13`][dst].
///
/// [dst]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-hash-to-curve-13#section-5.4.3
pub(crate) enum Domain<L>
pub(crate) enum Domain<'a, L>
where
L: ArrayLength<u8> + IsLess<U256>,
{
/// > 255
Hashed(GenericArray<u8, L>),
/// <= 255
Array(&'static [u8]),
Array(&'a [u8]),
}

impl<L> Domain<L>
impl<'a, L> Domain<'a, L>
where
L: ArrayLength<u8> + IsLess<U256>,
{
pub fn xof<X>(dst: &'static [u8]) -> Self
pub fn xof<X>(dst: &'a [u8]) -> Self
where
X: Default + ExtendableOutput + Update,
{
Expand All @@ -61,7 +70,7 @@ where
}
}

pub fn xmd<X>(dst: &'static [u8]) -> Self
pub fn xmd<X>(dst: &'a [u8]) -> Self
where
X: Digest<OutputSize = L>,
{
Expand Down
132 changes: 82 additions & 50 deletions elliptic-curve/src/hash2field/expand_msg/xmd.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
//! `expand_message_xmd` based on a hash function.

use super::{Domain, ExpandMsg};
use core::marker::PhantomData;

use super::{Domain, ExpandMsg, Expander};
use crate::{Error, Result};
use digest::{
generic_array::{
Expand All @@ -11,52 +13,20 @@ use digest::{
};

/// Placeholder type for implementing `expand_message_xmd` based on a hash function
pub struct ExpandMsgXmd<HashT>
///
/// # Errors
/// - `len_in_bytes == 0`
/// - `len_in_bytes != out.len()`
/// - `len_in_bytes > u16::MAX`
/// - `len_in_bytes > 255 * HashT::OutputSize`
pub struct ExpandMsgXmd<HashT>(PhantomData<HashT>)
where
HashT: Digest + BlockInput,
HashT::OutputSize: IsLess<U256>,
HashT::OutputSize: IsLessOrEqual<HashT::BlockSize>,
{
b_0: GenericArray<u8, HashT::OutputSize>,
b_vals: GenericArray<u8, HashT::OutputSize>,
domain: Domain<HashT::OutputSize>,
index: u8,
offset: usize,
ell: u8,
}

impl<HashT> ExpandMsgXmd<HashT>
where
HashT: Digest + BlockInput,
HashT::OutputSize: IsLess<U256>,
HashT::OutputSize: IsLessOrEqual<HashT::BlockSize>,
{
fn next(&mut self) -> bool {
if self.index < self.ell {
self.index += 1;
self.offset = 0;
// b_0 XOR b_(idx - 1)
let mut tmp = GenericArray::<u8, HashT::OutputSize>::default();
self.b_0
.iter()
.zip(&self.b_vals[..])
.enumerate()
.for_each(|(j, (b0val, bi1val))| tmp[j] = b0val ^ bi1val);
self.b_vals = HashT::new()
.chain(tmp)
.chain([self.index])
.chain(self.domain.data())
.chain([self.domain.len()])
.finalize();
true
} else {
false
}
}
}
HashT::OutputSize: IsLessOrEqual<HashT::BlockSize>;

/// ExpandMsgXmd implements expand_message_xmd for the ExpandMsg trait
impl<HashT> ExpandMsg for ExpandMsgXmd<HashT>
impl<'a, HashT> ExpandMsg<'a> for ExpandMsgXmd<HashT>
where
HashT: Digest + BlockInput,
// If `len_in_bytes` is bigger then 256, length of the `DST` will depend on
Expand All @@ -67,7 +37,13 @@ where
// https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-13.html#section-5.4.1-4
HashT::OutputSize: IsLessOrEqual<HashT::BlockSize>,
{
fn expand_message(msg: &[u8], dst: &'static [u8], len_in_bytes: usize) -> Result<Self> {
type Expander = ExpanderXmd<'a, HashT>;

fn expand_message(
msgs: &[&[u8]],
dst: &'a [u8],
len_in_bytes: usize,
) -> Result<Self::Expander> {
if len_in_bytes == 0 {
return Err(Error);
}
Expand All @@ -78,9 +54,13 @@ where
let ell = u8::try_from((len_in_bytes + b_in_bytes - 1) / b_in_bytes).map_err(|_| Error)?;

let domain = Domain::xmd::<HashT>(dst);
let b_0 = HashT::new()
.chain(GenericArray::<u8, HashT::BlockSize>::default())
.chain(msg)
let mut b_0 = HashT::new().chain(GenericArray::<u8, HashT::BlockSize>::default());

for msg in msgs {
b_0 = b_0.chain(msg);
}

let b_0 = b_0
.chain(len_in_bytes_u16.to_be_bytes())
.chain([0])
.chain(domain.data())
Expand All @@ -94,7 +74,7 @@ where
.chain([domain.len()])
.finalize();

Ok(Self {
Ok(ExpanderXmd {
b_0,
b_vals,
domain,
Expand All @@ -103,7 +83,59 @@ where
ell,
})
}
}

/// [`Expander`] type for [`ExpandMsgXmd`].
pub struct ExpanderXmd<'a, HashT>
where
HashT: Digest + BlockInput,
HashT::OutputSize: IsLess<U256>,
HashT::OutputSize: IsLessOrEqual<HashT::BlockSize>,
{
b_0: GenericArray<u8, HashT::OutputSize>,
b_vals: GenericArray<u8, HashT::OutputSize>,
domain: Domain<'a, HashT::OutputSize>,
index: u8,
offset: usize,
ell: u8,
}

impl<'a, HashT> ExpanderXmd<'a, HashT>
where
HashT: Digest + BlockInput,
HashT::OutputSize: IsLess<U256>,
HashT::OutputSize: IsLessOrEqual<HashT::BlockSize>,
{
fn next(&mut self) -> bool {
if self.index < self.ell {
self.index += 1;
self.offset = 0;
// b_0 XOR b_(idx - 1)
let mut tmp = GenericArray::<u8, HashT::OutputSize>::default();
self.b_0
.iter()
.zip(&self.b_vals[..])
.enumerate()
.for_each(|(j, (b0val, bi1val))| tmp[j] = b0val ^ bi1val);
self.b_vals = HashT::new()
.chain(tmp)
.chain([self.index])
.chain(self.domain.data())
.chain([self.domain.len()])
.finalize();
true
} else {
false
}
}
}

impl<'a, HashT> Expander for ExpanderXmd<'a, HashT>
where
HashT: Digest + BlockInput,
HashT::OutputSize: IsLess<U256>,
HashT::OutputSize: IsLessOrEqual<HashT::BlockSize>,
{
fn fill_bytes(&mut self, okm: &mut [u8]) {
for b in okm {
if self.offset == self.b_vals.len() && !self.next() {
Expand All @@ -128,7 +160,7 @@ mod test {

fn assert_message<HashT>(
msg: &[u8],
domain: &Domain<HashT::OutputSize>,
domain: &Domain<'_, HashT::OutputSize>,
len_in_bytes: u16,
bytes: &[u8],
) where
Expand Down Expand Up @@ -169,7 +201,7 @@ mod test {
fn assert<HashT, L: ArrayLength<u8>>(
&self,
dst: &'static [u8],
domain: &Domain<HashT::OutputSize>,
domain: &Domain<'_, HashT::OutputSize>,
) -> Result<()>
where
HashT: Digest + BlockInput,
Expand All @@ -178,7 +210,7 @@ mod test {
assert_message::<HashT>(self.msg, domain, L::to_u16(), self.msg_prime);

let mut expander =
<ExpandMsgXmd<HashT> as ExpandMsg>::expand_message(self.msg, dst, L::to_usize())?;
ExpandMsgXmd::<HashT>::expand_message(&[self.msg], dst, L::to_usize())?;

let mut uniform_bytes = GenericArray::<u8, L>::default();
expander.fill_bytes(&mut uniform_bytes);
Expand Down
Loading