diff --git a/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs b/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs index 5b1139e5b9c..e5a25dcfef1 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs @@ -66,6 +66,8 @@ mod block; use std::collections::{BTreeMap, BTreeSet}; +use fxhash::FxHashMap as HashMap; + use crate::ssa::{ ir::{ basic_block::BasicBlockId, @@ -111,6 +113,10 @@ struct PerFunctionContext<'f> { /// We avoid removing individual instructions as we go since removing elements /// from the middle of Vecs many times will be slower than a single call to `retain`. instructions_to_remove: BTreeSet, + + /// Track a value's last load across all blocks. + /// If a value is not used in anymore loads we can remove the last store to that value. + last_loads: HashMap, } impl<'f> PerFunctionContext<'f> { @@ -124,6 +130,7 @@ impl<'f> PerFunctionContext<'f> { inserter: FunctionInserter::new(function), blocks: BTreeMap::new(), instructions_to_remove: BTreeSet::new(), + last_loads: HashMap::default(), } } @@ -140,6 +147,18 @@ impl<'f> PerFunctionContext<'f> { let references = self.find_starting_references(block); self.analyze_block(block, references); } + + // If we never load from an address within a function we can remove all stores to that address. + // This rule does not apply to reference parameters, which we must also check for before removing these stores. + for (block_id, block) in self.blocks.iter() { + let block_params = self.inserter.function.dfg.block_parameters(*block_id); + for (value, store_instruction) in block.last_stores.iter() { + let is_reference_param = block_params.contains(value); + if self.last_loads.get(value).is_none() && !is_reference_param { + self.instructions_to_remove.insert(*store_instruction); + } + } + } } /// The value of each reference at the start of the given block is the unification @@ -239,6 +258,8 @@ impl<'f> PerFunctionContext<'f> { self.instructions_to_remove.insert(instruction); } else { references.mark_value_used(address, self.inserter.function); + + self.last_loads.insert(address, instruction); } } Instruction::Store { address, value } => { @@ -594,10 +615,8 @@ mod tests { // acir fn main f0 { // b0(): // v7 = allocate - // store Field 5 at v7 // jmp b1(Field 5) // b1(v3: Field): - // store Field 6 at v7 // return v3, Field 5, Field 6 // } let ssa = ssa.mem2reg(); @@ -609,9 +628,9 @@ mod tests { assert_eq!(count_loads(main.entry_block(), &main.dfg), 0); assert_eq!(count_loads(b1, &main.dfg), 0); - // Neither store is removed since they are each the last in the block and there are multiple blocks - assert_eq!(count_stores(main.entry_block(), &main.dfg), 1); - assert_eq!(count_stores(b1, &main.dfg), 1); + // All stores are removed as there are no loads to the values being stored anywhere in the function. + assert_eq!(count_stores(main.entry_block(), &main.dfg), 0); + assert_eq!(count_stores(b1, &main.dfg), 0); // The jmp to b1 should also be a constant 5 now match main.dfg[main.entry_block()].terminator() { @@ -641,8 +660,8 @@ mod tests { // b1(): // store Field 1 at v3 // store Field 2 at v4 - // v8 = load v3 - // v9 = eq v8, Field 2 + // v7 = load v3 + // v8 = eq v7, Field 2 // return // } let main_id = Id::test_new(0); @@ -681,12 +700,9 @@ mod tests { // acir fn main f0 { // b0(): // v9 = allocate - // store Field 0 at v9 // v10 = allocate - // store v9 at v10 // jmp b1() // b1(): - // store Field 2 at v9 // return // } let ssa = ssa.mem2reg(); @@ -698,14 +714,17 @@ mod tests { assert_eq!(count_loads(main.entry_block(), &main.dfg), 0); assert_eq!(count_loads(b1, &main.dfg), 0); - // Only the first store in b1 is removed since there is another store to the same reference + // All stores should be removed. + // The first store in b1 is removed since there is another store to the same reference // in the same block, and the store is not needed before the later store. - assert_eq!(count_stores(main.entry_block(), &main.dfg), 2); - assert_eq!(count_stores(b1, &main.dfg), 1); + // The rest of the stores are also removed as no loads are done within any blocks + // to the stored values. + assert_eq!(count_stores(main.entry_block(), &main.dfg), 0); + assert_eq!(count_stores(b1, &main.dfg), 0); let b1_instructions = main.dfg[b1].instructions(); // We expect the last eq to be optimized out - assert_eq!(b1_instructions.len(), 1); + assert_eq!(b1_instructions.len(), 0); } }