-
Notifications
You must be signed in to change notification settings - Fork 5.9k
[CINN][New Hardware Update] extend SplitCudaAndHostModule #64345
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,7 +12,7 @@ | |
| // See the License for the specific language governing permissions and | ||
| // limitations under the License. | ||
|
|
||
| #include "paddle/cinn/backends/codegen_cuda_util.h" | ||
| #include "paddle/cinn/backends/codegen_device_util.h" | ||
|
|
||
| #include "paddle/cinn/backends/cuda_util.h" | ||
| #include "paddle/cinn/common/cas.h" | ||
|
|
@@ -22,7 +22,7 @@ PD_DECLARE_bool(cinn_bucket_compile); | |
| namespace cinn { | ||
| namespace backends { | ||
|
|
||
| std::tuple<ir::Module, ir::Module> SplitCudaAndHostModule(ir::Module module) { | ||
| std::tuple<ir::Module, ir::Module> SplitDeviceAndHostModule(ir::Module module) { | ||
| if (FLAGS_cinn_bucket_compile) { | ||
| detail::CollectBucketStrategyHostFunctionVisitor visitor(module->name); | ||
| Expr expr(module); | ||
|
|
@@ -91,7 +91,16 @@ void detail::CollectBucketStrategyHostFunctionVisitor::ProcessLoweredFunc( | |
| ir::Var kernel_ptr(GenDeviceKernelName(func_node->name, predicate), | ||
| type_of<std::string>()); | ||
|
|
||
| Expr shared_mem_bytes = CalculateSharedMemory(func); | ||
| Expr shared_mem_bytes; | ||
| cinn::runtime::CurrentTarget::GetCurrentTarget().arch.Match( | ||
| [&](std::variant<common::UnknownArch, common::X86Arch, common::ARMArch>) { | ||
| CINN_NOT_IMPLEMENTED; | ||
| }, | ||
| [&](common::NVGPUArch) { | ||
| #ifdef CINN_WITH_CUDA | ||
| shared_mem_bytes = CalculateSharedMemory(func); | ||
| #endif | ||
| }); | ||
|
|
||
| VLOG(6) << "Add a call node for func_node->name " << func_node->name << "\n" | ||
| << "grid_dim: (" << func_node->cuda_axis_info.grid_dim(0) << ", " | ||
|
|
@@ -101,9 +110,17 @@ void detail::CollectBucketStrategyHostFunctionVisitor::ProcessLoweredFunc( | |
| << func_node->cuda_axis_info.block_dim(1) << ", " | ||
| << func_node->cuda_axis_info.block_dim(2) << "), " | ||
| << "shared_mem: " << shared_mem_bytes; | ||
| const char *call_kernel; | ||
| cinn::runtime::CurrentTarget::GetCurrentTarget().arch.Match( | ||
|
||
| [&](std::variant<common::UnknownArch, common::X86Arch, common::ARMArch>) { | ||
| CINN_NOT_IMPLEMENTED; | ||
| }, | ||
| [&](common::NVGPUArch) { | ||
| call_kernel = runtime::intrinsic::call_cuda_kernel; | ||
| }); | ||
| ir::Expr call_extern_api = | ||
| ir::Call::Make(Void(), | ||
| runtime::intrinsic::call_cuda_kernel, | ||
| call_kernel, | ||
| {kernel_ptr, | ||
| kernel_args_, | ||
| kernel_args_num_, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,12 +19,14 @@ | |
| #include <string> | ||
| #include <tuple> | ||
| #include <vector> | ||
|
|
||
| #ifdef CINN_WITH_CUDA | ||
| #include "paddle/cinn/backends/codegen_cuda_dev.h" | ||
| #endif | ||
| #include "paddle/cinn/cinn.h" | ||
| #include "paddle/cinn/ir/ir.h" | ||
| #include "paddle/cinn/ir/ir_mutator.h" | ||
| #include "paddle/cinn/ir/utils/ir_copy.h" | ||
| #include "paddle/cinn/runtime/flags.h" | ||
|
|
||
| namespace cinn { | ||
| namespace backends { | ||
|
|
@@ -43,16 +45,17 @@ namespace backends { | |
| * - replace the original kernel function with a Call node and add it to the | ||
| * first module, add a device kernel function to the second module. | ||
| */ | ||
| std::tuple<ir::Module, ir::Module> SplitCudaAndHostModule(ir::Module module); | ||
| std::tuple<ir::Module, ir::Module> SplitDeviceAndHostModule(ir::Module module); | ||
|
|
||
| namespace detail { | ||
|
|
||
| struct CollectHostFunctionVisitor : public ir::IRMutator<> { | ||
| explicit CollectHostFunctionVisitor(const std::string& module_name) | ||
| : host_module_builder(module_name + "_host", | ||
| cinn::common::DefaultHostTarget()), | ||
| device_module_builder(module_name + "_gpu_device", | ||
| cinn::common::DefaultNVGPUTarget()) {} | ||
| device_module_builder( | ||
| module_name + "_gpu_device", | ||
| cinn::runtime::CurrentTarget::GetCurrentTarget()) {} | ||
|
||
|
|
||
| std::tuple<ir::Module, ir::Module> operator()(Expr* expr) { | ||
| ir::IRMutator<>::Visit(expr, expr); | ||
|
|
@@ -109,9 +112,18 @@ struct CollectHostFunctionVisitor : public ir::IRMutator<> { | |
| // shared_mem_bytes Can be calculated after codegen_cuda_dev buffer creation | ||
| // however, this make CodeGenCUDA_Dev before spliting the host and device | ||
| // module Maybe we could reorder the process. | ||
| CodeGenCUDA_Dev codegen_dev(cinn::common::DefaultNVGPUTarget()); | ||
| codegen_dev.Compile(ir::LoweredFunc(func)); | ||
| Expr shared_mem_bytes = codegen_dev.GetDynSharedMemOffset(); | ||
| Expr shared_mem_bytes; | ||
| cinn::runtime::CurrentTarget::GetCurrentTarget().arch.Match( | ||
| [&](std::variant<common::UnknownArch, | ||
| common::X86Arch, | ||
| common::ARMArch>) { CINN_NOT_IMPLEMENTED; }, | ||
| [&](common::NVGPUArch) { | ||
| #ifdef CINN_WITH_CUDA | ||
| CodeGenCUDA_Dev codegen_dev(cinn::common::DefaultNVGPUTarget()); | ||
| codegen_dev.Compile(ir::LoweredFunc(func)); | ||
| shared_mem_bytes = codegen_dev.GetDynSharedMemOffset(); | ||
| #endif | ||
| }); | ||
|
|
||
| VLOG(6) << "Add a call node for func->name " << func->name << "\n" | ||
| << "grid_dim: (" << func->cuda_axis_info.grid_dim(0) << ", " | ||
|
|
@@ -121,9 +133,19 @@ struct CollectHostFunctionVisitor : public ir::IRMutator<> { | |
| << func->cuda_axis_info.block_dim(1) << ", " | ||
| << func->cuda_axis_info.block_dim(2) << "), " | ||
| << "shared_mem: " << shared_mem_bytes; | ||
|
|
||
| const char* call_kernel; | ||
|
||
| cinn::runtime::CurrentTarget::GetCurrentTarget().arch.Match( | ||
| [&](std::variant<common::UnknownArch, | ||
| common::X86Arch, | ||
| common::ARMArch>) { CINN_NOT_IMPLEMENTED; }, | ||
| [&](common::NVGPUArch) { | ||
| call_kernel = runtime::intrinsic::call_cuda_kernel; | ||
| }); | ||
|
|
||
| auto call_extern_api = | ||
| ir::Call::Make(Void(), | ||
| runtime::intrinsic::call_cuda_kernel, | ||
| call_kernel, | ||
| {kernel_ptr, | ||
| kernel_args, | ||
| kernel_args_num, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
std::optional shared_mem_bytes;