Skip to content

Commit b8d3424

Browse files
committed
Some more updates
1 parent 05b39d8 commit b8d3424

2 files changed

Lines changed: 17 additions & 8 deletions

File tree

src/main/java/org/apache/sysds/hops/rewriter/DMLCodeGenerator.java

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,15 @@ public class DMLCodeGenerator {
9393
});
9494

9595
customEncoders.put("cast.FLOAT", (stmt, sb, tmpVars) -> {
96-
sb.append("as.scalar(");
97-
appendExpression(stmt.getChild(0), sb, tmpVars);
98-
sb.append(')');
96+
if (stmt.getChild(0).getResultingDataType(ctx).equals("MATRIX")) {
97+
sb.append("as.scalar(");
98+
appendExpression(stmt.getChild(0), sb, tmpVars);
99+
sb.append(')');
100+
} else {
101+
sb.append("as.double(");
102+
appendExpression(stmt.getChild(0), sb, tmpVars);
103+
sb.append(')');
104+
}
99105

100106
return true;
101107
});

src/main/java/org/apache/sysds/hops/rewriter/RewriterRuleCreator.java

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -216,15 +216,18 @@ public static boolean validateRuleCorrectness(RewriterRule rule, final RuleConte
216216
}
217217

218218
public static boolean validateRuleApplicability(RewriterRule rule, final RuleContext ctx) {
219+
RewriterStatement _mstmt = rule.getStmt1();
219220
if (ctx.metaPropagator != null)
220-
ctx.metaPropagator.apply(rule.getStmt1());
221+
ctx.metaPropagator.apply(_mstmt);
221222

222-
Set<RewriterStatement> vars = DMLCodeGenerator.getVariables(rule.getStmt1());
223+
final RewriterStatement stmt1 = RewriterUtils.unfuseOperators(_mstmt, ctx);
224+
225+
Set<RewriterStatement> vars = DMLCodeGenerator.getVariables(stmt1);
223226
Set<String> varNames = vars.stream().map(RewriterStatement::getId).collect(Collectors.toSet());
224227
String code2Header = DMLCodeGenerator.generateDMLVariables(vars);
225-
String code2 = code2Header + "\nresult = " + DMLCodeGenerator.generateDML(rule.getStmt1());
228+
String code2 = code2Header + "\nresult = " + DMLCodeGenerator.generateDML(stmt1);
226229

227-
boolean isMatrix = rule.getStmt1().getResultingDataType(ctx).equals("MATRIX");
230+
boolean isMatrix = stmt1.getResultingDataType(ctx).equals("MATRIX");
228231

229232
if (isMatrix)
230233
code2 += "\nprint(lineage(result))";
@@ -291,7 +294,7 @@ public static boolean validateRuleApplicability(RewriterRule rule, final RuleCon
291294

292295
Map<RewriterStatement, RewriterStatement> createdObjects = new HashMap<>();
293296

294-
RewriterStatement stmt1ReplaceNCols = rule.getStmt1().nestedCopyOrInject(createdObjects, mstmt -> {
297+
RewriterStatement stmt1ReplaceNCols = stmt1.nestedCopyOrInject(createdObjects, mstmt -> {
295298
if (mstmt.isInstruction() && (mstmt.trueInstruction().equals("ncol") || mstmt.trueInstruction().equals("nrow")))
296299
return RewriterStatement.literal(ctx, DMLCodeGenerator.MATRIX_DIMS);
297300
return null;

0 commit comments

Comments
 (0)