From 0f16437a27fb271e8b7e1718c2be12967bb961b0 Mon Sep 17 00:00:00 2001 From: dAxpeDDa Date: Sun, 9 Jan 2022 20:56:58 +0100 Subject: [PATCH 1/9] Decouple element from `Group` --- src/group/mod.rs | 73 ++++++++-------- src/group/p256.rs | 44 +++++----- src/group/ristretto.rs | 31 ++++--- src/group/tests.rs | 14 +-- src/lib.rs | 83 ++++++++++-------- src/serialization.rs | 8 +- src/tests/voprf_test_vectors.rs | 42 +++++---- src/util.rs | 7 +- src/voprf.rs | 149 ++++++++++++++++---------------- 9 files changed, 237 insertions(+), 214 deletions(-) diff --git a/src/group/mod.rs b/src/group/mod.rs index 8bc58b2..8f405b3 100644 --- a/src/group/mod.rs +++ b/src/group/mod.rs @@ -21,6 +21,8 @@ use digest::{Digest, FixedOutputReset}; use generic_array::typenum::U1; use generic_array::{ArrayLength, GenericArray}; use rand_core::{CryptoRng, RngCore}; +#[cfg(feature = "ristretto255")] +pub use ristretto::Ristretto255; use subtle::ConstantTimeEq; use zeroize::Zeroize; @@ -28,22 +30,38 @@ use crate::{Error, Result}; /// A prime-order subgroup of a base field (EC, prime-order field ...). This /// subgroup is noted additively — as in the draft RFC — in this trait. -pub trait Group: - Copy - + Sized - + ConstantTimeEq - + for<'a> Mul<&'a ::Scalar, Output = Self> - + for<'a> Add<&'a Self, Output = Self> -{ +pub trait Group { /// The ciphersuite identifier as dictated by /// const SUITE_ID: usize; + /// The type of group elements + type Elem: Copy + + Sized + + ConstantTimeEq + + Zeroize + + for<'a> Mul<&'a Self::Scalar, Output = Self::Elem> + + for<'a> Add<&'a Self::Elem, Output = Self::Elem>; + + /// The byte length necessary to represent group elements + type ElemLen: ArrayLength + 'static; + + /// The type of base field scalars + type Scalar: Zeroize + + Copy + + ConstantTimeEq + + for<'a> Add<&'a Self::Scalar, Output = Self::Scalar> + + for<'a> Sub<&'a Self::Scalar, Output = Self::Scalar> + + for<'a> Mul<&'a Self::Scalar, Output = Self::Scalar>; + + /// The byte length necessary to represent scalars + type ScalarLen: ArrayLength + 'static; + /// transforms a password and domain separation tag (DST) into a curve point fn hash_to_curve + Add>( msg: &[u8], dst: GenericArray, - ) -> Result + ) -> Result where >::Output: ArrayLength; @@ -60,16 +78,6 @@ pub trait Group: where >::Output: ArrayLength; - /// The type of base field scalars - type Scalar: Zeroize - + Copy - + ConstantTimeEq - + for<'a> Add<&'a Self::Scalar, Output = Self::Scalar> - + for<'a> Sub<&'a Self::Scalar, Output = Self::Scalar> - + for<'a> Mul<&'a Self::Scalar, Output = Self::Scalar>; - /// The byte length necessary to represent scalars - type ScalarLen: ArrayLength + 'static; - /// Return a scalar from its fixed-length bytes representation, without /// checking if the scalar is zero. fn from_scalar_slice_unchecked( @@ -90,28 +98,28 @@ pub trait Group: /// picks a scalar at random fn random_nonzero_scalar(rng: &mut R) -> Self::Scalar; + /// Serializes a scalar to bytes fn scalar_as_bytes(scalar: Self::Scalar) -> GenericArray; + /// The multiplicative inverse of this scalar fn scalar_invert(scalar: &Self::Scalar) -> Self::Scalar; - /// The byte length necessary to represent group elements - type ElemLen: ArrayLength + 'static; - /// Return an element from its fixed-length bytes representation. This is /// the unchecked version, which does not check for deserializing the /// identity element - fn from_element_slice_unchecked(element_bits: &GenericArray) - -> Result; + fn from_element_slice_unchecked( + element_bits: &GenericArray, + ) -> Result; /// Return an element from its fixed-length bytes representation. If the /// element is the identity element, return an error. fn from_element_slice<'a>( element_bits: impl Into<&'a GenericArray>, - ) -> Result { + ) -> Result { let elem = Self::from_element_slice_unchecked(element_bits.into())?; - if Self::ct_eq(&elem, &::identity()).into() { + if Self::Elem::ct_eq(&elem, &Self::identity()).into() { // found the identity element return Err(Error::PointError); } @@ -120,26 +128,21 @@ pub trait Group: } /// Serializes the `self` group element - fn to_arr(&self) -> GenericArray; + fn to_arr(elem: Self::Elem) -> GenericArray; /// Get the base point for the group - fn base_point() -> Self; + fn base_point() -> Self::Elem; /// Returns if the group element is equal to the identity (1) - fn is_identity(&self) -> bool { - self.ct_eq(&::identity()).into() + fn is_identity(elem: Self::Elem) -> bool { + elem.ct_eq(&Self::identity()).into() } /// Returns the identity group element - fn identity() -> Self; + fn identity() -> Self::Elem; /// Returns the scalar representing zero fn scalar_zero() -> Self::Scalar; - - /// Set the contents of self to the identity value - fn zeroize(&mut self) { - *self = ::identity(); - } } #[cfg(test)] diff --git a/src/group/p256.rs b/src/group/p256.rs index 0510999..cec7132 100644 --- a/src/group/p256.rs +++ b/src/group/p256.rs @@ -29,7 +29,7 @@ use p256_::elliptic_curve::group::GroupEncoding; use p256_::elliptic_curve::ops::Reduce; use p256_::elliptic_curve::sec1::{FromEncodedPoint, ToEncodedPoint}; use p256_::elliptic_curve::Field; -use p256_::{AffinePoint, EncodedPoint, ProjectivePoint}; +use p256_::{AffinePoint, EncodedPoint, NistP256, ProjectivePoint, Scalar}; use rand_core::{CryptoRng, RngCore}; use subtle::{Choice, ConditionallySelectable}; @@ -41,15 +41,23 @@ use crate::{Error, Result}; pub type L = U48; #[cfg(feature = "p256")] -impl Group for ProjectivePoint { +impl Group for NistP256 { const SUITE_ID: usize = 0x0003; + type Elem = ProjectivePoint; + + type ElemLen = U33; + + type Scalar = Scalar; + + type ScalarLen = U32; + // Implements the `hash_to_curve()` function from // https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-hash-to-curve-11#section-3 fn hash_to_curve + Add>( msg: &[u8], dst: GenericArray, - ) -> Result + ) -> Result where >::Output: ArrayLength, { @@ -133,21 +141,17 @@ impl Group for ProjectivePoint { let mut result = GenericArray::default(); result[..bytes.len()].copy_from_slice(&bytes); - Ok(p256_::Scalar::from_be_bytes_reduced(result)) + Ok(Scalar::from_be_bytes_reduced(result)) } - type ElemLen = U33; - type Scalar = p256_::Scalar; - type ScalarLen = U32; - fn from_scalar_slice_unchecked( scalar_bits: &GenericArray, ) -> Result { - Ok(Self::Scalar::from_be_bytes_reduced(*scalar_bits)) + Ok(Scalar::from_be_bytes_reduced(*scalar_bits)) } fn random_nonzero_scalar(rng: &mut R) -> Self::Scalar { - Self::Scalar::random(rng) + Scalar::random(rng) } fn scalar_as_bytes(scalar: Self::Scalar) -> GenericArray { @@ -155,33 +159,33 @@ impl Group for ProjectivePoint { } fn scalar_invert(scalar: &Self::Scalar) -> Self::Scalar { - scalar.invert().unwrap_or(Self::Scalar::zero()) + scalar.invert().unwrap_or(Scalar::zero()) } fn from_element_slice_unchecked( element_bits: &GenericArray, - ) -> Result { - Option::from(Self::from_bytes(element_bits)).ok_or(Error::PointError) + ) -> Result { + Option::from(ProjectivePoint::from_bytes(element_bits)).ok_or(Error::PointError) } - fn to_arr(&self) -> GenericArray { - let bytes = self.to_affine().to_encoded_point(true); + fn to_arr(elem: Self::Elem) -> GenericArray { + let bytes = elem.to_affine().to_encoded_point(true); let bytes = bytes.as_bytes(); let mut result = GenericArray::default(); result[..bytes.len()].copy_from_slice(bytes); result } - fn base_point() -> Self { - Self::generator() + fn base_point() -> Self::Elem { + ProjectivePoint::generator() } - fn identity() -> Self { - Self::identity() + fn identity() -> Self::Elem { + ProjectivePoint::identity() } fn scalar_zero() -> Self::Scalar { - Self::Scalar::zero() + Scalar::zero() } } diff --git a/src/group/ristretto.rs b/src/group/ristretto.rs index 2426cb7..d8ff040 100644 --- a/src/group/ristretto.rs +++ b/src/group/ristretto.rs @@ -21,18 +21,26 @@ use rand_core::{CryptoRng, RngCore}; use super::Group; use crate::{Error, Result}; +/// [`Group`] implementation for Ristretto255. +pub struct Ristretto255; + // `cfg` here is only needed because of a bug in Rust's crate feature documentation. See: https://github.com/rust-lang/rust/issues/83428 #[cfg(feature = "ristretto255")] -/// The implementation of such a subgroup for Ristretto -impl Group for RistrettoPoint { +impl Group for Ristretto255 { const SUITE_ID: usize = 0x0001; + type Elem = RistrettoPoint; + + type Scalar = Scalar; + + type ScalarLen = U32; + // Implements the `hash_to_ristretto255()` function from // https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-10.txt fn hash_to_curve + Add>( msg: &[u8], dst: GenericArray, - ) -> Result + ) -> Result where >::Output: ArrayLength, { @@ -70,8 +78,6 @@ impl Group for RistrettoPoint { )) } - type Scalar = Scalar; - type ScalarLen = U32; fn from_scalar_slice_unchecked( scalar_bits: &GenericArray, ) -> Result { @@ -104,25 +110,26 @@ impl Group for RistrettoPoint { type ElemLen = U32; fn from_element_slice_unchecked( element_bits: &GenericArray, - ) -> Result { + ) -> Result { CompressedRistretto::from_slice(element_bits) .decompress() .ok_or(Error::PointError) } + // serialization of a group element - fn to_arr(&self) -> GenericArray { - self.compress().to_bytes().into() + fn to_arr(elem: Self::Elem) -> GenericArray { + elem.compress().to_bytes().into() } - fn base_point() -> Self { + fn base_point() -> Self::Elem { RISTRETTO_BASEPOINT_POINT } - fn identity() -> Self { - ::identity() + fn identity() -> Self::Elem { + RistrettoPoint::identity() } fn scalar_zero() -> Self::Scalar { - Self::Scalar::zero() + Scalar::zero() } } diff --git a/src/group/tests.rs b/src/group/tests.rs index c763d5d..58115b6 100644 --- a/src/group/tests.rs +++ b/src/group/tests.rs @@ -16,18 +16,18 @@ use crate::{Error, Group, Result}; fn test_group_properties() -> Result<()> { #[cfg(feature = "ristretto255")] { - use curve25519_dalek::ristretto::RistrettoPoint; + use crate::Ristretto255; - test_identity_element_error::()?; - test_zero_scalar_error::()?; + test_identity_element_error::()?; + test_zero_scalar_error::()?; } #[cfg(feature = "p256")] { - use p256_::ProjectivePoint; + use p256_::NistP256; - test_identity_element_error::()?; - test_zero_scalar_error::()?; + test_identity_element_error::()?; + test_zero_scalar_error::()?; } Ok(()) @@ -36,7 +36,7 @@ fn test_group_properties() -> Result<()> { // Checks that the identity element cannot be deserialized fn test_identity_element_error() -> Result<()> { let identity = G::identity(); - let result = G::from_element_slice(&identity.to_arr()); + let result = G::from_element_slice(&G::to_arr(identity)); assert!(matches!(result, Err(Error::PointError))); Ok(()) diff --git a/src/lib.rs b/src/lib.rs index 82aa96d..ed37593 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -24,7 +24,7 @@ //! We will use the following choices in this example: //! //! ```ignore -//! type Group = curve25519_dalek::ristretto::RistrettoPoint; +//! type Group = voprf::Ristretto255; //! type Hash = sha2::Sha512; //! ``` //! @@ -52,11 +52,11 @@ //! //! ``` //! # #[cfg(feature = "ristretto255")] -//! # type Group = curve25519_dalek::ristretto::RistrettoPoint; +//! # type Group = voprf::Ristretto255; //! # #[cfg(feature = "ristretto255")] //! # type Hash = sha2::Sha512; //! # #[cfg(all(feature = "p256", not(feature = "ristretto255")))] -//! # type Group = p256_::ProjectivePoint; +//! # type Group = p256_::NistP256; //! # #[cfg(all(feature = "p256", not(feature = "ristretto255")))] //! # type Hash = sha2::Sha256; //! use rand::rngs::OsRng; @@ -78,11 +78,11 @@ //! //! ``` //! # #[cfg(feature = "ristretto255")] -//! # type Group = curve25519_dalek::ristretto::RistrettoPoint; +//! # type Group = voprf::Ristretto255; //! # #[cfg(feature = "ristretto255")] //! # type Hash = sha2::Sha512; //! # #[cfg(all(feature = "p256", not(feature = "ristretto255")))] -//! # type Group = p256_::ProjectivePoint; +//! # type Group = p256_::NistP256; //! # #[cfg(all(feature = "p256", not(feature = "ristretto255")))] //! # type Hash = sha2::Sha256; //! use rand::rngs::OsRng; @@ -104,11 +104,11 @@ //! //! ``` //! # #[cfg(feature = "ristretto255")] -//! # type Group = curve25519_dalek::ristretto::RistrettoPoint; +//! # type Group = voprf::Ristretto255; //! # #[cfg(feature = "ristretto255")] //! # type Hash = sha2::Sha512; //! # #[cfg(all(feature = "p256", not(feature = "ristretto255")))] -//! # type Group = p256_::ProjectivePoint; +//! # type Group = p256_::NistP256; //! # #[cfg(all(feature = "p256", not(feature = "ristretto255")))] //! # type Hash = sha2::Sha256; //! # use voprf::NonVerifiableClient; @@ -136,11 +136,11 @@ //! //! ``` //! # #[cfg(feature = "ristretto255")] -//! # type Group = curve25519_dalek::ristretto::RistrettoPoint; +//! # type Group = voprf::Ristretto255; //! # #[cfg(feature = "ristretto255")] //! # type Hash = sha2::Sha512; //! # #[cfg(all(feature = "p256", not(feature = "ristretto255")))] -//! # type Group = p256_::ProjectivePoint; +//! # type Group = p256_::NistP256; //! # #[cfg(all(feature = "p256", not(feature = "ristretto255")))] //! # type Hash = sha2::Sha256; //! # use voprf::NonVerifiableClient; @@ -187,11 +187,11 @@ //! //! ``` //! # #[cfg(feature = "ristretto255")] -//! # type Group = curve25519_dalek::ristretto::RistrettoPoint; +//! # type Group = voprf::Ristretto255; //! # #[cfg(feature = "ristretto255")] //! # type Hash = sha2::Sha512; //! # #[cfg(all(feature = "p256", not(feature = "ristretto255")))] -//! # type Group = p256_::ProjectivePoint; +//! # type Group = p256_::NistP256; //! # #[cfg(all(feature = "p256", not(feature = "ristretto255")))] //! # type Hash = sha2::Sha256; //! use rand::rngs::OsRng; @@ -220,11 +220,11 @@ //! //! ``` //! # #[cfg(feature = "ristretto255")] -//! # type Group = curve25519_dalek::ristretto::RistrettoPoint; +//! # type Group = voprf::Ristretto255; //! # #[cfg(feature = "ristretto255")] //! # type Hash = sha2::Sha512; //! # #[cfg(all(feature = "p256", not(feature = "ristretto255")))] -//! # type Group = p256_::ProjectivePoint; +//! # type Group = p256_::NistP256; //! # #[cfg(all(feature = "p256", not(feature = "ristretto255")))] //! # type Hash = sha2::Sha256; //! use rand::rngs::OsRng; @@ -246,11 +246,11 @@ //! //! ``` //! # #[cfg(feature = "ristretto255")] -//! # type Group = curve25519_dalek::ristretto::RistrettoPoint; +//! # type Group = voprf::Ristretto255; //! # #[cfg(feature = "ristretto255")] //! # type Hash = sha2::Sha512; //! # #[cfg(all(feature = "p256", not(feature = "ristretto255")))] -//! # type Group = p256_::ProjectivePoint; +//! # type Group = p256_::NistP256; //! # #[cfg(all(feature = "p256", not(feature = "ristretto255")))] //! # type Hash = sha2::Sha256; //! # use voprf::VerifiableClient; @@ -279,11 +279,11 @@ //! //! ``` //! # #[cfg(feature = "ristretto255")] -//! # type Group = curve25519_dalek::ristretto::RistrettoPoint; +//! # type Group = voprf::Ristretto255; //! # #[cfg(feature = "ristretto255")] //! # type Hash = sha2::Sha512; //! # #[cfg(all(feature = "p256", not(feature = "ristretto255")))] -//! # type Group = p256_::ProjectivePoint; +//! # type Group = p256_::NistP256; //! # #[cfg(all(feature = "p256", not(feature = "ristretto255")))] //! # type Hash = sha2::Sha256; //! # use voprf::VerifiableClient; @@ -336,11 +336,11 @@ //! //! ``` //! # #[cfg(feature = "ristretto255")] -//! # type Group = curve25519_dalek::ristretto::RistrettoPoint; +//! # type Group = voprf::Ristretto255; //! # #[cfg(feature = "ristretto255")] //! # type Hash = sha2::Sha512; //! # #[cfg(all(feature = "p256", not(feature = "ristretto255")))] -//! # type Group = p256_::ProjectivePoint; +//! # type Group = p256_::NistP256; //! # #[cfg(all(feature = "p256", not(feature = "ristretto255")))] //! # type Hash = sha2::Sha256; //! # use voprf::VerifiableClient; @@ -364,11 +364,11 @@ //! //! ``` //! # #[cfg(feature = "ristretto255")] -//! # type Group = curve25519_dalek::ristretto::RistrettoPoint; +//! # type Group = voprf::Ristretto255; //! # #[cfg(feature = "ristretto255")] //! # type Hash = sha2::Sha512; //! # #[cfg(all(feature = "p256", not(feature = "ristretto255")))] -//! # type Group = p256_::ProjectivePoint; +//! # type Group = p256_::NistP256; //! # #[cfg(all(feature = "p256", not(feature = "ristretto255")))] //! # type Hash = sha2::Sha256; //! # use voprf::{VerifiableServerBatchEvaluatePrepareResult, VerifiableServerBatchEvaluateFinishResult, VerifiableClient}; @@ -407,11 +407,11 @@ //! ``` //! # #[cfg(feature = "alloc")] { //! # #[cfg(feature = "ristretto255")] -//! # type Group = curve25519_dalek::ristretto::RistrettoPoint; +//! # type Group = voprf::Ristretto255; //! # #[cfg(feature = "ristretto255")] //! # type Hash = sha2::Sha512; //! # #[cfg(all(feature = "p256", not(feature = "ristretto255")))] -//! # type Group = p256_::ProjectivePoint; +//! # type Group = p256_::NistP256; //! # #[cfg(all(feature = "p256", not(feature = "ristretto255")))] //! # type Hash = sha2::Sha256; //! # use voprf::{VerifiableServerBatchEvaluateResult, VerifiableClient}; @@ -446,11 +446,11 @@ //! ``` //! # #[cfg(feature = "alloc")] { //! # #[cfg(feature = "ristretto255")] -//! # type Group = curve25519_dalek::ristretto::RistrettoPoint; +//! # type Group = voprf::Ristretto255; //! # #[cfg(feature = "ristretto255")] //! # type Hash = sha2::Sha512; //! # #[cfg(all(feature = "p256", not(feature = "ristretto255")))] -//! # type Group = p256_::ProjectivePoint; +//! # type Group = p256_::NistP256; //! # #[cfg(all(feature = "p256", not(feature = "ristretto255")))] //! # type Hash = sha2::Sha256; //! # use voprf::{VerifiableServerBatchEvaluateResult, VerifiableClient}; @@ -507,9 +507,10 @@ //! - The `alloc` feature requires Rusts [`alloc`] crate and enables batching //! VOPRF evaluations. //! -//! - The `p256` feature enables using p256 as the underlying group for the -//! [Group] choice and increases the MSRV to 1.56. Note that this is currently -//! an experimental feature ⚠️, and is not yet ready for production use. +//! - The `p256` feature enables using [`NistP256`](p256_::NistP256) as the +//! underlying group for the [Group] choice and increases the MSRV to 1.56. +//! Note that this is currently an experimental feature ⚠️, and is not yet +//! ready for production use. //! //! - The `serde` feature, enabled by default, provides convenience functions //! for serializing and deserializing with [serde](https://serde.rs/). @@ -520,18 +521,21 @@ //! that need access to these raw values and are able to perform the necessary //! validations on them (such as being valid group elements). //! -//! - The backend features are re-exported from [curve25519-dalek](https://doc.dalek.rs/curve25519_dalek/index.html#backends-and-features) -//! and allow for selecting the corresponding backend for the curve arithmetic -//! used. The `ristretto255_u64` feature is included as the default. Other -//! features are mapped as `ristretto255_u32`, `ristretto255_fiat_u64` and -//! `ristretto255_fiat_u32`. Any `ristretto255_*` backend feature will enable -//! the `ristretto255` feature, which can be used too, but keep in mind that -//! `curve25519-dalek` will fail to compile without a selected backend. -//! -//! - The `ristretto255_simd` feature is re-exported from [curve25519-dalek](https://doc.dalek.rs/curve25519_dalek/index.html#backends-and-features) -//! and enables parallel formulas, using either AVX2 or AVX512-IFMA. This will +//! - The `ristretto255` feature enables using [`Ristretto255`] as the +//! underlying group for the [Group] choice. A backend feature, which are +//! re-exported from [curve25519-dalek] and allow for selecting the +//! corresponding backend for the curve arithmetic used, has to be selected, +//! otherwise compilation will fail. The `ristretto255_u64` feature is +//! included as the default. Other features are mapped as `ristretto255_u32`, +//! `ristretto255_fiat_u64` and `ristretto255_fiat_u32`. Any `ristretto255_*` +//! backend feature will enable the `ristretto255` feature. +//! +//! - The `ristretto255_simd` feature is re-exported from [curve25519-dalek] and +//! enables parallel formulas, using either AVX2 or AVX512-IFMA. This will //! automatically enable the `ristretto255_u64` feature and requires Rust //! nightly. +//! +//! [curve25519-dalek]: (https://doc.dalek.rs/curve25519_dalek/index.html#backends-and-features) #![deny(unsafe_code)] #![no_std] @@ -557,6 +561,9 @@ mod tests; // Exports +#[cfg(feature = "ristretto255")] +pub use group::Ristretto255; + pub use crate::error::{Error, Result}; pub use crate::group::Group; #[cfg(feature = "alloc")] diff --git a/src/serialization.rs b/src/serialization.rs index dd6a0d4..1b8c229 100644 --- a/src/serialization.rs +++ b/src/serialization.rs @@ -53,7 +53,7 @@ impl VerifiableClient, Sum: ArrayLength, { - G::scalar_as_bytes(self.blind).concat(self.blinded_element.to_arr()) + G::scalar_as_bytes(self.blind).concat(G::to_arr(self.blinded_element)) } /// Deserialization from bytes @@ -97,7 +97,7 @@ impl VerifiableServer, Sum: ArrayLength, { - G::scalar_as_bytes(self.sk).concat(self.pk.to_arr()) + G::scalar_as_bytes(self.sk).concat(G::to_arr(self.pk)) } /// Deserialization from bytes @@ -143,7 +143,7 @@ impl Proof { impl BlindedElement { /// Serialization into bytes pub fn serialize(&self) -> GenericArray { - self.value.to_arr() + G::to_arr(self.value) } /// Deserialization from bytes @@ -162,7 +162,7 @@ impl BlindedElement EvaluationElement { /// Serialization into bytes pub fn serialize(&self) -> GenericArray { - self.value.to_arr() + G::to_arr(self.value) } /// Deserialization from bytes diff --git a/src/tests/voprf_test_vectors.rs b/src/tests/voprf_test_vectors.rs index dd949aa..83f5f34 100644 --- a/src/tests/voprf_test_vectors.rs +++ b/src/tests/voprf_test_vectors.rs @@ -90,9 +90,10 @@ fn test_vectors() -> Result<()> { #[cfg(feature = "ristretto255")] { - use curve25519_dalek::ristretto::RistrettoPoint; use sha2::Sha512; + use crate::Ristretto255; + let ristretto_base_tvs = json_to_test_vectors!( rfc, String::from("ristretto255, SHA-512"), @@ -105,20 +106,20 @@ fn test_vectors() -> Result<()> { String::from("Verifiable") ); - test_base_seed_to_key::(&ristretto_base_tvs)?; - test_base_blind::(&ristretto_base_tvs)?; - test_base_evaluate::(&ristretto_base_tvs)?; - test_base_finalize::(&ristretto_base_tvs)?; + test_base_seed_to_key::(&ristretto_base_tvs)?; + test_base_blind::(&ristretto_base_tvs)?; + test_base_evaluate::(&ristretto_base_tvs)?; + test_base_finalize::(&ristretto_base_tvs)?; - test_verifiable_seed_to_key::(&ristretto_verifiable_tvs)?; - test_verifiable_blind::(&ristretto_verifiable_tvs)?; - test_verifiable_evaluate::(&ristretto_verifiable_tvs)?; - test_verifiable_finalize::(&ristretto_verifiable_tvs)?; + test_verifiable_seed_to_key::(&ristretto_verifiable_tvs)?; + test_verifiable_blind::(&ristretto_verifiable_tvs)?; + test_verifiable_evaluate::(&ristretto_verifiable_tvs)?; + test_verifiable_finalize::(&ristretto_verifiable_tvs)?; } #[cfg(feature = "p256")] { - use p256_::ProjectivePoint; + use p256_::NistP256; use sha2::Sha256; let p256_base_tvs = @@ -130,15 +131,15 @@ fn test_vectors() -> Result<()> { String::from("Verifiable") ); - test_base_seed_to_key::(&p256_base_tvs)?; - test_base_blind::(&p256_base_tvs)?; - test_base_evaluate::(&p256_base_tvs)?; - test_base_finalize::(&p256_base_tvs)?; + test_base_seed_to_key::(&p256_base_tvs)?; + test_base_blind::(&p256_base_tvs)?; + test_base_evaluate::(&p256_base_tvs)?; + test_base_finalize::(&p256_base_tvs)?; - test_verifiable_seed_to_key::(&p256_verifiable_tvs)?; - test_verifiable_blind::(&p256_verifiable_tvs)?; - test_verifiable_evaluate::(&p256_verifiable_tvs)?; - test_verifiable_finalize::(&p256_verifiable_tvs)?; + test_verifiable_seed_to_key::(&p256_verifiable_tvs)?; + test_verifiable_blind::(&p256_verifiable_tvs)?; + test_verifiable_evaluate::(&p256_verifiable_tvs)?; + test_verifiable_finalize::(&p256_verifiable_tvs)?; } Ok(()) @@ -168,7 +169,10 @@ fn test_verifiable_seed_to_key { #[cfg(feature = "ristretto255")] { - let _ = - $item::::deserialize( - &$bytes[..], - ); + let _ = $item::::deserialize(&$bytes[..]); } #[cfg(feature = "p256")] { - let _ = $item::::deserialize(&$bytes[..]); + let _ = $item::::deserialize(&$bytes[..]); } }; } diff --git a/src/voprf.rs b/src/voprf.rs index f4eeee8..e8321ca 100644 --- a/src/voprf.rs +++ b/src/voprf.rs @@ -74,18 +74,18 @@ pub struct NonVerifiableClient { pub(crate) blind: G::Scalar, - pub(crate) blinded_element: G, + pub(crate) blinded_element: G::Elem, #[derive_where(skip(Zeroize))] pub(crate) hash: PhantomData, } @@ -113,18 +113,18 @@ pub struct NonVerifiableServer { pub(crate) sk: G::Scalar, - pub(crate) pk: G, + pub(crate) pk: G::Elem, #[derive_where(skip(Zeroize))] pub(crate) hash: PhantomData, } @@ -153,17 +153,17 @@ pub struct Proof { /// server (either verifiable or not). #[derive(DeriveWhere)] #[derive_where(Clone, Zeroize(drop))] -#[derive_where(Debug, Eq, Hash, Ord, PartialEq, PartialOrd; G)] +#[derive_where(Debug, Eq, Hash, Ord, PartialEq, PartialOrd; G::Elem)] #[cfg_attr( feature = "serde", derive(serde::Deserialize, serde::Serialize), serde(bound( - deserialize = "G: serde::Deserialize<'de>", - serialize = "G: serde::Serialize" + deserialize = "G::Elem: serde::Deserialize<'de>", + serialize = "G::Elem: serde::Serialize" )) )] pub struct BlindedElement { - pub(crate) value: G, + pub(crate) value: G::Elem, #[derive_where(skip(Zeroize))] pub(crate) hash: PhantomData, } @@ -172,17 +172,17 @@ pub struct BlindedElement { - pub(crate) value: G, + pub(crate) value: G::Elem, #[derive_where(skip(Zeroize))] pub(crate) hash: PhantomData, } @@ -328,7 +328,7 @@ impl VerifiableClient, proof: &Proof, - pk: G, + pk: G::Elem, metadata: Option<&[u8]>, ) -> Result> { // `core::array::from_ref` needs a MSRV of 1.53 @@ -350,7 +350,7 @@ impl VerifiableClient, - pk: G, + pk: G::Elem, metadata: Option<&'a [u8]>, ) -> Result> where @@ -379,7 +379,7 @@ impl VerifiableClient Self { + pub fn from_blind_and_element(blind: G::Scalar, blinded_element: G::Elem) -> Self { Self { blind, blinded_element, @@ -647,7 +647,7 @@ impl VerifiableServer G { + pub fn get_public_key(&self) -> G::Elem { self.pk } } @@ -783,7 +783,7 @@ impl BlindedElement Self { + pub fn from_value_unchecked(value: G::Elem) -> Self { Self { value, hash: PhantomData, @@ -792,7 +792,7 @@ impl BlindedElement G { + pub fn value(&self) -> G::Elem { self.value } } @@ -813,7 +813,7 @@ impl EvaluationElement Self { + pub fn from_value_unchecked(value: G::Elem) -> Self { Self { value, hash: PhantomData, @@ -822,7 +822,7 @@ impl EvaluationElement G { + pub fn value(&self) -> G::Elem { self.value } } @@ -832,7 +832,7 @@ fn blind Result<(G::Scalar, G)> { +) -> Result<(G::Scalar, G::Elem)> { // Choose a random scalar that must be non-zero let blind = G::random_nonzero_scalar(blinding_factor_rng); let blinded_element = deterministic_blind_unchecked::(input, &blind, mode)?; @@ -846,7 +846,7 @@ fn deterministic_blind_unchecked Result { +) -> Result { let dst = GenericArray::from(STR_HASH_TO_GROUP).concat(get_context_string::(mode)?); let hashed_point = G::hash_to_curve::(input, dst)?; Ok(hashed_point * blind) @@ -860,7 +860,7 @@ type VerifiableUnblindResult<'a, G, H, IC, IM> = Map< >, <&'a IM as IntoIterator>::IntoIter, >, - fn((::Scalar, &EvaluationElement)) -> G, + fn((::Scalar, &EvaluationElement)) -> ::Elem, >; fn verifiable_unblind< @@ -872,7 +872,7 @@ fn verifiable_unblind< >( clients: &'a IC, messages: &'a IM, - pk: G, + pk: G::Elem, proof: &Proof, info: &[u8], ) -> Result> @@ -921,8 +921,8 @@ fn generate_proof< >( rng: &mut R, k: G::Scalar, - a: G, - b: G, + a: G::Elem, + b: G::Elem, cs: impl Iterator> + ExactSizeIterator, ds: impl Iterator> + ExactSizeIterator, ) -> Result> { @@ -936,11 +936,11 @@ fn generate_proof< GenericArray::from(STR_CHALLENGE).concat(get_context_string::(Mode::Verifiable)?); chain!( h2_input, - Serialize::::from_owned(b.to_arr())?, - Serialize::::from_owned(m.to_arr())?, - Serialize::::from_owned(z.to_arr())?, - Serialize::::from_owned(t2.to_arr())?, - Serialize::::from_owned(t3.to_arr())?, + Serialize::::from_owned(G::to_arr(b))?, + Serialize::::from_owned(G::to_arr(m))?, + Serialize::::from_owned(G::to_arr(z))?, + Serialize::::from_owned(G::to_arr(t2))?, + Serialize::::from_owned(G::to_arr(t3))?, Serialize::::from_owned(challenge_dst)?, ); @@ -959,8 +959,8 @@ fn generate_proof< #[allow(clippy::many_single_char_names)] fn verify_proof( - a: G, - b: G, + a: G::Elem, + b: G::Elem, cs: impl Iterator> + ExactSizeIterator, ds: impl Iterator> + ExactSizeIterator, proof: &Proof, @@ -973,11 +973,11 @@ fn verify_proof( GenericArray::from(STR_CHALLENGE).concat(get_context_string::(Mode::Verifiable)?); chain!( h2_input, - Serialize::::from_owned(b.to_arr())?, - Serialize::::from_owned(m.to_arr())?, - Serialize::::from_owned(z.to_arr())?, - Serialize::::from_owned(t2.to_arr())?, - Serialize::::from_owned(t3.to_arr())?, + Serialize::::from_owned(G::to_arr(b))?, + Serialize::::from_owned(G::to_arr(m))?, + Serialize::::from_owned(G::to_arr(z))?, + Serialize::::from_owned(G::to_arr(t2))?, + Serialize::::from_owned(G::to_arr(t3))?, Serialize::::from_owned(challenge_dst)?, ); @@ -993,7 +993,7 @@ fn verify_proof( type FinalizeAfterUnblindResult<'a, G, H, I, IE> = Map< Zip)>>, - fn(((I, G), (&'a [u8], GenericArray))) -> Result>, + fn(((I, ::Elem), (&'a [u8], GenericArray))) -> Result>, >; fn finalize_after_unblind< @@ -1001,7 +1001,7 @@ fn finalize_after_unblind< G: Group, H: BlockSizeUser + Digest + FixedOutputReset, I: AsRef<[u8]>, - IE: 'a + Iterator, + IE: 'a + Iterator, >( inputs_and_unblinded_elements: IE, info: &'a [u8], @@ -1018,7 +1018,7 @@ fn finalize_after_unblind< hash_input, Serialize::::from(input.as_ref())?, Serialize::::from(info)?, - Serialize::::from_owned(unblinded_element.to_arr())?, + Serialize::::from_owned(G::to_arr(unblinded_element))?, Serialize::::from_owned(finalize_dst)?, ); @@ -1030,10 +1030,10 @@ fn finalize_after_unblind< fn compute_composites( k_option: Option, - b: G, + b: G::Elem, c_slice: impl Iterator> + ExactSizeIterator, d_slice: impl Iterator> + ExactSizeIterator, -) -> Result<(G, G)> { +) -> Result<(G::Elem, G::Elem)> { if c_slice.len() != d_slice.len() { return Err(Error::MismatchedLengthsForCompositeInputs); } @@ -1044,7 +1044,7 @@ fn compute_composites( chain!( h1_input, - Serialize::::from_owned(b.to_arr())?, + Serialize::::from_owned(G::to_arr(b))?, Serialize::::from_owned(seed_dst)?, ); let seed = h1_input @@ -1058,8 +1058,8 @@ fn compute_composites( chain!(h2_input, Serialize::::from_owned(seed.clone())?, i2osp::(i)? => |x| Some(x.as_slice()), - Serialize::::from_owned(c.value.to_arr())?, - Serialize::::from_owned(d.value.to_arr())?, + Serialize::::from_owned(G::to_arr(c.value))?, + Serialize::::from_owned(G::to_arr(d.value))?, Serialize::::from_owned(composite_dst)?, ); let dst = GenericArray::from(STR_HASH_TO_SCALAR) @@ -1415,38 +1415,39 @@ mod tests { fn test_functionality() -> Result<()> { #[cfg(feature = "ristretto255")] { - use curve25519_dalek::ristretto::RistrettoPoint; use sha2::Sha512; - base_retrieval::(); - base_inversion_unsalted::(); - verifiable_retrieval::(); - verifiable_batch_retrieval::(); - verifiable_bad_public_key::(); - verifiable_batch_bad_public_key::(); - - zeroize_base_client::(); - zeroize_base_server::(); - zeroize_verifiable_client::(); - zeroize_verifiable_server::(); + use crate::Ristretto255; + + base_retrieval::(); + base_inversion_unsalted::(); + verifiable_retrieval::(); + verifiable_batch_retrieval::(); + verifiable_bad_public_key::(); + verifiable_batch_bad_public_key::(); + + zeroize_base_client::(); + zeroize_base_server::(); + zeroize_verifiable_client::(); + zeroize_verifiable_server::(); } #[cfg(feature = "p256")] { - use p256_::ProjectivePoint; + use p256_::NistP256; use sha2::Sha256; - base_retrieval::(); - base_inversion_unsalted::(); - verifiable_retrieval::(); - verifiable_batch_retrieval::(); - verifiable_bad_public_key::(); - verifiable_batch_bad_public_key::(); - - zeroize_base_client::(); - zeroize_base_server::(); - zeroize_verifiable_client::(); - zeroize_verifiable_server::(); + base_retrieval::(); + base_inversion_unsalted::(); + verifiable_retrieval::(); + verifiable_batch_retrieval::(); + verifiable_bad_public_key::(); + verifiable_batch_bad_public_key::(); + + zeroize_base_client::(); + zeroize_base_server::(); + zeroize_verifiable_client::(); + zeroize_verifiable_server::(); } Ok(()) From 142714cdeab7e16a01b54ab775536d6eb375cff7 Mon Sep 17 00:00:00 2001 From: dAxpeDDa Date: Sun, 9 Jan 2022 21:18:41 +0100 Subject: [PATCH 2/9] Change `SUITE_ID` to `u16` and rework `get_context_string()` --- src/group/mod.rs | 2 +- src/group/p256.rs | 2 +- src/group/ristretto.rs | 2 +- src/voprf.rs | 68 +++++++++++++++++++++++------------------- 4 files changed, 40 insertions(+), 34 deletions(-) diff --git a/src/group/mod.rs b/src/group/mod.rs index 8f405b3..2509fb0 100644 --- a/src/group/mod.rs +++ b/src/group/mod.rs @@ -33,7 +33,7 @@ use crate::{Error, Result}; pub trait Group { /// The ciphersuite identifier as dictated by /// - const SUITE_ID: usize; + const SUITE_ID: u16; /// The type of group elements type Elem: Copy diff --git a/src/group/p256.rs b/src/group/p256.rs index cec7132..a056d83 100644 --- a/src/group/p256.rs +++ b/src/group/p256.rs @@ -42,7 +42,7 @@ pub type L = U48; #[cfg(feature = "p256")] impl Group for NistP256 { - const SUITE_ID: usize = 0x0003; + const SUITE_ID: u16 = 0x0003; type Elem = ProjectivePoint; diff --git a/src/group/ristretto.rs b/src/group/ristretto.rs index d8ff040..269d2da 100644 --- a/src/group/ristretto.rs +++ b/src/group/ristretto.rs @@ -27,7 +27,7 @@ pub struct Ristretto255; // `cfg` here is only needed because of a bug in Rust's crate feature documentation. See: https://github.com/rust-lang/rust/issues/83428 #[cfg(feature = "ristretto255")] impl Group for Ristretto255 { - const SUITE_ID: usize = 0x0001; + const SUITE_ID: u16 = 0x0001; type Elem = RistrettoPoint; diff --git a/src/voprf.rs b/src/voprf.rs index e8321ca..bf94460 100644 --- a/src/voprf.rs +++ b/src/voprf.rs @@ -17,7 +17,7 @@ use derive_where::DeriveWhere; use digest::core_api::BlockSizeUser; use digest::{Digest, FixedOutputReset, Output}; use generic_array::sequence::Concat; -use generic_array::typenum::{U1, U11, U2, U20}; +use generic_array::typenum::{U11, U2, U20}; use generic_array::GenericArray; use rand_core::{CryptoRng, RngCore}; use subtle::ConstantTimeEq; @@ -42,8 +42,17 @@ static STR_VOPRF: [u8; 8] = *b"VOPRF08-"; /// Determines the mode of operation (either base mode or verifiable mode) #[derive(Clone, Copy)] enum Mode { - Base = 0, - Verifiable = 1, + Base, + Verifiable, +} + +impl Mode { + fn to_u8(self) -> u8 { + match self { + Mode::Base => 0, + Mode::Verifiable => 1, + } + } } //////////////////////////// @@ -418,7 +427,7 @@ impl NonVerifiableServer /// Corresponds to DeriveKeyPair() function from the VOPRF specification. pub fn new_from_seed(seed: &[u8]) -> Result { let dst = - GenericArray::from(STR_HASH_TO_SCALAR).concat(get_context_string::(Mode::Base)?); + GenericArray::from(STR_HASH_TO_SCALAR).concat(get_context_string::(Mode::Base)); let sk = G::hash_to_scalar::(Some(seed), dst)?; Ok(Self { sk, @@ -443,11 +452,11 @@ impl NonVerifiableServer chain!( context, STR_CONTEXT => |x| Some(x.as_ref()), - get_context_string::(Mode::Base)? => |x| Some(x.as_slice()), + get_context_string::(Mode::Base) => |x| Some(x.as_slice()), Serialize::::from(metadata.unwrap_or_default())?, ); let dst = - GenericArray::from(STR_HASH_TO_SCALAR).concat(get_context_string::(Mode::Base)?); + GenericArray::from(STR_HASH_TO_SCALAR).concat(get_context_string::(Mode::Base)); let m = G::hash_to_scalar::(context, dst)?; let t = self.sk + &m; let evaluation_element = blinded_element.value * &G::scalar_invert(&t); @@ -486,7 +495,7 @@ impl VerifiableServer Result { let dst = GenericArray::from(STR_HASH_TO_SCALAR) - .concat(get_context_string::(Mode::Verifiable)?); + .concat(get_context_string::(Mode::Verifiable)); let sk = G::hash_to_scalar::(Some(seed), dst)?; let pk = G::base_point() * &sk; Ok(Self { @@ -581,11 +590,11 @@ impl VerifiableServer Result> { chain!(context, STR_CONTEXT => |x| Some(x.as_ref()), - get_context_string::(Mode::Verifiable)? => |x| Some(x.as_slice()), + get_context_string::(Mode::Verifiable) => |x| Some(x.as_slice()), Serialize::::from(metadata.unwrap_or_default())?, ); let dst = GenericArray::from(STR_HASH_TO_SCALAR) - .concat(get_context_string::(Mode::Verifiable)?); + .concat(get_context_string::(Mode::Verifiable)); let m = G::hash_to_scalar::(context, dst)?; let t = self.sk + &m; let evaluation_elements = blinded_elements @@ -847,7 +856,7 @@ fn deterministic_blind_unchecked Result { - let dst = GenericArray::from(STR_HASH_TO_GROUP).concat(get_context_string::(mode)?); + let dst = GenericArray::from(STR_HASH_TO_GROUP).concat(get_context_string::(mode)); let hashed_point = G::hash_to_curve::(input, dst)?; Ok(hashed_point * blind) } @@ -884,12 +893,12 @@ where { chain!(context, STR_CONTEXT => |x| Some(x.as_ref()), - get_context_string::(Mode::Verifiable)? => |x| Some(x.as_slice()), + get_context_string::(Mode::Verifiable) => |x| Some(x.as_slice()), Serialize::::from(info)?, ); let dst = - GenericArray::from(STR_HASH_TO_SCALAR).concat(get_context_string::(Mode::Verifiable)?); + GenericArray::from(STR_HASH_TO_SCALAR).concat(get_context_string::(Mode::Verifiable)); let m = G::hash_to_scalar::(context, dst)?; let g = G::base_point(); @@ -933,7 +942,7 @@ fn generate_proof< let t3 = m * &r; let challenge_dst = - GenericArray::from(STR_CHALLENGE).concat(get_context_string::(Mode::Verifiable)?); + GenericArray::from(STR_CHALLENGE).concat(get_context_string::(Mode::Verifiable)); chain!( h2_input, Serialize::::from_owned(G::to_arr(b))?, @@ -945,7 +954,7 @@ fn generate_proof< ); let hash_to_scalar_dst = - GenericArray::from(STR_HASH_TO_SCALAR).concat(get_context_string::(Mode::Verifiable)?); + GenericArray::from(STR_HASH_TO_SCALAR).concat(get_context_string::(Mode::Verifiable)); let c_scalar = G::hash_to_scalar::(h2_input, hash_to_scalar_dst)?; let s_scalar = r - &(c_scalar * &k); @@ -970,7 +979,7 @@ fn verify_proof( let t3 = (m * &proof.s_scalar) + &(z * &proof.c_scalar); let challenge_dst = - GenericArray::from(STR_CHALLENGE).concat(get_context_string::(Mode::Verifiable)?); + GenericArray::from(STR_CHALLENGE).concat(get_context_string::(Mode::Verifiable)); chain!( h2_input, Serialize::::from_owned(G::to_arr(b))?, @@ -982,7 +991,7 @@ fn verify_proof( ); let hash_to_scalar_dst = - GenericArray::from(STR_HASH_TO_SCALAR).concat(get_context_string::(Mode::Verifiable)?); + GenericArray::from(STR_HASH_TO_SCALAR).concat(get_context_string::(Mode::Verifiable)); let c = G::hash_to_scalar::(h2_input, hash_to_scalar_dst)?; match c.ct_eq(&proof.c_scalar).into() { @@ -1007,7 +1016,7 @@ fn finalize_after_unblind< info: &'a [u8], mode: Mode, ) -> Result> { - let finalize_dst = GenericArray::from(STR_FINALIZE).concat(get_context_string::(mode)?); + let finalize_dst = GenericArray::from(STR_FINALIZE).concat(get_context_string::(mode)); Ok(inputs_and_unblinded_elements // To make a return type possible, we have to convert to a `fn` pointer, @@ -1038,9 +1047,9 @@ fn compute_composites( return Err(Error::MismatchedLengthsForCompositeInputs); } - let seed_dst = GenericArray::from(STR_SEED).concat(get_context_string::(Mode::Verifiable)?); + let seed_dst = GenericArray::from(STR_SEED).concat(get_context_string::(Mode::Verifiable)); let composite_dst = - GenericArray::from(STR_COMPOSITE).concat(get_context_string::(Mode::Verifiable)?); + GenericArray::from(STR_COMPOSITE).concat(get_context_string::(Mode::Verifiable)); chain!( h1_input, @@ -1063,7 +1072,7 @@ fn compute_composites( Serialize::::from_owned(composite_dst)?, ); let dst = GenericArray::from(STR_HASH_TO_SCALAR) - .concat(get_context_string::(Mode::Verifiable)?); + .concat(get_context_string::(Mode::Verifiable)); let di = G::hash_to_scalar::(h2_input, dst)?; m = c.value * &di + &m; z = match k_option { @@ -1082,10 +1091,10 @@ fn compute_composites( /// Generates the contextString parameter as defined in /// -fn get_context_string(mode: Mode) -> Result> { - Ok(GenericArray::from(STR_VOPRF) - .concat(i2osp::(mode as usize)?) - .concat(i2osp::(G::SUITE_ID)?)) +fn get_context_string(mode: Mode) -> GenericArray { + GenericArray::from(STR_VOPRF) + .concat([mode.to_u8()].into()) + .concat(G::SUITE_ID.to_be_bytes().into()) } /////////// @@ -1113,18 +1122,16 @@ mod tests { info: &[u8], mode: Mode, ) -> Output { - let dst = - GenericArray::from(STR_HASH_TO_GROUP).concat(get_context_string::(mode).unwrap()); + let dst = GenericArray::from(STR_HASH_TO_GROUP).concat(get_context_string::(mode)); let point = G::hash_to_curve::(input, dst).unwrap(); chain!(context, STR_CONTEXT => |x| Some(x.as_ref()), - get_context_string::(mode).unwrap() => |x| Some(x.as_slice()), + get_context_string::(mode) => |x| Some(x.as_slice()), Serialize::::from(info).unwrap(), ); - let dst = - GenericArray::from(STR_HASH_TO_SCALAR).concat(get_context_string::(mode).unwrap()); + let dst = GenericArray::from(STR_HASH_TO_SCALAR).concat(get_context_string::(mode)); let m = G::hash_to_scalar::(context, dst).unwrap(); let res = point * &G::scalar_invert(&(key + &m)); @@ -1315,8 +1322,7 @@ mod tests { ) .unwrap(); - let dst = GenericArray::from(STR_HASH_TO_GROUP) - .concat(get_context_string::(Mode::Base).unwrap()); + let dst = GenericArray::from(STR_HASH_TO_GROUP).concat(get_context_string::(Mode::Base)); let point = G::hash_to_curve::(&input, dst).unwrap(); let res2 = finalize_after_unblind::( Some((input.as_ref(), point)).into_iter(), From f8cd7a13302f8fc9e8ab7ef3cdb27689716f7a0e Mon Sep 17 00:00:00 2001 From: dAxpeDDa Date: Mon, 10 Jan 2022 00:45:47 +0100 Subject: [PATCH 3/9] Rework scalar de-serialization --- src/error.rs | 4 ++-- src/group/mod.rs | 18 ++---------------- src/group/p256.rs | 10 +++++----- src/group/ristretto.rs | 8 ++++---- src/group/tests.rs | 4 ++-- src/serialization.rs | 12 ++++++------ src/tests/voprf_test_vectors.rs | 8 ++++---- src/voprf.rs | 4 ++-- 8 files changed, 27 insertions(+), 41 deletions(-) diff --git a/src/error.rs b/src/error.rs index 3968af7..084ead2 100644 --- a/src/error.rs +++ b/src/error.rs @@ -34,8 +34,8 @@ pub enum Error { ProofVerificationError, /// Encountered insufficient bytes when attempting to deserialize SizeError, - /// Encountered a zero scalar - ZeroScalarError, + /// Encountered an invalid scalar + ScalarError, } #[cfg(feature = "std")] diff --git a/src/group/mod.rs b/src/group/mod.rs index 2509fb0..a0d3e4c 100644 --- a/src/group/mod.rs +++ b/src/group/mod.rs @@ -78,23 +78,9 @@ pub trait Group { where >::Output: ArrayLength; - /// Return a scalar from its fixed-length bytes representation, without - /// checking if the scalar is zero. - fn from_scalar_slice_unchecked( - scalar_bits: &GenericArray, - ) -> Result; - /// Return a scalar from its fixed-length bytes representation. If the - /// scalar is zero, then return an error. - fn from_scalar_slice<'a>( - scalar_bits: impl Into<&'a GenericArray>, - ) -> Result { - let scalar = Self::from_scalar_slice_unchecked(scalar_bits.into())?; - if scalar.ct_eq(&Self::scalar_zero()).into() { - return Err(Error::ZeroScalarError); - } - Ok(scalar) - } + /// scalar is zero or invalid, then return an error. + fn deserialize_scalar(scalar_bits: &GenericArray) -> Result; /// picks a scalar at random fn random_nonzero_scalar(rng: &mut R) -> Self::Scalar; diff --git a/src/group/p256.rs b/src/group/p256.rs index a056d83..55e2d3e 100644 --- a/src/group/p256.rs +++ b/src/group/p256.rs @@ -29,7 +29,7 @@ use p256_::elliptic_curve::group::GroupEncoding; use p256_::elliptic_curve::ops::Reduce; use p256_::elliptic_curve::sec1::{FromEncodedPoint, ToEncodedPoint}; use p256_::elliptic_curve::Field; -use p256_::{AffinePoint, EncodedPoint, NistP256, ProjectivePoint, Scalar}; +use p256_::{AffinePoint, EncodedPoint, NistP256, ProjectivePoint, Scalar, SecretKey}; use rand_core::{CryptoRng, RngCore}; use subtle::{Choice, ConditionallySelectable}; @@ -144,10 +144,10 @@ impl Group for NistP256 { Ok(Scalar::from_be_bytes_reduced(result)) } - fn from_scalar_slice_unchecked( - scalar_bits: &GenericArray, - ) -> Result { - Ok(Scalar::from_be_bytes_reduced(*scalar_bits)) + fn deserialize_scalar(scalar_bits: &GenericArray) -> Result { + SecretKey::from_be_bytes(scalar_bits) + .map(|secret_key| *secret_key.to_nonzero_scalar()) + .map_err(|_| Error::ScalarError) } fn random_nonzero_scalar(rng: &mut R) -> Self::Scalar { diff --git a/src/group/ristretto.rs b/src/group/ristretto.rs index 269d2da..3a5c5c6 100644 --- a/src/group/ristretto.rs +++ b/src/group/ristretto.rs @@ -78,10 +78,10 @@ impl Group for Ristretto255 { )) } - fn from_scalar_slice_unchecked( - scalar_bits: &GenericArray, - ) -> Result { - Ok(Scalar::from_bytes_mod_order(*scalar_bits.as_ref())) + fn deserialize_scalar(scalar_bits: &GenericArray) -> Result { + Scalar::from_canonical_bytes((*scalar_bits).into()) + .filter(|scalar| scalar != &Scalar::zero()) + .ok_or(Error::ScalarError) } fn random_nonzero_scalar(rng: &mut R) -> Self::Scalar { diff --git a/src/group/tests.rs b/src/group/tests.rs index 58115b6..20ac4a7 100644 --- a/src/group/tests.rs +++ b/src/group/tests.rs @@ -45,8 +45,8 @@ fn test_identity_element_error() -> Result<()> { // Checks that the zero scalar cannot be deserialized fn test_zero_scalar_error() -> Result<()> { let zero_scalar = G::scalar_zero(); - let result = G::from_scalar_slice(&G::scalar_as_bytes(zero_scalar)); - assert!(matches!(result, Err(Error::ZeroScalarError))); + let result = G::deserialize_scalar(&G::scalar_as_bytes(zero_scalar)); + assert!(matches!(result, Err(Error::ScalarError))); Ok(()) } diff --git a/src/serialization.rs b/src/serialization.rs index 1b8c229..047fa84 100644 --- a/src/serialization.rs +++ b/src/serialization.rs @@ -37,7 +37,7 @@ impl NonVerifiableClient pub fn deserialize(input: &[u8]) -> Result { let mut input = input.iter().copied(); - let blind = G::from_scalar_slice(&deserialize(&mut input)?)?; + let blind = G::deserialize_scalar(&deserialize(&mut input)?)?; Ok(Self { blind, @@ -60,7 +60,7 @@ impl VerifiableClient Result { let mut input = input.iter().copied(); - let blind = G::from_scalar_slice(&deserialize(&mut input)?)?; + let blind = G::deserialize_scalar(&deserialize(&mut input)?)?; let blinded_element = G::from_element_slice(&deserialize(&mut input)?)?; Ok(Self { @@ -81,7 +81,7 @@ impl NonVerifiableServer pub fn deserialize(input: &[u8]) -> Result { let mut input = input.iter().copied(); - let sk = G::from_scalar_slice(&deserialize(&mut input)?)?; + let sk = G::deserialize_scalar(&deserialize(&mut input)?)?; Ok(Self { sk, @@ -104,7 +104,7 @@ impl VerifiableServer Result { let mut input = input.iter().copied(); - let sk = G::from_scalar_slice(&deserialize(&mut input)?)?; + let sk = G::deserialize_scalar(&deserialize(&mut input)?)?; let pk = G::from_element_slice(&deserialize(&mut input)?)?; Ok(Self { @@ -129,8 +129,8 @@ impl Proof { pub fn deserialize(input: &[u8]) -> Result { let mut input = input.iter().copied(); - let c_scalar = G::from_scalar_slice(&deserialize(&mut input)?)?; - let s_scalar = G::from_scalar_slice(&deserialize(&mut input)?)?; + let c_scalar = G::deserialize_scalar(&deserialize(&mut input)?)?; + let s_scalar = G::deserialize_scalar(&deserialize(&mut input)?)?; Ok(Proof { c_scalar, diff --git a/src/tests/voprf_test_vectors.rs b/src/tests/voprf_test_vectors.rs index 83f5f34..7d3cc3f 100644 --- a/src/tests/voprf_test_vectors.rs +++ b/src/tests/voprf_test_vectors.rs @@ -184,7 +184,7 @@ fn test_base_blind( for parameters in tvs { for i in 0..parameters.input.len() { let blind = - G::from_scalar_slice(&GenericArray::clone_from_slice(¶meters.blind[i]))?; + G::deserialize_scalar(&GenericArray::clone_from_slice(¶meters.blind[i]))?; let client_result = NonVerifiableClient::::deterministic_blind_unchecked( ¶meters.input[i], blind, @@ -210,7 +210,7 @@ fn test_verifiable_blind for parameters in tvs { for i in 0..parameters.input.len() { let blind = - G::from_scalar_slice(&GenericArray::clone_from_slice(¶meters.blind[i]))?; + G::deserialize_scalar(&GenericArray::clone_from_slice(¶meters.blind[i]))?; let client_blind_result = VerifiableClient::::deterministic_blind_unchecked( ¶meters.input[i], blind, @@ -299,7 +299,7 @@ fn test_base_finalize( ) -> Result<()> { for parameters in tvs { for i in 0..parameters.input.len() { - let client = NonVerifiableClient::::from_blind(G::from_scalar_slice( + let client = NonVerifiableClient::::from_blind(G::deserialize_scalar( &GenericArray::clone_from_slice(¶meters.blind[i]), )?); @@ -322,7 +322,7 @@ fn test_verifiable_finalize::from_blind_and_element( - G::from_scalar_slice(&GenericArray::clone_from_slice(¶meters.blind[i]))?, + G::deserialize_scalar(&GenericArray::clone_from_slice(¶meters.blind[i]))?, G::from_element_slice(&GenericArray::clone_from_slice( ¶meters.blinded_element[i], ))?, diff --git a/src/voprf.rs b/src/voprf.rs index bf94460..6807524 100644 --- a/src/voprf.rs +++ b/src/voprf.rs @@ -414,7 +414,7 @@ impl NonVerifiableServer /// Produces a new instance of a [NonVerifiableServer] using a supplied set /// of bytes to represent the server's private key pub fn new_with_key(private_key_bytes: &[u8]) -> Result { - let sk = G::from_scalar_slice(private_key_bytes)?; + let sk = G::deserialize_scalar(private_key_bytes.into())?; Ok(Self { sk, hash: PhantomData, @@ -480,7 +480,7 @@ impl VerifiableServer Result { - let sk = G::from_scalar_slice(key)?; + let sk = G::deserialize_scalar(key.into())?; let pk = G::base_point() * &sk; Ok(Self { sk, From 0efb3d9aa357f413674e9270d6e7fc081c0dacbb Mon Sep 17 00:00:00 2001 From: dAxpeDDa Date: Mon, 10 Jan 2022 01:36:02 +0100 Subject: [PATCH 4/9] Rename `Group` methods - `random_nonzero_scalar` -> `random_scalar` - `scalar_as_bytes` -> `serialize_scalar` - `scalar_invert` -> `invert_scalar` --- src/group/mod.rs | 6 +++--- src/group/p256.rs | 10 +++++----- src/group/ristretto.rs | 10 +++++----- src/group/tests.rs | 2 +- src/serialization.rs | 10 +++++----- src/tests/voprf_test_vectors.rs | 8 ++++---- src/voprf.rs | 14 +++++++------- 7 files changed, 30 insertions(+), 30 deletions(-) diff --git a/src/group/mod.rs b/src/group/mod.rs index a0d3e4c..b3ba055 100644 --- a/src/group/mod.rs +++ b/src/group/mod.rs @@ -83,13 +83,13 @@ pub trait Group { fn deserialize_scalar(scalar_bits: &GenericArray) -> Result; /// picks a scalar at random - fn random_nonzero_scalar(rng: &mut R) -> Self::Scalar; + fn random_scalar(rng: &mut R) -> Self::Scalar; /// Serializes a scalar to bytes - fn scalar_as_bytes(scalar: Self::Scalar) -> GenericArray; + fn serialize_scalar(scalar: Self::Scalar) -> GenericArray; /// The multiplicative inverse of this scalar - fn scalar_invert(scalar: &Self::Scalar) -> Self::Scalar; + fn invert_scalar(scalar: Self::Scalar) -> Self::Scalar; /// Return an element from its fixed-length bytes representation. This is /// the unchecked version, which does not check for deserializing the diff --git a/src/group/p256.rs b/src/group/p256.rs index 55e2d3e..d8ae809 100644 --- a/src/group/p256.rs +++ b/src/group/p256.rs @@ -150,16 +150,16 @@ impl Group for NistP256 { .map_err(|_| Error::ScalarError) } - fn random_nonzero_scalar(rng: &mut R) -> Self::Scalar { - Scalar::random(rng) + fn random_scalar(rng: &mut R) -> Self::Scalar { + *SecretKey::random(rng).to_nonzero_scalar() } - fn scalar_as_bytes(scalar: Self::Scalar) -> GenericArray { + fn serialize_scalar(scalar: Self::Scalar) -> GenericArray { scalar.into() } - fn scalar_invert(scalar: &Self::Scalar) -> Self::Scalar { - scalar.invert().unwrap_or(Scalar::zero()) + fn invert_scalar(scalar: Self::Scalar) -> Self::Scalar { + Option::from(scalar.invert()).unwrap() } fn from_element_slice_unchecked( diff --git a/src/group/ristretto.rs b/src/group/ristretto.rs index 3a5c5c6..437d837 100644 --- a/src/group/ristretto.rs +++ b/src/group/ristretto.rs @@ -31,6 +31,8 @@ impl Group for Ristretto255 { type Elem = RistrettoPoint; + type ElemLen = U32; + type Scalar = Scalar; type ScalarLen = U32; @@ -84,7 +86,7 @@ impl Group for Ristretto255 { .ok_or(Error::ScalarError) } - fn random_nonzero_scalar(rng: &mut R) -> Self::Scalar { + fn random_scalar(rng: &mut R) -> Self::Scalar { loop { let scalar = { let mut scalar_bytes = [0u8; 64]; @@ -98,16 +100,14 @@ impl Group for Ristretto255 { } } - fn scalar_as_bytes(scalar: Self::Scalar) -> GenericArray { + fn serialize_scalar(scalar: Self::Scalar) -> GenericArray { scalar.to_bytes().into() } - fn scalar_invert(scalar: &Self::Scalar) -> Self::Scalar { + fn invert_scalar(scalar: Self::Scalar) -> Self::Scalar { scalar.invert() } - // The byte length necessary to represent group elements - type ElemLen = U32; fn from_element_slice_unchecked( element_bits: &GenericArray, ) -> Result { diff --git a/src/group/tests.rs b/src/group/tests.rs index 20ac4a7..f2eb2e6 100644 --- a/src/group/tests.rs +++ b/src/group/tests.rs @@ -45,7 +45,7 @@ fn test_identity_element_error() -> Result<()> { // Checks that the zero scalar cannot be deserialized fn test_zero_scalar_error() -> Result<()> { let zero_scalar = G::scalar_zero(); - let result = G::deserialize_scalar(&G::scalar_as_bytes(zero_scalar)); + let result = G::deserialize_scalar(&G::serialize_scalar(zero_scalar)); assert!(matches!(result, Err(Error::ScalarError))); Ok(()) diff --git a/src/serialization.rs b/src/serialization.rs index 047fa84..a2649e8 100644 --- a/src/serialization.rs +++ b/src/serialization.rs @@ -30,7 +30,7 @@ use crate::{ impl NonVerifiableClient { /// Serialization into bytes pub fn serialize(&self) -> GenericArray { - G::scalar_as_bytes(self.blind) + G::serialize_scalar(self.blind) } /// Deserialization from bytes @@ -53,7 +53,7 @@ impl VerifiableClient, Sum: ArrayLength, { - G::scalar_as_bytes(self.blind).concat(G::to_arr(self.blinded_element)) + G::serialize_scalar(self.blind).concat(G::to_arr(self.blinded_element)) } /// Deserialization from bytes @@ -74,7 +74,7 @@ impl VerifiableClient NonVerifiableServer { /// Serialization into bytes pub fn serialize(&self) -> GenericArray { - G::scalar_as_bytes(self.sk) + G::serialize_scalar(self.sk) } /// Deserialization from bytes @@ -97,7 +97,7 @@ impl VerifiableServer, Sum: ArrayLength, { - G::scalar_as_bytes(self.sk).concat(G::to_arr(self.pk)) + G::serialize_scalar(self.sk).concat(G::to_arr(self.pk)) } /// Deserialization from bytes @@ -122,7 +122,7 @@ impl Proof { G::ScalarLen: Add, Sum: ArrayLength, { - G::scalar_as_bytes(self.c_scalar).concat(G::scalar_as_bytes(self.s_scalar)) + G::serialize_scalar(self.c_scalar).concat(G::serialize_scalar(self.s_scalar)) } /// Deserialization from bytes diff --git a/src/tests/voprf_test_vectors.rs b/src/tests/voprf_test_vectors.rs index 7d3cc3f..6e2aadb 100644 --- a/src/tests/voprf_test_vectors.rs +++ b/src/tests/voprf_test_vectors.rs @@ -153,7 +153,7 @@ fn test_base_seed_to_key assert_eq!( ¶meters.sksm, - &G::scalar_as_bytes(server.get_private_key()).to_vec() + &G::serialize_scalar(server.get_private_key()).to_vec() ); } Ok(()) @@ -167,7 +167,7 @@ fn test_verifiable_seed_to_key( assert_eq!( ¶meters.blind[i], - &G::scalar_as_bytes(client_result.state.blind).to_vec() + &G::serialize_scalar(client_result.state.blind).to_vec() ); assert_eq!( parameters.blinded_element[i].as_slice(), @@ -218,7 +218,7 @@ fn test_verifiable_blind assert_eq!( ¶meters.blind[i], - &G::scalar_as_bytes(client_blind_result.state.get_blind()).to_vec() + &G::serialize_scalar(client_blind_result.state.get_blind()).to_vec() ); assert_eq!( parameters.blinded_element[i].as_slice(), diff --git a/src/voprf.rs b/src/voprf.rs index 6807524..2b51e69 100644 --- a/src/voprf.rs +++ b/src/voprf.rs @@ -255,7 +255,7 @@ impl NonVerifiableClient evaluation_element: &EvaluationElement, metadata: Option<&[u8]>, ) -> Result> { - let unblinded_element = evaluation_element.value * &G::scalar_invert(&self.blind); + let unblinded_element = evaluation_element.value * &G::invert_scalar(self.blind); let mut outputs = finalize_after_unblind::( Some((input, unblinded_element)).into_iter(), metadata.unwrap_or_default(), @@ -459,7 +459,7 @@ impl NonVerifiableServer GenericArray::from(STR_HASH_TO_SCALAR).concat(get_context_string::(Mode::Base)); let m = G::hash_to_scalar::(context, dst)?; let t = self.sk + &m; - let evaluation_element = blinded_element.value * &G::scalar_invert(&t); + let evaluation_element = blinded_element.value * &G::invert_scalar(t); Ok(NonVerifiableServerEvaluateResult { message: EvaluationElement { value: evaluation_element, @@ -600,7 +600,7 @@ impl VerifiableServer, _)) -> _>::from(|(x, t)| { PreparedEvaluationElement(EvaluationElement { value: x.value * &t, @@ -843,7 +843,7 @@ fn blind Result<(G::Scalar, G::Elem)> { // Choose a random scalar that must be non-zero - let blind = G::random_nonzero_scalar(blinding_factor_rng); + let blind = G::random_scalar(blinding_factor_rng); let blinded_element = deterministic_blind_unchecked::(input, &blind, mode)?; Ok((blind, blinded_element)) } @@ -919,7 +919,7 @@ where Ok(blinds .zip(messages.into_iter()) - .map(|(blind, x)| x.value * &G::scalar_invert(&blind))) + .map(|(blind, x)| x.value * &G::invert_scalar(blind))) } #[allow(clippy::many_single_char_names)] @@ -937,7 +937,7 @@ fn generate_proof< ) -> Result> { let (m, z) = compute_composites(Some(k), b, cs, ds)?; - let r = G::random_nonzero_scalar(rng); + let r = G::random_scalar(rng); let t2 = a * &r; let t3 = m * &r; @@ -1134,7 +1134,7 @@ mod tests { let dst = GenericArray::from(STR_HASH_TO_SCALAR).concat(get_context_string::(mode)); let m = G::hash_to_scalar::(context, dst).unwrap(); - let res = point * &G::scalar_invert(&(key + &m)); + let res = point * &G::invert_scalar(key + &m); finalize_after_unblind::(Some((input, res)).into_iter(), info, mode) .unwrap() From 95810fb64593823cb627da07edfbb1274734fae8 Mon Sep 17 00:00:00 2001 From: dAxpeDDa Date: Mon, 10 Jan 2022 01:36:25 +0100 Subject: [PATCH 5/9] Rework element de-serialization --- src/group/mod.rs | 22 ++-------------------- src/group/p256.rs | 11 +++++------ src/group/ristretto.rs | 5 ++--- src/group/tests.rs | 2 +- src/serialization.rs | 8 ++++---- src/tests/voprf_test_vectors.rs | 4 ++-- 6 files changed, 16 insertions(+), 36 deletions(-) diff --git a/src/group/mod.rs b/src/group/mod.rs index b3ba055..0f2a1d5 100644 --- a/src/group/mod.rs +++ b/src/group/mod.rs @@ -26,7 +26,7 @@ pub use ristretto::Ristretto255; use subtle::ConstantTimeEq; use zeroize::Zeroize; -use crate::{Error, Result}; +use crate::Result; /// A prime-order subgroup of a base field (EC, prime-order field ...). This /// subgroup is noted additively — as in the draft RFC — in this trait. @@ -91,27 +91,9 @@ pub trait Group { /// The multiplicative inverse of this scalar fn invert_scalar(scalar: Self::Scalar) -> Self::Scalar; - /// Return an element from its fixed-length bytes representation. This is - /// the unchecked version, which does not check for deserializing the - /// identity element - fn from_element_slice_unchecked( - element_bits: &GenericArray, - ) -> Result; - /// Return an element from its fixed-length bytes representation. If the /// element is the identity element, return an error. - fn from_element_slice<'a>( - element_bits: impl Into<&'a GenericArray>, - ) -> Result { - let elem = Self::from_element_slice_unchecked(element_bits.into())?; - - if Self::Elem::ct_eq(&elem, &Self::identity()).into() { - // found the identity element - return Err(Error::PointError); - } - - Ok(elem) - } + fn deserialize_elem(element_bits: &GenericArray) -> Result; /// Serializes the `self` group element fn to_arr(elem: Self::Elem) -> GenericArray; diff --git a/src/group/p256.rs b/src/group/p256.rs index d8ae809..6e3ba0a 100644 --- a/src/group/p256.rs +++ b/src/group/p256.rs @@ -25,11 +25,10 @@ use num_integer::Integer; use num_traits::{One, ToPrimitive, Zero}; use once_cell::unsync::Lazy; use p256_::elliptic_curve::group::prime::PrimeCurveAffine; -use p256_::elliptic_curve::group::GroupEncoding; use p256_::elliptic_curve::ops::Reduce; use p256_::elliptic_curve::sec1::{FromEncodedPoint, ToEncodedPoint}; use p256_::elliptic_curve::Field; -use p256_::{AffinePoint, EncodedPoint, NistP256, ProjectivePoint, Scalar, SecretKey}; +use p256_::{AffinePoint, EncodedPoint, NistP256, ProjectivePoint, PublicKey, Scalar, SecretKey}; use rand_core::{CryptoRng, RngCore}; use subtle::{Choice, ConditionallySelectable}; @@ -162,10 +161,10 @@ impl Group for NistP256 { Option::from(scalar.invert()).unwrap() } - fn from_element_slice_unchecked( - element_bits: &GenericArray, - ) -> Result { - Option::from(ProjectivePoint::from_bytes(element_bits)).ok_or(Error::PointError) + fn deserialize_elem(element_bits: &GenericArray) -> Result { + PublicKey::from_sec1_bytes(element_bits) + .map(|public_key| public_key.to_projective()) + .map_err(|_| Error::PointError) } fn to_arr(elem: Self::Elem) -> GenericArray { diff --git a/src/group/ristretto.rs b/src/group/ristretto.rs index 437d837..fa018a5 100644 --- a/src/group/ristretto.rs +++ b/src/group/ristretto.rs @@ -108,11 +108,10 @@ impl Group for Ristretto255 { scalar.invert() } - fn from_element_slice_unchecked( - element_bits: &GenericArray, - ) -> Result { + fn deserialize_elem(element_bits: &GenericArray) -> Result { CompressedRistretto::from_slice(element_bits) .decompress() + .filter(|point| point != &RistrettoPoint::identity()) .ok_or(Error::PointError) } diff --git a/src/group/tests.rs b/src/group/tests.rs index f2eb2e6..6cff83b 100644 --- a/src/group/tests.rs +++ b/src/group/tests.rs @@ -36,7 +36,7 @@ fn test_group_properties() -> Result<()> { // Checks that the identity element cannot be deserialized fn test_identity_element_error() -> Result<()> { let identity = G::identity(); - let result = G::from_element_slice(&G::to_arr(identity)); + let result = G::deserialize_elem(&G::to_arr(identity)); assert!(matches!(result, Err(Error::PointError))); Ok(()) diff --git a/src/serialization.rs b/src/serialization.rs index a2649e8..9fe3256 100644 --- a/src/serialization.rs +++ b/src/serialization.rs @@ -61,7 +61,7 @@ impl VerifiableClient VerifiableServer BlindedElement Result { let mut input = input.iter().copied(); - let value = G::from_element_slice(&deserialize(&mut input)?)?; + let value = G::deserialize_elem(&deserialize(&mut input)?)?; Ok(Self { value, @@ -169,7 +169,7 @@ impl EvaluationElement Result { let mut input = input.iter().copied(); - let value = G::from_element_slice(&deserialize(&mut input)?)?; + let value = G::deserialize_elem(&deserialize(&mut input)?)?; Ok(Self { value, diff --git a/src/tests/voprf_test_vectors.rs b/src/tests/voprf_test_vectors.rs index 6e2aadb..4e173d2 100644 --- a/src/tests/voprf_test_vectors.rs +++ b/src/tests/voprf_test_vectors.rs @@ -323,7 +323,7 @@ fn test_verifiable_finalize::from_blind_and_element( G::deserialize_scalar(&GenericArray::clone_from_slice(¶meters.blind[i]))?, - G::from_element_slice(&GenericArray::clone_from_slice( + G::deserialize_elem(&GenericArray::clone_from_slice( ¶meters.blinded_element[i], ))?, ); @@ -341,7 +341,7 @@ fn test_verifiable_finalize Date: Mon, 10 Jan 2022 01:55:44 +0100 Subject: [PATCH 6/9] Rename and remove `Group` methods `to_arr` -> `serialize_elem` `base_point` -> `base_elem` `is_identity` -> removed `identity` -> `identity_elem` `zero_scalar` -> hidden behind `cfg(test)` --- src/group/mod.rs | 14 +++++------- src/group/p256.rs | 10 +++++---- src/group/ristretto.rs | 9 ++++---- src/group/tests.rs | 6 ++--- src/serialization.rs | 8 +++---- src/tests/voprf_test_vectors.rs | 2 +- src/voprf.rs | 40 ++++++++++++++++----------------- 7 files changed, 44 insertions(+), 45 deletions(-) diff --git a/src/group/mod.rs b/src/group/mod.rs index 0f2a1d5..8eba51f 100644 --- a/src/group/mod.rs +++ b/src/group/mod.rs @@ -96,21 +96,17 @@ pub trait Group { fn deserialize_elem(element_bits: &GenericArray) -> Result; /// Serializes the `self` group element - fn to_arr(elem: Self::Elem) -> GenericArray; + fn serialize_elem(elem: Self::Elem) -> GenericArray; /// Get the base point for the group - fn base_point() -> Self::Elem; - - /// Returns if the group element is equal to the identity (1) - fn is_identity(elem: Self::Elem) -> bool { - elem.ct_eq(&Self::identity()).into() - } + fn base_elem() -> Self::Elem; /// Returns the identity group element - fn identity() -> Self::Elem; + fn identity_elem() -> Self::Elem; /// Returns the scalar representing zero - fn scalar_zero() -> Self::Scalar; + #[cfg(test)] + fn zero_scalar() -> Self::Scalar; } #[cfg(test)] diff --git a/src/group/p256.rs b/src/group/p256.rs index 6e3ba0a..b1ee3da 100644 --- a/src/group/p256.rs +++ b/src/group/p256.rs @@ -27,6 +27,7 @@ use once_cell::unsync::Lazy; use p256_::elliptic_curve::group::prime::PrimeCurveAffine; use p256_::elliptic_curve::ops::Reduce; use p256_::elliptic_curve::sec1::{FromEncodedPoint, ToEncodedPoint}; +#[cfg(test)] use p256_::elliptic_curve::Field; use p256_::{AffinePoint, EncodedPoint, NistP256, ProjectivePoint, PublicKey, Scalar, SecretKey}; use rand_core::{CryptoRng, RngCore}; @@ -167,7 +168,7 @@ impl Group for NistP256 { .map_err(|_| Error::PointError) } - fn to_arr(elem: Self::Elem) -> GenericArray { + fn serialize_elem(elem: Self::Elem) -> GenericArray { let bytes = elem.to_affine().to_encoded_point(true); let bytes = bytes.as_bytes(); let mut result = GenericArray::default(); @@ -175,15 +176,16 @@ impl Group for NistP256 { result } - fn base_point() -> Self::Elem { + fn base_elem() -> Self::Elem { ProjectivePoint::generator() } - fn identity() -> Self::Elem { + fn identity_elem() -> Self::Elem { ProjectivePoint::identity() } - fn scalar_zero() -> Self::Scalar { + #[cfg(test)] + fn zero_scalar() -> Self::Scalar { Scalar::zero() } } diff --git a/src/group/ristretto.rs b/src/group/ristretto.rs index fa018a5..2b0315d 100644 --- a/src/group/ristretto.rs +++ b/src/group/ristretto.rs @@ -116,19 +116,20 @@ impl Group for Ristretto255 { } // serialization of a group element - fn to_arr(elem: Self::Elem) -> GenericArray { + fn serialize_elem(elem: Self::Elem) -> GenericArray { elem.compress().to_bytes().into() } - fn base_point() -> Self::Elem { + fn base_elem() -> Self::Elem { RISTRETTO_BASEPOINT_POINT } - fn identity() -> Self::Elem { + fn identity_elem() -> Self::Elem { RistrettoPoint::identity() } - fn scalar_zero() -> Self::Scalar { + #[cfg(test)] + fn zero_scalar() -> Self::Scalar { Scalar::zero() } } diff --git a/src/group/tests.rs b/src/group/tests.rs index 6cff83b..d389048 100644 --- a/src/group/tests.rs +++ b/src/group/tests.rs @@ -35,8 +35,8 @@ fn test_group_properties() -> Result<()> { // Checks that the identity element cannot be deserialized fn test_identity_element_error() -> Result<()> { - let identity = G::identity(); - let result = G::deserialize_elem(&G::to_arr(identity)); + let identity = G::identity_elem(); + let result = G::deserialize_elem(&G::serialize_elem(identity)); assert!(matches!(result, Err(Error::PointError))); Ok(()) @@ -44,7 +44,7 @@ fn test_identity_element_error() -> Result<()> { // Checks that the zero scalar cannot be deserialized fn test_zero_scalar_error() -> Result<()> { - let zero_scalar = G::scalar_zero(); + let zero_scalar = G::zero_scalar(); let result = G::deserialize_scalar(&G::serialize_scalar(zero_scalar)); assert!(matches!(result, Err(Error::ScalarError))); diff --git a/src/serialization.rs b/src/serialization.rs index 9fe3256..1d4abff 100644 --- a/src/serialization.rs +++ b/src/serialization.rs @@ -53,7 +53,7 @@ impl VerifiableClient, Sum: ArrayLength, { - G::serialize_scalar(self.blind).concat(G::to_arr(self.blinded_element)) + G::serialize_scalar(self.blind).concat(G::serialize_elem(self.blinded_element)) } /// Deserialization from bytes @@ -97,7 +97,7 @@ impl VerifiableServer, Sum: ArrayLength, { - G::serialize_scalar(self.sk).concat(G::to_arr(self.pk)) + G::serialize_scalar(self.sk).concat(G::serialize_elem(self.pk)) } /// Deserialization from bytes @@ -143,7 +143,7 @@ impl Proof { impl BlindedElement { /// Serialization into bytes pub fn serialize(&self) -> GenericArray { - G::to_arr(self.value) + G::serialize_elem(self.value) } /// Deserialization from bytes @@ -162,7 +162,7 @@ impl BlindedElement EvaluationElement { /// Serialization into bytes pub fn serialize(&self) -> GenericArray { - G::to_arr(self.value) + G::serialize_elem(self.value) } /// Deserialization from bytes diff --git a/src/tests/voprf_test_vectors.rs b/src/tests/voprf_test_vectors.rs index 4e173d2..898404d 100644 --- a/src/tests/voprf_test_vectors.rs +++ b/src/tests/voprf_test_vectors.rs @@ -171,7 +171,7 @@ fn test_verifiable_seed_to_key VerifiableServer Result { let sk = G::deserialize_scalar(key.into())?; - let pk = G::base_point() * &sk; + let pk = G::base_elem() * &sk; Ok(Self { sk, pk, @@ -497,7 +497,7 @@ impl VerifiableServer(Mode::Verifiable)); let sk = G::hash_to_scalar::(Some(seed), dst)?; - let pk = G::base_point() * &sk; + let pk = G::base_elem() * &sk; Ok(Self { sk, pk, @@ -632,7 +632,7 @@ impl VerifiableServer>, <&'b IE as IntoIterator>::IntoIter: ExactSizeIterator, { - let g = G::base_point(); + let g = G::base_elem(); let u = g * t; let proof = generate_proof( @@ -901,7 +901,7 @@ where GenericArray::from(STR_HASH_TO_SCALAR).concat(get_context_string::(Mode::Verifiable)); let m = G::hash_to_scalar::(context, dst)?; - let g = G::base_point(); + let g = G::base_elem(); let t = g * &m; let u = t + &pk; @@ -945,11 +945,11 @@ fn generate_proof< GenericArray::from(STR_CHALLENGE).concat(get_context_string::(Mode::Verifiable)); chain!( h2_input, - Serialize::::from_owned(G::to_arr(b))?, - Serialize::::from_owned(G::to_arr(m))?, - Serialize::::from_owned(G::to_arr(z))?, - Serialize::::from_owned(G::to_arr(t2))?, - Serialize::::from_owned(G::to_arr(t3))?, + Serialize::::from_owned(G::serialize_elem(b))?, + Serialize::::from_owned(G::serialize_elem(m))?, + Serialize::::from_owned(G::serialize_elem(z))?, + Serialize::::from_owned(G::serialize_elem(t2))?, + Serialize::::from_owned(G::serialize_elem(t3))?, Serialize::::from_owned(challenge_dst)?, ); @@ -982,11 +982,11 @@ fn verify_proof( GenericArray::from(STR_CHALLENGE).concat(get_context_string::(Mode::Verifiable)); chain!( h2_input, - Serialize::::from_owned(G::to_arr(b))?, - Serialize::::from_owned(G::to_arr(m))?, - Serialize::::from_owned(G::to_arr(z))?, - Serialize::::from_owned(G::to_arr(t2))?, - Serialize::::from_owned(G::to_arr(t3))?, + Serialize::::from_owned(G::serialize_elem(b))?, + Serialize::::from_owned(G::serialize_elem(m))?, + Serialize::::from_owned(G::serialize_elem(z))?, + Serialize::::from_owned(G::serialize_elem(t2))?, + Serialize::::from_owned(G::serialize_elem(t3))?, Serialize::::from_owned(challenge_dst)?, ); @@ -1027,7 +1027,7 @@ fn finalize_after_unblind< hash_input, Serialize::::from(input.as_ref())?, Serialize::::from(info)?, - Serialize::::from_owned(G::to_arr(unblinded_element))?, + Serialize::::from_owned(G::serialize_elem(unblinded_element))?, Serialize::::from_owned(finalize_dst)?, ); @@ -1053,22 +1053,22 @@ fn compute_composites( chain!( h1_input, - Serialize::::from_owned(G::to_arr(b))?, + Serialize::::from_owned(G::serialize_elem(b))?, Serialize::::from_owned(seed_dst)?, ); let seed = h1_input .fold(H::new(), |h, bytes| h.chain_update(bytes)) .finalize(); - let mut m = G::identity(); - let mut z = G::identity(); + let mut m = G::identity_elem(); + let mut z = G::identity_elem(); for (i, (c, d)) in c_slice.zip(d_slice).enumerate() { chain!(h2_input, Serialize::::from_owned(seed.clone())?, i2osp::(i)? => |x| Some(x.as_slice()), - Serialize::::from_owned(G::to_arr(c.value))?, - Serialize::::from_owned(G::to_arr(d.value))?, + Serialize::::from_owned(G::serialize_elem(c.value))?, + Serialize::::from_owned(G::serialize_elem(d.value))?, Serialize::::from_owned(composite_dst)?, ); let dst = GenericArray::from(STR_HASH_TO_SCALAR) From f6a1bacd55ad68e1c7cb69c1183e9ec5f25e9bc4 Mon Sep 17 00:00:00 2001 From: dAxpeDDa Date: Mon, 10 Jan 2022 01:59:48 +0100 Subject: [PATCH 7/9] Sort `Group` methods --- src/group/mod.rs | 34 +++++++++++++------------- src/group/p256.rs | 48 ++++++++++++++++++------------------- src/group/ristretto.rs | 54 +++++++++++++++++++++--------------------- 3 files changed, 68 insertions(+), 68 deletions(-) diff --git a/src/group/mod.rs b/src/group/mod.rs index 8eba51f..4a039e4 100644 --- a/src/group/mod.rs +++ b/src/group/mod.rs @@ -78,35 +78,35 @@ pub trait Group { where >::Output: ArrayLength; - /// Return a scalar from its fixed-length bytes representation. If the - /// scalar is zero or invalid, then return an error. - fn deserialize_scalar(scalar_bits: &GenericArray) -> Result; - - /// picks a scalar at random - fn random_scalar(rng: &mut R) -> Self::Scalar; + /// Get the base point for the group + fn base_elem() -> Self::Elem; - /// Serializes a scalar to bytes - fn serialize_scalar(scalar: Self::Scalar) -> GenericArray; + /// Returns the identity group element + fn identity_elem() -> Self::Elem; - /// The multiplicative inverse of this scalar - fn invert_scalar(scalar: Self::Scalar) -> Self::Scalar; + /// Serializes the `self` group element + fn serialize_elem(elem: Self::Elem) -> GenericArray; /// Return an element from its fixed-length bytes representation. If the /// element is the identity element, return an error. fn deserialize_elem(element_bits: &GenericArray) -> Result; - /// Serializes the `self` group element - fn serialize_elem(elem: Self::Elem) -> GenericArray; - - /// Get the base point for the group - fn base_elem() -> Self::Elem; + /// picks a scalar at random + fn random_scalar(rng: &mut R) -> Self::Scalar; - /// Returns the identity group element - fn identity_elem() -> Self::Elem; + /// The multiplicative inverse of this scalar + fn invert_scalar(scalar: Self::Scalar) -> Self::Scalar; /// Returns the scalar representing zero #[cfg(test)] fn zero_scalar() -> Self::Scalar; + + /// Serializes a scalar to bytes + fn serialize_scalar(scalar: Self::Scalar) -> GenericArray; + + /// Return a scalar from its fixed-length bytes representation. If the + /// scalar is zero or invalid, then return an error. + fn deserialize_scalar(scalar_bits: &GenericArray) -> Result; } #[cfg(test)] diff --git a/src/group/p256.rs b/src/group/p256.rs index b1ee3da..55acdca 100644 --- a/src/group/p256.rs +++ b/src/group/p256.rs @@ -144,28 +144,12 @@ impl Group for NistP256 { Ok(Scalar::from_be_bytes_reduced(result)) } - fn deserialize_scalar(scalar_bits: &GenericArray) -> Result { - SecretKey::from_be_bytes(scalar_bits) - .map(|secret_key| *secret_key.to_nonzero_scalar()) - .map_err(|_| Error::ScalarError) - } - - fn random_scalar(rng: &mut R) -> Self::Scalar { - *SecretKey::random(rng).to_nonzero_scalar() - } - - fn serialize_scalar(scalar: Self::Scalar) -> GenericArray { - scalar.into() - } - - fn invert_scalar(scalar: Self::Scalar) -> Self::Scalar { - Option::from(scalar.invert()).unwrap() + fn base_elem() -> Self::Elem { + ProjectivePoint::generator() } - fn deserialize_elem(element_bits: &GenericArray) -> Result { - PublicKey::from_sec1_bytes(element_bits) - .map(|public_key| public_key.to_projective()) - .map_err(|_| Error::PointError) + fn identity_elem() -> Self::Elem { + ProjectivePoint::identity() } fn serialize_elem(elem: Self::Elem) -> GenericArray { @@ -176,18 +160,34 @@ impl Group for NistP256 { result } - fn base_elem() -> Self::Elem { - ProjectivePoint::generator() + fn deserialize_elem(element_bits: &GenericArray) -> Result { + PublicKey::from_sec1_bytes(element_bits) + .map(|public_key| public_key.to_projective()) + .map_err(|_| Error::PointError) } - fn identity_elem() -> Self::Elem { - ProjectivePoint::identity() + fn random_scalar(rng: &mut R) -> Self::Scalar { + *SecretKey::random(rng).to_nonzero_scalar() + } + + fn invert_scalar(scalar: Self::Scalar) -> Self::Scalar { + Option::from(scalar.invert()).unwrap() } #[cfg(test)] fn zero_scalar() -> Self::Scalar { Scalar::zero() } + + fn serialize_scalar(scalar: Self::Scalar) -> GenericArray { + scalar.into() + } + + fn deserialize_scalar(scalar_bits: &GenericArray) -> Result { + SecretKey::from_be_bytes(scalar_bits) + .map(|secret_key| *secret_key.to_nonzero_scalar()) + .map_err(|_| Error::ScalarError) + } } /// Corresponds to the hash_to_curve_simple_swu() function defined in diff --git a/src/group/ristretto.rs b/src/group/ristretto.rs index 2b0315d..14c44d6 100644 --- a/src/group/ristretto.rs +++ b/src/group/ristretto.rs @@ -80,10 +80,24 @@ impl Group for Ristretto255 { )) } - fn deserialize_scalar(scalar_bits: &GenericArray) -> Result { - Scalar::from_canonical_bytes((*scalar_bits).into()) - .filter(|scalar| scalar != &Scalar::zero()) - .ok_or(Error::ScalarError) + fn base_elem() -> Self::Elem { + RISTRETTO_BASEPOINT_POINT + } + + fn identity_elem() -> Self::Elem { + RistrettoPoint::identity() + } + + // serialization of a group element + fn serialize_elem(elem: Self::Elem) -> GenericArray { + elem.compress().to_bytes().into() + } + + fn deserialize_elem(element_bits: &GenericArray) -> Result { + CompressedRistretto::from_slice(element_bits) + .decompress() + .filter(|point| point != &RistrettoPoint::identity()) + .ok_or(Error::PointError) } fn random_scalar(rng: &mut R) -> Self::Scalar { @@ -100,36 +114,22 @@ impl Group for Ristretto255 { } } - fn serialize_scalar(scalar: Self::Scalar) -> GenericArray { - scalar.to_bytes().into() - } - fn invert_scalar(scalar: Self::Scalar) -> Self::Scalar { scalar.invert() } - fn deserialize_elem(element_bits: &GenericArray) -> Result { - CompressedRistretto::from_slice(element_bits) - .decompress() - .filter(|point| point != &RistrettoPoint::identity()) - .ok_or(Error::PointError) - } - - // serialization of a group element - fn serialize_elem(elem: Self::Elem) -> GenericArray { - elem.compress().to_bytes().into() - } - - fn base_elem() -> Self::Elem { - RISTRETTO_BASEPOINT_POINT + #[cfg(test)] + fn zero_scalar() -> Self::Scalar { + Scalar::zero() } - fn identity_elem() -> Self::Elem { - RistrettoPoint::identity() + fn serialize_scalar(scalar: Self::Scalar) -> GenericArray { + scalar.to_bytes().into() } - #[cfg(test)] - fn zero_scalar() -> Self::Scalar { - Scalar::zero() + fn deserialize_scalar(scalar_bits: &GenericArray) -> Result { + Scalar::from_canonical_bytes((*scalar_bits).into()) + .filter(|scalar| scalar != &Scalar::zero()) + .ok_or(Error::ScalarError) } } From 9079bdef824169440b983fd119810dd8e9265ff9 Mon Sep 17 00:00:00 2001 From: dAxpeDDa Date: Mon, 10 Jan 2022 05:45:37 +0100 Subject: [PATCH 8/9] Rework `expand_message_xmd` and remove utility --- Cargo.toml | 2 +- src/group/expand.rs | 110 ++++++------- src/group/mod.rs | 30 ++-- src/group/p256.rs | 56 ++++--- src/group/ristretto.rs | 55 +++---- src/lib.rs | 2 +- src/tests/voprf_vectors.rs | 2 +- src/util.rs | 122 ++------------ src/voprf.rs | 320 +++++++++++++++++++++++-------------- 9 files changed, 332 insertions(+), 367 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 56a2ec4..816ac30 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,7 +24,7 @@ p256 = [ "once_cell", "p256_", ] -ristretto255 = [] +ristretto255 = ["generic-array/more_lengths"] ristretto255_fiat_u32 = ["curve25519-dalek/fiat_u32_backend", "ristretto255"] ristretto255_fiat_u64 = ["curve25519-dalek/fiat_u64_backend", "ristretto255"] ristretto255_simd = ["curve25519-dalek/simd_backend", "ristretto255"] diff --git a/src/group/expand.rs b/src/group/expand.rs index 300723e..ec54907 100644 --- a/src/group/expand.rs +++ b/src/group/expand.rs @@ -5,74 +5,82 @@ // License, Version 2.0 found in the LICENSE-APACHE file in the root directory // of this source tree. -use core::ops::Add; +use core::convert::TryFrom; -use digest::core_api::BlockSizeUser; +use digest::core_api::{Block, BlockSizeUser}; use digest::{Digest, FixedOutputReset}; -use generic_array::sequence::Concat; -use generic_array::typenum::{Unsigned, U1, U2}; +use generic_array::typenum::{IsLess, NonZero, Unsigned, U65536}; use generic_array::{ArrayLength, GenericArray}; -use crate::util::i2osp; use crate::{Error, Result}; -// Computes ceil(x / y) -fn div_ceil(x: usize, y: usize) -> usize { - let additive = (x % y != 0) as usize; - x / y + additive -} - fn xor>(x: GenericArray, y: GenericArray) -> GenericArray { x.into_iter().zip(y).map(|(x1, x2)| x1 ^ x2).collect() } /// Corresponds to the expand_message_xmd() function defined in /// -pub fn expand_message_xmd< - 'a, - H: BlockSizeUser + Digest + FixedOutputReset, - L: ArrayLength, - M: IntoIterator, - D: ArrayLength + Add, ->( - msg: M, - dst: GenericArray, +pub fn expand_message_xmd>( + msg: &[&[u8]], + dst: &[u8], ) -> Result> where - >::Output: ArrayLength, + // Constraint set by `expand_message_xmd`: + // https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-13.html#section-5.4.1-6 + L: NonZero + IsLess, { - let digest_len = H::OutputSize::USIZE; - let ell = div_ceil(L::USIZE, digest_len); - if ell > 255 { + // DST, a byte string of at most 255 bytes. + let dst_len = u8::try_from(dst.len()).map_err(|_| Error::HashToCurveError)?; + + // b_in_bytes, b / 8 for b the output size of H in bits. + let b_in_bytes = H::OutputSize::to_usize(); + + // Constraint set by `expand_message_xmd`: + // https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-13.html#section-5.4.1-4 + if b_in_bytes > H::BlockSize::USIZE { return Err(Error::HashToCurveError); } - let dst_prime = dst.concat(i2osp::(D::USIZE)?); - let z_pad = i2osp::(0)?; - let l_i_b_str = i2osp::(L::USIZE)?; - let mut h = H::new(); + // ell = ceil(len_in_bytes / b_in_bytes) + // ABORT if ell > 255 + let ell = u8::try_from((L::USIZE + b_in_bytes - 1) / b_in_bytes) + .map_err(|_| Error::HashToCurveError)?; + + let mut hash = H::new(); + // b_0 = H(msg_prime) // msg_prime = Z_pad || msg || l_i_b_str || I2OSP(0, 1) || DST_prime - Digest::update(&mut h, z_pad); - for bytes in msg { - Digest::update(&mut h, bytes) + // Z_pad = I2OSP(0, s_in_bytes) + // s_in_bytes, the input block size of H, measured in bytes + Digest::update(&mut hash, Block::::default()); + for msg in msg { + Digest::update(&mut hash, msg); } - Digest::update(&mut h, l_i_b_str); - Digest::update(&mut h, i2osp::(0)?); - Digest::update(&mut h, &dst_prime); + // l_i_b_str = I2OSP(len_in_bytes, 2) + Digest::update(&mut hash, L::U16.to_be_bytes()); + Digest::update(&mut hash, [0]); + // DST_prime = DST || I2OSP(len(DST), 1) + Digest::update(&mut hash, dst); + Digest::update(&mut hash, [dst_len]); + let b_0 = hash.finalize_reset(); - // b[0] - let b_0 = h.finalize_reset(); let mut b_i = GenericArray::default(); let mut uniform_bytes = GenericArray::default(); - for (i, chunk) in (1..(ell + 1)).zip(uniform_bytes.chunks_mut(digest_len)) { - Digest::update(&mut h, xor(b_0.clone(), b_i.clone())); - Digest::update(&mut h, i2osp::(i)?); - Digest::update(&mut h, &dst_prime); - b_i = h.finalize_reset(); - chunk.copy_from_slice(&b_i[..digest_len.min(chunk.len())]); + // b_1 = H(b_0 || I2OSP(1, 1) || DST_prime) + // for i in (2, ..., ell): + for (i, chunk) in (1..(ell + 1)).zip(uniform_bytes.chunks_mut(b_in_bytes)) { + // b_i = H(strxor(b_0, b_(i - 1)) || I2OSP(i, 1) || DST_prime) + Digest::update(&mut hash, xor(b_0.clone(), b_i.clone())); + Digest::update(&mut hash, [i]); + // DST_prime = DST || I2OSP(len(DST), 1) + Digest::update(&mut hash, dst); + Digest::update(&mut hash, [dst_len]); + b_i = hash.finalize_reset(); + // uniform_bytes = b_1 || ... || b_ell + // return substr(uniform_bytes, 0, len_in_bytes) + chunk.copy_from_slice(&b_i[..b_in_bytes.min(chunk.len())]); } Ok(uniform_bytes) @@ -81,7 +89,6 @@ where #[cfg(test)] mod tests { use generic_array::typenum::{U128, U32}; - use generic_array::GenericArray; struct Params { msg: &'static str, @@ -91,6 +98,8 @@ mod tests { #[test] fn test_expand_message_xmd() { + const DST: [u8; 27] = *b"QUUX-V01-CS02-with-expander"; + // Test vectors taken from Section K.1 of https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-10.txt let test_vectors: alloc::vec::Vec = alloc::vec![ Params { @@ -190,20 +199,13 @@ mod tests { 378fba044a31f5cb44583a892f5969dcd73b3fa128816e", }, ]; - let dst = GenericArray::from(*b"QUUX-V01-CS02-with-expander"); for tv in test_vectors { let uniform_bytes = match tv.len_in_bytes { - 32 => super::expand_message_xmd::( - Some(tv.msg.as_bytes()), - dst, - ) - .map(|bytes| bytes.to_vec()), - 128 => super::expand_message_xmd::( - Some(tv.msg.as_bytes()), - dst, - ) - .map(|bytes| bytes.to_vec()), + 32 => super::expand_message_xmd::(&[tv.msg.as_bytes()], &DST) + .map(|bytes| bytes.to_vec()), + 128 => super::expand_message_xmd::(&[tv.msg.as_bytes()], &DST) + .map(|bytes| bytes.to_vec()), _ => unimplemented!(), } .unwrap(); diff --git a/src/group/mod.rs b/src/group/mod.rs index 4a039e4..47d5fb7 100644 --- a/src/group/mod.rs +++ b/src/group/mod.rs @@ -18,7 +18,6 @@ use core::ops::{Add, Mul, Sub}; use digest::core_api::BlockSizeUser; use digest::{Digest, FixedOutputReset}; -use generic_array::typenum::U1; use generic_array::{ArrayLength, GenericArray}; use rand_core::{CryptoRng, RngCore}; #[cfg(feature = "ristretto255")] @@ -26,8 +25,12 @@ pub use ristretto::Ristretto255; use subtle::ConstantTimeEq; use zeroize::Zeroize; +use crate::voprf::Mode; use crate::Result; +pub(crate) const STR_HASH_TO_SCALAR: [u8; 13] = *b"HashToScalar-"; +pub(crate) const STR_HASH_TO_GROUP: [u8; 12] = *b"HashToGroup-"; + /// A prime-order subgroup of a base field (EC, prime-order field ...). This /// subgroup is noted additively — as in the draft RFC — in this trait. pub trait Group { @@ -58,25 +61,16 @@ pub trait Group { type ScalarLen: ArrayLength + 'static; /// transforms a password and domain separation tag (DST) into a curve point - fn hash_to_curve + Add>( - msg: &[u8], - dst: GenericArray, - ) -> Result - where - >::Output: ArrayLength; + fn hash_to_curve( + msg: &[&[u8]], + mode: Mode, + ) -> Result; /// Hashes a slice of pseudo-random bytes to a scalar - fn hash_to_scalar< - 'a, - H: BlockSizeUser + Digest + FixedOutputReset, - D: ArrayLength + Add, - I: IntoIterator, - >( - input: I, - dst: GenericArray, - ) -> Result - where - >::Output: ArrayLength; + fn hash_to_scalar( + input: &[&[u8]], + mode: Mode, + ) -> Result; /// Get the base point for the group fn base_elem() -> Self::Elem; diff --git a/src/group/p256.rs b/src/group/p256.rs index 55acdca..36613be 100644 --- a/src/group/p256.rs +++ b/src/group/p256.rs @@ -18,7 +18,8 @@ use core::str::FromStr; use digest::core_api::BlockSizeUser; use digest::{Digest, FixedOutputReset}; -use generic_array::typenum::{Unsigned, U1, U2, U32, U33, U48}; +use generic_array::sequence::Concat; +use generic_array::typenum::{Unsigned, U2, U32, U33, U48}; use generic_array::{ArrayLength, GenericArray}; use num_bigint::{BigInt, Sign}; use num_integer::Integer; @@ -34,6 +35,8 @@ use rand_core::{CryptoRng, RngCore}; use subtle::{Choice, ConditionallySelectable}; use super::Group; +use crate::group::{STR_HASH_TO_GROUP, STR_HASH_TO_SCALAR}; +use crate::voprf::{self, Mode}; use crate::{Error, Result}; // https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-hash-to-curve-11#section-8.2 @@ -54,13 +57,13 @@ impl Group for NistP256 { // Implements the `hash_to_curve()` function from // https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-hash-to-curve-11#section-3 - fn hash_to_curve + Add>( - msg: &[u8], - dst: GenericArray, - ) -> Result - where - >::Output: ArrayLength, - { + fn hash_to_curve( + msg: &[&[u8]], + mode: Mode, + ) -> Result { + let dst = + GenericArray::from(STR_HASH_TO_GROUP).concat(voprf::get_context_string::(mode)); + // https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-hash-to-curve-11#section-8.2 // `p: 2^256 - 2^224 + 2^192 + 2^96 - 1` const P: Lazy = Lazy::new(|| { @@ -87,7 +90,7 @@ impl Group for NistP256 { // https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-hash-to-curve-11#section-5.3 // `hash_to_field` calls `expand_message` with a `len_in_bytes` of `count * L` let uniform_bytes = - super::expand::expand_message_xmd::>::Output, _, _>(Some(msg), dst)?; + super::expand::expand_message_xmd::>::Output>(msg, &dst)?; // hash to curve let (q0x, q0y) = hash_to_curve_simple_swu(&uniform_bytes[..L::USIZE], &A, &B, &P, &Z); @@ -108,18 +111,13 @@ impl Group for NistP256 { } // Implements the `HashToScalar()` function - fn hash_to_scalar< - 'a, - H: BlockSizeUser + Digest + FixedOutputReset, - D: ArrayLength + Add, - I: IntoIterator, - >( - input: I, - dst: GenericArray, - ) -> Result - where - >::Output: ArrayLength, - { + fn hash_to_scalar( + input: &[&[u8]], + mode: Mode, + ) -> Result { + let dst = + GenericArray::from(STR_HASH_TO_SCALAR).concat(voprf::get_context_string::(mode)); + // https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.186-4.pdf#[{%22num%22:211,%22gen%22:0},{%22name%22:%22XYZ%22},70,700,0] // P-256 `n` is defined as // `115792089210356248762697446949407573529996955224135760342 @@ -133,7 +131,7 @@ impl Group for NistP256 { // https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-hash-to-curve-11#section-5.3 // `HashToScalar` is `hash_to_field` - let uniform_bytes = super::expand::expand_message_xmd::(input, dst)?; + let uniform_bytes = super::expand::expand_message_xmd::(input, &dst)?; let bytes = BigInt::from_bytes_be(Sign::Plus, &uniform_bytes) .mod_floor(&N) .to_bytes_be() @@ -464,6 +462,8 @@ mod tests { #[test] fn hash_to_curve_simple_swu() { + const DST: [u8; 44] = *b"QUUX-V01-CS02-with-P256_XMD:SHA-256_SSWU_RO_"; + const P: Lazy = Lazy::new(|| { BigInt::from_str( "115792089210356248762697446949407573530086143415290314195533631308867097853951", @@ -549,15 +549,13 @@ mod tests { q1y: "f6ed88a7aab56a488100e6f1174fa9810b47db13e86be999644922961206e184", }, ]; - let dst = GenericArray::from(*b"QUUX-V01-CS02-with-P256_XMD:SHA-256_SSWU_RO_"); for tv in test_vectors { - let uniform_bytes = - super::super::expand::expand_message_xmd::( - Some(tv.msg.as_bytes()), - dst, - ) - .unwrap(); + let uniform_bytes = super::super::expand::expand_message_xmd::( + &[tv.msg.as_bytes()], + &DST, + ) + .unwrap(); let u0 = BigInt::from_bytes_be(Sign::Plus, &uniform_bytes[..48]).mod_floor(&P); let u1 = BigInt::from_bytes_be(Sign::Plus, &uniform_bytes[48..]).mod_floor(&P); diff --git a/src/group/ristretto.rs b/src/group/ristretto.rs index 14c44d6..ea1cc47 100644 --- a/src/group/ristretto.rs +++ b/src/group/ristretto.rs @@ -6,7 +6,6 @@ // of this source tree. use core::convert::TryInto; -use core::ops::Add; use curve25519_dalek::constants::RISTRETTO_BASEPOINT_POINT; use curve25519_dalek::ristretto::{CompressedRistretto, RistrettoPoint}; @@ -14,11 +13,13 @@ use curve25519_dalek::scalar::Scalar; use curve25519_dalek::traits::Identity; use digest::core_api::BlockSizeUser; use digest::{Digest, FixedOutputReset}; -use generic_array::typenum::{U1, U32, U64}; -use generic_array::{ArrayLength, GenericArray}; +use generic_array::sequence::Concat; +use generic_array::typenum::{U32, U64}; +use generic_array::GenericArray; use rand_core::{CryptoRng, RngCore}; -use super::Group; +use super::{expand, Group, STR_HASH_TO_GROUP, STR_HASH_TO_SCALAR}; +use crate::voprf::{self, Mode}; use crate::{Error, Result}; /// [`Group`] implementation for Ristretto255. @@ -39,38 +40,28 @@ impl Group for Ristretto255 { // Implements the `hash_to_ristretto255()` function from // https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-10.txt - fn hash_to_curve + Add>( - msg: &[u8], - dst: GenericArray, - ) -> Result - where - >::Output: ArrayLength, - { - let uniform_bytes = super::expand::expand_message_xmd::(Some(msg), dst)?; - - Ok(RistrettoPoint::from_uniform_bytes( - uniform_bytes - .as_slice() - .try_into() - .map_err(|_| Error::HashToCurveError)?, - )) + fn hash_to_curve( + msg: &[&[u8]], + mode: Mode, + ) -> Result { + let dst = + GenericArray::from(STR_HASH_TO_GROUP).concat(voprf::get_context_string::(mode)); + + let uniform_bytes = expand::expand_message_xmd::(msg, &dst)?; + + Ok(RistrettoPoint::from_uniform_bytes(&uniform_bytes.into())) } // Implements the `HashToScalar()` function from // https://www.ietf.org/archive/id/draft-irtf-cfrg-voprf-07.html#section-4.1 - fn hash_to_scalar< - 'a, - H: BlockSizeUser + Digest + FixedOutputReset, - D: ArrayLength + Add, - I: IntoIterator, - >( - input: I, - dst: GenericArray, - ) -> Result - where - >::Output: ArrayLength, - { - let uniform_bytes = super::expand::expand_message_xmd::(input, dst)?; + fn hash_to_scalar<'a, H: BlockSizeUser + Digest + FixedOutputReset>( + input: &[&[u8]], + mode: Mode, + ) -> Result { + let dst = + GenericArray::from(STR_HASH_TO_SCALAR).concat(voprf::get_context_string::(mode)); + + let uniform_bytes = expand::expand_message_xmd::(input, &dst)?; Ok(Scalar::from_bytes_mod_order_wide( uniform_bytes diff --git a/src/lib.rs b/src/lib.rs index ed37593..8aa59b9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -569,7 +569,7 @@ pub use crate::group::Group; #[cfg(feature = "alloc")] pub use crate::voprf::VerifiableServerBatchEvaluateResult; pub use crate::voprf::{ - BlindedElement, EvaluationElement, NonVerifiableClient, NonVerifiableClientBlindResult, + BlindedElement, EvaluationElement, Mode, NonVerifiableClient, NonVerifiableClientBlindResult, NonVerifiableServer, NonVerifiableServerEvaluateResult, PreparedEvaluationElement, PreparedTscalar, Proof, VerifiableClient, VerifiableClientBatchFinalizeResult, VerifiableClientBlindResult, VerifiableServer, VerifiableServerBatchEvaluateFinishResult, diff --git a/src/tests/voprf_vectors.rs b/src/tests/voprf_vectors.rs index f1f6cc3..9203072 100644 --- a/src/tests/voprf_vectors.rs +++ b/src/tests/voprf_vectors.rs @@ -8,7 +8,7 @@ //! The VOPRF test vectors taken from: //! https://github.com/cfrg/draft-irtf-cfrg-voprf/blob/master/draft-irtf-cfrg-voprf.md -pub(crate) static VECTORS: &str = r#" +pub(crate) const VECTORS: &str = r#" ## OPRF(ristretto255, SHA-512) ### Base Mode diff --git a/src/util.rs b/src/util.rs index 97bf909..e4fe7f9 100644 --- a/src/util.rs +++ b/src/util.rs @@ -7,137 +7,35 @@ //! Helper functions -use core::array::IntoIter; +use core::convert::TryFrom; -use generic_array::typenum::U0; +use generic_array::typenum::{IsLess, U2, U256}; use generic_array::{ArrayLength, GenericArray}; use crate::{Error, Result}; -// Corresponds to the I2OSP() function from RFC8017 -pub(crate) fn i2osp>(input: usize) -> Result> { - const SIZEOF_USIZE: usize = core::mem::size_of::(); - - // Make sure input fits in output. - if (SIZEOF_USIZE as u32 - input.leading_zeros() / 8) > L::U32 { - return Err(Error::SerializationError); - } - - let mut output = GenericArray::default(); - output[L::USIZE.saturating_sub(SIZEOF_USIZE)..] - .copy_from_slice(&input.to_be_bytes()[SIZEOF_USIZE.saturating_sub(L::USIZE)..]); - Ok(output) -} - -/// Computes `I2OSP(len(input), max_bytes) || input` and helps hold output -/// without allocation. -pub(crate) struct Serialize<'a, L1: ArrayLength, L2: ArrayLength = U0> { - octet: GenericArray, - input: Input<'a, L2>, -} - -enum Input<'a, L: ArrayLength> { - Owned(GenericArray), - Borrowed(&'a [u8]), -} - -impl<'a, L1: ArrayLength, L2: ArrayLength> IntoIterator for &'a Serialize<'a, L1, L2> { - type Item = &'a [u8]; - - type IntoIter = IntoIter<&'a [u8], 2>; - - fn into_iter(self) -> Self::IntoIter { - // MSRV: array `into_iter` isn't available in 1.51 - #[allow(deprecated)] - IntoIter::new([ - &self.octet, - match self.input { - Input::Owned(ref bytes) => bytes, - Input::Borrowed(bytes) => bytes, - }, - ]) - } -} - -impl<'a, L1: ArrayLength, L2: ArrayLength> Serialize<'a, L1, L2> { - // Variation of `serialize` that takes a borrowed `input. - pub(crate) fn from(input: &[u8]) -> Result> { - Ok(Serialize { - octet: i2osp::(input.len())?, - input: Input::Borrowed(input), - }) - } - - pub(crate) fn from_owned(input: GenericArray) -> Result> { - Ok(Serialize { - octet: i2osp::(input.len())?, - input: Input::Owned(input), - }) - } -} - -macro_rules! chain_name { - ($var:ident, $mod:ident) => { - $mod - }; - ($var:ident) => { - $var - }; -} - -macro_rules! chain_skip { - ($var:ident, $feed:expr) => { - $feed - }; - ($var:ident) => { - &$var - }; +pub(crate) fn i2osp_2(input: usize) -> Result> { + u16::try_from(input) + .map(|input| input.to_be_bytes().into()) + .map_err(|_| Error::SerializationError) } -/// The purpose of this macro is to replace -/// [`concat`](alloc::slice::Concat::concat)ing slices into an [`Iterator`] to -/// avoid allocation -macro_rules! chain { - ( - $var:ident, - $item1:expr $(=> |$mod1:ident| $feed1:expr)?, - $($item2:expr $(=> |$mod2:ident| $feed2:expr)?),+$(,)? - ) => { - let chain_name!(__temp$(, $mod1)?) = $item1; - let $var = (chain_skip!(__temp$(, $feed1)?)).into_iter(); - $( - let chain_name!(__temp$(, $mod2)?) = $item2; - let $var = $var.chain(chain_skip!(__temp$(, $feed2)?)); - )+ - }; +pub(crate) fn i2osp_2_array + IsLess>( + _: GenericArray, +) -> GenericArray { + L::U16.to_be_bytes().into() } #[cfg(test)] mod unit_tests { - use generic_array::typenum::{U1, U2}; use proptest::collection::vec; use proptest::prelude::*; - use super::*; use crate::{ BlindedElement, EvaluationElement, NonVerifiableClient, NonVerifiableServer, Proof, VerifiableClient, VerifiableServer, }; - // Test the error condition for I2OSP - #[test] - fn test_i2osp_err_check() { - assert!(i2osp::(0).is_ok()); - - assert!(i2osp::(255).is_ok()); - assert!(i2osp::(256).is_err()); - assert!(i2osp::(257).is_err()); - - assert!(i2osp::(256 * 256 - 1).is_ok()); - assert!(i2osp::(256 * 256).is_err()); - assert!(i2osp::(256 * 256 + 1).is_err()); - } - macro_rules! test_deserialize { ($item:ident, $bytes:ident) => { #[cfg(feature = "ristretto255")] diff --git a/src/voprf.rs b/src/voprf.rs index 088a2fa..6605faf 100644 --- a/src/voprf.rs +++ b/src/voprf.rs @@ -9,7 +9,7 @@ #[cfg(feature = "alloc")] use alloc::vec::Vec; -use core::convert::TryInto; +use core::convert::{TryFrom, TryInto}; use core::iter::{self, Map, Repeat, Zip}; use core::marker::PhantomData; @@ -17,12 +17,12 @@ use derive_where::DeriveWhere; use digest::core_api::BlockSizeUser; use digest::{Digest, FixedOutputReset, Output}; use generic_array::sequence::Concat; -use generic_array::typenum::{U11, U2, U20}; +use generic_array::typenum::{Unsigned, U11, U20}; use generic_array::GenericArray; use rand_core::{CryptoRng, RngCore}; use subtle::ConstantTimeEq; -use crate::util::{i2osp, Serialize}; +use crate::util::{i2osp_2, i2osp_2_array}; use crate::{Error, Group, Result}; /////////////// @@ -30,24 +30,26 @@ use crate::{Error, Group, Result}; // ========= // /////////////// -static STR_HASH_TO_SCALAR: [u8; 13] = *b"HashToScalar-"; -static STR_HASH_TO_GROUP: [u8; 12] = *b"HashToGroup-"; -static STR_FINALIZE: [u8; 9] = *b"Finalize-"; -static STR_SEED: [u8; 5] = *b"Seed-"; -static STR_CONTEXT: [u8; 8] = *b"Context-"; -static STR_COMPOSITE: [u8; 10] = *b"Composite-"; -static STR_CHALLENGE: [u8; 10] = *b"Challenge-"; -static STR_VOPRF: [u8; 8] = *b"VOPRF08-"; +const STR_FINALIZE: [u8; 9] = *b"Finalize-"; +const STR_SEED: [u8; 5] = *b"Seed-"; +const STR_CONTEXT: [u8; 8] = *b"Context-"; +const STR_COMPOSITE: [u8; 10] = *b"Composite-"; +const STR_CHALLENGE: [u8; 10] = *b"Challenge-"; +const STR_VOPRF: [u8; 8] = *b"VOPRF08-"; -/// Determines the mode of operation (either base mode or verifiable mode) +/// Determines the mode of operation (either base mode or verifiable mode). This +/// is only used for custom implementations for [`Group`]. #[derive(Clone, Copy)] -enum Mode { +pub enum Mode { + /// Non-verifiable mode. Base, + /// Verifiable mode. Verifiable, } impl Mode { - fn to_u8(self) -> u8 { + /// Mode as it is represented in a context string. + pub fn to_u8(self) -> u8 { match self { Mode::Base => 0, Mode::Verifiable => 1, @@ -426,9 +428,7 @@ impl NonVerifiableServer /// /// Corresponds to DeriveKeyPair() function from the VOPRF specification. pub fn new_from_seed(seed: &[u8]) -> Result { - let dst = - GenericArray::from(STR_HASH_TO_SCALAR).concat(get_context_string::(Mode::Base)); - let sk = G::hash_to_scalar::(Some(seed), dst)?; + let sk = G::hash_to_scalar::(&[seed], Mode::Base)?; Ok(Self { sk, hash: PhantomData, @@ -449,20 +449,27 @@ impl NonVerifiableServer blinded_element: &BlindedElement, metadata: Option<&[u8]>, ) -> Result> { - chain!( - context, - STR_CONTEXT => |x| Some(x.as_ref()), - get_context_string::(Mode::Base) => |x| Some(x.as_slice()), - Serialize::::from(metadata.unwrap_or_default())?, - ); - let dst = - GenericArray::from(STR_HASH_TO_SCALAR).concat(get_context_string::(Mode::Base)); - let m = G::hash_to_scalar::(context, dst)?; + // https://www.ietf.org/archive/id/draft-irtf-cfrg-voprf-08.html#section-3.3.1.1-1 + + let context_string = get_context_string::(Mode::Base); + let metadata = metadata.unwrap_or_default(); + + // context = "Context-" || contextString || I2OSP(len(info), 2) || info + let context = GenericArray::from(STR_CONTEXT) + .concat(context_string) + .concat(i2osp_2(metadata.len())?); + let context = [&context, metadata]; + + // m = GG.HashToScalar(context) + let m = G::hash_to_scalar::(&context, Mode::Base)?; + // t = skS + m let t = self.sk + &m; - let evaluation_element = blinded_element.value * &G::invert_scalar(t); + // Z = t^(-1) * R + let z = blinded_element.value * &G::invert_scalar(t); + Ok(NonVerifiableServerEvaluateResult { message: EvaluationElement { - value: evaluation_element, + value: z, hash: PhantomData, }, }) @@ -494,9 +501,7 @@ impl VerifiableServer Result { - let dst = GenericArray::from(STR_HASH_TO_SCALAR) - .concat(get_context_string::(Mode::Verifiable)); - let sk = G::hash_to_scalar::(Some(seed), dst)?; + let sk = G::hash_to_scalar::(&[seed], Mode::Verifiable)?; let pk = G::base_elem() * &sk; Ok(Self { sk, @@ -588,14 +593,18 @@ impl VerifiableServer, ) -> Result> { - chain!(context, - STR_CONTEXT => |x| Some(x.as_ref()), - get_context_string::(Mode::Verifiable) => |x| Some(x.as_slice()), - Serialize::::from(metadata.unwrap_or_default())?, - ); - let dst = GenericArray::from(STR_HASH_TO_SCALAR) - .concat(get_context_string::(Mode::Verifiable)); - let m = G::hash_to_scalar::(context, dst)?; + // https://www.ietf.org/archive/id/draft-irtf-cfrg-voprf-08.html#section-3.3.2.1-1 + + let context_string = get_context_string::(Mode::Verifiable); + let metadata = metadata.unwrap_or_default(); + + // context = "Context-" || contextString || I2OSP(len(info), 2) || info + let context = GenericArray::from(STR_CONTEXT) + .concat(context_string) + .concat(i2osp_2(metadata.len())?); + let context = [&context, metadata]; + + let m = G::hash_to_scalar::(&context, Mode::Verifiable)?; let t = self.sk + &m; let evaluation_elements = blinded_elements // To make a return type possible, we have to convert to a `fn` pointer, which isn't @@ -856,8 +865,7 @@ fn deterministic_blind_unchecked Result { - let dst = GenericArray::from(STR_HASH_TO_GROUP).concat(get_context_string::(mode)); - let hashed_point = G::hash_to_curve::(input, dst)?; + let hashed_point = G::hash_to_curve::(&[input], mode)?; Ok(hashed_point * blind) } @@ -891,15 +899,17 @@ where &'a IM: 'a + IntoIterator>, <&'a IM as IntoIterator>::IntoIter: ExactSizeIterator, { - chain!(context, - STR_CONTEXT => |x| Some(x.as_ref()), - get_context_string::(Mode::Verifiable) => |x| Some(x.as_slice()), - Serialize::::from(info)?, - ); + // https://www.ietf.org/archive/id/draft-irtf-cfrg-voprf-08.html#section-3.3.4.2-2 + + let context_string = get_context_string::(Mode::Verifiable); - let dst = - GenericArray::from(STR_HASH_TO_SCALAR).concat(get_context_string::(Mode::Verifiable)); - let m = G::hash_to_scalar::(context, dst)?; + // context = "Context-" || contextString || I2OSP(len(info), 2) || info + let context = GenericArray::from(STR_CONTEXT) + .concat(context_string) + .concat(i2osp_2(info.len())?); + let context = [&context, info]; + + let m = G::hash_to_scalar::(&context, Mode::Verifiable)?; let g = G::base_elem(); let t = g * &m; @@ -935,28 +945,53 @@ fn generate_proof< cs: impl Iterator> + ExactSizeIterator, ds: impl Iterator> + ExactSizeIterator, ) -> Result> { + // https://www.ietf.org/archive/id/draft-irtf-cfrg-voprf-08.html#section-3.3.2.2-1 + let (m, z) = compute_composites(Some(k), b, cs, ds)?; let r = G::random_scalar(rng); let t2 = a * &r; let t3 = m * &r; + // Bm = GG.SerializeElement(B) + let bm = G::serialize_elem(b); + // a0 = GG.SerializeElement(M) + let a0 = G::serialize_elem(m); + // a1 = GG.SerializeElement(Z) + let a1 = G::serialize_elem(z); + // a2 = GG.SerializeElement(t2) + let a2 = G::serialize_elem(t2); + // a3 = GG.SerializeElement(t3) + let a3 = G::serialize_elem(t3); + + let elem_len = G::ElemLen::U16.to_be_bytes(); + + // challengeDST = "Challenge-" || contextString let challenge_dst = GenericArray::from(STR_CHALLENGE).concat(get_context_string::(Mode::Verifiable)); - chain!( - h2_input, - Serialize::::from_owned(G::serialize_elem(b))?, - Serialize::::from_owned(G::serialize_elem(m))?, - Serialize::::from_owned(G::serialize_elem(z))?, - Serialize::::from_owned(G::serialize_elem(t2))?, - Serialize::::from_owned(G::serialize_elem(t3))?, - Serialize::::from_owned(challenge_dst)?, - ); - - let hash_to_scalar_dst = - GenericArray::from(STR_HASH_TO_SCALAR).concat(get_context_string::(Mode::Verifiable)); - - let c_scalar = G::hash_to_scalar::(h2_input, hash_to_scalar_dst)?; + let challenge_dst_len = i2osp_2_array(challenge_dst); + // h2Input = I2OSP(len(Bm), 2) || Bm || + // I2OSP(len(a0), 2) || a0 || + // I2OSP(len(a1), 2) || a1 || + // I2OSP(len(a2), 2) || a2 || + // I2OSP(len(a3), 2) || a3 || + // I2OSP(len(challengeDST), 2) || challengeDST + let h2_input = [ + &elem_len, + bm.as_slice(), + &elem_len, + &a0, + &elem_len, + &a1, + &elem_len, + &a2, + &elem_len, + &a3, + &challenge_dst_len, + &challenge_dst, + ]; + + let c_scalar = G::hash_to_scalar::(&h2_input, Mode::Verifiable)?; let s_scalar = r - &(c_scalar * &k); Ok(Proof { @@ -974,25 +1009,50 @@ fn verify_proof( ds: impl Iterator> + ExactSizeIterator, proof: &Proof, ) -> Result<()> { + // https://www.ietf.org/archive/id/draft-irtf-cfrg-voprf-08.html#section-3.3.4.1-2 let (m, z) = compute_composites(None, b, cs, ds)?; let t2 = (a * &proof.s_scalar) + &(b * &proof.c_scalar); let t3 = (m * &proof.s_scalar) + &(z * &proof.c_scalar); + // Bm = GG.SerializeElement(B) + let bm = G::serialize_elem(b); + // a0 = GG.SerializeElement(M) + let a0 = G::serialize_elem(m); + // a1 = GG.SerializeElement(Z) + let a1 = G::serialize_elem(z); + // a2 = GG.SerializeElement(t2) + let a2 = G::serialize_elem(t2); + // a3 = GG.SerializeElement(t3) + let a3 = G::serialize_elem(t3); + + let elem_len = G::ElemLen::U16.to_be_bytes(); + + // challengeDST = "Challenge-" || contextString let challenge_dst = GenericArray::from(STR_CHALLENGE).concat(get_context_string::(Mode::Verifiable)); - chain!( - h2_input, - Serialize::::from_owned(G::serialize_elem(b))?, - Serialize::::from_owned(G::serialize_elem(m))?, - Serialize::::from_owned(G::serialize_elem(z))?, - Serialize::::from_owned(G::serialize_elem(t2))?, - Serialize::::from_owned(G::serialize_elem(t3))?, - Serialize::::from_owned(challenge_dst)?, - ); - - let hash_to_scalar_dst = - GenericArray::from(STR_HASH_TO_SCALAR).concat(get_context_string::(Mode::Verifiable)); - let c = G::hash_to_scalar::(h2_input, hash_to_scalar_dst)?; + let challenge_dst_len = i2osp_2_array(challenge_dst); + // h2Input = I2OSP(len(Bm), 2) || Bm || + // I2OSP(len(a0), 2) || a0 || + // I2OSP(len(a1), 2) || a1 || + // I2OSP(len(a2), 2) || a2 || + // I2OSP(len(a3), 2) || a3 || + // I2OSP(len(challengeDST), 2) || challengeDST + let h2_input = [ + &elem_len, + bm.as_slice(), + &elem_len, + &a0, + &elem_len, + &a1, + &elem_len, + &a2, + &elem_len, + &a3, + &challenge_dst_len, + &challenge_dst, + ]; + + let c = G::hash_to_scalar::(&h2_input, Mode::Verifiable)?; match c.ct_eq(&proof.c_scalar).into() { true => Ok(()), @@ -1016,6 +1076,10 @@ fn finalize_after_unblind< info: &'a [u8], mode: Mode, ) -> Result> { + // https://www.ietf.org/archive/id/draft-irtf-cfrg-voprf-08.html#section-3.3.3.2-2 + // https://www.ietf.org/archive/id/draft-irtf-cfrg-voprf-08.html#section-3.3.4.3-1 + + // finalizeDST = "Finalize-" || contextString let finalize_dst = GenericArray::from(STR_FINALIZE).concat(get_context_string::(mode)); Ok(inputs_and_unblinded_elements @@ -1023,16 +1087,23 @@ fn finalize_after_unblind< // which isn't possible if we `move` from context. .zip(iter::repeat((info, finalize_dst))) .map(|((input, unblinded_element), (info, finalize_dst))| { - chain!( - hash_input, - Serialize::::from(input.as_ref())?, - Serialize::::from(info)?, - Serialize::::from_owned(G::serialize_elem(unblinded_element))?, - Serialize::::from_owned(finalize_dst)?, - ); - - Ok(hash_input - .fold(H::new(), |h, bytes| h.chain_update(bytes)) + let finalize_dst_len = i2osp_2_array(finalize_dst); + let elem_len = G::ElemLen::U16.to_be_bytes(); + + // hashInput = I2OSP(len(input), 2) || input || + // I2OSP(len(info), 2) || info || + // I2OSP(len(unblindedElement), 2) || unblindedElement || + // I2OSP(len(finalizeDST), 2) || finalizeDST + // return Hash(hashInput) + Ok(H::new() + .chain_update(i2osp_2(input.as_ref().len())?) + .chain_update(input.as_ref()) + .chain_update(i2osp_2(info.len())?) + .chain_update(info) + .chain_update(elem_len) + .chain_update(G::serialize_elem(unblinded_element)) + .chain_update(finalize_dst_len) + .chain_update(finalize_dst) .finalize()) })) } @@ -1043,37 +1114,53 @@ fn compute_composites( c_slice: impl Iterator> + ExactSizeIterator, d_slice: impl Iterator> + ExactSizeIterator, ) -> Result<(G::Elem, G::Elem)> { + // https://www.ietf.org/archive/id/draft-irtf-cfrg-voprf-08.html#section-3.3.2.3-2 + + let elem_len = G::ElemLen::U16.to_be_bytes(); + if c_slice.len() != d_slice.len() { return Err(Error::MismatchedLengthsForCompositeInputs); } + let len = u16::try_from(c_slice.len()).map_err(|_| Error::SerializationError)?; + let seed_dst = GenericArray::from(STR_SEED).concat(get_context_string::(Mode::Verifiable)); let composite_dst = GenericArray::from(STR_COMPOSITE).concat(get_context_string::(Mode::Verifiable)); + let composite_dst_len = i2osp_2_array(composite_dst); - chain!( - h1_input, - Serialize::::from_owned(G::serialize_elem(b))?, - Serialize::::from_owned(seed_dst)?, - ); - let seed = h1_input - .fold(H::new(), |h, bytes| h.chain_update(bytes)) + let seed = H::new() + .chain_update(&elem_len) + .chain_update(G::serialize_elem(b)) + .chain_update(i2osp_2_array(seed_dst)) + .chain_update(seed_dst) .finalize(); + let seed_len = i2osp_2(seed.len())?; let mut m = G::identity_elem(); let mut z = G::identity_elem(); - for (i, (c, d)) in c_slice.zip(d_slice).enumerate() { - chain!(h2_input, - Serialize::::from_owned(seed.clone())?, - i2osp::(i)? => |x| Some(x.as_slice()), - Serialize::::from_owned(G::serialize_elem(c.value))?, - Serialize::::from_owned(G::serialize_elem(d.value))?, - Serialize::::from_owned(composite_dst)?, - ); - let dst = GenericArray::from(STR_HASH_TO_SCALAR) - .concat(get_context_string::(Mode::Verifiable)); - let di = G::hash_to_scalar::(h2_input, dst)?; + for (i, (c, d)) in (0..len).zip(c_slice.zip(d_slice)) { + // Ci = GG.SerializeElement(Cs[i]) + let ci = G::serialize_elem(c.value); + // Di = GG.SerializeElement(Ds[i]) + let di = G::serialize_elem(d.value); + // h2Input = I2OSP(len(seed), 2) || seed || I2OSP(i, 2) || + // I2OSP(len(Ci), 2) || Ci || + // I2OSP(len(Di), 2) || Di || + // I2OSP(len(compositeDST), 2) || compositeDST + let h2_input = [ + &seed_len, + seed.as_slice(), + &i.to_be_bytes(), + &elem_len, + &ci, + &elem_len, + &di, + &composite_dst_len, + &composite_dst, + ]; + let di = G::hash_to_scalar::(&h2_input, Mode::Verifiable)?; m = c.value * &di + &m; z = match k_option { Some(_) => z, @@ -1091,7 +1178,7 @@ fn compute_composites( /// Generates the contextString parameter as defined in /// -fn get_context_string(mode: Mode) -> GenericArray { +pub(crate) fn get_context_string(mode: Mode) -> GenericArray { GenericArray::from(STR_VOPRF) .concat([mode.to_u8()].into()) .concat(G::SUITE_ID.to_be_bytes().into()) @@ -1109,7 +1196,7 @@ mod tests { use ::alloc::vec; use ::alloc::vec::Vec; use generic_array::typenum::Sum; - use generic_array::{ArrayLength, GenericArray}; + use generic_array::ArrayLength; use rand::rngs::OsRng; use zeroize::Zeroize; @@ -1122,17 +1209,13 @@ mod tests { info: &[u8], mode: Mode, ) -> Output { - let dst = GenericArray::from(STR_HASH_TO_GROUP).concat(get_context_string::(mode)); - let point = G::hash_to_curve::(input, dst).unwrap(); + let point = G::hash_to_curve::(&[input], mode).unwrap(); - chain!(context, - STR_CONTEXT => |x| Some(x.as_ref()), - get_context_string::(mode) => |x| Some(x.as_slice()), - Serialize::::from(info).unwrap(), - ); + let context_string = get_context_string::(mode); + let info_len = i2osp_2(info.len()).unwrap(); + let context = [&STR_CONTEXT, context_string.as_slice(), &info_len, info]; - let dst = GenericArray::from(STR_HASH_TO_SCALAR).concat(get_context_string::(mode)); - let m = G::hash_to_scalar::(context, dst).unwrap(); + let m = G::hash_to_scalar::(&context, mode).unwrap(); let res = point * &G::invert_scalar(key + &m); @@ -1194,7 +1277,7 @@ mod tests { .unwrap(); let wrong_pk = { // Choose a group element that is unlikely to be the right public key - G::hash_to_curve::(b"msg", (*b"dst").into()).unwrap() + G::hash_to_curve::(&[b"msg"], Mode::Base).unwrap() }; let client_finalize_result = client_blind_result.state.finalize( input, @@ -1291,7 +1374,7 @@ mod tests { let messages: Vec<_> = messages.collect(); let wrong_pk = { // Choose a group element that is unlikely to be the right public key - G::hash_to_curve::(b"msg", (*b"dst").into()).unwrap() + G::hash_to_curve::(&[b"msg"], Mode::Base).unwrap() }; let client_finalize_result = VerifiableClient::batch_finalize( &inputs, @@ -1322,8 +1405,7 @@ mod tests { ) .unwrap(); - let dst = GenericArray::from(STR_HASH_TO_GROUP).concat(get_context_string::(Mode::Base)); - let point = G::hash_to_curve::(&input, dst).unwrap(); + let point = G::hash_to_curve::(&[&input], Mode::Base).unwrap(); let res2 = finalize_after_unblind::( Some((input.as_ref(), point)).into_iter(), info, From c2ebc48e398670fa217e7f8c11a441be2dec0443 Mon Sep 17 00:00:00 2001 From: dAxpeDDa Date: Mon, 10 Jan 2022 06:57:44 +0100 Subject: [PATCH 9/9] Improve P256 `hash_to_scalar` --- src/group/p256.rs | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/src/group/p256.rs b/src/group/p256.rs index 36613be..da8cb7b 100644 --- a/src/group/p256.rs +++ b/src/group/p256.rs @@ -25,6 +25,7 @@ use num_bigint::{BigInt, Sign}; use num_integer::Integer; use num_traits::{One, ToPrimitive, Zero}; use once_cell::unsync::Lazy; +use p256_::elliptic_curve::bigint::{Encoding, U384}; use p256_::elliptic_curve::group::prime::PrimeCurveAffine; use p256_::elliptic_curve::ops::Reduce; use p256_::elliptic_curve::sec1::{FromEncodedPoint, ToEncodedPoint}; @@ -122,24 +123,19 @@ impl Group for NistP256 { // P-256 `n` is defined as // `115792089210356248762697446949407573529996955224135760342 // 422259061068512044369` - const N: Lazy = Lazy::new(|| { - BigInt::from_str( - "115792089210356248762697446949407573529996955224135760342422259061068512044369", - ) - .unwrap() - }); + const N: U384 = + U384::from_be_hex("00000000000000000000000000000000FFFFFFFF00000000FFFFFFFFFFFFFFFFBCE6FAADA7179E84F3B9CAC2FC632551"); // https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-hash-to-curve-11#section-5.3 // `HashToScalar` is `hash_to_field` let uniform_bytes = super::expand::expand_message_xmd::(input, &dst)?; - let bytes = BigInt::from_bytes_be(Sign::Plus, &uniform_bytes) - .mod_floor(&N) - .to_bytes_be() - .1; - let mut result = GenericArray::default(); - result[..bytes.len()].copy_from_slice(&bytes); + let bytes = Option::::from(U384::from_be_slice(&uniform_bytes).reduce(&N)) + .unwrap() + .to_be_bytes(); - Ok(Scalar::from_be_bytes_reduced(result)) + Ok(Scalar::from_be_bytes_reduced( + GenericArray::clone_from_slice(&bytes[16..]), + )) } fn base_elem() -> Self::Elem {