Skip to content

Commit edc3f17

Browse files
committed
NNZ estimatior
1 parent f7d4cb4 commit edc3f17

File tree

16 files changed

+121
-24
lines changed

16 files changed

+121
-24
lines changed

src/main/java/org/apache/sysds/hops/rewriter/RewriterRuleCreator.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
import org.apache.commons.collections4.bidimap.DualHashBidiMap;
44
import org.apache.commons.lang3.mutable.MutableBoolean;
55
import org.apache.sysds.hops.Hop;
6-
import scala.App;
6+
import org.apache.sysds.hops.rewriter.dml.DMLCodeGenerator;
7+
import org.apache.sysds.hops.rewriter.dml.DMLExecutor;
8+
import org.apache.sysds.hops.rewriter.estimators.RewriterCostEstimator;
79

810
import java.util.ArrayList;
911
import java.util.Collections;

src/main/java/org/apache/sysds/hops/rewriter/RewriterRuntimeUtils.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,19 @@
1313
import org.apache.sysds.hops.OptimizerUtils;
1414
import org.apache.sysds.hops.ReorgOp;
1515
import org.apache.sysds.hops.UnaryOp;
16+
import org.apache.sysds.hops.rewriter.dml.DMLExecutor;
1617
import org.apache.sysds.parser.DMLProgram;
1718
import org.apache.sysds.parser.ForStatement;
1819
import org.apache.sysds.parser.ForStatementBlock;
1920
import org.apache.sysds.parser.FunctionStatement;
2021
import org.apache.sysds.parser.FunctionStatementBlock;
2122
import org.apache.sysds.parser.IfStatement;
2223
import org.apache.sysds.parser.IfStatementBlock;
23-
import org.apache.sysds.parser.Statement;
2424
import org.apache.sysds.parser.StatementBlock;
2525
import org.apache.sysds.parser.WhileStatement;
2626
import org.apache.sysds.parser.WhileStatementBlock;
2727

2828
import javax.annotation.Nullable;
29-
import javax.validation.constraints.Null;
3029
import java.io.BufferedReader;
3130
import java.io.BufferedWriter;
3231
import java.io.FileReader;

src/main/java/org/apache/sysds/hops/rewriter/RewriterStatement.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import org.apache.commons.lang3.mutable.MutableInt;
77
import org.apache.logging.log4j.util.TriConsumer;
88
import org.apache.sysds.hops.rewriter.assertions.RewriterAssertions;
9+
import org.apache.sysds.hops.rewriter.estimators.RewriterCostEstimator;
910
import scala.Tuple2;
1011

1112
import java.util.ArrayList;
@@ -1171,6 +1172,10 @@ public static RewriterStatement ensureFloat(final RuleContext ctx, RewriterState
11711172
return castFloat(ctx, stmt);
11721173
}
11731174

1175+
public static RewriterStatement nnz(RewriterStatement of, final RuleContext ctx) {
1176+
return new RewriterInstruction().as(UUID.randomUUID().toString()).withInstruction("_nnz").withOps(of).consolidate(ctx);
1177+
}
1178+
11741179
public static RewriterStatement literal(final RuleContext ctx, Object literal) {
11751180
if (literal instanceof Double) {
11761181
return new RewriterDataType().as(literal.toString()).ofType("FLOAT").asLiteral(literal).consolidate(ctx);

src/main/java/org/apache/sysds/hops/rewriter/codegen/RewriterCodeGen.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import org.apache.sysds.hops.Hop;
55
import org.apache.sysds.hops.rewriter.assertions.RewriterAssertions;
6-
import org.apache.sysds.hops.rewriter.RewriterCostEstimator;
6+
import org.apache.sysds.hops.rewriter.estimators.RewriterCostEstimator;
77
import org.apache.sysds.hops.rewriter.RewriterDataType;
88
import org.apache.sysds.hops.rewriter.RewriterRule;
99
import org.apache.sysds.hops.rewriter.RewriterStatement;

src/main/java/org/apache/sysds/hops/rewriter/DMLCodeGenerator.java renamed to src/main/java/org/apache/sysds/hops/rewriter/dml/DMLCodeGenerator.java

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
1-
package org.apache.sysds.hops.rewriter;
1+
package org.apache.sysds.hops.rewriter.dml;
22

33
import org.apache.commons.lang3.NotImplementedException;
44
import org.apache.commons.lang3.function.TriFunction;
55
import org.apache.commons.lang3.mutable.MutableInt;
6+
import org.apache.sysds.hops.rewriter.RewriterInstruction;
7+
import org.apache.sysds.hops.rewriter.RewriterRule;
8+
import org.apache.sysds.hops.rewriter.RewriterStatement;
9+
import org.apache.sysds.hops.rewriter.RewriterUtils;
10+
import org.apache.sysds.hops.rewriter.RuleContext;
611
import scala.Tuple2;
712

813
import java.util.ArrayList;
@@ -13,9 +18,7 @@
1318
import java.util.Map;
1419
import java.util.Random;
1520
import java.util.Set;
16-
import java.util.function.BiFunction;
1721
import java.util.function.Consumer;
18-
import java.util.function.Function;
1922
import java.util.stream.Collectors;
2023

2124
public class DMLCodeGenerator {

src/main/java/org/apache/sysds/hops/rewriter/DMLExecutor.java renamed to src/main/java/org/apache/sysds/hops/rewriter/dml/DMLExecutor.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package org.apache.sysds.hops.rewriter;
1+
package org.apache.sysds.hops.rewriter.dml;
22

33
import org.apache.sysds.api.DMLScript;
44

src/main/java/org/apache/sysds/hops/rewriter/RewriterCostEstimator.java renamed to src/main/java/org/apache/sysds/hops/rewriter/estimators/RewriterCostEstimator.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
1-
package org.apache.sysds.hops.rewriter;
1+
package org.apache.sysds.hops.rewriter.estimators;
22

33
import org.apache.commons.lang3.mutable.MutableLong;
44
import org.apache.commons.lang3.mutable.MutableObject;
5+
import org.apache.sysds.hops.rewriter.RewriterInstruction;
6+
import org.apache.sysds.hops.rewriter.RewriterRule;
7+
import org.apache.sysds.hops.rewriter.RewriterStatement;
8+
import org.apache.sysds.hops.rewriter.RewriterUtils;
9+
import org.apache.sysds.hops.rewriter.RuleContext;
510
import org.apache.sysds.hops.rewriter.assertions.RewriterAssertionUtils;
611
import org.apache.sysds.hops.rewriter.assertions.RewriterAssertions;
712
import scala.Tuple2;
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
package org.apache.sysds.hops.rewriter.estimators;
2+
3+
import org.apache.sysds.hops.rewriter.ConstantFoldingFunctions;
4+
import org.apache.sysds.hops.rewriter.RewriterInstruction;
5+
import org.apache.sysds.hops.rewriter.RewriterStatement;
6+
import org.apache.sysds.hops.rewriter.RuleContext;
7+
import org.apache.sysds.hops.rewriter.utils.StatementUtils;
8+
9+
import java.util.Map;
10+
import java.util.UUID;
11+
12+
public class RewriterSparsityEstimator {
13+
public static RewriterStatement estimateNNZ(RewriterStatement stmt, Map<RewriterStatement, Long> matrixNNZs, final RuleContext ctx) {
14+
long[] nnzs = stmt.getOperands().stream().mapToLong(matrixNNZs::get).toArray();
15+
16+
switch (stmt.trueInstruction()) {
17+
case "%*%":
18+
return new RewriterInstruction("*", ctx, StatementUtils.min(ctx, new RewriterInstruction("*", ctx, stmt.getNRow(), stmt.getNCol()), RewriterStatement.nnz(stmt.getChild(1), ctx)), new RewriterInstruction("*", ctx, stmt.getNRow(), stmt.getNCol()), RewriterStatement.nnz(stmt.getChild(0), ctx));
19+
}
20+
21+
switch (stmt.trueTypedInstruction(ctx)) {
22+
case "*(MATRIX,MATRIX)":
23+
return StatementUtils.min(ctx, RewriterStatement.nnz(stmt.getChild(0), ctx), RewriterStatement.nnz(stmt.getChild(1), ctx));
24+
case "*(MATRIX,FLOAT)":
25+
if (stmt.getChild(1).isLiteral() && ConstantFoldingFunctions.overwritesLiteral(((Float) stmt.getChild(1).getLiteral()), "*", ctx) != null)
26+
return RewriterStatement.literal(ctx, 0L);
27+
return RewriterStatement.nnz(stmt.getChild(0), ctx);
28+
case "*(FLOAT,MATRIX)":
29+
if (stmt.getChild(0).isLiteral() && ConstantFoldingFunctions.overwritesLiteral(((Float) stmt.getChild(0).getLiteral()), "*", ctx) != null)
30+
return RewriterStatement.literal(ctx, 0L);
31+
return RewriterStatement.nnz(stmt.getChild(1), ctx);
32+
case "+(MATRIX,MATRIX)":
33+
case "-(MATRIX,MATRIX)":
34+
return StatementUtils.min(ctx, new RewriterInstruction("+", ctx, RewriterStatement.nnz(stmt.getChild(0), ctx), RewriterStatement.nnz(stmt.getChild(1), ctx)), StatementUtils.length(ctx, stmt));
35+
case "+(MATRIX,FLOAT)":
36+
case "-(MATRIX,FLOAT)":
37+
if (stmt.getChild(1).isLiteral() && ConstantFoldingFunctions.isNeutralElement(stmt.getChild(1).getLiteral(), "+"))
38+
return RewriterStatement.nnz(stmt.getChild(0), ctx);
39+
return StatementUtils.length(ctx, stmt);
40+
case "+(FLOAT,MATRIX)":
41+
case "-(FLOAT,MATRIX)":
42+
if (stmt.getChild(0).isLiteral() && ConstantFoldingFunctions.isNeutralElement(stmt.getChild(0).getLiteral(), "+"))
43+
return RewriterStatement.nnz(stmt.getChild(1), ctx);
44+
return StatementUtils.length(ctx, stmt);
45+
case "!=(MATRIX,MATRIX)":
46+
if (stmt.getChild(0).equals(stmt.getChild(1)))
47+
return RewriterStatement.literal(ctx, 0L);
48+
return StatementUtils.length(ctx, stmt);
49+
50+
case "sqrt(MATRIX)":
51+
return RewriterStatement.nnz(stmt.getChild(0), ctx);
52+
}
53+
54+
return StatementUtils.length(ctx, stmt);
55+
}
56+
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
package org.apache.sysds.hops.rewriter.utils;
2+
3+
import org.apache.sysds.hops.rewriter.RewriterInstruction;
4+
import org.apache.sysds.hops.rewriter.RewriterStatement;
5+
import org.apache.sysds.hops.rewriter.RuleContext;
6+
7+
public class StatementUtils {
8+
public static RewriterStatement max(final RuleContext ctx, RewriterStatement... of) {
9+
if (of.length == 1)
10+
return of[0];
11+
12+
if (of.length == 2)
13+
return new RewriterInstruction("max", ctx, of);
14+
15+
throw new UnsupportedOperationException();
16+
}
17+
18+
public static RewriterStatement min(final RuleContext ctx, RewriterStatement... of) {
19+
if (of.length == 1)
20+
return of[0];
21+
22+
if (of.length == 2)
23+
return new RewriterInstruction("min", ctx, of);
24+
25+
throw new UnsupportedOperationException();
26+
}
27+
28+
public static RewriterStatement length(final RuleContext ctx, RewriterStatement matrix) {
29+
if (!matrix.getResultingDataType(ctx).equals("MATRIX"))
30+
throw new IllegalArgumentException();
31+
32+
return new RewriterInstruction("*", ctx, matrix.getNRow(), matrix.getNCol());
33+
}
34+
}

src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterClusteringTest.java

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,10 @@
33
import org.apache.commons.lang3.mutable.MutableInt;
44
import org.apache.sysds.hops.rewriter.RewriteAutomaticallyGenerated;
55
import org.apache.sysds.hops.rewriter.RewriterAlphabetEncoder;
6-
import org.apache.sysds.hops.rewriter.RewriterCostEstimator;
6+
import org.apache.sysds.hops.rewriter.estimators.RewriterCostEstimator;
77
import org.apache.sysds.hops.rewriter.RewriterDatabase;
88
import org.apache.sysds.hops.rewriter.RewriterEquivalenceDatabase;
99
import org.apache.sysds.hops.rewriter.RewriterHeuristic;
10-
import org.apache.sysds.hops.rewriter.RewriterInstruction;
1110
import org.apache.sysds.hops.rewriter.RewriterRule;
1211
import org.apache.sysds.hops.rewriter.RewriterRuleCollection;
1312
import org.apache.sysds.hops.rewriter.RewriterRuleCreator;
@@ -18,7 +17,6 @@
1817
import org.apache.sysds.hops.rewriter.RuleContext;
1918
import org.apache.sysds.hops.rewriter.TopologicalSort;
2019
import scala.Tuple2;
21-
import scala.Tuple3;
2220
import scala.Tuple4;
2321
import scala.Tuple5;
2422

@@ -31,7 +29,6 @@
3129
import java.util.Comparator;
3230
import java.util.List;
3331
import java.util.Random;
34-
import java.util.UUID;
3532
import java.util.concurrent.atomic.AtomicLong;
3633
import java.util.function.Function;
3734
import java.util.stream.Collectors;

0 commit comments

Comments
 (0)