Skip to content

Commit 16d595c

Browse files
authored
support dynamic created attr (PaddlePaddle#40)
1 parent 39e2529 commit 16d595c

5 files changed

Lines changed: 87 additions & 49 deletions

File tree

paddle/fluid/framework/op_desc.cc

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -495,11 +495,17 @@ OpDesc::OpDesc(const proto::OpDesc &desc, BlockDesc *block)
495495
attr_type != proto::AttrType::BLOCKS &&
496496
attr_type != proto::AttrType::VAR &&
497497
attr_type != proto::AttrType::VARS) {
498-
auto iter = runtime_attrs_.find(attr_name);
499-
if (iter == runtime_attrs_.end()) {
500-
attrs_[attr_name] = GetAttrValue(attr);
498+
auto default_attr_iter = runtime_attrs_.find(attr_name);
499+
const auto &extra_dynamic_attr_map =
500+
operators::ExtraInfoUtils::Instance().GetExtraDynamicAttrsMap(Type());
501+
auto dynamic_attr_iter = extra_dynamic_attr_map.find(attr_name);
502+
503+
if (default_attr_iter != runtime_attrs_.end()) {
504+
default_attr_iter->second = GetAttrValue(attr);
505+
} else if (dynamic_attr_iter != extra_dynamic_attr_map.end()) {
506+
runtime_attrs_[attr_name] = GetAttrValue(attr);
501507
} else {
502-
iter->second = GetAttrValue(attr);
508+
attrs_[attr_name] = GetAttrValue(attr);
503509
}
504510
}
505511
}
@@ -665,8 +671,13 @@ void OpDesc::SetAttr(const std::string &name, const Attribute &v) {
665671

666672
const auto &extra_attr_map =
667673
operators::ExtraInfoUtils::Instance().GetExtraAttrsMap(Type());
674+
const auto &extra_dynamic_attr_map =
675+
operators::ExtraInfoUtils::Instance().GetExtraDynamicAttrsMap(Type());
668676
auto extra_attr_iter = extra_attr_map.find(name);
669-
if (extra_attr_iter != extra_attr_map.end()) {
677+
auto extra_dynamic_attr_iter = extra_dynamic_attr_map.find(name);
678+
// auto pass_attr_iter = operators::extra_attr_properties.find(name);
679+
if (extra_attr_iter != extra_attr_map.end() ||
680+
extra_dynamic_attr_iter != extra_dynamic_attr_map.end()) {
670681
is_runtime_attr = true;
671682
attrs_ptr = &(this->runtime_attrs_);
672683
}

paddle/fluid/operators/ops_extra_info.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ const std::unordered_map<std::string, ExtraAttrPropertySet>
9090
// ONEDNN dedicated attributes
9191
{"Activation_scale", ExtraAttrProperty::ONEDNN},
9292
{"Bias", ExtraAttrProperty::ONEDNN},
93+
{"Bias_scales", ExtraAttrProperty::ONEDNN},
9394
{"data_format", ExtraAttrProperty::ONEDNN},
9495
{"force_fp32_output", ExtraAttrProperty::ONEDNN},
9596
{"fuse_activation", ExtraAttrProperty::ONEDNN},
@@ -188,6 +189,15 @@ class ExtraInfoUtils {
188189
return empty_extra_attrs_map_;
189190
}
190191

192+
const paddle::framework::AttributeMap& GetExtraDynamicAttrsMap(
193+
const std::string& op_type) const {
194+
auto iter = g_extra_dynamic_attrs_map_.find(op_type);
195+
if (iter != g_extra_dynamic_attrs_map_.end()) {
196+
return iter->second;
197+
}
198+
return empty_extra_attrs_map_;
199+
}
200+
191201
const std::vector<std::function<void(framework::AttributeMap*, bool)>>&
192202
GetExtraAttrsChecker(const std::string& op_type) const {
193203
auto iter = g_extra_attrs_checker_.find(op_type);
@@ -211,6 +221,8 @@ class ExtraInfoUtils {
211221

212222
std::unordered_map<std::string, paddle::framework::AttributeMap>
213223
g_extra_attrs_map_;
224+
std::unordered_map<std::string, paddle::framework::AttributeMap>
225+
g_extra_dynamic_attrs_map_;
214226
paddle::framework::AttributeMap empty_extra_attrs_map_{};
215227
std::unordered_map<
216228
std::string,

paddle/phi/api/yaml/generator/ops_extra_info_gen.py

Lines changed: 57 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import argparse
1818

1919

20-
def map_code_template(attrs_str, attrs_checker_str):
20+
def map_code_template(default_attrs_str, dynamic_attr_str, attrs_checker_str):
2121
return f"""// This file is generated by paddle/phi/api/yaml/generator/ops_extra_info_gen.py
2222
#include "paddle/fluid/operators/ops_extra_info.h"
2323
@@ -28,7 +28,11 @@ def map_code_template(attrs_str, attrs_checker_str):
2828
2929
ExtraInfoUtils::ExtraInfoUtils() {{
3030
g_extra_attrs_map_ = {{
31-
{attrs_str}
31+
{default_attrs_str}
32+
}};
33+
34+
g_extra_dynamic_attrs_map_ = {{
35+
{dynamic_attr_str}
3236
}};
3337
3438
g_extra_attrs_checker_ = {{
@@ -64,10 +68,7 @@ def parse_attr(attr_str):
6468
'name'), result.group('default_val')
6569

6670

67-
def generate_extra_info(op_compat_yaml_path, ops_extra_info_path):
68-
compat_apis = []
69-
with open(op_compat_yaml_path, 'rt') as f:
70-
compat_apis = yaml.safe_load(f)
71+
def generate_attr_info(attr_type, op_compat_args):
7172

7273
def get_op_name(api_item):
7374
names = api_item.split('(')
@@ -76,49 +77,62 @@ def get_op_name(api_item):
7677
else:
7778
return names[1].split(')')[0].strip()
7879

79-
extra_map_str_list = []
80-
extra_checker_str_list = []
80+
attr_map_str_list = []
81+
attr_checker_str_list = []
82+
extra_args_map = op_compat_args['extra']
83+
if attr_type in extra_args_map:
84+
attr_map_list = []
85+
attr_checker_func_list = []
86+
for attr in extra_args_map[attr_type]:
87+
attr_type, attr_name, default_val = parse_attr(attr)
88+
attr_checker_func_list.append(
89+
f"[](framework::AttributeMap* attr_map, bool only_check_exist_value)-> void {{ ExtraAttrChecker<{attr_type}>(\"{attr_name}\", {default_val})(attr_map, only_check_exist_value);}}"
90+
)
91+
if attr_type.startswith("std::vector"):
92+
attr_map_list.append(
93+
f"{{\"{attr_name}\", {attr_type}{default_val}}}")
94+
else:
95+
attr_map_list.append(
96+
f"{{\"{attr_name}\", {attr_type}{{{default_val}}}}}")
97+
api_extra_attr_map = ", ".join(attr_map_list)
98+
api_extra_attr_checkers = ",\n ".join(attr_checker_func_list)
99+
attr_map_str_list.append(
100+
f"{{\"{get_op_name(op_compat_args['op'])}\", {{ {api_extra_attr_map} }}}}"
101+
)
102+
attr_checker_str_list.append(
103+
f"{{\"{get_op_name(op_compat_args['op'])}\", {{ {api_extra_attr_checkers} }}}}"
104+
)
105+
if 'backward' in op_compat_args:
106+
for bw_item in op_compat_args['backward'].split(','):
107+
bw_op_name = get_op_name(bw_item)
108+
attr_map_str_list.append(
109+
f"{{\"{bw_op_name}\", {{ {api_extra_attr_map} }}}}")
110+
attr_checker_str_list.append(
111+
f"{{\"{bw_op_name}\", {{ {api_extra_attr_checkers} }}}}")
112+
return attr_map_str_list, attr_checker_str_list
81113

114+
115+
def generate_extra_info(op_compat_yaml_path, ops_extra_info_path):
116+
compat_apis = []
117+
with open(op_compat_yaml_path, 'rt') as f:
118+
compat_apis = yaml.safe_load(f)
119+
extra_default_attr_str_list = []
120+
extra_dynamic_attr_str_list = []
121+
extra_checker_str_list = []
82122
for op_compat_args in compat_apis:
83123
if 'extra' in op_compat_args:
84-
extra_args_map = op_compat_args['extra']
85124
# TODO(chenweihang): add inputs and outputs
86-
if 'attrs' in extra_args_map:
87-
attr_map_list = []
88-
attr_checker_func_list = []
89-
for attr in extra_args_map['attrs']:
90-
attr_type, attr_name, default_val = parse_attr(attr)
91-
attr_checker_func_list.append(
92-
f"[](framework::AttributeMap* attr_map, bool only_check_exist_value)-> void {{ ExtraAttrChecker<{attr_type}>(\"{attr_name}\", {default_val})(attr_map, only_check_exist_value);}}"
93-
)
94-
if attr_type.startswith("std::vector"):
95-
attr_map_list.append(
96-
f"{{\"{attr_name}\", {attr_type}{default_val}}}")
97-
else:
98-
attr_map_list.append(
99-
f"{{\"{attr_name}\", {attr_type}{{{default_val}}}}}"
100-
)
101-
api_extra_attr_map = ", ".join(attr_map_list)
102-
api_extra_attr_checkers = ",\n ".join(
103-
attr_checker_func_list)
104-
extra_map_str_list.append(
105-
f"{{\"{get_op_name(op_compat_args['op'])}\", {{ {api_extra_attr_map} }}}}"
106-
)
107-
extra_checker_str_list.append(
108-
f"{{\"{get_op_name(op_compat_args['op'])}\", {{ {api_extra_attr_checkers} }}}}"
109-
)
110-
if 'backward' in op_compat_args:
111-
for bw_item in op_compat_args['backward'].split(','):
112-
bw_op_name = get_op_name(bw_item)
113-
extra_map_str_list.append(
114-
f"{{\"{bw_op_name}\", {{ {api_extra_attr_map} }}}}")
115-
extra_checker_str_list.append(
116-
f"{{\"{bw_op_name}\", {{ {api_extra_attr_checkers} }}}}"
117-
)
118-
125+
default_attr_map_str, default_attr_checker_str = generate_attr_info(
126+
'attrs', op_compat_args)
127+
dynamic_attr_map_str, _ = generate_attr_info(
128+
'dynamic_attrs', op_compat_args)
129+
extra_default_attr_str_list.extend(default_attr_map_str)
130+
extra_dynamic_attr_str_list.extend(dynamic_attr_map_str)
131+
extra_checker_str_list.extend(default_attr_checker_str)
119132
ops_extra_info_file = open(ops_extra_info_path, 'w')
120133
ops_extra_info_file.write(
121-
map_code_template(",\n ".join(extra_map_str_list),
134+
map_code_template(",\n ".join(extra_default_attr_str_list),
135+
",\n ".join(extra_dynamic_attr_str_list),
122136
",\n ".join(extra_checker_str_list)))
123137
ops_extra_info_file.close()
124138

paddle/phi/api/yaml/op_compat.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@
112112
bool fuse_residual_connection = false, float Scale_in = 1.0f, float Scale_out = 1.0f,
113113
float Scale_in_eltwise = 1.0f, 'float[] Scale_weights = {1.0f}', bool force_fp32_output = false,
114114
int workspace_size_MB = platform::GetDefaultConvWorkspaceSizeLimitMB(), bool exhaustive_search = false]
115+
dynamic_attrs : ['float[] Bias_scales = {1.0f}', float Sum_scale = 1.0f, 'float[] Output_shift_scale = {1.0f}', float Activation_scale = 1.0f]
115116

116117
- op : conv2d_fusion
117118
extra :

paddle/phi/kernels/onednn/conv_handler.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,7 @@ class ConvOneDNNHandlerT
406406

407407
// Scales for int8 bias are to be cached to avoid
408408
// computing them each iteration
409+
groups = std::max(groups, 1);
409410
auto bias_scale_tuple =
410411
std::static_pointer_cast<std::tuple<float, std::vector<float>>>(
411412
this->dev_ctx_.GetBlob(key_bs));
@@ -709,7 +710,6 @@ class ConvOneDNNHandlerT
709710
LOG(ERROR) << "Bias should be of type int32 but is " << bias->dtype();
710711
}
711712
const K_Bias* bias_data = bias->data<K_Bias>();
712-
713713
return this->AcquireMemoryWithReorder(
714714
bias->mem_desc(),
715715
this->fwd_pd_->bias_desc(),

0 commit comments

Comments
 (0)