Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions akd/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ experimental = ["akd_core/experimental"]
default = [
"public_auditing",
"parallel_vrf",
"parallel_insert",
"parallel_azks",
"preload_history",
"greedy_lookup_preload",
"experimental",
Expand All @@ -28,8 +28,8 @@ bench = ["experimental", "public_tests", "tokio/rt-multi-thread"]
# Greedy loading of lookup proof nodes
greedy_lookup_preload = []
public_auditing = ["dep:protobuf", "akd_core/protobuf"]
# Parallelize node insertion during publish
parallel_insert = []
# Parallelize node fetch and insertion during publish
parallel_azks = []
# Parallelize VRF calculations during publish
parallel_vrf = ["akd_core/parallel_vrf"]
# Enable pre-loading of the nodes when generating history proofs
Expand Down
161 changes: 129 additions & 32 deletions akd/src/append_only_zks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,14 @@ use std::collections::HashSet;
use std::convert::TryFrom;
use std::marker::Sync;
use std::ops::Deref;
use std::sync::Arc;

/// The default azks key
pub const DEFAULT_AZKS_KEY: u8 = 1u8;

/// The default available parallelism for parallel batch insertions, used when
/// available parallelism cannot be determined at runtime. Should be > 1
#[cfg(feature = "parallel_insert")]
#[cfg(feature = "parallel_azks")]
pub const DEFAULT_AVAILABLE_PARALLELISM: usize = 32;

async fn tic_toc<T>(f: impl core::future::Future<Output = T>) -> (T, Option<f64>) {
Expand All @@ -53,10 +54,10 @@ async fn tic_toc<T>(f: impl core::future::Future<Output = T>) -> (T, Option<f64>
}

fn get_parallel_levels() -> Option<u8> {
#[cfg(not(feature = "parallel_insert"))]
#[cfg(not(feature = "parallel_azks"))]
return None;

#[cfg(feature = "parallel_insert")]
#[cfg(feature = "parallel_azks")]
{
// Based on profiling results, the best performance is achieved when the
// number of spawned tasks is equal to the number of available threads.
Expand Down Expand Up @@ -305,7 +306,7 @@ impl Azks {
let azks_element_set = AzksElementSet::from(nodes);

// preload the nodes that we will visit during the insertion
let (_, time_s) = tic_toc(self.preload_nodes(storage, &azks_element_set)).await;
let (_, time_s) = tic_toc(self.preload_nodes(storage, &azks_element_set, false)).await;
if let Some(time) = time_s {
info!("Preload of tree took {} s", time,);
}
Expand Down Expand Up @@ -620,7 +621,7 @@ impl Azks {
Ok(count)
}

pub(crate) async fn preload_lookup_nodes<S: Database + Send + Sync>(
pub(crate) async fn preload_lookup_nodes<S: Database + Send + Sync + 'static>(
&self,
storage: &StorageManager<S>,
lookup_infos: &[LookupInfo],
Expand All @@ -637,47 +638,132 @@ impl Azks {
})
.collect();

// Load nodes.
self.preload_nodes(storage, &AzksElementSet::from(lookup_nodes))
// Load nodes without parallelism, since multiple lookups could be
// happening and parallelism might consume too many resources.
self.preload_nodes(storage, &AzksElementSet::from(lookup_nodes), true)
.await
}

/// Preloads given nodes using breadth-first search.
pub(crate) async fn preload_nodes<S: Database>(
pub(crate) async fn preload_nodes<S: Database + 'static>(
&self,
storage: &StorageManager<S>,
azks_element_set: &AzksElementSet,
no_parallelism: bool,
) -> Result<u64, AkdError> {
if !storage.has_cache() {
info!("No cache found, skipping preload");
return Ok(0);
}

let mut load_count: u64 = 0;
let mut current_nodes = vec![NodeKey(NodeLabel::root())];
// We clone and wrap AzksElementSet in an Arc so that it can be passed
// to another tokio task safely. The element set does not even need to
// be cloned, since preloading never modifies it. However, a clone helps
// avoid propagating the responsibility of creating an Arc to the caller.
// We can consider doing away with it in future.
let azks_element_set = Arc::new(azks_element_set.clone());
let epoch = self.get_latest_epoch();
let node_keys = vec![NodeKey(NodeLabel::root())];
let parallel_levels = if no_parallelism {
None
} else {
get_parallel_levels()
};

while !current_nodes.is_empty() {
let nodes =
TreeNode::batch_get_from_storage(storage, &current_nodes, self.get_latest_epoch())
.await?;
load_count += nodes.len() as u64;
let load_count = Azks::recursive_preload_nodes(
storage,
azks_element_set,
epoch,
node_keys,
parallel_levels,
)
.await?;

// Now that states are loaded in the cache, we can read and access them.
// Note, we perform directional loads to avoid accessing remote storage
// individually for each node's state.
current_nodes = nodes
.iter()
.filter(|node| azks_element_set.contains_prefix(&node.label))
.flat_map(|node| {
[Direction::Left, Direction::Right]
.iter()
.filter_map(|dir| node.get_child_label(*dir).map(NodeKey))
.collect::<Vec<NodeKey>>()
})
.collect();
debug!("Preload of tree ({} nodes) completed", load_count);

Ok(load_count)
}

#[async_recursion]
#[allow(clippy::multiple_bound_locations)]
async fn recursive_preload_nodes<S: Database + 'static>(
storage: &StorageManager<S>,
azks_element_set: Arc<AzksElementSet>,
epoch: u64,
node_keys: Vec<NodeKey>,
parallel_levels: Option<u8>,
) -> Result<u64, AkdError> {
if node_keys.is_empty() {
return Ok(0);
}

debug!("Preload of tree ({} nodes) completed", load_count);
let nodes = TreeNode::batch_get_from_storage(storage, &node_keys, epoch).await?;
let mut load_count = node_keys.len() as u64;

// Now that states are loaded in the cache, we can read and access them.
// Note, we perform directional loads to avoid accessing remote storage
// individually for each node's state.
let mut next_nodes: Vec<NodeKey> = nodes
.iter()
.filter(|node| azks_element_set.contains_prefix(&node.label))
.flat_map(|node| {
[Direction::Left, Direction::Right]
.iter()
.filter_map(|dir| node.get_child_label(*dir).map(NodeKey))
.collect::<Vec<NodeKey>>()
})
.collect();

if parallel_levels.is_some() {
// Divide work into two equivalent chunks.
let right_next_nodes = next_nodes.split_off(next_nodes.len() / 2);
let left_next_nodes = next_nodes;
let child_parallel_levels =
parallel_levels.and_then(|x| if x <= 1 { None } else { Some(x - 1) });

// Handle the left chunk in a different tokio task.
let storage_clone = storage.clone();
let azks_element_set_clone = azks_element_set.clone();
let left_future = async move {
Azks::recursive_preload_nodes(
&storage_clone,
azks_element_set_clone,
epoch,
left_next_nodes,
child_parallel_levels,
)
.await
};
let handle = tokio::task::spawn(left_future);

// Handle the right chunk in the current task.
let right_load_count = Azks::recursive_preload_nodes(
storage,
azks_element_set,
epoch,
right_next_nodes,
child_parallel_levels,
)
.await?;
load_count += right_load_count;

// Join on the handle for the left chunk.
let left_load_count = handle
.await
.map_err(|e| AkdError::Parallelism(ParallelismError::JoinErr(e.to_string())))??;
load_count += left_load_count;
} else {
// Perform all the work in the current task.
let next_load_count = Azks::recursive_preload_nodes(
storage,
azks_element_set,
epoch,
next_nodes,
parallel_levels,
)
.await?;
load_count += next_load_count;
}

Ok(load_count)
}
Expand Down Expand Up @@ -926,7 +1012,7 @@ impl Azks {
let maybe_task: Option<
tokio::task::JoinHandle<Result<(Vec<AzksElement>, Vec<AzksElement>), AkdError>>,
> = if let Some(left_child) = node.left_child {
#[cfg(feature = "parallel_insert")]
#[cfg(feature = "parallel_azks")]
{
if parallel_levels.map(|p| p as u64 > level).unwrap_or(false) {
// we can parallelise further!
Expand Down Expand Up @@ -975,7 +1061,7 @@ impl Azks {
}
}

#[cfg(not(feature = "parallel_insert"))]
#[cfg(not(feature = "parallel_azks"))]
{
// NO Parallelism, BAD! parallelism. Get your nose out of the garbage!
let child_node =
Expand Down Expand Up @@ -1560,7 +1646,18 @@ mod tests {
]);
let expected_preload_count = 3u64;
let actual_preload_count = azks
.preload_nodes(&storage_manager, &azks_element_set)
.preload_nodes(&storage_manager, &azks_element_set, false)
.await
.expect("Failed to preload nodes");

assert_eq!(
expected_preload_count, actual_preload_count,
"Preload count returned unexpected value!"
);

// Test preload with no_parallelism
let actual_preload_count = azks
.preload_nodes(&storage_manager, &azks_element_set, true)
.await
.expect("Failed to preload nodes");

Expand Down
2 changes: 1 addition & 1 deletion akd/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@
//!
//! Performance optimizations:
//! - `parallel_vrf`: Enables the VRF computations to be run in parallel
//! - `parallel_insert`: Enables nodes to be inserted via multiple threads during a publish operation
//! - `parallel_azks`: Enables nodes to be fetched and inserted via multiple threads during a publish operation
//! - `preload_history`: Enables pre-loading of nodes when generating history proofs
//! - `greedy_lookup_preload`: Greedy loading of lookup proof nodes
//!
Expand Down