Skip to content

Commit b2e0661

Browse files
chenwhqlMingMingShangTianzyfncgYuanRishengShixiaowei02
authored andcommitted
Paddle Tensor Operation Library initial implementation (PaddlePaddle#34425)
* initial tensor design & sign kernel demo * add move constructor for meta & add lodtensor * add dirs & sign xpu kernel * add mean cpu&cuda kernel impl * move sign & mean xpu & npu kernel * add selected_rows basic impl * refactor design, BaseTensor to DenseTensor, etc. * add scale mkldnn kernel * polish xpu & npu impl details * fix mkldnn reuse compile failed * change tensor operation lib name * rename util filename * add more comments * change TensorImplInterface to TensorInterface * add kernel key and factory * remove MKLDNNTensorMeta, add MKLDNNDenseTensor * change XXDeviceContext to XXContext * add base kernel registrar utils & test on sign * replace boost::any by paddle::any * fix several ci failed * fix npu compile error * add ordered map util * fix multiple ordered_map compile errors * move dev into include dir * support sign op in static op run * fix static op run error * fix new executor compile failed * add dygraph branch & remove sign_op.h * fix test_infer_no_need_buffer_slots * fix rocm compile link error * fix unitybuild error & clear glog * fix npu compile failed * skip quant trans test * fix part windows compile problem * fix xpu enforce error * fix inference test failed * remove ordered_map to solve quant failed * fix part of rcom compile faild * add more register kernels * revert scale kernel temporarily * fix code format error * add new kernel registrar marco * rename top to tcmpt * revert xpu, npu, mkldnn impl & remove op def * add kernel args parse functor to auto parse args * revert some change & add scale kernels * add op proto in dygraph kernelcontext building * polish kernel dispatch logic & nameing rule * fix scale kernel match error * fix scale test failed * add mean API and unittest * test mean api success * add branch to solve compiled error * skip clang format error * add mean skip rule in op_library * add dot kernel, api and unittest (PaddlePaddle#6) * remove old kernel and add symbol link * fix dot compiled failed * add merco for module declare * fix npu and xpu compile error * revert sign, mean, scale, dot kernel removing * add comment for keeping old kernel impl * fix mutable_data error * fix bfloat16 conflit * fix inference undef error * adapt to msvc compile rules * polish comment for template inst * add cmake template instantiation for win * fix backend to place device id bug * fix ifdef error * Op2functor (PaddlePaddle#7) * add kernel args maker class * make args maker non-const * remove debug log * modify codes by review options * split constructPrKernelContext function * fix output name bug * fix test_mean_op test_sign_op failed * fill_any_like kernel refactor (PaddlePaddle#10) * fill_any_like kernel refactor * remove useless code of full_like c++ api * skip dtype for fill_any_like * add attrs for kernel key constrcut * add use_pt_kernel Flags to control whether to use pt kernel (PaddlePaddle#13) * add use_pt_kernel Flags to control whether to use pt kernel * change the default value to true for cheking pt kernels * fix mutable_data cuda place error * move high level apis into hapi * remove selectedrows adapting temporarily * Support Scalar in Tensor Compute Library (PaddlePaddle#14) * fill_any_like kernel refactor * remove useless code of full_like c++ api * Support Scalar in Tensor Compute Library * add scalar in dygraph and static graph mode * keep the basic type for attr, instead of using scalar for all * merge the code * remove mkldnn tensor & polish details * use flat_hash_map and small_vector in kernel factory * Refactor flatten kernel (PaddlePaddle#12) * refactor flatten kernel * update infershape function * fix compile bugs * fix bugs when merge * fix compiler bugs * fix bugs when run test_flatten_api * fix bugs when run test * Revert "use flat_hash_map and small_vector in kernel factory" This reverts commit 2309149. * Move cpu, cuda and other device code into kernels (PaddlePaddle#15) * fill_any_like kernel refactor * remove useless code of full_like c++ api * Support Scalar in Tensor Compute Library * add scalar in dygraph and static graph mode * keep the basic type for attr, instead of using scalar for all * merge the code * start refactor matmul * move cpu, cuda and other device modules into kernels * merge code * polish code in operator.cc * Perfect unitests (PaddlePaddle#16) * perfect unittest * update license * replace with flat_hash_map, small_vector (PaddlePaddle#19) * fix small_vector build error on windows platform * replace with flat_hash_map, small_vector * remove todo * Perfect unitests (PaddlePaddle#20) * perfect unittest * update license * fix bug when run tcmpt_utils_test * refactor execution adapting impl * fix insert conflit * Fix CI bug of test_yolov3 (PaddlePaddle#21) * fill_any_like kernel refactor * remove useless code of full_like c++ api * Support Scalar in Tensor Compute Library * add scalar in dygraph and static graph mode * keep the basic type for attr, instead of using scalar for all * merge the code * start refactor matmul * move cpu, cuda and other device modules into kernels * merge code * polish code in operator.cc * Fix CI bug of test_yolov3 * add the tensor base class, test=develop (PaddlePaddle#17) * update the tensor base class, test=develop * remove two funcs, test=develop * update the error msg, test=develop Co-authored-by: Chen Weihang <[email protected]> * [no-verify] commit backend and tensor signature changes * Rename tcmpt to pten (PaddlePaddle#23) * rename tcmpt to pten * update omitted files for rename to pten * update omitted file for rename to pten * remove k of all enum var * remove kernel_instantiate (PaddlePaddle#26) * remove symbols and spatial_tensor * change common to functions * readd share tensor impl methods * add a candidate dense tensor class, test=develop (PaddlePaddle#28) * change all Pt to Pten * resolve conflit with xiaowei * Op2functor opt1 (PaddlePaddle#27) * replace to small vector and change to const & * add std::move Co-authored-by: Chen Weihang <[email protected]> * polish kernel factory and kernel registry * fix operator test error msg mismatch * remove tensor signature and backend set member * move scalar and polish enforce * revert dtype layout change to fix error * fix enum operator override error * add several base unittests * add pten utils tests * polish some details * Dev/op2func refactor 3 (PaddlePaddle#30) * add a candidate dense tensor class, test=develop * remove TensorBase::backend(), test=develop * remove some ops, test=develop * cherry-pick the pr of tensor meta, test=develop * moves the dense tensor and some ops, test=develop * update the linalg operator, test=develop * update other operators, test=develop * fix errors, test=develop * fix bugs, test=develop * try to resolve the problem of windows ci, test=develop * updates codes, test=develop * fix the tensor_utils.cc, test=develop * modify the dense tensor, test=develop * fix the data type, test=develop Co-authored-by: shixiaowei02 <[email protected]> * polish some details * polish kernel signature details * fix a bug about offsets of the tensor, test=develop (PaddlePaddle#31) Co-authored-by: shixiaowei02 <[email protected]> * polish some details Co-authored-by: chentianyu03 <[email protected]> Co-authored-by: zyfncg <[email protected]> Co-authored-by: YuanRisheng <[email protected]> Co-authored-by: 石晓伟 <[email protected]>
1 parent 80fa602 commit b2e0661

File tree

147 files changed

+8516
-195
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

147 files changed

+8516
-195
lines changed

cmake/generic.cmake

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,20 @@ function(find_fluid_modules TARGET_NAME)
116116
endif()
117117
endfunction(find_fluid_modules)
118118

119+
set_property(GLOBAL PROPERTY PTEN_MODULES "")
120+
# find all pten modules is used for paddle static library
121+
# for building inference libs
122+
function(find_pten_modules TARGET_NAME)
123+
get_filename_component(__target_path ${TARGET_NAME} ABSOLUTE)
124+
string(REGEX REPLACE "^${PADDLE_SOURCE_DIR}/" "" __target_path ${__target_path})
125+
string(FIND "${__target_path}" "pten" pos)
126+
if(pos GREATER 1)
127+
get_property(pten_modules GLOBAL PROPERTY PTEN_MODULES)
128+
set(pten_modules ${pten_modules} ${TARGET_NAME})
129+
set_property(GLOBAL PROPERTY PTEN_MODULES "${pten_modules}")
130+
endif()
131+
endfunction(find_pten_modules)
132+
119133
function(common_link TARGET_NAME)
120134
if (WITH_PROFILER)
121135
target_link_libraries(${TARGET_NAME} gperftools::profiler)
@@ -310,6 +324,7 @@ function(cc_library TARGET_NAME)
310324
else()
311325
add_library(${TARGET_NAME} STATIC ${cc_library_SRCS})
312326
find_fluid_modules(${TARGET_NAME})
327+
find_pten_modules(${TARGET_NAME})
313328
endif()
314329
if(cc_library_DEPS)
315330
# Don't need link libwarpctc.so
@@ -482,6 +497,7 @@ function(nv_library TARGET_NAME)
482497
else()
483498
add_library(${TARGET_NAME} STATIC ${nv_library_SRCS})
484499
find_fluid_modules(${TARGET_NAME})
500+
find_pten_modules(${TARGET_NAME})
485501
endif()
486502
if (nv_library_DEPS)
487503
add_dependencies(${TARGET_NAME} ${nv_library_DEPS})
@@ -572,6 +588,7 @@ function(hip_library TARGET_NAME)
572588
else()
573589
hip_add_library(${TARGET_NAME} STATIC ${hip_library_SRCS})
574590
find_fluid_modules(${TARGET_NAME})
591+
find_pten_modules(${TARGET_NAME})
575592
endif()
576593
if (hip_library_DEPS)
577594
add_dependencies(${TARGET_NAME} ${hip_library_DEPS})

paddle/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
add_subdirectory(scripts)
22
add_subdirectory(testing)
33
set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests CACHE INTERNAL "python tests directory")
4+
add_subdirectory(pten)
45
add_subdirectory(fluid)

paddle/fluid/framework/CMakeLists.txt

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,10 +197,12 @@ cc_library(unused_var_check SRCS unused_var_check.cc DEPS glog no_need_buffer_va
197197

198198
IF(WITH_XPU)
199199
cc_library(operator SRCS operator.cc DEPS xpu_op_list op_info device_context tensor scope glog trainer_desc_proto data_feed_proto
200-
shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type op_call_stack unused_var_check nan_inf_utils)
200+
shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type op_call_stack unused_var_check nan_inf_utils
201+
pten pten_utils kernel_factory)
201202
ELSE()
202203
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog trainer_desc_proto data_feed_proto
203-
shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type op_call_stack unused_var_check nan_inf_utils)
204+
shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type op_call_stack unused_var_check nan_inf_utils
205+
pten pten_utils kernel_factory)
204206
ENDIF()
205207

206208
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry device_context)
@@ -394,6 +396,8 @@ cc_library(save_load_util SRCS save_load_util.cc DEPS tensor scope layer)
394396
cc_test(save_load_util_test SRCS save_load_util_test.cc DEPS save_load_util tensor scope layer)
395397
cc_library(generator SRCS generator.cc DEPS enforce place)
396398

399+
cc_library(pten_utils SRCS pten_utils.cc DEPS lod_tensor selected_rows place pten var_type_traits pten_hapi_utils)
400+
397401
# Get the current working branch
398402
execute_process(
399403
COMMAND git rev-parse --abbrev-ref HEAD
@@ -456,3 +460,4 @@ if(WITH_TESTING AND TEST selected_rows_test)
456460
endif()
457461

458462
cc_test(scope_guard_test SRCS scope_guard_test.cc)
463+
cc_test(pten_utils_test SRCS pten_utils_test.cc DEPS pten_utils)

paddle/fluid/framework/operator.cc

Lines changed: 191 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ limitations under the License. */
2929
#include "paddle/fluid/framework/var_type.h"
3030
#include "paddle/fluid/platform/enforce.h"
3131
#include "paddle/fluid/platform/profiler.h"
32+
#include "paddle/pten/common/scalar.h"
3233

3334
namespace paddle {
3435
namespace framework {
@@ -49,6 +50,7 @@ DECLARE_bool(check_nan_inf);
4950
DECLARE_bool(enable_unused_var_check);
5051
PADDLE_DEFINE_EXPORTED_int32(inner_op_parallelism, 0,
5152
"number of threads for inner op");
53+
DECLARE_bool(run_pten_kernel);
5254

5355
namespace paddle {
5456
namespace framework {
@@ -1120,8 +1122,24 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
11201122
}
11211123
#endif
11221124

1123-
if (kernel_type_.get() == nullptr || kernel_func_.get() == nullptr) {
1124-
ChooseKernel(*runtime_ctx, scope, place);
1125+
auto exe_ctx = ExecutionContext(*this, scope, *dev_ctx, *runtime_ctx);
1126+
1127+
// TODO(chenweihang): Now we are still reusing a lot of the original fluid
1128+
// implementation, this is a gradual replacement process
1129+
// TODO(chenweihang): in the first phase of project, we only support CPU, CUDA
1130+
// and RCOM backend, the XPU, NPU and MKLDNN will be supported in the second
1131+
// phase
1132+
if (FLAGS_run_pten_kernel &&
1133+
pten::KernelFactory::Instance().HasCompatiblePtenKernel(type_)) {
1134+
if (pt_kernel_signature_.get() == nullptr || pt_kernel_.get() == nullptr) {
1135+
ChoosePtenKernel(exe_ctx);
1136+
}
1137+
run_pten_kernel_ = pt_kernel_->IsValid();
1138+
}
1139+
if (!run_pten_kernel_) {
1140+
if (kernel_type_.get() == nullptr || kernel_func_.get() == nullptr) {
1141+
ChooseKernel(exe_ctx);
1142+
}
11251143
}
11261144

11271145
// do data transformScope &transfer_scope;
@@ -1159,8 +1177,13 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
11591177
{
11601178
platform::RecordEvent record_event("compute",
11611179
platform::EventRole::kInnerOp);
1162-
(*kernel_func_)(
1163-
ExecutionContext(*this, exec_scope, *dev_ctx, *runtime_ctx));
1180+
if (run_pten_kernel_) {
1181+
auto op_kernel_ctx = BuildPtenKernelContext(*runtime_ctx, *dev_ctx);
1182+
(*pt_kernel_)(&op_kernel_ctx);
1183+
} else {
1184+
(*kernel_func_)(
1185+
ExecutionContext(*this, exec_scope, *dev_ctx, *runtime_ctx));
1186+
}
11641187
}
11651188

11661189
if (!transfered_inplace_vars.empty()) {
@@ -1208,25 +1231,11 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
12081231
}
12091232
}
12101233

1211-
void OperatorWithKernel::ChooseKernel(const RuntimeContext& ctx,
1212-
const Scope& scope,
1213-
const platform::Place& place) const {
1214-
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
1215-
auto* dev_ctx = pool.Get(place);
1216-
1217-
// check if op[type] has kernel registered.
1218-
auto& all_op_kernels = AllOpKernels();
1219-
auto kernels_iter = all_op_kernels.find(type_);
1220-
PADDLE_ENFORCE_NE(
1221-
kernels_iter, all_op_kernels.end(),
1222-
platform::errors::Unavailable(
1223-
"There are no kernels which are registered in the %s operator.",
1224-
type_));
1225-
1226-
OpKernelMap& kernels = kernels_iter->second;
1234+
OpKernelType OperatorWithKernel::InnerGetExpectedKernelType(
1235+
const ExecutionContext& ctx) const {
1236+
auto& dev_ctx = ctx.device_context();
12271237

1228-
auto expected_kernel_key = this->GetExpectedKernelType(
1229-
ExecutionContext(*this, scope, *dev_ctx, ctx));
1238+
auto expected_kernel_key = this->GetExpectedKernelType(ctx);
12301239
if (HasAttr("op_device")) {
12311240
if (Attr<std::string>("op_device") == "cpu") {
12321241
expected_kernel_key.place_ = platform::CPUPlace();
@@ -1243,9 +1252,9 @@ void OperatorWithKernel::ChooseKernel(const RuntimeContext& ctx,
12431252
// when the Op that only has CPUKernel is assigned to GPU, the CPUKernel
12441253
// will be executed and a warning will be given at the same time.
12451254
if (SupportGPU()) {
1246-
expected_kernel_key.place_ = dev_ctx->GetPlace();
1255+
expected_kernel_key.place_ = dev_ctx.GetPlace();
12471256
} else if (SupportNPU()) {
1248-
expected_kernel_key.place_ = dev_ctx->GetPlace();
1257+
expected_kernel_key.place_ = dev_ctx.GetPlace();
12491258
} else {
12501259
expected_kernel_key.place_ = platform::CPUPlace();
12511260
LOG_FIRST_N(WARNING, 1)
@@ -1256,6 +1265,47 @@ void OperatorWithKernel::ChooseKernel(const RuntimeContext& ctx,
12561265
}
12571266
VLOG(3) << "op type:" << type_
12581267
<< ", expected_kernel_key:" << expected_kernel_key;
1268+
return expected_kernel_key;
1269+
}
1270+
1271+
void OperatorWithKernel::ChoosePtenKernel(const ExecutionContext& ctx) const {
1272+
pt_kernel_signature_.reset(
1273+
new KernelSignature(std::move(this->GetExpectedPtenKernelArgs(ctx))));
1274+
1275+
VLOG(1) << KernelSignatureToString(*pt_kernel_signature_.get());
1276+
1277+
kernel_type_.reset(
1278+
new OpKernelType(std::move(InnerGetExpectedKernelType(ctx))));
1279+
1280+
auto pt_kernel_name = pten::KernelName(pt_kernel_signature_->name);
1281+
auto pt_kernel_key = TransOpKernelTypeToPtenKernelKey(*kernel_type_.get());
1282+
pt_kernel_.reset(
1283+
new pten::Kernel(pten::KernelFactory::Instance().SelectKernel(
1284+
pt_kernel_name, pt_kernel_key)));
1285+
1286+
if (pt_kernel_->IsValid()) {
1287+
VLOG(1) << "Static mode ChoosePtenKernel - kernel name: " << pt_kernel_name
1288+
<< " | kernel key: " << pt_kernel_key
1289+
<< " | kernel: " << *pt_kernel_;
1290+
} else {
1291+
VLOG(1) << "Static mode ChoosePtenKernel - kernel `" << pt_kernel_name
1292+
<< "` not found.";
1293+
}
1294+
}
1295+
1296+
void OperatorWithKernel::ChooseKernel(const ExecutionContext& ctx) const {
1297+
// check if op[type] has kernel registered.
1298+
auto& all_op_kernels = AllOpKernels();
1299+
auto kernels_iter = all_op_kernels.find(type_);
1300+
PADDLE_ENFORCE_NE(
1301+
kernels_iter, all_op_kernels.end(),
1302+
platform::errors::Unavailable(
1303+
"There are no kernels which are registered in the %s operator.",
1304+
type_));
1305+
1306+
OpKernelMap& kernels = kernels_iter->second;
1307+
1308+
auto expected_kernel_key = InnerGetExpectedKernelType(ctx);
12591309

12601310
auto kernel_iter = kernels.find(expected_kernel_key);
12611311
#ifdef PADDLE_WITH_MKLDNN
@@ -1562,11 +1612,10 @@ Scope* OperatorWithKernel::PrepareData(
15621612
}
15631613

15641614
void OperatorWithKernel::ParseInputDataType(
1565-
const ExecutionContext& ctx, const std::string& name,
1615+
const std::vector<Variable*>& vars, const std::string& name,
15661616
proto::VarType::Type* data_type) const {
15671617
proto::VarType::Type default_data_type =
15681618
static_cast<proto::VarType::Type>(-1);
1569-
const std::vector<Variable*> vars = ctx.MultiInputVar(name);
15701619
for (size_t i = 0; i < vars.size(); ++i) {
15711620
const Variable* var = vars[i];
15721621
if (var != nullptr) {
@@ -1588,10 +1637,9 @@ void OperatorWithKernel::ParseInputDataType(
15881637
if (t != nullptr) {
15891638
PADDLE_ENFORCE_EQ(
15901639
t->IsInitialized(), true,
1591-
platform::errors::InvalidArgument(
1592-
"The Tensor in the %s Op's Input Variable %s(%s) is "
1593-
"not initialized.",
1594-
Type(), name, ctx.InputNames(name).at(i)));
1640+
platform::errors::InvalidArgument("The %s Op's Input Variable `%s` "
1641+
"contains uninitialized Tensor.",
1642+
Type(), name));
15951643
proto::VarType::Type tmp = t->type();
15961644
PADDLE_ENFORCE(tmp == *data_type || *data_type == default_data_type,
15971645
platform::errors::InvalidArgument(
@@ -1614,7 +1662,8 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
16141662
static_cast<proto::VarType::Type>(-1);
16151663
proto::VarType::Type data_type = dafault_data_type;
16161664
for (auto& input : ctx.InNameList()) {
1617-
ParseInputDataType(ctx, input, &data_type);
1665+
const std::vector<Variable*> vars = ctx.MultiInputVar(input);
1666+
ParseInputDataType(vars, input, &data_type);
16181667
}
16191668
PADDLE_ENFORCE_NE(
16201669
data_type, dafault_data_type,
@@ -1628,7 +1677,7 @@ proto::VarType::Type OperatorWithKernel::IndicateVarDataType(
16281677
proto::VarType::Type dafault_data_type =
16291678
static_cast<proto::VarType::Type>(-1);
16301679
proto::VarType::Type data_type = dafault_data_type;
1631-
ParseInputDataType(ctx, name, &data_type);
1680+
ParseInputDataType(ctx.MultiInputVar(name), name, &data_type);
16321681
PADDLE_ENFORCE_NE(
16331682
data_type, dafault_data_type,
16341683
platform::errors::InvalidArgument(
@@ -1711,5 +1760,115 @@ OpKernelType OperatorWithKernel::GetKernelTypeForVar(
17111760
tensor.layout());
17121761
}
17131762

1763+
KernelSignature OperatorWithKernel::GetExpectedPtenKernelArgs(
1764+
const ExecutionContext& ctx) const {
1765+
if (!KernelSignatureMap::Instance().Has(Type())) {
1766+
// TODO(chenweihang): we can generate this map by proto info in compile time
1767+
KernelArgsNameMakerByOpProto maker(Info().proto_);
1768+
KernelSignatureMap::Instance().Emplace(
1769+
Type(), std::move(maker.GetKernelSignature()));
1770+
}
1771+
return KernelSignatureMap::Instance().Get(Type());
1772+
}
1773+
1774+
pten::KernelContext OperatorWithKernel::BuildPtenKernelContext(
1775+
const RuntimeContext& ctx, const platform::DeviceContext& dev_ctx) const {
1776+
// TODO(chenweihang): now only work for very simple case,
1777+
// many cases need to be deal with later:
1778+
// 1. the input and output are not tensor
1779+
// 2. the dispensbale, duplicable input and output
1780+
// 3. needless attributes remove
1781+
// 4. use pt Tensor directly
1782+
// 5. kernel input is not DenseTensor
1783+
pten::KernelContext op_kernel_ctx(dev_ctx);
1784+
1785+
auto& input_names = std::get<0>(pt_kernel_signature_->args);
1786+
auto& attr_names = std::get<1>(pt_kernel_signature_->args);
1787+
auto& output_names = std::get<2>(pt_kernel_signature_->args);
1788+
1789+
auto input_defs = pt_kernel_->args_def().input_defs();
1790+
auto attr_defs = pt_kernel_->args_def().attribute_defs();
1791+
auto output_defs = pt_kernel_->args_def().output_defs();
1792+
1793+
PADDLE_ENFORCE_EQ(input_names.size(), input_defs.size(),
1794+
platform::errors::InvalidArgument(
1795+
"The size of inputs_args names (%d) must be equal to "
1796+
"the size of kernel input_defs (%d).",
1797+
input_names.size(), input_defs.size()));
1798+
1799+
PADDLE_ENFORCE_EQ(output_names.size(), output_defs.size(),
1800+
platform::errors::InvalidArgument(
1801+
"The size of outputs_args names (%d) must be equal to "
1802+
"the size of kernel output_defs (%d).",
1803+
output_names.size(), output_defs.size()));
1804+
1805+
PADDLE_ENFORCE_EQ(attr_names.size(), attr_defs.size(),
1806+
platform::errors::InvalidArgument(
1807+
"The size of attribute_args names (%d) must be equal "
1808+
"to the size of kernel attribute_defs (%d).",
1809+
attr_names.size(), attr_defs.size()));
1810+
1811+
for (size_t i = 0; i < input_names.size(); ++i) {
1812+
auto in_def = input_defs.at(i);
1813+
VLOG(2) << "in_def: " << in_def.backend << ", " << in_def.dtype << ", "
1814+
<< in_def.layout;
1815+
1816+
auto ins_vector = ctx.inputs.at(input_names[i]);
1817+
1818+
paddle::SmallVector<std::shared_ptr<pten::TensorBase>> tmp_inputs;
1819+
for (auto var : ins_vector) {
1820+
tmp_inputs.emplace_back(
1821+
experimental::MakePtenTensorBaseFromVar(*var, in_def));
1822+
}
1823+
op_kernel_ctx.EmplaceBackInputs(std::move(tmp_inputs));
1824+
}
1825+
1826+
for (size_t i = 0; i < output_names.size(); ++i) {
1827+
auto out_def = output_defs.at(i);
1828+
auto outs_vector = ctx.outputs.at(output_names[i]);
1829+
1830+
paddle::SmallVector<std::shared_ptr<pten::TensorBase>> tmp_outputs;
1831+
for (auto var : outs_vector) {
1832+
tmp_outputs.emplace_back(
1833+
experimental::MakePtenTensorBaseFromVar(var, out_def));
1834+
}
1835+
op_kernel_ctx.EmplaceBackOutputs(std::move(tmp_outputs));
1836+
}
1837+
1838+
for (size_t i = 0; i < attr_names.size(); ++i) {
1839+
auto& attr = Attrs().at(attr_names[i]);
1840+
if (attr_defs[i].type_index == std::type_index(typeid(pten::Scalar))) {
1841+
// TODO(chenweihang): support other attrs later
1842+
// TODO(zhangyunfei): Scalar should hold scaler type, and we should check
1843+
// attribtue type by attr_defs
1844+
if (std::type_index(attr.type()) == std::type_index(typeid(float))) {
1845+
op_kernel_ctx.EmplaceBackAttr(
1846+
std::move(pten::Scalar(BOOST_GET_CONST(float, attr))));
1847+
} else {
1848+
PADDLE_THROW(platform::errors::Unimplemented(
1849+
"unsupported cast op attribute `%s` to Scalar when construct "
1850+
"KernelContext.",
1851+
attr_names[i]));
1852+
}
1853+
} else {
1854+
// TODO(chenweihang): support other attrs later
1855+
if (attr_defs[i].type_index == std::type_index(typeid(int))) {
1856+
op_kernel_ctx.EmplaceBackAttr(BOOST_GET_CONST(int, attr));
1857+
} else if (attr_defs[i].type_index == std::type_index(typeid(float))) {
1858+
op_kernel_ctx.EmplaceBackAttr(BOOST_GET_CONST(float, attr));
1859+
} else if (attr_defs[i].type_index == std::type_index(typeid(bool))) {
1860+
op_kernel_ctx.EmplaceBackAttr(BOOST_GET_CONST(bool, attr));
1861+
} else {
1862+
PADDLE_THROW(platform::errors::Unimplemented(
1863+
"unsupported cast op attribute `%s` when construct "
1864+
"KernelContext.",
1865+
attr_names[i]));
1866+
}
1867+
}
1868+
}
1869+
1870+
return op_kernel_ctx;
1871+
}
1872+
17141873
} // namespace framework
17151874
} // namespace paddle

0 commit comments

Comments
 (0)