-
Notifications
You must be signed in to change notification settings - Fork 5.9k
optimize attr default value #33357
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
optimize attr default value #33357
Changes from 9 commits
dd3329d
c8b557b
8b1cdab
0cf6374
2facfca
911300c
3598f3c
2ff8504
e22ea5f
d44aeb6
0132959
32e2f60
eb9b97e
e681acd
ea470d7
f54c359
2d4c9dc
806564c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -208,22 +208,37 @@ Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc); | |
|
|
||
| class AttrReader { | ||
| public: | ||
| explicit AttrReader(const AttributeMap& attrs) : attrs_(attrs) {} | ||
| explicit AttrReader(const AttributeMap& attrs) | ||
| : attrs_(attrs), attrs_default_(nullptr) {} | ||
|
|
||
| AttrReader(const AttributeMap& attrs, const AttributeMap& attrs_default) | ||
| : attrs_(attrs), attrs_default_(&attrs_default) {} | ||
|
|
||
| template <typename T> | ||
| inline const T& Get(const std::string& name) const { | ||
| PADDLE_ENFORCE_NE(attrs_.count(name), 0, | ||
| platform::errors::NotFound( | ||
| "Attribute (%s) should be in AttributeMap.", name)); | ||
| auto it = attrs_.find(name); | ||
| if (it == attrs_.end()) { | ||
|
||
| if (attrs_default_ != nullptr) { | ||
| it = attrs_default_->find(name); | ||
| if (it == attrs_default_->end()) { | ||
|
||
| PADDLE_THROW(platform::errors::NotFound( | ||
| "Attribute (%s) should be in AttributeMap.", name)); | ||
| } | ||
| } else { | ||
| PADDLE_THROW(platform::errors::NotFound( | ||
| "Attribute (%s) should be in AttributeMap.", name)); | ||
| } | ||
| } | ||
|
|
||
| Attribute& attr = const_cast<Attribute&>(attrs_.at(name)); | ||
| Attribute& attr = const_cast<Attribute&>(it->second); | ||
| ExtractAttribute<T> extract_attr(name); | ||
| T* attr_value = extract_attr(attr); | ||
| return *attr_value; | ||
| } | ||
|
|
||
| private: | ||
| const AttributeMap& attrs_; | ||
| const AttributeMap* attrs_default_; | ||
|
||
| }; | ||
|
|
||
| // check whether a value(attribute) fit a certain limit | ||
|
|
@@ -345,30 +360,41 @@ class TypedAttrChecker { | |
| return *this; | ||
| } | ||
|
|
||
| void operator()(AttributeMap* attr_map, | ||
| bool get_default_value_only = false) const { | ||
| void operator()(AttributeMap* attr_map, bool get_default_value_only = false, | ||
| bool no_default_value = false) const { | ||
| if (get_default_value_only) { | ||
| if (!default_value_setter_.empty()) { | ||
| attr_map->emplace(attr_name_, default_value_setter_[0]()); | ||
| } | ||
| return; | ||
| } | ||
|
|
||
| auto it = attr_map->find(attr_name_); | ||
| if (it == attr_map->end()) { | ||
| // user do not set this attr | ||
| PADDLE_ENFORCE_EQ( | ||
| default_value_setter_.empty(), false, | ||
| platform::errors::InvalidArgument( | ||
| "Attribute (%s) is not set correctly.", attr_name_)); | ||
| // default_value_setter_ has no more than one element | ||
| attr_map->emplace(attr_name_, default_value_setter_[0]()); | ||
| } | ||
| it = attr_map->find(attr_name_); | ||
| ExtractAttribute<T> extract_attr(attr_name_); | ||
| T* attr_value = extract_attr(it->second); | ||
| for (const auto& checker : value_checkers_) { | ||
| checker(*attr_value); | ||
| if (no_default_value) { | ||
| auto it = attr_map->find(attr_name_); | ||
| if (it != attr_map->end()) { | ||
| ExtractAttribute<T> extract_attr(attr_name_); | ||
| T* attr_value = extract_attr(it->second); | ||
| for (const auto& checker : value_checkers_) { | ||
| checker(*attr_value); | ||
| } | ||
| } | ||
| } else { | ||
| auto it = attr_map->find(attr_name_); | ||
| if (it == attr_map->end()) { | ||
| // user do not set this attr | ||
| PADDLE_ENFORCE_EQ( | ||
| default_value_setter_.empty(), false, | ||
| platform::errors::InvalidArgument( | ||
| "Attribute (%s) is not set correctly.", attr_name_)); | ||
| // default_value_setter_ has no more than one element | ||
| attr_map->emplace(attr_name_, default_value_setter_[0]()); | ||
|
||
| } | ||
| it = attr_map->find(attr_name_); | ||
|
||
| ExtractAttribute<T> extract_attr(attr_name_); | ||
| T* attr_value = extract_attr(it->second); | ||
| for (const auto& checker : value_checkers_) { | ||
| checker(*attr_value); | ||
| } | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -380,7 +406,7 @@ class TypedAttrChecker { | |
|
|
||
| // check whether op's all attributes fit their own limits | ||
| class OpAttrChecker { | ||
| typedef std::function<void(AttributeMap*, bool)> AttrChecker; | ||
| typedef std::function<void(AttributeMap*, bool, bool)> AttrChecker; | ||
|
|
||
| public: | ||
| template <typename T> | ||
|
|
@@ -390,18 +416,19 @@ class OpAttrChecker { | |
| return *(checker.target<TypedAttrChecker<T>>()); | ||
| } | ||
|
|
||
| void Check(AttributeMap* attr_map, bool explicit_only = false) const { | ||
| void Check(AttributeMap* attr_map, bool explicit_only = false, | ||
| bool no_default_value = false) const { | ||
| auto checker_num = attr_checkers_.size(); | ||
| if (explicit_only) checker_num = explicit_checker_num_; | ||
| for (size_t i = 0; i < checker_num; ++i) { | ||
| attr_checkers_[i](attr_map, false); | ||
| attr_checkers_[i](attr_map, false, no_default_value); | ||
| } | ||
| } | ||
|
|
||
| AttributeMap GetAttrsDefaultValuesMap() const { | ||
|
||
| AttributeMap default_values_map; | ||
| for (const auto& checker : attr_checkers_) { | ||
| checker(&default_values_map, true); | ||
| checker(&default_values_map, true, false); | ||
| } | ||
| return default_values_map; | ||
| } | ||
|
|
@@ -410,9 +437,19 @@ class OpAttrChecker { | |
| explicit_checker_num_ = attr_checkers_.size(); | ||
| } | ||
|
|
||
| void InitDefaultMap() { | ||
|
||
| for (const auto& checker : attr_checkers_) { | ||
| checker(&attrs_default_, true, false); | ||
|
||
| } | ||
| } | ||
|
|
||
| const AttributeMap& GetAttrDefaultMap() const { return attrs_default_; } | ||
|
||
|
|
||
| private: | ||
| std::vector<AttrChecker> attr_checkers_; | ||
|
|
||
| AttributeMap attrs_default_; | ||
|
|
||
| // in order to improve the efficiency of dynamic graph mode, | ||
| // we divede the attribute into explicit type and implicit type. | ||
| // for explicit attribute, we mean the attribute added in the customized | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -93,6 +93,8 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto, | |
| AddAttr<std::string>(OpDeviceAttrName(), "Device type of this operator.") | ||
| .SetDefault(""); | ||
| Validate(); | ||
|
|
||
| op_checker_->InitDefaultMap(); | ||
|
||
| } | ||
|
|
||
| } // namespace framework | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -113,9 +113,18 @@ class GradOpBaseMakerBase { | |
| return vec_temp; | ||
| } | ||
|
|
||
| // Only for dygraph | ||
| void SetDygraphAttrsDefaultMap(const framework::AttributeMap& attrs_default) { | ||
|
||
| attrs_default_ = &attrs_default; | ||
| } | ||
|
|
||
| const framework::AttributeMap& AttrsDefault() const { | ||
|
||
| return *attrs_default_; | ||
| } | ||
|
|
||
| const framework::AttributeMap& Attrs() const { return attrs_; } | ||
|
|
||
| const framework::Attribute& GetAttr(const std::string& name) const { | ||
| virtual const framework::Attribute& GetAttr(const std::string& name) const { | ||
| auto it = attrs_.find(name); | ||
| PADDLE_ENFORCE_EQ( | ||
| it != attrs_.end(), true, | ||
|
|
@@ -199,6 +208,7 @@ class GradOpBaseMakerBase { | |
| const NameVarBaseMap& var_base_map_in_; | ||
| const NameVarBaseMap& var_base_map_out_; | ||
| const framework::AttributeMap& attrs_; | ||
| const framework::AttributeMap* attrs_default_; | ||
| const std::map<std::string, std::string>& inplace_map_; | ||
| }; | ||
|
|
||
|
|
@@ -285,6 +295,10 @@ class TracedGradOp { | |
| return op_->SetAttrMap(attrs); | ||
| } | ||
|
|
||
| void SetAttrDefaultMap(const framework::AttributeMap& attrs) { | ||
|
||
| return op_->SetAttrDefaultMap(attrs); | ||
| } | ||
|
|
||
| void SetAttr(const std::string& name, const framework::Attribute& v) { | ||
| op_->SetAttr(name, v); | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -35,11 +35,13 @@ class DygraphExecutionContext : public framework::ExecutionContext { | |
| const framework::RuntimeContext& ctx, | ||
| const NameVarMap<VarType>& var_base_map_in, | ||
| const NameVarMap<VarType>& var_base_map_out, | ||
| const framework::AttributeMap& attrs) | ||
| const framework::AttributeMap& attrs, | ||
| const framework::AttributeMap& attrs_default) | ||
| : ExecutionContext(op, scope, device_context, ctx), | ||
| var_base_map_in_(var_base_map_in), | ||
| var_base_map_out_(var_base_map_out), | ||
| attrs_(attrs) {} | ||
| attrs_(attrs), | ||
| attrs_default_(attrs_default) {} | ||
|
|
||
| std::string InputName(const std::string& name) const override { | ||
| auto it = var_base_map_in_.find(name); | ||
|
|
@@ -92,17 +94,21 @@ class DygraphExecutionContext : public framework::ExecutionContext { | |
| } | ||
|
|
||
| bool HasAttr(const std::string& name) const override { | ||
| return attrs_.count(name) != 0; | ||
| return attrs_.count(name) != 0 || attrs_default_.count(name) != 0; | ||
| } | ||
|
|
||
| const framework::AttributeMap& Attrs() const override { return attrs_; } | ||
|
|
||
| const framework::Attribute& GetAttr(const std::string& name) const override { | ||
| auto it = attrs_.find(name); | ||
|
|
||
| PADDLE_ENFORCE_NE( | ||
| it, attrs_.end(), | ||
| platform::errors::NotFound("can not find [%s] in attrs", name)); | ||
| if (it == attrs_.end()) { | ||
| it = attrs_default_.find(name); | ||
| if (it == attrs_default_.end()) { | ||
| PADDLE_THROW( | ||
| platform::errors::NotFound("Can not find [%s] in attrs", name)); | ||
|
||
| } | ||
| } | ||
|
|
||
| return it->second; | ||
| } | ||
|
|
@@ -192,6 +198,7 @@ class DygraphExecutionContext : public framework::ExecutionContext { | |
| const NameVarMap<VarType>& var_base_map_in_; | ||
| const NameVarMap<VarType>& var_base_map_out_; | ||
| const framework::AttributeMap& attrs_; | ||
| const framework::AttributeMap& attrs_default_; | ||
| }; | ||
|
|
||
| } // namespace imperative | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -35,10 +35,12 @@ class DygraphInferShapeContext : public framework::InferShapeContext { | |
| DygraphInferShapeContext(const NameVarMap<VarType>* in, | ||
| const NameVarMap<VarType>* out, | ||
| const framework::AttributeMap* attr, | ||
| const framework::AttributeMap* attr_default, | ||
|
||
| const std::string op_type) | ||
| : var_base_map_in_(in), | ||
| var_base_map_out_(out), | ||
| attrs_(attr), | ||
| attrs_default_(attr_default), | ||
| op_type_(op_type) {} | ||
|
|
||
| bool HasInput(const std::string& name) const override { | ||
|
|
@@ -101,7 +103,7 @@ class DygraphInferShapeContext : public framework::InferShapeContext { | |
| } | ||
|
|
||
| framework::AttrReader Attrs() const override { | ||
| return framework::AttrReader(*attrs_); | ||
| return framework::AttrReader(*attrs_, *attrs_default_); | ||
| } | ||
|
|
||
| std::vector<std::string> Inputs(const std::string& name) const override { | ||
|
|
@@ -395,6 +397,7 @@ class DygraphInferShapeContext : public framework::InferShapeContext { | |
| const NameVarMap<VarType>* var_base_map_in_; | ||
| const NameVarMap<VarType>* var_base_map_out_; | ||
| const framework::AttributeMap* attrs_; | ||
| const framework::AttributeMap* attrs_default_; | ||
| const std::string op_type_; | ||
| }; | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -32,20 +32,28 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext { | |
| public: | ||
| RuntimeInferVarTypeContext(const NameVarMap<VarType>& inputs, | ||
| const NameVarMap<VarType>& outputs, | ||
| const framework::AttributeMap& attrs_map) | ||
| const framework::AttributeMap& attrs_map, | ||
| const framework::AttributeMap& attrs_map_default) | ||
| : InferVarTypeContext(nullptr, nullptr), | ||
| inputs_(inputs), | ||
| outputs_(outputs), | ||
| attrs_(attrs_map) {} | ||
| attrs_(attrs_map), | ||
| attrs_default_(attrs_map_default) {} | ||
|
|
||
| virtual ~RuntimeInferVarTypeContext() {} | ||
|
|
||
| framework::Attribute GetAttr(const std::string& name) const override { | ||
| auto iter = attrs_.find(name); | ||
| PADDLE_ENFORCE_EQ( | ||
| iter != attrs_.end(), true, | ||
| platform::errors::NotFound("Cannot find attribute %s", name)); | ||
| return iter->second; | ||
| auto it = attrs_.find(name); | ||
|
|
||
| if (it == attrs_.end()) { | ||
| it = attrs_default_.find(name); | ||
| if (it == attrs_default_.end()) { | ||
| PADDLE_THROW( | ||
| platform::errors::NotFound("Can not find [%s] in attrs", name)); | ||
|
||
| } | ||
| } | ||
|
|
||
| return it->second; | ||
| } | ||
|
|
||
| bool HasInput(const std::string& name) const override { | ||
|
|
@@ -233,6 +241,7 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext { | |
| const NameVarMap<VarType>& inputs_; | ||
| const NameVarMap<VarType>& outputs_; | ||
| const framework::AttributeMap& attrs_; | ||
| const framework::AttributeMap& attrs_default_; | ||
| }; | ||
|
|
||
| } // namespace imperative | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
命名有些奇怪,建议
attrs_default->default_attrsThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同感
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done,thx!