diff --git a/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs b/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs index 9d6582c0db7..3d98f4126cf 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs @@ -116,7 +116,7 @@ struct PerFunctionContext<'f> { /// Track a value's last load across all blocks. /// If a value is not used in anymore loads we can remove the last store to that value. - last_loads: HashMap, + last_loads: HashMap, } impl<'f> PerFunctionContext<'f> { @@ -152,9 +152,31 @@ impl<'f> PerFunctionContext<'f> { // This rule does not apply to reference parameters, which we must also check for before removing these stores. for (block_id, block) in self.blocks.iter() { let block_params = self.inserter.function.dfg.block_parameters(*block_id); - for (value, store_instruction) in block.last_stores.iter() { - let is_reference_param = block_params.contains(value); - if self.last_loads.get(value).is_none() && !is_reference_param { + for (store_address, store_instruction) in block.last_stores.iter() { + let is_reference_param = block_params.contains(store_address); + let terminator = self.inserter.function.dfg[*block_id].unwrap_terminator(); + + let is_return = matches!(terminator, TerminatorInstruction::Return { .. }); + let remove_load = if is_return { + // Determine whether the last store is used in the return value + let mut is_return_value = false; + terminator.for_each_value(|return_value| { + is_return_value = return_value == *store_address || is_return_value; + }); + + // If the last load of a store is not part of the block with a return terminator, + // we can safely remove this store. + let last_load_not_in_return = self + .last_loads + .get(store_address) + .map(|(_, last_load_block)| *last_load_block != *block_id) + .unwrap_or(true); + !is_return_value && last_load_not_in_return + } else { + self.last_loads.get(store_address).is_none() + }; + + if remove_load && !is_reference_param { self.instructions_to_remove.insert(*store_instruction); } } @@ -259,7 +281,7 @@ impl<'f> PerFunctionContext<'f> { } else { references.mark_value_used(address, self.inserter.function); - self.last_loads.insert(address, instruction); + self.last_loads.insert(address, (instruction, block_id)); } } Instruction::Store { address, value } => {