Skip to content

Commit 4519bd1

Browse files
committed
Make FrontendCastDoubleToHalf match auto schedule
This was very similar to the FrontendAdd example since it also uses the pointwise scheduler.
1 parent dd44075 commit 4519bd1

File tree

1 file changed

+81
-25
lines changed

1 file changed

+81
-25
lines changed

third_party/nvfuser/test/test_gpu_match_frontend.cpp

Lines changed: 81 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -724,7 +724,36 @@ TEST_F(NVFuserTest, FusionFrontendCastDoubleToHalf_CUDA) {
724724

725725
std::vector<IValue> inputs = {t0, t1};
726726

727-
// Define fusion
727+
Fusion fauto;
728+
{ // Do automatic scheduling on fauto
729+
FusionGuard fg(&fauto);
730+
731+
auto tv0 = makeSymbolicTensor(2, DataType::Double);
732+
auto tv1 = makeSymbolicTensor(2, DataType::Double);
733+
734+
fauto.addInput(tv0);
735+
fauto.addInput(tv1);
736+
737+
auto tv2 = castOp(DataType::Half, tv0);
738+
auto tv3 = castOp(DataType::Half, tv1);
739+
// implicit casts
740+
auto tv4 = castOp(DataType::Float, tv2);
741+
auto tv5 = castOp(DataType::Float, tv3);
742+
auto tv6 = add(tv4, tv5);
743+
auto tv7 = relu(tv6);
744+
auto tv8 = castOp(DataType::Half, tv7);
745+
746+
fauto.addOutput(tv8);
747+
748+
// Run automatic scheduler
749+
auto pointwise_params = getPointwiseHeuristics(&fauto, inputs);
750+
TORCH_CHECK(pointwise_params, "Pointwise schedule was not generated!");
751+
schedulePointwise(&fauto, *pointwise_params);
752+
}
753+
754+
// Re-define the fusion exactly for manual scheduling
755+
// This is necessary in order to catch all the constructors inside each
756+
// Fusion independently.
728757
Fusion fusion;
729758
FusionGuard fg(&fusion);
730759

@@ -734,37 +763,64 @@ TEST_F(NVFuserTest, FusionFrontendCastDoubleToHalf_CUDA) {
734763
fusion.addInput(tv0);
735764
fusion.addInput(tv1);
736765

737-
auto tv0h = castOp(DataType::Half, tv0);
738-
auto tv1h = castOp(DataType::Half, tv1);
739-
auto tv0f = castOp(DataType::Float, tv0h);
740-
auto tv1f = castOp(DataType::Float, tv1h);
741-
auto tv2 = add(tv0f, tv1f);
742-
auto tv3 = relu(tv2);
743-
auto tv4 = castOp(DataType::Half, tv3);
766+
auto tv2 = castOp(DataType::Half, tv0);
767+
auto tv3 = castOp(DataType::Half, tv1);
768+
// implicit casts
769+
auto tv4 = castOp(DataType::Float, tv2);
770+
auto tv5 = castOp(DataType::Float, tv3);
771+
auto tv6 = add(tv4, tv5);
772+
auto tv7 = relu(tv6);
773+
auto tv8 = castOp(DataType::Half, tv7);
744774

745-
fusion.addOutput(tv4);
746-
747-
// Run automatic scheduler
748-
auto fauto = Fusion(fusion); // unique_ptr to copy of fusion
749-
auto pointwise_params = getPointwiseHeuristics(&fauto, inputs);
750-
TORCH_CHECK(pointwise_params, "Pointwise schedule was not generated!");
751-
schedulePointwise(&fauto, *pointwise_params);
775+
fusion.addOutput(tv8);
752776

753777
// Perform manual scheduling
754-
tv4->merge(0, 1);
755-
tv4->split(0, NamedScalar::getParallelDim(ParallelType::TIDx));
756-
tv4->axis(0)->parallelize(ParallelType::BIDx);
757-
tv4->axis(1)->parallelize(ParallelType::TIDx);
778+
779+
// Before schedulePointwise() is called, getPointwiseHeuristics() calls
780+
// vectorize_helper::getExpandedVectorization() which in turn calls:
781+
// vectorize_helper::getVectorizationSize
782+
// vectorize_helper::ProjectedExtent::getNumerator
783+
// vectorize_helper::ProjectedExtent::computeNumerDenomir
784+
// IrContainer::oneVal
785+
// oneVal() creates an actual Val here to hold the denominator and
786+
// initializes it to 1. Since this is reflected in the fusion log, I'm
787+
// inserting it here even though it has not effect on the generated kernel.
788+
fusion.oneVal();
789+
790+
tv0->cacheAfter(); // tv9
791+
tv1->cacheAfter(); // tv10
792+
auto tv11 = tv8->cacheBefore(); // tv11
793+
794+
tv8->merge(0, 1);
795+
tv8->reorder({{0, -1}});
796+
tv8->reorder({{-1, 0}});
797+
tv8->split(0, 128);
798+
tv8->split(0, 1);
799+
tv8->split(0, 1);
800+
tv8->axis(0)->parallelize(ParallelType::BIDx);
801+
tv8->axis(1)->parallelize(ParallelType::Unswitch);
802+
tv8->axis(3)->parallelize(ParallelType::TIDx);
758803

759804
// propagate the mapping to other tensors
760-
TransformPropagatorWithCheck propagator(tv4);
761-
MaxRootDomainInfoSpanningTree(tv4).traverse(&propagator);
762-
scheduler_utils::parallelizeAllLike(
763-
tv4, {tv0, tv1, tv0h, tv1h, tv0f, tv1f, tv2, tv3});
805+
TransformPropagatorWithCheck propagator(tv8);
806+
MaxRootDomainInfoSpanningTree(tv8).traverse(&propagator);
807+
scheduler_utils::parallelizeAllLike(tv8);
764808

765-
inlineMost();
809+
// Pointwise scheduler does not use inlineMost(), as reduction scheduler does
810+
// Instead, it uses inlineAllAt followed by inlineMost(innermost_tensors)
811+
inlineAllAt(tv8, 2, true);
812+
inlineMost(
813+
std::vector<TensorView*>({tv0, tv1, tv2, tv3, tv4, tv5, tv6, tv7, tv11}));
766814

767-
compare_ir(fusion, fauto);
815+
// Note that inlineAllAt iterates through an unordered_set to do inlining, so
816+
// it is not practical to match the fusion_debug log exactly when using
817+
// pointwise scheduler
818+
compare_ir_math(fusion, fauto);
819+
compare_transforms(fusion, fauto);
820+
// compare_fusion_debug(fusion, fauto);
821+
compare_kernels(fusion, fauto);
822+
823+
// compare_ir(fusion, fauto);
768824

769825
// Perform eager computation and verify
770826
auto t0h = t0.to(options.dtype(at::kHalf));

0 commit comments

Comments
 (0)