Skip to content

Commit e84eb00

Browse files
committed
Bugfix
1 parent 522c46b commit e84eb00

File tree

5 files changed

+32
-1
lines changed

5 files changed

+32
-1
lines changed

src/main/java/org/apache/sysds/hops/rewriter/MetaPropagator.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ public RewriterStatement apply(RewriterStatement root) {
7171
}
7272

7373
private RewriterStatement propagateDims(RewriterStatement root, RewriterStatement parent, int pIdx, RewriterAssertions assertions) {
74+
if (root.getResultingDataType(ctx) == null)
75+
throw new IllegalArgumentException("Null type: " + root.toParsableString(ctx));
7476
if (!root.getResultingDataType(ctx).startsWith("MATRIX")) {
7577
if (root.isInstruction()) {
7678
String ti = root.trueTypedInstruction(ctx);

src/main/java/org/apache/sysds/hops/rewriter/RewriterHeuristic.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ public RewriterStatement apply(RewriterStatement currentStmt, @Nullable BiFuncti
5858
foundRewrite.setValue(true);
5959

6060
while (rule != null) {
61-
//System.out.println("Pre-apply: " + rule.rule.getName());
61+
System.out.println("Pre-apply: " + rule.rule.getName());
6262
/*if (currentStmt.toParsableString(ruleSet.getContext()).equals("%*%(X,[](B,1,ncol(X),1,ncol(B)))"))
6363
System.out.println("test");*/
6464
/*System.out.println("Expr: " + rule.matches.get(0).getExpressionRoot().toParsableString(ruleSet.getContext()));

src/main/java/org/apache/sysds/hops/rewriter/RewriterInstruction.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ public RewriterStatement copyNode() {
189189
mCopy.costFunction = costFunction;
190190
mCopy.consolidated = consolidated;
191191
mCopy.operands = new ArrayList<>(operands);
192+
mCopy.returnType = returnType;
192193
if (meta != null)
193194
mCopy.meta = new HashMap<>(meta);
194195
else
@@ -215,6 +216,7 @@ public RewriterStatement nestedCopyOrInject(Map<RewriterStatement, RewriterState
215216
mCopy.costFunction = costFunction;
216217
mCopy.consolidated = consolidated;
217218
mCopy.operands = new ArrayList<>(operands.size());
219+
mCopy.returnType = returnType;
218220
mCopy.hashCode = hashCode;
219221
if (meta != null)
220222
mCopy.meta = new HashMap<>(meta);
@@ -245,6 +247,7 @@ public boolean isInstruction() {
245247
return true;
246248
}
247249

250+
@Deprecated
248251
@Override
249252
public RewriterStatement clone() {
250253
RewriterInstruction mClone = new RewriterInstruction();
@@ -259,6 +262,7 @@ public RewriterStatement clone() {
259262
mClone.operands = clonedOperands;
260263
mClone.costFunction = costFunction;
261264
mClone.consolidated = consolidated;
265+
mClone.returnType = returnType;
262266
mClone.meta = meta;
263267
return mClone;
264268
}

src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterClusteringTest.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ public void testExpressionClustering() {
6969
db.forEach(expr -> {
7070
if (ctr.incrementAndGet() % 10 == 0)
7171
System.out.println("Done: " + ctr.intValue() + " / " + size);
72+
if (ctr.intValue() > 1000)
73+
return; // Skip
7274
// First, build all possible subtrees
7375
System.out.println("Eval: " + expr.toParsableString(ctx));
7476
List<RewriterStatement> subExprs = RewriterUtils.generateSubtrees(expr, ctx, 500);
@@ -163,12 +165,26 @@ private boolean checkRelevance(List<RewriterStatement> stmts) {
163165
TopologicalSort.sort(stmt2, ctx);
164166

165167
if (!stmt1.match(RewriterStatement.MatcherContext.exactMatchWithDifferentLiteralValues(ctx, stmt2))) {
168+
// TODO: Minimal difference can still prune valid rewrites (e.g. sum(A %*% B) -> sum(A * t(B)))
166169
RewriterStatement.MatcherContext mCtx = RewriterStatement.MatcherContext.findMinimalDifference(ctx, stmts.get(j));
167170
stmts.get(i).match(mCtx);
168171
Tuple2<RewriterStatement, RewriterStatement> minimalDifference = mCtx.getFirstMismatch();
169172

170173
if (minimalDifference._1 == stmts.get(i))
171174
match = false;
175+
else {
176+
// Otherwise we need to work ourselves backwards to the root if both canonical forms don't match now
177+
RewriterStatement minStmt1 = minimalDifference._1.nestedCopy();
178+
RewriterStatement minStmt2 = minimalDifference._2.nestedCopy();
179+
minStmt1 = converter.apply(minStmt1);
180+
minStmt2 = converter.apply(minStmt2);
181+
182+
if (minStmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, minStmt2))) {
183+
// Then the minimal difference does not imply equivalence
184+
// For now, just keep every result then
185+
match = false;
186+
}
187+
}
172188
}
173189
}
174190
}

src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterStreamTests.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -556,5 +556,14 @@ public void myTest5() {
556556
System.out.println(stmt.toParsableString(ctx, true));
557557
}
558558

559+
@Test
560+
public void myTest6() {
561+
RewriterStatement stmt = RewriterUtils.parse("rowSums(<=(D,minD))", ctx, "MATRIX:D,minD");
562+
stmt = canonicalConverter.apply(stmt);
563+
564+
System.out.println("==========");
565+
System.out.println(stmt.toParsableString(ctx, true));
566+
}
567+
559568
// TODO: There is a problem if e.g. _EClass(argList(ncol(X), +(ncol(X), 0))) as then ncol(X) will be replaced again with the _EClass
560569
}

0 commit comments

Comments
 (0)