Skip to content

Commit 143b68a

Browse files
committed
Bugfix
1 parent 8ee7431 commit 143b68a

File tree

8 files changed

+188
-21
lines changed

8 files changed

+188
-21
lines changed

src/main/java/org/apache/sysds/hops/rewriter/MetaPropagator.java

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,21 @@ public MetaPropagator(RuleContext ctx) {
2020
// TODO: Maybe automatically recompute hash codes?
2121
public RewriterStatement apply(RewriterStatement root) {
2222
//System.out.println("Propagating...");
23-
//System.out.println("--> " + root);
23+
//System.out.println("--> " + root.toParsableString(ctx));
2424
RewriterAssertions assertions = root.getAssertions(ctx);
2525
MutableObject<RewriterStatement> out = new MutableObject<>(root);
2626
HashMap<Object, RewriterStatement> literalMap = new HashMap<>();
2727
root.forEachPostOrderWithDuplicates((el, parent, pIdx) -> {
2828
//System.out.println("mAssertions: " + assertions);
29+
/*System.out.println("Assessing: " + el.toParsableString(ctx));
30+
if (parent != null)
31+
System.out.println("With parent: " + parent.toParsableString(ctx));*/
2932
RewriterStatement toSet = propagateDims(el, parent, pIdx, assertions);
3033

3134
if (toSet != null && toSet != el) {
3235
/*System.out.println("Set: " + toSet);
3336
System.out.println("Old: " + el);
34-
System.out.println("Parent: " + parent.toParsableString(ctx));
37+
System.out.println("Parent: " + parent);
3538
System.out.println("PIdx: " + pIdx);*/
3639
el = toSet;
3740
if (parent == null)
@@ -296,7 +299,7 @@ private RewriterStatement propagateDims(RewriterStatement root, RewriterStatemen
296299
return null;
297300
}
298301

299-
throw new NotImplementedException("Unknown instruction: " + instr.trueTypedInstruction(ctx) + "\n" + instr.toString(ctx));
302+
throw new NotImplementedException("Unknown instruction: " + instr.trueTypedInstruction(ctx) + "\n" + instr.toParsableString(ctx));
300303
}
301304

302305
return null;

src/main/java/org/apache/sysds/hops/rewriter/RewriterAssertions.java

Lines changed: 129 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,19 @@
55
import java.util.Collections;
66
import java.util.HashMap;
77
import java.util.HashSet;
8+
import java.util.List;
89
import java.util.Map;
910
import java.util.Objects;
1011
import java.util.Set;
1112
import java.util.UUID;
13+
import java.util.function.Consumer;
1214
import java.util.stream.Collectors;
1315

1416
public class RewriterAssertions {
1517
private final RuleContext ctx;
1618
private Map<RewriterStatement, RewriterAssertion> assertionMatcher = new HashMap<>();
19+
// Tracks which statements are part of which assertions
20+
private Map<RewriterStatement, Set<RewriterAssertion>> partOfAssertion = new HashMap<>();
1721
private Set<RewriterAssertion> allAssertions = new HashSet<>();
1822

1923
public RewriterAssertions(final RuleContext ctx) {
@@ -40,6 +44,7 @@ public RewriterAssertions(final RuleContext ctx) {
4044
return assertions;
4145
}*/
4246

47+
// TODO: Add parts of assertions map
4348
public static RewriterAssertions copy(RewriterAssertions old, Map<RewriterStatement, RewriterStatement> createdObjects, boolean removeOthers) {
4449
//System.out.println("Copying: " + old);
4550
RewriterAssertions newAssertions = new RewriterAssertions(old.ctx);
@@ -69,10 +74,17 @@ public static RewriterAssertions copy(RewriterAssertions old, Map<RewriterStatem
6974
RewriterAssertion mapped = RewriterAssertion.from(newSet);
7075
if (assertion.stmt != null)
7176
mapped.stmt = createdObjects.get(assertion.stmt);
77+
if (assertion.backRef != null)
78+
mapped.backRef = createdObjects.get(assertion.backRef);
7279
mappedAssertions.put(assertion, mapped);
7380
return mapped;
7481
}).filter(Objects::nonNull).collect(Collectors.toSet());
7582

83+
newAssertions.partOfAssertion = old.partOfAssertion.entrySet().stream().collect(Collectors.toMap(
84+
v -> createdObjects.getOrDefault(v.getKey(), v.getKey()),
85+
v -> v.getValue().stream().map(mappedAssertions::get).collect(Collectors.toSet())
86+
));
87+
7688
if (removeOthers) {
7789
old.assertionMatcher.forEach((k, v) -> {
7890
RewriterStatement newK = createdObjects.get(k);
@@ -100,6 +112,7 @@ public static RewriterAssertions copy(RewriterAssertions old, Map<RewriterStatem
100112
}
101113

102114
//System.out.println("New: " + newAssertions);
115+
//System.out.println("New parts: " + newAssertions.partOfAssertion);
103116

104117
return newAssertions;
105118
}
@@ -153,6 +166,19 @@ public boolean addEqualityAssertion(RewriterStatement stmt1, RewriterStatement s
153166
allAssertions.add(newAssertion);
154167

155168
resolveCyclicAssertions(newAssertion);
169+
170+
forEachUniqueElementInAssertion(newAssertion, cur -> {
171+
partOfAssertion.compute(cur, (k, v) -> {
172+
if (v == null)
173+
v = new HashSet<>();
174+
175+
v.add(newAssertion);
176+
return v;
177+
});
178+
});
179+
180+
//System.out.println("MNew parts: " + partOfAssertion);
181+
156182
//System.out.println("New assertion1: " + newAssertion);
157183
return true;
158184
}
@@ -171,6 +197,18 @@ public boolean addEqualityAssertion(RewriterStatement stmt1, RewriterStatement s
171197
updateInstance(existingAssertion.stmt.getChild(0), existingAssertion.set);
172198

173199
resolveCyclicAssertions(existingAssertion);
200+
201+
toAssert.forEachPreOrder(cur -> {
202+
partOfAssertion.compute(cur, (k, v) -> {
203+
if (v == null)
204+
v = new HashSet<>();
205+
206+
v.add(existingAssertion);
207+
return v;
208+
});
209+
return true;
210+
});
211+
174212
//System.out.println("New assertion2: " + existingAssertion);
175213
return true;
176214
}
@@ -198,23 +236,53 @@ public boolean addEqualityAssertion(RewriterStatement stmt1, RewriterStatement s
198236
//System.out.println("New assertion3: " + stmt2Assertions);
199237
resolveCyclicAssertions(stmt2Assertions);
200238

239+
final RewriterAssertion assertionToRemove = stmt1Assertions;
240+
final RewriterAssertion assertionToExtend = stmt2Assertions;
241+
forEachUniqueElementInAssertion(stmt1Assertions, cur -> {
242+
Set<RewriterAssertion> v = partOfAssertion.get(cur);
243+
v.remove(assertionToRemove);
244+
v.add(assertionToExtend);
245+
});
246+
201247
return true;
202248
}
203249

250+
private void forEachUniqueElementInAssertion(RewriterAssertion assertion, Consumer<RewriterStatement> consumer) {
251+
Set<RewriterStatement> visited = new HashSet<>();
252+
for (RewriterStatement eq : assertion.set) {
253+
eq.forEachPreOrderWithDuplicates(cur -> {
254+
if (!visited.add(cur))
255+
return false;
256+
257+
consumer.accept(cur);
258+
return true;
259+
});
260+
}
261+
}
262+
204263
// Replace cycles with _backRef()
264+
// TODO: Also copy duplicate referenced sub-trees to avoid cycles (e.g. _EClass(a*b+c, a) and sqrt(a*b) => What to do with a in a*b? _backRef or _EClass?)
265+
// TODO: This requires a guarantee that reference counts are intact
205266
private void resolveCyclicAssertions(RewriterAssertion assertion) {
206267
if (assertion.stmt == null)
207268
return;
208269

209270
//System.out.println("Resolving cycles in: " + assertion);
210271

211-
String rType = assertion.stmt.getResultingDataType(ctx);
272+
RewriterStatement backref = assertion.getBackRef(ctx, this);
273+
//String rType = assertion.stmt.getResultingDataType(ctx);
212274

213-
RewriterStatement backref = new RewriterInstruction()
275+
/*RewriterStatement backref = new RewriterInstruction()
214276
.as(UUID.randomUUID().toString())
215277
.withInstruction("_backRef." + rType)
216278
.consolidate(ctx);
217-
backref.unsafePutMeta("_backRef", assertion.stmt);
279+
backref.unsafePutMeta("_backRef", assertion.stmt);*/
280+
281+
// Check if any sub-graph of the E-Graph is referenced outside the E-Class
282+
// If any child of the duplicate reference would create a back-reference, we need to copy the entire sub-graph
283+
//HashMap<RewriterStatement, Integer> refCtr = new HashMap<>();
284+
285+
218286

219287
for (RewriterStatement eq : assertion.set) {
220288
eq.forEachPreOrder((cur, parent, pIdx) -> {
@@ -241,24 +309,30 @@ public RewriterStatement getAssertionStatement(RewriterStatement stmt, RewriterS
241309
//System.out.println("In: " + this);
242310
RewriterAssertion set = assertionMatcher.get(stmt);
243311

244-
if (set == null)
312+
if (set == null || set.getEClassStmt(ctx, this).getChild(0) == parent)
245313
return stmt;
246314

247-
RewriterStatement mstmt = set.stmt;
315+
//System.out.println("EClassStmt: " + set.getEClassStmt(ctx, this).getChild(0));
316+
if (parent != null && parent != set.getEClassStmt(ctx, this).getChild(0) && partOfAssertion.getOrDefault(parent, Collections.emptySet()).contains(set))
317+
return set.getBackRef(ctx, this);
248318

249-
if (mstmt == null) {
319+
/*RewriterStatement mstmt = set.stmt;
320+
321+
if (mstmt == null)
322+
mstmt = set.getEClassStmt(ctx, this);
323+
{
250324
// Then we create a new statement for it
251325
RewriterStatement argList = new RewriterInstruction().as(UUID.randomUUID().toString()).withInstruction("argList").withOps(set.set.toArray(RewriterStatement[]::new));
252326
mstmt = new RewriterInstruction().as(UUID.randomUUID().toString()).withInstruction("_EClass").withOps(argList);
253327
mstmt.consolidate(ctx);
254328
set.stmt = mstmt;
255329
assertionMatcher.put(set.stmt, set);
256330
resolveCyclicAssertions(set);
257-
} else if (mstmt.getChild(0) == parent) {
331+
}*/ /*else if (mstmt.getChild(0) == parent) {
258332
return stmt;
259-
}
333+
}*/
260334

261-
return mstmt;
335+
return set.getEClassStmt(ctx, this);
262336
}
263337

264338
// TODO: This does not handle metadata
@@ -324,6 +398,52 @@ private void updateInstance(RewriterStatement stmt, Set<RewriterStatement> set)
324398
private static class RewriterAssertion {
325399
Set<RewriterStatement> set;
326400
RewriterStatement stmt;
401+
RewriterStatement backRef; // The back-reference to this assertion
402+
403+
RewriterStatement getEClassStmt(final RuleContext ctx, RewriterAssertions assertions) {
404+
if (stmt != null)
405+
return stmt;
406+
407+
RewriterStatement argList = new RewriterInstruction().as(UUID.randomUUID().toString()).withInstruction("argList").withOps(set.toArray(RewriterStatement[]::new));
408+
stmt = new RewriterInstruction().as(UUID.randomUUID().toString()).withInstruction("_EClass").withOps(argList);
409+
stmt.consolidate(ctx);
410+
assertions.assertionMatcher.put(stmt, this);
411+
assertions.partOfAssertion.compute(stmt, (k, v) -> {
412+
if (v == null)
413+
v = new HashSet<>();
414+
415+
v.add(this);
416+
return v;
417+
});
418+
assertions.partOfAssertion.compute(argList, (k, v) -> {
419+
if (v == null)
420+
v = new HashSet<>();
421+
422+
v.add(this);
423+
return v;
424+
});
425+
assertions.resolveCyclicAssertions(this);
426+
return stmt;
427+
}
428+
429+
RewriterStatement getBackRef(final RuleContext ctx, RewriterAssertions assertions) {
430+
if (backRef != null)
431+
return backRef;
432+
433+
backRef = new RewriterInstruction()
434+
.as(UUID.randomUUID().toString())
435+
.withInstruction("_backRef." + getEClassStmt(ctx, assertions).getResultingDataType(ctx))
436+
.consolidate(ctx);
437+
backRef.unsafePutMeta("_backRef", getEClassStmt(ctx, assertions));
438+
assertions.partOfAssertion.compute(backRef, (k, v) -> {
439+
if (v == null)
440+
v = new HashSet<>();
441+
442+
v.add(this);
443+
return v;
444+
});
445+
return backRef;
446+
}
327447

328448
@Override
329449
public String toString() {

src/main/java/org/apache/sysds/hops/rewriter/RewriterContextSettings.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,8 @@ public static String getDefaultContextString() {
324324
});
325325

326326
builder.append("_m(INT,INT,FLOAT)::MATRIX\n");
327+
builder.append("_m(INT,INT,BOOL)::MATRIX\n");
328+
builder.append("_m(INT,INT,INT)::MATRIX\n");
327329
List.of("FLOAT", "INT", "BOOL").forEach(t -> {
328330
builder.append("_idxExpr(INT," + t + ")::" + t + "*\n");
329331
builder.append("_idxExpr(INT," + t + "*)::" + t + "*\n");

src/main/java/org/apache/sysds/hops/rewriter/RewriterHeuristic.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ public RewriterStatement apply(RewriterStatement currentStmt, @Nullable BiFuncti
5858
foundRewrite.setValue(true);
5959

6060
while (rule != null) {
61-
System.out.println("Pre-apply: " + rule.rule.getName());
61+
//System.out.println("Pre-apply: " + rule.rule.getName());
6262
/*if (currentStmt.toParsableString(ruleSet.getContext()).equals("%*%(X,[](B,1,ncol(X),1,ncol(B)))"))
6363
System.out.println("test");*/
6464
/*System.out.println("Expr: " + rule.matches.get(0).getExpressionRoot().toParsableString(ruleSet.getContext()));

src/main/java/org/apache/sysds/hops/rewriter/RewriterInstruction.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ public String getResultingDataType(final RuleContext ctx) {
4141
else
4242
returnType = ctx.instrTypes.get(trueTypedInstruction(ctx));//getResult(ctx).getResultingDataType(ctx);
4343

44+
if (returnType == null)
45+
throw new IllegalArgumentException("Return type not found for: " + trueTypedInstruction(ctx));
46+
4447
return returnType;
4548
}
4649

@@ -446,8 +449,7 @@ public int toParsableString(StringBuilder sb, Map<RewriterRule.IdentityRewriterS
446449
}
447450

448451
sb.append(')');
449-
// TODO: Remove
450-
sb.append("::" + getResultingDataType(ctx));
452+
//sb.append("::" + getResultingDataType(ctx));
451453

452454
return maxRefId;
453455
}

src/main/java/org/apache/sysds/hops/rewriter/RewriterRuleCollection.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -969,13 +969,14 @@ public static void pushdownStreamSelections(final List<RewriterRule> rules, fina
969969
.build()
970970
);
971971

972-
rules.add(new RewriterRuleBuilder(ctx, "_m(i::<const>, j::<const>, v) => v")
972+
// TODO: Deal with boolean or int matrices
973+
rules.add(new RewriterRuleBuilder(ctx, "_m(i::<const>, j::<const>, v) => cast.MATRIX(v)")
973974
.setUnidirectional(true)
974975
.parseGlobalVars("MATRIX:A,B")
975976
.parseGlobalVars("INT:i,j")
976977
.parseGlobalVars("FLOAT:v")
977978
.withParsedStatement("_m(i, j, v)", hooks)
978-
.toParsedStatement("v", hooks)
979+
.toParsedStatement("cast.MATRIX(v)", hooks)
979980
.iff(match -> {
980981
List<RewriterStatement> ops = match.getMatchRoot().getOperands();
981982

src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterClusteringTest.java

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,17 @@ public void testExpressionClustering() {
6969
db.forEach(expr -> {
7070
if (ctr.incrementAndGet() % 10 == 0)
7171
System.out.println("Done: " + ctr.intValue() + " / " + size);
72-
if (ctr.intValue() > 1000)
73-
return; // Skip
72+
//if (ctr.intValue() > 1000)
73+
//return; // Skip
7474
// First, build all possible subtrees
75-
System.out.println("Eval: " + expr.toParsableString(ctx));
75+
//System.out.println("Eval:\n" + expr.toParsableString(ctx, true));
7676
List<RewriterStatement> subExprs = RewriterUtils.generateSubtrees(expr, ctx, 500);
7777
if (subExprs.size() > 100)
7878
System.out.println("Critical number of subtrees: " + subExprs.size());
79+
if (subExprs.size() > 2000) {
80+
System.out.println("Skipping subtrees...");
81+
subExprs = List.of(expr);
82+
}
7983
//List<RewriterStatement> subExprs = List.of(expr);
8084
long evaluationCtr = 0;
8185

src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterStreamTests.java

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,14 @@ public void testSimpleSumPullOut() {
527527
assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2));
528528
}
529529

530+
@Test
531+
public void testBackrefInequality() {
532+
// TODO
533+
// Some example where _backRef() is not the same as another one
534+
// As we need to compare to the meta-data
535+
assert false;
536+
}
537+
530538
@Test
531539
public void myTest() {
532540
RewriterStatement stmt1 = RewriterUtils.parse("sum(-(X, 7))", ctx, "MATRIX:X,Y", "LITERAL_INT:1,7", "INT:a", "LITERAL_FLOAT:7.0");
@@ -581,5 +589,32 @@ public void myTest6() {
581589
System.out.println(stmt.toParsableString(ctx, true));
582590
}
583591

584-
// TODO: There is a problem if e.g. _EClass(argList(ncol(X), +(ncol(X), 0))) as then ncol(X) will be replaced again with the _EClass
592+
@Test
593+
public void myTest7() {
594+
String stmtStr = "MATRIX:combined\n" +
595+
"FLOAT:int0,int496,int236,int618\n" +
596+
"LITERAL_INT:1,2\n" +
597+
"INT:parsertemp71754,int497,int280\n" +
598+
"&(RBind(!=([](combined,1,-(parsertemp71754,int497),1,ncol(combined)),[](combined,2,nrow(combined),1,ncol(combined))),rand(1,1,int0,int496)),RBind(rand(1,1,int618,int236),!=([](combined,1,-(parsertemp71754,int280),1,ncol(combined)),[](combined,2,nrow(combined),1,ncol(combined)))))";
599+
600+
RewriterStatement stmt = RewriterUtils.parse(stmtStr, ctx);
601+
stmt = canonicalConverter.apply(stmt);
602+
603+
System.out.println("==========");
604+
System.out.println(stmt.toParsableString(ctx, true));
605+
}
606+
607+
@Test
608+
public void myTest8() {
609+
String stmtStr = "MATRIX:prec_chol,X,mu\n" +
610+
"INT:i,k\n" +
611+
"LITERAL_INT:1,5\n" +
612+
"%*%(X,[](prec_chol,1,*(i,ncol(X)),1,5))";
613+
614+
RewriterStatement stmt = RewriterUtils.parse(stmtStr, ctx);
615+
stmt = canonicalConverter.apply(stmt);
616+
617+
System.out.println("==========");
618+
System.out.println(stmt.toParsableString(ctx, true));
619+
}
585620
}

0 commit comments

Comments
 (0)