Skip to content

Commit 81f1b0f

Browse files
committed
Some fixes
1 parent ff36f12 commit 81f1b0f

File tree

5 files changed

+50
-27
lines changed

5 files changed

+50
-27
lines changed

src/main/java/org/apache/sysds/hops/rewriter/RewriterRuleSet.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -335,8 +335,14 @@ public static RewriterRuleSet deserialize(String[] data, final RuleContext ctx)
335335
currentLines.clear();
336336
}
337337

338-
for (RewriterRule rule : rules)
339-
rule.determineConditionalApplicability();
338+
for (RewriterRule rule : rules) {
339+
try {
340+
rule.determineConditionalApplicability();
341+
} catch (Exception e) {
342+
System.err.println("Error while determining the conditional ability of " + rule.toString());
343+
e.printStackTrace();
344+
}
345+
}
340346

341347
return new RewriterRuleSet(ctx, rules);
342348
}

src/main/java/org/apache/sysds/hops/rewriter/codegen/CodeGenUtils.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package org.apache.sysds.hops.rewriter.codegen;
22

33
import org.apache.commons.lang3.NotImplementedException;
4+
import org.apache.sysds.common.Types;
45
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
56
import org.apache.sysds.hops.rewriter.assertions.RewriterAssertions;
67
import org.apache.sysds.hops.rewriter.RewriterStatement;
@@ -69,6 +70,8 @@ public static String getOpCode(RewriterStatement stmt, final RuleContext ctx) {
6970
return "Types.OpOp1.MULT2";
7071
case "cast.MATRIX":
7172
return "Types.OpOp1.CAST_AS_MATRIX";
73+
case "cast.FLOAT":
74+
return "Types.OpOp1.CAST_AS_SCALAR";
7275
case "const":
7376
return "Types.OpOpDG.RAND";
7477
}
@@ -203,6 +206,7 @@ public static String getOpClass(RewriterStatement stmt, final RuleContext ctx) {
203206
case "round":
204207
case "*2":
205208
case "cast.MATRIX":
209+
case "cast.FLOAT":
206210
case "nrow":
207211
case "ncol":
208212
return "UnaryOp";

src/main/java/org/apache/sysds/hops/rewriter/codegen/RewriterCodeGen.java

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import org.apache.sysds.hops.LiteralOp;
77
import org.apache.sysds.hops.UnaryOp;
88
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
9+
import org.apache.sysds.hops.rewriter.RewriteAutomaticallyGenerated;
10+
import org.apache.sysds.hops.rewriter.RewriterRuleSet;
911
import org.apache.sysds.hops.rewriter.assertions.RewriterAssertions;
1012
import org.apache.sysds.hops.rewriter.estimators.RewriterCostEstimator;
1113
import org.apache.sysds.hops.rewriter.RewriterDataType;
@@ -15,6 +17,10 @@
1517
import org.codehaus.janino.SimpleCompiler;
1618
import scala.Tuple2;
1719

20+
import java.io.FileWriter;
21+
import java.io.IOException;
22+
import java.nio.file.Files;
23+
import java.nio.file.Paths;
1824
import java.util.AbstractCollection;
1925
import java.util.ArrayList;
2026
import java.util.HashMap;
@@ -28,6 +34,29 @@
2834
public class RewriterCodeGen {
2935
public static boolean DEBUG = true;
3036

37+
public static String generateRewritesFromFiles(List<String> filePaths, String targetFile, boolean optimize, final RuleContext ctx) throws IOException {
38+
return generateRewritesFromFiles(filePaths, targetFile, optimize, 2, true, true, ctx);
39+
}
40+
41+
public static String generateRewritesFromFiles(List<String> filePaths, String targetFile, boolean optimize, int maxOptimizationDepth, boolean includePackageInfo, boolean maintainStatistics, final RuleContext ctx) throws IOException {
42+
List<String> lines = new ArrayList<>();
43+
44+
for (String path : filePaths) {
45+
lines.addAll(Files.readAllLines(Paths.get(path)));
46+
}
47+
48+
RewriterRuleSet ruleSet = RewriterRuleSet.deserialize(lines, ctx);
49+
String javaCode = ruleSet.toJavaCode("GeneratedRewriteClass", optimize, maxOptimizationDepth, includePackageInfo, true, maintainStatistics);
50+
51+
try (FileWriter writer = new FileWriter(targetFile)) {
52+
writer.write(javaCode);
53+
} catch (IOException e) {
54+
throw e;
55+
}
56+
57+
return javaCode;
58+
}
59+
3160
public static Function<Hop, Hop> compileRewrites(String className, List<Tuple2<String, RewriterRule>> rewrites, final RuleContext ctx, boolean ignoreErrors, boolean printErrors) throws Exception {
3261
String code = generateClass(className, rewrites, false, false, ctx, ignoreErrors, printErrors);
3362
System.out.println("Compiling code:\n" + code);
@@ -265,7 +294,7 @@ private static void buildMatchingSequence(String name, RewriterStatement from, L
265294
for (int i = 1; i < msb.size(); i++) {
266295
indent(indentation+1, sb);
267296
sb.append("case " + i + ": {");
268-
buildNewHop(name, from, tos.get(i-1), sb, combinedAssertions, vars, ctx, indentation+2, maintainRewriteStats);
297+
buildNewHop(name, from, tos.get(i-1), sb, combinedAssertions, new HashMap<>(vars), ctx, indentation+2, maintainRewriteStats);
269298
indent(indentation+1, sb);
270299
sb.append("}\n");
271300
}

src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/CodeGenTests.java

Lines changed: 6 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -268,30 +268,14 @@ public void testLiteral() {
268268

269269
@Test
270270
public void codeGen() {
271-
try {
272-
List<String> lines = Files.readAllLines(Paths.get(RewriteAutomaticallyGenerated.FILE_PATH_MB));
273-
RewriterRuleSet ruleSet = RewriterRuleSet.deserialize(lines, ctx);
274-
275-
RewriterRuntimeUtils.printUnknowns = false;
276-
/*Set<RewriterRule> invalid_unoptimized = ruleSet.generateCodeAndTest(false, true);
277-
Set<RewriterRule> invalid_optimized = ruleSet.generateCodeAndTest(true, true);
278-
System.out.println("========== DIFF ===========");
279-
invalid_optimized.removeAll(invalid_unoptimized);
280-
for (RewriterRule rule : invalid_optimized) {
281-
System.out.println(rule);
282-
}*/
271+
List<String> files = List.of(RewriteAutomaticallyGenerated.FILE_PATH, RewriteAutomaticallyGenerated.FILE_PATH_CONDITIONAL);
272+
String targetPath = "/Users/janniklindemann/Dev/MScThesis/other/GeneratedRewriteClass.java";
283273

274+
try {
275+
// This is to specify that the generated code should print to the console if it modifies the DAG
276+
// This should be disabled when generating production code
284277
RewriterCodeGen.DEBUG = true;
285-
String javaCode = ruleSet.toJavaCode("GeneratedRewriteClass", true, 2, true, true, true);
286-
String filePath = "/Users/janniklindemann/Dev/MScThesis/other/GeneratedRewriteClass.java";
287-
288-
try (FileWriter writer = new FileWriter(filePath)) {
289-
writer.write(javaCode);
290-
} catch (IOException e) {
291-
e.printStackTrace();
292-
}
293-
294-
//System.out.println(javaCode);
278+
RewriterCodeGen.generateRewritesFromFiles(files, targetPath, true, ctx);
295279
} catch (IOException e) {
296280
e.printStackTrace();
297281
}

src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/RuleCreationTests.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,9 @@ public void validationTest3() {
101101
.withParsedStatement("cast.MATRIX(sum(colVec(A)))")
102102
.toParsedStatement("rowSums(colVec(A))")
103103
.build();
104-
104+
105105
assert RewriterRuleCreator.validateRuleCorrectness(rule, ctx);
106-
//assert RewriterRuleCreator.validateRuleApplicability(rule, ctx);
106+
assert !RewriterRuleCreator.validateRuleApplicability(rule, ctx);
107107
}
108108

109109
@Test

0 commit comments

Comments
 (0)