Skip to content
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 66 additions & 30 deletions paddle/fluid/framework/attribute.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,22 +208,35 @@ Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc);

class AttrReader {
public:
explicit AttrReader(const AttributeMap& attrs) : attrs_(attrs) {}
explicit AttrReader(const AttributeMap& attrs)
: attrs_(attrs), attrs_default_(nullptr) {}

AttrReader(const AttributeMap& attrs, const AttributeMap& attrs_default)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

命名有些奇怪,建议attrs_default -> default_attrs

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同感

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done,thx!

: attrs_(attrs), attrs_default_(&attrs_default) {}

template <typename T>
inline const T& Get(const std::string& name) const {
PADDLE_ENFORCE_NE(attrs_.count(name), 0,
auto it = attrs_.find(name);
bool found = it != attrs_.end();
if (it == attrs_.end()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not use found instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done,thx!

if (attrs_default_ != nullptr) {
it = attrs_default_->find(name);
found = it != attrs_default_->end();
}
}
PADDLE_ENFORCE_EQ(found, true,
platform::errors::NotFound(
"Attribute (%s) should be in AttributeMap.", name));

Attribute& attr = const_cast<Attribute&>(attrs_.at(name));
Attribute& attr = const_cast<Attribute&>(it->second);
ExtractAttribute<T> extract_attr(name);
T* attr_value = extract_attr(attr);
return *attr_value;
}

private:
const AttributeMap& attrs_;
const AttributeMap* attrs_default_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么不同样使用const &类型,这里有什么考虑吗

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同感

Copy link
Contributor Author

@wanghuancoder wanghuancoder Jun 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

引用必须在构造函数的时候初始化。考虑到动态图静态图都使用AttrReader他的构造函数有2种:
AttrReader(const AttributeMap& attrs)
: attrs_(attrs), default_attrs_(nullptr) {}

AttrReader(const AttributeMap& attrs, const AttributeMap& default_attrs)
: attrs_(attrs), default_attrs_(&default_attrs) {}

};

// check whether a value(attribute) fit a certain limit
Expand All @@ -234,8 +247,8 @@ class GreaterThanChecker {
void operator()(const T& value) const {
PADDLE_ENFORCE_GT(
value, lower_bound_,
platform::errors::OutOfRange(
"Check for attribute value greater than a certain value failed."));
platform::errors::OutOfRange("Check for attribute value greater than "
"a certain value failed."));
}

private:
Expand Down Expand Up @@ -332,9 +345,9 @@ class TypedAttrChecker {
TypedAttrChecker& SetDefault(const T& default_value) {
PADDLE_ENFORCE_EQ(
default_value_setter_.empty(), true,
platform::errors::AlreadyExists(
"Attribute (%s) has a default value and cannot be set repeatedly.",
attr_name_));
platform::errors::AlreadyExists("Attribute (%s) has a default value "
"and cannot be set repeatedly.",
attr_name_));
default_value_setter_.push_back(DefaultValueSetter<T>(default_value));
return *this;
}
Expand All @@ -345,30 +358,41 @@ class TypedAttrChecker {
return *this;
}

void operator()(AttributeMap* attr_map,
bool get_default_value_only = false) const {
void operator()(AttributeMap* attr_map, bool get_default_value_only = false,
bool without_default_value = false) const {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

without_default_value这个参数的命名理解起来比较困难,是不是改成only_check_exist_value之类的比较容易理解

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

without_default_value 让人有点儿容易和get_default_value_only 的作用混淆....

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done,thx!

if (get_default_value_only) {
if (!default_value_setter_.empty()) {
attr_map->emplace(attr_name_, default_value_setter_[0]());
}
return;
}

auto it = attr_map->find(attr_name_);
if (it == attr_map->end()) {
// user do not set this attr
PADDLE_ENFORCE_EQ(
default_value_setter_.empty(), false,
platform::errors::InvalidArgument(
"Attribute (%s) is not set correctly.", attr_name_));
// default_value_setter_ has no more than one element
attr_map->emplace(attr_name_, default_value_setter_[0]());
}
it = attr_map->find(attr_name_);
ExtractAttribute<T> extract_attr(attr_name_);
T* attr_value = extract_attr(it->second);
for (const auto& checker : value_checkers_) {
checker(*attr_value);
if (without_default_value) {
auto it = attr_map->find(attr_name_);
if (it != attr_map->end()) {
ExtractAttribute<T> extract_attr(attr_name_);
T* attr_value = extract_attr(it->second);
for (const auto& checker : value_checkers_) {
checker(*attr_value);
}
}
} else {
auto it = attr_map->find(attr_name_);
if (it == attr_map->end()) {
// user do not set this attr
PADDLE_ENFORCE_EQ(
default_value_setter_.empty(), false,
platform::errors::InvalidArgument(
"Attribute (%s) is not set correctly.", attr_name_));
// default_value_setter_ has no more than one element
auto tmp = attr_map->emplace(attr_name_, default_value_setter_[0]());
it = tmp.first;
}
ExtractAttribute<T> extract_attr(attr_name_);
T* attr_value = extract_attr(it->second);
for (const auto& checker : value_checkers_) {
checker(*attr_value);
}
}
}

Expand All @@ -380,7 +404,7 @@ class TypedAttrChecker {

// check whether op's all attributes fit their own limits
class OpAttrChecker {
typedef std::function<void(AttributeMap*, bool)> AttrChecker;
typedef std::function<void(AttributeMap*, bool, bool)> AttrChecker;

public:
template <typename T>
Expand All @@ -390,18 +414,19 @@ class OpAttrChecker {
return *(checker.target<TypedAttrChecker<T>>());
}

void Check(AttributeMap* attr_map, bool explicit_only = false) const {
void Check(AttributeMap* attr_map, bool explicit_only = false,
bool without_default_value = false) const {
auto checker_num = attr_checkers_.size();
if (explicit_only) checker_num = explicit_checker_num_;
for (size_t i = 0; i < checker_num; ++i) {
attr_checkers_[i](attr_map, false);
attr_checkers_[i](attr_map, false, without_default_value);
}
}

AttributeMap GetAttrsDefaultValuesMap() const {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GetDefaultAttrsValuesMap()?

AttributeMap default_values_map;
for (const auto& checker : attr_checkers_) {
checker(&default_values_map, true);
checker(&default_values_map, true, false);
}
return default_values_map;
}
Expand All @@ -410,15 +435,26 @@ class OpAttrChecker {
explicit_checker_num_ = attr_checkers_.size();
}

void InitDefaultMap() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

InitDefaultMap意义不明,InitDefaultMap -> InitDefaultAttributeMap

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done,thx!

for (const auto& checker : attr_checkers_) {
checker(&attrs_default_, true, false);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

attrs_default_ -> default_attrs_

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done,thx!

}
}

const AttributeMap& GetAttrDefaultMap() const { return attrs_default_; }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GetAttrDefaultMap -> GetDefaultAttrMap

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done,thx!


private:
std::vector<AttrChecker> attr_checkers_;

AttributeMap attrs_default_;

// in order to improve the efficiency of dynamic graph mode,
// we divede the attribute into explicit type and implicit type.
// for explicit attribute, we mean the attribute added in the customized
// op makers, usually it's defined in the overloaded Make method.
// for implicit attribute, we mean the attribute added outside of the Make
// method like "op_role", "op_role_var", and they are useless in dynamic graph
// method like "op_role", "op_role_var", and they are useless in dynamic
// graph
// mode
size_t explicit_checker_num_;
};
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/framework/custom_operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -781,10 +781,12 @@ void RegisterOperatorWithMetaInfo(
const imperative::NameVarBaseMap& var_base_map_in,
const imperative::NameVarBaseMap& var_base_map_out,
const framework::AttributeMap& attrs,
const framework::AttributeMap& attrs_default,
const std::map<std::string, std::string>& inplace_map) {
CustomGradOpMaker<paddle::imperative::OpBase> maker(
type, var_base_map_in, var_base_map_out, attrs, inplace_map,
grad_op_name, grad_op_inputs, grad_op_outputs);
maker.SetDygraphAttrsDefaultMap(attrs_default);
return maker();
};

Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/framework/details/op_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,10 @@ struct OpInfoFiller<T, kGradOpBaseMaker> {
const imperative::NameVarBaseMap& var_base_map_in,
const imperative::NameVarBaseMap& var_base_map_out,
const framework::AttributeMap& attrs,
const framework::AttributeMap& attrs_default,
const std::map<std::string, std::string>& inplace_map) {
T maker(type, var_base_map_in, var_base_map_out, attrs, inplace_map);
maker.SetDygraphAttrsDefaultMap(attrs_default);
return maker();
};
}
Expand Down
14 changes: 14 additions & 0 deletions paddle/fluid/framework/grad_op_desc_maker.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,19 @@ class SingleGradOpMaker<imperative::OpBase>
public:
using GradOpBaseMakerBase::GradOpBaseMakerBase;

virtual const framework::Attribute& GetAttr(const std::string& name) const {
auto it = Attrs().find(name);
if (it == Attrs().end()) {
it = this->AttrsDefault().find(name);
PADDLE_ENFORCE_EQ(it != this->AttrsDefault().end(), true,
platform::errors::NotFound(
"Cannot find attribute [%s] in operator [%s]", name,
this->ForwardOpType()));
}

return it->second;
}

std::shared_ptr<imperative::GradOpNode> operator()() const final {
auto node = this->NewGradNode();
auto& inplace_map = this->GetInplaceMap();
Expand All @@ -228,6 +241,7 @@ class SingleGradOpMaker<imperative::OpBase>
{
imperative::TracedGradOp traced_grad_op(node);
try {
traced_grad_op.SetAttrDefaultMap(this->AttrsDefault());
this->Apply(&traced_grad_op);
} catch (platform::EnforceNotMet& exception) {
framework::AppendErrorOpHint(traced_grad_op.Type(), &exception);
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/op_proto_maker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto,
op_checker_ = attr_checker;
Make();
op_checker_->RecordExplicitCheckerNum();
op_checker_->InitDefaultMap();

AddAttr<int>(OpRoleAttrName(), "The role of this operator")
.InEnum(
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/type_defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ using DygraphGradOpMakerFN =
const imperative::NameVarBaseMap& /*var_base_map_in*/,
const imperative::NameVarBaseMap& /*var_base_map_out*/,
const framework::AttributeMap& /*attributes*/,
const framework::AttributeMap& /*attributes default*/,
const std::map<std::string, std::string>& /*inplace_map*/)>;

using InferVarTypeFN =
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/imperative/basic_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -474,10 +474,10 @@ void BasicEngine::Execute() {
try {
if (tmp_ins_ptr == nullptr) {
OpBase::Run(cur_op.InnerOp(), bwd_ins, tmp_outs, cur_op.Attrs(),
cur_op.place());
cur_op.AttrsDefault(), cur_op.place());
} else {
OpBase::Run(cur_op.InnerOp(), *tmp_ins_ptr, tmp_outs,
cur_op.Attrs(), cur_op.place());
cur_op.Attrs(), cur_op.AttrsDefault(), cur_op.place());
}
} catch (platform::EnforceNotMet& exception) {
Clear();
Expand Down
16 changes: 15 additions & 1 deletion paddle/fluid/imperative/dygraph_grad_maker.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,18 @@ class GradOpBaseMakerBase {
return vec_temp;
}

// Only for dygraph
void SetDygraphAttrsDefaultMap(const framework::AttributeMap& attrs_default) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SetDygraphAttrsDefaultMap -> SetDygraphDefaultAttrsMap

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done,thx!

attrs_default_ = &attrs_default;
}

const framework::AttributeMap& AttrsDefault() const {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DefaultAttrsMap

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done,thx!

return *attrs_default_;
}

const framework::AttributeMap& Attrs() const { return attrs_; }

const framework::Attribute& GetAttr(const std::string& name) const {
virtual const framework::Attribute& GetAttr(const std::string& name) const {
auto it = attrs_.find(name);
PADDLE_ENFORCE_EQ(
it != attrs_.end(), true,
Expand Down Expand Up @@ -199,6 +208,7 @@ class GradOpBaseMakerBase {
const NameVarBaseMap& var_base_map_in_;
const NameVarBaseMap& var_base_map_out_;
const framework::AttributeMap& attrs_;
const framework::AttributeMap* attrs_default_;
const std::map<std::string, std::string>& inplace_map_;
};

Expand Down Expand Up @@ -285,6 +295,10 @@ class TracedGradOp {
return op_->SetAttrMap(attrs);
}

void SetAttrDefaultMap(const framework::AttributeMap& attrs) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SetDefaultAttrMap?

return op_->SetAttrDefaultMap(attrs);
}

void SetAttr(const std::string& name, const framework::Attribute& v) {
op_->SetAttr(name, v);
}
Expand Down
19 changes: 13 additions & 6 deletions paddle/fluid/imperative/execution_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,13 @@ class DygraphExecutionContext : public framework::ExecutionContext {
const framework::RuntimeContext& ctx,
const NameVarMap<VarType>& var_base_map_in,
const NameVarMap<VarType>& var_base_map_out,
const framework::AttributeMap& attrs)
const framework::AttributeMap& attrs,
const framework::AttributeMap& attrs_default)
: ExecutionContext(op, scope, device_context, ctx),
var_base_map_in_(var_base_map_in),
var_base_map_out_(var_base_map_out),
attrs_(attrs) {}
attrs_(attrs),
attrs_default_(attrs_default) {}

std::string InputName(const std::string& name) const override {
auto it = var_base_map_in_.find(name);
Expand Down Expand Up @@ -92,17 +94,21 @@ class DygraphExecutionContext : public framework::ExecutionContext {
}

bool HasAttr(const std::string& name) const override {
return attrs_.count(name) != 0;
return attrs_.count(name) != 0 || attrs_default_.count(name) != 0;
}

const framework::AttributeMap& Attrs() const override { return attrs_; }

const framework::Attribute& GetAttr(const std::string& name) const override {
auto it = attrs_.find(name);

PADDLE_ENFORCE_NE(
it, attrs_.end(),
platform::errors::NotFound("can not find [%s] in attrs", name));
if (it == attrs_.end()) {
it = attrs_default_.find(name);
if (it == attrs_default_.end()) {
PADDLE_THROW(
platform::errors::NotFound("Can not find [%s] in attrs", name));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can not find [%s] in attributes of op %s.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done,thx!

}
}

return it->second;
}
Expand Down Expand Up @@ -192,6 +198,7 @@ class DygraphExecutionContext : public framework::ExecutionContext {
const NameVarMap<VarType>& var_base_map_in_;
const NameVarMap<VarType>& var_base_map_out_;
const framework::AttributeMap& attrs_;
const framework::AttributeMap& attrs_default_;
};

} // namespace imperative
Expand Down
5 changes: 4 additions & 1 deletion paddle/fluid/imperative/infer_shape_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
DygraphInferShapeContext(const NameVarMap<VarType>* in,
const NameVarMap<VarType>* out,
const framework::AttributeMap* attr,
const framework::AttributeMap* attr_default,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const std::string op_type)
: var_base_map_in_(in),
var_base_map_out_(out),
attrs_(attr),
attrs_default_(attr_default),
op_type_(op_type) {}

bool HasInput(const std::string& name) const override {
Expand Down Expand Up @@ -101,7 +103,7 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
}

framework::AttrReader Attrs() const override {
return framework::AttrReader(*attrs_);
return framework::AttrReader(*attrs_, *attrs_default_);
}

std::vector<std::string> Inputs(const std::string& name) const override {
Expand Down Expand Up @@ -395,6 +397,7 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
const NameVarMap<VarType>* var_base_map_in_;
const NameVarMap<VarType>* var_base_map_out_;
const framework::AttributeMap* attrs_;
const framework::AttributeMap* attrs_default_;
const std::string op_type_;
};

Expand Down
Loading