Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ once_cell = "1.18.0"
itertools = "0.12.0"
rand = "0.8.5"
ref-cast = "1.0.20"
derive_more = "0.99.17"

[target.'cfg(any(target_arch = "x86_64", target_arch = "aarch64"))'.dependencies]
pasta-msm = { git = "https://github.com/lurk-lab/pasta-msm", branch = "dev", version = "0.1.4" }
Expand Down
213 changes: 28 additions & 185 deletions src/provider/bn256_grumpkin.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
//! This module implements the Nova traits for `bn256::Point`, `bn256::Scalar`, `grumpkin::Point`, `grumpkin::Scalar`.
use crate::{
provider::{
traits::{CompressedGroup, DlogGroup},
util::msm::cpu_best_msm,
},
impl_traits,
provider::{traits::DlogGroup, util::msm::cpu_best_msm},
traits::{Group, PrimeFieldExt, TranscriptReprTrait},
};
use digest::{ExtendableOutput, Update};
use ff::{FromUniformBytes, PrimeField};
use group::{cofactor::CofactorCurveAffine, Curve, Group as AnotherGroup, GroupEncoding};
use group::{cofactor::CofactorCurveAffine, Curve, Group as AnotherGroup};
#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
use grumpkin_msm::{bn256 as bn256_msm, grumpkin as grumpkin_msm};
use num_bigint::BigInt;
use num_traits::Num;
// Remove this when https://github.com/zcash/pasta_curves/issues/41 resolves
Expand All @@ -17,201 +17,44 @@ use rayon::prelude::*;
use sha3::Shake256;
use std::io::Read;

use halo2curves::bn256::{
G1Affine as Bn256Affine, G1Compressed as Bn256Compressed, G1 as Bn256Point,
};
use halo2curves::grumpkin::{
G1Affine as GrumpkinAffine, G1Compressed as GrumpkinCompressed, G1 as GrumpkinPoint,
};

/// Re-exports that give access to the standard aliases used in the code base, for bn256
pub mod bn256 {
pub use halo2curves::bn256::{Fq as Base, Fr as Scalar, G1Affine as Affine, G1 as Point};
pub use halo2curves::bn256::{
Fq as Base, Fr as Scalar, G1Affine as Affine, G1Compressed as Compressed, G1 as Point,
};
}

/// Re-exports that give access to the standard aliases used in the code base, for grumpkin
pub mod grumpkin {
pub use halo2curves::grumpkin::{Fq as Base, Fr as Scalar, G1Affine as Affine, G1 as Point};
}

macro_rules! impl_traits {
(
$name:ident,
$name_compressed:ident,
$name_curve:ident,
$name_curve_affine:ident,
$order_str:literal,
$base_str:literal
) => {
impl Group for $name::Point {
type Base = $name::Base;
type Scalar = $name::Scalar;

fn group_params() -> (Self::Base, Self::Base, BigInt, BigInt) {
let A = $name::Point::a();
let B = $name::Point::b();
let order = BigInt::from_str_radix($order_str, 16).unwrap();
let base = BigInt::from_str_radix($base_str, 16).unwrap();

(A, B, order, base)
}
}

impl DlogGroup for $name::Point {
type CompressedGroupElement = $name_compressed;
type PreprocessedGroupElement = $name::Affine;

#[tracing::instrument(
skip_all,
level = "trace",
name = "<_ as Group>::vartime_multiscalar_mul"
)]
fn vartime_multiscalar_mul(
scalars: &[Self::Scalar],
bases: &[Self::PreprocessedGroupElement],
) -> Self {
#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
if scalars.len() >= 128 {
grumpkin_msm::$name(bases, scalars)
} else {
cpu_best_msm(scalars, bases)
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
cpu_best_msm(scalars, bases)
}
fn preprocessed(&self) -> Self::PreprocessedGroupElement {
self.to_affine()
}

fn compress(&self) -> Self::CompressedGroupElement {
self.to_bytes()
}

fn from_label(label: &'static [u8], n: usize) -> Vec<Self::PreprocessedGroupElement> {
let mut shake = Shake256::default();
shake.update(label);
let mut reader = shake.finalize_xof();
let mut uniform_bytes_vec = Vec::new();
for _ in 0..n {
let mut uniform_bytes = [0u8; 32];
reader.read_exact(&mut uniform_bytes).unwrap();
uniform_bytes_vec.push(uniform_bytes);
}
let gens_proj: Vec<$name_curve> = (0..n)
.into_par_iter()
.map(|i| {
let hash = $name_curve::hash_to_curve("from_uniform_bytes");
hash(&uniform_bytes_vec[i])
})
.collect();

let num_threads = rayon::current_num_threads();
if gens_proj.len() > num_threads {
let chunk = (gens_proj.len() as f64 / num_threads as f64).ceil() as usize;
(0..num_threads)
.into_par_iter()
.flat_map(|i| {
let start = i * chunk;
let end = if i == num_threads - 1 {
gens_proj.len()
} else {
core::cmp::min((i + 1) * chunk, gens_proj.len())
};
if end > start {
let mut gens = vec![$name_curve_affine::identity(); end - start];
<Self as Curve>::batch_normalize(&gens_proj[start..end], &mut gens);
gens
} else {
vec![]
}
})
.collect()
} else {
let mut gens = vec![$name_curve_affine::identity(); n];
<Self as Curve>::batch_normalize(&gens_proj, &mut gens);
gens
}
}

fn zero() -> Self {
$name::Point::identity()
}

fn to_coordinates(&self) -> (Self::Base, Self::Base, bool) {
let coordinates = self.to_affine().coordinates();
if coordinates.is_some().unwrap_u8() == 1
&& ($name_curve_affine::identity() != self.to_affine())
{
(*coordinates.unwrap().x(), *coordinates.unwrap().y(), false)
} else {
(Self::Base::zero(), Self::Base::zero(), true)
}
}
}

impl PrimeFieldExt for $name::Scalar {
fn from_uniform(bytes: &[u8]) -> Self {
let bytes_arr: [u8; 64] = bytes.try_into().unwrap();
$name::Scalar::from_uniform_bytes(&bytes_arr)
}
}

impl<G: DlogGroup> TranscriptReprTrait<G> for $name_compressed {
fn to_transcript_bytes(&self) -> Vec<u8> {
self.as_ref().to_vec()
}
}

impl CompressedGroup for $name_compressed {
type GroupElement = $name::Point;

fn decompress(&self) -> Option<$name::Point> {
Some($name_curve::from_bytes(&self).unwrap())
}
}

impl<G: Group> TranscriptReprTrait<G> for $name::Scalar {
fn to_transcript_bytes(&self) -> Vec<u8> {
self.to_repr().to_vec()
}
}

impl<G: DlogGroup> TranscriptReprTrait<G> for $name::Affine {
fn to_transcript_bytes(&self) -> Vec<u8> {
let (x, y, is_infinity_byte) = {
let coordinates = self.coordinates();
if coordinates.is_some().unwrap_u8() == 1 && ($name_curve_affine::identity() != *self) {
let c = coordinates.unwrap();
(*c.x(), *c.y(), u8::from(false))
} else {
($name::Base::zero(), $name::Base::zero(), u8::from(false))
}
};

x.to_repr()
.into_iter()
.chain(y.to_repr().into_iter())
.chain(std::iter::once(is_infinity_byte))
.collect()
}
}
pub use halo2curves::grumpkin::{
Fq as Base, Fr as Scalar, G1Affine as Affine, G1Compressed as Compressed, G1 as Point,
};
}

#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
impl_traits!(
bn256,
"30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001",
"30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47",
bn256_msm
);
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
impl_traits!(
bn256,
Bn256Compressed,
Bn256Point,
Bn256Affine,
"30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001",
"30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47"
);

#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
impl_traits!(
grumpkin,
"30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47",
"30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001",
grumpkin_msm
);
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
impl_traits!(
grumpkin,
GrumpkinCompressed,
GrumpkinPoint,
GrumpkinAffine,
"30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47",
"30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001"
);
Expand All @@ -237,7 +80,7 @@ mod tests {
.map(|_| bn256::Scalar::random(&mut rng))
.collect::<Vec<_>>();

let cpu_msm = cpu_best_msm(&scalars, &points);
let cpu_msm = cpu_best_msm(&points, &scalars);
let gpu_msm = bn256::Point::vartime_multiscalar_mul(&scalars, &points);

assert_eq!(cpu_msm, gpu_msm);
Expand All @@ -253,7 +96,7 @@ mod tests {
.map(|_| grumpkin::Scalar::random(&mut rng))
.collect::<Vec<_>>();

let cpu_msm = cpu_best_msm(&scalars, &points);
let cpu_msm = cpu_best_msm(&points, &scalars);
let gpu_msm = grumpkin::Point::vartime_multiscalar_mul(&scalars, &points);

assert_eq!(cpu_msm, gpu_msm);
Expand Down
2 changes: 1 addition & 1 deletion src/provider/kzg_commitment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ pub struct KZGCommitmentEngine<E: Engine> {
impl<E: Engine, NE: NovaEngine<GE = E::G1, Scalar = E::Fr>> CommitmentEngineTrait<NE>
for KZGCommitmentEngine<E>
where
E::G1: DlogGroup<PreprocessedGroupElement = E::G1Affine, Scalar = E::Fr>,
E::G1: DlogGroup<ScalarExt = E::Fr, AffineExt = E::G1Affine>,
E::G1Affine: Serialize + for<'de> Deserialize<'de>,
E::G2Affine: Serialize + for<'de> Deserialize<'de>,
E::Fr: PrimeFieldBits, // TODO due to use of gen_srs_for_testing, make optional
Expand Down
16 changes: 8 additions & 8 deletions src/provider/mlkzg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ impl<E, NE> EvaluationEngine<E, NE>
where
E: Engine,
NE: NovaEngine<GE = E::G1, Scalar = E::Fr>,
E::G1: DlogGroup<PreprocessedGroupElement = E::G1Affine, Scalar = E::Fr>,
E::G1: DlogGroup,
E::Fr: TranscriptReprTrait<E::G1>,
E::G1Affine: TranscriptReprTrait<E::G1>, // TODO: this bound on DlogGroup is really unusable!
{
Expand Down Expand Up @@ -104,7 +104,7 @@ where
E::Fr: Serialize + DeserializeOwned,
E::G1Affine: Serialize + DeserializeOwned,
E::G2Affine: Serialize + DeserializeOwned,
E::G1: DlogGroup<PreprocessedGroupElement = E::G1Affine, Scalar = E::Fr>,
E::G1: DlogGroup<ScalarExt = E::Fr, AffineExt = E::G1Affine>,
<E::G1 as Group>::Base: TranscriptReprTrait<E::G1>, // Note: due to the move of the bound TranscriptReprTrait<G> on G::Base from Group to Engine
E::Fr: PrimeFieldBits, // TODO due to use of gen_srs_for_testing, make optional
E::Fr: TranscriptReprTrait<E::G1>,
Expand Down Expand Up @@ -161,7 +161,7 @@ where

<NE::CE as CommitmentEngineTrait<NE>>::commit(ck, &h)
.comm
.preprocessed()
.to_affine()
};

let kzg_open_batch = |C: &[E::G1Affine],
Expand Down Expand Up @@ -253,7 +253,7 @@ where
.map(|i| {
<NE::CE as CommitmentEngineTrait<NE>>::commit(ck, &polys[i])
.comm
.preprocessed()
.to_affine()
})
.collect();

Expand All @@ -265,7 +265,7 @@ where

// Phase 3 -- create response
let mut com_all = comms.clone();
com_all.insert(0, C.comm.preprocessed());
com_all.insert(0, C.comm.to_affine());
let (w, evals) = kzg_open_batch(&com_all, &polys, &u, transcript);

Ok(EvaluationArgument { comms, w, evals })
Expand Down Expand Up @@ -300,7 +300,7 @@ where

// Compute the commitment to the batched polynomial B(X)
let c_0: E::G1 = C[0].into();
let C_B = (c_0 + NE::GE::vartime_multiscalar_mul(&q_powers[1..k], &C[1..k])).preprocessed();
let C_B = (c_0 + NE::GE::vartime_multiscalar_mul(&q_powers[1..k], &C[1..k])).to_affine();

// Compute the batched openings
// compute B(u_i) = v[i][0] + q*v[i][1] + ... + q^(t-1) * v[i][t-1]
Expand Down Expand Up @@ -356,10 +356,10 @@ where
// obtained from the transcript
let r = Self::compute_challenge(&com, transcript);

if r == E::Fr::ZERO || C.comm == E::G1::zero() {
if r == E::Fr::ZERO || C.comm == E::G1::identity() {
return Err(NovaError::ProofVerifyError);
}
com.insert(0, C.comm.preprocessed()); // set com_0 = C, shifts other commitments to the right
com.insert(0, C.comm.to_affine()); // set com_0 = C, shifts other commitments to the right

let u = vec![r, -r, r * r];

Expand Down
2 changes: 1 addition & 1 deletion src/provider/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ mod tests {
acc + *base * coeff
});

assert_eq!(naive, cpu_best_msm(&coeffs, &bases))
assert_eq!(naive, cpu_best_msm(&bases, &coeffs))
}

#[test]
Expand Down
4 changes: 2 additions & 2 deletions src/provider/non_hiding_kzg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ pub struct UVKZGPCS<E> {

impl<E: MultiMillerLoop> UVKZGPCS<E>
where
E::G1: DlogGroup<PreprocessedGroupElement = E::G1Affine, Scalar = E::Fr>,
E::G1: DlogGroup<ScalarExt = E::Fr, AffineExt = E::G1Affine>,
{
/// Generate a commitment for a polynomial
/// Note that the scheme is not hidding
Expand Down Expand Up @@ -327,7 +327,7 @@ mod tests {
fn end_to_end_test_template<E>() -> Result<(), NovaError>
where
E: MultiMillerLoop,
E::G1: DlogGroup<PreprocessedGroupElement = E::G1Affine, Scalar = E::Fr>,
E::G1: DlogGroup<ScalarExt = E::Fr, AffineExt = E::G1Affine>,
E::Fr: PrimeFieldBits,
{
for _ in 0..100 {
Expand Down
Loading