Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions paddle/cinn/backends/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@ gather_srcs(
extern_func_protos.cc
extern_func_jit_register.cc
modular.cc
compiler.cc)
compiler.cc
codegen_device_util.cc)

if(WITH_CUDA)
add_subdirectory(nvrtc)
list(APPEND srcs cuda_util.cc codegen_cuda_dev.cc codegen_cuda_util.cc)
list(APPEND srcs cuda_util.cc codegen_cuda_dev.cc)
endif()

if(WITH_OPENMP)
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/backends/codegen_cuda_generate_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

#include "paddle/cinn/backends/codegen_cuda_dev.h"
#include "paddle/cinn/backends/codegen_cuda_host.h"
#include "paddle/cinn/backends/codegen_cuda_util.h"
#include "paddle/cinn/backends/codegen_device_util.h"
#include "paddle/cinn/backends/extern_func_jit_register.h"
#include "paddle/cinn/backends/llvm/execution_engine.h"
#include "paddle/cinn/backends/llvm/simple_jit.h"
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/backends/codegen_cuda_host.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
#include <string>
#include <unordered_map>

#include "paddle/cinn/backends/codegen_cuda_util.h"
#include "paddle/cinn/backends/codegen_device_util.h"
#include "paddle/cinn/backends/extern_func_emitter_builtin.h"
#include "paddle/cinn/backends/extern_func_jit_register.h"
#include "paddle/cinn/backends/llvm/llvm_util.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/backends/codegen_cuda_util.h"
#include "paddle/cinn/backends/codegen_device_util.h"

#include "paddle/cinn/backends/cuda_util.h"
#include "paddle/cinn/common/cas.h"
Expand All @@ -22,7 +22,7 @@ PD_DECLARE_bool(cinn_bucket_compile);
namespace cinn {
namespace backends {

std::tuple<ir::Module, ir::Module> SplitCudaAndHostModule(ir::Module module) {
std::tuple<ir::Module, ir::Module> SplitDeviceAndHostModule(ir::Module module) {
if (FLAGS_cinn_bucket_compile) {
detail::CollectBucketStrategyHostFunctionVisitor visitor(module->name);
Expr expr(module);
Expand Down Expand Up @@ -91,7 +91,16 @@ void detail::CollectBucketStrategyHostFunctionVisitor::ProcessLoweredFunc(
ir::Var kernel_ptr(GenDeviceKernelName(func_node->name, predicate),
type_of<std::string>());

Expr shared_mem_bytes = CalculateSharedMemory(func);
Expr shared_mem_bytes;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

std::optional shared_mem_bytes;

cinn::runtime::CurrentTarget::GetCurrentTarget().arch.Match(
[&](std::variant<common::UnknownArch, common::X86Arch, common::ARMArch>) {
CINN_NOT_IMPLEMENTED;
},
[&](common::NVGPUArch) {
#ifdef CINN_WITH_CUDA
shared_mem_bytes = CalculateSharedMemory(func);
#endif
});

VLOG(6) << "Add a call node for func_node->name " << func_node->name << "\n"
<< "grid_dim: (" << func_node->cuda_axis_info.grid_dim(0) << ", "
Expand All @@ -101,9 +110,17 @@ void detail::CollectBucketStrategyHostFunctionVisitor::ProcessLoweredFunc(
<< func_node->cuda_axis_info.block_dim(1) << ", "
<< func_node->cuda_axis_info.block_dim(2) << "), "
<< "shared_mem: " << shared_mem_bytes;
const char *call_kernel;
cinn::runtime::CurrentTarget::GetCurrentTarget().arch.Match(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cinn::runtime::CurrentTarget::GetCurrentTarget这个名字非常拧巴。原本的命名空间是cinn::common::,为啥不继续写在common目录下呢?

[&](std::variant<common::UnknownArch, common::X86Arch, common::ARMArch>) {
CINN_NOT_IMPLEMENTED;
},
[&](common::NVGPUArch) {
call_kernel = runtime::intrinsic::call_cuda_kernel;
});
ir::Expr call_extern_api =
ir::Call::Make(Void(),
runtime::intrinsic::call_cuda_kernel,
call_kernel,
{kernel_ptr,
kernel_args_,
kernel_args_num_,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@
#include <string>
#include <tuple>
#include <vector>

#ifdef CINN_WITH_CUDA
#include "paddle/cinn/backends/codegen_cuda_dev.h"
#endif
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/runtime/flags.h"

namespace cinn {
namespace backends {
Expand All @@ -43,16 +45,17 @@ namespace backends {
* - replace the original kernel function with a Call node and add it to the
* first module, add a device kernel function to the second module.
*/
std::tuple<ir::Module, ir::Module> SplitCudaAndHostModule(ir::Module module);
std::tuple<ir::Module, ir::Module> SplitDeviceAndHostModule(ir::Module module);

namespace detail {

struct CollectHostFunctionVisitor : public ir::IRMutator<> {
explicit CollectHostFunctionVisitor(const std::string& module_name)
: host_module_builder(module_name + "_host",
cinn::common::DefaultHostTarget()),
device_module_builder(module_name + "_gpu_device",
cinn::common::DefaultNVGPUTarget()) {}
device_module_builder(
module_name + "_gpu_device",
cinn::runtime::CurrentTarget::GetCurrentTarget()) {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些改动没有遵守平迁原则。cinn::common::DefaultNVGPUTarget cinn::runtime::CurrentTarget::GetCurrentTarget 是啥关系呢?如果是替换,那应该在别的pr里完成。
就算要同时迁移这里的逻辑,与旧版cinn::common::DefaultNVGPUTarget对应的名字难道不是cinn::common::DefaultDeviceTarget吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


std::tuple<ir::Module, ir::Module> operator()(Expr* expr) {
ir::IRMutator<>::Visit(expr, expr);
Expand Down Expand Up @@ -109,9 +112,18 @@ struct CollectHostFunctionVisitor : public ir::IRMutator<> {
// shared_mem_bytes Can be calculated after codegen_cuda_dev buffer creation
// however, this make CodeGenCUDA_Dev before spliting the host and device
// module Maybe we could reorder the process.
CodeGenCUDA_Dev codegen_dev(cinn::common::DefaultNVGPUTarget());
codegen_dev.Compile(ir::LoweredFunc(func));
Expr shared_mem_bytes = codegen_dev.GetDynSharedMemOffset();
Expr shared_mem_bytes;
cinn::runtime::CurrentTarget::GetCurrentTarget().arch.Match(
[&](std::variant<common::UnknownArch,
common::X86Arch,
common::ARMArch>) { CINN_NOT_IMPLEMENTED; },
[&](common::NVGPUArch) {
#ifdef CINN_WITH_CUDA
CodeGenCUDA_Dev codegen_dev(cinn::common::DefaultNVGPUTarget());
codegen_dev.Compile(ir::LoweredFunc(func));
shared_mem_bytes = codegen_dev.GetDynSharedMemOffset();
#endif
});

VLOG(6) << "Add a call node for func->name " << func->name << "\n"
<< "grid_dim: (" << func->cuda_axis_info.grid_dim(0) << ", "
Expand All @@ -121,9 +133,19 @@ struct CollectHostFunctionVisitor : public ir::IRMutator<> {
<< func->cuda_axis_info.block_dim(1) << ", "
<< func->cuda_axis_info.block_dim(2) << "), "
<< "shared_mem: " << shared_mem_bytes;

const char* call_kernel;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const char* call_kernel = nullptr;
基本数据类型总是需要默认值。
或者更加现代c++的感觉:

std::optional<const char*> call_kernel;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

cinn::runtime::CurrentTarget::GetCurrentTarget().arch.Match(
[&](std::variant<common::UnknownArch,
common::X86Arch,
common::ARMArch>) { CINN_NOT_IMPLEMENTED; },
[&](common::NVGPUArch) {
call_kernel = runtime::intrinsic::call_cuda_kernel;
});

auto call_extern_api =
ir::Call::Make(Void(),
runtime::intrinsic::call_cuda_kernel,
call_kernel,
{kernel_ptr,
kernel_args,
kernel_args_num,
Expand Down
7 changes: 4 additions & 3 deletions paddle/cinn/backends/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#ifdef CINN_WITH_CUDA
#include "paddle/cinn/backends/codegen_cuda_dev.h"
#include "paddle/cinn/backends/codegen_cuda_host.h"
#include "paddle/cinn/backends/codegen_cuda_util.h"
#include "paddle/cinn/backends/codegen_device_util.h"
#include "paddle/cinn/backends/nvrtc/nvrtc_util.h"
#include "paddle/cinn/runtime/cuda/cuda_module.h"
#include "paddle/cinn/runtime/cuda/cuda_util.h"
Expand Down Expand Up @@ -246,7 +246,7 @@ std::string Compiler::GetSourceCode(const ir::Module& module) {
[&](common::NVGPUArch) -> std::string {
#ifdef CINN_WITH_CUDA
auto _host_module_device_module_ =
SplitCudaAndHostModule(module); // NOLINT
SplitDeviceAndHostModule(module); // NOLINT
auto& host_module = std::get<0>(_host_module_device_module_);
auto& device_module = std::get<1>(_host_module_device_module_);
CodeGenCUDA_Dev codegen(target_);
Expand All @@ -270,7 +270,8 @@ void Compiler::BuildDefault(const Module& module) {
void Compiler::CompileCudaModule(const Module& module,
const std::string& code) {
#ifdef CINN_WITH_CUDA
auto _host_module_device_module_ = SplitCudaAndHostModule(module); // NOLINT
auto _host_module_device_module_ =
SplitDeviceAndHostModule(module); // NOLINT
auto& host_module = std::get<0>(_host_module_device_module_);
auto& device_module = std::get<1>(_host_module_device_module_);
VLOG(3) << "[CUDA] host module:\n" << host_module;
Expand Down
3 changes: 3 additions & 0 deletions paddle/cinn/common/arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <functional>
#include <ostream>
#include <variant>
#include "paddle/common/overloaded.h"

namespace cinn {
namespace common {
Expand Down Expand Up @@ -45,6 +46,8 @@ struct Arch final : public ArchBase {
return static_cast<const ArchBase&>(*this);
}

DEFINE_MATCH_METHOD();

bool operator==(const auto& other) const {
return this->index() == other.index();
}
Expand Down
4 changes: 2 additions & 2 deletions paddle/cinn/common/cuda_test_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

#include "paddle/cinn/backends/codegen_cuda_dev.h"
#include "paddle/cinn/backends/codegen_cuda_host.h"
#include "paddle/cinn/backends/codegen_cuda_util.h"
#include "paddle/cinn/backends/codegen_device_util.h"
#include "paddle/cinn/backends/nvrtc/nvrtc_util.h"
#include "paddle/cinn/runtime/cuda/cuda_module.h"
#include "paddle/cinn/runtime/cuda/cuda_util.h"
Expand All @@ -28,7 +28,7 @@ namespace common {
void CudaModuleTester::Compile(const ir::Module& m,
const std::string& rewrite_cuda_code) {
auto _host_module_device_module_ =
backends::SplitCudaAndHostModule(m); // NOLINT
backends::SplitDeviceAndHostModule(m); // NOLINT
auto& host_module = std::get<0>(_host_module_device_module_);
auto& device_module = std::get<1>(_host_module_device_module_);
CHECK(!host_module.functions().empty());
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/frontend/paddle/model_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

#include "paddle/cinn/backends/codegen_cuda_dev.h"
#include "paddle/cinn/backends/codegen_cuda_host.h"
#include "paddle/cinn/backends/codegen_cuda_util.h"
#include "paddle/cinn/backends/codegen_device_util.h"
#include "paddle/cinn/backends/cuda_util.h"
#include "paddle/cinn/common/common.h"
#include "paddle/cinn/frontend/paddle/compatible_pb.h"
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/hlir/framework/graph_compiler_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ struct CompilationContext {
void* stream = nullptr;

// Set attached source code, if code is not empty, these codes will replace
// the device_module code after SplitCudaAndHostModule.
// the device_module code after SplitDeviceAndHostModule.
void ApplySourceCode(const std::string& code);
// Apply results of auto-tune to compile.
// Compilation will start from CompilationStage::CODEGEN_AND_JIT when tuning
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/hlir/framework/op_lowering_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

#include "paddle/cinn/backends/codegen_c_x86.h"
#include "paddle/cinn/backends/codegen_cuda_dev.h"
#include "paddle/cinn/backends/codegen_cuda_util.h"
#include "paddle/cinn/backends/codegen_device_util.h"
#include "paddle/cinn/backends/cuda_util.h"
#include "paddle/cinn/backends/llvm/execution_engine.h"
#include "paddle/cinn/backends/nvrtc/nvrtc_util.h"
Expand Down
4 changes: 2 additions & 2 deletions paddle/cinn/hlir/framework/parallel_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

#include "paddle/cinn/backends/codegen_cuda_dev.h"
#include "paddle/cinn/backends/codegen_cuda_host.h"
#include "paddle/cinn/backends/codegen_cuda_util.h"
#include "paddle/cinn/backends/codegen_device_util.h"
#include "paddle/cinn/backends/compiler.h"
#include "paddle/cinn/backends/llvm/codegen_x86.h"
#include "paddle/cinn/backends/llvm/runtime_symbol_registry.h"
Expand Down Expand Up @@ -238,7 +238,7 @@ void ParallelCompiler::Task::CodegenAndJit() {
auto ir_module = builder.Build();
if (context->target == cinn::common::DefaultNVGPUTarget()) {
#ifdef CINN_WITH_CUDA
auto splited_module = backends::SplitCudaAndHostModule(ir_module);
auto splited_module = backends::SplitDeviceAndHostModule(ir_module);
auto hmodule = std::get<0>(splited_module);
auto dmodule = std::get<1>(splited_module);

Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/hlir/framework/pir/op_lowering_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

#include "paddle/cinn/adt/map_expr_ctx.h"
#include "paddle/cinn/ast_gen_ius/tensor_group.h"
#include "paddle/cinn/backends/codegen_cuda_util.h"
#include "paddle/cinn/backends/codegen_device_util.h"
#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h"
#include "paddle/cinn/hlir/framework/compile_error.h"
#include "paddle/cinn/hlir/framework/pir/op_lowering_util.h"
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/hlir/op/custom_call.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/backends/codegen_cuda_util.h"
#include "paddle/cinn/backends/codegen_device_util.h"
#include "paddle/cinn/common/cas.h"
#include "paddle/cinn/hlir/framework/node.h"
#include "paddle/cinn/hlir/framework/op.h"
Expand Down
4 changes: 2 additions & 2 deletions paddle/cinn/hlir/op/reduction_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

#include "paddle/cinn/backends/codegen_cuda_dev.h"
#include "paddle/cinn/backends/codegen_cuda_host.h"
#include "paddle/cinn/backends/codegen_cuda_util.h"
#include "paddle/cinn/backends/codegen_device_util.h"
#include "paddle/cinn/backends/cuda_util.h"
#include "paddle/cinn/backends/llvm/execution_engine.h"
#include "paddle/cinn/backends/llvm/runtime_symbol_registry.h"
Expand Down Expand Up @@ -116,7 +116,7 @@ std::pair<ir::Module, std::string> GenReduceCode(
// now.
auto module = builder.Build();
auto host_module_device_module =
backends::SplitCudaAndHostModule(module); // NOLINT
backends::SplitDeviceAndHostModule(module); // NOLINT
auto& host_module = std::get<0>(host_module_device_module);
auto& device_module = std::get<1>(host_module_device_module);

Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/hlir/op/transform_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

#include "paddle/cinn/backends/codegen_cuda_dev.h"
#include "paddle/cinn/backends/codegen_cuda_host.h"
#include "paddle/cinn/backends/codegen_cuda_util.h"
#include "paddle/cinn/backends/codegen_device_util.h"
#include "paddle/cinn/backends/cuda_util.h"
#include "paddle/cinn/backends/llvm/execution_engine.h"
#include "paddle/cinn/backends/llvm/runtime_symbol_registry.h"
Expand Down
8 changes: 4 additions & 4 deletions paddle/cinn/hlir/pe/pe_transform_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#include <gtest/gtest.h>

#include "paddle/cinn/backends/codegen_cuda_dev.h"
#include "paddle/cinn/backends/codegen_cuda_util.h"
#include "paddle/cinn/backends/codegen_device_util.h"
#include "paddle/cinn/backends/cuda_util.h"
#include "paddle/cinn/backends/llvm/execution_engine.h"
#include "paddle/cinn/backends/nvrtc/nvrtc_util.h"
Expand Down Expand Up @@ -132,7 +132,7 @@ TEST(ScatterAssign, ScatterAssign) {
builder.AddFunction(func);

auto module = builder.Build();
auto host_module_device_module = backends::SplitCudaAndHostModule(module);
auto host_module_device_module = backends::SplitDeviceAndHostModule(module);
auto &host_module = std::get<0>(host_module_device_module);
auto &device_module = std::get<1>(host_module_device_module);

Expand Down Expand Up @@ -176,7 +176,7 @@ TEST(SliceAssign, SliceAssign) {
builder.AddFunction(func);

auto module = builder.Build();
auto host_module_device_module = backends::SplitCudaAndHostModule(module);
auto host_module_device_module = backends::SplitDeviceAndHostModule(module);
auto &host_module = std::get<0>(host_module_device_module);
auto &device_module = std::get<1>(host_module_device_module);

Expand Down Expand Up @@ -217,7 +217,7 @@ TEST(Concat, ConcatCase0) {
builder.AddFunction(func);

auto module = builder.Build();
auto host_module_device_module = backends::SplitCudaAndHostModule(module);
auto host_module_device_module = backends::SplitDeviceAndHostModule(module);
auto &host_module = std::get<0>(host_module_device_module);
auto &device_module = std::get<1>(host_module_device_module);

Expand Down