@@ -277,7 +277,8 @@ fn remove_first_class_functions_in_instruction(
277277/// Try to map the given function literal to a field, returning Some(field) on success.
278278/// Returns none if the given value was not a function or doesn't need to be mapped.
279279fn map_function_to_field ( func : & mut Function , value : ValueId ) -> Option < ValueId > {
280- if is_function_type ( func. dfg [ value] . get_type ( ) . as_ref ( ) ) {
280+ let typ = func. dfg [ value] . get_type ( ) ;
281+ if is_function_type ( typ. as_ref ( ) ) {
281282 match & func. dfg [ value] {
282283 // If the value is a static function, transform it to the function id
283284 Value :: Function ( id) => {
@@ -286,7 +287,7 @@ fn map_function_to_field(func: &mut Function, value: ValueId) -> Option<ValueId>
286287 }
287288 // If the value is a function used as value, just change the type of it
288289 Value :: Instruction { .. } | Value :: Param { .. } => {
289- func. dfg . set_type_of_value ( value, Type :: field ( ) ) ;
290+ func. dfg . set_type_of_value ( value, replacement_type ( typ . as_ref ( ) ) ) ;
290291 }
291292 _ => ( ) ,
292293 }
@@ -1247,4 +1248,58 @@ mod tests {
12471248 }
12481249 " ) ;
12491250 }
1251+
1252+ #[ test]
1253+ fn mut_ref_function_matching ( ) {
1254+ let src = "
1255+ brillig(inline) fn add_to_tally_public f0 {
1256+ b0():
1257+ v4 = allocate -> &mut function
1258+ store f2 at v4
1259+ v10 = call f10(v4, f33) -> Field
1260+ return
1261+ }
1262+ brillig(inline) fn lambda f2 {
1263+ b0():
1264+ return Field 1
1265+ }
1266+ brillig(inline) fn at f10 {
1267+ b0(v4: &mut function, v6: function):
1268+ v10 = call v6(v4) -> Field
1269+ return v10
1270+ }
1271+ brillig(inline) fn lambda f33 {
1272+ b0(v4: &mut function):
1273+ v10 = call v4() -> Field
1274+ return v10
1275+ }
1276+ " ;
1277+
1278+ let ssa = Ssa :: from_str ( src) . unwrap ( ) ;
1279+ let ssa = ssa. defunctionalize ( ) ;
1280+
1281+ assert_ssa_snapshot ! ( ssa, @r"
1282+ brillig(inline) fn add_to_tally_public f0 {
1283+ b0():
1284+ v0 = allocate -> &mut Field
1285+ store Field 1 at v0
1286+ v4 = call f2(v0, Field 3) -> Field
1287+ return
1288+ }
1289+ brillig(inline) fn lambda f1 {
1290+ b0():
1291+ return Field 1
1292+ }
1293+ brillig(inline) fn at f2 {
1294+ b0(v0: &mut Field, v1: Field):
1295+ v3 = call f3(v0) -> Field
1296+ return v3
1297+ }
1298+ brillig(inline) fn lambda f3 {
1299+ b0(v0: &mut Field):
1300+ v2 = call f1() -> Field
1301+ return v2
1302+ }
1303+ " ) ;
1304+ }
12501305}
0 commit comments