Skip to content

Commit e8cdb49

Browse files
authored
[CustomOp] Support attributes as func input in custom op (#31128)
* add simple attr support and test * add int, float attr support * support other attribute * add custom attrs test in cmake * polish details * fix test failed * add backward test * update test flags
1 parent ffbf713 commit e8cdb49

File tree

7 files changed

+458
-43
lines changed

7 files changed

+458
-43
lines changed

paddle/fluid/extension/include/op_meta_info.h

Lines changed: 48 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,26 @@ inline std::string Grad(const std::string& var_name) {
8181
using KernelFunc = std::vector<Tensor> (*)(std::vector<Tensor> inputs,
8282
std::vector<boost::any> attrs);
8383

84+
#define PD_SPECIALIZE_ComputeCallHelper(attr_type) \
85+
template <typename... Tail> \
86+
struct ComputeCallHelper<attr_type, Tail...> { \
87+
template <int in_idx, int attr_idx, typename... PreviousArgs> \
88+
static Return Compute(std::vector<Tensor> inputs, \
89+
std::vector<boost::any> attrs, \
90+
const PreviousArgs&... pargs) { \
91+
try { \
92+
attr_type arg = boost::any_cast<attr_type>(attrs[attr_idx]); \
93+
return ComputeCallHelper<Tail...>::template Compute<in_idx, \
94+
attr_idx + 1>( \
95+
inputs, attrs, pargs..., arg); \
96+
} catch (boost::bad_any_cast&) { \
97+
PD_THROW( \
98+
"Attribute cast error in custom operator. Expected " #attr_type \
99+
" value."); \
100+
} \
101+
} \
102+
}
103+
84104
template <typename T>
85105
struct TypeTag {};
86106

@@ -114,26 +134,20 @@ struct KernelFuncImpl<Return (*)(Args...), impl_fn> {
114134
}
115135
};
116136

117-
// TODO(chenweihang): add support for attribute input
118-
// int attribute input (not used now)
119-
template <typename... Tail>
120-
struct ComputeCallHelper<int, Tail...> {
121-
template <int in_idx, int attr_idx, typename... PreviousArgs>
122-
static Return Compute(std::vector<Tensor> inputs,
123-
std::vector<boost::any> attrs,
124-
const PreviousArgs&... pargs) {
125-
try {
126-
int arg = boost::any_cast<int>(attrs[attr_idx]);
127-
return ComputeCallHelper<Tail...>::template Compute<in_idx,
128-
attr_idx + 1>(
129-
inputs, attrs, pargs..., arg);
130-
} catch (boost::bad_any_cast&) {
131-
PD_THROW(
132-
"Attribute cast error in custom operator. Expected int value.");
133-
}
134-
}
135-
};
136-
137+
PD_SPECIALIZE_ComputeCallHelper(bool);
138+
PD_SPECIALIZE_ComputeCallHelper(int);
139+
PD_SPECIALIZE_ComputeCallHelper(float);
140+
PD_SPECIALIZE_ComputeCallHelper(int64_t);
141+
PD_SPECIALIZE_ComputeCallHelper(std::string);
142+
PD_SPECIALIZE_ComputeCallHelper(std::vector<int>);
143+
PD_SPECIALIZE_ComputeCallHelper(std::vector<float>);
144+
PD_SPECIALIZE_ComputeCallHelper(std::vector<int64_t>);
145+
PD_SPECIALIZE_ComputeCallHelper(std::vector<std::string>);
146+
// TODO(chenweihang): support other attribute type if needed.
147+
// Why not support other attribute type here?
148+
// - boost::blank, std::vector<bool> and std::vector<double>
149+
// are not used in op
150+
// - BlockDesc* and std::vector<BlockDesc*> are used in framework
137151
// end: base template
138152
template <typename T>
139153
struct ComputeCallHelper<TypeTag<T>> {
@@ -245,10 +259,23 @@ struct InferDtypeFuncImpl<Return (*)(Args...), impl_fn> {
245259
class PD_DLL_DECL OpMetaInfo {
246260
public:
247261
explicit OpMetaInfo(const std::string& op_name) : name_(op_name) {}
262+
263+
// format: {"<name1>", "<name2>", ...}
248264
OpMetaInfo& Inputs(std::vector<std::string>&& inputs);
265+
266+
// format: {"<name1>", "<name2>", ...}
249267
OpMetaInfo& Outputs(std::vector<std::string>&& outputs);
268+
269+
// format: {"<name1>:<type1>", "<name1>:<type1>", ...}
270+
OpMetaInfo& Attrs(std::vector<std::string>&& attrs);
271+
272+
// format: PD_KERNEL(...)
250273
OpMetaInfo& SetKernelFn(KernelFunc&& func);
274+
275+
// format: PD_INFER_SHAPE(...)
251276
OpMetaInfo& SetInferShapeFn(InferShapeFunc&& func);
277+
278+
// format: PD_INFER_DTYPE(...)
252279
OpMetaInfo& SetInferDtypeFn(InferDtypeFunc&& func);
253280

254281
private:
@@ -297,6 +324,7 @@ class PD_DLL_DECL OpMetaInfoBuilder {
297324
explicit OpMetaInfoBuilder(std::string&& name);
298325
OpMetaInfoBuilder& Inputs(std::vector<std::string>&& inputs);
299326
OpMetaInfoBuilder& Outputs(std::vector<std::string>&& outputs);
327+
OpMetaInfoBuilder& Attrs(std::vector<std::string>&& attrs);
300328
OpMetaInfoBuilder& SetKernelFn(KernelFunc func);
301329
OpMetaInfoBuilder& SetInferShapeFn(InferShapeFunc func);
302330
OpMetaInfoBuilder& SetInferDtypeFn(InferDtypeFunc func);

paddle/fluid/extension/src/op_meta_info.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ OpMetaInfo& OpMetaInfo::Outputs(std::vector<std::string>&& outputs) {
3232
outputs_ = std::forward<std::vector<std::string>>(outputs);
3333
return *this;
3434
}
35+
OpMetaInfo& OpMetaInfo::Attrs(std::vector<std::string>&& attrs) {
36+
attrs_ = std::forward<std::vector<std::string>>(attrs);
37+
return *this;
38+
}
3539
OpMetaInfo& OpMetaInfo::SetKernelFn(KernelFunc&& func) {
3640
kernel_fn_ = std::forward<KernelFunc>(func);
3741
return *this;
@@ -78,6 +82,11 @@ OpMetaInfoBuilder& OpMetaInfoBuilder::Outputs(
7882
return *this;
7983
}
8084

85+
OpMetaInfoBuilder& OpMetaInfoBuilder::Attrs(std::vector<std::string>&& attrs) {
86+
info_ptr_->Attrs(std::forward<std::vector<std::string>>(attrs));
87+
return *this;
88+
}
89+
8190
OpMetaInfoBuilder& OpMetaInfoBuilder::SetKernelFn(KernelFunc func) {
8291
info_ptr_->SetKernelFn(std::forward<KernelFunc>(func));
8392
return *this;

paddle/fluid/framework/custom_operator.cc

Lines changed: 118 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,24 @@ inline bool IsMemberOf(const std::vector<std::string>& vec,
7373
return std::find(vec.cbegin(), vec.cend(), name) != vec.cend();
7474
}
7575

76+
std::vector<std::string> ParseAttrStr(const std::string& attr) {
77+
auto split_pos = attr.find_first_of(":");
78+
PADDLE_ENFORCE_NE(split_pos, std::string::npos,
79+
platform::errors::InvalidArgument(
80+
"Invalid attribute string format. Attribute string "
81+
"format is `<name>:<type>`."));
82+
83+
std::vector<std::string> rlt;
84+
// 1. name
85+
rlt.emplace_back(string::trim_spaces(attr.substr(0, split_pos)));
86+
// 2. type
87+
rlt.emplace_back(string::trim_spaces(attr.substr(split_pos + 1)));
88+
89+
VLOG(1) << "attr name: " << rlt[0] << ", attr type str: " << rlt[1];
90+
91+
return rlt;
92+
}
93+
7694
} // namespace detail
7795

7896
////////////////// Kernel Define ////////////////////
@@ -81,7 +99,8 @@ inline bool IsMemberOf(const std::vector<std::string>& vec,
8199
static void RunKernelFunc(const framework::ExecutionContext& ctx,
82100
const paddle::KernelFunc& func,
83101
const std::vector<std::string>& inputs,
84-
const std::vector<std::string>& outputs) {
102+
const std::vector<std::string>& outputs,
103+
const std::vector<std::string>& attrs) {
85104
VLOG(1) << "Custom Operator: Start run KernelFunc.";
86105
std::vector<paddle::Tensor> custom_ins;
87106
for (auto& in_name : inputs) {
@@ -98,10 +117,43 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx,
98117
custom_ins.emplace_back(custom_in);
99118
}
100119

101-
std::vector<boost::any> attrs;
120+
std::vector<boost::any> custom_attrs;
121+
for (auto& attr_str : attrs) {
122+
auto attr_name_and_type = detail::ParseAttrStr(attr_str);
123+
auto attr_name = attr_name_and_type[0];
124+
auto attr_type_str = attr_name_and_type[1];
125+
if (attr_type_str == "bool") {
126+
custom_attrs.emplace_back(ctx.Attr<bool>(attr_name));
127+
} else if (attr_type_str == "int") {
128+
custom_attrs.emplace_back(ctx.Attr<int>(attr_name));
129+
} else if (attr_type_str == "float") {
130+
custom_attrs.emplace_back(ctx.Attr<float>(attr_name));
131+
} else if (attr_type_str == "int64_t") {
132+
custom_attrs.emplace_back(ctx.Attr<int64_t>(attr_name));
133+
} else if (attr_type_str == "std::string") {
134+
custom_attrs.emplace_back(ctx.Attr<std::string>(attr_name));
135+
} else if (attr_type_str == "std::vector<int>") {
136+
custom_attrs.emplace_back(ctx.Attr<std::vector<int>>(attr_name));
137+
} else if (attr_type_str == "std::vector<float>") {
138+
custom_attrs.emplace_back(ctx.Attr<std::vector<float>>(attr_name));
139+
} else if (attr_type_str == "std::vector<int64_t>") {
140+
custom_attrs.emplace_back(ctx.Attr<std::vector<int64_t>>(attr_name));
141+
} else if (attr_type_str == "std::vector<std::string>") {
142+
custom_attrs.emplace_back(ctx.Attr<std::vector<std::string>>(attr_name));
143+
} else {
144+
PADDLE_THROW(platform::errors::Unimplemented(
145+
"Unsupported `%s` type value as custom attribute now. "
146+
"Supported data types include `bool`, `int`, `float`, "
147+
"`int64_t`, `std::string`, `std::vector<int>`, "
148+
"`std::vector<float>`, `std::vector<int64_t>, "
149+
"`std::vector<std::string>`, Please check whether "
150+
"the attribute data type and data type string are matched.",
151+
attr_type_str));
152+
}
153+
}
102154

103155
VLOG(1) << "Run ComputeFunc.";
104-
auto outs = func(custom_ins, attrs);
156+
auto outs = func(custom_ins, custom_attrs);
105157

106158
VLOG(1) << "Custom Operator: Share outputs into ExecutionContext.";
107159
for (size_t i = 0; i < outputs.size(); ++i) {
@@ -164,7 +216,51 @@ class CustomOpMaker : public OpProtoAndCheckerMaker {
164216
for (auto& out_name : outputs_) {
165217
AddOutput(out_name, "The output " + out_name + "of Custom Operator.");
166218
}
167-
// TODO(chenweihang): support attrs in later PR
219+
for (auto& attr : attrs_) {
220+
auto attr_name_and_type = detail::ParseAttrStr(attr);
221+
auto attr_name = attr_name_and_type[0];
222+
auto attr_type_str = attr_name_and_type[1];
223+
if (attr_type_str == "bool") {
224+
AddAttr<bool>(attr_name, "custom operator bool attribute.")
225+
.SetDefault(false);
226+
} else if (attr_type_str == "int") {
227+
AddAttr<int>(attr_name, "custom operator int attribute.").SetDefault(1);
228+
} else if (attr_type_str == "float") {
229+
AddAttr<float>(attr_name, "custom operator float attribute.")
230+
.SetDefault(1.0f);
231+
} else if (attr_type_str == "int64_t") {
232+
AddAttr<int64_t>(attr_name, "custom operator int64_t attribute.")
233+
.SetDefault(1);
234+
} else if (attr_type_str == "std::string") {
235+
AddAttr<std::string>(attr_name, "custom operator int attribute.")
236+
.SetDefault("");
237+
} else if (attr_type_str == "std::vector<int>") {
238+
AddAttr<std::vector<int>>(attr_name,
239+
"custom operator std::vector<int> attribute.")
240+
.SetDefault({});
241+
} else if (attr_type_str == "std::vector<float>") {
242+
AddAttr<std::vector<float>>(
243+
attr_name, "custom operator std::vector<float> attribute.")
244+
.SetDefault({});
245+
} else if (attr_type_str == "std::vector<int64_t>") {
246+
AddAttr<std::vector<int64_t>>(
247+
attr_name, "custom operator std::vector<int64_t> attribute.")
248+
.SetDefault({});
249+
} else if (attr_type_str == "std::vector<std::string>") {
250+
AddAttr<std::vector<std::string>>(
251+
attr_name, "custom operator std::vector<std::string> attribute.")
252+
.SetDefault({});
253+
} else {
254+
PADDLE_THROW(platform::errors::Unimplemented(
255+
"Unsupported `%s` type value as custom attribute now. "
256+
"Supported data types include `bool`, `int`, `float`, "
257+
"`int64_t`, `std::string`, `std::vector<int>`, "
258+
"`std::vector<float>`, `std::vector<int64_t>, "
259+
"`std::vector<std::string>`, Please check whether "
260+
"the attribute data type and data type string are matched.",
261+
attr_type_str));
262+
}
263+
}
168264
AddComment(R"DOC(
169265
Custom Operator.
170266
@@ -227,7 +323,7 @@ class CustomGradOpMaker<OpDesc> : public SingleGradOpMaker<OpDesc> {
227323
VLOG(1) << "Custom Operator: GradOpDescMaker - output: " << out_name;
228324
grad_op->SetOutput(out_name, this->InputGrad(detail::NoGrad(out_name)));
229325
}
230-
// TODO(chenweihang): support attrs in later PR
326+
grad_op->SetAttrMap(this->Attrs());
231327
}
232328

233329
private:
@@ -287,7 +383,7 @@ class CustomGradOpMaker<imperative::OpBase>
287383
VLOG(1) << "Custom Operator: GradOpBaseMaker - output: " << out_name;
288384
grad_op->SetOutput(out_name, this->InputGrad(detail::NoGrad(out_name)));
289385
}
290-
// TODO(chenweihang): support attrs in later PR
386+
grad_op->SetAttrMap(this->Attrs());
291387
}
292388

293389
private:
@@ -303,31 +399,36 @@ void RegisterOperatorKernelWithPlace(const std::string& name,
303399
const proto::VarType::Type type,
304400
const PlaceType& place,
305401
const std::vector<std::string>& inputs,
306-
const std::vector<std::string>& outputs) {
402+
const std::vector<std::string>& outputs,
403+
const std::vector<std::string>& attrs) {
307404
OpKernelType key(type,
308405
CustomTensorUtils::ConvertEnumPlaceToInnerPlace(place));
309406
VLOG(1) << "Custom Operator: op kernel key: " << key;
310407
OperatorWithKernel::AllOpKernels()[name][key] =
311-
[kernel_func, inputs, outputs](const framework::ExecutionContext& ctx) {
408+
[kernel_func, inputs, outputs,
409+
attrs](const framework::ExecutionContext& ctx) {
312410
VLOG(1) << "Custom Operator: run custom kernel func in lambda.";
313-
RunKernelFunc(ctx, kernel_func, inputs, outputs);
411+
RunKernelFunc(ctx, kernel_func, inputs, outputs, attrs);
314412
};
315413
}
316414

317415
void RegisterOperatorKernel(const std::string& name,
318416
const paddle::KernelFunc& kernel_func,
319417
const std::vector<std::string>& inputs,
320-
const std::vector<std::string>& outputs) {
418+
const std::vector<std::string>& outputs,
419+
const std::vector<std::string>& attrs) {
321420
VLOG(1) << "Custom Operator: op name in kernel: " << name;
322421
// NOTE [ Dummy Op Kernel Key ]
323422
// TODO(chenweihang): Because execute engine need get device context based
324423
// op_kernel_key.place_, so we should register kernel for each
325424
// device. But this is not entirely correct, if user only give a cpu kernel,
326425
// but call api in gpu device, it will cause error.
327426
RegisterOperatorKernelWithPlace(name, kernel_func, proto::VarType::RAW,
328-
PlaceType::kCPU, inputs, outputs);
427+
PlaceType::kCPU, inputs, outputs, attrs);
428+
#ifdef PADDLE_WITH_CUDA
329429
RegisterOperatorKernelWithPlace(name, kernel_func, proto::VarType::RAW,
330-
PlaceType::kGPU, inputs, outputs);
430+
PlaceType::kGPU, inputs, outputs, attrs);
431+
#endif
331432
}
332433

333434
void RegisterOperatorWithMetaInfo(
@@ -350,6 +451,8 @@ void RegisterOperatorWithMetaInfo(
350451
<< string::join_strings(op_inputs, ',');
351452
VLOG(1) << "Custom Operator: forward, op outputs: "
352453
<< string::join_strings(op_outputs, ',');
454+
VLOG(1) << "Custom Operator: forward, op attrs: "
455+
<< string::join_strings(op_attrs, ',');
353456

354457
// Op
355458
info.creator_ = [](const std::string& op_name, const VariableNameMap& inputs,
@@ -426,7 +529,7 @@ void RegisterOperatorWithMetaInfo(
426529
};
427530

428531
// Kernel func
429-
RegisterOperatorKernel(op_name, kernel_fn, op_inputs, op_outputs);
532+
RegisterOperatorKernel(op_name, kernel_fn, op_inputs, op_outputs, op_attrs);
430533

431534
// If grad op or double grad op exists
432535
std::string cur_op_name = op_name;
@@ -436,6 +539,7 @@ void RegisterOperatorWithMetaInfo(
436539
auto& grad_op_name = OpMetaInfoHelper::GetOpName(cur_grad_op);
437540
auto& grad_op_inputs = OpMetaInfoHelper::GetInputs(cur_grad_op);
438541
auto& grad_op_outputs = OpMetaInfoHelper::GetOutputs(cur_grad_op);
542+
auto& grad_op_attrs = OpMetaInfoHelper::GetAttrs(cur_grad_op);
439543
auto& grad_kernel_fn = OpMetaInfoHelper::GetKernelFn(cur_grad_op);
440544

441545
VLOG(1) << "Custom Operator: backward, op name: " << grad_op_name;
@@ -489,7 +593,7 @@ void RegisterOperatorWithMetaInfo(
489593

490594
// Kernel func
491595
RegisterOperatorKernel(grad_op_name, grad_kernel_fn, grad_op_inputs,
492-
grad_op_outputs);
596+
grad_op_outputs, grad_op_attrs);
493597

494598
// update current info
495599
OpInfoMap::Instance().Insert(cur_op_name, info);

python/paddle/fluid/tests/custom_op/CMakeLists.txt

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,13 @@ py_test(test_sysconfig SRCS test_sysconfig.py)
1313

1414
# 'test_dispatch' compile .cc file
1515
py_test(test_dispatch_jit SRCS test_dispatch_jit.py)
16-
set_tests_properties(test_dispatch_jit PROPERTIES TIMEOUT 180)
16+
set_tests_properties(test_dispatch_jit PROPERTIES TIMEOUT 120)
1717

1818
py_test(test_multi_out_jit SRCS test_multi_out_jit.py)
19-
set_tests_properties(test_multi_out_jit PROPERTIES TIMEOUT 180)
19+
set_tests_properties(test_multi_out_jit PROPERTIES TIMEOUT 120)
20+
21+
py_test(test_custom_attrs_jit SRCS test_custom_attrs_jit.py)
22+
set_tests_properties(test_custom_attrs_jit PROPERTIES TIMEOUT 120)
2023

2124
if(NOT LINUX)
2225
return()

0 commit comments

Comments
 (0)