diff --git a/Cargo.toml b/Cargo.toml index 3f58cff..2706786 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,6 +38,7 @@ typenum = { version = "1.16.0" } [dev-dependencies] criterion = { version = "0.5.1", features = ["html_reports"] } jf-primitives = { git = "https://github.com/espressosystems/jellyfish", features = ["test-srs"] } +rand_chacha = { version = "0.3.1", default-features = false } sha2 = { version = "0.10" } [[bench]] diff --git a/src/qc/bit_vector.rs b/src/qc/bit_vector.rs index 40ffb82..e57c6b9 100644 --- a/src/qc/bit_vector.rs +++ b/src/qc/bit_vector.rs @@ -1,7 +1,10 @@ //! Implementation for BitVectorQC that uses BLS signature + Bit vector. //! See more details in HotShot paper. -use crate::qc::QuorumCertificate; +use crate::{ + qc::QuorumCertificate, + stake_table::{SnapshotVersion, StakeTableScheme}, +}; use ark_std::{ fmt::Debug, format, @@ -20,32 +23,30 @@ use serde::{Deserialize, Serialize}; use typenum::U32; /// An implementation of QC using BLS signature and a bit-vector. -#[derive(Serialize, Deserialize)] -pub struct BitVectorQC Deserialize<'a>>( +pub struct BitVectorQC( PhantomData, + PhantomData, ); #[derive(Serialize, Deserialize, PartialEq, Debug)] -pub struct StakeTableEntry { - pub stake_key: V, - pub stake_amount: U256, -} - -#[derive(Serialize, Deserialize, PartialEq, Debug)] -pub struct QCParams { - pub stake_entries: Vec>, +pub struct QCParams { + pub stake_table: ST, pub threshold: U256, - pub agg_sig_pp: P, + pub agg_sig_pp: A::PublicParameter, } -impl QuorumCertificate for BitVectorQC +impl QuorumCertificate for BitVectorQC where - A: AggregateableSignatureSchemes + Serialize + for<'a> Deserialize<'a>, + A: AggregateableSignatureSchemes + Serialize + for<'a> Deserialize<'a> + PartialEq, + ST: StakeTableScheme + + Serialize + + for<'a> Deserialize<'a> + + PartialEq, { - type QCProverParams = QCParams; + type QCProverParams = QCParams; // TODO: later with SNARKs we'll use a smaller verifier parameter - type QCVerifierParams = QCParams; + type QCVerifierParams = QCParams; type QC = (A::Signature, BitVec); type MessageLength = U32; @@ -65,25 +66,28 @@ where signers: &BitSlice, sigs: &[A::Signature], ) -> Result { - if signers.len() != qc_pp.stake_entries.len() { + let st_len = qc_pp.stake_table.len(SnapshotVersion::LastEpochStart)?; + if signers.len() != st_len { return Err(ParameterError(format!( "bit vector len {} != the number of stake entries {}", signers.len(), - qc_pp.stake_entries.len(), + st_len, ))); } - let total_weight: U256 = - qc_pp - .stake_entries - .iter() - .zip(signers.iter()) - .fold(U256::zero(), |acc, (entry, b)| { + let total_weight: U256 = qc_pp + .stake_table + .iter(SnapshotVersion::LastEpochStart)? + .zip(signers.iter()) + .fold( + U256::zero(), + |acc, (entry, b)| { if *b { - acc + entry.stake_amount + acc + entry.1 } else { acc } - }); + }, + ); if total_weight < qc_pp.threshold { return Err(ParameterError(format!( "total_weight {} less than threshold {}", @@ -91,9 +95,13 @@ where ))); } let mut ver_keys = vec![]; - for (entry, b) in qc_pp.stake_entries.iter().zip(signers.iter()) { + for (entry, b) in qc_pp + .stake_table + .iter(SnapshotVersion::LastEpochStart)? + .zip(signers.iter()) + { if *b { - ver_keys.push(entry.stake_key.clone()); + ver_keys.push(entry.0.clone()); } } if ver_keys.len() != sigs.len() { @@ -114,25 +122,28 @@ where qc: &Self::QC, ) -> Result { let (sig, signers) = qc; - if signers.len() != qc_vp.stake_entries.len() { + let st_len = qc_vp.stake_table.len(SnapshotVersion::LastEpochStart)?; + if signers.len() != st_len { return Err(ParameterError(format!( "signers bit vector len {} != the number of stake entries {}", signers.len(), - qc_vp.stake_entries.len(), + st_len, ))); } - let total_weight: U256 = - qc_vp - .stake_entries - .iter() - .zip(signers.iter()) - .fold(U256::zero(), |acc, (entry, b)| { + let total_weight: U256 = qc_vp + .stake_table + .iter(SnapshotVersion::LastEpochStart)? + .zip(signers.iter()) + .fold( + U256::zero(), + |acc, (entry, b)| { if *b { - acc + entry.stake_amount + acc + entry.1 } else { acc } - }); + }, + ); if total_weight < qc_vp.threshold { return Err(ParameterError(format!( "total_weight {} less than threshold {}", @@ -140,9 +151,13 @@ where ))); } let mut ver_keys = vec![]; - for (entry, b) in qc_vp.stake_entries.iter().zip(signers.iter()) { + for (entry, b) in qc_vp + .stake_table + .iter(SnapshotVersion::LastEpochStart)? + .zip(signers.iter()) + { if *b { - ver_keys.push(entry.stake_key.clone()); + ver_keys.push(entry.0.clone()); } } A::multi_sig_verify(&qc_vp.agg_sig_pp, &ver_keys[..], message, sig)?; @@ -156,22 +171,23 @@ where qc: &Self::QC, ) -> Result::VerificationKey>, PrimitivesError> { let (_sig, signers) = qc; - if signers.len() != qc_vp.stake_entries.len() { + let st_len = qc_vp.stake_table.len(SnapshotVersion::LastEpochStart)?; + if signers.len() != st_len { return Err(ParameterError(format!( "signers bit vector len {} != the number of stake entries {}", signers.len(), - qc_vp.stake_entries.len(), + st_len, ))); } Self::check(qc_vp, message, qc)?; let signer_pks: Vec<_> = qc_vp - .stake_entries - .iter() + .stake_table + .iter(SnapshotVersion::LastEpochStart)? .zip(signers.iter()) .filter(|(_, b)| **b) - .map(|(pk, _)| pk.stake_key.clone()) + .map(|(pk, _)| pk.0) .collect(); Ok(signer_pks) } @@ -180,49 +196,49 @@ where #[cfg(test)] mod tests { use super::*; + use crate::stake_table::{StakeTable, StakeTableScheme}; use jf_primitives::signatures::bls_over_bn254::{BLSOverBN254CurveSignatureScheme, KeyPair}; use jf_primitives::signatures::SignatureScheme; macro_rules! test_quorum_certificate { ($aggsig:tt) => { + type ST = StakeTable<<$aggsig as SignatureScheme>::VerificationKey>; let mut rng = jf_utils::test_rng(); + let agg_sig_pp = $aggsig::param_gen(Some(&mut rng)).unwrap(); let key_pair1 = KeyPair::generate(&mut rng); let key_pair2 = KeyPair::generate(&mut rng); let key_pair3 = KeyPair::generate(&mut rng); - let entry1 = StakeTableEntry { - stake_key: key_pair1.ver_key(), - stake_amount: U256::from(3u8), - }; - let entry2 = StakeTableEntry { - stake_key: key_pair2.ver_key(), - stake_amount: U256::from(5u8), - }; - let entry3 = StakeTableEntry { - stake_key: key_pair3.ver_key(), - stake_amount: U256::from(7u8), - }; + + let mut st = ST::new(3); + st.register(key_pair1.ver_key(), U256::from(3u8)).unwrap(); + st.register(key_pair2.ver_key(), U256::from(5u8)).unwrap(); + st.register(key_pair3.ver_key(), U256::from(7u8)).unwrap(); + st.advance(); + st.advance(); + let qc_pp = QCParams { - stake_entries: vec![entry1, entry2, entry3], + stake_table: st, threshold: U256::from(10u8), agg_sig_pp, }; + let msg = [72u8; 32]; - let sig1 = BitVectorQC::<$aggsig>::sign( + let sig1 = BitVectorQC::<$aggsig, ST>::sign( &agg_sig_pp, &msg.into(), key_pair1.sign_key_ref(), &mut rng, ) .unwrap(); - let sig2 = BitVectorQC::<$aggsig>::sign( + let sig2 = BitVectorQC::<$aggsig, ST>::sign( &agg_sig_pp, &msg.into(), key_pair2.sign_key_ref(), &mut rng, ) .unwrap(); - let sig3 = BitVectorQC::<$aggsig>::sign( + let sig3 = BitVectorQC::<$aggsig, ST>::sign( &agg_sig_pp, &msg.into(), key_pair3.sign_key_ref(), @@ -232,15 +248,15 @@ mod tests { // happy path let signers = bitvec![0, 1, 1]; - let qc = BitVectorQC::<$aggsig>::assemble( + let qc = BitVectorQC::<$aggsig, ST>::assemble( &qc_pp, signers.as_bitslice(), &[sig2.clone(), sig3.clone()], ) .unwrap(); - assert!(BitVectorQC::<$aggsig>::check(&qc_pp, &msg.into(), &qc).is_ok()); + assert!(BitVectorQC::<$aggsig, ST>::check(&qc_pp, &msg.into(), &qc).is_ok()); assert_eq!( - BitVectorQC::<$aggsig>::trace(&qc_pp, &msg.into(), &qc).unwrap(), + BitVectorQC::<$aggsig, ST>::trace(&qc_pp, &msg.into(), &qc).unwrap(), vec![key_pair2.ver_key(), key_pair3.ver_key()], ); @@ -250,14 +266,34 @@ mod tests { bincode::deserialize(&bincode::serialize(&qc).unwrap()).unwrap() ); + // (alex) since deserialized stake table's leaf would contain normalized projective + // points with Z=1, which differs from the original projective representation. + // We compare individual fields for equivalence instead. + let de_qc_pp: QCParams<$aggsig, ST> = + bincode::deserialize(&bincode::serialize(&qc_pp).unwrap()).unwrap(); + assert_eq!( + qc_pp.stake_table.commitment(SnapshotVersion::Head).unwrap(), + de_qc_pp + .stake_table + .commitment(SnapshotVersion::Head) + .unwrap(), + ); assert_eq!( - qc_pp, - bincode::deserialize(&bincode::serialize(&qc_pp).unwrap()).unwrap() + qc_pp + .stake_table + .commitment(SnapshotVersion::LastEpochStart) + .unwrap(), + de_qc_pp + .stake_table + .commitment(SnapshotVersion::LastEpochStart) + .unwrap(), ); + assert_eq!(qc_pp.threshold, de_qc_pp.threshold); + assert_eq!(qc_pp.agg_sig_pp, de_qc_pp.agg_sig_pp); // bad paths // number of signatures unmatch - assert!(BitVectorQC::<$aggsig>::assemble( + assert!(BitVectorQC::<$aggsig, ST>::assemble( &qc_pp, signers.as_bitslice(), &[sig2.clone()] @@ -265,7 +301,7 @@ mod tests { .is_err()); // total weight under threshold let active_bad = bitvec![1, 1, 0]; - assert!(BitVectorQC::<$aggsig>::assemble( + assert!(BitVectorQC::<$aggsig, ST>::assemble( &qc_pp, active_bad.as_bitslice(), &[sig1.clone(), sig2.clone()] @@ -273,33 +309,35 @@ mod tests { .is_err()); // wrong bool vector length let active_bad_2 = bitvec![0, 1, 1, 0]; - assert!(BitVectorQC::<$aggsig>::assemble( + assert!(BitVectorQC::<$aggsig, ST>::assemble( &qc_pp, active_bad_2.as_bitslice(), &[sig2, sig3], ) .is_err()); - assert!(BitVectorQC::<$aggsig>::check( + assert!(BitVectorQC::<$aggsig, ST>::check( &qc_pp, &msg.into(), &(qc.0.clone(), active_bad) ) .is_err()); - assert!(BitVectorQC::<$aggsig>::check( + assert!(BitVectorQC::<$aggsig, ST>::check( &qc_pp, &msg.into(), &(qc.0.clone(), active_bad_2) ) .is_err()); let bad_msg = [70u8; 32]; - assert!(BitVectorQC::<$aggsig>::check(&qc_pp, &bad_msg.into(), &qc).is_err()); + assert!(BitVectorQC::<$aggsig, ST>::check(&qc_pp, &bad_msg.into(), &qc).is_err()); let bad_sig = &sig1; - assert!( - BitVectorQC::<$aggsig>::check(&qc_pp, &msg.into(), &(bad_sig.clone(), qc.1)) - .is_err() - ); + assert!(BitVectorQC::<$aggsig, ST>::check( + &qc_pp, + &msg.into(), + &(bad_sig.clone(), qc.1) + ) + .is_err()); }; } #[test] diff --git a/src/stake_table/config.rs b/src/stake_table/config.rs index 0388576..eacc80a 100644 --- a/src/stake_table/config.rs +++ b/src/stake_table/config.rs @@ -9,7 +9,7 @@ use jf_primitives::crhf::FixedLengthRescueCRHF; pub(crate) const TREE_BRANCH: usize = 3; /// Internal type of Merkle node value(commitment) -pub(crate) type FieldType = ark_bn254::Fr; +pub(crate) type FieldType = ark_bn254::Fq; /// Hash algorithm used in Merkle tree, using a RATE-3 rescue pub(crate) type Digest = FixedLengthRescueCRHF; diff --git a/src/stake_table/error.rs b/src/stake_table/error.rs index 2641945..3fca626 100644 --- a/src/stake_table/error.rs +++ b/src/stake_table/error.rs @@ -1,4 +1,6 @@ +use ark_std::string::ToString; use displaydoc::Display; +use jf_primitives::errors::PrimitivesError; #[derive(Debug, Display)] pub enum StakeTableError { @@ -18,4 +20,15 @@ pub enum StakeTableError { InsufficientFund, /// The number of stake exceed U256 StakeOverflow, + /// The historical snapshot requested is not supported. + SnapshotUnsupported, +} + +impl ark_std::error::Error for StakeTableError {} + +impl From for PrimitivesError { + fn from(value: StakeTableError) -> Self { + // FIXME: (alex) should we define a PrimitivesError::General()? + Self::ParameterError(value.to_string()) + } } diff --git a/src/stake_table/mod.rs b/src/stake_table/mod.rs index c8e5259..5572a07 100644 --- a/src/stake_table/mod.rs +++ b/src/stake_table/mod.rs @@ -1,198 +1,268 @@ use self::{ error::StakeTableError, - utils::{to_merkle_path, PersistentMerkleNode}, -}; -use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; -use ark_std::{ - collections::HashMap, - rand::{CryptoRng, RngCore}, - sync::Arc, - vec::Vec, + utils::{to_merkle_path, Key, PersistentMerkleNode}, }; +use ark_std::{collections::HashMap, rand::SeedableRng, sync::Arc, vec::Vec}; +use digest::crypto_common::rand_core::CryptoRngCore; use ethereum_types::{U256, U512}; use serde::{Deserialize, Serialize}; -use tagged_base64::tagged; mod config; mod utils; // Exports pub mod error; -pub use utils::MerkleCommitment; -pub use utils::MerklePath; -pub use utils::MerklePathEntry; -pub use utils::MerkleProof; - -/// Copied from HotShot repo. -/// Type saftey wrapper for byte encoded keys. -/// Assume that the content is a canonically serialized public key -#[tagged("PUBKEY")] -#[derive( - Clone, Debug, Hash, CanonicalSerialize, CanonicalDeserialize, PartialEq, Eq, PartialOrd, Ord, -)] -pub struct EncodedPublicKey(pub Vec); - -/// Enum type for stake table version -/// * `STVersion::PENDING`: the most up-to-date stake table, where the incoming transactions shall be performed on. -/// * `STVersion::FROZEN`: when an epoch ends, the PENDING stake table is frozen for leader elections for next epoch. -/// * `STVersion::ACTIVE`: the active stake table for leader election. -pub enum STVersion { - PENDING, - FROZEN, - ACTIVE, +pub use utils::{MerkleCommitment, MerklePath, MerklePathEntry, MerkleProof}; + +/// Snapshots of the stake table +/// - the latest "Head" where all new changes are applied to +/// - `EpochStart` marks the snapshot at the beginning of the current epoch +/// - `LastEpochStart` marks the beginning of the last epoch +/// - `BlockNum(u64)` at arbitrary block height +pub enum SnapshotVersion { + Head, + EpochStart, + LastEpochStart, + BlockNum(u64), +} + +/// Common interfaces required for a stake table used in HotShot System. +/// APIs that doesn't take `version: SnapshotVersion` as an input by default works on the head/latest version. +pub trait StakeTableScheme { + /// type for stake key + type Key: Clone; + /// type for the staked amount + type Amount: Clone + Copy; + /// type for the commitment to the current stake table + type Commitment; + /// type for the proof associated with the lookup result (if any) + type LookupProof; + /// type for the iterator over (key, value) entries + type IntoIter: Iterator; + + /// Register a new key into the stake table. + fn register(&mut self, new_key: Self::Key, amount: Self::Amount) + -> Result<(), StakeTableError>; + + /// Batch register a list of new keys. A default implementation is provided + /// w/o batch optimization. + fn batch_register(&mut self, new_keys: I, amounts: J) -> Result<(), StakeTableError> + where + I: IntoIterator, + J: IntoIterator, + { + let _ = new_keys + .into_iter() + .zip(amounts.into_iter()) + .try_for_each(|(key, amount)| Self::register(self, key, amount)); + Ok(()) + } + + /// Deregister an existing key from the stake table. + /// Returns error if some keys are not found. + fn deregister(&mut self, existing_key: &Self::Key) -> Result<(), StakeTableError>; + + /// Batch deregister a list of keys. A default implementation is provided + /// w/o batch optimization. + fn batch_deregister<'a, I>(&mut self, existing_keys: I) -> Result<(), StakeTableError> + where + I: IntoIterator::Key>, + ::Key: 'a, + { + let _ = existing_keys + .into_iter() + .try_for_each(|key| Self::deregister(self, key)); + Ok(()) + } + + /// Returns the commitment to the `version` of stake table. + fn commitment(&self, version: SnapshotVersion) -> Result; + + /// Returns the accumulated stakes of all registered keys of the `version` + /// of stake table. + fn total_stake(&self, version: SnapshotVersion) -> Result; + + /// Returns the number of keys in the `version` of the table. + fn len(&self, version: SnapshotVersion) -> Result; + + /// Returns true if `key` is currently registered, else returns false. + fn contains_key(&self, key: &Self::Key) -> bool; + + /// Lookup the stake under a key against a specific historical `version`, + /// returns error if keys unregistered. + fn lookup( + &self, + version: SnapshotVersion, + key: &Self::Key, + ) -> Result<(Self::Amount, Self::LookupProof), StakeTableError>; + + /// Returns the stakes withhelded by a public key, None if the key is not registered. + /// If you need a lookup proof, use [`Self::lookup()`] instead (which is usually more expensive). + fn simple_lookup( + &self, + version: SnapshotVersion, + key: &Self::Key, + ) -> Result; + + /// Update the stake of the `key` with `(negative ? -1 : 1) * delta`. + /// Return the updated stake or error. + fn update( + &mut self, + key: &Self::Key, + delta: Self::Amount, + negative: bool, + ) -> Result; + + /// Batch update the stake balance of `keys`. Read documentation about + /// [`Self::update()`]. By default, we call `Self::update()` on each + /// (key, amount, negative) tuple. + fn batch_update( + &mut self, + keys: &[Self::Key], + amounts: &[Self::Amount], + negative_flags: Vec, + ) -> Result, StakeTableError> { + let updated_amounts = keys + .iter() + .zip(amounts.iter()) + .zip(negative_flags.iter()) + .map(|((key, &amount), negative)| Self::update(self, key, amount, *negative)) + .collect::, _>>()?; + + Ok(updated_amounts) + } + + /// Randomly sample a (key, stake_amount) pair proportional to the stake distributions, + /// given a fixed seed for `rng`, this sampling should be deterministic. + fn sample( + &self, + rng: &mut (impl SeedableRng + CryptoRngCore), + ) -> Option<(&Self::Key, &Self::Amount)>; + + /// Returns an iterator over all (key, value) entries of the `version` of the table + fn iter(&self, version: SnapshotVersion) -> Result; } /// Locally maintained stake table +/// generic over public key type `K`. #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] -pub struct StakeTable { +#[serde(bound = "K: Key")] +pub struct StakeTable { /// The most up-to-date stake table, where the incoming transactions shall be performed on. - pending: Arc, - /// When an epoch ends, the PENDING stake table is frozen for leader elections for next epoch. - frozen: Arc, - /// The active stake table for leader election. - active: Arc, + head: Arc>, + /// The snapshot of stake table at the beginning of the current epoch + epoch_start: Arc>, + /// The stake table used for leader election. + last_epoch_start: Arc>, /// Height of the underlying merkle tree, determines the capacity. /// The capacity is `TREE_BRANCH.pow(height)`. height: usize, /// The mapping from public keys to their location in the Merkle tree. - mapping: HashMap, + #[serde(skip)] + mapping: HashMap, } -impl StakeTable { - /// Initiating an empty stake table. - /// Overall capacity is `TREE_BRANCH.pow(height)`. - pub fn new(height: usize) -> Self { - Self { - pending: Arc::new(PersistentMerkleNode::Empty), - frozen: Arc::new(PersistentMerkleNode::Empty), - active: Arc::new(PersistentMerkleNode::Empty), - height, - mapping: HashMap::new(), +impl StakeTableScheme for StakeTable { + type Key = K; + type Amount = U256; + type Commitment = MerkleCommitment; + type LookupProof = MerkleProof; + type IntoIter = utils::IntoIter; + + fn register( + &mut self, + new_key: Self::Key, + amount: Self::Amount, + ) -> Result<(), StakeTableError> { + match self.mapping.get(&new_key) { + Some(_) => Err(StakeTableError::ExistingKey), + None => { + let pos = self.mapping.len(); + self.head = self.head.register( + self.height, + &to_merkle_path(pos, self.height), + &new_key, + amount, + )?; + self.mapping.insert(new_key, pos); + Ok(()) + } } } - /// Update the stake table when the epoch number advances, should be manually called. - pub fn advance(&mut self) { - self.active = self.frozen.clone(); - self.frozen = self.pending.clone(); + fn deregister(&mut self, _existing_key: &Self::Key) -> Result<(), StakeTableError> { + // TODO: (alex) work on this in a future PR + unimplemented!() } - /// Returns the number of stakes holding by the input key for a specific stake table version - pub fn simple_lookup( - &self, - version: STVersion, - key: &EncodedPublicKey, - ) -> Result { - let root = match version { - STVersion::PENDING => &self.pending, - STVersion::FROZEN => &self.frozen, - STVersion::ACTIVE => &self.active, - }; - match self.mapping.get(key) { - Some(index) => { - let branches = to_merkle_path(*index, self.height); - root.simple_lookup(self.height, &branches) - } - None => Err(StakeTableError::KeyNotFound), - } + fn commitment(&self, version: SnapshotVersion) -> Result { + let root = Self::get_root(self, version)?; + Ok(MerkleCommitment::new( + root.commitment(), + self.height, + root.num_keys(), + )) } - /// Returns a membership proof for the input key for a specific stake table version - pub fn lookup( - &self, - version: STVersion, - key: &EncodedPublicKey, - ) -> Result { - let root = match version { - STVersion::PENDING => &self.pending, - STVersion::FROZEN => &self.frozen, - STVersion::ACTIVE => &self.active, - }; - match self.mapping.get(key) { - Some(index) => { - let branches = to_merkle_path(*index, self.height); - root.lookup(self.height, &branches) - } - None => Err(StakeTableError::KeyNotFound), - } + fn total_stake(&self, version: SnapshotVersion) -> Result { + let root = Self::get_root(self, version)?; + Ok(root.total_stakes()) } - /// Returns a succint commitment for a specific stake table version - pub fn commitment(&self, version: STVersion) -> MerkleCommitment { - let root = match version { - STVersion::PENDING => &self.pending, - STVersion::FROZEN => &self.frozen, - STVersion::ACTIVE => &self.active, - }; - MerkleCommitment::new(root.commitment(), self.height, root.num_keys()) + fn len(&self, version: SnapshotVersion) -> Result { + let root = Self::get_root(self, version)?; + Ok(root.num_keys()) } - /// Returns the total amount of stakes for a specific stake table version - pub fn total_stakes(&self, version: STVersion) -> U256 { - let root = match version { - STVersion::PENDING => &self.pending, - STVersion::FROZEN => &self.frozen, - STVersion::ACTIVE => &self.active, - }; - root.total_stakes() + fn contains_key(&self, key: &Self::Key) -> bool { + self.mapping.contains_key(key) } - /// Returns the number of keys for a specific stake table version - pub fn num_keys(&self, version: STVersion) -> usize { - let root = match version { - STVersion::PENDING => &self.pending, - STVersion::FROZEN => &self.frozen, - STVersion::ACTIVE => &self.active, - }; - root.num_keys() - } + fn lookup( + &self, + version: SnapshotVersion, + key: &Self::Key, + ) -> Result<(Self::Amount, Self::LookupProof), StakeTableError> { + let root = Self::get_root(self, version)?; - /// Almost uniformly samples a key weighted by its stake from the active stake table - pub fn sample_key_by_stake(&self, rng: &mut R) -> &EncodedPublicKey { - let mut bytes = [0u8; 64]; - rng.fill_bytes(&mut bytes); - let r = U512::from_big_endian(&bytes); - let m = U512::from(self.active.total_stakes()); - let pos: U256 = (r % m).try_into().unwrap(); - self.active.get_key_by_stake(pos).unwrap() + let proof = match self.mapping.get(key) { + Some(index) => { + let branches = to_merkle_path(*index, self.height); + root.lookup(self.height, &branches) + } + None => Err(StakeTableError::KeyNotFound), + }?; + let amount = *proof.get_value().ok_or(StakeTableError::KeyNotFound)?; + Ok((amount, proof)) } - /// Set the stake withheld by `key` to be `value`. - /// Return the previous stake if succeed. - pub fn set_value( - &mut self, - key: &EncodedPublicKey, - value: U256, - ) -> Result { + fn simple_lookup( + &self, + version: SnapshotVersion, + key: &K, + ) -> Result { + let root = Self::get_root(self, version)?; match self.mapping.get(key) { - Some(pos) => { - let old_value: U256; - (self.pending, old_value) = self.pending.set_value( - self.height, - &to_merkle_path(*pos, self.height), - key, - value, - )?; - Ok(old_value) + Some(index) => { + let branches = to_merkle_path(*index, self.height); + root.simple_lookup(self.height, &branches) } None => Err(StakeTableError::KeyNotFound), } } - /// Update the stake of the `key` with `(negative ? -1 : 1) * delta`. - /// Return the updated stake - pub fn update( + fn update( &mut self, - key: &EncodedPublicKey, - delta: U256, + key: &Self::Key, + delta: Self::Amount, negative: bool, - ) -> Result { + ) -> Result { match self.mapping.get(key) { Some(pos) => { let value: U256; - (self.pending, value) = self.pending.update( + (self.head, value) = self.head.update( self.height, &to_merkle_path(*pos, self.height), key, @@ -205,49 +275,102 @@ impl StakeTable { } } - /// Register a new key from the pending stake table - pub fn register(&mut self, key: &EncodedPublicKey, value: U256) -> Result<(), StakeTableError> { + /// Almost uniformly samples a key weighted by its stake from the + /// last_epoch_start stake table + fn sample( + &self, + rng: &mut (impl SeedableRng + CryptoRngCore), + ) -> Option<(&Self::Key, &Self::Amount)> { + let mut bytes = [0u8; 64]; + rng.fill_bytes(&mut bytes); + let r = U512::from_big_endian(&bytes); + let m = U512::from(self.last_epoch_start.total_stakes()); + let pos: U256 = (r % m).try_into().unwrap(); // won't fail + self.last_epoch_start.get_key_by_stake(pos) + } + + fn iter(&self, version: SnapshotVersion) -> Result { + let root = Self::get_root(self, version)?; + Ok(utils::IntoIter::new(root)) + } +} + +impl StakeTable { + /// Initiating an empty stake table. + /// Overall capacity is `TREE_BRANCH.pow(height)`. + pub fn new(height: usize) -> Self { + Self { + head: Arc::new(PersistentMerkleNode::Empty), + epoch_start: Arc::new(PersistentMerkleNode::Empty), + last_epoch_start: Arc::new(PersistentMerkleNode::Empty), + height, + mapping: HashMap::new(), + } + } + + // returns the root of stake table at `version` + fn get_root( + &self, + version: SnapshotVersion, + ) -> Result>, StakeTableError> { + match version { + SnapshotVersion::Head => Ok(Arc::clone(&self.head)), + SnapshotVersion::EpochStart => Ok(Arc::clone(&self.epoch_start)), + SnapshotVersion::LastEpochStart => Ok(Arc::clone(&self.last_epoch_start)), + SnapshotVersion::BlockNum(_) => Err(StakeTableError::SnapshotUnsupported), + } + } + + /// Update the stake table when the epoch number advances, should be manually called. + pub fn advance(&mut self) { + self.last_epoch_start = self.epoch_start.clone(); + self.epoch_start = self.head.clone(); + } + + /// Set the stake withheld by `key` to be `value`. + /// Return the previous stake if succeed. + pub fn set_value(&mut self, key: &K, value: U256) -> Result { match self.mapping.get(key) { - Some(_) => Err(StakeTableError::ExistingKey), - None => { - let pos = self.mapping.len(); - self.mapping.insert(key.clone(), pos); - self.pending = self.pending.register( + Some(pos) => { + let old_value: U256; + (self.head, old_value) = self.head.set_value( self.height, - &to_merkle_path(pos, self.height), + &to_merkle_path(*pos, self.height), key, value, )?; - Ok(()) + Ok(old_value) } + None => Err(StakeTableError::KeyNotFound), } } } #[cfg(test)] mod tests { - use crate::stake_table::STVersion; - - use super::{config::FieldType, EncodedPublicKey, StakeTable}; - use ark_std::vec::Vec; + use super::{error::StakeTableError, SnapshotVersion, StakeTable, StakeTableScheme}; + use ark_std::{rand::SeedableRng, vec::Vec}; use ethereum_types::U256; - use jf_utils::to_bytes; + + // Hotshot use bn254::Fq as key type. + type Key = ark_bn254::Fq; #[test] - fn test_stake_table() { + fn test_stake_table() -> Result<(), StakeTableError> { let mut st = StakeTable::new(3); - let keys = (0..10) - .map(|i| EncodedPublicKey(to_bytes!(&FieldType::from(i)).unwrap())) - .collect::>(); - assert_eq!(st.total_stakes(STVersion::PENDING), U256::from(0)); + let keys = (0..10).map(Key::from).collect::>(); + assert_eq!(st.total_stake(SnapshotVersion::Head)?, U256::from(0)); // Registering keys keys.iter() .take(4) - .for_each(|key| st.register(key, U256::from(100)).unwrap()); - assert_eq!(st.total_stakes(STVersion::PENDING), U256::from(400)); - assert_eq!(st.total_stakes(STVersion::FROZEN), U256::from(0)); - assert_eq!(st.total_stakes(STVersion::ACTIVE), U256::from(0)); + .for_each(|&key| st.register(key, U256::from(100)).unwrap()); + assert_eq!(st.total_stake(SnapshotVersion::Head)?, U256::from(400)); + assert_eq!(st.total_stake(SnapshotVersion::EpochStart)?, U256::from(0)); + assert_eq!( + st.total_stake(SnapshotVersion::LastEpochStart)?, + U256::from(0) + ); // set to zero for futher sampling test assert_eq!( st.set_value(&keys[1], U256::from(0)).unwrap(), @@ -257,34 +380,51 @@ mod tests { keys.iter() .skip(4) .take(3) - .for_each(|key| st.register(key, U256::from(100)).unwrap()); - assert_eq!(st.total_stakes(STVersion::PENDING), U256::from(600)); - assert_eq!(st.total_stakes(STVersion::FROZEN), U256::from(300)); - assert_eq!(st.total_stakes(STVersion::ACTIVE), U256::from(0)); + .for_each(|&key| st.register(key, U256::from(100)).unwrap()); + assert_eq!(st.total_stake(SnapshotVersion::Head)?, U256::from(600)); + assert_eq!( + st.total_stake(SnapshotVersion::EpochStart)?, + U256::from(300) + ); + assert_eq!( + st.total_stake(SnapshotVersion::LastEpochStart)?, + U256::from(0) + ); st.advance(); keys.iter() .skip(7) - .for_each(|key| st.register(key, U256::from(100)).unwrap()); - assert_eq!(st.total_stakes(STVersion::PENDING), U256::from(900)); - assert_eq!(st.total_stakes(STVersion::FROZEN), U256::from(600)); - assert_eq!(st.total_stakes(STVersion::ACTIVE), U256::from(300)); + .for_each(|&key| st.register(key, U256::from(100)).unwrap()); + assert_eq!(st.total_stake(SnapshotVersion::Head)?, U256::from(900)); + assert_eq!( + st.total_stake(SnapshotVersion::EpochStart)?, + U256::from(600) + ); + assert_eq!( + st.total_stake(SnapshotVersion::LastEpochStart)?, + U256::from(300) + ); // No duplicate register - assert!(st.register(&keys[0], U256::from(100)).is_err()); - // The 9-th key is still in pending stake table - assert!(st.simple_lookup(STVersion::FROZEN, &keys[9]).is_err()); - assert!(st.simple_lookup(STVersion::FROZEN, &keys[5]).is_ok()); + assert!(st.register(keys[0], U256::from(100)).is_err()); + // The 9-th key is still in head stake table + assert!(st.lookup(SnapshotVersion::EpochStart, &keys[9]).is_err()); + assert!(st.lookup(SnapshotVersion::EpochStart, &keys[5]).is_ok()); // The 6-th key is still frozen - assert!(st.simple_lookup(STVersion::ACTIVE, &keys[6]).is_err()); - assert!(st.simple_lookup(STVersion::ACTIVE, &keys[2]).is_ok()); + assert!(st + .lookup(SnapshotVersion::LastEpochStart, &keys[6]) + .is_err()); + assert!(st.lookup(SnapshotVersion::LastEpochStart, &keys[2]).is_ok()); // Set value shall return the old value assert_eq!( st.set_value(&keys[0], U256::from(101)).unwrap(), U256::from(100) ); - assert_eq!(st.total_stakes(STVersion::PENDING), U256::from(901)); - assert_eq!(st.total_stakes(STVersion::FROZEN), U256::from(600)); + assert_eq!(st.total_stake(SnapshotVersion::Head)?, U256::from(901)); + assert_eq!( + st.total_stake(SnapshotVersion::EpochStart)?, + U256::from(600) + ); // Update that results in a negative stake assert!(st.update(&keys[0], U256::from(1000), true).is_err()); @@ -299,18 +439,26 @@ mod tests { ); // Testing membership proof - let proof = st.lookup(STVersion::FROZEN, &keys[5]).unwrap(); - assert!(proof.verify(&st.commitment(STVersion::FROZEN)).is_ok()); + let proof = st.lookup(SnapshotVersion::EpochStart, &keys[5])?.1; + assert!(proof + .verify(&st.commitment(SnapshotVersion::EpochStart)?) + .is_ok()); // Membership proofs are tied with a specific version - assert!(proof.verify(&st.commitment(STVersion::PENDING)).is_err()); - assert!(proof.verify(&st.commitment(STVersion::ACTIVE)).is_err()); + assert!(proof + .verify(&st.commitment(SnapshotVersion::Head)?) + .is_err()); + assert!(proof + .verify(&st.commitment(SnapshotVersion::LastEpochStart)?) + .is_err()); // Random test for sampling keys - let mut rng = jf_utils::test_rng(); + let mut rng = rand_chacha::ChaCha20Rng::seed_from_u64(41u64); for _ in 0..100 { - let key = st.sample_key_by_stake(&mut rng); + let (_key, value) = st.sample(&mut rng).unwrap(); // Sampled keys should have positive stake - assert!(st.simple_lookup(STVersion::ACTIVE, key).unwrap() > U256::from(0)); + assert!(value > &U256::from(0)); } + + Ok(()) } } diff --git a/src/stake_table/utils.rs b/src/stake_table/utils.rs index ce6b67a..c624ee7 100644 --- a/src/stake_table/utils.rs +++ b/src/stake_table/utils.rs @@ -3,24 +3,63 @@ use super::{ config::{u256_to_field, Digest, FieldType, TREE_BRANCH}, error::StakeTableError, - EncodedPublicKey, }; +use ark_ff::{Field, PrimeField}; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; -use ark_std::{sync::Arc, vec, vec::Vec}; +use ark_std::{hash::Hash, sync::Arc, vec, vec::Vec}; use ethereum_types::U256; -use jf_primitives::crhf::CRHF; +use jf_primitives::{crhf::CRHF, signatures::bls_over_bn254}; use jf_utils::canonical; use serde::{Deserialize, Serialize}; use tagged_base64::tagged; +/// Common trait bounds for generic key type `K` for [`PersistentMerkleNode`] +pub trait Key: + Clone + CanonicalSerialize + CanonicalDeserialize + PartialEq + Eq + IntoFields + Hash +{ +} +impl Key for T where + T: Clone + + CanonicalSerialize + + CanonicalDeserialize + + PartialEq + + Eq + + IntoFields + + Hash +{ +} + +/// A trait that converts into a field element. +/// Help avoid "cannot impl foreign traits on foreign types" problem +pub trait IntoFields { + fn into_fields(self) -> [F; 2]; +} + +impl IntoFields for FieldType { + fn into_fields(self) -> [FieldType; 2] { + [FieldType::default(), self] + } +} + +impl IntoFields for bls_over_bn254::VerKey { + fn into_fields(self) -> [FieldType; 2] { + let bytes = jf_utils::to_bytes!(&self.to_affine()).unwrap(); + let x = ::from_le_bytes_mod_order(&bytes[..32]); + let y = ::from_le_bytes_mod_order(&bytes[32..]); + [x, y] + } +} + /// A persistent merkle tree tailored for the stake table. +/// Generic over the key type `K` #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] -pub(crate) enum PersistentMerkleNode { +#[serde(bound = "K: Key")] +pub(crate) enum PersistentMerkleNode { Empty, Branch { #[serde(with = "canonical")] comm: FieldType, - children: [Arc; TREE_BRANCH], + children: [Arc>; TREE_BRANCH], num_keys: usize, total_stakes: U256, }, @@ -28,37 +67,37 @@ pub(crate) enum PersistentMerkleNode { #[serde(with = "canonical")] comm: FieldType, #[serde(with = "canonical")] - key: EncodedPublicKey, + key: K, value: U256, }, } /// A compressed Merkle node for Merkle path #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] -pub enum MerklePathEntry { +pub enum MerklePathEntry { Branch { pos: usize, #[serde(with = "canonical")] siblings: [FieldType; TREE_BRANCH - 1], }, Leaf { - key: EncodedPublicKey, + key: K, value: U256, }, } /// Path from a Merkle root to a leaf -pub type MerklePath = Vec; +pub type MerklePath = Vec>; #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] /// An existential proof -pub struct MerkleProof { +pub struct MerkleProof { /// Index for the given key pub index: usize, /// A Merkle path for the given leaf - pub path: MerklePath, + pub path: MerklePath, } -impl MerkleProof { +impl MerkleProof { pub fn tree_height(&self) -> usize { self.path.len() - 1 } @@ -67,7 +106,7 @@ impl MerkleProof { &self.index } - pub fn get_key(&self) -> Option<&EncodedPublicKey> { + pub fn get_key(&self) -> Option<&K> { match self.path.first() { Some(MerklePathEntry::Leaf { key, value: _ }) => Some(key), _ => None, @@ -81,7 +120,7 @@ impl MerkleProof { } } - pub fn get_key_value(&self) -> Option<(&EncodedPublicKey, &U256)> { + pub fn get_key_value(&self) -> Option<(&K, &U256)> { match self.path.first() { Some(MerklePathEntry::Leaf { key, value }) => Some((key, value)), _ => None, @@ -91,12 +130,9 @@ impl MerkleProof { pub fn compute_root(&self) -> Result { match self.path.first() { Some(MerklePathEntry::Leaf { key, value }) => { - let input = [ - FieldType::from(0), - ::deserialize_compressed(&key.0[..]) - .unwrap(), - u256_to_field(value), - ]; + let mut input = [FieldType::default(); 3]; + input[..2].copy_from_slice(&(*key).clone().into_fields()[..]); + input[2] = u256_to_field(value); let init = Digest::evaluate(input).map_err(|_| StakeTableError::RescueError)?[0]; self.path .iter() @@ -160,7 +196,7 @@ impl MerkleCommitment { } } -impl PersistentMerkleNode { +impl PersistentMerkleNode { /// Returns the succint commitment of this subtree pub fn commitment(&self) -> FieldType { match self { @@ -234,7 +270,7 @@ impl PersistentMerkleNode { } /// Returns a Merkle proof to the given location - pub fn lookup(&self, height: usize, path: &[usize]) -> Result { + pub fn lookup(&self, height: usize, path: &[usize]) -> Result, StakeTableError> { match self { PersistentMerkleNode::Empty => Err(StakeTableError::KeyNotFound), PersistentMerkleNode::Branch { @@ -274,7 +310,7 @@ impl PersistentMerkleNode { /// Imagine that the keys in this subtree is sorted, returns the first key such that /// the prefix sum of withholding stakes is greater or equal the given `stake_number`. /// Useful for key sampling weighted by withholding stakes - pub fn get_key_by_stake(&self, mut stake_number: U256) -> Option<&EncodedPublicKey> { + pub fn get_key_by_stake(&self, mut stake_number: U256) -> Option<(&K, &U256)> { if stake_number >= self.total_stakes() { None } else { @@ -296,8 +332,8 @@ impl PersistentMerkleNode { PersistentMerkleNode::Leaf { comm: _, key, - value: _, - } => Some(key), + value, + } => Some((key, value)), } } } @@ -307,17 +343,14 @@ impl PersistentMerkleNode { &self, height: usize, path: &[usize], - key: &EncodedPublicKey, + key: &K, value: U256, ) -> Result, StakeTableError> { if height == 0 { if matches!(self, PersistentMerkleNode::Empty) { - let input = [ - FieldType::from(0u64), - ::deserialize_compressed(&key.0[..]) - .unwrap(), - u256_to_field(&value), - ]; + let mut input = [FieldType::default(); 3]; + input[..2].copy_from_slice(&(*key).clone().into_fields()[..]); + input[2] = u256_to_field(&value); Ok(Arc::new(PersistentMerkleNode::Leaf { comm: Digest::evaluate(input).map_err(|_| StakeTableError::RescueError)?[0], key: key.clone(), @@ -362,7 +395,7 @@ impl PersistentMerkleNode { &self, height: usize, path: &[usize], - key: &EncodedPublicKey, + key: &K, delta: U256, negative: bool, ) -> Result<(Arc, U256), StakeTableError> { @@ -410,12 +443,9 @@ impl PersistentMerkleNode { .checked_add(delta) .ok_or(StakeTableError::StakeOverflow) }?; - let input = [ - FieldType::from(0), - ::deserialize_compressed(&key.0[..]) - .unwrap(), - u256_to_field(&value), - ]; + let mut input = [FieldType::default(); 3]; + input[..2].copy_from_slice(&(*key).clone().into_fields()[..]); + input[2] = u256_to_field(&value); Ok(( Arc::new(PersistentMerkleNode::Leaf { comm: Digest::evaluate(input) @@ -438,7 +468,7 @@ impl PersistentMerkleNode { &self, height: usize, path: &[usize], - key: &EncodedPublicKey, + key: &K, value: U256, ) -> Result<(Arc, U256), StakeTableError> { match self { @@ -480,12 +510,9 @@ impl PersistentMerkleNode { value: old_value, } => { if key == cur_key { - let input = [ - FieldType::from(0), - ::deserialize_compressed(&key.0[..]) - .unwrap(), - u256_to_field(&value), - ]; + let mut input = [FieldType::default(); 3]; + input[..2].copy_from_slice(&(*key).clone().into_fields()[..]); + input[2] = u256_to_field(&value); Ok(( Arc::new(PersistentMerkleNode::Leaf { comm: Digest::evaluate(input) @@ -503,6 +530,68 @@ impl PersistentMerkleNode { } } +/// An owning iterator over the (key, value) entries of a `PersistentMerkleNode` +/// Traverse using post-order: children from left to right, finally visit the current. +pub struct IntoIter { + unvisited: Vec>>, + num_visited: usize, +} + +impl IntoIter { + /// create a new merkle tree iterator from a `root`. + /// This (abstract) `root` can be an internal node of a larger tree, our iterator + /// will iterate over all of its children. + pub(crate) fn new(root: Arc>) -> Self { + Self { + unvisited: vec![root], + num_visited: 0, + } + } +} + +impl Iterator for IntoIter { + type Item = (K, U256); + fn next(&mut self) -> Option { + if self.unvisited.is_empty() { + return None; + } + + let visiting = (**self.unvisited.last()?).clone(); + match visiting { + PersistentMerkleNode::Empty => None, + PersistentMerkleNode::Leaf { + comm: _, + key, + value, + } => { + self.unvisited.pop(); + self.num_visited += 1; + Some((key, value)) + } + PersistentMerkleNode::Branch { + comm: _, + children, + num_keys: _, + total_stakes: _, + } => { + self.unvisited.pop(); + // put the left-most child to the last, so it is visited first. + self.unvisited.extend(children.into_iter().rev()); + self.next() + } + } + } +} + +impl IntoIterator for PersistentMerkleNode { + type Item = (K, U256); + type IntoIter = self::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + Self::IntoIter::new(Arc::new(self)) + } +} + /// Convert an index to a list of Merkle path branches pub fn to_merkle_path(idx: usize, height: usize) -> Vec { let mut pos = idx; @@ -526,21 +615,26 @@ pub fn from_merkle_path(path: &[usize]) -> usize { #[cfg(test)] mod tests { use super::{to_merkle_path, PersistentMerkleNode}; - use crate::stake_table::{config::FieldType, EncodedPublicKey}; - use ark_std::{sync::Arc, vec, vec::Vec}; + use crate::stake_table::config; + use ark_std::{ + rand::{Rng, RngCore}, + sync::Arc, + vec, + vec::Vec, + }; use ethereum_types::U256; - use jf_utils::to_bytes; + use jf_utils::test_rng; + + type Key = ark_bn254::Fq; #[test] fn test_persistent_merkle_tree() { let height = 3; - let mut roots = vec![Arc::new(PersistentMerkleNode::Empty)]; + let mut roots = vec![Arc::new(PersistentMerkleNode::::Empty)]; let path = (0..10) .map(|idx| to_merkle_path(idx, height)) .collect::>(); - let keys = (0..10) - .map(|i| EncodedPublicKey(to_bytes!(&FieldType::from(i)).unwrap())) - .collect::>(); + let keys = (0..10).map(Key::from).collect::>(); // Insert key (0..10) with associated value 100 to the persistent merkle tree for (i, key) in keys.iter().enumerate() { roots.push( @@ -573,6 +667,7 @@ mod tests { .unwrap() .get_key_by_stake(U256::from(i as u64 * 100 + i as u64 + 1)) .unwrap() + .0 ); }); @@ -644,4 +739,31 @@ mod tests { ); assert_eq!(U256::from(1000), roots.last().unwrap().total_stakes()); } + + #[test] + fn test_mt_iter() { + let height = 3; + let capacity = config::TREE_BRANCH.pow(height); + let mut rng = test_rng(); + + for _ in 0..5 { + let num_keys = rng.gen_range(1..capacity); + let keys: Vec = (0..num_keys).map(|i| Key::from(i as u64)).collect(); + let paths = (0..num_keys) + .map(|idx| to_merkle_path(idx, height as usize)) + .collect::>(); + let amounts: Vec = (0..num_keys).map(|_| U256::from(rng.next_u64())).collect(); + + // register all `num_keys` of (key, amount) pair. + let mut root = Arc::new(PersistentMerkleNode::::Empty); + for i in 0..num_keys { + root = root + .register(height as usize, &paths[i], &keys[i], amounts[i]) + .unwrap(); + } + for (i, (k, v)) in (*root).clone().into_iter().enumerate() { + assert_eq!((k, v), (keys[i], amounts[i])); + } + } + } }