Skip to content
98 changes: 78 additions & 20 deletions compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
//! This includes branches in the CFG with non-constant conditions. Flattening these requires
//! special handling for operations with side-effects and can lead to a loss of information since
//! the jmpif will no longer be in the program. As a result, this pass should usually be towards or
//! at the end of the optimization passes. Note that this pass will also perform unexpectedly if
//! loops are still present in the program. Since the pass sees a normal jmpif, it will attempt to
//! merge both blocks, but no actual looping will occur.
//! at the end of the optimization passes.
//! Furthermore, this pass assumes that no loops are present in the program and will assume
//! that a jmpif is a branch point and will attempt to merge both blocks. No actual looping will occur.
//!
//! This pass is also known to produce some extra instructions which may go unused (usually 'Not')
//! while merging branches. These extra instructions can be cleaned up by a later dead instruction
Expand Down Expand Up @@ -218,7 +218,9 @@ pub(crate) struct Context<'f> {
not_instructions: HashMap<ValueId, ValueId>,

/// Flag to tell the context to not issue 'enable_side_effect' instructions during flattening.
/// This should be set to true only by flatten_single(), when no instruction is known to fail.
/// It is set with an attribute when defining a function that cannot fail whatsoever to avoid
/// the overhead of handling side effects.
/// It can also be set to true by flatten_single(), when no instruction is known to fail.
pub(crate) no_predicate: bool,
}

Expand Down Expand Up @@ -247,13 +249,19 @@ struct ConditionalContext {
call_stack: CallStackId,
}

/// Flattens the control flow graph of the function such that it is left with a
/// single block containing all instructions and no more control-flow.
fn flatten_function_cfg(function: &mut Function, no_predicates: &HashMap<FunctionId, bool>) {
// This pass may run forever on a brillig function.
// Analyze will check if the predecessors have been processed and push the block to the back of
// the queue. This loops forever if there are still any loops present in the program.
if matches!(function.runtime(), RuntimeType::Brillig(_)) {
return;
}

// Creates a context that will perform the flattening
// We give it the map of the conditional branches in the CFG
// and the target block where the flattened instructions should be added.
let cfg = ControlFlowGraph::with_function(function);
let branch_ends = branch_analysis::find_branch_ends(function, &cfg);
let target_block = function.entry_block();
Expand Down Expand Up @@ -293,16 +301,29 @@ impl<'f> Context<'f> {
}
}

/// Flatten the CFG by inlining all instructions from the queued blocks
/// until all blocks have been flattened.
/// We follow the terminator of each block to determine which blocks to
/// process next:
/// If the terminator is a 'JumpIf', we assume we are entering a conditional statement and
/// add the start blocks of the 'then_branch', 'else_branch' and the 'exit' block to the queue.
/// Other blocks will have only one successor, so we will process them iteratively,
/// until we reach one block already in the queue, i.e added when entering a conditional statement,
/// i.e the 'else_branch' or the 'exit'. In that case we switch to the next block in the queue, instead
/// of the successor.
/// This process ensure that the blocks are always processed in this order:
/// if_entry -> then_branch -> else_branch -> exit
/// In case of nested if statements, for instance in the 'then_branch', it will be:
/// if_entry -> then_branch -> if_entry_2 -> then_branch_2 -> exit_2 -> else_branch -> exit
/// Information about the nested if statements is stored in the 'condition_stack' which
/// is pop-ed/push-ed when entering/leaving a conditional statement.
pub(crate) fn flatten(&mut self, no_predicates: &HashMap<FunctionId, bool>) {
// Flatten the CFG by inlining all instructions from the queued blocks
// until all blocks have been flattened.
// We follow the terminator of each block to determine which blocks to
// process next
let mut queue = vec![self.target_block];
while let Some(block) = queue.pop() {
self.inline_block(block, no_predicates);
let to_process = self.handle_terminator(block, &queue);
for incoming_block in to_process {
// Do not add blocks already in the queue
if !queue.contains(&incoming_block) {
queue.push(incoming_block);
}
Expand All @@ -326,6 +347,10 @@ impl<'f> Context<'f> {
}

/// Returns the current condition
/// The conditions are in a stack, they are added as conditional branches are encountered
/// so the last one is the current condition.
/// When processing a conditional branch, we first follow the 'then' branch and only after we
/// process the 'else' branch. At that point, the ConditionalContext has the 'else_branch'
fn get_last_condition(&self) -> Option<ValueId> {
self.condition_stack.last().map(|context| match &context.else_branch {
Some(else_branch) => else_branch.condition,
Expand All @@ -348,7 +373,11 @@ impl<'f> Context<'f> {
result
}

// Inline all instructions from the given block into the target block, and track slice capacities
/// Inline all instructions from the given block into the target block, and track slice capacities
/// This is done by processing every instructions in the block and using the flattening context
/// to push them in the target block
///
/// - `no_predicates` indicates which functions have no predicates and for which we disable the handling side effects
pub(crate) fn inline_block(
&mut self,
block: BasicBlockId,
Expand Down Expand Up @@ -388,6 +417,12 @@ impl<'f> Context<'f> {
/// For a normal block, it would be its successor
/// For blocks related to a conditional statement, we ensure to process
/// the 'then-branch', then the 'else-branch' (if it exists), and finally the end block
/// The update of the context is done by the functions 'if_start', 'then_stop' and 'else_stop'
/// which perform the business logic when entering a conditional statement, finishing the 'then-branch'
/// and the 'else-branch, respectively.
/// We know if a block is related to the conditional statement if is referenced by the 'work_list'
/// Indeed, the start blocks of the 'then_branch' and 'else_branch' are added to the 'work_list' when
/// starting to process a conditional statement.
pub(crate) fn handle_terminator(
&mut self,
block: BasicBlockId,
Expand Down Expand Up @@ -430,7 +465,11 @@ impl<'f> Context<'f> {
}
}

/// Process a conditional statement
/// Process a conditional statement by creating a 'ConditionalContext'
/// with information about the branch, and storing it in the dedicated stack.
/// Local allocations are moved to the 'then_branch' of the ConditionalContext.
/// Returns the blocks corresponding to the 'then_branch', 'else_branch', and exit block of the conditional statement,
/// so that they will be processed in this order.
fn if_start(
&mut self,
condition: &ValueId,
Expand Down Expand Up @@ -472,7 +511,11 @@ impl<'f> Context<'f> {
vec![self.branch_ends[if_entry], *else_destination, *then_destination]
}

/// Switch context to the 'else-branch'
/// Switch context to the 'else-branch':
/// - Negates the condition for the 'else_branch' and set it in the ConditionalContext
/// - Move the local allocations to the 'else_branch'
/// - Issues the 'enable_side_effect' instruction
/// - Returns the exit block of the conditional statement
fn then_stop(&mut self, block: &BasicBlockId) -> Vec<BasicBlockId> {
let mut cond_context = self.condition_stack.pop().unwrap();
cond_context.then_branch.last_block = *block;
Expand Down Expand Up @@ -500,6 +543,7 @@ impl<'f> Context<'f> {
vec![self.cfg.successors(*block).next().unwrap()]
}

/// Negates a boolean value by inserting a Not instruction
fn not_instruction(&mut self, condition: ValueId, call_stack: CallStackId) -> ValueId {
if let Some(existing) = self.not_instructions.get(&condition) {
return *existing;
Expand All @@ -510,7 +554,10 @@ impl<'f> Context<'f> {
not
}

/// Process the 'exit' block of a conditional statement
/// Process the 'exit' block of a conditional statement:
/// - Retrieves the local allocations from the Conditional Context
/// - Issues the 'enable_side_effect' instruction
/// - Joins the arguments from both branches
fn else_stop(&mut self, block: &BasicBlockId) -> Vec<BasicBlockId> {
let mut cond_context = self.condition_stack.pop().unwrap();
if cond_context.else_branch.is_none() {
Expand Down Expand Up @@ -547,8 +594,9 @@ impl<'f> Context<'f> {
/// all of the join point's predecessors, and it must handle any differing side effects from
/// each branch.
///
/// Afterwards, continues inlining recursively until it finds the next end block or finds the
/// end of the function.
/// The merge of arguments is done by inserting an 'IfElse' instructions which returns
/// the argument from the then_branch or the else_branch depending the the condition.
/// They are added to the 'arguments_stack' instead of the arguments of the 2 branches.
///
/// Returns the final block that was inlined.
fn inline_branch_end(
Expand Down Expand Up @@ -678,8 +726,10 @@ impl<'f> Context<'f> {
}
}

/// If we are currently in a branch, we need to modify constrain instructions
/// to multiply them by the branch's condition (see optimization #1 in the module comment).
/// If we are currently in a branch, we need to modify instructions that have side effects
/// (e.g. constraints, stores, range checks) to ensure that the side effect is only applied
/// if their branch is taken.
/// For instance we multiply constrain instructions by the branch's condition (see optimization #1 in the module comment).
fn handle_instruction_side_effects(
&mut self,
instruction: Instruction,
Expand All @@ -698,12 +748,12 @@ impl<'f> Context<'f> {
Instruction::Constrain(lhs, rhs, message)
}
Instruction::Store { address, value } => {
// If this instruction immediately follows an allocate, and stores to that
// address there is no previous value to load and we don't need a merge anyway.
// There is no side effect for 'Store' instructions on arrays allocated in the current branch, because
// these arrays do not exist in other branches.
if self.local_allocations.contains(&address) {
Instruction::Store { address, value }
} else {
// Instead of storing `value`, store `if condition { value } else { previous_value }`
// Instead of storing `value`, we store: `if condition { value } else { previous_value }`
let typ = self.inserter.function.dfg.type_of_value(value);
let load = Instruction::Load { address };
let previous_value = self
Expand Down Expand Up @@ -734,6 +784,8 @@ impl<'f> Context<'f> {
}
Instruction::Call { func, mut arguments } => match self.inserter.function.dfg[func]
{
// A ToBits (or ToRadix in general) can fail if the input has more bits than the target.
// We ensure it does not fail by multiplying the input by the condition.
Value::Intrinsic(Intrinsic::ToBits(_) | Intrinsic::ToRadix(_)) => {
let field = arguments[0];
let casted_condition =
Expand All @@ -744,13 +796,14 @@ impl<'f> Context<'f> {

Instruction::Call { func, arguments }
}
//Issue #5045: We set curve points to infinity if condition is false
//Issue #5045: We set curve points to infinity if condition is false, to ensure that they are on the curve, if not the addition may fail.
Value::Intrinsic(Intrinsic::BlackBox(BlackBoxFunc::EmbeddedCurveAdd)) => {
arguments[2] = self.var_or_one(arguments[2], condition, call_stack);
arguments[5] = self.var_or_one(arguments[5], condition, call_stack);

Instruction::Call { func, arguments }
}
// For MSM, we also ensure the inputs are on the curve if the predicate is false.
Value::Intrinsic(Intrinsic::BlackBox(BlackBoxFunc::MultiScalarMul)) => {
let points_array_idx = if matches!(
self.inserter.function.dfg.type_of_value(arguments[0]),
Expand Down Expand Up @@ -782,6 +835,10 @@ impl<'f> Context<'f> {
}
}

/// 'Cast' the 'condition' to 'value' type
/// This needed because we need to multiply the condition with several values
/// in order to 'nullify' side-effects when the 'condition' is false (in 'handle_instruction_side_effects()' function).
/// Since the condition is a boolean, it can be safely casted to any other type.
fn cast_condition_to_value_type(
&mut self,
condition: ValueId,
Expand All @@ -793,6 +850,7 @@ impl<'f> Context<'f> {
self.insert_instruction(cast, call_stack)
}

/// Insert a multiplication between 'condition' and 'value'
fn mul_by_condition(
&mut self,
value: ValueId,
Expand Down
Loading