Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 63 additions & 26 deletions paddle/fluid/framework/attribute.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,22 +208,37 @@ Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc);

class AttrReader {
public:
explicit AttrReader(const AttributeMap& attrs) : attrs_(attrs) {}
explicit AttrReader(const AttributeMap& attrs)
: attrs_(attrs), attrs_default_(nullptr) {}

AttrReader(const AttributeMap& attrs, const AttributeMap& attrs_default)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

命名有些奇怪,建议attrs_default -> default_attrs

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同感

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done,thx!

: attrs_(attrs), attrs_default_(&attrs_default) {}

template <typename T>
inline const T& Get(const std::string& name) const {
PADDLE_ENFORCE_NE(attrs_.count(name), 0,
platform::errors::NotFound(
"Attribute (%s) should be in AttributeMap.", name));
auto it = attrs_.find(name);
if (it == attrs_.end()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not use found instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done,thx!

if (attrs_default_ != nullptr) {
it = attrs_default_->find(name);
if (it == attrs_default_->end()) {
Copy link
Collaborator

@phlrain phlrain Jun 18, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bool found = false;

if ( it == attrs_ends() )
{
if (attrs_default_ != nullptr) {
it = attrs_default_->find(name);
if (it == attrs_default_->end())
{
found = false;
}
}

PADDLE_ENFORE( found == true, "")

PADDLE_THROW(platform::errors::NotFound(
"Attribute (%s) should be in AttributeMap.", name));
}
} else {
PADDLE_THROW(platform::errors::NotFound(
"Attribute (%s) should be in AttributeMap.", name));
}
}

Attribute& attr = const_cast<Attribute&>(attrs_.at(name));
Attribute& attr = const_cast<Attribute&>(it->second);
ExtractAttribute<T> extract_attr(name);
T* attr_value = extract_attr(attr);
return *attr_value;
}

private:
const AttributeMap& attrs_;
const AttributeMap* attrs_default_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么不同样使用const &类型,这里有什么考虑吗

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同感

Copy link
Contributor Author

@wanghuancoder wanghuancoder Jun 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

引用必须在构造函数的时候初始化。考虑到动态图静态图都使用AttrReader他的构造函数有2种:
AttrReader(const AttributeMap& attrs)
: attrs_(attrs), default_attrs_(nullptr) {}

AttrReader(const AttributeMap& attrs, const AttributeMap& default_attrs)
: attrs_(attrs), default_attrs_(&default_attrs) {}

};

// check whether a value(attribute) fit a certain limit
Expand Down Expand Up @@ -345,30 +360,41 @@ class TypedAttrChecker {
return *this;
}

void operator()(AttributeMap* attr_map,
bool get_default_value_only = false) const {
void operator()(AttributeMap* attr_map, bool get_default_value_only = false,
bool no_default_value = false) const {
if (get_default_value_only) {
if (!default_value_setter_.empty()) {
attr_map->emplace(attr_name_, default_value_setter_[0]());
}
return;
}

auto it = attr_map->find(attr_name_);
if (it == attr_map->end()) {
// user do not set this attr
PADDLE_ENFORCE_EQ(
default_value_setter_.empty(), false,
platform::errors::InvalidArgument(
"Attribute (%s) is not set correctly.", attr_name_));
// default_value_setter_ has no more than one element
attr_map->emplace(attr_name_, default_value_setter_[0]());
}
it = attr_map->find(attr_name_);
ExtractAttribute<T> extract_attr(attr_name_);
T* attr_value = extract_attr(it->second);
for (const auto& checker : value_checkers_) {
checker(*attr_value);
if (no_default_value) {
auto it = attr_map->find(attr_name_);
if (it != attr_map->end()) {
ExtractAttribute<T> extract_attr(attr_name_);
T* attr_value = extract_attr(it->second);
for (const auto& checker : value_checkers_) {
checker(*attr_value);
}
}
} else {
auto it = attr_map->find(attr_name_);
if (it == attr_map->end()) {
// user do not set this attr
PADDLE_ENFORCE_EQ(
default_value_setter_.empty(), false,
platform::errors::InvalidArgument(
"Attribute (%s) is not set correctly.", attr_name_));
// default_value_setter_ has no more than one element
attr_map->emplace(attr_name_, default_value_setter_[0]());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

auto res = attr_map->emplace(attr_name_, default_value_setter_0);
it = res.first;

emplace 之后会返回一个pair, 这个pair的第一个元素是新的iterator

}
it = attr_map->find(attr_name_);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有了上面的操作后,这个find可以省略

ExtractAttribute<T> extract_attr(attr_name_);
T* attr_value = extract_attr(it->second);
for (const auto& checker : value_checkers_) {
checker(*attr_value);
}
}
}

Expand All @@ -380,7 +406,7 @@ class TypedAttrChecker {

// check whether op's all attributes fit their own limits
class OpAttrChecker {
typedef std::function<void(AttributeMap*, bool)> AttrChecker;
typedef std::function<void(AttributeMap*, bool, bool)> AttrChecker;

public:
template <typename T>
Expand All @@ -390,18 +416,19 @@ class OpAttrChecker {
return *(checker.target<TypedAttrChecker<T>>());
}

void Check(AttributeMap* attr_map, bool explicit_only = false) const {
void Check(AttributeMap* attr_map, bool explicit_only = false,
bool no_default_value = false) const {
auto checker_num = attr_checkers_.size();
if (explicit_only) checker_num = explicit_checker_num_;
for (size_t i = 0; i < checker_num; ++i) {
attr_checkers_[i](attr_map, false);
attr_checkers_[i](attr_map, false, no_default_value);
}
}

AttributeMap GetAttrsDefaultValuesMap() const {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GetDefaultAttrsValuesMap()?

AttributeMap default_values_map;
for (const auto& checker : attr_checkers_) {
checker(&default_values_map, true);
checker(&default_values_map, true, false);
}
return default_values_map;
}
Expand All @@ -410,9 +437,19 @@ class OpAttrChecker {
explicit_checker_num_ = attr_checkers_.size();
}

void InitDefaultMap() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

InitDefaultMap意义不明,InitDefaultMap -> InitDefaultAttributeMap

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done,thx!

for (const auto& checker : attr_checkers_) {
checker(&attrs_default_, true, false);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

attrs_default_ -> default_attrs_

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done,thx!

}
}

const AttributeMap& GetAttrDefaultMap() const { return attrs_default_; }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GetAttrDefaultMap -> GetDefaultAttrMap

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done,thx!


private:
std::vector<AttrChecker> attr_checkers_;

AttributeMap attrs_default_;

// in order to improve the efficiency of dynamic graph mode,
// we divede the attribute into explicit type and implicit type.
// for explicit attribute, we mean the attribute added in the customized
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/framework/custom_operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -781,10 +781,12 @@ void RegisterOperatorWithMetaInfo(
const imperative::NameVarBaseMap& var_base_map_in,
const imperative::NameVarBaseMap& var_base_map_out,
const framework::AttributeMap& attrs,
const framework::AttributeMap& attrs_default,
const std::map<std::string, std::string>& inplace_map) {
CustomGradOpMaker<paddle::imperative::OpBase> maker(
type, var_base_map_in, var_base_map_out, attrs, inplace_map,
grad_op_name, grad_op_inputs, grad_op_outputs);
maker.SetDygraphAttrsDefaultMap(attrs_default);
return maker();
};

Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/framework/details/op_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,10 @@ struct OpInfoFiller<T, kGradOpBaseMaker> {
const imperative::NameVarBaseMap& var_base_map_in,
const imperative::NameVarBaseMap& var_base_map_out,
const framework::AttributeMap& attrs,
const framework::AttributeMap& attrs_default,
const std::map<std::string, std::string>& inplace_map) {
T maker(type, var_base_map_in, var_base_map_out, attrs, inplace_map);
maker.SetDygraphAttrsDefaultMap(attrs_default);
return maker();
};
}
Expand Down
14 changes: 14 additions & 0 deletions paddle/fluid/framework/grad_op_desc_maker.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,19 @@ class SingleGradOpMaker<imperative::OpBase>
public:
using GradOpBaseMakerBase::GradOpBaseMakerBase;

virtual const framework::Attribute& GetAttr(const std::string& name) const {
auto it = Attrs().find(name);
if (it == Attrs().end()) {
it = this->AttrsDefault().find(name);
PADDLE_ENFORCE_EQ(it != this->AttrsDefault().end(), true,
platform::errors::NotFound(
"Cannot find attribute [%s] in operator [%s]", name,
this->ForwardOpType()));
}

return it->second;
}

std::shared_ptr<imperative::GradOpNode> operator()() const final {
auto node = this->NewGradNode();
auto& inplace_map = this->GetInplaceMap();
Expand All @@ -228,6 +241,7 @@ class SingleGradOpMaker<imperative::OpBase>
{
imperative::TracedGradOp traced_grad_op(node);
try {
traced_grad_op.SetAttrDefaultMap(this->AttrsDefault());
this->Apply(&traced_grad_op);
} catch (platform::EnforceNotMet& exception) {
framework::AppendErrorOpHint(traced_grad_op.Type(), &exception);
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/framework/op_proto_maker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto,
AddAttr<std::string>(OpDeviceAttrName(), "Device type of this operator.")
.SetDefault("");
Validate();

op_checker_->InitDefaultMap();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个Init可以放在68行后面,后面这些信息动态图都是不需要的

}

} // namespace framework
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/type_defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ using DygraphGradOpMakerFN =
const imperative::NameVarBaseMap& /*var_base_map_in*/,
const imperative::NameVarBaseMap& /*var_base_map_out*/,
const framework::AttributeMap& /*attributes*/,
const framework::AttributeMap& /*attributes default*/,
const std::map<std::string, std::string>& /*inplace_map*/)>;

using InferVarTypeFN =
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/imperative/basic_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -474,10 +474,10 @@ void BasicEngine::Execute() {
try {
if (tmp_ins_ptr == nullptr) {
OpBase::Run(cur_op.InnerOp(), bwd_ins, tmp_outs, cur_op.Attrs(),
cur_op.place());
cur_op.AttrsDefault(), cur_op.place());
} else {
OpBase::Run(cur_op.InnerOp(), *tmp_ins_ptr, tmp_outs,
cur_op.Attrs(), cur_op.place());
cur_op.Attrs(), cur_op.AttrsDefault(), cur_op.place());
}
} catch (platform::EnforceNotMet& exception) {
Clear();
Expand Down
16 changes: 15 additions & 1 deletion paddle/fluid/imperative/dygraph_grad_maker.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,18 @@ class GradOpBaseMakerBase {
return vec_temp;
}

// Only for dygraph
void SetDygraphAttrsDefaultMap(const framework::AttributeMap& attrs_default) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SetDygraphAttrsDefaultMap -> SetDygraphDefaultAttrsMap

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done,thx!

attrs_default_ = &attrs_default;
}

const framework::AttributeMap& AttrsDefault() const {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DefaultAttrsMap

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done,thx!

return *attrs_default_;
}

const framework::AttributeMap& Attrs() const { return attrs_; }

const framework::Attribute& GetAttr(const std::string& name) const {
virtual const framework::Attribute& GetAttr(const std::string& name) const {
auto it = attrs_.find(name);
PADDLE_ENFORCE_EQ(
it != attrs_.end(), true,
Expand Down Expand Up @@ -199,6 +208,7 @@ class GradOpBaseMakerBase {
const NameVarBaseMap& var_base_map_in_;
const NameVarBaseMap& var_base_map_out_;
const framework::AttributeMap& attrs_;
const framework::AttributeMap* attrs_default_;
const std::map<std::string, std::string>& inplace_map_;
};

Expand Down Expand Up @@ -285,6 +295,10 @@ class TracedGradOp {
return op_->SetAttrMap(attrs);
}

void SetAttrDefaultMap(const framework::AttributeMap& attrs) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SetDefaultAttrMap?

return op_->SetAttrDefaultMap(attrs);
}

void SetAttr(const std::string& name, const framework::Attribute& v) {
op_->SetAttr(name, v);
}
Expand Down
19 changes: 13 additions & 6 deletions paddle/fluid/imperative/execution_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,13 @@ class DygraphExecutionContext : public framework::ExecutionContext {
const framework::RuntimeContext& ctx,
const NameVarMap<VarType>& var_base_map_in,
const NameVarMap<VarType>& var_base_map_out,
const framework::AttributeMap& attrs)
const framework::AttributeMap& attrs,
const framework::AttributeMap& attrs_default)
: ExecutionContext(op, scope, device_context, ctx),
var_base_map_in_(var_base_map_in),
var_base_map_out_(var_base_map_out),
attrs_(attrs) {}
attrs_(attrs),
attrs_default_(attrs_default) {}

std::string InputName(const std::string& name) const override {
auto it = var_base_map_in_.find(name);
Expand Down Expand Up @@ -92,17 +94,21 @@ class DygraphExecutionContext : public framework::ExecutionContext {
}

bool HasAttr(const std::string& name) const override {
return attrs_.count(name) != 0;
return attrs_.count(name) != 0 || attrs_default_.count(name) != 0;
}

const framework::AttributeMap& Attrs() const override { return attrs_; }

const framework::Attribute& GetAttr(const std::string& name) const override {
auto it = attrs_.find(name);

PADDLE_ENFORCE_NE(
it, attrs_.end(),
platform::errors::NotFound("can not find [%s] in attrs", name));
if (it == attrs_.end()) {
it = attrs_default_.find(name);
if (it == attrs_default_.end()) {
PADDLE_THROW(
platform::errors::NotFound("Can not find [%s] in attrs", name));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can not find [%s] in attributes of op %s.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done,thx!

}
}

return it->second;
}
Expand Down Expand Up @@ -192,6 +198,7 @@ class DygraphExecutionContext : public framework::ExecutionContext {
const NameVarMap<VarType>& var_base_map_in_;
const NameVarMap<VarType>& var_base_map_out_;
const framework::AttributeMap& attrs_;
const framework::AttributeMap& attrs_default_;
};

} // namespace imperative
Expand Down
5 changes: 4 additions & 1 deletion paddle/fluid/imperative/infer_shape_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
DygraphInferShapeContext(const NameVarMap<VarType>* in,
const NameVarMap<VarType>* out,
const framework::AttributeMap* attr,
const framework::AttributeMap* attr_default,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const std::string op_type)
: var_base_map_in_(in),
var_base_map_out_(out),
attrs_(attr),
attrs_default_(attr_default),
op_type_(op_type) {}

bool HasInput(const std::string& name) const override {
Expand Down Expand Up @@ -101,7 +103,7 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
}

framework::AttrReader Attrs() const override {
return framework::AttrReader(*attrs_);
return framework::AttrReader(*attrs_, *attrs_default_);
}

std::vector<std::string> Inputs(const std::string& name) const override {
Expand Down Expand Up @@ -395,6 +397,7 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
const NameVarMap<VarType>* var_base_map_in_;
const NameVarMap<VarType>* var_base_map_out_;
const framework::AttributeMap* attrs_;
const framework::AttributeMap* attrs_default_;
const std::string op_type_;
};

Expand Down
23 changes: 16 additions & 7 deletions paddle/fluid/imperative/infer_var_type_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,28 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext {
public:
RuntimeInferVarTypeContext(const NameVarMap<VarType>& inputs,
const NameVarMap<VarType>& outputs,
const framework::AttributeMap& attrs_map)
const framework::AttributeMap& attrs_map,
const framework::AttributeMap& attrs_map_default)
: InferVarTypeContext(nullptr, nullptr),
inputs_(inputs),
outputs_(outputs),
attrs_(attrs_map) {}
attrs_(attrs_map),
attrs_default_(attrs_map_default) {}

virtual ~RuntimeInferVarTypeContext() {}

framework::Attribute GetAttr(const std::string& name) const override {
auto iter = attrs_.find(name);
PADDLE_ENFORCE_EQ(
iter != attrs_.end(), true,
platform::errors::NotFound("Cannot find attribute %s", name));
return iter->second;
auto it = attrs_.find(name);

if (it == attrs_.end()) {
it = attrs_default_.find(name);
if (it == attrs_default_.end()) {
PADDLE_THROW(
platform::errors::NotFound("Can not find [%s] in attrs", name));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done,thx!

}
}

return it->second;
}

bool HasInput(const std::string& name) const override {
Expand Down Expand Up @@ -233,6 +241,7 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext {
const NameVarMap<VarType>& inputs_;
const NameVarMap<VarType>& outputs_;
const framework::AttributeMap& attrs_;
const framework::AttributeMap& attrs_default_;
};

} // namespace imperative
Expand Down
Loading