Skip to content

Commit 624ae6c

Browse files
vezenovmTomAFrenchjfecher
authored
feat(perf): Track last loads per block in mem2reg and remove them if possible (#6088)
Co-authored-by: Tom French <15848336+TomAFrench@users.noreply.github.com> Co-authored-by: jfecher <jake@aztecprotocol.com>
1 parent 6491175 commit 624ae6c

4 files changed

Lines changed: 239 additions & 3 deletions

File tree

compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs

Lines changed: 221 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
//! - A reference with 0 aliases means we were unable to find which reference this reference
1919
//! refers to. If such a reference is stored to, we must conservatively invalidate every
2020
//! reference in the current block.
21+
//! - We also track the last load instruction to each address per block.
2122
//!
2223
//! From there, to figure out the value of each reference at the end of block, iterate each instruction:
2324
//! - On `Instruction::Allocate`:
@@ -28,6 +29,13 @@
2829
//! - Furthermore, if the result of the load is a reference, mark the result as an alias
2930
//! of the reference it dereferences to (if known).
3031
//! - If which reference it dereferences to is not known, this load result has no aliases.
32+
//! - We also track the last instance of a load instruction to each address in a block.
33+
//! If we see that the last load instruction was from the same address as the current load instruction,
34+
//! we move to replace the result of the current load with the result of the previous load.
35+
//! This removal requires a couple conditions:
36+
//! - No store occurs to that address before the next load,
37+
//! - The address is not used as an argument to a call
38+
//! This optimization helps us remove repeated loads for which there are not known values.
3139
//! - On `Instruction::Store { address, value }`:
3240
//! - If the address of the store is known:
3341
//! - If the address has exactly 1 alias:
@@ -40,11 +48,13 @@
4048
//! - Conservatively mark every alias in the block to `Unknown`.
4149
//! - Additionally, if there were no Loads to any alias of the address between this Store and
4250
//! the previous Store to the same address, the previous store can be removed.
51+
//! - Remove the instance of the last load instruction to the address and its aliases
4352
//! - On `Instruction::Call { arguments }`:
4453
//! - If any argument of the call is a reference, set the value of each alias of that
4554
//! reference to `Unknown`
4655
//! - Any builtin functions that may return aliases if their input also contains a
4756
//! reference should be tracked. Examples: `slice_push_back`, `slice_insert`, `slice_remove`, etc.
57+
//! - Remove the instance of the last load instruction for any reference arguments and their aliases
4858
//!
4959
//! On a terminator instruction:
5060
//! - If the terminator is a `Jmp`:
@@ -274,6 +284,9 @@ impl<'f> PerFunctionContext<'f> {
274284
if let Some(first_predecessor) = predecessors.next() {
275285
let mut first = self.blocks.get(&first_predecessor).cloned().unwrap_or_default();
276286
first.last_stores.clear();
287+
// Last loads are tracked per block. During unification we are creating a new block from the current one,
288+
// so we must clear the last loads of the current block before we return the new block.
289+
first.last_loads.clear();
277290

278291
// Note that we have to start folding with the first block as the accumulator.
279292
// If we started with an empty block, an empty block union'd with any other block
@@ -410,6 +423,28 @@ impl<'f> PerFunctionContext<'f> {
410423

411424
self.last_loads.insert(address, (instruction, block_id));
412425
}
426+
427+
// Check whether the block has a repeat load from the same address (w/ no calls or stores in between the loads).
428+
// If we do have a repeat load, we can remove the current load and map its result to the previous load's result.
429+
if let Some(last_load) = references.last_loads.get(&address) {
430+
let Instruction::Load { address: previous_address } =
431+
&self.inserter.function.dfg[*last_load]
432+
else {
433+
panic!("Expected a Load instruction here");
434+
};
435+
let result = self.inserter.function.dfg.instruction_results(instruction)[0];
436+
let previous_result =
437+
self.inserter.function.dfg.instruction_results(*last_load)[0];
438+
if *previous_address == address {
439+
self.inserter.map_value(result, previous_result);
440+
self.instructions_to_remove.insert(instruction);
441+
}
442+
}
443+
// We want to set the load for every load even if the address has a known value
444+
// and the previous load instruction was removed.
445+
// We are safe to still remove a repeat load in this case as we are mapping from the current load's
446+
// result to the previous load, which if it was removed should already have a mapping to the known value.
447+
references.set_last_load(address, instruction);
413448
}
414449
Instruction::Store { address, value } => {
415450
let address = self.inserter.function.dfg.resolve(*address);
@@ -435,6 +470,8 @@ impl<'f> PerFunctionContext<'f> {
435470
}
436471

437472
references.set_known_value(address, value);
473+
// If we see a store to an address, the last load to that address needs to remain.
474+
references.keep_last_load_for(address, self.inserter.function);
438475
references.last_stores.insert(address, instruction);
439476
}
440477
Instruction::Allocate => {
@@ -542,6 +579,9 @@ impl<'f> PerFunctionContext<'f> {
542579
let value = self.inserter.function.dfg.resolve(*value);
543580
references.set_unknown(value);
544581
references.mark_value_used(value, self.inserter.function);
582+
583+
// If a reference is an argument to a call, the last load to that address and its aliases needs to remain.
584+
references.keep_last_load_for(value, self.inserter.function);
545585
}
546586
}
547587
}
@@ -572,6 +612,12 @@ impl<'f> PerFunctionContext<'f> {
572612
let destination_parameters = self.inserter.function.dfg[*destination].parameters();
573613
assert_eq!(destination_parameters.len(), arguments.len());
574614

615+
// If we have multiple parameters that alias that same argument value,
616+
// then those parameters also alias each other.
617+
// We save parameters with repeat arguments to later mark those
618+
// parameters as aliasing one another.
619+
let mut arg_set: HashMap<ValueId, BTreeSet<ValueId>> = HashMap::default();
620+
575621
// Add an alias for each reference parameter
576622
for (parameter, argument) in destination_parameters.iter().zip(arguments) {
577623
if self.inserter.function.dfg.value_is_reference(*parameter) {
@@ -581,10 +627,27 @@ impl<'f> PerFunctionContext<'f> {
581627
if let Some(aliases) = references.aliases.get_mut(expression) {
582628
// The argument reference is possibly aliased by this block parameter
583629
aliases.insert(*parameter);
630+
631+
// Check if we have seen the same argument
632+
let seen_parameters = arg_set.entry(argument).or_default();
633+
// Add the current parameter to the parameters we have seen for this argument.
634+
// The previous parameters and the current one alias one another.
635+
seen_parameters.insert(*parameter);
584636
}
585637
}
586638
}
587639
}
640+
641+
// Set the aliases of the parameters
642+
for (_, aliased_params) in arg_set {
643+
for param in aliased_params.iter() {
644+
self.set_aliases(
645+
references,
646+
*param,
647+
AliasSet::known_multiple(aliased_params.clone()),
648+
);
649+
}
650+
}
588651
}
589652
TerminatorInstruction::Return { return_values, .. } => {
590653
// Removing all `last_stores` for each returned reference is more important here
@@ -900,7 +963,7 @@ mod tests {
900963
// v10 = eq v9, Field 2
901964
// constrain v9 == Field 2
902965
// v11 = load v2
903-
// v12 = load v10
966+
// v12 = load v11
904967
// v13 = eq v12, Field 2
905968
// constrain v11 == Field 2
906969
// return
@@ -959,7 +1022,7 @@ mod tests {
9591022
let main = ssa.main();
9601023
assert_eq!(main.reachable_blocks().len(), 4);
9611024

962-
// The store from the original SSA should remain
1025+
// The stores from the original SSA should remain
9631026
assert_eq!(count_stores(main.entry_block(), &main.dfg), 2);
9641027
assert_eq!(count_stores(b2, &main.dfg), 1);
9651028

@@ -1006,4 +1069,160 @@ mod tests {
10061069
let main = ssa.main();
10071070
assert_eq!(count_loads(main.entry_block(), &main.dfg), 1);
10081071
}
1072+
1073+
#[test]
1074+
fn remove_repeat_loads() {
1075+
// This tests starts with two loads from the same unknown load.
1076+
// Specifically you should look for `load v2` in `b3`.
1077+
// We should be able to remove the second repeated load.
1078+
let src = "
1079+
acir(inline) fn main f0 {
1080+
b0():
1081+
v0 = allocate -> &mut Field
1082+
store Field 0 at v0
1083+
v2 = allocate -> &mut &mut Field
1084+
store v0 at v2
1085+
jmp b1(Field 0)
1086+
b1(v3: Field):
1087+
v4 = eq v3, Field 0
1088+
jmpif v4 then: b2, else: b3
1089+
b2():
1090+
v5 = load v2 -> &mut Field
1091+
store Field 2 at v5
1092+
v8 = add v3, Field 1
1093+
jmp b1(v8)
1094+
b3():
1095+
v9 = load v0 -> Field
1096+
v10 = eq v9, Field 2
1097+
constrain v9 == Field 2
1098+
v11 = load v2 -> &mut Field
1099+
v12 = load v2 -> &mut Field
1100+
v13 = load v12 -> Field
1101+
v14 = eq v13, Field 2
1102+
constrain v13 == Field 2
1103+
return
1104+
}
1105+
";
1106+
1107+
let ssa = Ssa::from_str(src).unwrap();
1108+
1109+
// The repeated load from v3 should be removed
1110+
// b3 should only have three loads now rather than four previously
1111+
//
1112+
// All stores are expected to remain.
1113+
let expected = "
1114+
acir(inline) fn main f0 {
1115+
b0():
1116+
v1 = allocate -> &mut Field
1117+
store Field 0 at v1
1118+
v3 = allocate -> &mut &mut Field
1119+
store v1 at v3
1120+
jmp b1(Field 0)
1121+
b1(v0: Field):
1122+
v4 = eq v0, Field 0
1123+
jmpif v4 then: b3, else: b2
1124+
b3():
1125+
v11 = load v3 -> &mut Field
1126+
store Field 2 at v11
1127+
v13 = add v0, Field 1
1128+
jmp b1(v13)
1129+
b2():
1130+
v5 = load v1 -> Field
1131+
v7 = eq v5, Field 2
1132+
constrain v5 == Field 2
1133+
v8 = load v3 -> &mut Field
1134+
v9 = load v8 -> Field
1135+
v10 = eq v9, Field 2
1136+
constrain v9 == Field 2
1137+
return
1138+
}
1139+
";
1140+
1141+
let ssa = ssa.mem2reg();
1142+
assert_normalized_ssa_equals(ssa, expected);
1143+
}
1144+
1145+
#[test]
1146+
fn keep_repeat_loads_passed_to_a_call() {
1147+
// The test is the exact same as `remove_repeat_loads` above except with the call
1148+
// to `f1` between the repeated loads.
1149+
let src = "
1150+
acir(inline) fn main f0 {
1151+
b0():
1152+
v1 = allocate -> &mut Field
1153+
store Field 0 at v1
1154+
v3 = allocate -> &mut &mut Field
1155+
store v1 at v3
1156+
jmp b1(Field 0)
1157+
b1(v0: Field):
1158+
v4 = eq v0, Field 0
1159+
jmpif v4 then: b3, else: b2
1160+
b3():
1161+
v13 = load v3 -> &mut Field
1162+
store Field 2 at v13
1163+
v15 = add v0, Field 1
1164+
jmp b1(v15)
1165+
b2():
1166+
v5 = load v1 -> Field
1167+
v7 = eq v5, Field 2
1168+
constrain v5 == Field 2
1169+
v8 = load v3 -> &mut Field
1170+
call f1(v3)
1171+
v10 = load v3 -> &mut Field
1172+
v11 = load v10 -> Field
1173+
v12 = eq v11, Field 2
1174+
constrain v11 == Field 2
1175+
return
1176+
}
1177+
acir(inline) fn foo f1 {
1178+
b0(v0: &mut Field):
1179+
return
1180+
}
1181+
";
1182+
1183+
let ssa = Ssa::from_str(src).unwrap();
1184+
1185+
let ssa = ssa.mem2reg();
1186+
// We expect the program to be unchanged
1187+
assert_normalized_ssa_equals(ssa, src);
1188+
}
1189+
1190+
#[test]
1191+
fn keep_repeat_loads_with_alias_store() {
1192+
// v7, v8, and v9 alias one another. We want to make sure that a repeat load to v7 with a store
1193+
// to its aliases in between the repeat loads does not remove those loads.
1194+
let src = "
1195+
acir(inline) fn main f0 {
1196+
b0(v0: u1):
1197+
jmpif v0 then: b2, else: b1
1198+
b2():
1199+
v6 = allocate -> &mut Field
1200+
store Field 0 at v6
1201+
jmp b3(v6, v6, v6)
1202+
b3(v1: &mut Field, v2: &mut Field, v3: &mut Field):
1203+
v8 = load v1 -> Field
1204+
store Field 2 at v2
1205+
v10 = load v1 -> Field
1206+
store Field 1 at v3
1207+
v11 = load v1 -> Field
1208+
store Field 3 at v3
1209+
v13 = load v1 -> Field
1210+
constrain v8 == Field 0
1211+
constrain v10 == Field 2
1212+
constrain v11 == Field 1
1213+
constrain v13 == Field 3
1214+
return
1215+
b1():
1216+
v4 = allocate -> &mut Field
1217+
store Field 1 at v4
1218+
jmp b3(v4, v4, v4)
1219+
}
1220+
";
1221+
1222+
let ssa = Ssa::from_str(src).unwrap();
1223+
1224+
let ssa = ssa.mem2reg();
1225+
// We expect the program to be unchanged
1226+
assert_normalized_ssa_equals(ssa, src);
1227+
}
10091228
}

compiler/noirc_evaluator/src/ssa/opt/mem2reg/alias_set.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ impl AliasSet {
2424
Self { aliases: Some(aliases) }
2525
}
2626

27+
pub(super) fn known_multiple(values: BTreeSet<ValueId>) -> AliasSet {
28+
Self { aliases: Some(values) }
29+
}
30+
2731
/// In rare cases, such as when creating an empty array of references, the set of aliases for a
2832
/// particular value will be known to be zero, which is distinct from being unknown and
2933
/// possibly referring to any alias.

compiler/noirc_evaluator/src/ssa/opt/mem2reg/block.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ pub(super) struct Block {
3434

3535
/// The last instance of a `Store` instruction to each address in this block
3636
pub(super) last_stores: im::OrdMap<ValueId, InstructionId>,
37+
38+
// The last instance of a `Load` instruction to each address in this block
39+
pub(super) last_loads: im::OrdMap<ValueId, InstructionId>,
3740
}
3841

3942
/// An `Expression` here is used to represent a canonical key
@@ -237,4 +240,14 @@ impl Block {
237240

238241
Cow::Owned(AliasSet::unknown())
239242
}
243+
244+
pub(super) fn set_last_load(&mut self, address: ValueId, instruction: InstructionId) {
245+
self.last_loads.insert(address, instruction);
246+
}
247+
248+
pub(super) fn keep_last_load_for(&mut self, address: ValueId, function: &Function) {
249+
let address = function.dfg.resolve(address);
250+
self.last_loads.remove(&address);
251+
self.for_each_alias_of(address, |block, alias| block.last_loads.remove(&alias));
252+
}
240253
}

tooling/debugger/tests/debug.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ mod tests {
1212
let nargo_bin =
1313
cargo_bin("nargo").into_os_string().into_string().expect("Cannot parse nargo path");
1414

15-
let timeout_seconds = 25;
15+
let timeout_seconds = 30;
1616
let mut dbg_session =
1717
spawn_bash(Some(timeout_seconds * 1000)).expect("Could not start bash session");
1818

0 commit comments

Comments
 (0)