Skip to content

Commit b7c86f4

Browse files
authored
chore: put RcTracker as part of the DIE context (#7309)
1 parent 826b18a commit b7c86f4

File tree

1 file changed

+56
-26
lines changed
  • compiler/noirc_evaluator/src/ssa/opt

1 file changed

+56
-26
lines changed

compiler/noirc_evaluator/src/ssa/opt/die.rs

Lines changed: 56 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -135,9 +135,8 @@ struct Context {
135135
/// them just yet.
136136
flattened: bool,
137137

138-
// When tracking mutations we consider arrays with the same type as all being possibly mutated.
139-
// This we consider to span all blocks of the functions.
140-
mutated_array_types: HashSet<Type>,
138+
/// Track IncrementRc instructions per block to determine whether they are useless.
139+
rc_tracker: RcTracker,
141140
}
142141

143142
impl Context {
@@ -167,10 +166,8 @@ impl Context {
167166
let block = &function.dfg[block_id];
168167
self.mark_terminator_values_as_used(function, block);
169168

170-
// Lend the shared array type to the tracker.
171-
let mut mutated_array_types = std::mem::take(&mut self.mutated_array_types);
172-
let mut rc_tracker = RcTracker::new(&mut mutated_array_types);
173-
rc_tracker.mark_terminator_arrays_as_used(function, block);
169+
self.rc_tracker.new_block();
170+
self.rc_tracker.mark_terminator_arrays_as_used(function, block);
174171

175172
let instructions_len = block.instructions().len();
176173

@@ -203,12 +200,11 @@ impl Context {
203200
}
204201
}
205202

206-
rc_tracker.track_inc_rcs_to_remove(*instruction_id, function);
203+
self.rc_tracker.track_inc_rcs_to_remove(*instruction_id, function);
207204
}
208205

209-
self.instructions_to_remove.extend(rc_tracker.get_non_mutated_arrays(&function.dfg));
210-
self.instructions_to_remove.extend(rc_tracker.rc_pairs_to_remove);
211-
206+
self.instructions_to_remove.extend(self.rc_tracker.get_non_mutated_arrays(&function.dfg));
207+
self.instructions_to_remove.extend(self.rc_tracker.rc_pairs_to_remove.drain());
212208
// If there are some instructions that might trigger an out of bounds error,
213209
// first add constrain checks. Then run the DIE pass again, which will remove those
214210
// but leave the constrains (any any value needed by those constrains)
@@ -228,9 +224,6 @@ impl Context {
228224
.instructions_mut()
229225
.retain(|instruction| !self.instructions_to_remove.contains(instruction));
230226

231-
// Take the mutated array back.
232-
self.mutated_array_types = mutated_array_types;
233-
234227
false
235228
}
236229

@@ -279,11 +272,15 @@ impl Context {
279272
let typ = typ.get_contained_array();
280273
// Want to store the array type which is being referenced,
281274
// because it's the underlying array that the `inc_rc` is associated with.
282-
self.mutated_array_types.insert(typ.clone());
275+
self.add_mutated_array_type(typ.clone());
283276
}
284277
}
285278
}
286279

280+
fn add_mutated_array_type(&mut self, typ: Type) {
281+
self.rc_tracker.mutated_array_types.insert(typ.get_contained_array().clone());
282+
}
283+
287284
/// Go through the RC instructions collected when we figured out which values were unused;
288285
/// for each RC that refers to an unused value, remove the RC as well.
289286
fn remove_rc_instructions(&self, dfg: &mut DataFlowGraph) {
@@ -615,8 +612,9 @@ fn apply_side_effects(
615612
(lhs, rhs)
616613
}
617614

615+
#[derive(Default)]
618616
/// Per block RC tracker.
619-
struct RcTracker<'a> {
617+
struct RcTracker {
620618
// We can track IncrementRc instructions per block to determine whether they are useless.
621619
// IncrementRc and DecrementRc instructions are normally side effectual instructions, but we remove
622620
// them if their value is not used anywhere in the function. However, even when their value is used, their existence
@@ -631,23 +629,21 @@ struct RcTracker<'a> {
631629
// If an array is the same type as one of those non-mutated array types, we can safely remove all IncrementRc instructions on that array.
632630
inc_rcs: HashMap<ValueId, HashSet<InstructionId>>,
633631
// Mutated arrays shared across the blocks of the function.
634-
mutated_array_types: &'a mut HashSet<Type>,
632+
// When tracking mutations we consider arrays with the same type as all being possibly mutated.
633+
mutated_array_types: HashSet<Type>,
635634
// The SSA often creates patterns where after simplifications we end up with repeat
636635
// IncrementRc instructions on the same value. We track whether the previous instruction was an IncrementRc,
637636
// and if the current instruction is also an IncrementRc on the same value we remove the current instruction.
638637
// `None` if the previous instruction was anything other than an IncrementRc
639638
previous_inc_rc: Option<ValueId>,
640639
}
641640

642-
impl<'a> RcTracker<'a> {
643-
fn new(mutated_array_types: &'a mut HashSet<Type>) -> Self {
644-
Self {
645-
rcs_with_possible_pairs: Default::default(),
646-
rc_pairs_to_remove: Default::default(),
647-
inc_rcs: Default::default(),
648-
previous_inc_rc: Default::default(),
649-
mutated_array_types,
650-
}
641+
impl RcTracker {
642+
fn new_block(&mut self) {
643+
self.rcs_with_possible_pairs.clear();
644+
self.rc_pairs_to_remove.clear();
645+
self.inc_rcs.clear();
646+
self.previous_inc_rc = Default::default();
651647
}
652648

653649
fn mark_terminator_arrays_as_used(&mut self, function: &Function, block: &BasicBlock) {
@@ -1128,4 +1124,38 @@ mod test {
11281124
";
11291125
assert_normalized_ssa_equals(ssa, expected);
11301126
}
1127+
1128+
#[test]
1129+
fn do_not_remove_inc_rc_if_mutated_in_other_block() {
1130+
let src = "
1131+
brillig(inline) fn main f0 {
1132+
b0(v0: &mut [Field; 3]):
1133+
v1 = load v0 -> [Field; 3]
1134+
inc_rc v1
1135+
jmp b1()
1136+
b1():
1137+
v2 = load v0 -> [Field; 3]
1138+
v3 = array_set v2, index u32 0, value u32 0
1139+
store v3 at v0
1140+
return
1141+
}
1142+
";
1143+
let ssa = Ssa::from_str(src).unwrap();
1144+
1145+
let expected = "
1146+
brillig(inline) fn main f0 {
1147+
b0(v0: &mut [Field; 3]):
1148+
v1 = load v0 -> [Field; 3]
1149+
inc_rc v1
1150+
jmp b1()
1151+
b1():
1152+
v2 = load v0 -> [Field; 3]
1153+
v4 = array_set v2, index u32 0, value u32 0
1154+
store v4 at v0
1155+
return
1156+
}
1157+
";
1158+
let ssa = ssa.dead_instruction_elimination();
1159+
assert_normalized_ssa_equals(ssa, expected);
1160+
}
11311161
}

0 commit comments

Comments
 (0)