@@ -73,6 +73,24 @@ inline bool IsMemberOf(const std::vector<std::string>& vec,
7373 return std::find (vec.cbegin (), vec.cend (), name) != vec.cend ();
7474}
7575
76+ std::vector<std::string> ParseAttrStr (const std::string& attr) {
77+ auto split_pos = attr.find_first_of (" :" );
78+ PADDLE_ENFORCE_NE (split_pos, std::string::npos,
79+ platform::errors::InvalidArgument (
80+ " Invalid attribute string format. Attribute string "
81+ " format is `<name>:<type>`." ));
82+
83+ std::vector<std::string> rlt;
84+ // 1. name
85+ rlt.emplace_back (string::trim_spaces (attr.substr (0 , split_pos)));
86+ // 2. type
87+ rlt.emplace_back (string::trim_spaces (attr.substr (split_pos + 1 )));
88+
89+ VLOG (1 ) << " attr name: " << rlt[0 ] << " , attr type str: " << rlt[1 ];
90+
91+ return rlt;
92+ }
93+
7694} // namespace detail
7795
7896// //////////////// Kernel Define ////////////////////
@@ -81,7 +99,8 @@ inline bool IsMemberOf(const std::vector<std::string>& vec,
8199static void RunKernelFunc (const framework::ExecutionContext& ctx,
82100 const paddle::KernelFunc& func,
83101 const std::vector<std::string>& inputs,
84- const std::vector<std::string>& outputs) {
102+ const std::vector<std::string>& outputs,
103+ const std::vector<std::string>& attrs) {
85104 VLOG (1 ) << " Custom Operator: Start run KernelFunc." ;
86105 std::vector<paddle::Tensor> custom_ins;
87106 for (auto & in_name : inputs) {
@@ -98,10 +117,43 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx,
98117 custom_ins.emplace_back (custom_in);
99118 }
100119
101- std::vector<boost::any> attrs;
120+ std::vector<boost::any> custom_attrs;
121+ for (auto & attr_str : attrs) {
122+ auto attr_name_and_type = detail::ParseAttrStr (attr_str);
123+ auto attr_name = attr_name_and_type[0 ];
124+ auto attr_type_str = attr_name_and_type[1 ];
125+ if (attr_type_str == " bool" ) {
126+ custom_attrs.emplace_back (ctx.Attr <bool >(attr_name));
127+ } else if (attr_type_str == " int" ) {
128+ custom_attrs.emplace_back (ctx.Attr <int >(attr_name));
129+ } else if (attr_type_str == " float" ) {
130+ custom_attrs.emplace_back (ctx.Attr <float >(attr_name));
131+ } else if (attr_type_str == " int64_t" ) {
132+ custom_attrs.emplace_back (ctx.Attr <int64_t >(attr_name));
133+ } else if (attr_type_str == " std::string" ) {
134+ custom_attrs.emplace_back (ctx.Attr <std::string>(attr_name));
135+ } else if (attr_type_str == " std::vector<int>" ) {
136+ custom_attrs.emplace_back (ctx.Attr <std::vector<int >>(attr_name));
137+ } else if (attr_type_str == " std::vector<float>" ) {
138+ custom_attrs.emplace_back (ctx.Attr <std::vector<float >>(attr_name));
139+ } else if (attr_type_str == " std::vector<int64_t>" ) {
140+ custom_attrs.emplace_back (ctx.Attr <std::vector<int64_t >>(attr_name));
141+ } else if (attr_type_str == " std::vector<std::string>" ) {
142+ custom_attrs.emplace_back (ctx.Attr <std::vector<std::string>>(attr_name));
143+ } else {
144+ PADDLE_THROW (platform::errors::Unimplemented (
145+ " Unsupported `%s` type value as custom attribute now. "
146+ " Supported data types include `bool`, `int`, `float`, "
147+ " `int64_t`, `std::string`, `std::vector<int>`, "
148+ " `std::vector<float>`, `std::vector<int64_t>, "
149+ " `std::vector<std::string>`, Please check whether "
150+ " the attribute data type and data type string are matched." ,
151+ attr_type_str));
152+ }
153+ }
102154
103155 VLOG (1 ) << " Run ComputeFunc." ;
104- auto outs = func (custom_ins, attrs );
156+ auto outs = func (custom_ins, custom_attrs );
105157
106158 VLOG (1 ) << " Custom Operator: Share outputs into ExecutionContext." ;
107159 for (size_t i = 0 ; i < outputs.size (); ++i) {
@@ -164,7 +216,51 @@ class CustomOpMaker : public OpProtoAndCheckerMaker {
164216 for (auto & out_name : outputs_) {
165217 AddOutput (out_name, " The output " + out_name + " of Custom Operator." );
166218 }
167- // TODO(chenweihang): support attrs in later PR
219+ for (auto & attr : attrs_) {
220+ auto attr_name_and_type = detail::ParseAttrStr (attr);
221+ auto attr_name = attr_name_and_type[0 ];
222+ auto attr_type_str = attr_name_and_type[1 ];
223+ if (attr_type_str == " bool" ) {
224+ AddAttr<bool >(attr_name, " custom operator bool attribute." )
225+ .SetDefault (false );
226+ } else if (attr_type_str == " int" ) {
227+ AddAttr<int >(attr_name, " custom operator int attribute." ).SetDefault (1 );
228+ } else if (attr_type_str == " float" ) {
229+ AddAttr<float >(attr_name, " custom operator float attribute." )
230+ .SetDefault (1 .0f );
231+ } else if (attr_type_str == " int64_t" ) {
232+ AddAttr<int64_t >(attr_name, " custom operator int64_t attribute." )
233+ .SetDefault (1 );
234+ } else if (attr_type_str == " std::string" ) {
235+ AddAttr<std::string>(attr_name, " custom operator int attribute." )
236+ .SetDefault (" " );
237+ } else if (attr_type_str == " std::vector<int>" ) {
238+ AddAttr<std::vector<int >>(attr_name,
239+ " custom operator std::vector<int> attribute." )
240+ .SetDefault ({});
241+ } else if (attr_type_str == " std::vector<float>" ) {
242+ AddAttr<std::vector<float >>(
243+ attr_name, " custom operator std::vector<float> attribute." )
244+ .SetDefault ({});
245+ } else if (attr_type_str == " std::vector<int64_t>" ) {
246+ AddAttr<std::vector<int64_t >>(
247+ attr_name, " custom operator std::vector<int64_t> attribute." )
248+ .SetDefault ({});
249+ } else if (attr_type_str == " std::vector<std::string>" ) {
250+ AddAttr<std::vector<std::string>>(
251+ attr_name, " custom operator std::vector<std::string> attribute." )
252+ .SetDefault ({});
253+ } else {
254+ PADDLE_THROW (platform::errors::Unimplemented (
255+ " Unsupported `%s` type value as custom attribute now. "
256+ " Supported data types include `bool`, `int`, `float`, "
257+ " `int64_t`, `std::string`, `std::vector<int>`, "
258+ " `std::vector<float>`, `std::vector<int64_t>, "
259+ " `std::vector<std::string>`, Please check whether "
260+ " the attribute data type and data type string are matched." ,
261+ attr_type_str));
262+ }
263+ }
168264 AddComment (R"DOC(
169265Custom Operator.
170266
@@ -227,7 +323,7 @@ class CustomGradOpMaker<OpDesc> : public SingleGradOpMaker<OpDesc> {
227323 VLOG (1 ) << " Custom Operator: GradOpDescMaker - output: " << out_name;
228324 grad_op->SetOutput (out_name, this ->InputGrad (detail::NoGrad (out_name)));
229325 }
230- // TODO(chenweihang): support attrs in later PR
326+ grad_op-> SetAttrMap ( this -> Attrs ());
231327 }
232328
233329 private:
@@ -287,7 +383,7 @@ class CustomGradOpMaker<imperative::OpBase>
287383 VLOG (1 ) << " Custom Operator: GradOpBaseMaker - output: " << out_name;
288384 grad_op->SetOutput (out_name, this ->InputGrad (detail::NoGrad (out_name)));
289385 }
290- // TODO(chenweihang): support attrs in later PR
386+ grad_op-> SetAttrMap ( this -> Attrs ());
291387 }
292388
293389 private:
@@ -303,31 +399,36 @@ void RegisterOperatorKernelWithPlace(const std::string& name,
303399 const proto::VarType::Type type,
304400 const PlaceType& place,
305401 const std::vector<std::string>& inputs,
306- const std::vector<std::string>& outputs) {
402+ const std::vector<std::string>& outputs,
403+ const std::vector<std::string>& attrs) {
307404 OpKernelType key (type,
308405 CustomTensorUtils::ConvertEnumPlaceToInnerPlace (place));
309406 VLOG (1 ) << " Custom Operator: op kernel key: " << key;
310407 OperatorWithKernel::AllOpKernels ()[name][key] =
311- [kernel_func, inputs, outputs](const framework::ExecutionContext& ctx) {
408+ [kernel_func, inputs, outputs,
409+ attrs](const framework::ExecutionContext& ctx) {
312410 VLOG (1 ) << " Custom Operator: run custom kernel func in lambda." ;
313- RunKernelFunc (ctx, kernel_func, inputs, outputs);
411+ RunKernelFunc (ctx, kernel_func, inputs, outputs, attrs );
314412 };
315413}
316414
317415void RegisterOperatorKernel (const std::string& name,
318416 const paddle::KernelFunc& kernel_func,
319417 const std::vector<std::string>& inputs,
320- const std::vector<std::string>& outputs) {
418+ const std::vector<std::string>& outputs,
419+ const std::vector<std::string>& attrs) {
321420 VLOG (1 ) << " Custom Operator: op name in kernel: " << name;
322421 // NOTE [ Dummy Op Kernel Key ]
323422 // TODO(chenweihang): Because execute engine need get device context based
324423 // op_kernel_key.place_, so we should register kernel for each
325424 // device. But this is not entirely correct, if user only give a cpu kernel,
326425 // but call api in gpu device, it will cause error.
327426 RegisterOperatorKernelWithPlace (name, kernel_func, proto::VarType::RAW,
328- PlaceType::kCPU , inputs, outputs);
427+ PlaceType::kCPU , inputs, outputs, attrs);
428+ #ifdef PADDLE_WITH_CUDA
329429 RegisterOperatorKernelWithPlace (name, kernel_func, proto::VarType::RAW,
330- PlaceType::kGPU , inputs, outputs);
430+ PlaceType::kGPU , inputs, outputs, attrs);
431+ #endif
331432}
332433
333434void RegisterOperatorWithMetaInfo (
@@ -350,6 +451,8 @@ void RegisterOperatorWithMetaInfo(
350451 << string::join_strings (op_inputs, ' ,' );
351452 VLOG (1 ) << " Custom Operator: forward, op outputs: "
352453 << string::join_strings (op_outputs, ' ,' );
454+ VLOG (1 ) << " Custom Operator: forward, op attrs: "
455+ << string::join_strings (op_attrs, ' ,' );
353456
354457 // Op
355458 info.creator_ = [](const std::string& op_name, const VariableNameMap& inputs,
@@ -426,7 +529,7 @@ void RegisterOperatorWithMetaInfo(
426529 };
427530
428531 // Kernel func
429- RegisterOperatorKernel (op_name, kernel_fn, op_inputs, op_outputs);
532+ RegisterOperatorKernel (op_name, kernel_fn, op_inputs, op_outputs, op_attrs );
430533
431534 // If grad op or double grad op exists
432535 std::string cur_op_name = op_name;
@@ -436,6 +539,7 @@ void RegisterOperatorWithMetaInfo(
436539 auto & grad_op_name = OpMetaInfoHelper::GetOpName (cur_grad_op);
437540 auto & grad_op_inputs = OpMetaInfoHelper::GetInputs (cur_grad_op);
438541 auto & grad_op_outputs = OpMetaInfoHelper::GetOutputs (cur_grad_op);
542+ auto & grad_op_attrs = OpMetaInfoHelper::GetAttrs (cur_grad_op);
439543 auto & grad_kernel_fn = OpMetaInfoHelper::GetKernelFn (cur_grad_op);
440544
441545 VLOG (1 ) << " Custom Operator: backward, op name: " << grad_op_name;
@@ -489,7 +593,7 @@ void RegisterOperatorWithMetaInfo(
489593
490594 // Kernel func
491595 RegisterOperatorKernel (grad_op_name, grad_kernel_fn, grad_op_inputs,
492- grad_op_outputs);
596+ grad_op_outputs, grad_op_attrs );
493597
494598 // update current info
495599 OpInfoMap::Instance ().Insert (cur_op_name, info);
0 commit comments