-
Notifications
You must be signed in to change notification settings - Fork 5.9k
optimize attr default value #33357
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
optimize attr default value #33357
Changes from 16 commits
dd3329d
c8b557b
8b1cdab
0cf6374
2facfca
911300c
3598f3c
2ff8504
e22ea5f
d44aeb6
0132959
32e2f60
eb9b97e
e681acd
ea470d7
f54c359
2d4c9dc
806564c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -208,22 +208,35 @@ Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc); | |
|
|
||
| class AttrReader { | ||
| public: | ||
| explicit AttrReader(const AttributeMap& attrs) : attrs_(attrs) {} | ||
| explicit AttrReader(const AttributeMap& attrs) | ||
| : attrs_(attrs), attrs_default_(nullptr) {} | ||
|
|
||
| AttrReader(const AttributeMap& attrs, const AttributeMap& attrs_default) | ||
| : attrs_(attrs), attrs_default_(&attrs_default) {} | ||
|
|
||
| template <typename T> | ||
| inline const T& Get(const std::string& name) const { | ||
| PADDLE_ENFORCE_NE(attrs_.count(name), 0, | ||
| auto it = attrs_.find(name); | ||
| bool found = it != attrs_.end(); | ||
| if (it == attrs_.end()) { | ||
|
||
| if (attrs_default_ != nullptr) { | ||
| it = attrs_default_->find(name); | ||
| found = it != attrs_default_->end(); | ||
| } | ||
| } | ||
| PADDLE_ENFORCE_EQ(found, true, | ||
| platform::errors::NotFound( | ||
| "Attribute (%s) should be in AttributeMap.", name)); | ||
|
|
||
| Attribute& attr = const_cast<Attribute&>(attrs_.at(name)); | ||
| Attribute& attr = const_cast<Attribute&>(it->second); | ||
| ExtractAttribute<T> extract_attr(name); | ||
| T* attr_value = extract_attr(attr); | ||
| return *attr_value; | ||
| } | ||
|
|
||
| private: | ||
| const AttributeMap& attrs_; | ||
| const AttributeMap* attrs_default_; | ||
|
||
| }; | ||
|
|
||
| // check whether a value(attribute) fit a certain limit | ||
|
|
@@ -234,8 +247,8 @@ class GreaterThanChecker { | |
| void operator()(const T& value) const { | ||
| PADDLE_ENFORCE_GT( | ||
| value, lower_bound_, | ||
| platform::errors::OutOfRange( | ||
| "Check for attribute value greater than a certain value failed.")); | ||
| platform::errors::OutOfRange("Check for attribute value greater than " | ||
| "a certain value failed.")); | ||
| } | ||
|
|
||
| private: | ||
|
|
@@ -332,9 +345,9 @@ class TypedAttrChecker { | |
| TypedAttrChecker& SetDefault(const T& default_value) { | ||
| PADDLE_ENFORCE_EQ( | ||
| default_value_setter_.empty(), true, | ||
| platform::errors::AlreadyExists( | ||
| "Attribute (%s) has a default value and cannot be set repeatedly.", | ||
| attr_name_)); | ||
| platform::errors::AlreadyExists("Attribute (%s) has a default value " | ||
| "and cannot be set repeatedly.", | ||
| attr_name_)); | ||
| default_value_setter_.push_back(DefaultValueSetter<T>(default_value)); | ||
| return *this; | ||
| } | ||
|
|
@@ -345,30 +358,41 @@ class TypedAttrChecker { | |
| return *this; | ||
| } | ||
|
|
||
| void operator()(AttributeMap* attr_map, | ||
| bool get_default_value_only = false) const { | ||
| void operator()(AttributeMap* attr_map, bool get_default_value_only = false, | ||
| bool without_default_value = false) const { | ||
|
||
| if (get_default_value_only) { | ||
| if (!default_value_setter_.empty()) { | ||
| attr_map->emplace(attr_name_, default_value_setter_[0]()); | ||
| } | ||
| return; | ||
| } | ||
|
|
||
| auto it = attr_map->find(attr_name_); | ||
| if (it == attr_map->end()) { | ||
| // user do not set this attr | ||
| PADDLE_ENFORCE_EQ( | ||
| default_value_setter_.empty(), false, | ||
| platform::errors::InvalidArgument( | ||
| "Attribute (%s) is not set correctly.", attr_name_)); | ||
| // default_value_setter_ has no more than one element | ||
| attr_map->emplace(attr_name_, default_value_setter_[0]()); | ||
| } | ||
| it = attr_map->find(attr_name_); | ||
| ExtractAttribute<T> extract_attr(attr_name_); | ||
| T* attr_value = extract_attr(it->second); | ||
| for (const auto& checker : value_checkers_) { | ||
| checker(*attr_value); | ||
| if (without_default_value) { | ||
| auto it = attr_map->find(attr_name_); | ||
| if (it != attr_map->end()) { | ||
| ExtractAttribute<T> extract_attr(attr_name_); | ||
| T* attr_value = extract_attr(it->second); | ||
| for (const auto& checker : value_checkers_) { | ||
| checker(*attr_value); | ||
| } | ||
| } | ||
| } else { | ||
| auto it = attr_map->find(attr_name_); | ||
| if (it == attr_map->end()) { | ||
| // user do not set this attr | ||
| PADDLE_ENFORCE_EQ( | ||
| default_value_setter_.empty(), false, | ||
| platform::errors::InvalidArgument( | ||
| "Attribute (%s) is not set correctly.", attr_name_)); | ||
| // default_value_setter_ has no more than one element | ||
| auto tmp = attr_map->emplace(attr_name_, default_value_setter_[0]()); | ||
| it = tmp.first; | ||
| } | ||
| ExtractAttribute<T> extract_attr(attr_name_); | ||
| T* attr_value = extract_attr(it->second); | ||
| for (const auto& checker : value_checkers_) { | ||
| checker(*attr_value); | ||
| } | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -380,7 +404,7 @@ class TypedAttrChecker { | |
|
|
||
| // check whether op's all attributes fit their own limits | ||
| class OpAttrChecker { | ||
| typedef std::function<void(AttributeMap*, bool)> AttrChecker; | ||
| typedef std::function<void(AttributeMap*, bool, bool)> AttrChecker; | ||
|
|
||
| public: | ||
| template <typename T> | ||
|
|
@@ -390,18 +414,19 @@ class OpAttrChecker { | |
| return *(checker.target<TypedAttrChecker<T>>()); | ||
| } | ||
|
|
||
| void Check(AttributeMap* attr_map, bool explicit_only = false) const { | ||
| void Check(AttributeMap* attr_map, bool explicit_only = false, | ||
| bool without_default_value = false) const { | ||
| auto checker_num = attr_checkers_.size(); | ||
| if (explicit_only) checker_num = explicit_checker_num_; | ||
| for (size_t i = 0; i < checker_num; ++i) { | ||
| attr_checkers_[i](attr_map, false); | ||
| attr_checkers_[i](attr_map, false, without_default_value); | ||
| } | ||
| } | ||
|
|
||
| AttributeMap GetAttrsDefaultValuesMap() const { | ||
|
||
| AttributeMap default_values_map; | ||
| for (const auto& checker : attr_checkers_) { | ||
| checker(&default_values_map, true); | ||
| checker(&default_values_map, true, false); | ||
| } | ||
| return default_values_map; | ||
| } | ||
|
|
@@ -410,15 +435,26 @@ class OpAttrChecker { | |
| explicit_checker_num_ = attr_checkers_.size(); | ||
| } | ||
|
|
||
| void InitDefaultMap() { | ||
|
||
| for (const auto& checker : attr_checkers_) { | ||
| checker(&attrs_default_, true, false); | ||
|
||
| } | ||
| } | ||
|
|
||
| const AttributeMap& GetAttrDefaultMap() const { return attrs_default_; } | ||
|
||
|
|
||
| private: | ||
| std::vector<AttrChecker> attr_checkers_; | ||
|
|
||
| AttributeMap attrs_default_; | ||
|
|
||
| // in order to improve the efficiency of dynamic graph mode, | ||
| // we divede the attribute into explicit type and implicit type. | ||
| // for explicit attribute, we mean the attribute added in the customized | ||
| // op makers, usually it's defined in the overloaded Make method. | ||
| // for implicit attribute, we mean the attribute added outside of the Make | ||
| // method like "op_role", "op_role_var", and they are useless in dynamic graph | ||
| // method like "op_role", "op_role_var", and they are useless in dynamic | ||
| // graph | ||
| // mode | ||
| size_t explicit_checker_num_; | ||
| }; | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -113,9 +113,18 @@ class GradOpBaseMakerBase { | |
| return vec_temp; | ||
| } | ||
|
|
||
| // Only for dygraph | ||
| void SetDygraphAttrsDefaultMap(const framework::AttributeMap& attrs_default) { | ||
|
||
| attrs_default_ = &attrs_default; | ||
| } | ||
|
|
||
| const framework::AttributeMap& AttrsDefault() const { | ||
|
||
| return *attrs_default_; | ||
| } | ||
|
|
||
| const framework::AttributeMap& Attrs() const { return attrs_; } | ||
|
|
||
| const framework::Attribute& GetAttr(const std::string& name) const { | ||
| virtual const framework::Attribute& GetAttr(const std::string& name) const { | ||
| auto it = attrs_.find(name); | ||
| PADDLE_ENFORCE_EQ( | ||
| it != attrs_.end(), true, | ||
|
|
@@ -199,6 +208,7 @@ class GradOpBaseMakerBase { | |
| const NameVarBaseMap& var_base_map_in_; | ||
| const NameVarBaseMap& var_base_map_out_; | ||
| const framework::AttributeMap& attrs_; | ||
| const framework::AttributeMap* attrs_default_; | ||
| const std::map<std::string, std::string>& inplace_map_; | ||
| }; | ||
|
|
||
|
|
@@ -285,6 +295,10 @@ class TracedGradOp { | |
| return op_->SetAttrMap(attrs); | ||
| } | ||
|
|
||
| void SetAttrDefaultMap(const framework::AttributeMap& attrs) { | ||
|
||
| return op_->SetAttrDefaultMap(attrs); | ||
| } | ||
|
|
||
| void SetAttr(const std::string& name, const framework::Attribute& v) { | ||
| op_->SetAttr(name, v); | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -35,11 +35,13 @@ class DygraphExecutionContext : public framework::ExecutionContext { | |
| const framework::RuntimeContext& ctx, | ||
| const NameVarMap<VarType>& var_base_map_in, | ||
| const NameVarMap<VarType>& var_base_map_out, | ||
| const framework::AttributeMap& attrs) | ||
| const framework::AttributeMap& attrs, | ||
| const framework::AttributeMap& attrs_default) | ||
| : ExecutionContext(op, scope, device_context, ctx), | ||
| var_base_map_in_(var_base_map_in), | ||
| var_base_map_out_(var_base_map_out), | ||
| attrs_(attrs) {} | ||
| attrs_(attrs), | ||
| attrs_default_(attrs_default) {} | ||
|
|
||
| std::string InputName(const std::string& name) const override { | ||
| auto it = var_base_map_in_.find(name); | ||
|
|
@@ -92,17 +94,21 @@ class DygraphExecutionContext : public framework::ExecutionContext { | |
| } | ||
|
|
||
| bool HasAttr(const std::string& name) const override { | ||
| return attrs_.count(name) != 0; | ||
| return attrs_.count(name) != 0 || attrs_default_.count(name) != 0; | ||
| } | ||
|
|
||
| const framework::AttributeMap& Attrs() const override { return attrs_; } | ||
|
|
||
| const framework::Attribute& GetAttr(const std::string& name) const override { | ||
| auto it = attrs_.find(name); | ||
|
|
||
| PADDLE_ENFORCE_NE( | ||
| it, attrs_.end(), | ||
| platform::errors::NotFound("can not find [%s] in attrs", name)); | ||
| if (it == attrs_.end()) { | ||
| it = attrs_default_.find(name); | ||
| if (it == attrs_default_.end()) { | ||
| PADDLE_THROW( | ||
| platform::errors::NotFound("Can not find [%s] in attrs", name)); | ||
|
||
| } | ||
| } | ||
|
|
||
| return it->second; | ||
| } | ||
|
|
@@ -192,6 +198,7 @@ class DygraphExecutionContext : public framework::ExecutionContext { | |
| const NameVarMap<VarType>& var_base_map_in_; | ||
| const NameVarMap<VarType>& var_base_map_out_; | ||
| const framework::AttributeMap& attrs_; | ||
| const framework::AttributeMap& attrs_default_; | ||
| }; | ||
|
|
||
| } // namespace imperative | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -35,10 +35,12 @@ class DygraphInferShapeContext : public framework::InferShapeContext { | |
| DygraphInferShapeContext(const NameVarMap<VarType>* in, | ||
| const NameVarMap<VarType>* out, | ||
| const framework::AttributeMap* attr, | ||
| const framework::AttributeMap* attr_default, | ||
|
||
| const std::string op_type) | ||
| : var_base_map_in_(in), | ||
| var_base_map_out_(out), | ||
| attrs_(attr), | ||
| attrs_default_(attr_default), | ||
| op_type_(op_type) {} | ||
|
|
||
| bool HasInput(const std::string& name) const override { | ||
|
|
@@ -101,7 +103,7 @@ class DygraphInferShapeContext : public framework::InferShapeContext { | |
| } | ||
|
|
||
| framework::AttrReader Attrs() const override { | ||
| return framework::AttrReader(*attrs_); | ||
| return framework::AttrReader(*attrs_, *attrs_default_); | ||
| } | ||
|
|
||
| std::vector<std::string> Inputs(const std::string& name) const override { | ||
|
|
@@ -395,6 +397,7 @@ class DygraphInferShapeContext : public framework::InferShapeContext { | |
| const NameVarMap<VarType>* var_base_map_in_; | ||
| const NameVarMap<VarType>* var_base_map_out_; | ||
| const framework::AttributeMap* attrs_; | ||
| const framework::AttributeMap* attrs_default_; | ||
| const std::string op_type_; | ||
| }; | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
命名有些奇怪,建议
attrs_default->default_attrsThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同感
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done,thx!