Skip to content

Commit 97c6ee9

Browse files
committed
optionally enable epilog schedule (for now)
1 parent 77f831b commit 97c6ee9

File tree

2 files changed

+44
-27
lines changed

2 files changed

+44
-27
lines changed

torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp

Lines changed: 41 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -311,8 +311,12 @@ void scheduleMatmul(
311311
// Setup accumulator register.
312312
auto cc = c->cacheBefore();
313313

314-
// Setup output smem buffer
315-
auto c_smem = c->cacheBefore();
314+
TensorView* c_smem = nullptr;
315+
316+
if (params.has_epilog) {
317+
// Setup output smem buffer
318+
c_smem = c->cacheBefore();
319+
}
316320

317321
// Get the input to the mma op.
318322
auto mma = dynamic_cast<MmaOp*>(cc->definition());
@@ -483,7 +487,10 @@ void scheduleMatmul(
483487
// Set memory type:
484488
acw_smem->setMemoryType(MemoryType::Shared);
485489
bcw_smem->setMemoryType(MemoryType::Shared);
486-
c_smem->setMemoryType(MemoryType::Shared);
490+
491+
if (params.has_epilog) {
492+
c_smem->setMemoryType(MemoryType::Shared);
493+
}
487494

488495
// Set parallelization:
489496
// TODO: this section goes to a separate matmul util,
@@ -534,37 +541,44 @@ void scheduleMatmul(
534541
bcr->skewDoubleBuffer();
535542
}
536543

544+
auto output_buffer = params.has_epilog ? c_smem : c;
545+
537546
scheduler_utils::BoundedDirectionalTransformPropagator::forward(
538547
cc,
539548
-1,
540-
{c_smem},
541-
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
542-
.propagateParallelType()
543-
.propagateToBoundary());
544-
545-
// Epilog schedule:
546-
scheduler_utils::BoundedDirectionalTransformPropagator::forward(
547-
c_smem,
548-
3,
549-
{c},
549+
{output_buffer},
550550
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
551551
.propagateParallelType()
552552
.propagateToBoundary());
553553

554-
c_smem->computeAt(c, 3);
555-
c->reorder({{-1, -2}, {-2, -1}});
556-
// 16 x 128, with half of the warps:
557-
558-
// Output vectorize by 4:
559-
c->split(-2, 2);
560-
c->split(-1, 4);
561-
562-
// [8, 2, 32, 4]
563-
c->axis(-3)->parallelize(ParallelType::TIDy);
564-
c->axis(-2)->parallelize(ParallelType::TIDx);
565-
c->axis(-1)->parallelize(ParallelType::Vectorize);
566-
c_smem->axis(-1)->parallelize(ParallelType::Vectorize);
567-
c_smem->doubleBuffer();
554+
// Epilog schedule (To be built out):
555+
if (params.has_epilog) {
556+
scheduler_utils::BoundedDirectionalTransformPropagator::forward(
557+
c_smem,
558+
3,
559+
{c},
560+
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
561+
.propagateParallelType()
562+
.propagateToBoundary());
563+
564+
c_smem->computeAt(c, 3);
565+
c->reorder({{-1, -2}, {-2, -1}});
566+
// 16 x 128, with half of the warps:
567+
568+
// Output vectorize by 4:
569+
c->split(-2, 2);
570+
c->split(-1, 4);
571+
572+
// [8, 2, 32, 4]
573+
c->axis(-3)->parallelize(ParallelType::TIDy);
574+
c->axis(-2)->parallelize(ParallelType::TIDx);
575+
c->axis(-1)->parallelize(ParallelType::Vectorize);
576+
c_smem->axis(-1)->parallelize(ParallelType::Vectorize);
577+
c_smem->doubleBuffer();
578+
} else {
579+
// Always vector
580+
c->axis(-1)->parallelize(ParallelType::Vectorize);
581+
}
568582

569583
if (params.index_lift_options.lift_gmem_read_address) {
570584
a->liftReadAddress();

torch/csrc/jit/codegen/cuda/scheduler/matmul.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ class MatmulParam {
5151

5252
//! Enables predicate peeling mainloop:
5353
bool peel_main_loop = true;
54+
55+
//! Enables an epilog schedule
56+
bool has_epilog = false;
5457
};
5558

5659
//! Prototype auto scheduling function.

0 commit comments

Comments
 (0)