Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 133 additions & 59 deletions akd/src/append_only_zks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ fn get_parallel_levels() -> Option<u8> {
let parallel_levels = (available_parallelism as f32).log2().ceil() as u8;

info!(
"Insert will be performed in parallel (available parallelism: {}, parallel levels: {})",
"Parallel levels requested (available parallelism: {}, parallel levels: {})",
available_parallelism, parallel_levels
);
Some(parallel_levels)
Expand Down Expand Up @@ -315,11 +315,20 @@ impl Azks {
let azks_element_set = AzksElementSet::from(nodes);

// preload the nodes that we will visit during the insertion
let (_, time_s) =
let (fallable_load_count, time_s) =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the first item in the tuple is intended to be named in a manner indicating that it can fail, I believe we may want to rename this fallible_load_count.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I copied this from elsewhere in the file, but I will edit both to use the right spelling :)

tic_toc(self.preload_nodes(storage, &azks_element_set, PreloadParallelism::Default))
.await;
let load_count = fallable_load_count?;
if let Some(time) = time_s {
info!("Preload of tree took {} s", time,);
info!(
"Preload of nodes for insert ({} objects loaded), took {} s",
load_count, time,
);
} else {
info!(
"Preload of nodes for insert ({} objects loaded) completed.",
load_count
);
Comment on lines +328 to +331
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really want to log this as an INFO level log? Although this will arguably be much less hot than a read path, it may be possible for this to become noisy.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually intentional, we we switched the previous preload log to DEBUG, we lost visibility on how long node preloading was taking during sequencing (and how many objects were loaded).

}

// increment the current epoch
Expand Down Expand Up @@ -883,32 +892,32 @@ impl Azks {
// Suppose the epochs start_epoch and end_epoch exist in the set.
// This function should return the proof that nothing was removed/changed from the tree
// between these epochs.
let (fallable_load_count, time_s) = tic_toc(self.preload_audit_nodes::<_>(
storage,
latest_epoch,
start_epoch,
end_epoch,
PreloadParallelism::Default,
))
.await;
let load_count = fallable_load_count?;
if let Some(time) = time_s {
info!(
"Preload of nodes for audit ({} objects loaded), took {} s",
load_count, time,
);
} else {
info!(
"Preload of nodes for audit ({} objects loaded) completed.",
load_count
);
}
storage.log_metrics().await;

let node =
TreeNode::get_from_storage(storage, &NodeKey(NodeLabel::root()), latest_epoch).await?;

for ep in start_epoch..end_epoch {
let (fallable_load_count, time_s) = tic_toc(self.gather_audit_proof_nodes::<_>(
vec![node.clone()],
storage,
ep,
ep + 1,
))
.await;
let load_count = fallable_load_count?;
if let Some(time) = time_s {
info!(
"Preload of nodes for audit ({} objects loaded), took {} s",
load_count, time,
);
} else {
info!(
"Preload of nodes for audit ({} objects loaded) completed.",
load_count
);
}
storage.log_metrics().await;

let (unchanged, leaves) = Self::get_append_only_proof_helper::<TC, _>(
latest_epoch,
storage,
Expand All @@ -930,60 +939,125 @@ impl Azks {
Ok(AppendOnlyProof { proofs, epochs })
}

fn determine_retrieval_nodes(
node: &TreeNode,
async fn preload_audit_nodes<S: Database + 'static>(
&self,
storage: &StorageManager<S>,
latest_epoch: u64,
start_epoch: u64,
end_epoch: u64,
) -> Vec<NodeLabel> {
if node.node_type == TreeNodeType::Leaf {
return vec![];
parallelism: PreloadParallelism,
) -> Result<u64, AkdError> {
if !storage.has_cache() {
info!("No cache found, skipping preload");
return Ok(0);
}

if node.get_latest_epoch() <= start_epoch {
return vec![];
}
let node_keys = vec![NodeKey(NodeLabel::root())];
let parallel_levels = if parallelism == PreloadParallelism::Disabled {
None
} else {
get_parallel_levels()
};

if node.min_descendant_epoch > end_epoch {
return vec![];
}
let load_count = Azks::recursive_preload_audit_nodes(
storage,
node_keys,
latest_epoch,
start_epoch,
end_epoch,
parallel_levels,
)
.await?;

match (node.left_child, node.right_child) {
(Some(lc), None) => vec![lc],
(None, Some(rc)) => vec![rc],
(Some(lc), Some(rc)) => vec![lc, rc],
_ => vec![],
}
Ok(load_count)
}

async fn gather_audit_proof_nodes<S: Database>(
&self,
nodes: Vec<TreeNode>,
#[async_recursion]
#[allow(clippy::multiple_bound_locations)]
async fn recursive_preload_audit_nodes<S: Database + 'static>(
storage: &StorageManager<S>,
node_keys: Vec<NodeKey>,
latest_epoch: u64,
start_epoch: u64,
end_epoch: u64,
parallel_levels: Option<u8>,
) -> Result<u64, AkdError> {
let mut children_to_fetch: Vec<NodeKey> = nodes
if node_keys.is_empty() {
return Ok(0);
}

let nodes = TreeNode::batch_get_from_storage(storage, &node_keys, latest_epoch).await?;
let mut load_count = node_keys.len() as u64;

let mut next_nodes: Vec<NodeKey> = nodes
.iter()
.flat_map(|node| Self::determine_retrieval_nodes(node, start_epoch, end_epoch))
.map(NodeKey)
.filter(|node| {
node.node_type != TreeNodeType::Leaf
&& node.get_latest_epoch() > start_epoch
&& node.min_descendant_epoch < end_epoch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be <=? Previously, we were returning an empty vec when determining retrieval nodes in cases where the min descendant epoch was strictly greater than the end epoch. Since we're doing the inverse here, shouldn't we be less than or equal to?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great catch!

})
.flat_map(|node| {
[Direction::Left, Direction::Right]
.iter()
.filter_map(|dir| node.get_child_label(*dir).map(NodeKey))
.collect::<Vec<NodeKey>>()
})
.collect();

let mut element_count = 0u64;
while !children_to_fetch.is_empty() {
let got = TreeNode::batch_get_from_storage(
if parallel_levels.is_some() {
// Divide work into two equivalent chunks.
let right_next_nodes = next_nodes.split_off(next_nodes.len() / 2);
let left_next_nodes = next_nodes;
let child_parallel_levels =
parallel_levels.and_then(|x| if x <= 1 { None } else { Some(x - 1) });

// Handle the left chunk in a different tokio task.
let storage_clone = storage.clone();
let left_future = async move {
Azks::recursive_preload_audit_nodes(
&storage_clone,
left_next_nodes,
latest_epoch,
start_epoch,
end_epoch,
child_parallel_levels,
)
.await
};
let handle = tokio::task::spawn(left_future);

// Handle the right chunk in the current task.
let right_load_count = Azks::recursive_preload_audit_nodes(
storage,
&children_to_fetch,
self.get_latest_epoch(),
right_next_nodes,
latest_epoch,
start_epoch,
end_epoch,
child_parallel_levels,
)
.await?;
element_count += got.len() as u64;
children_to_fetch = got
.iter()
.flat_map(|node| Self::determine_retrieval_nodes(node, start_epoch, end_epoch))
.map(NodeKey)
.collect();
load_count += right_load_count;

// Join on the handle for the left chunk.
let left_load_count = handle
.await
.map_err(|e| AkdError::Parallelism(ParallelismError::JoinErr(e.to_string())))??;
load_count += left_load_count;
} else {
// Perform all the work in the current task.
let next_load_count = Azks::recursive_preload_audit_nodes(
storage,
next_nodes,
latest_epoch,
start_epoch,
end_epoch,
parallel_levels,
)
.await?;
load_count += next_load_count;
}
Ok(element_count)

Ok(load_count)
}

#[async_recursion]
Expand Down