@@ -216,15 +216,18 @@ public static boolean validateRuleCorrectness(RewriterRule rule, final RuleConte
216216 }
217217
218218 public static boolean validateRuleApplicability (RewriterRule rule , final RuleContext ctx ) {
219+ RewriterStatement _mstmt = rule .getStmt1 ();
219220 if (ctx .metaPropagator != null )
220- ctx .metaPropagator .apply (rule . getStmt1 () );
221+ ctx .metaPropagator .apply (_mstmt );
221222
222- Set <RewriterStatement > vars = DMLCodeGenerator .getVariables (rule .getStmt1 ());
223+ final RewriterStatement stmt1 = RewriterUtils .unfuseOperators (_mstmt , ctx );
224+
225+ Set <RewriterStatement > vars = DMLCodeGenerator .getVariables (stmt1 );
223226 Set <String > varNames = vars .stream ().map (RewriterStatement ::getId ).collect (Collectors .toSet ());
224227 String code2Header = DMLCodeGenerator .generateDMLVariables (vars );
225- String code2 = code2Header + "\n result = " + DMLCodeGenerator .generateDML (rule . getStmt1 () );
228+ String code2 = code2Header + "\n result = " + DMLCodeGenerator .generateDML (stmt1 );
226229
227- boolean isMatrix = rule . getStmt1 () .getResultingDataType (ctx ).equals ("MATRIX" );
230+ boolean isMatrix = stmt1 .getResultingDataType (ctx ).equals ("MATRIX" );
228231
229232 if (isMatrix )
230233 code2 += "\n print(lineage(result))" ;
@@ -291,7 +294,7 @@ public static boolean validateRuleApplicability(RewriterRule rule, final RuleCon
291294
292295 Map <RewriterStatement , RewriterStatement > createdObjects = new HashMap <>();
293296
294- RewriterStatement stmt1ReplaceNCols = rule . getStmt1 () .nestedCopyOrInject (createdObjects , mstmt -> {
297+ RewriterStatement stmt1ReplaceNCols = stmt1 .nestedCopyOrInject (createdObjects , mstmt -> {
295298 if (mstmt .isInstruction () && (mstmt .trueInstruction ().equals ("ncol" ) || mstmt .trueInstruction ().equals ("nrow" )))
296299 return RewriterStatement .literal (ctx , DMLCodeGenerator .MATRIX_DIMS );
297300 return null ;
0 commit comments