Skip to content

Commit e1848ce

Browse files
committed
Bugfix
1 parent ed84ce7 commit e1848ce

2 files changed

Lines changed: 8 additions & 5 deletions

File tree

src/main/java/org/apache/sysds/hops/rewriter/estimators/RewriterCostEstimator.java

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,11 @@ public class RewriterCostEstimator {
3232
public static final BiFunction<RewriterStatement, Tuple2<Long, Long>, Long> DEFAULT_NNZ_FN = (el, tpl) -> tpl._1 * tpl._2;
3333

3434
// Computes the cost of an expression using different matrix dimensions and sparsities
35-
public static void compareCosts(RewriterStatement stmt1, RewriterStatement stmt2, final RuleContext ctx) {
35+
public static void compareCosts(RewriterStatement stmt1, RewriterStatement stmt2, RewriterAssertions jointAssertions, final RuleContext ctx) {
3636
Map<RewriterStatement, RewriterStatement> estimates1 = RewriterSparsityEstimator.estimateAllNNZ(stmt1, ctx);
3737
Map<RewriterStatement, RewriterStatement> estimates2 = RewriterSparsityEstimator.estimateAllNNZ(stmt2, ctx);
3838

39-
MutableObject<RewriterAssertions> assertionRef = new MutableObject<>();
39+
MutableObject<RewriterAssertions> assertionRef = new MutableObject<>(jointAssertions);
4040
RewriterStatement costFn1 = getRawCostFunction(stmt1, ctx, assertionRef);
4141
RewriterStatement costFn2 = getRawCostFunction(stmt2, ctx, assertionRef);
4242

@@ -46,9 +46,12 @@ public static void compareCosts(RewriterStatement stmt1, RewriterStatement stmt2
4646
long[] dimVals = new long[] {1, 5000};
4747
double[] sparsities = new double[] {1.0D, 0.2D, 0.001D};
4848

49+
50+
costFn1.unsafePutMeta("_assertions", jointAssertions);
4951
Map<RewriterStatement, RewriterStatement> createdObjects = new HashMap<>();
5052
RewriterStatement costFn1Cpy = costFn1.nestedCopy(true, createdObjects);
51-
RewriterStatement costFn2Cpy = costFn2.nestedCopy(true, createdObjects);
53+
RewriterStatement costFn2Cpy = costFn2.nestedCopy(false, createdObjects);
54+
costFn2Cpy.unsafePutMeta("_assertions", costFn1Cpy.getAssertions(ctx));
5255

5356
Set<RewriterStatement> dimsToPopulate = new HashSet<>();
5457
Set<RewriterStatement> nnzsToPopulate = new HashSet<>();

src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/SparsityEstimationTest.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,8 @@ public void test4() {
9393

9494
RewriterAssertionUtils.buildImplicitAssertion(rule.getStmt1(), rule.getStmt1().getAssertions(ctx), ctx);
9595
RewriterAssertionUtils.buildImplicitAssertion(rule.getStmt2(), rule.getStmt1().getAssertions(ctx), ctx);
96-
rule.getStmt2().unsafePutMeta("_assertions", rule.getStmt1().getAssertions(ctx));
96+
//rule.getStmt2().unsafePutMeta("_assertions", rule.getStmt1().getAssertions(ctx));
9797

98-
RewriterCostEstimator.compareCosts(rule.getStmt1(), rule.getStmt2(), ctx);
98+
RewriterCostEstimator.compareCosts(rule.getStmt1(), rule.getStmt2(), rule.getStmt1().getAssertions(ctx), ctx);
9999
}
100100
}

0 commit comments

Comments
 (0)