diff --git a/compiler/noirc_evaluator/src/ssa/ir/call_graph.rs b/compiler/noirc_evaluator/src/ssa/ir/call_graph.rs index d55ace80b52..093dac3a452 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/call_graph.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/call_graph.rs @@ -14,7 +14,7 @@ use fxhash::{FxHashMap as HashMap, FxHashSet as HashSet}; use petgraph::{ algo::kosaraju_scc, graph::{DiGraph, NodeIndex as PetGraphIndex}, - visit::EdgeRef, + visit::{Dfs, EdgeRef, Walker}, }; use crate::ssa::ssa_gen::Ssa; @@ -236,6 +236,28 @@ impl CallGraph { counts } + + /// Returns all functions reachable from the provided root(s). + /// + /// This function uses DFS internally to find all nodes reachable from the provided root(s). + pub(crate) fn reachable_from( + &self, + roots: impl IntoIterator, + ) -> HashSet { + let mut reachable = HashSet::default(); + + for root in roots { + // If the root does not exist, skip it. + let Some(&start_index) = self.ids_to_indices.get(&root) else { + continue; + }; + // Use DFS to determine all reachable nodes from the root + let dfs = Dfs::new(&self.graph, start_index); + reachable.extend(dfs.iter(&self.graph).map(|index| self.indices_to_ids[&index])); + } + + reachable + } } /// Utility function to find out the direct calls of a function. diff --git a/compiler/noirc_evaluator/src/ssa/opt/remove_unreachable.rs b/compiler/noirc_evaluator/src/ssa/opt/remove_unreachable.rs index ce94f3bd4d1..ca0e36a98f6 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/remove_unreachable.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/remove_unreachable.rs @@ -11,18 +11,19 @@ //! Even if not immediately called, it may later be dynamically loaded and invoked. //! This marking is conservative but ensures correctness. We should instead rely on [mem2reg][crate::ssa::opt::mem2reg] //! for resolving loads/stores. +//! - A function is reachable if it is used in a block terminator (e.g., returned from a function) //! -//! The pass performs a recursive traversal starting from all entry points and marks -//! any transitively reachable functions. It then discards the rest. +//! The pass builds a call graph based upon the definition of reachability above. +//! It then identifies all entry points and uses the [CallGraph::reachable_from] utility +//! to mark all transitively reachable functions. It then discards the rest. //! //! This pass helps shrink the SSA before compilation stages like inlining and dead code elimination. use std::collections::BTreeSet; -use fxhash::FxHashSet as HashSet; - use crate::ssa::{ ir::{ + call_graph::CallGraph, function::{Function, FunctionId}, instruction::Instruction, value::Value, @@ -33,20 +34,22 @@ use crate::ssa::{ impl Ssa { /// See [`remove_unreachable`][self] module for more information. pub(crate) fn remove_unreachable_functions(mut self) -> Self { - let mut reachable_functions = HashSet::default(); - - // Go through all the functions, and if we have an entry point, extend the set of all - // functions which are reachable. - for (id, function) in self.functions.iter() { + // Identify entry points + let entry_points = self.functions.iter().filter_map(|(&id, func)| { // Not using `Ssa::is_entry_point` because it could leave Brillig functions that nobody calls in the SSA, // because it considers every Brillig function as an entry point. - let is_entry_point = function.id() == self.main_id - || function.runtime().is_acir() && function.runtime().is_entry_point(); + let is_entry_point = + id == self.main_id || func.runtime().is_acir() && func.runtime().is_entry_point(); + is_entry_point.then_some(id) + }); - if is_entry_point { - collect_reachable_functions(&self, *id, &mut reachable_functions); - } - } + // Build call graph dependencies using this passes definition of reachability. + let dependencies = + self.functions.iter().map(|(&id, func)| (id, used_functions(func))).collect(); + let call_graph = CallGraph::from_deps(dependencies); + + // Traverse the call graph from all entry points + let reachable_functions = call_graph.reachable_from(entry_points); // Discard all functions not marked as reachable self.functions.retain(|id, _| reachable_functions.contains(id)); @@ -54,45 +57,6 @@ impl Ssa { } } -/// Recursively determine the reachable functions from a given function. -/// This function is only intended to be called on functions that are already known -/// to be entry points or transitively reachable from one. -/// -/// # Arguments -/// - `ssa`: The full [Ssa] structure containing all functions. -/// - `current_func_id`: The [FunctionId] from which to begin a traversal. -/// - `reachable_functions`: A mutable set used to collect all reachable functions. -/// It serves both as the final output of this traversal and as a visited set -/// to prevent cycles and redundant recursion. -fn collect_reachable_functions( - ssa: &Ssa, - current_func_id: FunctionId, - reachable_functions: &mut HashSet, -) { - // If this function has already been determine as reachable, then we have already - // processed the given function and we can simply return. - if reachable_functions.contains(¤t_func_id) { - return; - } - // Mark the given function as reachable - reachable_functions.insert(current_func_id); - - // If the debugger is used, its possible for function inlining - // to remove functions that the debugger still references - let Some(func) = ssa.functions.get(¤t_func_id) else { - return; - }; - - // Get the set of reachable functions from the given function - let used_functions = used_functions(func); - - // For each reachable function within the given function recursively collect - // any more reachable functions. - for called_func_id in used_functions.iter() { - collect_reachable_functions(ssa, *called_func_id, reachable_functions); - } -} - /// Identifies all reachable function IDs within a given function. /// This includes: /// - Function calls (functions used via `Call` instructions)