Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 133 additions & 59 deletions akd/src/append_only_zks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ fn get_parallel_levels() -> Option<u8> {
let parallel_levels = (available_parallelism as f32).log2().ceil() as u8;

info!(
"Insert will be performed in parallel (available parallelism: {}, parallel levels: {})",
"Parallel levels requested (available parallelism: {}, parallel levels: {})",
available_parallelism, parallel_levels
);
Some(parallel_levels)
Expand Down Expand Up @@ -315,11 +315,20 @@ impl Azks {
let azks_element_set = AzksElementSet::from(nodes);

// preload the nodes that we will visit during the insertion
let (_, time_s) =
let (fallible_load_count, time_s) =
tic_toc(self.preload_nodes(storage, &azks_element_set, PreloadParallelism::Default))
.await;
let load_count = fallible_load_count?;
if let Some(time) = time_s {
info!("Preload of tree took {} s", time,);
info!(
"Preload of nodes for insert ({} objects loaded), took {} s",
load_count, time,
);
} else {
info!(
"Preload of nodes for insert ({} objects loaded) completed.",
load_count
);
Comment on lines +328 to +331
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really want to log this as an INFO level log? Although this will arguably be much less hot than a read path, it may be possible for this to become noisy.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually intentional, we we switched the previous preload log to DEBUG, we lost visibility on how long node preloading was taking during sequencing (and how many objects were loaded).

}

// increment the current epoch
Expand Down Expand Up @@ -883,32 +892,32 @@ impl Azks {
// Suppose the epochs start_epoch and end_epoch exist in the set.
// This function should return the proof that nothing was removed/changed from the tree
// between these epochs.
let (fallible_load_count, time_s) = tic_toc(self.preload_audit_nodes::<_>(
storage,
latest_epoch,
start_epoch,
end_epoch,
PreloadParallelism::Default,
))
.await;
let load_count = fallible_load_count?;
if let Some(time) = time_s {
info!(
"Preload of nodes for audit ({} objects loaded), took {} s",
load_count, time,
);
} else {
info!(
"Preload of nodes for audit ({} objects loaded) completed.",
load_count
);
}
storage.log_metrics().await;

let node =
TreeNode::get_from_storage(storage, &NodeKey(NodeLabel::root()), latest_epoch).await?;

for ep in start_epoch..end_epoch {
let (fallable_load_count, time_s) = tic_toc(self.gather_audit_proof_nodes::<_>(
vec![node.clone()],
storage,
ep,
ep + 1,
))
.await;
let load_count = fallable_load_count?;
if let Some(time) = time_s {
info!(
"Preload of nodes for audit ({} objects loaded), took {} s",
load_count, time,
);
} else {
info!(
"Preload of nodes for audit ({} objects loaded) completed.",
load_count
);
}
storage.log_metrics().await;

let (unchanged, leaves) = Self::get_append_only_proof_helper::<TC, _>(
latest_epoch,
storage,
Expand All @@ -930,60 +939,125 @@ impl Azks {
Ok(AppendOnlyProof { proofs, epochs })
}

fn determine_retrieval_nodes(
node: &TreeNode,
async fn preload_audit_nodes<S: Database + 'static>(
&self,
storage: &StorageManager<S>,
latest_epoch: u64,
start_epoch: u64,
end_epoch: u64,
) -> Vec<NodeLabel> {
if node.node_type == TreeNodeType::Leaf {
return vec![];
parallelism: PreloadParallelism,
) -> Result<u64, AkdError> {
if !storage.has_cache() {
info!("No cache found, skipping preload");
return Ok(0);
}

if node.get_latest_epoch() <= start_epoch {
return vec![];
}
let node_keys = vec![NodeKey(NodeLabel::root())];
let parallel_levels = if parallelism == PreloadParallelism::Disabled {
None
} else {
get_parallel_levels()
};

if node.min_descendant_epoch > end_epoch {
return vec![];
}
let load_count = Azks::recursive_preload_audit_nodes(
storage,
node_keys,
latest_epoch,
start_epoch,
end_epoch,
parallel_levels,
)
.await?;

match (node.left_child, node.right_child) {
(Some(lc), None) => vec![lc],
(None, Some(rc)) => vec![rc],
(Some(lc), Some(rc)) => vec![lc, rc],
_ => vec![],
}
Ok(load_count)
}

async fn gather_audit_proof_nodes<S: Database>(
&self,
nodes: Vec<TreeNode>,
#[async_recursion]
#[allow(clippy::multiple_bound_locations)]
async fn recursive_preload_audit_nodes<S: Database + 'static>(
storage: &StorageManager<S>,
node_keys: Vec<NodeKey>,
latest_epoch: u64,
start_epoch: u64,
end_epoch: u64,
parallel_levels: Option<u8>,
) -> Result<u64, AkdError> {
let mut children_to_fetch: Vec<NodeKey> = nodes
if node_keys.is_empty() {
return Ok(0);
}

let nodes = TreeNode::batch_get_from_storage(storage, &node_keys, latest_epoch).await?;
let mut load_count = node_keys.len() as u64;

let mut next_nodes: Vec<NodeKey> = nodes
.iter()
.flat_map(|node| Self::determine_retrieval_nodes(node, start_epoch, end_epoch))
.map(NodeKey)
.filter(|node| {
node.node_type != TreeNodeType::Leaf
&& node.get_latest_epoch() > start_epoch
&& node.min_descendant_epoch <= end_epoch
})
.flat_map(|node| {
[Direction::Left, Direction::Right]
.iter()
.filter_map(|dir| node.get_child_label(*dir).map(NodeKey))
.collect::<Vec<NodeKey>>()
})
.collect();

let mut element_count = 0u64;
while !children_to_fetch.is_empty() {
let got = TreeNode::batch_get_from_storage(
if parallel_levels.is_some() {
// Divide work into two equivalent chunks.
let right_next_nodes = next_nodes.split_off(next_nodes.len() / 2);
let left_next_nodes = next_nodes;
let child_parallel_levels =
parallel_levels.and_then(|x| if x <= 1 { None } else { Some(x - 1) });

// Handle the left chunk in a different tokio task.
let storage_clone = storage.clone();
let left_future = async move {
Azks::recursive_preload_audit_nodes(
&storage_clone,
left_next_nodes,
latest_epoch,
start_epoch,
end_epoch,
child_parallel_levels,
)
.await
};
let handle = tokio::task::spawn(left_future);

// Handle the right chunk in the current task.
let right_load_count = Azks::recursive_preload_audit_nodes(
storage,
&children_to_fetch,
self.get_latest_epoch(),
right_next_nodes,
latest_epoch,
start_epoch,
end_epoch,
child_parallel_levels,
)
.await?;
element_count += got.len() as u64;
children_to_fetch = got
.iter()
.flat_map(|node| Self::determine_retrieval_nodes(node, start_epoch, end_epoch))
.map(NodeKey)
.collect();
load_count += right_load_count;

// Join on the handle for the left chunk.
let left_load_count = handle
.await
.map_err(|e| AkdError::Parallelism(ParallelismError::JoinErr(e.to_string())))??;
load_count += left_load_count;
} else {
// Perform all the work in the current task.
let next_load_count = Azks::recursive_preload_audit_nodes(
storage,
next_nodes,
latest_epoch,
start_epoch,
end_epoch,
parallel_levels,
)
.await?;
load_count += next_load_count;
}
Ok(element_count)

Ok(load_count)
}

#[async_recursion]
Expand Down