Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,9 @@ if(TRITON_BUILD_PYTHON_MODULE)
message(STATUS "Adding Python module")
set(PYTHON_SRC_PATH ${CMAKE_CURRENT_SOURCE_DIR}/python/src)
set(PYTHON_SRC ${PYTHON_SRC_PATH}/main.cc
${PYTHON_SRC_PATH}/triton.cc
${PYTHON_SRC_PATH}/ir.cc
${PYTHON_SRC_PATH}/passes.cc
${PYTHON_SRC_PATH}/translation.cc
${PYTHON_SRC_PATH}/runtime.cc
${PYTHON_SRC_PATH}/interpreter.cc)
include_directories("." ${PYTHON_SRC_PATH})

Expand Down
35 changes: 19 additions & 16 deletions include/triton/Dialect/TritonGPU/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,34 +5,37 @@
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"

namespace mlir {
namespace triton {
namespace gpu {

std::unique_ptr<Pass> createTritonGPUPipelinePass(int numStages = 3,
int numWarps = 4,
int numCTAs = 1,
int computeCapability = 80);
std::unique_ptr<Pass> createPipelinePass(int numStages = 3, int numWarps = 4,
int numCTAs = 1,
int computeCapability = 80);

std::unique_ptr<Pass>
createTritonGPUAccelerateMatmulPass(int computeCapability = 80);
std::unique_ptr<Pass> createAccelerateMatmulPass(int computeCapability = 80);

std::unique_ptr<Pass> createTritonGPUPrefetchPass();
std::unique_ptr<Pass> createPrefetchPass();

std::unique_ptr<Pass> createTritonGPUCanonicalizeLoopsPass();
std::unique_ptr<Pass> createCanonicalizeLoopsPass();

std::unique_ptr<Pass> createTritonGPUCoalescePass();
std::unique_ptr<Pass> createCoalescePass();

std::unique_ptr<Pass> createTritonGPUReorderInstructionsPass();
std::unique_ptr<Pass> createReorderInstructionsPass();

std::unique_ptr<Pass> createTritonGPUDecomposeConversionsPass();
std::unique_ptr<Pass> createDecomposeConversionsPass();

std::unique_ptr<Pass> createTritonGPURemoveLayoutConversionsPass();
std::unique_ptr<Pass> createRemoveLayoutConversionsPass();

std::unique_ptr<Pass> createTritonGPUVerifier();
std::unique_ptr<Pass> createVerifier();

std::unique_ptr<Pass> createTritonGPUOptimizeDotOperandsPass();
std::unique_ptr<Pass> createOptimizeDotOperandsPass();

std::unique_ptr<Pass> createTritonGPUOptimizeEpiloguePass();
std::unique_ptr<Pass> createOptimizeEpiloguePass();

std::unique_ptr<Pass> createTritonGPUOptimizeThreadLocalityPass();
std::unique_ptr<Pass> createOptimizeThreadLocalityPass();

} // namespace gpu
} // namespace triton

/// Generate the code for registering passes.
#define GEN_PASS_REGISTRATION
Expand Down
20 changes: 10 additions & 10 deletions include/triton/Dialect/TritonGPU/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def TritonGPUPipeline : Pass<"tritongpu-pipeline", "mlir::ModuleOp"> {
needed at the next iteration
}];

let constructor = "mlir::createTritonGPUPipelinePass()";
let constructor = "mlir::triton::gpu::createPipelinePass()";

let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
Expand Down Expand Up @@ -42,7 +42,7 @@ def TritonGPUPrefetch : Pass<"tritongpu-prefetch", "mlir::ModuleOp"> {
that may have their operands constructed at the end of the previous iteration
}];

let constructor = "mlir::createTritonGPUPrefetchPass()";
let constructor = "mlir::triton::gpu::createPrefetchPass()";

let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::scf::SCFDialect",
Expand All @@ -57,7 +57,7 @@ def TritonGPUAccelerateMatmul : Pass<"tritongpu-accelerate-matmul", "mlir::Modul
(e.g., Nvidia tensor cores)
}];

let constructor = "mlir::createTritonGPUAccelerateMatmulPass()";
let constructor = "mlir::triton::gpu::createAccelerateMatmulPass()";

let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
Expand All @@ -78,7 +78,7 @@ def TritonGPUOptimizeDotOperands : Pass<"tritongpu-optimize-dot-operands", "mlir
hardware-accelerated transpositions.
}];

let constructor = "mlir::createTritonGPUOptimizeDotOperandsPass()";
let constructor = "mlir::triton::gpu::createOptimizeDotOperandsPass()";

let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
Expand All @@ -92,7 +92,7 @@ def TritonGPUCoalesce: Pass<"tritongpu-coalesce", "mlir::ModuleOp"> {
TODO
}];

let constructor = "mlir::createTritonGPUCoalescePass()";
let constructor = "mlir::triton::gpu::createCoalescePass()";

let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"];
}
Expand All @@ -104,7 +104,7 @@ def TritonGPURemoveLayoutConversions : Pass<"tritongpu-remove-layout-conversions
let description = [{
}];

let constructor = "mlir::createTritonGPURemoveLayoutConversionsPass()";
let constructor = "mlir::triton::gpu::createRemoveLayoutConversionsPass()";

let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::TritonDialect"];
Expand All @@ -117,7 +117,7 @@ def TritonGPUOptimizeEpilogue : Pass<"tritongpu-optimize-epilogue", "mlir::Modul
let description = [{
}];

let constructor = "mlir::createTritonGPUOptimizeEpiloguePass()";
let constructor = "mlir::triton::gpu::createOptimizeEpiloguePass()";

let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::TritonDialect"];
Expand All @@ -131,7 +131,7 @@ def TritonGPUOptimizeThreadLocality : Pass<"tritongpu-optimize-thread-locality",
Today, this optimizes reduction yielded by loop to be thread-local until after the loop completes.
}];

let constructor = "mlir::createTritonGPUOptimizeThreadLocalityPass()";
let constructor = "mlir::triton::gpu::createOptimizeThreadLocalityPass()";

let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::TritonDialect"];
Expand All @@ -144,7 +144,7 @@ def TritonGPUReorderInstructions: Pass<"tritongpu-reorder-instructions", "mlir::
"conversions from shared memory before their first use) and (2) promote LLVM instruction "
"order more friendly to `ptxas`.";

let constructor = "mlir::createTritonGPUReorderInstructionsPass()";
let constructor = "mlir::triton::gpu::createReorderInstructionsPass()";

let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::TritonDialect"];
Expand All @@ -155,7 +155,7 @@ def TritonGPUDecomposeConversions: Pass<"tritongpu-decompose-conversions", "mlir

let description = "Decomposing conversions this way makes it possible to use CSE and re-use #shared tensors";

let constructor = "mlir::createTritonGPUDecomposeConversionsPass()";
let constructor = "mlir::triton::gpu::createDecomposeConversionsPass()";

let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::TritonDialect"];
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,6 @@ class TritonGPUAccelerateMatmulPass
};

std::unique_ptr<Pass>
mlir::createTritonGPUAccelerateMatmulPass(int computeCapability) {
mlir::triton::gpu::createAccelerateMatmulPass(int computeCapability) {
return std::make_unique<TritonGPUAccelerateMatmulPass>(computeCapability);
}
2 changes: 1 addition & 1 deletion lib/Dialect/TritonGPU/Transforms/Coalesce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,6 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
}
};

std::unique_ptr<Pass> mlir::createTritonGPUCoalescePass() {
std::unique_ptr<Pass> mlir::triton::gpu::createCoalescePass() {
return std::make_unique<CoalescePass>();
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,6 @@ class TritonGPUDecomposeConversionsPass
}
};

std::unique_ptr<Pass> mlir::createTritonGPUDecomposeConversionsPass() {
std::unique_ptr<Pass> mlir::triton::gpu::createDecomposeConversionsPass() {
return std::make_unique<TritonGPUDecomposeConversionsPass>();
}
2 changes: 1 addition & 1 deletion lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,6 @@ class TritonGPUOptimizeDotOperandsPass
}
};

std::unique_ptr<Pass> mlir::createTritonGPUOptimizeDotOperandsPass() {
std::unique_ptr<Pass> mlir::triton::gpu::createOptimizeDotOperandsPass() {
return std::make_unique<TritonGPUOptimizeDotOperandsPass>();
}
2 changes: 1 addition & 1 deletion lib/Dialect/TritonGPU/Transforms/OptimizeEpilogue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,6 @@ class TritonGPUOptimizeEpiloguePass
}
};

std::unique_ptr<Pass> mlir::createTritonGPUOptimizeEpiloguePass() {
std::unique_ptr<Pass> mlir::triton::gpu::createOptimizeEpiloguePass() {
return std::make_unique<TritonGPUOptimizeEpiloguePass>();
}
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,6 @@ class TritonGPUOptimizeThreadLocalityPass
}
};

std::unique_ptr<Pass> mlir::createTritonGPUOptimizeThreadLocalityPass() {
std::unique_ptr<Pass> mlir::triton::gpu::createOptimizeThreadLocalityPass() {
return std::make_unique<TritonGPUOptimizeThreadLocalityPass>();
}
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,9 @@ struct PipelinePass : public TritonGPUPipelineBase<PipelinePass> {
};
} // anonymous namespace

std::unique_ptr<Pass> mlir::createTritonGPUPipelinePass(int numStages,
int numWarps,
int numCTAs,
int computeCapability) {
std::unique_ptr<Pass>
mlir::triton::gpu::createPipelinePass(int numStages, int numWarps, int numCTAs,
int computeCapability) {
return std::make_unique<PipelinePass>(numStages, numWarps, numCTAs,
computeCapability);
}
2 changes: 1 addition & 1 deletion lib/Dialect/TritonGPU/Transforms/Prefetch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,6 @@ struct PrefetchPass : public TritonGPUPrefetchBase<PrefetchPass> {

} // anonymous namespace

std::unique_ptr<Pass> mlir::createTritonGPUPrefetchPass() {
std::unique_ptr<Pass> mlir::triton::gpu::createPrefetchPass() {
return std::make_unique<PrefetchPass>();
}
Original file line number Diff line number Diff line change
Expand Up @@ -997,6 +997,6 @@ class TritonGPURemoveLayoutConversionsPass
}
};

std::unique_ptr<Pass> mlir::createTritonGPURemoveLayoutConversionsPass() {
std::unique_ptr<Pass> mlir::triton::gpu::createRemoveLayoutConversionsPass() {
return std::make_unique<TritonGPURemoveLayoutConversionsPass>();
}
2 changes: 1 addition & 1 deletion lib/Dialect/TritonGPU/Transforms/ReorderInstructions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,6 @@ class TritonGPUReorderInstructionsPass
}
};

std::unique_ptr<Pass> mlir::createTritonGPUReorderInstructionsPass() {
std::unique_ptr<Pass> mlir::triton::gpu::createReorderInstructionsPass() {
return std::make_unique<TritonGPUReorderInstructionsPass>();
}
Loading