1818//! - A reference with 0 aliases means we were unable to find which reference this reference
1919//! refers to. If such a reference is stored to, we must conservatively invalidate every
2020//! reference in the current block.
21+ //! - We also track the last load instruction to each address per block.
2122//!
2223//! From there, to figure out the value of each reference at the end of block, iterate each instruction:
2324//! - On `Instruction::Allocate`:
2829//! - Furthermore, if the result of the load is a reference, mark the result as an alias
2930//! of the reference it dereferences to (if known).
3031//! - If which reference it dereferences to is not known, this load result has no aliases.
32+ //! - We also track the last instance of a load instruction to each address in a block.
33+ //! If we see that the last load instruction was from the same address as the current load instruction,
34+ //! we move to replace the result of the current load with the result of the previous load.
35+ //! This removal requires a couple conditions:
36+ //! - No store occurs to that address before the next load,
37+ //! - The address is not used as an argument to a call
38+ //! This optimization helps us remove repeated loads for which there are not known values.
3139//! - On `Instruction::Store { address, value }`:
3240//! - If the address of the store is known:
3341//! - If the address has exactly 1 alias:
4048//! - Conservatively mark every alias in the block to `Unknown`.
4149//! - Additionally, if there were no Loads to any alias of the address between this Store and
4250//! the previous Store to the same address, the previous store can be removed.
51+ //! - Remove the instance of the last load instruction to the address and its aliases
4352//! - On `Instruction::Call { arguments }`:
4453//! - If any argument of the call is a reference, set the value of each alias of that
4554//! reference to `Unknown`
4655//! - Any builtin functions that may return aliases if their input also contains a
4756//! reference should be tracked. Examples: `slice_push_back`, `slice_insert`, `slice_remove`, etc.
57+ //! - Remove the instance of the last load instruction for any reference arguments and their aliases
4858//!
4959//! On a terminator instruction:
5060//! - If the terminator is a `Jmp`:
@@ -274,6 +284,9 @@ impl<'f> PerFunctionContext<'f> {
274284 if let Some ( first_predecessor) = predecessors. next ( ) {
275285 let mut first = self . blocks . get ( & first_predecessor) . cloned ( ) . unwrap_or_default ( ) ;
276286 first. last_stores . clear ( ) ;
287+ // Last loads are tracked per block. During unification we are creating a new block from the current one,
288+ // so we must clear the last loads of the current block before we return the new block.
289+ first. last_loads . clear ( ) ;
277290
278291 // Note that we have to start folding with the first block as the accumulator.
279292 // If we started with an empty block, an empty block union'd with any other block
@@ -410,6 +423,28 @@ impl<'f> PerFunctionContext<'f> {
410423
411424 self . last_loads . insert ( address, ( instruction, block_id) ) ;
412425 }
426+
427+ // Check whether the block has a repeat load from the same address (w/ no calls or stores in between the loads).
428+ // If we do have a repeat load, we can remove the current load and map its result to the previous load's result.
429+ if let Some ( last_load) = references. last_loads . get ( & address) {
430+ let Instruction :: Load { address : previous_address } =
431+ & self . inserter . function . dfg [ * last_load]
432+ else {
433+ panic ! ( "Expected a Load instruction here" ) ;
434+ } ;
435+ let result = self . inserter . function . dfg . instruction_results ( instruction) [ 0 ] ;
436+ let previous_result =
437+ self . inserter . function . dfg . instruction_results ( * last_load) [ 0 ] ;
438+ if * previous_address == address {
439+ self . inserter . map_value ( result, previous_result) ;
440+ self . instructions_to_remove . insert ( instruction) ;
441+ }
442+ }
443+ // We want to set the load for every load even if the address has a known value
444+ // and the previous load instruction was removed.
445+ // We are safe to still remove a repeat load in this case as we are mapping from the current load's
446+ // result to the previous load, which if it was removed should already have a mapping to the known value.
447+ references. set_last_load ( address, instruction) ;
413448 }
414449 Instruction :: Store { address, value } => {
415450 let address = self . inserter . function . dfg . resolve ( * address) ;
@@ -435,6 +470,8 @@ impl<'f> PerFunctionContext<'f> {
435470 }
436471
437472 references. set_known_value ( address, value) ;
473+ // If we see a store to an address, the last load to that address needs to remain.
474+ references. keep_last_load_for ( address, self . inserter . function ) ;
438475 references. last_stores . insert ( address, instruction) ;
439476 }
440477 Instruction :: Allocate => {
@@ -542,6 +579,9 @@ impl<'f> PerFunctionContext<'f> {
542579 let value = self . inserter . function . dfg . resolve ( * value) ;
543580 references. set_unknown ( value) ;
544581 references. mark_value_used ( value, self . inserter . function ) ;
582+
583+ // If a reference is an argument to a call, the last load to that address and its aliases needs to remain.
584+ references. keep_last_load_for ( value, self . inserter . function ) ;
545585 }
546586 }
547587 }
@@ -572,6 +612,12 @@ impl<'f> PerFunctionContext<'f> {
572612 let destination_parameters = self . inserter . function . dfg [ * destination] . parameters ( ) ;
573613 assert_eq ! ( destination_parameters. len( ) , arguments. len( ) ) ;
574614
615+ // If we have multiple parameters that alias that same argument value,
616+ // then those parameters also alias each other.
617+ // We save parameters with repeat arguments to later mark those
618+ // parameters as aliasing one another.
619+ let mut arg_set: HashMap < ValueId , BTreeSet < ValueId > > = HashMap :: default ( ) ;
620+
575621 // Add an alias for each reference parameter
576622 for ( parameter, argument) in destination_parameters. iter ( ) . zip ( arguments) {
577623 if self . inserter . function . dfg . value_is_reference ( * parameter) {
@@ -581,10 +627,27 @@ impl<'f> PerFunctionContext<'f> {
581627 if let Some ( aliases) = references. aliases . get_mut ( expression) {
582628 // The argument reference is possibly aliased by this block parameter
583629 aliases. insert ( * parameter) ;
630+
631+ // Check if we have seen the same argument
632+ let seen_parameters = arg_set. entry ( argument) . or_default ( ) ;
633+ // Add the current parameter to the parameters we have seen for this argument.
634+ // The previous parameters and the current one alias one another.
635+ seen_parameters. insert ( * parameter) ;
584636 }
585637 }
586638 }
587639 }
640+
641+ // Set the aliases of the parameters
642+ for ( _, aliased_params) in arg_set {
643+ for param in aliased_params. iter ( ) {
644+ self . set_aliases (
645+ references,
646+ * param,
647+ AliasSet :: known_multiple ( aliased_params. clone ( ) ) ,
648+ ) ;
649+ }
650+ }
588651 }
589652 TerminatorInstruction :: Return { return_values, .. } => {
590653 // Removing all `last_stores` for each returned reference is more important here
@@ -900,7 +963,7 @@ mod tests {
900963 // v10 = eq v9, Field 2
901964 // constrain v9 == Field 2
902965 // v11 = load v2
903- // v12 = load v10
966+ // v12 = load v11
904967 // v13 = eq v12, Field 2
905968 // constrain v11 == Field 2
906969 // return
@@ -959,7 +1022,7 @@ mod tests {
9591022 let main = ssa. main ( ) ;
9601023 assert_eq ! ( main. reachable_blocks( ) . len( ) , 4 ) ;
9611024
962- // The store from the original SSA should remain
1025+ // The stores from the original SSA should remain
9631026 assert_eq ! ( count_stores( main. entry_block( ) , & main. dfg) , 2 ) ;
9641027 assert_eq ! ( count_stores( b2, & main. dfg) , 1 ) ;
9651028
@@ -1006,4 +1069,160 @@ mod tests {
10061069 let main = ssa. main ( ) ;
10071070 assert_eq ! ( count_loads( main. entry_block( ) , & main. dfg) , 1 ) ;
10081071 }
1072+
1073+ #[ test]
1074+ fn remove_repeat_loads ( ) {
1075+ // This tests starts with two loads from the same unknown load.
1076+ // Specifically you should look for `load v2` in `b3`.
1077+ // We should be able to remove the second repeated load.
1078+ let src = "
1079+ acir(inline) fn main f0 {
1080+ b0():
1081+ v0 = allocate -> &mut Field
1082+ store Field 0 at v0
1083+ v2 = allocate -> &mut &mut Field
1084+ store v0 at v2
1085+ jmp b1(Field 0)
1086+ b1(v3: Field):
1087+ v4 = eq v3, Field 0
1088+ jmpif v4 then: b2, else: b3
1089+ b2():
1090+ v5 = load v2 -> &mut Field
1091+ store Field 2 at v5
1092+ v8 = add v3, Field 1
1093+ jmp b1(v8)
1094+ b3():
1095+ v9 = load v0 -> Field
1096+ v10 = eq v9, Field 2
1097+ constrain v9 == Field 2
1098+ v11 = load v2 -> &mut Field
1099+ v12 = load v2 -> &mut Field
1100+ v13 = load v12 -> Field
1101+ v14 = eq v13, Field 2
1102+ constrain v13 == Field 2
1103+ return
1104+ }
1105+ " ;
1106+
1107+ let ssa = Ssa :: from_str ( src) . unwrap ( ) ;
1108+
1109+ // The repeated load from v3 should be removed
1110+ // b3 should only have three loads now rather than four previously
1111+ //
1112+ // All stores are expected to remain.
1113+ let expected = "
1114+ acir(inline) fn main f0 {
1115+ b0():
1116+ v1 = allocate -> &mut Field
1117+ store Field 0 at v1
1118+ v3 = allocate -> &mut &mut Field
1119+ store v1 at v3
1120+ jmp b1(Field 0)
1121+ b1(v0: Field):
1122+ v4 = eq v0, Field 0
1123+ jmpif v4 then: b3, else: b2
1124+ b3():
1125+ v11 = load v3 -> &mut Field
1126+ store Field 2 at v11
1127+ v13 = add v0, Field 1
1128+ jmp b1(v13)
1129+ b2():
1130+ v5 = load v1 -> Field
1131+ v7 = eq v5, Field 2
1132+ constrain v5 == Field 2
1133+ v8 = load v3 -> &mut Field
1134+ v9 = load v8 -> Field
1135+ v10 = eq v9, Field 2
1136+ constrain v9 == Field 2
1137+ return
1138+ }
1139+ " ;
1140+
1141+ let ssa = ssa. mem2reg ( ) ;
1142+ assert_normalized_ssa_equals ( ssa, expected) ;
1143+ }
1144+
1145+ #[ test]
1146+ fn keep_repeat_loads_passed_to_a_call ( ) {
1147+ // The test is the exact same as `remove_repeat_loads` above except with the call
1148+ // to `f1` between the repeated loads.
1149+ let src = "
1150+ acir(inline) fn main f0 {
1151+ b0():
1152+ v1 = allocate -> &mut Field
1153+ store Field 0 at v1
1154+ v3 = allocate -> &mut &mut Field
1155+ store v1 at v3
1156+ jmp b1(Field 0)
1157+ b1(v0: Field):
1158+ v4 = eq v0, Field 0
1159+ jmpif v4 then: b3, else: b2
1160+ b3():
1161+ v13 = load v3 -> &mut Field
1162+ store Field 2 at v13
1163+ v15 = add v0, Field 1
1164+ jmp b1(v15)
1165+ b2():
1166+ v5 = load v1 -> Field
1167+ v7 = eq v5, Field 2
1168+ constrain v5 == Field 2
1169+ v8 = load v3 -> &mut Field
1170+ call f1(v3)
1171+ v10 = load v3 -> &mut Field
1172+ v11 = load v10 -> Field
1173+ v12 = eq v11, Field 2
1174+ constrain v11 == Field 2
1175+ return
1176+ }
1177+ acir(inline) fn foo f1 {
1178+ b0(v0: &mut Field):
1179+ return
1180+ }
1181+ " ;
1182+
1183+ let ssa = Ssa :: from_str ( src) . unwrap ( ) ;
1184+
1185+ let ssa = ssa. mem2reg ( ) ;
1186+ // We expect the program to be unchanged
1187+ assert_normalized_ssa_equals ( ssa, src) ;
1188+ }
1189+
1190+ #[ test]
1191+ fn keep_repeat_loads_with_alias_store ( ) {
1192+ // v7, v8, and v9 alias one another. We want to make sure that a repeat load to v7 with a store
1193+ // to its aliases in between the repeat loads does not remove those loads.
1194+ let src = "
1195+ acir(inline) fn main f0 {
1196+ b0(v0: u1):
1197+ jmpif v0 then: b2, else: b1
1198+ b2():
1199+ v6 = allocate -> &mut Field
1200+ store Field 0 at v6
1201+ jmp b3(v6, v6, v6)
1202+ b3(v1: &mut Field, v2: &mut Field, v3: &mut Field):
1203+ v8 = load v1 -> Field
1204+ store Field 2 at v2
1205+ v10 = load v1 -> Field
1206+ store Field 1 at v3
1207+ v11 = load v1 -> Field
1208+ store Field 3 at v3
1209+ v13 = load v1 -> Field
1210+ constrain v8 == Field 0
1211+ constrain v10 == Field 2
1212+ constrain v11 == Field 1
1213+ constrain v13 == Field 3
1214+ return
1215+ b1():
1216+ v4 = allocate -> &mut Field
1217+ store Field 1 at v4
1218+ jmp b3(v4, v4, v4)
1219+ }
1220+ " ;
1221+
1222+ let ssa = Ssa :: from_str ( src) . unwrap ( ) ;
1223+
1224+ let ssa = ssa. mem2reg ( ) ;
1225+ // We expect the program to be unchanged
1226+ assert_normalized_ssa_equals ( ssa, src) ;
1227+ }
10091228}
0 commit comments