Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 14 additions & 11 deletions paddle/cinn/hlir/framework/accuracy_checker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -272,19 +272,22 @@ std::string AccuracyChecker::CheckBuffer(const cinn_buffer_t* buffer,

template <typename T>
void AccuracyChecker::MemcpyDeviceToHost(const T* src, size_t numel, T* dst) {
target_.arch.Match(
[&](common::NVGPUArch) {
#ifdef CINN_WITH_CUDA
if (target_ == cinn::common::DefaultNVGPUTarget()) {
cudaMemcpy(dst, src, numel * sizeof(T), cudaMemcpyDeviceToHost);
return;
}
cudaMemcpy(dst, src, numel * sizeof(T), cudaMemcpyDeviceToHost);
#else
CINN_NOT_IMPLEMENTED;
#endif
if (target_ == cinn::common::DefaultHostTarget()) {
for (size_t i = 0; i < numel; ++i) {
dst[i] = src[i];
}
} else {
CHECK(false) << "Not supported target type.";
}
},
[&](common::X86Arch) {
for (size_t i = 0; i < numel; ++i) {
dst[i] = src[i];
}
},
[&](std::variant<common::UnknownArch, common::ARMArch>) {
CINN_NOT_IMPLEMENTED;
});
}

template <typename T>
Expand Down
25 changes: 15 additions & 10 deletions paddle/cinn/hlir/framework/accuracy_checker_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,18 +47,23 @@ void SetRandomTensor(Tensor tensor, Target target, bool generate_nan) {

std::vector<float> random_nan_vec(numel);
GenerateRandomData(random_nan_vec.data(), numel, generate_nan);

target.arch.Match(
[&](common::NVGPUArch) {
#ifdef CINN_WITH_CUDA
if (target == cinn::common::DefaultNVGPUTarget()) {
cudaMemcpy(dst,
random_nan_vec.data(),
numel * sizeof(float),
cudaMemcpyHostToDevice);
}
cudaMemcpy(dst,
random_nan_vec.data(),
numel * sizeof(float),
cudaMemcpyHostToDevice);
#else
CINN_NOT_IMPLEMENTED;
#endif
if (target == cinn::common::DefaultHostTarget()) {
std::copy(random_nan_vec.begin(), random_nan_vec.end(), dst);
}
},
[&](common::X86Arch) {
std::copy(random_nan_vec.begin(), random_nan_vec.end(), dst);
},
[&](std::variant<common::UnknownArch, common::ARMArch>) {
CINN_NOT_IMPLEMENTED;
});
}

TEST(AccuracyChecker, tensor) {
Expand Down
15 changes: 11 additions & 4 deletions paddle/cinn/hlir/framework/graph_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -366,10 +366,17 @@ std::vector<ir::LoweredFunc> GetFuncFromImpl(
for (int i = 0; i < C->size() - 1; i++) {
ir::Expr temp = C[i];
// checkout whether the tensor is with buffer.
if (!temp.as_tensor_ref()->buffer.defined() ||
target != cinn::common::DefaultNVGPUTarget()) {
all_arg_tensors.push_back(temp.as_tensor_ref());
}
target.arch.Match(
[&](common::NVGPUArch) {
if (!temp.as_tensor_ref()->buffer.defined()) {
all_arg_tensors.push_back(temp.as_tensor_ref());
}
},
[&](std::variant<common::UnknownArch,
common::X86Arch,
common::ARMArch>) {
all_arg_tensors.push_back(temp.as_tensor_ref());
});
}

poly::StageMap stages = C.back();
Expand Down
74 changes: 46 additions & 28 deletions paddle/cinn/hlir/framework/instruction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,13 +156,17 @@ void Instruction::Run(
CHECK(fn_ptrs_[idx]) << "The LoweredFunc address should be set first by "
"calling SetLoweredFunc method";
if (!dryrun) {
if (target_ == cinn::common::DefaultNVGPUTarget()) {
((lower_func_ptr_g)fn_ptrs_[idx])(
static_cast<void*>(pod_args.data()), pod_args.size(), stream);
} else {
((lower_func_ptr_t)fn_ptrs_[idx])(static_cast<void*>(pod_args.data()),
pod_args.size());
}
target_.arch.Match(
[&](common::NVGPUArch) {
((lower_func_ptr_g)fn_ptrs_[idx])(
static_cast<void*>(pod_args.data()), pod_args.size(), stream);
},
[&](std::variant<common::UnknownArch,
common::X86Arch,
common::ARMArch>) {
((lower_func_ptr_t)fn_ptrs_[idx])(
static_cast<void*>(pod_args.data()), pod_args.size());
});
}
}
VLOG(3) << "Done Running extern function " << function_name_;
Expand All @@ -177,13 +181,17 @@ void Instruction::Run(
CHECK(fn_ptrs_[idx]) << "The LoweredFunc address should be set first by "
"calling SetLoweredFunc method";
if (!dryrun) {
if (target_ == cinn::common::DefaultNVGPUTarget()) {
((lower_func_ptr_g)fn_ptrs_[idx])(
static_cast<void*>(pod_args.data()), pod_args.size(), stream);
} else {
((lower_func_ptr_t)fn_ptrs_[idx])(static_cast<void*>(pod_args.data()),
pod_args.size());
}
target_.arch.Match(
[&](common::NVGPUArch) {
((lower_func_ptr_g)fn_ptrs_[idx])(
static_cast<void*>(pod_args.data()), pod_args.size(), stream);
},
[&](std::variant<common::UnknownArch,
common::X86Arch,
common::ARMArch>) {
((lower_func_ptr_t)fn_ptrs_[idx])(
static_cast<void*>(pod_args.data()), pod_args.size());
});
}
}
VLOG(3) << "Done Running extern function " << function_name_;
Expand Down Expand Up @@ -231,13 +239,19 @@ void Instruction::Run(
<< "The LoweredFunc address should be set first by "
"calling SetLoweredFunc method";
if (!dryrun) {
if (target_ == cinn::common::DefaultNVGPUTarget()) {
((lower_func_ptr_g)fn_ptrs_[idx])(
static_cast<void*>(pod_args.data()), pod_args.size(), stream);
} else {
((lower_func_ptr_t)fn_ptrs_[idx])(
static_cast<void*>(pod_args.data()), pod_args.size());
}
target_.arch.Match(
[&](common::NVGPUArch) {
((lower_func_ptr_g)fn_ptrs_[idx])(
static_cast<void*>(pod_args.data()),
pod_args.size(),
stream);
},
[&](std::variant<common::UnknownArch,
common::X86Arch,
common::ARMArch>) {
((lower_func_ptr_t)fn_ptrs_[idx])(
static_cast<void*>(pod_args.data()), pod_args.size());
});
}
}
VLOG(3) << "Done Running extern function " << function_name_;
Expand Down Expand Up @@ -390,13 +404,17 @@ void Instruction::Run(
CHECK(fn_ptrs_[idx]) << "The LoweredFunc address should be set first by "
"calling SetLoweredFunc method";
if (!dryrun) {
if (target_ == cinn::common::DefaultNVGPUTarget()) {
((lower_func_ptr_g)fn_ptrs_[idx])(
static_cast<void*>(pod_args.data()), pod_args.size(), stream);
} else {
((lower_func_ptr_t)fn_ptrs_[idx])(static_cast<void*>(pod_args.data()),
pod_args.size());
}
target_.arch.Match(
[&](common::NVGPUArch) {
((lower_func_ptr_g)fn_ptrs_[idx])(
static_cast<void*>(pod_args.data()), pod_args.size(), stream);
},
[&](std::variant<common::UnknownArch,
common::X86Arch,
common::ARMArch>) {
((lower_func_ptr_t)fn_ptrs_[idx])(
static_cast<void*>(pod_args.data()), pod_args.size());
});
}
}
VLOG(3) << "Done Running extern function " << function_name_;
Expand Down
22 changes: 11 additions & 11 deletions paddle/cinn/hlir/framework/instruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,21 +109,21 @@ class Instruction {
auto& pod_args = args_cached_[idx];
CHECK(fn_ptrs_[idx]) << "The LoweredFunc address should be set first "
"by calling SetLoweredFunc method";
if (target_ == cinn::common::DefaultNVGPUTarget()) {
((lower_func_ptr_g)fn_ptrs_[idx])(
static_cast<void*>(pod_args.data()), pod_args.size(), stream);
} else {
((lower_func_ptr_t)fn_ptrs_[idx])(static_cast<void*>(pod_args.data()),
pod_args.size());
}
cinn::common::DefaultDeviceTarget().arch.Match(
[&](std::variant<common::UnknownArch,
common::X86Arch,
common::ARMArch>) {},
target_.arch.Match(
[&](common::NVGPUArch) {
((lower_func_ptr_g)fn_ptrs_[idx])(
static_cast<void*>(pod_args.data()), pod_args.size(), stream);
#ifdef CINN_WITH_CUDA
CUDA_CALL(cudaDeviceSynchronize());
#else
CINN_NOT_IMPLEMENTED;
#endif
},
[&](std::variant<common::UnknownArch,
common::X86Arch,
common::ARMArch>) {
((lower_func_ptr_t)fn_ptrs_[idx])(
static_cast<void*>(pod_args.data()), pod_args.size());
});
}
}
Expand Down
18 changes: 13 additions & 5 deletions paddle/cinn/hlir/framework/op_lowering_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -408,11 +408,19 @@ std::vector<ir::LoweredFunc> OpLowererImpl::DoOpLower(
}

// Insert output tensors into function arg
if (!expr.as_tensor_ref()->buffer.defined() ||
this->target_ != cinn::common::DefaultNVGPUTarget()) {
op_func_arg_tensors->push_back(expr.as_tensor_ref());
expr.as_tensor_ref()->WithBuffer();
}
target_.arch.Match(
[&](common::NVGPUArch) {
if (!expr.as_tensor_ref()->buffer.defined()) {
op_func_arg_tensors->push_back(expr.as_tensor_ref());
expr.as_tensor_ref()->WithBuffer();
}
},
[&](std::variant<common::UnknownArch,
common::X86Arch,
common::ARMArch>) {
op_func_arg_tensors->push_back(expr.as_tensor_ref());
expr.as_tensor_ref()->WithBuffer();
});
}

// 2.Do lower
Expand Down
100 changes: 54 additions & 46 deletions paddle/cinn/hlir/framework/parallel_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -246,58 +246,66 @@ void ParallelCompiler::Task::CodegenAndJit() {
}

auto ir_module = builder.Build();
if (context->target == cinn::common::DefaultNVGPUTarget()) {
context->target.arch.Match(
[&](common::NVGPUArch) {
#ifdef CINN_WITH_CUDA
auto splited_module = backends::SplitDeviceAndHostModule(ir_module);
auto hmodule = std::get<0>(splited_module);
auto dmodule = std::get<1>(splited_module);
auto splited_module = backends::SplitDeviceAndHostModule(ir_module);
auto hmodule = std::get<0>(splited_module);
auto dmodule = std::get<1>(splited_module);

VLOG(4) << "Host Code:\n" << hmodule;
VLOG(4) << "Device Code:\n" << dmodule;
std::string cuda_c;
if (context->attached_source_code.empty()) {
backends::CodeGenCUDA_Dev codegen(context->target);
cuda_c = codegen.Compile(dmodule);
} else {
VLOG(4) << "Codegen and jit with attached source code.";
cuda_c = context->attached_source_code;
}
CHECK(!cuda_c.empty()) << "Compile CUDA C code failed from device module:\n"
<< dmodule;
backends::CompilationInfoDumper::DumpSourceCodeByGroupIndex(
cuda_c, group_id, device_id);
pcompiler->result_.SetSourceCode(group_id, cuda_c);
VLOG(4) << "Host Code:\n" << hmodule;
VLOG(4) << "Device Code:\n" << dmodule;
std::string cuda_c;
if (context->attached_source_code.empty()) {
backends::CodeGenCUDA_Dev codegen(context->target);
cuda_c = codegen.Compile(dmodule);
} else {
VLOG(4) << "Codegen and jit with attached source code.";
cuda_c = context->attached_source_code;
}
CHECK(!cuda_c.empty())
<< "Compile CUDA C code failed from device module:\n"
<< dmodule;
backends::CompilationInfoDumper::DumpSourceCodeByGroupIndex(
cuda_c, group_id, device_id);
pcompiler->result_.SetSourceCode(group_id, cuda_c);

cinn::backends::SourceCodePrint::GetInstance()->write(cuda_c);
cinn::backends::SourceCodePrint::GetInstance()->write(cuda_c);

using runtime::cuda::CUDAModule;
backends::nvrtc::Compiler compiler;
auto ptx = compiler(cuda_c);
CHECK(!ptx.empty()) << "Compile PTX failed from source code:\n" << cuda_c;
backends::CompilationInfoDumper::DumpPtxCodeByGroupIndex(
ptx, group_id, device_id);
pcompiler->result_.SetSourcePtx(group_id, ptx);
// load cumodule
cumodule = std::make_unique<CUDAModule>(ptx,
compiler.compile_to_cubin()
? CUDAModule::Kind::CUBIN
: CUDAModule::Kind::PTX);
using runtime::cuda::CUDAModule;
backends::nvrtc::Compiler compiler;
auto ptx = compiler(cuda_c);
CHECK(!ptx.empty()) << "Compile PTX failed from source code:\n"
<< cuda_c;
backends::CompilationInfoDumper::DumpPtxCodeByGroupIndex(
ptx, group_id, device_id);
pcompiler->result_.SetSourcePtx(group_id, ptx);
// load cumodule
cumodule = std::make_unique<CUDAModule>(ptx,
compiler.compile_to_cubin()
? CUDAModule::Kind::CUBIN
: CUDAModule::Kind::PTX);

// register kernel
backends::RuntimeSymbols symbols;
for (auto& fn : dmodule.functions()) {
auto cufunc = cumodule->GetFunction(device_id, fn->name);
CHECK(cufunc);
symbols.RegisterVar(fn->name + "_ptr_", reinterpret_cast<void*>(cufunc));
}
engine = backends::ExecutionEngine::Create(backends::ExecutionOptions(),
std::move(symbols));
engine->Link<backends::CodeGenCUDA_Host>(hmodule);
// register kernel
backends::RuntimeSymbols symbols;
for (auto& fn : dmodule.functions()) {
auto cufunc = cumodule->GetFunction(device_id, fn->name);
CHECK(cufunc);
symbols.RegisterVar(fn->name + "_ptr_",
reinterpret_cast<void*>(cufunc));
}
engine = backends::ExecutionEngine::Create(backends::ExecutionOptions(),
std::move(symbols));
engine->Link<backends::CodeGenCUDA_Host>(hmodule);
#else
CINN_NOT_IMPLEMENTED;
#endif
} else {
engine = backends::ExecutionEngine::Create(backends::ExecutionOptions());
engine->Link<backends::CodeGenX86>(ir_module);
}
},
[&](std::variant<common::UnknownArch, common::X86Arch, common::ARMArch>) {
engine =
backends::ExecutionEngine::Create(backends::ExecutionOptions());
engine->Link<backends::CodeGenX86>(ir_module);
});
}

void ParallelCompiler::Task::BuildInstruction() {
Expand Down
22 changes: 15 additions & 7 deletions paddle/cinn/hlir/framework/pir/op_lowering_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -994,13 +994,21 @@ std::vector<ir::LoweredFunc> OpLowererImpl::DoOpLower(
}

// Insert output tensors into function arg
if (!expr.as_tensor_ref()->buffer.defined() ||
this->target_ != cinn::common::DefaultNVGPUTarget()) {
op_func_arg_tensors->push_back(expr.as_tensor_ref());
expr.as_tensor_ref()->WithBuffer();
} else {
op_func_arg_tensors->push_back(expr.as_tensor_ref());
}
target_.arch.Match(
[&](common::NVGPUArch) {
if (!expr.as_tensor_ref()->buffer.defined()) {
op_func_arg_tensors->push_back(expr.as_tensor_ref());
expr.as_tensor_ref()->WithBuffer();
} else {
op_func_arg_tensors->push_back(expr.as_tensor_ref());
}
},
[&](std::variant<common::UnknownArch,
common::X86Arch,
common::ARMArch>) {
op_func_arg_tensors->push_back(expr.as_tensor_ref());
expr.as_tensor_ref()->WithBuffer();
});
}

VLOG(4) << "op_func_arg_tensors.size(): " << op_func_arg_tensors->size();
Expand Down
Loading