@@ -311,8 +311,12 @@ void scheduleMatmul(
311311 // Setup accumulator register.
312312 auto cc = c->cacheBefore ();
313313
314- // Setup output smem buffer
315- auto c_smem = c->cacheBefore ();
314+ TensorView* c_smem = nullptr ;
315+
316+ if (params.has_epilog ) {
317+ // Setup output smem buffer
318+ c_smem = c->cacheBefore ();
319+ }
316320
317321 // Get the input to the mma op.
318322 auto mma = dynamic_cast <MmaOp*>(cc->definition ());
@@ -483,7 +487,10 @@ void scheduleMatmul(
483487 // Set memory type:
484488 acw_smem->setMemoryType (MemoryType::Shared);
485489 bcw_smem->setMemoryType (MemoryType::Shared);
486- c_smem->setMemoryType (MemoryType::Shared);
490+
491+ if (params.has_epilog ) {
492+ c_smem->setMemoryType (MemoryType::Shared);
493+ }
487494
488495 // Set parallelization:
489496 // TODO: this section goes to a separate matmul util,
@@ -534,37 +541,44 @@ void scheduleMatmul(
534541 bcr->skewDoubleBuffer ();
535542 }
536543
544+ auto output_buffer = params.has_epilog ? c_smem : c;
545+
537546 scheduler_utils::BoundedDirectionalTransformPropagator::forward (
538547 cc,
539548 -1 ,
540- {c_smem},
541- scheduler_utils::BoundedDirectionalTransformPropagator::Options ()
542- .propagateParallelType ()
543- .propagateToBoundary ());
544-
545- // Epilog schedule:
546- scheduler_utils::BoundedDirectionalTransformPropagator::forward (
547- c_smem,
548- 3 ,
549- {c},
549+ {output_buffer},
550550 scheduler_utils::BoundedDirectionalTransformPropagator::Options ()
551551 .propagateParallelType ()
552552 .propagateToBoundary ());
553553
554- c_smem->computeAt (c, 3 );
555- c->reorder ({{-1 , -2 }, {-2 , -1 }});
556- // 16 x 128, with half of the warps:
557-
558- // Output vectorize by 4:
559- c->split (-2 , 2 );
560- c->split (-1 , 4 );
561-
562- // [8, 2, 32, 4]
563- c->axis (-3 )->parallelize (ParallelType::TIDy);
564- c->axis (-2 )->parallelize (ParallelType::TIDx);
565- c->axis (-1 )->parallelize (ParallelType::Vectorize);
566- c_smem->axis (-1 )->parallelize (ParallelType::Vectorize);
567- c_smem->doubleBuffer ();
554+ // Epilog schedule (To be built out):
555+ if (params.has_epilog ) {
556+ scheduler_utils::BoundedDirectionalTransformPropagator::forward (
557+ c_smem,
558+ 3 ,
559+ {c},
560+ scheduler_utils::BoundedDirectionalTransformPropagator::Options ()
561+ .propagateParallelType ()
562+ .propagateToBoundary ());
563+
564+ c_smem->computeAt (c, 3 );
565+ c->reorder ({{-1 , -2 }, {-2 , -1 }});
566+ // 16 x 128, with half of the warps:
567+
568+ // Output vectorize by 4:
569+ c->split (-2 , 2 );
570+ c->split (-1 , 4 );
571+
572+ // [8, 2, 32, 4]
573+ c->axis (-3 )->parallelize (ParallelType::TIDy);
574+ c->axis (-2 )->parallelize (ParallelType::TIDx);
575+ c->axis (-1 )->parallelize (ParallelType::Vectorize);
576+ c_smem->axis (-1 )->parallelize (ParallelType::Vectorize);
577+ c_smem->doubleBuffer ();
578+ } else {
579+ // Always vector
580+ c->axis (-1 )->parallelize (ParallelType::Vectorize);
581+ }
568582
569583 if (params.index_lift_options .lift_gmem_read_address ) {
570584 a->liftReadAddress ();
0 commit comments