Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion compiler/noirc_evaluator/src/ssa/ir/call_graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use fxhash::{FxHashMap as HashMap, FxHashSet as HashSet};
use petgraph::{
algo::kosaraju_scc,
graph::{DiGraph, NodeIndex as PetGraphIndex},
visit::EdgeRef,
visit::{Dfs, EdgeRef, Walker},
};

use crate::ssa::ssa_gen::Ssa;
Expand Down Expand Up @@ -236,6 +236,28 @@ impl CallGraph {

counts
}

/// Returns all functions reachable from the provided root(s).
///
/// This function uses DFS internally to find all nodes reachable from the provided root(s).
pub(crate) fn reachable_from(
&self,
roots: impl IntoIterator<Item = FunctionId>,
) -> HashSet<FunctionId> {
let mut reachable = HashSet::default();

for root in roots {
// If the root does not exist, skip it.
let Some(&start_index) = self.ids_to_indices.get(&root) else {
continue;
};
// Use DFS to determine all reachable nodes from the root
let dfs = Dfs::new(&self.graph, start_index);
reachable.extend(dfs.iter(&self.graph).map(|index| self.indices_to_ids[&index]));
}

reachable
}
}

/// Utility function to find out the direct calls of a function.
Expand Down
80 changes: 25 additions & 55 deletions compiler/noirc_evaluator/src/ssa/opt/remove_unreachable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,19 @@
//! Even if not immediately called, it may later be dynamically loaded and invoked.
//! This marking is conservative but ensures correctness. We should instead rely on [mem2reg][crate::ssa::opt::mem2reg]
//! for resolving loads/stores.
//! - A function is reachable if it is used in a block terminator (e.g., returned from a function)
//!
//! The pass performs a recursive traversal starting from all entry points and marks
//! any transitively reachable functions. It then discards the rest.
//! The pass builds a call graph based upon the definition of reachability above.
//! It then identifies all entry points and uses the [CallGraph::reachable_from] utility
//! to mark all transitively reachable functions. It then discards the rest.
//!
//! This pass helps shrink the SSA before compilation stages like inlining and dead code elimination.

use std::collections::BTreeSet;

use fxhash::FxHashSet as HashSet;

use crate::ssa::{
ir::{
call_graph::CallGraph,
function::{Function, FunctionId},
instruction::Instruction,
value::Value,
Expand All @@ -33,66 +34,35 @@ use crate::ssa::{
impl Ssa {
/// See [`remove_unreachable`][self] module for more information.
pub(crate) fn remove_unreachable_functions(mut self) -> Self {
let mut reachable_functions = HashSet::default();

// Go through all the functions, and if we have an entry point, extend the set of all
// functions which are reachable.
for (id, function) in self.functions.iter() {
// Identify entry points
let entry_points = self.functions.iter().filter_map(|(&id, func)| {
// Not using `Ssa::is_entry_point` because it could leave Brillig functions that nobody calls in the SSA,
// because it considers every Brillig function as an entry point.
let is_entry_point = function.id() == self.main_id
|| function.runtime().is_acir() && function.runtime().is_entry_point();

if is_entry_point {
collect_reachable_functions(&self, *id, &mut reachable_functions);
}
}
let is_entry_point =
id == self.main_id || func.runtime().is_acir() && func.runtime().is_entry_point();
is_entry_point.then_some(id)
});

// Build call graph dependencies using this passes definition of reachability.
let dependencies = self
.functions
.iter()
.map(|(&id, func)| {
let used = used_functions(func);
(id, used)
})
.collect();
let call_graph = CallGraph::from_deps(dependencies);

// Traverse the call graph from all entry points
let reachable_functions = call_graph.reachable_from(entry_points);

// Discard all functions not marked as reachable
self.functions.retain(|id, _| reachable_functions.contains(id));
self
}
}

/// Recursively determine the reachable functions from a given function.
/// This function is only intended to be called on functions that are already known
/// to be entry points or transitively reachable from one.
///
/// # Arguments
/// - `ssa`: The full [Ssa] structure containing all functions.
/// - `current_func_id`: The [FunctionId] from which to begin a traversal.
/// - `reachable_functions`: A mutable set used to collect all reachable functions.
/// It serves both as the final output of this traversal and as a visited set
/// to prevent cycles and redundant recursion.
fn collect_reachable_functions(
ssa: &Ssa,
current_func_id: FunctionId,
reachable_functions: &mut HashSet<FunctionId>,
) {
// If this function has already been determine as reachable, then we have already
// processed the given function and we can simply return.
if reachable_functions.contains(&current_func_id) {
return;
}
// Mark the given function as reachable
reachable_functions.insert(current_func_id);

// If the debugger is used, its possible for function inlining
// to remove functions that the debugger still references
let Some(func) = ssa.functions.get(&current_func_id) else {
return;
};

// Get the set of reachable functions from the given function
let used_functions = used_functions(func);

// For each reachable function within the given function recursively collect
// any more reachable functions.
for called_func_id in used_functions.iter() {
collect_reachable_functions(ssa, *called_func_id, reachable_functions);
}
}

/// Identifies all reachable function IDs within a given function.
/// This includes:
/// - Function calls (functions used via `Call` instructions)
Expand Down
Loading