@@ -96,7 +96,8 @@ public static String generateRuleValidationDML(RewriterRule rule, double eps, St
9696
9797 StringBuilder sb = new StringBuilder ();
9898
99- for (RewriterStatement var : vars ) {
99+ sb .append (generateDMLVariables (vars ));
100+ /*for (RewriterStatement var : vars) {
100101 switch (var.getResultingDataType(ctx)) {
101102 case "MATRIX":
102103 sb.append(var.getId() + " = rand(rows=1000, cols=1000, min=0.0, max=1.0)\n");
@@ -113,7 +114,7 @@ public static String generateRuleValidationDML(RewriterRule rule, double eps, St
113114 default:
114115 throw new NotImplementedException(var.getResultingDataType(ctx));
115116 }
116- }
117+ }*/
117118
118119 sb .append ('\n' );
119120 sb .append ("R1 = " );
@@ -131,6 +132,41 @@ public static String generateRuleValidationDML(RewriterRule rule, double eps, St
131132 return sb .toString ();
132133 }
133134
135+ public static String generateDMLVariables (RewriterStatement root ) {
136+ Set <RewriterStatement > vars = new HashSet <>();
137+ root .forEachPostOrder ((stmt , pred ) -> {
138+ if (!stmt .isInstruction () && !stmt .isLiteral ())
139+ vars .add (stmt );
140+ }, false );
141+
142+ return generateDMLVariables (vars );
143+ }
144+
145+ public static String generateDMLVariables (Set <RewriterStatement > vars ) {
146+ StringBuilder sb = new StringBuilder ();
147+
148+ for (RewriterStatement var : vars ) {
149+ switch (var .getResultingDataType (ctx )) {
150+ case "MATRIX" :
151+ sb .append (var .getId () + " = rand(rows=1000, cols=1000, min=0.0, max=1.0)\n " );
152+ break ;
153+ case "FLOAT" :
154+ sb .append (var .getId () + " = as.scalar(rand())\n " );
155+ break ;
156+ case "INT" :
157+ sb .append (var .getId () + " = as.integer(as.scalar(rand(min=0.0, max=10000.0)))\n " );
158+ break ;
159+ case "BOOL" :
160+ sb .append (var .getId () + " = as.scalar(rand()) < 0.5\n " );
161+ break ;
162+ default :
163+ throw new NotImplementedException (var .getResultingDataType (ctx ));
164+ }
165+ }
166+
167+ return sb .toString ();
168+ }
169+
134170 public static String generateEqualityCheck (String stmt1Var , String stmt2Var , String dataType , double eps ) {
135171 switch (dataType ) {
136172 case "MATRIX" :
@@ -145,6 +181,17 @@ public static String generateEqualityCheck(String stmt1Var, String stmt2Var, Str
145181 throw new NotImplementedException ();
146182 }
147183
184+ public static String generateDMLDefs (RewriterStatement stmt ) {
185+ Map <String , RewriterStatement > vars = new HashMap <>();
186+
187+ stmt .forEachPostOrder ((cur , pred ) -> {
188+ if (!cur .isInstruction () && !cur .isLiteral ())
189+ vars .put (cur .getId (), cur );
190+ }, false );
191+
192+ return generateDMLDefs (vars );
193+ }
194+
148195 public static String generateDMLDefs (Map <String , RewriterStatement > defs ) {
149196 StringBuilder sb = new StringBuilder ();
150197
0 commit comments