@@ -32,11 +32,11 @@ public class RewriterCostEstimator {
3232 public static final BiFunction <RewriterStatement , Tuple2 <Long , Long >, Long > DEFAULT_NNZ_FN = (el , tpl ) -> tpl ._1 * tpl ._2 ;
3333
3434 // Computes the cost of an expression using different matrix dimensions and sparsities
35- public static void compareCosts (RewriterStatement stmt1 , RewriterStatement stmt2 , final RuleContext ctx ) {
35+ public static void compareCosts (RewriterStatement stmt1 , RewriterStatement stmt2 , RewriterAssertions jointAssertions , final RuleContext ctx ) {
3636 Map <RewriterStatement , RewriterStatement > estimates1 = RewriterSparsityEstimator .estimateAllNNZ (stmt1 , ctx );
3737 Map <RewriterStatement , RewriterStatement > estimates2 = RewriterSparsityEstimator .estimateAllNNZ (stmt2 , ctx );
3838
39- MutableObject <RewriterAssertions > assertionRef = new MutableObject <>();
39+ MutableObject <RewriterAssertions > assertionRef = new MutableObject <>(jointAssertions );
4040 RewriterStatement costFn1 = getRawCostFunction (stmt1 , ctx , assertionRef );
4141 RewriterStatement costFn2 = getRawCostFunction (stmt2 , ctx , assertionRef );
4242
@@ -46,9 +46,12 @@ public static void compareCosts(RewriterStatement stmt1, RewriterStatement stmt2
4646 long [] dimVals = new long [] {1 , 5000 };
4747 double [] sparsities = new double [] {1.0D , 0.2D , 0.001D };
4848
49+
50+ costFn1 .unsafePutMeta ("_assertions" , jointAssertions );
4951 Map <RewriterStatement , RewriterStatement > createdObjects = new HashMap <>();
5052 RewriterStatement costFn1Cpy = costFn1 .nestedCopy (true , createdObjects );
51- RewriterStatement costFn2Cpy = costFn2 .nestedCopy (true , createdObjects );
53+ RewriterStatement costFn2Cpy = costFn2 .nestedCopy (false , createdObjects );
54+ costFn2Cpy .unsafePutMeta ("_assertions" , costFn1Cpy .getAssertions (ctx ));
5255
5356 Set <RewriterStatement > dimsToPopulate = new HashSet <>();
5457 Set <RewriterStatement > nnzsToPopulate = new HashSet <>();
0 commit comments