diff --git a/akd/src/append_only_zks.rs b/akd/src/append_only_zks.rs index 2a54b4f1..2ab9450a 100644 --- a/akd/src/append_only_zks.rs +++ b/akd/src/append_only_zks.rs @@ -75,7 +75,7 @@ fn get_parallel_levels() -> Option { let parallel_levels = (available_parallelism as f32).log2().ceil() as u8; info!( - "Insert will be performed in parallel (available parallelism: {}, parallel levels: {})", + "Parallel levels requested (available parallelism: {}, parallel levels: {})", available_parallelism, parallel_levels ); Some(parallel_levels) @@ -315,11 +315,20 @@ impl Azks { let azks_element_set = AzksElementSet::from(nodes); // preload the nodes that we will visit during the insertion - let (_, time_s) = + let (fallible_load_count, time_s) = tic_toc(self.preload_nodes(storage, &azks_element_set, PreloadParallelism::Default)) .await; + let load_count = fallible_load_count?; if let Some(time) = time_s { - info!("Preload of tree took {} s", time,); + info!( + "Preload of nodes for insert ({} objects loaded), took {} s", + load_count, time, + ); + } else { + info!( + "Preload of nodes for insert ({} objects loaded) completed.", + load_count + ); } // increment the current epoch @@ -883,32 +892,32 @@ impl Azks { // Suppose the epochs start_epoch and end_epoch exist in the set. // This function should return the proof that nothing was removed/changed from the tree // between these epochs. + let (fallible_load_count, time_s) = tic_toc(self.preload_audit_nodes::<_>( + storage, + latest_epoch, + start_epoch, + end_epoch, + PreloadParallelism::Default, + )) + .await; + let load_count = fallible_load_count?; + if let Some(time) = time_s { + info!( + "Preload of nodes for audit ({} objects loaded), took {} s", + load_count, time, + ); + } else { + info!( + "Preload of nodes for audit ({} objects loaded) completed.", + load_count + ); + } + storage.log_metrics().await; let node = TreeNode::get_from_storage(storage, &NodeKey(NodeLabel::root()), latest_epoch).await?; for ep in start_epoch..end_epoch { - let (fallable_load_count, time_s) = tic_toc(self.gather_audit_proof_nodes::<_>( - vec![node.clone()], - storage, - ep, - ep + 1, - )) - .await; - let load_count = fallable_load_count?; - if let Some(time) = time_s { - info!( - "Preload of nodes for audit ({} objects loaded), took {} s", - load_count, time, - ); - } else { - info!( - "Preload of nodes for audit ({} objects loaded) completed.", - load_count - ); - } - storage.log_metrics().await; - let (unchanged, leaves) = Self::get_append_only_proof_helper::( latest_epoch, storage, @@ -930,60 +939,125 @@ impl Azks { Ok(AppendOnlyProof { proofs, epochs }) } - fn determine_retrieval_nodes( - node: &TreeNode, + async fn preload_audit_nodes( + &self, + storage: &StorageManager, + latest_epoch: u64, start_epoch: u64, end_epoch: u64, - ) -> Vec { - if node.node_type == TreeNodeType::Leaf { - return vec![]; + parallelism: PreloadParallelism, + ) -> Result { + if !storage.has_cache() { + info!("No cache found, skipping preload"); + return Ok(0); } - if node.get_latest_epoch() <= start_epoch { - return vec![]; - } + let node_keys = vec![NodeKey(NodeLabel::root())]; + let parallel_levels = if parallelism == PreloadParallelism::Disabled { + None + } else { + get_parallel_levels() + }; - if node.min_descendant_epoch > end_epoch { - return vec![]; - } + let load_count = Azks::recursive_preload_audit_nodes( + storage, + node_keys, + latest_epoch, + start_epoch, + end_epoch, + parallel_levels, + ) + .await?; - match (node.left_child, node.right_child) { - (Some(lc), None) => vec![lc], - (None, Some(rc)) => vec![rc], - (Some(lc), Some(rc)) => vec![lc, rc], - _ => vec![], - } + Ok(load_count) } - async fn gather_audit_proof_nodes( - &self, - nodes: Vec, + #[async_recursion] + #[allow(clippy::multiple_bound_locations)] + async fn recursive_preload_audit_nodes( storage: &StorageManager, + node_keys: Vec, + latest_epoch: u64, start_epoch: u64, end_epoch: u64, + parallel_levels: Option, ) -> Result { - let mut children_to_fetch: Vec = nodes + if node_keys.is_empty() { + return Ok(0); + } + + let nodes = TreeNode::batch_get_from_storage(storage, &node_keys, latest_epoch).await?; + let mut load_count = node_keys.len() as u64; + + let mut next_nodes: Vec = nodes .iter() - .flat_map(|node| Self::determine_retrieval_nodes(node, start_epoch, end_epoch)) - .map(NodeKey) + .filter(|node| { + node.node_type != TreeNodeType::Leaf + && node.get_latest_epoch() > start_epoch + && node.min_descendant_epoch <= end_epoch + }) + .flat_map(|node| { + [Direction::Left, Direction::Right] + .iter() + .filter_map(|dir| node.get_child_label(*dir).map(NodeKey)) + .collect::>() + }) .collect(); - let mut element_count = 0u64; - while !children_to_fetch.is_empty() { - let got = TreeNode::batch_get_from_storage( + if parallel_levels.is_some() { + // Divide work into two equivalent chunks. + let right_next_nodes = next_nodes.split_off(next_nodes.len() / 2); + let left_next_nodes = next_nodes; + let child_parallel_levels = + parallel_levels.and_then(|x| if x <= 1 { None } else { Some(x - 1) }); + + // Handle the left chunk in a different tokio task. + let storage_clone = storage.clone(); + let left_future = async move { + Azks::recursive_preload_audit_nodes( + &storage_clone, + left_next_nodes, + latest_epoch, + start_epoch, + end_epoch, + child_parallel_levels, + ) + .await + }; + let handle = tokio::task::spawn(left_future); + + // Handle the right chunk in the current task. + let right_load_count = Azks::recursive_preload_audit_nodes( storage, - &children_to_fetch, - self.get_latest_epoch(), + right_next_nodes, + latest_epoch, + start_epoch, + end_epoch, + child_parallel_levels, ) .await?; - element_count += got.len() as u64; - children_to_fetch = got - .iter() - .flat_map(|node| Self::determine_retrieval_nodes(node, start_epoch, end_epoch)) - .map(NodeKey) - .collect(); + load_count += right_load_count; + + // Join on the handle for the left chunk. + let left_load_count = handle + .await + .map_err(|e| AkdError::Parallelism(ParallelismError::JoinErr(e.to_string())))??; + load_count += left_load_count; + } else { + // Perform all the work in the current task. + let next_load_count = Azks::recursive_preload_audit_nodes( + storage, + next_nodes, + latest_epoch, + start_epoch, + end_epoch, + parallel_levels, + ) + .await?; + load_count += next_load_count; } - Ok(element_count) + + Ok(load_count) } #[async_recursion]