Skip to content

Commit 99e9d9a

Browse files
committed
Some more improvements
1 parent 0f9f5ad commit 99e9d9a

5 files changed

Lines changed: 213 additions & 6 deletions

File tree

src/main/java/org/apache/sysds/hops/rewriter/RewriterInstruction.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,15 @@ public boolean match(final MatcherContext mCtx) {
138138
if (mCtx.isDebug())
139139
System.out.println("Matching: " + this.toString(ctx) + " <=> " + stmt.toString(ctx));
140140

141+
// Check for some meta information
142+
if (mCtx.statementsCanBeVariables && getResultingDataType(ctx).equals("MATRIX")) {
143+
if (trueInstruction().equals("rowVec") && stmt.isRowVector()) {
144+
return true;
145+
} else if (trueInstruction().equals("colVec") && stmt.isColVector()) {
146+
return true;
147+
}
148+
}
149+
141150
if (stmt instanceof RewriterInstruction && getResultingDataType(ctx).equals(stmt.getResultingDataType(ctx))) {
142151
RewriterInstruction inst = (RewriterInstruction)stmt;
143152

src/main/java/org/apache/sysds/hops/rewriter/RewriterRuleCreator.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ public static boolean validateRuleApplicability(RewriterRule rule, final RuleCon
192192
else
193193
hop = prog.getStatementBlocks().get(0).getHops().get(0).getInput(0).getInput(0).getInput(0);
194194

195-
RewriterStatement stmt = RewriterRuntimeUtils.buildDAGFromHop(hop, 1000, ctx);
195+
RewriterStatement stmt = RewriterRuntimeUtils.buildDAGFromHop(hop, 1000, true, ctx);
196196

197197
if (stmt == null)
198198
return false;

src/main/java/org/apache/sysds/hops/rewriter/RewriterRuntimeUtils.java

Lines changed: 106 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -192,8 +192,77 @@ public static void forAllHops(DMLProgram program, Consumer<Hop> consumer) {
192192
sb.getHops().forEach(consumer);
193193
}
194194

195-
public static RewriterStatement buildDAGFromHop(Hop hop, int maxDepth, final RuleContext ctx) {
196-
return buildDAGRecursively(hop, null, new HashMap<>(), 0, maxDepth, ctx);
195+
public static RewriterStatement buildDAGFromHop(Hop hop, int maxDepth, boolean mindDataCharacteristics, final RuleContext ctx) {
196+
RewriterStatement out = buildDAGRecursively(hop, null, new HashMap<>(), 0, maxDepth, ctx);
197+
198+
if (mindDataCharacteristics)
199+
return populateDataCharacteristics(out, ctx);
200+
201+
return out;
202+
}
203+
204+
public static RewriterStatement populateDataCharacteristics(RewriterStatement stmt, final RuleContext ctx) {
205+
if (stmt == null)
206+
return null;
207+
208+
if (stmt instanceof RewriterDataType && stmt.getResultingDataType(ctx).equals("MATRIX")) {
209+
Long nrow = (Long) stmt.getMeta("_actualNRow");
210+
Long ncol = (Long) stmt.getMeta("_actualNCol");
211+
int matType = 0;
212+
213+
// TODO: what if matrix consists of one single element?
214+
if (nrow != null && nrow == 1L) {
215+
matType = 1;
216+
} else if (ncol != null && ncol == 1L) {
217+
matType = 2;
218+
}
219+
220+
if (matType > 0) {
221+
return new RewriterInstruction()
222+
.as(UUID.randomUUID().toString())
223+
.withInstruction(matType == 1 ? "rowVec" : "colVec")
224+
.withOps(stmt)
225+
.consolidate(ctx);
226+
}
227+
}
228+
229+
Map<RewriterStatement, RewriterStatement> createdObjects = new HashMap<>();
230+
231+
stmt.forEachPostOrder((cur, pred) -> {
232+
for (int i = 0; i < cur.getOperands().size(); i++) {
233+
RewriterStatement child = cur.getChild(i);
234+
235+
if (child instanceof RewriterDataType && child.getResultingDataType(ctx).equals("MATRIX")) {
236+
Long nrow = (Long) child.getMeta("_actualNRow");
237+
Long ncol = (Long) child.getMeta("_actualNCol");
238+
int matType = 0;
239+
240+
// TODO: what if matrix consists of one single element?
241+
if (nrow != null && nrow == 1L) {
242+
matType = 1;
243+
} else if (ncol != null && ncol == 1L) {
244+
matType = 2;
245+
}
246+
247+
if (matType > 0) {
248+
RewriterStatement created = createdObjects.get(child);
249+
250+
if (created == null) {
251+
created = new RewriterInstruction()
252+
.as(UUID.randomUUID().toString())
253+
.withInstruction(matType == 1 ? "rowVec" : "colVec")
254+
.withOps(child)
255+
.consolidate(ctx);
256+
createdObjects.put(child, created);
257+
}
258+
259+
cur.getOperands().set(i, created);
260+
}
261+
}
262+
}
263+
}, false);
264+
265+
return stmt;
197266
}
198267

199268
public static void forAllUniqueTranslatableStatements(DMLProgram program, int maxDepth, Consumer<RewriterStatement> stmt, RewriterDatabase db, final RuleContext ctx) {
@@ -364,6 +433,8 @@ private static RewriterStatement buildDAGRecursively(Hop next, @Nullable String
364433
if (stmt == null)
365434
return buildLeaf(next, expectedType, ctx);
366435

436+
insertDataCharacteristics(next, stmt, ctx);
437+
367438
if (buildInputs(stmt, next.getInput(), cache, true, depth, maxDepth, ctx))
368439
return stmt;
369440

@@ -377,6 +448,8 @@ private static RewriterStatement buildDAGRecursively(Hop next, @Nullable String
377448
if (stmt == null)
378449
return buildLeaf(next, expectedType, ctx);
379450

451+
insertDataCharacteristics(next, stmt, ctx);
452+
380453
if (buildInputs(stmt, next.getInput(), cache, true, depth, maxDepth, ctx))
381454
return stmt;
382455

@@ -390,6 +463,8 @@ private static RewriterStatement buildDAGRecursively(Hop next, @Nullable String
390463
if (stmt == null)
391464
return buildLeaf(next, expectedType, ctx);
392465

466+
insertDataCharacteristics(next, stmt, ctx);
467+
393468
if (buildInputs(stmt, next.getInput(), cache, true, depth, maxDepth, ctx))
394469
return stmt;
395470

@@ -403,6 +478,8 @@ private static RewriterStatement buildDAGRecursively(Hop next, @Nullable String
403478
if (stmt == null)
404479
return buildLeaf(next, expectedType, ctx);
405480

481+
insertDataCharacteristics(next, stmt, ctx);
482+
406483
if (buildInputs(stmt, next.getInput(), cache, true, depth, maxDepth, ctx))
407484
return stmt;
408485

@@ -416,6 +493,8 @@ private static RewriterStatement buildDAGRecursively(Hop next, @Nullable String
416493
if (stmt == null)
417494
return buildLeaf(next, expectedType, ctx);
418495

496+
insertDataCharacteristics(next, stmt, ctx);
497+
419498
if (buildInputs(stmt, next.getInput(), cache, true, depth, maxDepth, ctx))
420499
return stmt;
421500

@@ -429,6 +508,8 @@ private static RewriterStatement buildDAGRecursively(Hop next, @Nullable String
429508
if (stmt == null)
430509
return buildLeaf(next, expectedType, ctx);
431510

511+
insertDataCharacteristics(next, stmt, ctx);
512+
432513
if (buildInputs(stmt, next.getInput(), cache, true, depth, maxDepth, ctx))
433514
return stmt;
434515

@@ -443,6 +524,8 @@ private static RewriterStatement buildDAGRecursively(Hop next, @Nullable String
443524
if (stmt == null)
444525
return buildLeaf(next, expectedType, ctx);
445526

527+
insertDataCharacteristics(next, stmt, ctx);
528+
446529
if (buildInputs(stmt, interestingHops, cache, true, depth, maxDepth, ctx))
447530
return stmt;
448531

@@ -465,6 +548,19 @@ private static RewriterStatement buildDAGRecursively(Hop next, @Nullable String
465548
return null;
466549
}
467550

551+
private static void insertDataCharacteristics(Hop hop, RewriterStatement stmt, final RuleContext ctx) {
552+
if (stmt.getResultingDataType(ctx).equals("MATRIX")) {
553+
if (hop.getDataCharacteristics() != null) {
554+
long nrows = hop.getDataCharacteristics().getRows();
555+
long ncols = hop.getDataCharacteristics().getCols();
556+
if (nrows > 0)
557+
stmt.unsafePutMeta("_actualNRow", RewriterStatement.literal(ctx, nrows));
558+
if (ncols > 0)
559+
stmt.unsafePutMeta("_actualNCol", RewriterStatement.literal(ctx, nrows));
560+
}
561+
}
562+
}
563+
468564
// TODO: Maybe introduce other implicit conversions if types mismatch
469565
private static RewriterStatement checkForCorrectTypes(RewriterStatement stmt, @Nullable String expectedType, Hop hop, final RuleContext ctx) {
470566
if (stmt == null)
@@ -500,14 +596,19 @@ private static RewriterStatement buildLeaf(Hop hop, @Nullable String expectedTyp
500596
if (RewriterUtils.DOUBLE_PATTERN.matcher(hopName).matches() || RewriterUtils.SPECIAL_FLOAT_PATTERN.matcher(hopName).matches())
501597
hopName = "float" + new Random().nextInt(1000);
502598

503-
if (expectedType != null)
504-
return RewriterUtils.parse(hopName, ctx, expectedType + ":" + hopName);
599+
if (expectedType != null) {
600+
RewriterStatement stmt = RewriterUtils.parse(hopName, ctx, expectedType + ":" + hopName);
601+
insertDataCharacteristics(hop, stmt, ctx);
602+
return stmt;
603+
}
505604

506605
switch (hop.getDataType()) {
507606
case SCALAR:
508607
return buildScalarLeaf(hop, hopName, ctx);
509608
case MATRIX:
510-
return RewriterUtils.parse(hopName, ctx, "MATRIX:" + hopName);
609+
RewriterStatement stmt = RewriterUtils.parse(hopName, ctx, "MATRIX:" + hopName);
610+
insertDataCharacteristics(hop, stmt, ctx);
611+
return stmt;
511612
}
512613

513614
return null; // Not supported then

src/main/java/org/apache/sysds/hops/rewriter/RewriterStatement.java

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -966,6 +966,36 @@ public String toString() {
966966
return toString(RuleContext.currentContext);
967967
}
968968

969+
public boolean isColVector() {
970+
RewriterStatement nrow = getNRow();
971+
972+
if (nrow == null)
973+
return false;
974+
975+
if (nrow.isLiteral() && nrow.getLiteral().equals(1L))
976+
return true;
977+
978+
if (nrow.isEClass() && nrow.getChild(0).getOperands().stream().anyMatch(el -> el.isLiteral() && el.getLiteral().equals(1L)))
979+
return true;
980+
981+
return false;
982+
}
983+
984+
public boolean isRowVector() {
985+
RewriterStatement ncol = getNCol();
986+
987+
if (ncol == null)
988+
return false;
989+
990+
if (ncol.isLiteral() && ncol.getLiteral().equals(1L))
991+
return true;
992+
993+
if (ncol.isEClass() && ncol.getChild(0).getOperands().stream().anyMatch(el -> el.isLiteral() && el.getLiteral().equals(1L)))
994+
return true;
995+
996+
return false;
997+
}
998+
969999
public List<String> toExecutableString(final RuleContext ctx) {
9701000
ArrayList<String> defList = new ArrayList<>();
9711001
prepareDefinitions(ctx, defList, new HashSet<>());
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
package org.apache.sysds.test.component.codegen.rewrite.functions;
2+
3+
import org.apache.sysds.hops.rewriter.RewriterRule;
4+
import org.apache.sysds.hops.rewriter.RewriterRuleBuilder;
5+
import org.apache.sysds.hops.rewriter.RewriterRuleCreator;
6+
import org.apache.sysds.hops.rewriter.RewriterRuleSet;
7+
import org.apache.sysds.hops.rewriter.RewriterStatement;
8+
import org.apache.sysds.hops.rewriter.RewriterUtils;
9+
import org.apache.sysds.hops.rewriter.RuleContext;
10+
import org.junit.BeforeClass;
11+
import org.junit.Test;
12+
13+
import java.util.List;
14+
import java.util.function.Function;
15+
16+
public class TestRuleSet {
17+
private static RuleContext ctx;
18+
private static Function<RewriterStatement, RewriterStatement> canonicalConverter;
19+
20+
@BeforeClass
21+
public static void setup() {
22+
ctx = RewriterUtils.buildDefaultContext();
23+
canonicalConverter = RewriterUtils.buildCanonicalFormConverter(ctx, false);
24+
}
25+
26+
@Test
27+
public void test1() {
28+
RewriterRule rule = new RewriterRuleBuilder(ctx)
29+
.setUnidirectional(true)
30+
.parseGlobalVars("MATRIX:A,B")
31+
.withParsedStatement("sum(%*%(A, t(B)))")
32+
.toParsedStatement("sum(*(A, B))")
33+
.build();
34+
35+
RewriterRuleSet rs = new RewriterRuleSet(ctx, List.of(rule));
36+
37+
RewriterStatement stmt = RewriterUtils.parse("sum(%*%(colVec(A), t(colVec(B))))", ctx, "MATRIX:A,B");
38+
39+
RewriterRuleSet.ApplicableRule ar = rs.acceleratedFindFirst(stmt);
40+
41+
assert ar != null;
42+
43+
stmt = ar.rule.apply(ar.matches.get(0), stmt, ar.forward, false);
44+
System.out.println(stmt.toParsableString(ctx));
45+
}
46+
47+
@Test
48+
public void test2() {
49+
RewriterRule rule = new RewriterRuleBuilder(ctx)
50+
.setUnidirectional(true)
51+
.parseGlobalVars("MATRIX:A,B")
52+
.withParsedStatement("as.matrix(sum(colVec(A)))")
53+
.toParsedStatement("rowSums(colVec(A))")
54+
.build();
55+
56+
RewriterRuleSet rs = new RewriterRuleSet(ctx, List.of(rule));
57+
58+
RewriterStatement stmt = RewriterUtils.parse("as.matrix(sum(t(colVec(A))))", ctx, "MATRIX:A,B");
59+
60+
RewriterRuleSet.ApplicableRule ar = rs.acceleratedFindFirst(stmt);
61+
62+
assert ar != null;
63+
64+
stmt = ar.rule.apply(ar.matches.get(0), stmt, ar.forward, false);
65+
System.out.println(stmt.toParsableString(ctx));
66+
}
67+
}

0 commit comments

Comments
 (0)