From dd3329de6633c6f21edb17db2c575fd43da7c376 Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Fri, 4 Jun 2021 10:27:07 +0000 Subject: [PATCH 01/17] optimize attr default value, test=develop --- paddle/fluid/framework/attribute.h | 87 +++++++++++++------ paddle/fluid/framework/grad_op_desc_maker.h | 4 + paddle/fluid/framework/op_proto_maker.cc | 2 + paddle/fluid/imperative/basic_engine.cc | 4 +- paddle/fluid/imperative/dygraph_grad_maker.h | 4 + paddle/fluid/imperative/execution_context.h | 20 +++-- paddle/fluid/imperative/infer_shape_context.h | 5 +- .../fluid/imperative/infer_var_type_context.h | 24 +++-- paddle/fluid/imperative/layer.cc | 10 ++- paddle/fluid/imperative/op_base.h | 30 +++++-- .../fluid/imperative/partial_grad_engine.cc | 5 +- paddle/fluid/imperative/prepared_operator.cc | 37 ++++---- paddle/fluid/imperative/prepared_operator.h | 12 ++- paddle/fluid/imperative/tests/test_layer.cc | 16 ++-- paddle/fluid/imperative/tracer.cc | 5 +- .../test_common_infer_shape_functions.cc | 2 +- 16 files changed, 186 insertions(+), 81 deletions(-) diff --git a/paddle/fluid/framework/attribute.h b/paddle/fluid/framework/attribute.h index 66b988ee1f1fb6..83f42e75b42d15 100644 --- a/paddle/fluid/framework/attribute.h +++ b/paddle/fluid/framework/attribute.h @@ -208,15 +208,27 @@ Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc); class AttrReader { public: - explicit AttrReader(const AttributeMap& attrs) : attrs_(attrs) {} + explicit AttrReader(const AttributeMap& attrs) + : attrs_(attrs), attrs_default_(nullptr) {} + + AttrReader(const AttributeMap& attrs, const AttributeMap& attrs_default) + : attrs_(attrs), attrs_default_(&attrs_default) {} template inline const T& Get(const std::string& name) const { - PADDLE_ENFORCE_NE(attrs_.count(name), 0, - platform::errors::NotFound( - "Attribute (%s) should be in AttributeMap.", name)); + auto it = attrs_.find(name); + if (it == attrs_.end() && attrs_default_ != nullptr) { + it = attrs_default_->find(name); + if (it == attrs_default_->end()) { + PADDLE_THROW(platform::errors::NotFound( + "Attribute (%s) should be in AttributeMap.", name)); + } + } else { + PADDLE_THROW(platform::errors::NotFound( + "Attribute (%s) should be in AttributeMap.", name)); + } - Attribute& attr = const_cast(attrs_.at(name)); + Attribute& attr = const_cast(it->second); ExtractAttribute extract_attr(name); T* attr_value = extract_attr(attr); return *attr_value; @@ -224,6 +236,7 @@ class AttrReader { private: const AttributeMap& attrs_; + const AttributeMap* attrs_default_; }; // check whether a value(attribute) fit a certain limit @@ -345,8 +358,8 @@ class TypedAttrChecker { return *this; } - void operator()(AttributeMap* attr_map, - bool get_default_value_only = false) const { + void operator()(AttributeMap* attr_map, bool get_default_value_only = false, + bool no_default_value = false) const { if (get_default_value_only) { if (!default_value_setter_.empty()) { attr_map->emplace(attr_name_, default_value_setter_[0]()); @@ -354,21 +367,32 @@ class TypedAttrChecker { return; } - auto it = attr_map->find(attr_name_); - if (it == attr_map->end()) { - // user do not set this attr - PADDLE_ENFORCE_EQ( - default_value_setter_.empty(), false, - platform::errors::InvalidArgument( - "Attribute (%s) is not set correctly.", attr_name_)); - // default_value_setter_ has no more than one element - attr_map->emplace(attr_name_, default_value_setter_[0]()); - } - it = attr_map->find(attr_name_); - ExtractAttribute extract_attr(attr_name_); - T* attr_value = extract_attr(it->second); - for (const auto& checker : value_checkers_) { - checker(*attr_value); + if (no_default_value) { + auto it = attr_map->find(attr_name_); + if (it != attr_map->end()) { + ExtractAttribute extract_attr(attr_name_); + T* attr_value = extract_attr(it->second); + for (const auto& checker : value_checkers_) { + checker(*attr_value); + } + } + } else { + auto it = attr_map->find(attr_name_); + if (it == attr_map->end()) { + // user do not set this attr + PADDLE_ENFORCE_EQ( + default_value_setter_.empty(), false, + platform::errors::InvalidArgument( + "Attribute (%s) is not set correctly.", attr_name_)); + // default_value_setter_ has no more than one element + attr_map->emplace(attr_name_, default_value_setter_[0]()); + } + it = attr_map->find(attr_name_); + ExtractAttribute extract_attr(attr_name_); + T* attr_value = extract_attr(it->second); + for (const auto& checker : value_checkers_) { + checker(*attr_value); + } } } @@ -380,7 +404,7 @@ class TypedAttrChecker { // check whether op's all attributes fit their own limits class OpAttrChecker { - typedef std::function AttrChecker; + typedef std::function AttrChecker; public: template @@ -390,18 +414,19 @@ class OpAttrChecker { return *(checker.target>()); } - void Check(AttributeMap* attr_map, bool explicit_only = false) const { + void Check(AttributeMap* attr_map, bool explicit_only = false, + bool no_default_value = false) const { auto checker_num = attr_checkers_.size(); if (explicit_only) checker_num = explicit_checker_num_; for (size_t i = 0; i < checker_num; ++i) { - attr_checkers_[i](attr_map, false); + attr_checkers_[i](attr_map, false, no_default_value); } } AttributeMap GetAttrsDefaultValuesMap() const { AttributeMap default_values_map; for (const auto& checker : attr_checkers_) { - checker(&default_values_map, true); + checker(&default_values_map, true, false); } return default_values_map; } @@ -410,9 +435,19 @@ class OpAttrChecker { explicit_checker_num_ = attr_checkers_.size(); } + void InitDefaultMap() { + for (const auto& checker : attr_checkers_) { + checker(&attrs_default_, true, false); + } + } + + const AttributeMap& GetAttrDefaultMap() const { return attrs_default_; } + private: std::vector attr_checkers_; + AttributeMap attrs_default_; + // in order to improve the efficiency of dynamic graph mode, // we divede the attribute into explicit type and implicit type. // for explicit attribute, we mean the attribute added in the customized diff --git a/paddle/fluid/framework/grad_op_desc_maker.h b/paddle/fluid/framework/grad_op_desc_maker.h index b0247fe795b3ea..ec6ebcf0e6740e 100644 --- a/paddle/fluid/framework/grad_op_desc_maker.h +++ b/paddle/fluid/framework/grad_op_desc_maker.h @@ -229,6 +229,10 @@ class SingleGradOpMaker imperative::TracedGradOp traced_grad_op(node); try { this->Apply(&traced_grad_op); + traced_grad_op.SetAttrDefaultMap( + paddle::framework::OpInfoMap::Instance() + .Get(this->ForwardOpType()) + .checker_->GetAttrDefaultMap()); } catch (platform::EnforceNotMet& exception) { framework::AppendErrorOpHint(traced_grad_op.Type(), &exception); throw std::move(exception); diff --git a/paddle/fluid/framework/op_proto_maker.cc b/paddle/fluid/framework/op_proto_maker.cc index 0b9fd0a47e22c7..9d2bf183f0cbc9 100644 --- a/paddle/fluid/framework/op_proto_maker.cc +++ b/paddle/fluid/framework/op_proto_maker.cc @@ -93,6 +93,8 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto, AddAttr(OpDeviceAttrName(), "Device type of this operator.") .SetDefault(""); Validate(); + + op_checker_->InitDefaultMap(); } } // namespace framework diff --git a/paddle/fluid/imperative/basic_engine.cc b/paddle/fluid/imperative/basic_engine.cc index 7bcc3d6c608c94..f52b9b73df6789 100644 --- a/paddle/fluid/imperative/basic_engine.cc +++ b/paddle/fluid/imperative/basic_engine.cc @@ -474,10 +474,10 @@ void BasicEngine::Execute() { try { if (tmp_ins_ptr == nullptr) { OpBase::Run(cur_op.InnerOp(), bwd_ins, tmp_outs, cur_op.Attrs(), - cur_op.place()); + cur_op.AttrsDefault(), cur_op.place()); } else { OpBase::Run(cur_op.InnerOp(), *tmp_ins_ptr, tmp_outs, - cur_op.Attrs(), cur_op.place()); + cur_op.Attrs(), cur_op.AttrsDefault(), cur_op.place()); } } catch (platform::EnforceNotMet& exception) { Clear(); diff --git a/paddle/fluid/imperative/dygraph_grad_maker.h b/paddle/fluid/imperative/dygraph_grad_maker.h index 7fefc9ccc67b52..dceb663fc8c153 100644 --- a/paddle/fluid/imperative/dygraph_grad_maker.h +++ b/paddle/fluid/imperative/dygraph_grad_maker.h @@ -285,6 +285,10 @@ class TracedGradOp { return op_->SetAttrMap(attrs); } + void SetAttrDefaultMap(const framework::AttributeMap& attrs) { + return op_->SetAttrDefaultMap(attrs); + } + void SetAttr(const std::string& name, const framework::Attribute& v) { op_->SetAttr(name, v); } diff --git a/paddle/fluid/imperative/execution_context.h b/paddle/fluid/imperative/execution_context.h index 398b1292e2ffe0..a80c84e38442d2 100644 --- a/paddle/fluid/imperative/execution_context.h +++ b/paddle/fluid/imperative/execution_context.h @@ -35,11 +35,13 @@ class DygraphExecutionContext : public framework::ExecutionContext { const framework::RuntimeContext& ctx, const NameVarMap& var_base_map_in, const NameVarMap& var_base_map_out, - const framework::AttributeMap& attrs) + const framework::AttributeMap& attrs, + const framework::AttributeMap& attrs_default) : ExecutionContext(op, scope, device_context, ctx), var_base_map_in_(var_base_map_in), var_base_map_out_(var_base_map_out), - attrs_(attrs) {} + attrs_(attrs), + attrs_default_(attrs_default) {} std::string InputName(const std::string& name) const override { auto it = var_base_map_in_.find(name); @@ -92,7 +94,7 @@ class DygraphExecutionContext : public framework::ExecutionContext { } bool HasAttr(const std::string& name) const override { - return attrs_.count(name) != 0; + return attrs_.count(name) != 0 || attrs_default_.count(name) != 0; } const framework::AttributeMap& Attrs() const override { return attrs_; } @@ -100,9 +102,14 @@ class DygraphExecutionContext : public framework::ExecutionContext { const framework::Attribute& GetAttr(const std::string& name) const override { auto it = attrs_.find(name); - PADDLE_ENFORCE_NE( - it, attrs_.end(), - platform::errors::NotFound("can not find [%s] in attrs", name)); + bool find = (it != attrs_.end()); + if (it == attrs_.end()) { + it = attrs_default_.find(name); + find = (it != attrs_default_.end()); + } + + PADDLE_ENFORCE_NE(find, false, platform::errors::NotFound( + "can not find [%s] in attrs", name)); return it->second; } @@ -192,6 +199,7 @@ class DygraphExecutionContext : public framework::ExecutionContext { const NameVarMap& var_base_map_in_; const NameVarMap& var_base_map_out_; const framework::AttributeMap& attrs_; + const framework::AttributeMap& attrs_default_; }; } // namespace imperative diff --git a/paddle/fluid/imperative/infer_shape_context.h b/paddle/fluid/imperative/infer_shape_context.h index fcd4545a2c82d3..79f71775d35304 100644 --- a/paddle/fluid/imperative/infer_shape_context.h +++ b/paddle/fluid/imperative/infer_shape_context.h @@ -35,10 +35,12 @@ class DygraphInferShapeContext : public framework::InferShapeContext { DygraphInferShapeContext(const NameVarMap* in, const NameVarMap* out, const framework::AttributeMap* attr, + const framework::AttributeMap* attr_default, const std::string op_type) : var_base_map_in_(in), var_base_map_out_(out), attrs_(attr), + attrs_default_(attr_default), op_type_(op_type) {} bool HasInput(const std::string& name) const override { @@ -101,7 +103,7 @@ class DygraphInferShapeContext : public framework::InferShapeContext { } framework::AttrReader Attrs() const override { - return framework::AttrReader(*attrs_); + return framework::AttrReader(*attrs_, *attrs_default_); } std::vector Inputs(const std::string& name) const override { @@ -395,6 +397,7 @@ class DygraphInferShapeContext : public framework::InferShapeContext { const NameVarMap* var_base_map_in_; const NameVarMap* var_base_map_out_; const framework::AttributeMap* attrs_; + const framework::AttributeMap* attrs_default_; const std::string op_type_; }; diff --git a/paddle/fluid/imperative/infer_var_type_context.h b/paddle/fluid/imperative/infer_var_type_context.h index f740507fa50860..4286e6217b54d3 100644 --- a/paddle/fluid/imperative/infer_var_type_context.h +++ b/paddle/fluid/imperative/infer_var_type_context.h @@ -32,20 +32,29 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext { public: RuntimeInferVarTypeContext(const NameVarMap& inputs, const NameVarMap& outputs, - const framework::AttributeMap& attrs_map) + const framework::AttributeMap& attrs_map, + const framework::AttributeMap& attrs_map_default) : InferVarTypeContext(nullptr, nullptr), inputs_(inputs), outputs_(outputs), - attrs_(attrs_map) {} + attrs_(attrs_map), + attrs_default_(attrs_map_default) {} virtual ~RuntimeInferVarTypeContext() {} framework::Attribute GetAttr(const std::string& name) const override { - auto iter = attrs_.find(name); - PADDLE_ENFORCE_EQ( - iter != attrs_.end(), true, - platform::errors::NotFound("Cannot find attribute %s", name)); - return iter->second; + auto it = attrs_.find(name); + + bool find = (it != attrs_.end()); + if (it == attrs_.end()) { + it = attrs_default_.find(name); + find = (it != attrs_default_.end()); + } + + PADDLE_ENFORCE_NE(find, false, platform::errors::NotFound( + "Can not find [%s] in attrs", name)); + + return it->second; } bool HasInput(const std::string& name) const override { @@ -233,6 +242,7 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext { const NameVarMap& inputs_; const NameVarMap& outputs_; const framework::AttributeMap& attrs_; + const framework::AttributeMap& attrs_default_; }; } // namespace imperative diff --git a/paddle/fluid/imperative/layer.cc b/paddle/fluid/imperative/layer.cc index a4af3117d3e32e..078dcb5db89f5b 100644 --- a/paddle/fluid/imperative/layer.cc +++ b/paddle/fluid/imperative/layer.cc @@ -329,6 +329,7 @@ static void OpBaseRunImpl(const framework::OperatorBase& op, const NameVarMap& ins, const NameVarMap& outs, const framework::AttributeMap& attrs, + const framework::AttributeMap& attrs_default, const platform::Place& place) { auto* op_kernel = dynamic_cast(&op); PADDLE_ENFORCE_NOT_NULL( @@ -336,7 +337,8 @@ static void OpBaseRunImpl(const framework::OperatorBase& op, "Only support operator with kernel in Dygraph mode.")); auto& info = op.Info(); if (info.infer_var_type_) { - RuntimeInferVarTypeContext infer_var_type_ctx(ins, outs, attrs); + RuntimeInferVarTypeContext infer_var_type_ctx(ins, outs, attrs, + attrs_default); info.infer_var_type_(&infer_var_type_ctx); } @@ -395,16 +397,18 @@ void OpBase::Run(const framework::OperatorBase& op, const NameVarMap& ins, const NameVarMap& outs, const framework::AttributeMap& attrs, + const framework::AttributeMap& attrs_default, const platform::Place& place) { - OpBaseRunImpl(op, ins, outs, attrs, place); + OpBaseRunImpl(op, ins, outs, attrs, attrs_default, place); } void OpBase::Run(const framework::OperatorBase& op, const NameVarMap& ins, const NameVarMap& outs, const framework::AttributeMap& attrs, + const framework::AttributeMap& attrs_default, const platform::Place& place) { - OpBaseRunImpl(op, ins, outs, attrs, place); + OpBaseRunImpl(op, ins, outs, attrs, attrs_default, place); } void ClearNoNeedBufferInputs(OpBase* op) { diff --git a/paddle/fluid/imperative/op_base.h b/paddle/fluid/imperative/op_base.h index 0164ff9313cdfe..4be7660824f0e4 100644 --- a/paddle/fluid/imperative/op_base.h +++ b/paddle/fluid/imperative/op_base.h @@ -50,6 +50,10 @@ class OpBase { const framework::AttributeMap& Attrs() const { return attrs_; } + const framework::AttributeMap& AttrsDefault() const { + return *attrs_default_; + } + const framework::OpInfo& Info() const { PADDLE_ENFORCE_NOT_NULL(op_, platform::errors::PreconditionNotMet( "OpBase::Info() should be called after " @@ -99,6 +103,10 @@ class OpBase { void SetAttrMap(const framework::AttributeMap& attrs) { attrs_ = attrs; } + void SetAttrDefaultMap(const framework::AttributeMap& attrs_default) { + attrs_default_ = &attrs_default; + } + void SetAttr(const std::string& name, const framework::Attribute& v) { attrs_[name] = v; } @@ -110,14 +118,23 @@ class OpBase { const framework::AttributeMap& Attrs() { return attrs_; } - bool HasAttr(const std::string& name) const { return attrs_.count(name) > 0; } + const framework::AttributeMap& AttrsDefault() { return *attrs_default_; } + + bool HasAttr(const std::string& name) const { + return attrs_.count(name) > 0 || attrs_default_->count(name) > 0; + } const framework::Attribute& GetAttr(const std::string& name) const { auto it = attrs_.find(name); - PADDLE_ENFORCE_NE( - it, attrs_.end(), - platform::errors::NotFound("can not find attribute [%s]", name)); - return it->second; + if (it != attrs_.end()) { + return it->second; + } else { + auto it_default = attrs_default_->find(name); + PADDLE_ENFORCE_NE( + it_default, attrs_default_->end(), + platform::errors::NotFound("can not find attribute [%s]", name)); + return it_default->second; + } } template @@ -156,12 +173,14 @@ class OpBase { const NameVarMap& ins, const NameVarMap& outs, const framework::AttributeMap& attrs, + const framework::AttributeMap& attrs_default, const platform::Place& place); static void Run(const framework::OperatorBase& op, const NameVarMap& ins, const NameVarMap& outs, const framework::AttributeMap& attrs, + const framework::AttributeMap& attrs_default, const platform::Place& place); private: @@ -174,6 +193,7 @@ class OpBase { NameVarMap ins_; NameVarMap outs_; framework::AttributeMap attrs_; + const framework::AttributeMap* attrs_default_; std::unique_ptr op_; platform::Place place_; size_t id_{-1UL}; diff --git a/paddle/fluid/imperative/partial_grad_engine.cc b/paddle/fluid/imperative/partial_grad_engine.cc index 3da3a05ed1071c..c0c260b306866f 100644 --- a/paddle/fluid/imperative/partial_grad_engine.cc +++ b/paddle/fluid/imperative/partial_grad_engine.cc @@ -887,8 +887,9 @@ void PartialGradTask::RunEachOp(OpBase *op) { OpBase::Run(op->InnerOp(), tmp_ins, tmp_outs, op->Attrs(), op->place()); if (create_graph_) { - auto double_grad_node = CreateGradOpNode(op->InnerOp(), tmp_ins, tmp_outs, - op->Attrs(), op->place(), {}); + auto double_grad_node = + CreateGradOpNode(op->InnerOp(), tmp_ins, tmp_outs, op->Attrs(), + op->AttrsDefault(), op->place(), {}); PADDLE_ENFORCE_NOT_NULL( double_grad_node, platform::errors::NotFound("The Op %s doesn't have any grad op. If you " diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index 2a3b6424d4a14e..2d5154d4078598 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -88,7 +88,8 @@ PreparedOp PrepareImpl(const NameVarMap& ins, const NameVarMap& outs, const framework::OperatorWithKernel& op, const platform::Place& place, - const framework::AttributeMap& attrs) { + const framework::AttributeMap& attrs, + const framework::AttributeMap& attrs_default) { platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto* dev_ctx = pool.Get(place); @@ -105,9 +106,9 @@ PreparedOp PrepareImpl(const NameVarMap& ins, #endif // 1. get expected kernel key - auto expected_kernel_key = - op.GetExpectedKernelType(DygraphExecutionContext( - op, framework::Scope(), *dev_ctx, ctx, ins, outs, attrs)); + auto expected_kernel_key = op.GetExpectedKernelType( + DygraphExecutionContext(op, framework::Scope(), *dev_ctx, ctx, + ins, outs, attrs, attrs_default)); VLOG(3) << "expected_kernel_key:" << expected_kernel_key; // 2. check if op[type] has kernel registered. @@ -145,16 +146,19 @@ PreparedOp PreparedOp::Prepare(const NameVarMap& ins, const NameVarMap& outs, const framework::OperatorWithKernel& op, const platform::Place& place, - const framework::AttributeMap& attrs) { - return PrepareImpl(ins, outs, op, place, attrs); + const framework::AttributeMap& attrs, + const framework::AttributeMap& attrs_default) { + return PrepareImpl(ins, outs, op, place, attrs, attrs_default); } PreparedOp PreparedOp::Prepare(const NameVarMap& ins, const NameVarMap& outs, const framework::OperatorWithKernel& op, const platform::Place& place, - const framework::AttributeMap& attrs) { - return PrepareImpl(ins, outs, op, place, attrs); + const framework::AttributeMap& attrs, + const framework::AttributeMap& attrs_default) { + return PrepareImpl(ins, outs, op, place, attrs, + attrs_default); } template @@ -163,17 +167,18 @@ static void PreparedOpRunImpl( const framework::OpKernelType& kernel_type, const framework::OperatorWithKernel::OpKernelFunc& func, platform::DeviceContext* dev_ctx, const NameVarMap& ins, - const NameVarMap& outs, const framework::AttributeMap& attrs) { + const NameVarMap& outs, const framework::AttributeMap& attrs, + const framework::AttributeMap& attrs_default) { // TODO(zjl): remove scope in dygraph framework::Scope scope; DygraphInferShapeContext infer_shape_ctx(&ins, &outs, &attrs, - op.Type()); + &attrs_default, op.Type()); static_cast(op).InferShape( &infer_shape_ctx); func(DygraphExecutionContext(op, scope, *dev_ctx, ctx, ins, outs, - attrs)); + attrs, attrs_default)); /** * [ Why need handle complex gradient to real gradient? ] @@ -194,16 +199,18 @@ static void PreparedOpRunImpl( void PreparedOp::Run(const NameVarMap& ins, const NameVarMap& outs, - const framework::AttributeMap& attrs) { + const framework::AttributeMap& attrs, + const framework::AttributeMap& attrs_default) { PreparedOpRunImpl(op_, ctx_, kernel_type_, func_, dev_ctx_, ins, - outs, attrs); + outs, attrs, attrs_default); } void PreparedOp::Run(const NameVarMap& ins, const NameVarMap& outs, - const framework::AttributeMap& attrs) { + const framework::AttributeMap& attrs, + const framework::AttributeMap& attrs_default) { PreparedOpRunImpl(op_, ctx_, kernel_type_, func_, dev_ctx_, - ins, outs, attrs); + ins, outs, attrs, attrs_default); } } // namespace imperative diff --git a/paddle/fluid/imperative/prepared_operator.h b/paddle/fluid/imperative/prepared_operator.h index 1f6be5483be30b..f5b133f7bdc187 100644 --- a/paddle/fluid/imperative/prepared_operator.h +++ b/paddle/fluid/imperative/prepared_operator.h @@ -151,20 +151,24 @@ class PreparedOp { const NameVarMap& outs, const framework::OperatorWithKernel& op, const platform::Place& place, - const framework::AttributeMap& attrs); + const framework::AttributeMap& attrs, + const framework::AttributeMap& attrs_default); static PreparedOp Prepare(const NameVarMap& ins, const NameVarMap& outs, const framework::OperatorWithKernel& op, const platform::Place& place, - const framework::AttributeMap& attrs); + const framework::AttributeMap& attrs, + const framework::AttributeMap& attrs_default); void Run(const NameVarMap& in, const NameVarMap& out, - const framework::AttributeMap& attrs); + const framework::AttributeMap& attrs, + const framework::AttributeMap& attrs_default); void Run(const NameVarMap& ins, const NameVarMap& outs, - const framework::AttributeMap& attrs); + const framework::AttributeMap& attrs, + const framework::AttributeMap& attrs_default); const framework::OpKernelType& kernel_type() const { return kernel_type_; } diff --git a/paddle/fluid/imperative/tests/test_layer.cc b/paddle/fluid/imperative/tests/test_layer.cc index 4a30ffb7e3d01f..3c9711a051ac80 100644 --- a/paddle/fluid/imperative/tests/test_layer.cc +++ b/paddle/fluid/imperative/tests/test_layer.cc @@ -43,10 +43,12 @@ template class TestRuntimeInferVarTypeContext : public RuntimeInferVarTypeContext { public: - TestRuntimeInferVarTypeContext(const NameVarMap& inputs, - const NameVarMap& outputs, - const framework::AttributeMap& attrs_map) - : RuntimeInferVarTypeContext(inputs, outputs, attrs_map) {} + TestRuntimeInferVarTypeContext( + const NameVarMap& inputs, const NameVarMap& outputs, + const framework::AttributeMap& attrs_map, + const framework::AttributeMap& attrs_map_default) + : RuntimeInferVarTypeContext(inputs, outputs, attrs_map, + attrs_map_default) {} bool HasVar(const std::string& name) const { return RuntimeInferVarTypeContext::HasVar(name); @@ -125,7 +127,7 @@ TEST(test_layer, test_runtime_context) { auto* ctx = new imperative::TestRuntimeInferVarTypeContext( - ins, outs, attrs); + ins, outs, attrs, {}); ASSERT_TRUE(ctx->HasInput("X")); ASSERT_TRUE(ctx->HasOutput("Out")); @@ -358,7 +360,7 @@ TEST(test_layer, test_dygraph_execution_context) { framework::Scope scope; DygraphExecutionContext dy_exe_context( - *(op.get()), scope, *dev_ctx, ctx, ins, outs, concat_att_map); + *(op.get()), scope, *dev_ctx, ctx, ins, outs, concat_att_map, {}); ASSERT_EQ(dy_exe_context.InputSize("X"), 1u); ASSERT_EQ(dy_exe_context.InputName("X"), "vin"); @@ -386,7 +388,7 @@ TEST(test_layer, test_dygraph_infershape_context) { concat_att_map["axis"] = 1; DygraphInferShapeContext infer_shape_ctx( - &ins, &outs, &concat_att_map, "dummy"); + &ins, &outs, &concat_att_map, {}, "dummy"); bool have_x = infer_shape_ctx.HasOutputs("Out"); ASSERT_EQ(have_x, true); diff --git a/paddle/fluid/imperative/tracer.cc b/paddle/fluid/imperative/tracer.cc index 41ad70e5a5741b..492564c9f1278c 100644 --- a/paddle/fluid/imperative/tracer.cc +++ b/paddle/fluid/imperative/tracer.cc @@ -154,7 +154,7 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins, const auto& op_info = op->Info(); auto* attr_checker = op_info.Checker(); if (attr_checker) { - attr_checker->Check(&attrs, true); + attr_checker->Check(&attrs, true, true); } NameVarBaseMap new_ins = ins; @@ -181,7 +181,8 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins, #endif } - OpBase::Run(*op, new_ins, outs, attrs, place); + OpBase::Run(*op, new_ins, outs, attrs, attr_checker->GetAttrDefaultMap(), + place); } catch (platform::EnforceNotMet& exception) { framework::AppendErrorOpHint(type, &exception); throw std::move(exception); diff --git a/paddle/fluid/operators/test_common_infer_shape_functions.cc b/paddle/fluid/operators/test_common_infer_shape_functions.cc index ca8f6ce84fc571..60eeb66ae7d1ec 100644 --- a/paddle/fluid/operators/test_common_infer_shape_functions.cc +++ b/paddle/fluid/operators/test_common_infer_shape_functions.cc @@ -48,7 +48,7 @@ class DygraphInferShapeTest { void SetOpType(const std::string& op_type) { op_type_ = op_type; } void Run(std::function infer_shape) { imperative::DygraphInferShapeContext ctx( - &ins_, &outs_, &attrs_, op_type_); + &ins_, &outs_, &attrs_, {}, op_type_); infer_shape(&ctx); for (const auto& pair : expected_dims_) { auto out = outs_[pair.first][0]; From c8b557b972f1cd8156a970f8b1d5f1bad67244dd Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Fri, 4 Jun 2021 10:37:15 +0000 Subject: [PATCH 02/17] refine, test=develop --- paddle/fluid/imperative/layer.cc | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/imperative/layer.cc b/paddle/fluid/imperative/layer.cc index 078dcb5db89f5b..64e182d2b223d8 100644 --- a/paddle/fluid/imperative/layer.cc +++ b/paddle/fluid/imperative/layer.cc @@ -371,13 +371,14 @@ static void OpBaseRunImpl(const framework::OperatorBase& op, * after the execution of op, but the original input is directly * overwritten in the previous dynamic graph implemention. */ - auto prepared_op = PreparedOp::Prepare(ins, outs, *op_kernel, place, attrs); + auto prepared_op = + PreparedOp::Prepare(ins, outs, *op_kernel, place, attrs, attrs_default); auto tmp_ins_ptr = PrepareData(*op_kernel, ins, prepared_op.kernel_type()); if (tmp_ins_ptr == nullptr) { - prepared_op.Run(ins, outs, attrs); + prepared_op.Run(ins, outs, attrs, attrs_default); } else { - prepared_op.Run(*tmp_ins_ptr, outs, attrs); + prepared_op.Run(*tmp_ins_ptr, outs, attrs, attrs_default); } VLOG(4) << LayerDebugString(op.Type(), ins, outs); From 8b1cdab1c527b6d2e0f01b5503d0b085b97f22f0 Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Fri, 4 Jun 2021 10:48:13 +0000 Subject: [PATCH 03/17] refine, test=develop --- paddle/fluid/imperative/partial_grad_engine.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/imperative/partial_grad_engine.cc b/paddle/fluid/imperative/partial_grad_engine.cc index c0c260b306866f..981c0be9513856 100644 --- a/paddle/fluid/imperative/partial_grad_engine.cc +++ b/paddle/fluid/imperative/partial_grad_engine.cc @@ -884,12 +884,12 @@ void PartialGradTask::RunEachOp(OpBase *op) { } // Run op - OpBase::Run(op->InnerOp(), tmp_ins, tmp_outs, op->Attrs(), op->place()); + OpBase::Run(op->InnerOp(), tmp_ins, tmp_outs, op->Attrs(), op->AttrsDefault(), + op->place()); if (create_graph_) { - auto double_grad_node = - CreateGradOpNode(op->InnerOp(), tmp_ins, tmp_outs, op->Attrs(), - op->AttrsDefault(), op->place(), {}); + auto double_grad_node = CreateGradOpNode(op->InnerOp(), tmp_ins, tmp_outs, + op->Attrs(), op->place(), {}); PADDLE_ENFORCE_NOT_NULL( double_grad_node, platform::errors::NotFound("The Op %s doesn't have any grad op. If you " From 0cf637418e11bdd7b6c9c37726223a43025112aa Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Fri, 4 Jun 2021 11:07:43 +0000 Subject: [PATCH 04/17] refine, test=develop --- paddle/fluid/imperative/tests/test_prepare_op.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/imperative/tests/test_prepare_op.cc b/paddle/fluid/imperative/tests/test_prepare_op.cc index 7d6882a4ee7d00..5e269d74044d24 100644 --- a/paddle/fluid/imperative/tests/test_prepare_op.cc +++ b/paddle/fluid/imperative/tests/test_prepare_op.cc @@ -93,7 +93,7 @@ TEST(test_prepare_op, test_prepare_op) { ASSERT_NO_FATAL_FAILURE(PreparedOp preparedOp = PreparedOp::Prepare( ins, outs, dynamic_cast(*op), - place, split_attr_map)); + place, split_attr_map, {})); } const framework::Tensor* GetTensorFromVar(const framework::Variable& var); @@ -144,7 +144,7 @@ TEST(test_prepare_op, test_prepare_data) { // test if it can be transformed to GPU place auto prepared_op = PreparedOp::Prepare( ins, outs, dynamic_cast(*op), gpu_place, - attr_map); + attr_map, {}); PrepareData( dynamic_cast(*op), ins, prepared_op.kernel_type()); @@ -193,7 +193,7 @@ void TestPrepareDataSamePlace(framework::AttributeMap attr_map) { // test if it never transferred on GPU place auto prepared_op = PreparedOp::Prepare( ins, outs, dynamic_cast(*op), cpu_place, - attr_map); + attr_map, {}); PrepareData( dynamic_cast(*op), ins, prepared_op.kernel_type()); From 2facfcaaae3b7d79cb3319b7d25c59ef2325da25 Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Mon, 7 Jun 2021 02:31:58 +0000 Subject: [PATCH 05/17] fix bug in AttrReader, test=develop --- paddle/fluid/framework/attribute.h | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/framework/attribute.h b/paddle/fluid/framework/attribute.h index 83f42e75b42d15..96ae831202246f 100644 --- a/paddle/fluid/framework/attribute.h +++ b/paddle/fluid/framework/attribute.h @@ -217,15 +217,17 @@ class AttrReader { template inline const T& Get(const std::string& name) const { auto it = attrs_.find(name); - if (it == attrs_.end() && attrs_default_ != nullptr) { - it = attrs_default_->find(name); - if (it == attrs_default_->end()) { + if (it == attrs_.end()) { + if (attrs_default_ != nullptr) { + it = attrs_default_->find(name); + if (it == attrs_default_->end()) { + PADDLE_THROW(platform::errors::NotFound( + "Attribute (%s) should be in AttributeMap.", name)); + } + } else { PADDLE_THROW(platform::errors::NotFound( "Attribute (%s) should be in AttributeMap.", name)); } - } else { - PADDLE_THROW(platform::errors::NotFound( - "Attribute (%s) should be in AttributeMap.", name)); } Attribute& attr = const_cast(it->second); From 911300cb19c193be3c23a873c2106f1cbdab8d67 Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Mon, 7 Jun 2021 06:12:52 +0000 Subject: [PATCH 06/17] fix bug, test=develop --- paddle/fluid/framework/grad_op_desc_maker.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/framework/grad_op_desc_maker.h b/paddle/fluid/framework/grad_op_desc_maker.h index ec6ebcf0e6740e..851c255d43af6e 100644 --- a/paddle/fluid/framework/grad_op_desc_maker.h +++ b/paddle/fluid/framework/grad_op_desc_maker.h @@ -228,11 +228,11 @@ class SingleGradOpMaker { imperative::TracedGradOp traced_grad_op(node); try { - this->Apply(&traced_grad_op); traced_grad_op.SetAttrDefaultMap( paddle::framework::OpInfoMap::Instance() .Get(this->ForwardOpType()) .checker_->GetAttrDefaultMap()); + this->Apply(&traced_grad_op); } catch (platform::EnforceNotMet& exception) { framework::AppendErrorOpHint(traced_grad_op.Type(), &exception); throw std::move(exception); From 3598f3ce15f12427eb9167796c4ae813b0e4d468 Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Tue, 8 Jun 2021 03:01:33 +0000 Subject: [PATCH 07/17] fix double_grad, test=develop --- paddle/fluid/framework/details/op_registry.h | 2 ++ paddle/fluid/framework/grad_op_desc_maker.h | 18 ++++++++++++++---- paddle/fluid/framework/type_defs.h | 1 + paddle/fluid/imperative/dygraph_grad_maker.h | 12 +++++++++++- paddle/fluid/imperative/layer.cc | 6 +++--- paddle/fluid/imperative/layer.h | 4 ++-- paddle/fluid/imperative/partial_grad_engine.cc | 5 +++-- paddle/fluid/imperative/tracer.cc | 3 ++- 8 files changed, 38 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/framework/details/op_registry.h b/paddle/fluid/framework/details/op_registry.h index df5370e42ee9f3..a6b51f1532b477 100644 --- a/paddle/fluid/framework/details/op_registry.h +++ b/paddle/fluid/framework/details/op_registry.h @@ -249,8 +249,10 @@ struct OpInfoFiller { const imperative::NameVarBaseMap& var_base_map_in, const imperative::NameVarBaseMap& var_base_map_out, const framework::AttributeMap& attrs, + const framework::AttributeMap& attrs_default, const std::map& inplace_map) { T maker(type, var_base_map_in, var_base_map_out, attrs, inplace_map); + maker.SetDygraphAttrsDefaultMap(attrs_default); return maker(); }; } diff --git a/paddle/fluid/framework/grad_op_desc_maker.h b/paddle/fluid/framework/grad_op_desc_maker.h index 851c255d43af6e..2a9ba6fb55c6ea 100644 --- a/paddle/fluid/framework/grad_op_desc_maker.h +++ b/paddle/fluid/framework/grad_op_desc_maker.h @@ -219,6 +219,19 @@ class SingleGradOpMaker public: using GradOpBaseMakerBase::GradOpBaseMakerBase; + virtual const framework::Attribute& GetAttr(const std::string& name) const { + auto it = Attrs().find(name); + if (it == Attrs().end()) { + it = this->AttrsDefault().find(name); + PADDLE_ENFORCE_EQ(it != this->AttrsDefault().end(), true, + platform::errors::NotFound( + "Cannot find attribute [%s] in operator [%s]", name, + this->ForwardOpType())); + } + + return it->second; + } + std::shared_ptr operator()() const final { auto node = this->NewGradNode(); auto& inplace_map = this->GetInplaceMap(); @@ -228,10 +241,7 @@ class SingleGradOpMaker { imperative::TracedGradOp traced_grad_op(node); try { - traced_grad_op.SetAttrDefaultMap( - paddle::framework::OpInfoMap::Instance() - .Get(this->ForwardOpType()) - .checker_->GetAttrDefaultMap()); + traced_grad_op.SetAttrDefaultMap(this->AttrsDefault()); this->Apply(&traced_grad_op); } catch (platform::EnforceNotMet& exception) { framework::AppendErrorOpHint(traced_grad_op.Type(), &exception); diff --git a/paddle/fluid/framework/type_defs.h b/paddle/fluid/framework/type_defs.h index e43cccfe648165..7b72485c521a5a 100644 --- a/paddle/fluid/framework/type_defs.h +++ b/paddle/fluid/framework/type_defs.h @@ -71,6 +71,7 @@ using DygraphGradOpMakerFN = const imperative::NameVarBaseMap& /*var_base_map_in*/, const imperative::NameVarBaseMap& /*var_base_map_out*/, const framework::AttributeMap& /*attributes*/, + const framework::AttributeMap& /*attributes default*/, const std::map& /*inplace_map*/)>; using InferVarTypeFN = diff --git a/paddle/fluid/imperative/dygraph_grad_maker.h b/paddle/fluid/imperative/dygraph_grad_maker.h index dceb663fc8c153..ace4d2364cad42 100644 --- a/paddle/fluid/imperative/dygraph_grad_maker.h +++ b/paddle/fluid/imperative/dygraph_grad_maker.h @@ -113,9 +113,18 @@ class GradOpBaseMakerBase { return vec_temp; } + // Only for dygraph + void SetDygraphAttrsDefaultMap(const framework::AttributeMap& attrs_default) { + attrs_default_ = &attrs_default; + } + + const framework::AttributeMap& AttrsDefault() const { + return *attrs_default_; + } + const framework::AttributeMap& Attrs() const { return attrs_; } - const framework::Attribute& GetAttr(const std::string& name) const { + virtual const framework::Attribute& GetAttr(const std::string& name) const { auto it = attrs_.find(name); PADDLE_ENFORCE_EQ( it != attrs_.end(), true, @@ -199,6 +208,7 @@ class GradOpBaseMakerBase { const NameVarBaseMap& var_base_map_in_; const NameVarBaseMap& var_base_map_out_; const framework::AttributeMap& attrs_; + const framework::AttributeMap* attrs_default_; const std::map& inplace_map_; }; diff --git a/paddle/fluid/imperative/layer.cc b/paddle/fluid/imperative/layer.cc index 64e182d2b223d8..2ce49ef060672a 100644 --- a/paddle/fluid/imperative/layer.cc +++ b/paddle/fluid/imperative/layer.cc @@ -451,15 +451,15 @@ void ClearNoNeedBufferInputs(OpBase* op) { std::shared_ptr CreateGradOpNode( const framework::OperatorBase& op, const NameVarBaseMap& ins, const NameVarBaseMap& outs, const framework::AttributeMap& attrs, - const platform::Place& place, + const framework::AttributeMap& attrs_default, const platform::Place& place, const std::map& inplace_map) { const auto& info = op.Info(); if (!info.dygraph_grad_op_maker_) { return nullptr; } - auto grad_node = - info.dygraph_grad_op_maker_(op.Type(), ins, outs, attrs, inplace_map); + auto grad_node = info.dygraph_grad_op_maker_(op.Type(), ins, outs, attrs, + attrs_default, inplace_map); if (grad_node && !grad_node->empty()) { for (auto& grad_op : *grad_node) { grad_op.SetId(OpBase::GenerateUniqueId()); diff --git a/paddle/fluid/imperative/layer.h b/paddle/fluid/imperative/layer.h index bbede47e364298..1933cb4dc35b12 100644 --- a/paddle/fluid/imperative/layer.h +++ b/paddle/fluid/imperative/layer.h @@ -108,7 +108,7 @@ class VarBase { void ClearGradVarBase() { grad_var_ = nullptr; } - void SetGradVarBase(VarBase& grad_var) { + void SetGradVarBase(VarBase& grad_var) { // NOLINT MutableGradVarBase()->CopyFrom(grad_var, true); } @@ -283,7 +283,7 @@ class Layer { std::shared_ptr CreateGradOpNode( const framework::OperatorBase& op, const NameVarBaseMap& ins, const NameVarBaseMap& outs, const framework::AttributeMap& attrs, - const platform::Place& place, + const framework::AttributeMap& attrs_default, const platform::Place& place, const std::map& inplace_map); void ClearNoNeedBufferInputs(OpBase* op); diff --git a/paddle/fluid/imperative/partial_grad_engine.cc b/paddle/fluid/imperative/partial_grad_engine.cc index 981c0be9513856..253948fa755a48 100644 --- a/paddle/fluid/imperative/partial_grad_engine.cc +++ b/paddle/fluid/imperative/partial_grad_engine.cc @@ -888,8 +888,9 @@ void PartialGradTask::RunEachOp(OpBase *op) { op->place()); if (create_graph_) { - auto double_grad_node = CreateGradOpNode(op->InnerOp(), tmp_ins, tmp_outs, - op->Attrs(), op->place(), {}); + auto double_grad_node = + CreateGradOpNode(op->InnerOp(), tmp_ins, tmp_outs, op->Attrs(), + op->AttrsDefault(), op->place(), {}); PADDLE_ENFORCE_NOT_NULL( double_grad_node, platform::errors::NotFound("The Op %s doesn't have any grad op. If you " diff --git a/paddle/fluid/imperative/tracer.cc b/paddle/fluid/imperative/tracer.cc index 492564c9f1278c..104ac31fca64a5 100644 --- a/paddle/fluid/imperative/tracer.cc +++ b/paddle/fluid/imperative/tracer.cc @@ -205,7 +205,8 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins, } if (ComputeRequiredGrad(new_ins, outs, trace_backward)) { - CreateGradOpNode(*op, new_ins, outs, attrs, place, inplace_map); + CreateGradOpNode(*op, new_ins, outs, attrs, + attr_checker->GetAttrDefaultMap(), place, inplace_map); } else { VLOG(3) << "No Grad to track for Op: " << type; } From 2ff8504bf82f715f0891f5a961cd294b083b3dd9 Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Tue, 8 Jun 2021 03:04:38 +0000 Subject: [PATCH 08/17] refine, test=develop --- paddle/fluid/framework/custom_operator.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/paddle/fluid/framework/custom_operator.cc b/paddle/fluid/framework/custom_operator.cc index c4b833ec94c294..52a817830b7c0b 100644 --- a/paddle/fluid/framework/custom_operator.cc +++ b/paddle/fluid/framework/custom_operator.cc @@ -781,10 +781,12 @@ void RegisterOperatorWithMetaInfo( const imperative::NameVarBaseMap& var_base_map_in, const imperative::NameVarBaseMap& var_base_map_out, const framework::AttributeMap& attrs, + const framework::AttributeMap& attrs_default, const std::map& inplace_map) { CustomGradOpMaker maker( type, var_base_map_in, var_base_map_out, attrs, inplace_map, grad_op_name, grad_op_inputs, grad_op_outputs); + maker.SetDygraphAttrsDefaultMap(attrs_default); return maker(); }; From e22ea5f9e8bda38b45cb3a6bc45412a06b297d70 Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Tue, 8 Jun 2021 06:03:48 +0000 Subject: [PATCH 09/17] refine, test=develop --- paddle/fluid/imperative/execution_context.h | 9 ++++----- paddle/fluid/imperative/infer_var_type_context.h | 9 ++++----- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/imperative/execution_context.h b/paddle/fluid/imperative/execution_context.h index a80c84e38442d2..a1015003a5c6a4 100644 --- a/paddle/fluid/imperative/execution_context.h +++ b/paddle/fluid/imperative/execution_context.h @@ -102,15 +102,14 @@ class DygraphExecutionContext : public framework::ExecutionContext { const framework::Attribute& GetAttr(const std::string& name) const override { auto it = attrs_.find(name); - bool find = (it != attrs_.end()); if (it == attrs_.end()) { it = attrs_default_.find(name); - find = (it != attrs_default_.end()); + if (it == attrs_default_.end()) { + PADDLE_THROW( + platform::errors::NotFound("Can not find [%s] in attrs", name)); + } } - PADDLE_ENFORCE_NE(find, false, platform::errors::NotFound( - "can not find [%s] in attrs", name)); - return it->second; } diff --git a/paddle/fluid/imperative/infer_var_type_context.h b/paddle/fluid/imperative/infer_var_type_context.h index 4286e6217b54d3..ea9bff59ba8f3c 100644 --- a/paddle/fluid/imperative/infer_var_type_context.h +++ b/paddle/fluid/imperative/infer_var_type_context.h @@ -45,15 +45,14 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext { framework::Attribute GetAttr(const std::string& name) const override { auto it = attrs_.find(name); - bool find = (it != attrs_.end()); if (it == attrs_.end()) { it = attrs_default_.find(name); - find = (it != attrs_default_.end()); + if (it == attrs_default_.end()) { + PADDLE_THROW( + platform::errors::NotFound("Can not find [%s] in attrs", name)); + } } - PADDLE_ENFORCE_NE(find, false, platform::errors::NotFound( - "Can not find [%s] in attrs", name)); - return it->second; } From d44aeb66ea2e4b14082b0226b582a02b53c457b6 Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Tue, 8 Jun 2021 11:41:07 +0000 Subject: [PATCH 10/17] fix checker null, test=develop --- paddle/fluid/imperative/tracer.cc | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/imperative/tracer.cc b/paddle/fluid/imperative/tracer.cc index 104ac31fca64a5..0fd320743518a7 100644 --- a/paddle/fluid/imperative/tracer.cc +++ b/paddle/fluid/imperative/tracer.cc @@ -157,6 +157,11 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins, attr_checker->Check(&attrs, true, true); } + static paddle::framework::AttributeMap empty_attrs_map = {}; + const paddle::framework::AttributeMap& attrs_default = + attr_checker == nullptr ? empty_attrs_map + : attr_checker->GetAttrDefaultMap(); + NameVarBaseMap new_ins = ins; if (enable_autocast_) { VLOG(5) << "Auto mixed precision run operator: " << type; @@ -181,8 +186,7 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins, #endif } - OpBase::Run(*op, new_ins, outs, attrs, attr_checker->GetAttrDefaultMap(), - place); + OpBase::Run(*op, new_ins, outs, attrs, attrs_default, place); } catch (platform::EnforceNotMet& exception) { framework::AppendErrorOpHint(type, &exception); throw std::move(exception); @@ -205,8 +209,8 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins, } if (ComputeRequiredGrad(new_ins, outs, trace_backward)) { - CreateGradOpNode(*op, new_ins, outs, attrs, - attr_checker->GetAttrDefaultMap(), place, inplace_map); + CreateGradOpNode(*op, new_ins, outs, attrs, attrs_default, place, + inplace_map); } else { VLOG(3) << "No Grad to track for Op: " << type; } From 32e2f6005db2c7cfc60f4643974281f6e4aee736 Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Fri, 18 Jun 2021 08:47:23 +0000 Subject: [PATCH 11/17] for test, test=develop --- paddle/fluid/imperative/op_base.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/imperative/op_base.h b/paddle/fluid/imperative/op_base.h index 4be7660824f0e4..84e64f8778c4e7 100644 --- a/paddle/fluid/imperative/op_base.h +++ b/paddle/fluid/imperative/op_base.h @@ -83,7 +83,7 @@ class OpBase { void CheckAttrs() { auto& info = Info(); if (info.Checker() != nullptr) { - info.Checker()->Check(&attrs_, true); + info.Checker()->Check(&attrs_, true, true); } } From eb9b97e139422b9f8bfbd01ce875d52cdf2cc78f Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Fri, 18 Jun 2021 10:36:54 +0000 Subject: [PATCH 12/17] refine, test=develop --- paddle/fluid/framework/attribute.h | 39 ++++++++++++------------ paddle/fluid/framework/op_proto_maker.cc | 3 +- paddle/fluid/imperative/op_base.h | 2 +- 3 files changed, 21 insertions(+), 23 deletions(-) diff --git a/paddle/fluid/framework/attribute.h b/paddle/fluid/framework/attribute.h index 96ae831202246f..a00c2c9260c034 100644 --- a/paddle/fluid/framework/attribute.h +++ b/paddle/fluid/framework/attribute.h @@ -217,18 +217,16 @@ class AttrReader { template inline const T& Get(const std::string& name) const { auto it = attrs_.find(name); - if (it == attrs_.end()) { + bool found = it != attrs.ends(); + if (it == attrs.ends()) { if (attrs_default_ != nullptr) { it = attrs_default_->find(name); - if (it == attrs_default_->end()) { - PADDLE_THROW(platform::errors::NotFound( - "Attribute (%s) should be in AttributeMap.", name)); - } - } else { - PADDLE_THROW(platform::errors::NotFound( - "Attribute (%s) should be in AttributeMap.", name)); + found = it != attrs_default_->end(); } } + PADDLE_ENFORCE(found == true, + platform::errors::NotFound( + "Attribute (%s) should be in AttributeMap.", name)) Attribute& attr = const_cast(it->second); ExtractAttribute extract_attr(name); @@ -249,8 +247,8 @@ class GreaterThanChecker { void operator()(const T& value) const { PADDLE_ENFORCE_GT( value, lower_bound_, - platform::errors::OutOfRange( - "Check for attribute value greater than a certain value failed.")); + platform::errors::OutOfRange("Check for attribute value greater than " + "a certain value failed.")); } private: @@ -347,9 +345,9 @@ class TypedAttrChecker { TypedAttrChecker& SetDefault(const T& default_value) { PADDLE_ENFORCE_EQ( default_value_setter_.empty(), true, - platform::errors::AlreadyExists( - "Attribute (%s) has a default value and cannot be set repeatedly.", - attr_name_)); + platform::errors::AlreadyExists("Attribute (%s) has a default value " + "and cannot be set repeatedly.", + attr_name_)); default_value_setter_.push_back(DefaultValueSetter(default_value)); return *this; } @@ -361,7 +359,7 @@ class TypedAttrChecker { } void operator()(AttributeMap* attr_map, bool get_default_value_only = false, - bool no_default_value = false) const { + bool without_default_value = false) const { if (get_default_value_only) { if (!default_value_setter_.empty()) { attr_map->emplace(attr_name_, default_value_setter_[0]()); @@ -369,7 +367,7 @@ class TypedAttrChecker { return; } - if (no_default_value) { + if (without_default_value) { auto it = attr_map->find(attr_name_); if (it != attr_map->end()) { ExtractAttribute extract_attr(attr_name_); @@ -387,9 +385,9 @@ class TypedAttrChecker { platform::errors::InvalidArgument( "Attribute (%s) is not set correctly.", attr_name_)); // default_value_setter_ has no more than one element - attr_map->emplace(attr_name_, default_value_setter_[0]()); + auto tmp = attr_map->emplace(attr_name_, default_value_setter_[0]()); + it = tmp.first; } - it = attr_map->find(attr_name_); ExtractAttribute extract_attr(attr_name_); T* attr_value = extract_attr(it->second); for (const auto& checker : value_checkers_) { @@ -417,11 +415,11 @@ class OpAttrChecker { } void Check(AttributeMap* attr_map, bool explicit_only = false, - bool no_default_value = false) const { + bool without_default_value = false) const { auto checker_num = attr_checkers_.size(); if (explicit_only) checker_num = explicit_checker_num_; for (size_t i = 0; i < checker_num; ++i) { - attr_checkers_[i](attr_map, false, no_default_value); + attr_checkers_[i](attr_map, false, without_default_value); } } @@ -455,7 +453,8 @@ class OpAttrChecker { // for explicit attribute, we mean the attribute added in the customized // op makers, usually it's defined in the overloaded Make method. // for implicit attribute, we mean the attribute added outside of the Make - // method like "op_role", "op_role_var", and they are useless in dynamic graph + // method like "op_role", "op_role_var", and they are useless in dynamic + // graph // mode size_t explicit_checker_num_; }; diff --git a/paddle/fluid/framework/op_proto_maker.cc b/paddle/fluid/framework/op_proto_maker.cc index 9d2bf183f0cbc9..d86b555d9a843d 100644 --- a/paddle/fluid/framework/op_proto_maker.cc +++ b/paddle/fluid/framework/op_proto_maker.cc @@ -66,6 +66,7 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto, op_checker_ = attr_checker; Make(); op_checker_->RecordExplicitCheckerNum(); + op_checker_->InitDefaultMap(); AddAttr(OpRoleAttrName(), "The role of this operator") .InEnum( @@ -93,8 +94,6 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto, AddAttr(OpDeviceAttrName(), "Device type of this operator.") .SetDefault(""); Validate(); - - op_checker_->InitDefaultMap(); } } // namespace framework diff --git a/paddle/fluid/imperative/op_base.h b/paddle/fluid/imperative/op_base.h index 84e64f8778c4e7..4be7660824f0e4 100644 --- a/paddle/fluid/imperative/op_base.h +++ b/paddle/fluid/imperative/op_base.h @@ -83,7 +83,7 @@ class OpBase { void CheckAttrs() { auto& info = Info(); if (info.Checker() != nullptr) { - info.Checker()->Check(&attrs_, true, true); + info.Checker()->Check(&attrs_, true); } } From e681acd6f03a220c28e8a8f866985481e03ca7ef Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Fri, 18 Jun 2021 10:43:46 +0000 Subject: [PATCH 13/17] refine, test=develop --- paddle/fluid/framework/attribute.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/framework/attribute.h b/paddle/fluid/framework/attribute.h index a00c2c9260c034..fc9d409455ba65 100644 --- a/paddle/fluid/framework/attribute.h +++ b/paddle/fluid/framework/attribute.h @@ -217,8 +217,8 @@ class AttrReader { template inline const T& Get(const std::string& name) const { auto it = attrs_.find(name); - bool found = it != attrs.ends(); - if (it == attrs.ends()) { + bool found = it != attrs_.ends(); + if (it == attrs_.ends()) { if (attrs_default_ != nullptr) { it = attrs_default_->find(name); found = it != attrs_default_->end(); @@ -226,7 +226,7 @@ class AttrReader { } PADDLE_ENFORCE(found == true, platform::errors::NotFound( - "Attribute (%s) should be in AttributeMap.", name)) + "Attribute (%s) should be in AttributeMap.", name)); Attribute& attr = const_cast(it->second); ExtractAttribute extract_attr(name); From ea470d7d9f96ae2e43bbe3cbd412e90b54510033 Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Fri, 18 Jun 2021 10:48:57 +0000 Subject: [PATCH 14/17] refine, test=develop --- paddle/fluid/framework/attribute.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/framework/attribute.h b/paddle/fluid/framework/attribute.h index fc9d409455ba65..db1a616fb23049 100644 --- a/paddle/fluid/framework/attribute.h +++ b/paddle/fluid/framework/attribute.h @@ -217,8 +217,8 @@ class AttrReader { template inline const T& Get(const std::string& name) const { auto it = attrs_.find(name); - bool found = it != attrs_.ends(); - if (it == attrs_.ends()) { + bool found = it != attrs_.end(); + if (it == attrs_.end()) { if (attrs_default_ != nullptr) { it = attrs_default_->find(name); found = it != attrs_default_->end(); From f54c3596ec661c7b9ac4547c8cc65cf5b4249d1e Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Fri, 18 Jun 2021 11:17:31 +0000 Subject: [PATCH 15/17] refine, test=develop --- paddle/fluid/framework/attribute.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/framework/attribute.h b/paddle/fluid/framework/attribute.h index db1a616fb23049..a2a55e5e80c368 100644 --- a/paddle/fluid/framework/attribute.h +++ b/paddle/fluid/framework/attribute.h @@ -224,9 +224,9 @@ class AttrReader { found = it != attrs_default_->end(); } } - PADDLE_ENFORCE(found == true, - platform::errors::NotFound( - "Attribute (%s) should be in AttributeMap.", name)); + PADDLE_ENFORCE_EQ(found, true, + platform::errors::NotFound( + "Attribute (%s) should be in AttributeMap.", name)); Attribute& attr = const_cast(it->second); ExtractAttribute extract_attr(name); From 2d4c9dc4058b9257b752c5b542eace03e7470fc9 Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Mon, 21 Jun 2021 07:31:23 +0000 Subject: [PATCH 16/17] refine, test=develop --- paddle/fluid/framework/attribute.h | 32 +++++++++---------- paddle/fluid/framework/custom_operator.cc | 4 +-- paddle/fluid/framework/details/op_registry.h | 4 +-- paddle/fluid/framework/grad_op_desc_maker.h | 6 ++-- paddle/fluid/framework/op_proto_maker.cc | 2 +- paddle/fluid/imperative/basic_engine.cc | 5 +-- paddle/fluid/imperative/dygraph_grad_maker.h | 10 +++--- paddle/fluid/imperative/execution_context.h | 17 +++++----- paddle/fluid/imperative/infer_shape_context.h | 6 ++-- .../fluid/imperative/infer_var_type_context.h | 12 +++---- paddle/fluid/imperative/layer.cc | 22 ++++++------- paddle/fluid/imperative/layer.h | 4 +-- paddle/fluid/imperative/op_base.h | 22 ++++++------- .../fluid/imperative/partial_grad_engine.cc | 6 ++-- paddle/fluid/imperative/prepared_operator.cc | 26 +++++++-------- paddle/fluid/imperative/prepared_operator.h | 8 ++--- paddle/fluid/imperative/tracer.cc | 10 +++--- 17 files changed, 99 insertions(+), 97 deletions(-) diff --git a/paddle/fluid/framework/attribute.h b/paddle/fluid/framework/attribute.h index a2a55e5e80c368..a73819a400ab91 100644 --- a/paddle/fluid/framework/attribute.h +++ b/paddle/fluid/framework/attribute.h @@ -209,19 +209,19 @@ Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc); class AttrReader { public: explicit AttrReader(const AttributeMap& attrs) - : attrs_(attrs), attrs_default_(nullptr) {} + : attrs_(attrs), default_attrs_(nullptr) {} - AttrReader(const AttributeMap& attrs, const AttributeMap& attrs_default) - : attrs_(attrs), attrs_default_(&attrs_default) {} + AttrReader(const AttributeMap& attrs, const AttributeMap& default_attrs) + : attrs_(attrs), default_attrs_(&default_attrs) {} template inline const T& Get(const std::string& name) const { auto it = attrs_.find(name); bool found = it != attrs_.end(); - if (it == attrs_.end()) { - if (attrs_default_ != nullptr) { - it = attrs_default_->find(name); - found = it != attrs_default_->end(); + if (!found) { + if (default_attrs_ != nullptr) { + it = default_attrs_->find(name); + found = it != default_attrs_->end(); } } PADDLE_ENFORCE_EQ(found, true, @@ -236,7 +236,7 @@ class AttrReader { private: const AttributeMap& attrs_; - const AttributeMap* attrs_default_; + const AttributeMap* default_attrs_; }; // check whether a value(attribute) fit a certain limit @@ -359,7 +359,7 @@ class TypedAttrChecker { } void operator()(AttributeMap* attr_map, bool get_default_value_only = false, - bool without_default_value = false) const { + bool only_check_exist_value = false) const { if (get_default_value_only) { if (!default_value_setter_.empty()) { attr_map->emplace(attr_name_, default_value_setter_[0]()); @@ -367,7 +367,7 @@ class TypedAttrChecker { return; } - if (without_default_value) { + if (only_check_exist_value) { auto it = attr_map->find(attr_name_); if (it != attr_map->end()) { ExtractAttribute extract_attr(attr_name_); @@ -415,11 +415,11 @@ class OpAttrChecker { } void Check(AttributeMap* attr_map, bool explicit_only = false, - bool without_default_value = false) const { + bool only_check_exist_value = false) const { auto checker_num = attr_checkers_.size(); if (explicit_only) checker_num = explicit_checker_num_; for (size_t i = 0; i < checker_num; ++i) { - attr_checkers_[i](attr_map, false, without_default_value); + attr_checkers_[i](attr_map, false, only_check_exist_value); } } @@ -435,18 +435,18 @@ class OpAttrChecker { explicit_checker_num_ = attr_checkers_.size(); } - void InitDefaultMap() { + void InitDefaultAttributeMap() { for (const auto& checker : attr_checkers_) { - checker(&attrs_default_, true, false); + checker(&default_attrs_, true, false); } } - const AttributeMap& GetAttrDefaultMap() const { return attrs_default_; } + const AttributeMap& GetDefaultAttrMap() const { return default_attrs_; } private: std::vector attr_checkers_; - AttributeMap attrs_default_; + AttributeMap default_attrs_; // in order to improve the efficiency of dynamic graph mode, // we divede the attribute into explicit type and implicit type. diff --git a/paddle/fluid/framework/custom_operator.cc b/paddle/fluid/framework/custom_operator.cc index 52a817830b7c0b..b1c5ff86d19790 100644 --- a/paddle/fluid/framework/custom_operator.cc +++ b/paddle/fluid/framework/custom_operator.cc @@ -781,12 +781,12 @@ void RegisterOperatorWithMetaInfo( const imperative::NameVarBaseMap& var_base_map_in, const imperative::NameVarBaseMap& var_base_map_out, const framework::AttributeMap& attrs, - const framework::AttributeMap& attrs_default, + const framework::AttributeMap& default_attrs, const std::map& inplace_map) { CustomGradOpMaker maker( type, var_base_map_in, var_base_map_out, attrs, inplace_map, grad_op_name, grad_op_inputs, grad_op_outputs); - maker.SetDygraphAttrsDefaultMap(attrs_default); + maker.SetDygraphDefaultAttrsMap(default_attrs); return maker(); }; diff --git a/paddle/fluid/framework/details/op_registry.h b/paddle/fluid/framework/details/op_registry.h index a6b51f1532b477..27f55e237f5168 100644 --- a/paddle/fluid/framework/details/op_registry.h +++ b/paddle/fluid/framework/details/op_registry.h @@ -249,10 +249,10 @@ struct OpInfoFiller { const imperative::NameVarBaseMap& var_base_map_in, const imperative::NameVarBaseMap& var_base_map_out, const framework::AttributeMap& attrs, - const framework::AttributeMap& attrs_default, + const framework::AttributeMap& default_attrs, const std::map& inplace_map) { T maker(type, var_base_map_in, var_base_map_out, attrs, inplace_map); - maker.SetDygraphAttrsDefaultMap(attrs_default); + maker.SetDygraphDefaultAttrsMap(default_attrs); return maker(); }; } diff --git a/paddle/fluid/framework/grad_op_desc_maker.h b/paddle/fluid/framework/grad_op_desc_maker.h index 2a9ba6fb55c6ea..8d52d62c2587d4 100644 --- a/paddle/fluid/framework/grad_op_desc_maker.h +++ b/paddle/fluid/framework/grad_op_desc_maker.h @@ -222,8 +222,8 @@ class SingleGradOpMaker virtual const framework::Attribute& GetAttr(const std::string& name) const { auto it = Attrs().find(name); if (it == Attrs().end()) { - it = this->AttrsDefault().find(name); - PADDLE_ENFORCE_EQ(it != this->AttrsDefault().end(), true, + it = this->DefaultAttrsMap().find(name); + PADDLE_ENFORCE_EQ(it != this->DefaultAttrsMap().end(), true, platform::errors::NotFound( "Cannot find attribute [%s] in operator [%s]", name, this->ForwardOpType())); @@ -241,7 +241,7 @@ class SingleGradOpMaker { imperative::TracedGradOp traced_grad_op(node); try { - traced_grad_op.SetAttrDefaultMap(this->AttrsDefault()); + traced_grad_op.SetAttrDefaultMap(this->DefaultAttrsMap()); this->Apply(&traced_grad_op); } catch (platform::EnforceNotMet& exception) { framework::AppendErrorOpHint(traced_grad_op.Type(), &exception); diff --git a/paddle/fluid/framework/op_proto_maker.cc b/paddle/fluid/framework/op_proto_maker.cc index d86b555d9a843d..8fbea51584d3ca 100644 --- a/paddle/fluid/framework/op_proto_maker.cc +++ b/paddle/fluid/framework/op_proto_maker.cc @@ -66,7 +66,7 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto, op_checker_ = attr_checker; Make(); op_checker_->RecordExplicitCheckerNum(); - op_checker_->InitDefaultMap(); + op_checker_->InitDefaultAttributeMap(); AddAttr(OpRoleAttrName(), "The role of this operator") .InEnum( diff --git a/paddle/fluid/imperative/basic_engine.cc b/paddle/fluid/imperative/basic_engine.cc index f52b9b73df6789..84ee1fbe5df96a 100644 --- a/paddle/fluid/imperative/basic_engine.cc +++ b/paddle/fluid/imperative/basic_engine.cc @@ -474,10 +474,11 @@ void BasicEngine::Execute() { try { if (tmp_ins_ptr == nullptr) { OpBase::Run(cur_op.InnerOp(), bwd_ins, tmp_outs, cur_op.Attrs(), - cur_op.AttrsDefault(), cur_op.place()); + cur_op.DefaultAttrsMap(), cur_op.place()); } else { OpBase::Run(cur_op.InnerOp(), *tmp_ins_ptr, tmp_outs, - cur_op.Attrs(), cur_op.AttrsDefault(), cur_op.place()); + cur_op.Attrs(), cur_op.DefaultAttrsMap(), + cur_op.place()); } } catch (platform::EnforceNotMet& exception) { Clear(); diff --git a/paddle/fluid/imperative/dygraph_grad_maker.h b/paddle/fluid/imperative/dygraph_grad_maker.h index ace4d2364cad42..40281b19daba41 100644 --- a/paddle/fluid/imperative/dygraph_grad_maker.h +++ b/paddle/fluid/imperative/dygraph_grad_maker.h @@ -114,12 +114,12 @@ class GradOpBaseMakerBase { } // Only for dygraph - void SetDygraphAttrsDefaultMap(const framework::AttributeMap& attrs_default) { - attrs_default_ = &attrs_default; + void SetDygraphDefaultAttrsMap(const framework::AttributeMap& default_attrs) { + default_attrs_ = &default_attrs; } - const framework::AttributeMap& AttrsDefault() const { - return *attrs_default_; + const framework::AttributeMap& DefaultAttrsMap() const { + return *default_attrs_; } const framework::AttributeMap& Attrs() const { return attrs_; } @@ -208,7 +208,7 @@ class GradOpBaseMakerBase { const NameVarBaseMap& var_base_map_in_; const NameVarBaseMap& var_base_map_out_; const framework::AttributeMap& attrs_; - const framework::AttributeMap* attrs_default_; + const framework::AttributeMap* default_attrs_; const std::map& inplace_map_; }; diff --git a/paddle/fluid/imperative/execution_context.h b/paddle/fluid/imperative/execution_context.h index a1015003a5c6a4..5446add86788b2 100644 --- a/paddle/fluid/imperative/execution_context.h +++ b/paddle/fluid/imperative/execution_context.h @@ -36,12 +36,12 @@ class DygraphExecutionContext : public framework::ExecutionContext { const NameVarMap& var_base_map_in, const NameVarMap& var_base_map_out, const framework::AttributeMap& attrs, - const framework::AttributeMap& attrs_default) + const framework::AttributeMap& default_attrs) : ExecutionContext(op, scope, device_context, ctx), var_base_map_in_(var_base_map_in), var_base_map_out_(var_base_map_out), attrs_(attrs), - attrs_default_(attrs_default) {} + default_attrs_(default_attrs) {} std::string InputName(const std::string& name) const override { auto it = var_base_map_in_.find(name); @@ -94,7 +94,7 @@ class DygraphExecutionContext : public framework::ExecutionContext { } bool HasAttr(const std::string& name) const override { - return attrs_.count(name) != 0 || attrs_default_.count(name) != 0; + return attrs_.count(name) != 0 || default_attrs_.count(name) != 0; } const framework::AttributeMap& Attrs() const override { return attrs_; } @@ -103,10 +103,11 @@ class DygraphExecutionContext : public framework::ExecutionContext { auto it = attrs_.find(name); if (it == attrs_.end()) { - it = attrs_default_.find(name); - if (it == attrs_default_.end()) { - PADDLE_THROW( - platform::errors::NotFound("Can not find [%s] in attrs", name)); + it = default_attrs_.find(name); + if (it == default_attrs_.end()) { + PADDLE_THROW(platform::errors::NotFound( + "Can not find [%s] in attributes of op %s.", name, + this->GetOp().Type())); } } @@ -198,7 +199,7 @@ class DygraphExecutionContext : public framework::ExecutionContext { const NameVarMap& var_base_map_in_; const NameVarMap& var_base_map_out_; const framework::AttributeMap& attrs_; - const framework::AttributeMap& attrs_default_; + const framework::AttributeMap& default_attrs_; }; } // namespace imperative diff --git a/paddle/fluid/imperative/infer_shape_context.h b/paddle/fluid/imperative/infer_shape_context.h index 79f71775d35304..613e0eb57bc6d9 100644 --- a/paddle/fluid/imperative/infer_shape_context.h +++ b/paddle/fluid/imperative/infer_shape_context.h @@ -40,7 +40,7 @@ class DygraphInferShapeContext : public framework::InferShapeContext { : var_base_map_in_(in), var_base_map_out_(out), attrs_(attr), - attrs_default_(attr_default), + default_attrs_(attr_default), op_type_(op_type) {} bool HasInput(const std::string& name) const override { @@ -103,7 +103,7 @@ class DygraphInferShapeContext : public framework::InferShapeContext { } framework::AttrReader Attrs() const override { - return framework::AttrReader(*attrs_, *attrs_default_); + return framework::AttrReader(*attrs_, *default_attrs_); } std::vector Inputs(const std::string& name) const override { @@ -397,7 +397,7 @@ class DygraphInferShapeContext : public framework::InferShapeContext { const NameVarMap* var_base_map_in_; const NameVarMap* var_base_map_out_; const framework::AttributeMap* attrs_; - const framework::AttributeMap* attrs_default_; + const framework::AttributeMap* default_attrs_; const std::string op_type_; }; diff --git a/paddle/fluid/imperative/infer_var_type_context.h b/paddle/fluid/imperative/infer_var_type_context.h index ea9bff59ba8f3c..5c7b5a8b8526fa 100644 --- a/paddle/fluid/imperative/infer_var_type_context.h +++ b/paddle/fluid/imperative/infer_var_type_context.h @@ -38,7 +38,7 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext { inputs_(inputs), outputs_(outputs), attrs_(attrs_map), - attrs_default_(attrs_map_default) {} + default_attrs_(attrs_map_default) {} virtual ~RuntimeInferVarTypeContext() {} @@ -46,10 +46,10 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext { auto it = attrs_.find(name); if (it == attrs_.end()) { - it = attrs_default_.find(name); - if (it == attrs_default_.end()) { - PADDLE_THROW( - platform::errors::NotFound("Can not find [%s] in attrs", name)); + it = default_attrs_.find(name); + if (it == default_attrs_.end()) { + PADDLE_THROW(platform::errors::NotFound( + "Can not find [%s] in attributes.", name)); } } @@ -241,7 +241,7 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext { const NameVarMap& inputs_; const NameVarMap& outputs_; const framework::AttributeMap& attrs_; - const framework::AttributeMap& attrs_default_; + const framework::AttributeMap& default_attrs_; }; } // namespace imperative diff --git a/paddle/fluid/imperative/layer.cc b/paddle/fluid/imperative/layer.cc index 2ce49ef060672a..6e28ecd9971abc 100644 --- a/paddle/fluid/imperative/layer.cc +++ b/paddle/fluid/imperative/layer.cc @@ -329,7 +329,7 @@ static void OpBaseRunImpl(const framework::OperatorBase& op, const NameVarMap& ins, const NameVarMap& outs, const framework::AttributeMap& attrs, - const framework::AttributeMap& attrs_default, + const framework::AttributeMap& default_attrs, const platform::Place& place) { auto* op_kernel = dynamic_cast(&op); PADDLE_ENFORCE_NOT_NULL( @@ -338,7 +338,7 @@ static void OpBaseRunImpl(const framework::OperatorBase& op, auto& info = op.Info(); if (info.infer_var_type_) { RuntimeInferVarTypeContext infer_var_type_ctx(ins, outs, attrs, - attrs_default); + default_attrs); info.infer_var_type_(&infer_var_type_ctx); } @@ -372,13 +372,13 @@ static void OpBaseRunImpl(const framework::OperatorBase& op, * overwritten in the previous dynamic graph implemention. */ auto prepared_op = - PreparedOp::Prepare(ins, outs, *op_kernel, place, attrs, attrs_default); + PreparedOp::Prepare(ins, outs, *op_kernel, place, attrs, default_attrs); auto tmp_ins_ptr = PrepareData(*op_kernel, ins, prepared_op.kernel_type()); if (tmp_ins_ptr == nullptr) { - prepared_op.Run(ins, outs, attrs, attrs_default); + prepared_op.Run(ins, outs, attrs, default_attrs); } else { - prepared_op.Run(*tmp_ins_ptr, outs, attrs, attrs_default); + prepared_op.Run(*tmp_ins_ptr, outs, attrs, default_attrs); } VLOG(4) << LayerDebugString(op.Type(), ins, outs); @@ -398,18 +398,18 @@ void OpBase::Run(const framework::OperatorBase& op, const NameVarMap& ins, const NameVarMap& outs, const framework::AttributeMap& attrs, - const framework::AttributeMap& attrs_default, + const framework::AttributeMap& default_attrs, const platform::Place& place) { - OpBaseRunImpl(op, ins, outs, attrs, attrs_default, place); + OpBaseRunImpl(op, ins, outs, attrs, default_attrs, place); } void OpBase::Run(const framework::OperatorBase& op, const NameVarMap& ins, const NameVarMap& outs, const framework::AttributeMap& attrs, - const framework::AttributeMap& attrs_default, + const framework::AttributeMap& default_attrs, const platform::Place& place) { - OpBaseRunImpl(op, ins, outs, attrs, attrs_default, place); + OpBaseRunImpl(op, ins, outs, attrs, default_attrs, place); } void ClearNoNeedBufferInputs(OpBase* op) { @@ -451,7 +451,7 @@ void ClearNoNeedBufferInputs(OpBase* op) { std::shared_ptr CreateGradOpNode( const framework::OperatorBase& op, const NameVarBaseMap& ins, const NameVarBaseMap& outs, const framework::AttributeMap& attrs, - const framework::AttributeMap& attrs_default, const platform::Place& place, + const framework::AttributeMap& default_attrs, const platform::Place& place, const std::map& inplace_map) { const auto& info = op.Info(); if (!info.dygraph_grad_op_maker_) { @@ -459,7 +459,7 @@ std::shared_ptr CreateGradOpNode( } auto grad_node = info.dygraph_grad_op_maker_(op.Type(), ins, outs, attrs, - attrs_default, inplace_map); + default_attrs, inplace_map); if (grad_node && !grad_node->empty()) { for (auto& grad_op : *grad_node) { grad_op.SetId(OpBase::GenerateUniqueId()); diff --git a/paddle/fluid/imperative/layer.h b/paddle/fluid/imperative/layer.h index 1933cb4dc35b12..56e16ba199707c 100644 --- a/paddle/fluid/imperative/layer.h +++ b/paddle/fluid/imperative/layer.h @@ -108,7 +108,7 @@ class VarBase { void ClearGradVarBase() { grad_var_ = nullptr; } - void SetGradVarBase(VarBase& grad_var) { // NOLINT + void SetGradVarBase(const VarBase& grad_var) { MutableGradVarBase()->CopyFrom(grad_var, true); } @@ -283,7 +283,7 @@ class Layer { std::shared_ptr CreateGradOpNode( const framework::OperatorBase& op, const NameVarBaseMap& ins, const NameVarBaseMap& outs, const framework::AttributeMap& attrs, - const framework::AttributeMap& attrs_default, const platform::Place& place, + const framework::AttributeMap& default_attrs, const platform::Place& place, const std::map& inplace_map); void ClearNoNeedBufferInputs(OpBase* op); diff --git a/paddle/fluid/imperative/op_base.h b/paddle/fluid/imperative/op_base.h index 4be7660824f0e4..f2065a95c5c642 100644 --- a/paddle/fluid/imperative/op_base.h +++ b/paddle/fluid/imperative/op_base.h @@ -50,8 +50,8 @@ class OpBase { const framework::AttributeMap& Attrs() const { return attrs_; } - const framework::AttributeMap& AttrsDefault() const { - return *attrs_default_; + const framework::AttributeMap& DefaultAttrsMap() const { + return *default_attrs_; } const framework::OpInfo& Info() const { @@ -103,8 +103,8 @@ class OpBase { void SetAttrMap(const framework::AttributeMap& attrs) { attrs_ = attrs; } - void SetAttrDefaultMap(const framework::AttributeMap& attrs_default) { - attrs_default_ = &attrs_default; + void SetAttrDefaultMap(const framework::AttributeMap& default_attrs) { + default_attrs_ = &default_attrs; } void SetAttr(const std::string& name, const framework::Attribute& v) { @@ -118,10 +118,10 @@ class OpBase { const framework::AttributeMap& Attrs() { return attrs_; } - const framework::AttributeMap& AttrsDefault() { return *attrs_default_; } + const framework::AttributeMap& DefaultAttrsMap() { return *default_attrs_; } bool HasAttr(const std::string& name) const { - return attrs_.count(name) > 0 || attrs_default_->count(name) > 0; + return attrs_.count(name) > 0 || default_attrs_->count(name) > 0; } const framework::Attribute& GetAttr(const std::string& name) const { @@ -129,9 +129,9 @@ class OpBase { if (it != attrs_.end()) { return it->second; } else { - auto it_default = attrs_default_->find(name); + auto it_default = default_attrs_->find(name); PADDLE_ENFORCE_NE( - it_default, attrs_default_->end(), + it_default, default_attrs_->end(), platform::errors::NotFound("can not find attribute [%s]", name)); return it_default->second; } @@ -173,14 +173,14 @@ class OpBase { const NameVarMap& ins, const NameVarMap& outs, const framework::AttributeMap& attrs, - const framework::AttributeMap& attrs_default, + const framework::AttributeMap& default_attrs, const platform::Place& place); static void Run(const framework::OperatorBase& op, const NameVarMap& ins, const NameVarMap& outs, const framework::AttributeMap& attrs, - const framework::AttributeMap& attrs_default, + const framework::AttributeMap& default_attrs, const platform::Place& place); private: @@ -193,7 +193,7 @@ class OpBase { NameVarMap ins_; NameVarMap outs_; framework::AttributeMap attrs_; - const framework::AttributeMap* attrs_default_; + const framework::AttributeMap* default_attrs_; std::unique_ptr op_; platform::Place place_; size_t id_{-1UL}; diff --git a/paddle/fluid/imperative/partial_grad_engine.cc b/paddle/fluid/imperative/partial_grad_engine.cc index 253948fa755a48..d905b1350821c4 100644 --- a/paddle/fluid/imperative/partial_grad_engine.cc +++ b/paddle/fluid/imperative/partial_grad_engine.cc @@ -884,13 +884,13 @@ void PartialGradTask::RunEachOp(OpBase *op) { } // Run op - OpBase::Run(op->InnerOp(), tmp_ins, tmp_outs, op->Attrs(), op->AttrsDefault(), - op->place()); + OpBase::Run(op->InnerOp(), tmp_ins, tmp_outs, op->Attrs(), + op->DefaultAttrsMap(), op->place()); if (create_graph_) { auto double_grad_node = CreateGradOpNode(op->InnerOp(), tmp_ins, tmp_outs, op->Attrs(), - op->AttrsDefault(), op->place(), {}); + op->DefaultAttrsMap(), op->place(), {}); PADDLE_ENFORCE_NOT_NULL( double_grad_node, platform::errors::NotFound("The Op %s doesn't have any grad op. If you " diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index e042555f1ea0b5..6bdb042ebd5572 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -92,7 +92,7 @@ PreparedOp PrepareImpl(const NameVarMap& ins, const framework::OperatorWithKernel& op, const platform::Place& place, const framework::AttributeMap& attrs, - const framework::AttributeMap& attrs_default) { + const framework::AttributeMap& default_attrs) { platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto* dev_ctx = pool.Get(place); @@ -111,7 +111,7 @@ PreparedOp PrepareImpl(const NameVarMap& ins, // 1. get expected kernel key auto expected_kernel_key = op.GetExpectedKernelType( DygraphExecutionContext(op, framework::Scope(), *dev_ctx, ctx, - ins, outs, attrs, attrs_default)); + ins, outs, attrs, default_attrs)); VLOG(3) << "expected_kernel_key:" << expected_kernel_key; // 2. check if op[type] has kernel registered. @@ -150,8 +150,8 @@ PreparedOp PreparedOp::Prepare(const NameVarMap& ins, const framework::OperatorWithKernel& op, const platform::Place& place, const framework::AttributeMap& attrs, - const framework::AttributeMap& attrs_default) { - return PrepareImpl(ins, outs, op, place, attrs, attrs_default); + const framework::AttributeMap& default_attrs) { + return PrepareImpl(ins, outs, op, place, attrs, default_attrs); } PreparedOp PreparedOp::Prepare(const NameVarMap& ins, @@ -159,9 +159,9 @@ PreparedOp PreparedOp::Prepare(const NameVarMap& ins, const framework::OperatorWithKernel& op, const platform::Place& place, const framework::AttributeMap& attrs, - const framework::AttributeMap& attrs_default) { + const framework::AttributeMap& default_attrs) { return PrepareImpl(ins, outs, op, place, attrs, - attrs_default); + default_attrs); } template @@ -171,17 +171,17 @@ static void PreparedOpRunImpl( const framework::OperatorWithKernel::OpKernelFunc& func, platform::DeviceContext* dev_ctx, const NameVarMap& ins, const NameVarMap& outs, const framework::AttributeMap& attrs, - const framework::AttributeMap& attrs_default) { + const framework::AttributeMap& default_attrs) { // TODO(zjl): remove scope in dygraph framework::Scope scope; DygraphInferShapeContext infer_shape_ctx(&ins, &outs, &attrs, - &attrs_default, op.Type()); + &default_attrs, op.Type()); static_cast(op).InferShape( &infer_shape_ctx); func(DygraphExecutionContext(op, scope, *dev_ctx, ctx, ins, outs, - attrs, attrs_default)); + attrs, default_attrs)); if (FLAGS_check_nan_inf) { framework::details::CheckOpHasNanOrInfInDygraph( @@ -208,17 +208,17 @@ static void PreparedOpRunImpl( void PreparedOp::Run(const NameVarMap& ins, const NameVarMap& outs, const framework::AttributeMap& attrs, - const framework::AttributeMap& attrs_default) { + const framework::AttributeMap& default_attrs) { PreparedOpRunImpl(op_, ctx_, kernel_type_, func_, dev_ctx_, ins, - outs, attrs, attrs_default); + outs, attrs, default_attrs); } void PreparedOp::Run(const NameVarMap& ins, const NameVarMap& outs, const framework::AttributeMap& attrs, - const framework::AttributeMap& attrs_default) { + const framework::AttributeMap& default_attrs) { PreparedOpRunImpl(op_, ctx_, kernel_type_, func_, dev_ctx_, - ins, outs, attrs, attrs_default); + ins, outs, attrs, default_attrs); } } // namespace imperative diff --git a/paddle/fluid/imperative/prepared_operator.h b/paddle/fluid/imperative/prepared_operator.h index f5b133f7bdc187..53f876c498cd04 100644 --- a/paddle/fluid/imperative/prepared_operator.h +++ b/paddle/fluid/imperative/prepared_operator.h @@ -152,23 +152,23 @@ class PreparedOp { const framework::OperatorWithKernel& op, const platform::Place& place, const framework::AttributeMap& attrs, - const framework::AttributeMap& attrs_default); + const framework::AttributeMap& default_attrs); static PreparedOp Prepare(const NameVarMap& ins, const NameVarMap& outs, const framework::OperatorWithKernel& op, const platform::Place& place, const framework::AttributeMap& attrs, - const framework::AttributeMap& attrs_default); + const framework::AttributeMap& default_attrs); void Run(const NameVarMap& in, const NameVarMap& out, const framework::AttributeMap& attrs, - const framework::AttributeMap& attrs_default); + const framework::AttributeMap& default_attrs); void Run(const NameVarMap& ins, const NameVarMap& outs, const framework::AttributeMap& attrs, - const framework::AttributeMap& attrs_default); + const framework::AttributeMap& default_attrs); const framework::OpKernelType& kernel_type() const { return kernel_type_; } diff --git a/paddle/fluid/imperative/tracer.cc b/paddle/fluid/imperative/tracer.cc index 0fd320743518a7..367f948ef63b22 100644 --- a/paddle/fluid/imperative/tracer.cc +++ b/paddle/fluid/imperative/tracer.cc @@ -154,13 +154,13 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins, const auto& op_info = op->Info(); auto* attr_checker = op_info.Checker(); if (attr_checker) { - attr_checker->Check(&attrs, true, true); + attr_checker->Check(&attrs, true, /*only_check_exist_value=*/true); } static paddle::framework::AttributeMap empty_attrs_map = {}; - const paddle::framework::AttributeMap& attrs_default = + const paddle::framework::AttributeMap& default_attrs = attr_checker == nullptr ? empty_attrs_map - : attr_checker->GetAttrDefaultMap(); + : attr_checker->GetDefaultAttrMap(); NameVarBaseMap new_ins = ins; if (enable_autocast_) { @@ -186,7 +186,7 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins, #endif } - OpBase::Run(*op, new_ins, outs, attrs, attrs_default, place); + OpBase::Run(*op, new_ins, outs, attrs, default_attrs, place); } catch (platform::EnforceNotMet& exception) { framework::AppendErrorOpHint(type, &exception); throw std::move(exception); @@ -209,7 +209,7 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins, } if (ComputeRequiredGrad(new_ins, outs, trace_backward)) { - CreateGradOpNode(*op, new_ins, outs, attrs, attrs_default, place, + CreateGradOpNode(*op, new_ins, outs, attrs, default_attrs, place, inplace_map); } else { VLOG(3) << "No Grad to track for Op: " << type; From 806564c66e9c3ff2de0094d65f019ca04bb70a80 Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Wed, 23 Jun 2021 01:24:52 +0000 Subject: [PATCH 17/17] refine, test=develop --- paddle/fluid/framework/attribute.h | 2 +- paddle/fluid/framework/grad_op_desc_maker.h | 2 +- paddle/fluid/framework/ir/op_compat_sensible_pass.cc | 2 +- paddle/fluid/framework/type_defs.h | 2 +- paddle/fluid/imperative/dygraph_grad_maker.h | 4 ++-- paddle/fluid/imperative/infer_shape_context.h | 4 ++-- paddle/fluid/imperative/infer_var_type_context.h | 4 ++-- paddle/fluid/imperative/op_base.h | 2 +- paddle/fluid/imperative/tests/test_layer.cc | 4 ++-- paddle/fluid/pybind/pybind.cc | 2 +- 10 files changed, 14 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/framework/attribute.h b/paddle/fluid/framework/attribute.h index a73819a400ab91..e9e18757656339 100644 --- a/paddle/fluid/framework/attribute.h +++ b/paddle/fluid/framework/attribute.h @@ -423,7 +423,7 @@ class OpAttrChecker { } } - AttributeMap GetAttrsDefaultValuesMap() const { + AttributeMap GetDefaultAttrsMap() const { AttributeMap default_values_map; for (const auto& checker : attr_checkers_) { checker(&default_values_map, true, false); diff --git a/paddle/fluid/framework/grad_op_desc_maker.h b/paddle/fluid/framework/grad_op_desc_maker.h index 8d52d62c2587d4..ebbfd446a03de2 100644 --- a/paddle/fluid/framework/grad_op_desc_maker.h +++ b/paddle/fluid/framework/grad_op_desc_maker.h @@ -241,7 +241,7 @@ class SingleGradOpMaker { imperative::TracedGradOp traced_grad_op(node); try { - traced_grad_op.SetAttrDefaultMap(this->DefaultAttrsMap()); + traced_grad_op.SetDefaultAttrsMap(this->DefaultAttrsMap()); this->Apply(&traced_grad_op); } catch (platform::EnforceNotMet& exception) { framework::AppendErrorOpHint(traced_grad_op.Type(), &exception); diff --git a/paddle/fluid/framework/ir/op_compat_sensible_pass.cc b/paddle/fluid/framework/ir/op_compat_sensible_pass.cc index cbb12839362f38..56637a6c7b2b3e 100644 --- a/paddle/fluid/framework/ir/op_compat_sensible_pass.cc +++ b/paddle/fluid/framework/ir/op_compat_sensible_pass.cc @@ -61,7 +61,7 @@ AttrCompat& AttrCompat::IsLeftDefault() { return *this; } const OpInfo& op_info = OpInfoMap::Instance().Get(op_name); - const AttributeMap attrs = op_info.Checker()->GetAttrsDefaultValuesMap(); + const AttributeMap attrs = op_info.Checker()->GetDefaultAttrsMap(); if (attrs.find(attr_name_) == attrs.end()) { LOG(WARNING) << "Op (" << op_name << ") has no default attr:" << attr_name_; conditions_.emplace_back([](const Attribute& attr) { return false; }); diff --git a/paddle/fluid/framework/type_defs.h b/paddle/fluid/framework/type_defs.h index 7b72485c521a5a..951daea47bde3b 100644 --- a/paddle/fluid/framework/type_defs.h +++ b/paddle/fluid/framework/type_defs.h @@ -71,7 +71,7 @@ using DygraphGradOpMakerFN = const imperative::NameVarBaseMap& /*var_base_map_in*/, const imperative::NameVarBaseMap& /*var_base_map_out*/, const framework::AttributeMap& /*attributes*/, - const framework::AttributeMap& /*attributes default*/, + const framework::AttributeMap& /*default attributes*/, const std::map& /*inplace_map*/)>; using InferVarTypeFN = diff --git a/paddle/fluid/imperative/dygraph_grad_maker.h b/paddle/fluid/imperative/dygraph_grad_maker.h index 40281b19daba41..f1eb8aa62c9271 100644 --- a/paddle/fluid/imperative/dygraph_grad_maker.h +++ b/paddle/fluid/imperative/dygraph_grad_maker.h @@ -295,8 +295,8 @@ class TracedGradOp { return op_->SetAttrMap(attrs); } - void SetAttrDefaultMap(const framework::AttributeMap& attrs) { - return op_->SetAttrDefaultMap(attrs); + void SetDefaultAttrsMap(const framework::AttributeMap& attrs) { + return op_->SetDefaultAttrsMap(attrs); } void SetAttr(const std::string& name, const framework::Attribute& v) { diff --git a/paddle/fluid/imperative/infer_shape_context.h b/paddle/fluid/imperative/infer_shape_context.h index 613e0eb57bc6d9..7efe1177f5dc78 100644 --- a/paddle/fluid/imperative/infer_shape_context.h +++ b/paddle/fluid/imperative/infer_shape_context.h @@ -35,12 +35,12 @@ class DygraphInferShapeContext : public framework::InferShapeContext { DygraphInferShapeContext(const NameVarMap* in, const NameVarMap* out, const framework::AttributeMap* attr, - const framework::AttributeMap* attr_default, + const framework::AttributeMap* default_attr, const std::string op_type) : var_base_map_in_(in), var_base_map_out_(out), attrs_(attr), - default_attrs_(attr_default), + default_attrs_(default_attr), op_type_(op_type) {} bool HasInput(const std::string& name) const override { diff --git a/paddle/fluid/imperative/infer_var_type_context.h b/paddle/fluid/imperative/infer_var_type_context.h index 5c7b5a8b8526fa..7defc339f4f81d 100644 --- a/paddle/fluid/imperative/infer_var_type_context.h +++ b/paddle/fluid/imperative/infer_var_type_context.h @@ -33,12 +33,12 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext { RuntimeInferVarTypeContext(const NameVarMap& inputs, const NameVarMap& outputs, const framework::AttributeMap& attrs_map, - const framework::AttributeMap& attrs_map_default) + const framework::AttributeMap& default_attrs_map) : InferVarTypeContext(nullptr, nullptr), inputs_(inputs), outputs_(outputs), attrs_(attrs_map), - default_attrs_(attrs_map_default) {} + default_attrs_(default_attrs_map) {} virtual ~RuntimeInferVarTypeContext() {} diff --git a/paddle/fluid/imperative/op_base.h b/paddle/fluid/imperative/op_base.h index f2065a95c5c642..acb125a82925d7 100644 --- a/paddle/fluid/imperative/op_base.h +++ b/paddle/fluid/imperative/op_base.h @@ -103,7 +103,7 @@ class OpBase { void SetAttrMap(const framework::AttributeMap& attrs) { attrs_ = attrs; } - void SetAttrDefaultMap(const framework::AttributeMap& default_attrs) { + void SetDefaultAttrsMap(const framework::AttributeMap& default_attrs) { default_attrs_ = &default_attrs; } diff --git a/paddle/fluid/imperative/tests/test_layer.cc b/paddle/fluid/imperative/tests/test_layer.cc index 3c9711a051ac80..064f47f54979a1 100644 --- a/paddle/fluid/imperative/tests/test_layer.cc +++ b/paddle/fluid/imperative/tests/test_layer.cc @@ -46,9 +46,9 @@ class TestRuntimeInferVarTypeContext TestRuntimeInferVarTypeContext( const NameVarMap& inputs, const NameVarMap& outputs, const framework::AttributeMap& attrs_map, - const framework::AttributeMap& attrs_map_default) + const framework::AttributeMap& default_attrs_map) : RuntimeInferVarTypeContext(inputs, outputs, attrs_map, - attrs_map_default) {} + default_attrs_map) {} bool HasVar(const std::string& name) const { return RuntimeInferVarTypeContext::HasVar(name); diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 86084297c4ae65..67f004e61cbfdf 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -1308,7 +1308,7 @@ All parameter, weight, gradient are variables in Paddle. if (info != nullptr) { if (info->HasOpProtoAndChecker()) { auto op_checker = info->Checker(); - res = op_checker->GetAttrsDefaultValuesMap(); + res = op_checker->GetDefaultAttrsMap(); } } return res;