@@ -192,8 +192,77 @@ public static void forAllHops(DMLProgram program, Consumer<Hop> consumer) {
192192 sb .getHops ().forEach (consumer );
193193 }
194194
195- public static RewriterStatement buildDAGFromHop (Hop hop , int maxDepth , final RuleContext ctx ) {
196- return buildDAGRecursively (hop , null , new HashMap <>(), 0 , maxDepth , ctx );
195+ public static RewriterStatement buildDAGFromHop (Hop hop , int maxDepth , boolean mindDataCharacteristics , final RuleContext ctx ) {
196+ RewriterStatement out = buildDAGRecursively (hop , null , new HashMap <>(), 0 , maxDepth , ctx );
197+
198+ if (mindDataCharacteristics )
199+ return populateDataCharacteristics (out , ctx );
200+
201+ return out ;
202+ }
203+
204+ public static RewriterStatement populateDataCharacteristics (RewriterStatement stmt , final RuleContext ctx ) {
205+ if (stmt == null )
206+ return null ;
207+
208+ if (stmt instanceof RewriterDataType && stmt .getResultingDataType (ctx ).equals ("MATRIX" )) {
209+ Long nrow = (Long ) stmt .getMeta ("_actualNRow" );
210+ Long ncol = (Long ) stmt .getMeta ("_actualNCol" );
211+ int matType = 0 ;
212+
213+ // TODO: what if matrix consists of one single element?
214+ if (nrow != null && nrow == 1L ) {
215+ matType = 1 ;
216+ } else if (ncol != null && ncol == 1L ) {
217+ matType = 2 ;
218+ }
219+
220+ if (matType > 0 ) {
221+ return new RewriterInstruction ()
222+ .as (UUID .randomUUID ().toString ())
223+ .withInstruction (matType == 1 ? "rowVec" : "colVec" )
224+ .withOps (stmt )
225+ .consolidate (ctx );
226+ }
227+ }
228+
229+ Map <RewriterStatement , RewriterStatement > createdObjects = new HashMap <>();
230+
231+ stmt .forEachPostOrder ((cur , pred ) -> {
232+ for (int i = 0 ; i < cur .getOperands ().size (); i ++) {
233+ RewriterStatement child = cur .getChild (i );
234+
235+ if (child instanceof RewriterDataType && child .getResultingDataType (ctx ).equals ("MATRIX" )) {
236+ Long nrow = (Long ) child .getMeta ("_actualNRow" );
237+ Long ncol = (Long ) child .getMeta ("_actualNCol" );
238+ int matType = 0 ;
239+
240+ // TODO: what if matrix consists of one single element?
241+ if (nrow != null && nrow == 1L ) {
242+ matType = 1 ;
243+ } else if (ncol != null && ncol == 1L ) {
244+ matType = 2 ;
245+ }
246+
247+ if (matType > 0 ) {
248+ RewriterStatement created = createdObjects .get (child );
249+
250+ if (created == null ) {
251+ created = new RewriterInstruction ()
252+ .as (UUID .randomUUID ().toString ())
253+ .withInstruction (matType == 1 ? "rowVec" : "colVec" )
254+ .withOps (child )
255+ .consolidate (ctx );
256+ createdObjects .put (child , created );
257+ }
258+
259+ cur .getOperands ().set (i , created );
260+ }
261+ }
262+ }
263+ }, false );
264+
265+ return stmt ;
197266 }
198267
199268 public static void forAllUniqueTranslatableStatements (DMLProgram program , int maxDepth , Consumer <RewriterStatement > stmt , RewriterDatabase db , final RuleContext ctx ) {
@@ -364,6 +433,8 @@ private static RewriterStatement buildDAGRecursively(Hop next, @Nullable String
364433 if (stmt == null )
365434 return buildLeaf (next , expectedType , ctx );
366435
436+ insertDataCharacteristics (next , stmt , ctx );
437+
367438 if (buildInputs (stmt , next .getInput (), cache , true , depth , maxDepth , ctx ))
368439 return stmt ;
369440
@@ -377,6 +448,8 @@ private static RewriterStatement buildDAGRecursively(Hop next, @Nullable String
377448 if (stmt == null )
378449 return buildLeaf (next , expectedType , ctx );
379450
451+ insertDataCharacteristics (next , stmt , ctx );
452+
380453 if (buildInputs (stmt , next .getInput (), cache , true , depth , maxDepth , ctx ))
381454 return stmt ;
382455
@@ -390,6 +463,8 @@ private static RewriterStatement buildDAGRecursively(Hop next, @Nullable String
390463 if (stmt == null )
391464 return buildLeaf (next , expectedType , ctx );
392465
466+ insertDataCharacteristics (next , stmt , ctx );
467+
393468 if (buildInputs (stmt , next .getInput (), cache , true , depth , maxDepth , ctx ))
394469 return stmt ;
395470
@@ -403,6 +478,8 @@ private static RewriterStatement buildDAGRecursively(Hop next, @Nullable String
403478 if (stmt == null )
404479 return buildLeaf (next , expectedType , ctx );
405480
481+ insertDataCharacteristics (next , stmt , ctx );
482+
406483 if (buildInputs (stmt , next .getInput (), cache , true , depth , maxDepth , ctx ))
407484 return stmt ;
408485
@@ -416,6 +493,8 @@ private static RewriterStatement buildDAGRecursively(Hop next, @Nullable String
416493 if (stmt == null )
417494 return buildLeaf (next , expectedType , ctx );
418495
496+ insertDataCharacteristics (next , stmt , ctx );
497+
419498 if (buildInputs (stmt , next .getInput (), cache , true , depth , maxDepth , ctx ))
420499 return stmt ;
421500
@@ -429,6 +508,8 @@ private static RewriterStatement buildDAGRecursively(Hop next, @Nullable String
429508 if (stmt == null )
430509 return buildLeaf (next , expectedType , ctx );
431510
511+ insertDataCharacteristics (next , stmt , ctx );
512+
432513 if (buildInputs (stmt , next .getInput (), cache , true , depth , maxDepth , ctx ))
433514 return stmt ;
434515
@@ -443,6 +524,8 @@ private static RewriterStatement buildDAGRecursively(Hop next, @Nullable String
443524 if (stmt == null )
444525 return buildLeaf (next , expectedType , ctx );
445526
527+ insertDataCharacteristics (next , stmt , ctx );
528+
446529 if (buildInputs (stmt , interestingHops , cache , true , depth , maxDepth , ctx ))
447530 return stmt ;
448531
@@ -465,6 +548,19 @@ private static RewriterStatement buildDAGRecursively(Hop next, @Nullable String
465548 return null ;
466549 }
467550
551+ private static void insertDataCharacteristics (Hop hop , RewriterStatement stmt , final RuleContext ctx ) {
552+ if (stmt .getResultingDataType (ctx ).equals ("MATRIX" )) {
553+ if (hop .getDataCharacteristics () != null ) {
554+ long nrows = hop .getDataCharacteristics ().getRows ();
555+ long ncols = hop .getDataCharacteristics ().getCols ();
556+ if (nrows > 0 )
557+ stmt .unsafePutMeta ("_actualNRow" , RewriterStatement .literal (ctx , nrows ));
558+ if (ncols > 0 )
559+ stmt .unsafePutMeta ("_actualNCol" , RewriterStatement .literal (ctx , nrows ));
560+ }
561+ }
562+ }
563+
468564 // TODO: Maybe introduce other implicit conversions if types mismatch
469565 private static RewriterStatement checkForCorrectTypes (RewriterStatement stmt , @ Nullable String expectedType , Hop hop , final RuleContext ctx ) {
470566 if (stmt == null )
@@ -500,14 +596,19 @@ private static RewriterStatement buildLeaf(Hop hop, @Nullable String expectedTyp
500596 if (RewriterUtils .DOUBLE_PATTERN .matcher (hopName ).matches () || RewriterUtils .SPECIAL_FLOAT_PATTERN .matcher (hopName ).matches ())
501597 hopName = "float" + new Random ().nextInt (1000 );
502598
503- if (expectedType != null )
504- return RewriterUtils .parse (hopName , ctx , expectedType + ":" + hopName );
599+ if (expectedType != null ) {
600+ RewriterStatement stmt = RewriterUtils .parse (hopName , ctx , expectedType + ":" + hopName );
601+ insertDataCharacteristics (hop , stmt , ctx );
602+ return stmt ;
603+ }
505604
506605 switch (hop .getDataType ()) {
507606 case SCALAR :
508607 return buildScalarLeaf (hop , hopName , ctx );
509608 case MATRIX :
510- return RewriterUtils .parse (hopName , ctx , "MATRIX:" + hopName );
609+ RewriterStatement stmt = RewriterUtils .parse (hopName , ctx , "MATRIX:" + hopName );
610+ insertDataCharacteristics (hop , stmt , ctx );
611+ return stmt ;
511612 }
512613
513614 return null ; // Not supported then
0 commit comments