Skip to content

Commit 9c523e1

Browse files
committed
Bugfixes
1 parent 6e037e2 commit 9c523e1

File tree

2 files changed

+52
-4
lines changed

2 files changed

+52
-4
lines changed

src/main/java/org/apache/sysds/hops/rewriter/DMLCodeGenerator.java

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,8 @@ public static String generateRuleValidationDML(RewriterRule rule, double eps, St
9696

9797
StringBuilder sb = new StringBuilder();
9898

99-
for (RewriterStatement var : vars) {
99+
sb.append(generateDMLVariables(vars));
100+
/*for (RewriterStatement var : vars) {
100101
switch (var.getResultingDataType(ctx)) {
101102
case "MATRIX":
102103
sb.append(var.getId() + " = rand(rows=1000, cols=1000, min=0.0, max=1.0)\n");
@@ -113,7 +114,7 @@ public static String generateRuleValidationDML(RewriterRule rule, double eps, St
113114
default:
114115
throw new NotImplementedException(var.getResultingDataType(ctx));
115116
}
116-
}
117+
}*/
117118

118119
sb.append('\n');
119120
sb.append("R1 = ");
@@ -131,6 +132,41 @@ public static String generateRuleValidationDML(RewriterRule rule, double eps, St
131132
return sb.toString();
132133
}
133134

135+
public static String generateDMLVariables(RewriterStatement root) {
136+
Set<RewriterStatement> vars = new HashSet<>();
137+
root.forEachPostOrder((stmt, pred) -> {
138+
if (!stmt.isInstruction() && !stmt.isLiteral())
139+
vars.add(stmt);
140+
}, false);
141+
142+
return generateDMLVariables(vars);
143+
}
144+
145+
public static String generateDMLVariables(Set<RewriterStatement> vars) {
146+
StringBuilder sb = new StringBuilder();
147+
148+
for (RewriterStatement var : vars) {
149+
switch (var.getResultingDataType(ctx)) {
150+
case "MATRIX":
151+
sb.append(var.getId() + " = rand(rows=1000, cols=1000, min=0.0, max=1.0)\n");
152+
break;
153+
case "FLOAT":
154+
sb.append(var.getId() + " = as.scalar(rand())\n");
155+
break;
156+
case "INT":
157+
sb.append(var.getId() + " = as.integer(as.scalar(rand(min=0.0, max=10000.0)))\n");
158+
break;
159+
case "BOOL":
160+
sb.append(var.getId() + " = as.scalar(rand()) < 0.5\n");
161+
break;
162+
default:
163+
throw new NotImplementedException(var.getResultingDataType(ctx));
164+
}
165+
}
166+
167+
return sb.toString();
168+
}
169+
134170
public static String generateEqualityCheck(String stmt1Var, String stmt2Var, String dataType, double eps) {
135171
switch (dataType) {
136172
case "MATRIX":
@@ -145,6 +181,17 @@ public static String generateEqualityCheck(String stmt1Var, String stmt2Var, Str
145181
throw new NotImplementedException();
146182
}
147183

184+
public static String generateDMLDefs(RewriterStatement stmt) {
185+
Map<String, RewriterStatement> vars = new HashMap<>();
186+
187+
stmt.forEachPostOrder((cur, pred) -> {
188+
if (!cur.isInstruction() && !cur.isLiteral())
189+
vars.put(cur.getId(), cur);
190+
}, false);
191+
192+
return generateDMLDefs(vars);
193+
}
194+
148195
public static String generateDMLDefs(Map<String, RewriterStatement> defs) {
149196
StringBuilder sb = new StringBuilder();
150197

src/main/java/org/apache/sysds/hops/rewriter/RewriterRuleCreator.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,10 +138,11 @@ public static boolean validateRuleCorrectnessAndGains(RewriterRule rule, final R
138138
MutableBoolean isValid = new MutableBoolean(false);
139139
DMLExecutor.executeCode(code, DMLCodeGenerator.ruleValidationScript(sessionId, isValid::setValue));
140140

141-
String code2 = DMLCodeGenerator.generateDML(rule.getStmt1());
141+
String code2Header = DMLCodeGenerator.generateDMLVariables(rule.getStmt1());
142+
String code2 = code2Header + "\nresult = " + DMLCodeGenerator.generateDML(rule.getStmt1()) + "\nprint(lineage(result))";
142143
RewriterRuntimeUtils.attachHopInterceptor(prog -> {
143144
DMLExecutor.println("HERE");
144-
DMLExecutor.println(prog.getStatementBlocks().get(0).getHops().get(0).getInput(0));
145+
DMLExecutor.println(prog.getStatementBlocks().get(0).getHops().get(0).getInput(0).getInput(0));
145146
List<RewriterStatement> topLevelStmts = RewriterRuntimeUtils.getTopLevelHops(prog, ctx);
146147
DMLExecutor.println(topLevelStmts);
147148
// TODO: Evaluate cost and if our rule can still be applied

0 commit comments

Comments
 (0)