55import java .util .Collections ;
66import java .util .HashMap ;
77import java .util .HashSet ;
8+ import java .util .List ;
89import java .util .Map ;
910import java .util .Objects ;
1011import java .util .Set ;
1112import java .util .UUID ;
13+ import java .util .function .Consumer ;
1214import java .util .stream .Collectors ;
1315
1416public class RewriterAssertions {
1517 private final RuleContext ctx ;
1618 private Map <RewriterStatement , RewriterAssertion > assertionMatcher = new HashMap <>();
19+ // Tracks which statements are part of which assertions
20+ private Map <RewriterStatement , Set <RewriterAssertion >> partOfAssertion = new HashMap <>();
1721 private Set <RewriterAssertion > allAssertions = new HashSet <>();
1822
1923 public RewriterAssertions (final RuleContext ctx ) {
@@ -40,6 +44,7 @@ public RewriterAssertions(final RuleContext ctx) {
4044 return assertions;
4145 }*/
4246
47+ // TODO: Add parts of assertions map
4348 public static RewriterAssertions copy (RewriterAssertions old , Map <RewriterStatement , RewriterStatement > createdObjects , boolean removeOthers ) {
4449 //System.out.println("Copying: " + old);
4550 RewriterAssertions newAssertions = new RewriterAssertions (old .ctx );
@@ -69,10 +74,17 @@ public static RewriterAssertions copy(RewriterAssertions old, Map<RewriterStatem
6974 RewriterAssertion mapped = RewriterAssertion .from (newSet );
7075 if (assertion .stmt != null )
7176 mapped .stmt = createdObjects .get (assertion .stmt );
77+ if (assertion .backRef != null )
78+ mapped .backRef = createdObjects .get (assertion .backRef );
7279 mappedAssertions .put (assertion , mapped );
7380 return mapped ;
7481 }).filter (Objects ::nonNull ).collect (Collectors .toSet ());
7582
83+ newAssertions .partOfAssertion = old .partOfAssertion .entrySet ().stream ().collect (Collectors .toMap (
84+ v -> createdObjects .getOrDefault (v .getKey (), v .getKey ()),
85+ v -> v .getValue ().stream ().map (mappedAssertions ::get ).collect (Collectors .toSet ())
86+ ));
87+
7688 if (removeOthers ) {
7789 old .assertionMatcher .forEach ((k , v ) -> {
7890 RewriterStatement newK = createdObjects .get (k );
@@ -100,6 +112,7 @@ public static RewriterAssertions copy(RewriterAssertions old, Map<RewriterStatem
100112 }
101113
102114 //System.out.println("New: " + newAssertions);
115+ //System.out.println("New parts: " + newAssertions.partOfAssertion);
103116
104117 return newAssertions ;
105118 }
@@ -153,6 +166,19 @@ public boolean addEqualityAssertion(RewriterStatement stmt1, RewriterStatement s
153166 allAssertions .add (newAssertion );
154167
155168 resolveCyclicAssertions (newAssertion );
169+
170+ forEachUniqueElementInAssertion (newAssertion , cur -> {
171+ partOfAssertion .compute (cur , (k , v ) -> {
172+ if (v == null )
173+ v = new HashSet <>();
174+
175+ v .add (newAssertion );
176+ return v ;
177+ });
178+ });
179+
180+ //System.out.println("MNew parts: " + partOfAssertion);
181+
156182 //System.out.println("New assertion1: " + newAssertion);
157183 return true ;
158184 }
@@ -171,6 +197,18 @@ public boolean addEqualityAssertion(RewriterStatement stmt1, RewriterStatement s
171197 updateInstance (existingAssertion .stmt .getChild (0 ), existingAssertion .set );
172198
173199 resolveCyclicAssertions (existingAssertion );
200+
201+ toAssert .forEachPreOrder (cur -> {
202+ partOfAssertion .compute (cur , (k , v ) -> {
203+ if (v == null )
204+ v = new HashSet <>();
205+
206+ v .add (existingAssertion );
207+ return v ;
208+ });
209+ return true ;
210+ });
211+
174212 //System.out.println("New assertion2: " + existingAssertion);
175213 return true ;
176214 }
@@ -198,23 +236,53 @@ public boolean addEqualityAssertion(RewriterStatement stmt1, RewriterStatement s
198236 //System.out.println("New assertion3: " + stmt2Assertions);
199237 resolveCyclicAssertions (stmt2Assertions );
200238
239+ final RewriterAssertion assertionToRemove = stmt1Assertions ;
240+ final RewriterAssertion assertionToExtend = stmt2Assertions ;
241+ forEachUniqueElementInAssertion (stmt1Assertions , cur -> {
242+ Set <RewriterAssertion > v = partOfAssertion .get (cur );
243+ v .remove (assertionToRemove );
244+ v .add (assertionToExtend );
245+ });
246+
201247 return true ;
202248 }
203249
250+ private void forEachUniqueElementInAssertion (RewriterAssertion assertion , Consumer <RewriterStatement > consumer ) {
251+ Set <RewriterStatement > visited = new HashSet <>();
252+ for (RewriterStatement eq : assertion .set ) {
253+ eq .forEachPreOrderWithDuplicates (cur -> {
254+ if (!visited .add (cur ))
255+ return false ;
256+
257+ consumer .accept (cur );
258+ return true ;
259+ });
260+ }
261+ }
262+
204263 // Replace cycles with _backRef()
264+ // TODO: Also copy duplicate referenced sub-trees to avoid cycles (e.g. _EClass(a*b+c, a) and sqrt(a*b) => What to do with a in a*b? _backRef or _EClass?)
265+ // TODO: This requires a guarantee that reference counts are intact
205266 private void resolveCyclicAssertions (RewriterAssertion assertion ) {
206267 if (assertion .stmt == null )
207268 return ;
208269
209270 //System.out.println("Resolving cycles in: " + assertion);
210271
211- String rType = assertion .stmt .getResultingDataType (ctx );
272+ RewriterStatement backref = assertion .getBackRef (ctx , this );
273+ //String rType = assertion.stmt.getResultingDataType(ctx);
212274
213- RewriterStatement backref = new RewriterInstruction ()
275+ /* RewriterStatement backref = new RewriterInstruction()
214276 .as(UUID.randomUUID().toString())
215277 .withInstruction("_backRef." + rType)
216278 .consolidate(ctx);
217- backref .unsafePutMeta ("_backRef" , assertion .stmt );
279+ backref.unsafePutMeta("_backRef", assertion.stmt);*/
280+
281+ // Check if any sub-graph of the E-Graph is referenced outside the E-Class
282+ // If any child of the duplicate reference would create a back-reference, we need to copy the entire sub-graph
283+ //HashMap<RewriterStatement, Integer> refCtr = new HashMap<>();
284+
285+
218286
219287 for (RewriterStatement eq : assertion .set ) {
220288 eq .forEachPreOrder ((cur , parent , pIdx ) -> {
@@ -241,24 +309,30 @@ public RewriterStatement getAssertionStatement(RewriterStatement stmt, RewriterS
241309 //System.out.println("In: " + this);
242310 RewriterAssertion set = assertionMatcher .get (stmt );
243311
244- if (set == null )
312+ if (set == null || set . getEClassStmt ( ctx , this ). getChild ( 0 ) == parent )
245313 return stmt ;
246314
247- RewriterStatement mstmt = set .stmt ;
315+ //System.out.println("EClassStmt: " + set.getEClassStmt(ctx, this).getChild(0));
316+ if (parent != null && parent != set .getEClassStmt (ctx , this ).getChild (0 ) && partOfAssertion .getOrDefault (parent , Collections .emptySet ()).contains (set ))
317+ return set .getBackRef (ctx , this );
248318
249- if (mstmt == null ) {
319+ /*RewriterStatement mstmt = set.stmt;
320+
321+ if (mstmt == null)
322+ mstmt = set.getEClassStmt(ctx, this);
323+ {
250324 // Then we create a new statement for it
251325 RewriterStatement argList = new RewriterInstruction().as(UUID.randomUUID().toString()).withInstruction("argList").withOps(set.set.toArray(RewriterStatement[]::new));
252326 mstmt = new RewriterInstruction().as(UUID.randomUUID().toString()).withInstruction("_EClass").withOps(argList);
253327 mstmt.consolidate(ctx);
254328 set.stmt = mstmt;
255329 assertionMatcher.put(set.stmt, set);
256330 resolveCyclicAssertions(set);
257- } else if (mstmt .getChild (0 ) == parent ) {
331+ }*/ /* else if (mstmt.getChild(0) == parent) {
258332 return stmt;
259- }
333+ }*/
260334
261- return mstmt ;
335+ return set . getEClassStmt ( ctx , this ) ;
262336 }
263337
264338 // TODO: This does not handle metadata
@@ -324,6 +398,52 @@ private void updateInstance(RewriterStatement stmt, Set<RewriterStatement> set)
324398 private static class RewriterAssertion {
325399 Set <RewriterStatement > set ;
326400 RewriterStatement stmt ;
401+ RewriterStatement backRef ; // The back-reference to this assertion
402+
403+ RewriterStatement getEClassStmt (final RuleContext ctx , RewriterAssertions assertions ) {
404+ if (stmt != null )
405+ return stmt ;
406+
407+ RewriterStatement argList = new RewriterInstruction ().as (UUID .randomUUID ().toString ()).withInstruction ("argList" ).withOps (set .toArray (RewriterStatement []::new ));
408+ stmt = new RewriterInstruction ().as (UUID .randomUUID ().toString ()).withInstruction ("_EClass" ).withOps (argList );
409+ stmt .consolidate (ctx );
410+ assertions .assertionMatcher .put (stmt , this );
411+ assertions .partOfAssertion .compute (stmt , (k , v ) -> {
412+ if (v == null )
413+ v = new HashSet <>();
414+
415+ v .add (this );
416+ return v ;
417+ });
418+ assertions .partOfAssertion .compute (argList , (k , v ) -> {
419+ if (v == null )
420+ v = new HashSet <>();
421+
422+ v .add (this );
423+ return v ;
424+ });
425+ assertions .resolveCyclicAssertions (this );
426+ return stmt ;
427+ }
428+
429+ RewriterStatement getBackRef (final RuleContext ctx , RewriterAssertions assertions ) {
430+ if (backRef != null )
431+ return backRef ;
432+
433+ backRef = new RewriterInstruction ()
434+ .as (UUID .randomUUID ().toString ())
435+ .withInstruction ("_backRef." + getEClassStmt (ctx , assertions ).getResultingDataType (ctx ))
436+ .consolidate (ctx );
437+ backRef .unsafePutMeta ("_backRef" , getEClassStmt (ctx , assertions ));
438+ assertions .partOfAssertion .compute (backRef , (k , v ) -> {
439+ if (v == null )
440+ v = new HashSet <>();
441+
442+ v .add (this );
443+ return v ;
444+ });
445+ return backRef ;
446+ }
327447
328448 @ Override
329449 public String toString () {
0 commit comments