@@ -724,7 +724,36 @@ TEST_F(NVFuserTest, FusionFrontendCastDoubleToHalf_CUDA) {
724724
725725 std::vector<IValue> inputs = {t0, t1};
726726
727- // Define fusion
727+ Fusion fauto;
728+ { // Do automatic scheduling on fauto
729+ FusionGuard fg (&fauto);
730+
731+ auto tv0 = makeSymbolicTensor (2 , DataType::Double);
732+ auto tv1 = makeSymbolicTensor (2 , DataType::Double);
733+
734+ fauto.addInput (tv0);
735+ fauto.addInput (tv1);
736+
737+ auto tv2 = castOp (DataType::Half, tv0);
738+ auto tv3 = castOp (DataType::Half, tv1);
739+ // implicit casts
740+ auto tv4 = castOp (DataType::Float, tv2);
741+ auto tv5 = castOp (DataType::Float, tv3);
742+ auto tv6 = add (tv4, tv5);
743+ auto tv7 = relu (tv6);
744+ auto tv8 = castOp (DataType::Half, tv7);
745+
746+ fauto.addOutput (tv8);
747+
748+ // Run automatic scheduler
749+ auto pointwise_params = getPointwiseHeuristics (&fauto, inputs);
750+ TORCH_CHECK (pointwise_params, " Pointwise schedule was not generated!" );
751+ schedulePointwise (&fauto, *pointwise_params);
752+ }
753+
754+ // Re-define the fusion exactly for manual scheduling
755+ // This is necessary in order to catch all the constructors inside each
756+ // Fusion independently.
728757 Fusion fusion;
729758 FusionGuard fg (&fusion);
730759
@@ -734,37 +763,64 @@ TEST_F(NVFuserTest, FusionFrontendCastDoubleToHalf_CUDA) {
734763 fusion.addInput (tv0);
735764 fusion.addInput (tv1);
736765
737- auto tv0h = castOp (DataType::Half, tv0);
738- auto tv1h = castOp (DataType::Half, tv1);
739- auto tv0f = castOp (DataType::Float, tv0h);
740- auto tv1f = castOp (DataType::Float, tv1h);
741- auto tv2 = add (tv0f, tv1f);
742- auto tv3 = relu (tv2);
743- auto tv4 = castOp (DataType::Half, tv3);
766+ auto tv2 = castOp (DataType::Half, tv0);
767+ auto tv3 = castOp (DataType::Half, tv1);
768+ // implicit casts
769+ auto tv4 = castOp (DataType::Float, tv2);
770+ auto tv5 = castOp (DataType::Float, tv3);
771+ auto tv6 = add (tv4, tv5);
772+ auto tv7 = relu (tv6);
773+ auto tv8 = castOp (DataType::Half, tv7);
744774
745- fusion.addOutput (tv4);
746-
747- // Run automatic scheduler
748- auto fauto = Fusion (fusion); // unique_ptr to copy of fusion
749- auto pointwise_params = getPointwiseHeuristics (&fauto, inputs);
750- TORCH_CHECK (pointwise_params, " Pointwise schedule was not generated!" );
751- schedulePointwise (&fauto, *pointwise_params);
775+ fusion.addOutput (tv8);
752776
753777 // Perform manual scheduling
754- tv4->merge (0 , 1 );
755- tv4->split (0 , NamedScalar::getParallelDim (ParallelType::TIDx));
756- tv4->axis (0 )->parallelize (ParallelType::BIDx);
757- tv4->axis (1 )->parallelize (ParallelType::TIDx);
778+
779+ // Before schedulePointwise() is called, getPointwiseHeuristics() calls
780+ // vectorize_helper::getExpandedVectorization() which in turn calls:
781+ // vectorize_helper::getVectorizationSize
782+ // vectorize_helper::ProjectedExtent::getNumerator
783+ // vectorize_helper::ProjectedExtent::computeNumerDenomir
784+ // IrContainer::oneVal
785+ // oneVal() creates an actual Val here to hold the denominator and
786+ // initializes it to 1. Since this is reflected in the fusion log, I'm
787+ // inserting it here even though it has not effect on the generated kernel.
788+ fusion.oneVal ();
789+
790+ tv0->cacheAfter (); // tv9
791+ tv1->cacheAfter (); // tv10
792+ auto tv11 = tv8->cacheBefore (); // tv11
793+
794+ tv8->merge (0 , 1 );
795+ tv8->reorder ({{0 , -1 }});
796+ tv8->reorder ({{-1 , 0 }});
797+ tv8->split (0 , 128 );
798+ tv8->split (0 , 1 );
799+ tv8->split (0 , 1 );
800+ tv8->axis (0 )->parallelize (ParallelType::BIDx);
801+ tv8->axis (1 )->parallelize (ParallelType::Unswitch);
802+ tv8->axis (3 )->parallelize (ParallelType::TIDx);
758803
759804 // propagate the mapping to other tensors
760- TransformPropagatorWithCheck propagator (tv4);
761- MaxRootDomainInfoSpanningTree (tv4).traverse (&propagator);
762- scheduler_utils::parallelizeAllLike (
763- tv4, {tv0, tv1, tv0h, tv1h, tv0f, tv1f, tv2, tv3});
805+ TransformPropagatorWithCheck propagator (tv8);
806+ MaxRootDomainInfoSpanningTree (tv8).traverse (&propagator);
807+ scheduler_utils::parallelizeAllLike (tv8);
764808
765- inlineMost ();
809+ // Pointwise scheduler does not use inlineMost(), as reduction scheduler does
810+ // Instead, it uses inlineAllAt followed by inlineMost(innermost_tensors)
811+ inlineAllAt (tv8, 2 , true );
812+ inlineMost (
813+ std::vector<TensorView*>({tv0, tv1, tv2, tv3, tv4, tv5, tv6, tv7, tv11}));
766814
767- compare_ir (fusion, fauto);
815+ // Note that inlineAllAt iterates through an unordered_set to do inlining, so
816+ // it is not practical to match the fusion_debug log exactly when using
817+ // pointwise scheduler
818+ compare_ir_math (fusion, fauto);
819+ compare_transforms (fusion, fauto);
820+ // compare_fusion_debug(fusion, fauto);
821+ compare_kernels (fusion, fauto);
822+
823+ // compare_ir(fusion, fauto);
768824
769825 // Perform eager computation and verify
770826 auto t0h = t0.to (options.dtype (at::kHalf ));
0 commit comments