Skip to content

Commit cb05568

Browse files
committed
Some improvements
1 parent 8cbddbf commit cb05568

File tree

3 files changed

+78
-9
lines changed

3 files changed

+78
-9
lines changed

src/main/java/org/apache/sysds/hops/rewriter/RewriterAssertions.java

Lines changed: 71 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import org.apache.commons.lang3.NotImplementedException;
44
import org.apache.commons.lang3.function.TriFunction;
5+
import org.apache.commons.lang3.mutable.MutableObject;
56
import scala.Tuple2;
67

78
import java.util.ArrayList;
@@ -279,14 +280,14 @@ public void updateAssertionContents(Function<RewriterStatement, RewriterStatemen
279280
// TODO: What about backRef?
280281
}
281282

282-
public Stream<RewriterStatement> streamOfContents() {
283+
public Stream<Tuple2<RewriterStatement, RewriterStatement.RewriterPredecessor>> streamOfContents() {
283284
return allAssertions.stream().flatMap(assertion -> {
284285
if (assertion.stmt != null) {
285286
if (assertion.backRef != null)
286-
return Stream.of(assertion.stmt, assertion.backRef);
287-
return Stream.of(assertion.stmt);
287+
return Stream.of(new Tuple2<>(assertion.stmt, new RewriterStatement.RewriterPredecessor(this, assertion)), new Tuple2<>(assertion.backRef, new RewriterStatement.RewriterPredecessor(this, assertion)));
288+
return Stream.of(new Tuple2<>(assertion.stmt, new RewriterStatement.RewriterPredecessor(this, assertion)));
288289
} else {
289-
return assertion.set.stream();
290+
return assertion.set.stream().map(stmt -> new Tuple2<>(stmt, new RewriterStatement.RewriterPredecessor(this, assertion)));
290291
}
291292
});
292293
}
@@ -585,6 +586,72 @@ else if (root.getMeta("_assertions") != null)
585586
return eClass;
586587
}
587588

589+
// This removes E-Classes that are not actually E-Classes like _EClass(argList(nrow(A), nrow(A))), or _EClass(argList(nrow(A), _backRef.INT()))
590+
public RewriterStatement cleanupEClasses(RewriterStatement expressionRoot) {
591+
Set<RewriterAssertion> toRemoveList = new HashSet<>();
592+
Map<RewriterStatement, RewriterStatement> toRemove = new HashMap<>();
593+
594+
for (RewriterAssertion assertion : allAssertions) {
595+
int previousSize = assertion.set.size();
596+
if (assertion.stmt != null) {
597+
// Eliminate top-level back-refs
598+
assertion.set.removeIf(el -> el.isInstruction() && el.trueInstruction().startsWith("_backRef") && el.getMeta("_backRef").equals(assertion.stmt));
599+
}
600+
601+
if (assertion.set.size() < 2) {
602+
toRemoveList.add(assertion);
603+
604+
if (assertion.stmt != null)
605+
toRemove.put(assertion.stmt, assertion.set.stream().findFirst().get());
606+
}
607+
608+
if (previousSize != assertion.set.size() && assertion.stmt != null) {
609+
// Then we need to update the EClass
610+
assertion.stmt.getChild(0).getOperands().removeIf(el -> !assertion.set.contains(el));
611+
612+
if (assertion.stmt.getChild(0).getOperands().size() != assertion.set.size()) {
613+
// Then there are still duplicates which we need to rule out
614+
Set<RewriterStatement> visited = new HashSet<>();
615+
List<RewriterStatement> eItems = assertion.stmt.getChild(0).getOperands();
616+
for (int i = 0; i < eItems.size(); i++) {
617+
if (!visited.add(eItems.get(i)))
618+
eItems.remove(i--);
619+
}
620+
}
621+
}
622+
}
623+
624+
if (!toRemoveList.isEmpty()) {
625+
allAssertions.removeAll(toRemoveList);
626+
627+
if (!toRemove.isEmpty()) {
628+
if (expressionRoot.isEClass()) {
629+
RewriterStatement mNew = toRemove.get(expressionRoot);
630+
631+
if (mNew != null)
632+
expressionRoot = mNew;
633+
}
634+
635+
expressionRoot.forEachPostOrder((cur, pred) -> {
636+
cur.allChildren().forEach(t -> {
637+
if (t._1.isEClass()) {
638+
RewriterStatement mNew = toRemove.get(t._1);
639+
if (mNew != null) {
640+
if (t._2.isOperand()) {
641+
cur.getOperands().set(t._2.getIndex(), mNew);
642+
} else if (t._2.isMetaObject()) {
643+
cur.unsafePutMeta(t._2.getMetaKey(), mNew);
644+
}
645+
}
646+
}
647+
});
648+
}, true);
649+
}
650+
}
651+
652+
return expressionRoot;
653+
}
654+
588655
private void updateRecursively(RewriterStatement cur) {
589656
for (int i = 0; i < cur.getOperands().size(); i++) {
590657
RewriterStatement child = cur.getChild(i);

src/main/java/org/apache/sysds/hops/rewriter/RewriterStatement.java

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import java.util.function.Function;
2828
import java.util.regex.Pattern;
2929
import java.util.stream.Collectors;
30+
import java.util.stream.IntStream;
3031
import java.util.stream.Stream;
3132

3233
public abstract class RewriterStatement {
@@ -940,18 +941,18 @@ protected void nestedCopyOrInjectMetaStatements(Map<RewriterStatement, RewriterS
940941

941942
// This returns a stream of all children including metadata and assertions if available
942943
// This may contain loops in case of back references
943-
public Stream<RewriterStatement> allChildren() {
944-
Stream<RewriterStatement> stream = getOperands().stream();
944+
public Stream<Tuple2<RewriterStatement, RewriterPredecessor>> allChildren() {
945+
Stream<Tuple2<RewriterStatement, RewriterPredecessor>> stream = IntStream.range(0, getOperands().size()).mapToObj(i -> new Tuple2<>(getOperands().get(i), new RewriterPredecessor(this, i)));
945946
RewriterStatement ncol = getNCol();
946947
RewriterStatement nrow = getNRow();
947948
RewriterStatement backRef = getBackRef();
948949

949950
if (ncol != null)
950-
stream = Stream.concat(stream, Stream.of(ncol));
951+
stream = Stream.concat(stream, Stream.of(new Tuple2<>(ncol, new RewriterPredecessor(this, "ncol"))));
951952
if (nrow != null)
952-
stream = Stream.concat(stream, Stream.of(nrow));
953+
stream = Stream.concat(stream, Stream.of(new Tuple2<>(nrow, new RewriterPredecessor(this, "nrow"))));
953954
if (backRef != null)
954-
stream = Stream.concat(stream, Stream.of(backRef));
955+
stream = Stream.concat(stream, Stream.of(new Tuple2<>(backRef, new RewriterPredecessor(this, "_backRef"))));
955956

956957
RewriterAssertions assertions = (RewriterAssertions) getMeta("_assertions");
957958

src/main/java/org/apache/sysds/hops/rewriter/RewriterUtils.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1264,6 +1264,7 @@ public static Function<RewriterStatement, RewriterStatement> buildCanonicalFormC
12641264
}, debug);
12651265

12661266
stmt = foldConstants(stmt, ctx);
1267+
stmt = stmt.getAssertions(ctx).cleanupEClasses(stmt);
12671268

12681269
// TODO: After this, stuff like CSE, A-A = 0, etc. must still be applied
12691270

0 commit comments

Comments
 (0)