diff --git a/paddle/cinn/backends/CMakeLists.txt b/paddle/cinn/backends/CMakeLists.txt index c746886a43d9b8..c09b95e8d85814 100755 --- a/paddle/cinn/backends/CMakeLists.txt +++ b/paddle/cinn/backends/CMakeLists.txt @@ -13,11 +13,12 @@ gather_srcs( extern_func_protos.cc extern_func_jit_register.cc modular.cc - compiler.cc) + compiler.cc + codegen_device_util.cc) if(WITH_CUDA) add_subdirectory(nvrtc) - list(APPEND srcs cuda_util.cc codegen_cuda_dev.cc codegen_cuda_util.cc) + list(APPEND srcs cuda_util.cc codegen_cuda_dev.cc) endif() if(WITH_OPENMP) diff --git a/paddle/cinn/backends/codegen_cuda_generate_test.cc b/paddle/cinn/backends/codegen_cuda_generate_test.cc index a70099943284f2..b3e9f85614087d 100644 --- a/paddle/cinn/backends/codegen_cuda_generate_test.cc +++ b/paddle/cinn/backends/codegen_cuda_generate_test.cc @@ -21,7 +21,7 @@ #include "paddle/cinn/backends/codegen_cuda_dev.h" #include "paddle/cinn/backends/codegen_cuda_host.h" -#include "paddle/cinn/backends/codegen_cuda_util.h" +#include "paddle/cinn/backends/codegen_device_util.h" #include "paddle/cinn/backends/extern_func_jit_register.h" #include "paddle/cinn/backends/llvm/execution_engine.h" #include "paddle/cinn/backends/llvm/simple_jit.h" diff --git a/paddle/cinn/backends/codegen_cuda_host.cc b/paddle/cinn/backends/codegen_cuda_host.cc index 3b7235819661fa..b888db7c7c7264 100644 --- a/paddle/cinn/backends/codegen_cuda_host.cc +++ b/paddle/cinn/backends/codegen_cuda_host.cc @@ -18,7 +18,7 @@ #include #include -#include "paddle/cinn/backends/codegen_cuda_util.h" +#include "paddle/cinn/backends/codegen_device_util.h" #include "paddle/cinn/backends/extern_func_emitter_builtin.h" #include "paddle/cinn/backends/extern_func_jit_register.h" #include "paddle/cinn/backends/llvm/llvm_util.h" diff --git a/paddle/cinn/backends/codegen_cuda_util.cc b/paddle/cinn/backends/codegen_device_util.cc similarity index 85% rename from paddle/cinn/backends/codegen_cuda_util.cc rename to paddle/cinn/backends/codegen_device_util.cc index 729dcca7be745b..73a4bf1358b2c3 100644 --- a/paddle/cinn/backends/codegen_cuda_util.cc +++ b/paddle/cinn/backends/codegen_device_util.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/cinn/backends/codegen_cuda_util.h" +#include "paddle/cinn/backends/codegen_device_util.h" #include "paddle/cinn/backends/cuda_util.h" #include "paddle/cinn/common/cas.h" @@ -22,7 +22,7 @@ PD_DECLARE_bool(cinn_bucket_compile); namespace cinn { namespace backends { -std::tuple SplitCudaAndHostModule(ir::Module module) { +std::tuple SplitDeviceAndHostModule(ir::Module module) { if (FLAGS_cinn_bucket_compile) { detail::CollectBucketStrategyHostFunctionVisitor visitor(module->name); Expr expr(module); @@ -91,7 +91,16 @@ void detail::CollectBucketStrategyHostFunctionVisitor::ProcessLoweredFunc( ir::Var kernel_ptr(GenDeviceKernelName(func_node->name, predicate), type_of()); - Expr shared_mem_bytes = CalculateSharedMemory(func); + std::optional shared_mem_bytes; + cinn::common::DefaultDeviceTarget().arch.Match( + [&](std::variant) { + CINN_NOT_IMPLEMENTED; + }, + [&](common::NVGPUArch) { +#ifdef CINN_WITH_CUDA + shared_mem_bytes = CalculateSharedMemory(func); +#endif + }); VLOG(6) << "Add a call node for func_node->name " << func_node->name << "\n" << "grid_dim: (" << func_node->cuda_axis_info.grid_dim(0) << ", " @@ -100,10 +109,18 @@ void detail::CollectBucketStrategyHostFunctionVisitor::ProcessLoweredFunc( << "block_dim: (" << func_node->cuda_axis_info.block_dim(0) << ", " << func_node->cuda_axis_info.block_dim(1) << ", " << func_node->cuda_axis_info.block_dim(2) << "), " - << "shared_mem: " << shared_mem_bytes; + << "shared_mem: " << shared_mem_bytes.value(); + std::optional call_kernel; + cinn::common::DefaultDeviceTarget().arch.Match( + [&](std::variant) { + CINN_NOT_IMPLEMENTED; + }, + [&](common::NVGPUArch) { + call_kernel = runtime::intrinsic::call_cuda_kernel; + }); ir::Expr call_extern_api = ir::Call::Make(Void(), - runtime::intrinsic::call_cuda_kernel, + call_kernel.value(), {kernel_ptr, kernel_args_, kernel_args_num_, @@ -113,7 +130,7 @@ void detail::CollectBucketStrategyHostFunctionVisitor::ProcessLoweredFunc( func_node->cuda_axis_info.block_dim(0), // block_x func_node->cuda_axis_info.block_dim(1), // block_y func_node->cuda_axis_info.block_dim(2), // block_z - shared_mem_bytes, // shared_mem + shared_mem_bytes.value(), // shared_mem kernel_stream_}, {}, ir::CallType::Extern, diff --git a/paddle/cinn/backends/codegen_cuda_util.h b/paddle/cinn/backends/codegen_device_util.h similarity index 87% rename from paddle/cinn/backends/codegen_cuda_util.h rename to paddle/cinn/backends/codegen_device_util.h index eddaca176b6c93..caada3153e63bb 100644 --- a/paddle/cinn/backends/codegen_cuda_util.h +++ b/paddle/cinn/backends/codegen_device_util.h @@ -19,12 +19,14 @@ #include #include #include - +#ifdef CINN_WITH_CUDA #include "paddle/cinn/backends/codegen_cuda_dev.h" +#endif #include "paddle/cinn/cinn.h" #include "paddle/cinn/ir/ir.h" #include "paddle/cinn/ir/ir_mutator.h" #include "paddle/cinn/ir/utils/ir_copy.h" +#include "paddle/cinn/runtime/flags.h" namespace cinn { namespace backends { @@ -43,7 +45,7 @@ namespace backends { * - replace the original kernel function with a Call node and add it to the * first module, add a device kernel function to the second module. */ -std::tuple SplitCudaAndHostModule(ir::Module module); +std::tuple SplitDeviceAndHostModule(ir::Module module); namespace detail { @@ -52,7 +54,7 @@ struct CollectHostFunctionVisitor : public ir::IRMutator<> { : host_module_builder(module_name + "_host", cinn::common::DefaultHostTarget()), device_module_builder(module_name + "_gpu_device", - cinn::common::DefaultNVGPUTarget()) {} + cinn::common::DefaultDeviceTarget()) {} std::tuple operator()(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); @@ -109,9 +111,18 @@ struct CollectHostFunctionVisitor : public ir::IRMutator<> { // shared_mem_bytes Can be calculated after codegen_cuda_dev buffer creation // however, this make CodeGenCUDA_Dev before spliting the host and device // module Maybe we could reorder the process. - CodeGenCUDA_Dev codegen_dev(cinn::common::DefaultNVGPUTarget()); - codegen_dev.Compile(ir::LoweredFunc(func)); - Expr shared_mem_bytes = codegen_dev.GetDynSharedMemOffset(); + std::optional shared_mem_bytes; + cinn::common::DefaultDeviceTarget().arch.Match( + [&](std::variant) { CINN_NOT_IMPLEMENTED; }, + [&](common::NVGPUArch) { +#ifdef CINN_WITH_CUDA + CodeGenCUDA_Dev codegen_dev(cinn::common::DefaultNVGPUTarget()); + codegen_dev.Compile(ir::LoweredFunc(func)); + shared_mem_bytes = codegen_dev.GetDynSharedMemOffset(); +#endif + }); VLOG(6) << "Add a call node for func->name " << func->name << "\n" << "grid_dim: (" << func->cuda_axis_info.grid_dim(0) << ", " @@ -120,10 +131,20 @@ struct CollectHostFunctionVisitor : public ir::IRMutator<> { << "block_dim: (" << func->cuda_axis_info.block_dim(0) << ", " << func->cuda_axis_info.block_dim(1) << ", " << func->cuda_axis_info.block_dim(2) << "), " - << "shared_mem: " << shared_mem_bytes; + << "shared_mem: " << shared_mem_bytes.value(); + + std::optional call_kernel; + cinn::common::DefaultDeviceTarget().arch.Match( + [&](std::variant) { CINN_NOT_IMPLEMENTED; }, + [&](common::NVGPUArch) { + call_kernel = runtime::intrinsic::call_cuda_kernel; + }); + auto call_extern_api = ir::Call::Make(Void(), - runtime::intrinsic::call_cuda_kernel, + call_kernel.value(), {kernel_ptr, kernel_args, kernel_args_num, @@ -133,7 +154,7 @@ struct CollectHostFunctionVisitor : public ir::IRMutator<> { func->cuda_axis_info.block_dim(0), // block_x func->cuda_axis_info.block_dim(1), // block_y func->cuda_axis_info.block_dim(2), // block_z - shared_mem_bytes, + shared_mem_bytes.value(), kernel_stream}, {}, ir::CallType::Extern, diff --git a/paddle/cinn/backends/compiler.cc b/paddle/cinn/backends/compiler.cc index b37090a74fbe15..eebcea6aeaa84e 100644 --- a/paddle/cinn/backends/compiler.cc +++ b/paddle/cinn/backends/compiler.cc @@ -24,7 +24,7 @@ #ifdef CINN_WITH_CUDA #include "paddle/cinn/backends/codegen_cuda_dev.h" #include "paddle/cinn/backends/codegen_cuda_host.h" -#include "paddle/cinn/backends/codegen_cuda_util.h" +#include "paddle/cinn/backends/codegen_device_util.h" #include "paddle/cinn/backends/nvrtc/nvrtc_util.h" #include "paddle/cinn/runtime/cuda/cuda_module.h" #include "paddle/cinn/runtime/cuda/cuda_util.h" @@ -246,7 +246,7 @@ std::string Compiler::GetSourceCode(const ir::Module& module) { [&](common::NVGPUArch) -> std::string { #ifdef CINN_WITH_CUDA auto _host_module_device_module_ = - SplitCudaAndHostModule(module); // NOLINT + SplitDeviceAndHostModule(module); // NOLINT auto& host_module = std::get<0>(_host_module_device_module_); auto& device_module = std::get<1>(_host_module_device_module_); CodeGenCUDA_Dev codegen(target_); @@ -270,7 +270,8 @@ void Compiler::BuildDefault(const Module& module) { void Compiler::CompileCudaModule(const Module& module, const std::string& code) { #ifdef CINN_WITH_CUDA - auto _host_module_device_module_ = SplitCudaAndHostModule(module); // NOLINT + auto _host_module_device_module_ = + SplitDeviceAndHostModule(module); // NOLINT auto& host_module = std::get<0>(_host_module_device_module_); auto& device_module = std::get<1>(_host_module_device_module_); VLOG(3) << "[CUDA] host module:\n" << host_module; diff --git a/paddle/cinn/common/arch.h b/paddle/cinn/common/arch.h index e43dbeadc97ab7..768c628f2df576 100644 --- a/paddle/cinn/common/arch.h +++ b/paddle/cinn/common/arch.h @@ -17,6 +17,7 @@ #include #include #include +#include "paddle/common/overloaded.h" namespace cinn { namespace common { @@ -45,6 +46,8 @@ struct Arch final : public ArchBase { return static_cast(*this); } + DEFINE_MATCH_METHOD(); + bool operator==(const auto& other) const { return this->index() == other.index(); } diff --git a/paddle/cinn/common/cuda_test_helper.cc b/paddle/cinn/common/cuda_test_helper.cc index f43678266daa59..e49b2e1d16184d 100644 --- a/paddle/cinn/common/cuda_test_helper.cc +++ b/paddle/cinn/common/cuda_test_helper.cc @@ -16,7 +16,7 @@ #include "paddle/cinn/backends/codegen_cuda_dev.h" #include "paddle/cinn/backends/codegen_cuda_host.h" -#include "paddle/cinn/backends/codegen_cuda_util.h" +#include "paddle/cinn/backends/codegen_device_util.h" #include "paddle/cinn/backends/nvrtc/nvrtc_util.h" #include "paddle/cinn/runtime/cuda/cuda_module.h" #include "paddle/cinn/runtime/cuda/cuda_util.h" @@ -28,7 +28,7 @@ namespace common { void CudaModuleTester::Compile(const ir::Module& m, const std::string& rewrite_cuda_code) { auto _host_module_device_module_ = - backends::SplitCudaAndHostModule(m); // NOLINT + backends::SplitDeviceAndHostModule(m); // NOLINT auto& host_module = std::get<0>(_host_module_device_module_); auto& device_module = std::get<1>(_host_module_device_module_); CHECK(!host_module.functions().empty()); diff --git a/paddle/cinn/common/target.cc b/paddle/cinn/common/target.cc index 57657d01d45a85..b36975a7b0090e 100644 --- a/paddle/cinn/common/target.cc +++ b/paddle/cinn/common/target.cc @@ -249,6 +249,12 @@ const Target &DefaultNVGPUTarget() { return target; } +const Target &DefaultDeviceTarget() { +#ifdef CINN_WITH_CUDA + return DefaultNVGPUTarget(); +#endif +} + int GetMaxThreads() { // cudaDeviceGetAttribute ( int* value, cudaDeviceAttr attr, int device ) int max_threads = 1; diff --git a/paddle/cinn/common/target.h b/paddle/cinn/common/target.h index 6df1d1ece8c5f8..693f561e3185bd 100644 --- a/paddle/cinn/common/target.h +++ b/paddle/cinn/common/target.h @@ -100,6 +100,8 @@ const Target& DefaultHostTarget(); const Target& DefaultNVGPUTarget(); +const Target& DefaultDeviceTarget(); + const Target& DefaultTarget(); int GetMaxThreads(); diff --git a/paddle/cinn/frontend/paddle/model_parser.cc b/paddle/cinn/frontend/paddle/model_parser.cc index ad028bff1c8093..8819fe5df5c3e5 100644 --- a/paddle/cinn/frontend/paddle/model_parser.cc +++ b/paddle/cinn/frontend/paddle/model_parser.cc @@ -19,7 +19,7 @@ #include "paddle/cinn/backends/codegen_cuda_dev.h" #include "paddle/cinn/backends/codegen_cuda_host.h" -#include "paddle/cinn/backends/codegen_cuda_util.h" +#include "paddle/cinn/backends/codegen_device_util.h" #include "paddle/cinn/backends/cuda_util.h" #include "paddle/cinn/common/common.h" #include "paddle/cinn/frontend/paddle/compatible_pb.h" diff --git a/paddle/cinn/hlir/framework/graph_compiler_util.h b/paddle/cinn/hlir/framework/graph_compiler_util.h index 6ef9afa32fcacf..634fee4c67f5bf 100644 --- a/paddle/cinn/hlir/framework/graph_compiler_util.h +++ b/paddle/cinn/hlir/framework/graph_compiler_util.h @@ -91,7 +91,7 @@ struct CompilationContext { void* stream = nullptr; // Set attached source code, if code is not empty, these codes will replace - // the device_module code after SplitCudaAndHostModule. + // the device_module code after SplitDeviceAndHostModule. void ApplySourceCode(const std::string& code); // Apply results of auto-tune to compile. // Compilation will start from CompilationStage::CODEGEN_AND_JIT when tuning diff --git a/paddle/cinn/hlir/framework/op_lowering_test.cc b/paddle/cinn/hlir/framework/op_lowering_test.cc index be33fa25125d28..3ee00b5e07e6ac 100644 --- a/paddle/cinn/hlir/framework/op_lowering_test.cc +++ b/paddle/cinn/hlir/framework/op_lowering_test.cc @@ -18,7 +18,7 @@ #include "paddle/cinn/backends/codegen_c_x86.h" #include "paddle/cinn/backends/codegen_cuda_dev.h" -#include "paddle/cinn/backends/codegen_cuda_util.h" +#include "paddle/cinn/backends/codegen_device_util.h" #include "paddle/cinn/backends/cuda_util.h" #include "paddle/cinn/backends/llvm/execution_engine.h" #include "paddle/cinn/backends/nvrtc/nvrtc_util.h" diff --git a/paddle/cinn/hlir/framework/parallel_compiler.cc b/paddle/cinn/hlir/framework/parallel_compiler.cc index 4e5ec751f0a8ac..ff52c8f570c026 100644 --- a/paddle/cinn/hlir/framework/parallel_compiler.cc +++ b/paddle/cinn/hlir/framework/parallel_compiler.cc @@ -20,7 +20,7 @@ #include "paddle/cinn/backends/codegen_cuda_dev.h" #include "paddle/cinn/backends/codegen_cuda_host.h" -#include "paddle/cinn/backends/codegen_cuda_util.h" +#include "paddle/cinn/backends/codegen_device_util.h" #include "paddle/cinn/backends/compiler.h" #include "paddle/cinn/backends/llvm/codegen_x86.h" #include "paddle/cinn/backends/llvm/runtime_symbol_registry.h" @@ -238,7 +238,7 @@ void ParallelCompiler::Task::CodegenAndJit() { auto ir_module = builder.Build(); if (context->target == cinn::common::DefaultNVGPUTarget()) { #ifdef CINN_WITH_CUDA - auto splited_module = backends::SplitCudaAndHostModule(ir_module); + auto splited_module = backends::SplitDeviceAndHostModule(ir_module); auto hmodule = std::get<0>(splited_module); auto dmodule = std::get<1>(splited_module); diff --git a/paddle/cinn/hlir/framework/pir/op_lowering_impl.cc b/paddle/cinn/hlir/framework/pir/op_lowering_impl.cc index a87bbf953de1be..d0e4aee10f31ad 100644 --- a/paddle/cinn/hlir/framework/pir/op_lowering_impl.cc +++ b/paddle/cinn/hlir/framework/pir/op_lowering_impl.cc @@ -18,7 +18,7 @@ #include "paddle/cinn/adt/map_expr_ctx.h" #include "paddle/cinn/ast_gen_ius/tensor_group.h" -#include "paddle/cinn/backends/codegen_cuda_util.h" +#include "paddle/cinn/backends/codegen_device_util.h" #include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" #include "paddle/cinn/hlir/framework/compile_error.h" #include "paddle/cinn/hlir/framework/pir/op_lowering_util.h" diff --git a/paddle/cinn/hlir/op/custom_call.cc b/paddle/cinn/hlir/op/custom_call.cc index fc84e4cc9eb1a6..c090e165066600 100644 --- a/paddle/cinn/hlir/op/custom_call.cc +++ b/paddle/cinn/hlir/op/custom_call.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/cinn/backends/codegen_cuda_util.h" +#include "paddle/cinn/backends/codegen_device_util.h" #include "paddle/cinn/common/cas.h" #include "paddle/cinn/hlir/framework/node.h" #include "paddle/cinn/hlir/framework/op.h" diff --git a/paddle/cinn/hlir/op/reduction_test.cc b/paddle/cinn/hlir/op/reduction_test.cc index 5586c323462ac6..dab984922fdef4 100644 --- a/paddle/cinn/hlir/op/reduction_test.cc +++ b/paddle/cinn/hlir/op/reduction_test.cc @@ -22,7 +22,7 @@ #include "paddle/cinn/backends/codegen_cuda_dev.h" #include "paddle/cinn/backends/codegen_cuda_host.h" -#include "paddle/cinn/backends/codegen_cuda_util.h" +#include "paddle/cinn/backends/codegen_device_util.h" #include "paddle/cinn/backends/cuda_util.h" #include "paddle/cinn/backends/llvm/execution_engine.h" #include "paddle/cinn/backends/llvm/runtime_symbol_registry.h" @@ -116,7 +116,7 @@ std::pair GenReduceCode( // now. auto module = builder.Build(); auto host_module_device_module = - backends::SplitCudaAndHostModule(module); // NOLINT + backends::SplitDeviceAndHostModule(module); // NOLINT auto& host_module = std::get<0>(host_module_device_module); auto& device_module = std::get<1>(host_module_device_module); diff --git a/paddle/cinn/hlir/op/transform_test.cc b/paddle/cinn/hlir/op/transform_test.cc index 0e9b6d86d2ece6..a55224c9954365 100644 --- a/paddle/cinn/hlir/op/transform_test.cc +++ b/paddle/cinn/hlir/op/transform_test.cc @@ -21,7 +21,7 @@ #include "paddle/cinn/backends/codegen_cuda_dev.h" #include "paddle/cinn/backends/codegen_cuda_host.h" -#include "paddle/cinn/backends/codegen_cuda_util.h" +#include "paddle/cinn/backends/codegen_device_util.h" #include "paddle/cinn/backends/cuda_util.h" #include "paddle/cinn/backends/llvm/execution_engine.h" #include "paddle/cinn/backends/llvm/runtime_symbol_registry.h" diff --git a/paddle/cinn/hlir/pe/pe_transform_test.cc b/paddle/cinn/hlir/pe/pe_transform_test.cc index 852cc26211298e..601882227f6282 100644 --- a/paddle/cinn/hlir/pe/pe_transform_test.cc +++ b/paddle/cinn/hlir/pe/pe_transform_test.cc @@ -15,7 +15,7 @@ #include #include "paddle/cinn/backends/codegen_cuda_dev.h" -#include "paddle/cinn/backends/codegen_cuda_util.h" +#include "paddle/cinn/backends/codegen_device_util.h" #include "paddle/cinn/backends/cuda_util.h" #include "paddle/cinn/backends/llvm/execution_engine.h" #include "paddle/cinn/backends/nvrtc/nvrtc_util.h" @@ -132,7 +132,7 @@ TEST(ScatterAssign, ScatterAssign) { builder.AddFunction(func); auto module = builder.Build(); - auto host_module_device_module = backends::SplitCudaAndHostModule(module); + auto host_module_device_module = backends::SplitDeviceAndHostModule(module); auto &host_module = std::get<0>(host_module_device_module); auto &device_module = std::get<1>(host_module_device_module); @@ -176,7 +176,7 @@ TEST(SliceAssign, SliceAssign) { builder.AddFunction(func); auto module = builder.Build(); - auto host_module_device_module = backends::SplitCudaAndHostModule(module); + auto host_module_device_module = backends::SplitDeviceAndHostModule(module); auto &host_module = std::get<0>(host_module_device_module); auto &device_module = std::get<1>(host_module_device_module); @@ -217,7 +217,7 @@ TEST(Concat, ConcatCase0) { builder.AddFunction(func); auto module = builder.Build(); - auto host_module_device_module = backends::SplitCudaAndHostModule(module); + auto host_module_device_module = backends::SplitDeviceAndHostModule(module); auto &host_module = std::get<0>(host_module_device_module); auto &device_module = std::get<1>(host_module_device_module);