Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions paddle/fluid/framework/details/op_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ limitations under the License. */

#pragma once

#include <memory>
#include <string>
#include <tuple>
#include <type_traits>
Expand Down Expand Up @@ -161,8 +162,8 @@ struct OpInfoFiller<T, kOperator> {
template <typename T>
struct OpInfoFiller<T, kOpProtoAndCheckerMaker> {
void operator()(const char* op_type, OpInfo* info) const {
info->proto_ = new proto::OpProto;
info->checker_ = new OpAttrChecker();
info->proto_ = std::make_shared<proto::OpProto>();
info->checker_ = std::make_shared<OpAttrChecker>();
T maker;
maker(info->proto_, info->checker_);
info->proto_->set_type(op_type);
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/op_desc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -668,7 +668,7 @@ static void InitInferShapeFuncs() {
void OpDesc::CheckAttrs() {
PADDLE_ENFORCE(!Type().empty(),
"CheckAttr() can not be called before type is setted.");
auto *checker = OpInfoMap::Instance().Get(Type()).Checker();
auto checker = OpInfoMap::Instance().Get(Type()).Checker();
if (checker == nullptr) {
// checker is not configured. That operator could be generated by Paddle,
// not by users.
Expand Down
10 changes: 6 additions & 4 deletions paddle/fluid/framework/op_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@ limitations under the License. */
#pragma once
#include <functional>
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>

#include "paddle/fluid/framework/attribute.h"
Expand All @@ -36,13 +38,13 @@ class InferShapeBase {
struct OpInfo {
OpCreator creator_;
GradOpMakerFN grad_op_maker_;
proto::OpProto* proto_{nullptr};
OpAttrChecker* checker_{nullptr};
std::shared_ptr<proto::OpProto> proto_;
std::shared_ptr<OpAttrChecker> checker_;
InferVarTypeFN infer_var_type_;
InferShapeFN infer_shape_;
InferInplaceOpFN infer_inplace_;
InferNoNeedBufferVarsFN infer_no_need_buffer_vars_;

OpInfo() : proto_{nullptr}, checker_{nullptr} {}
// NOTE(zjl): this flag is added to check whether
// the grad maker is the default one.
bool use_default_grad_op_desc_maker_{false};
Expand Down Expand Up @@ -70,7 +72,7 @@ struct OpInfo {
return grad_op_maker_;
}

const OpAttrChecker* Checker() const { return checker_; }
const std::shared_ptr<OpAttrChecker> Checker() const { return checker_; }

const InferNoNeedBufferVarsFN& NoNeedBufferVarsInferer() const {
return infer_no_need_buffer_vars_;
Expand Down
5 changes: 3 additions & 2 deletions paddle/fluid/framework/op_proto_maker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,9 @@ void OpProtoAndCheckerMaker::CheckNoDuplicatedInOutAttrs() {
}
}

void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto,
OpAttrChecker* attr_checker) {
void OpProtoAndCheckerMaker::operator()(
std::shared_ptr<proto::OpProto> proto,
std::shared_ptr<OpAttrChecker> attr_checker) {
proto_ = proto;
op_checker_ = attr_checker;
Make();
Expand Down
10 changes: 6 additions & 4 deletions paddle/fluid/framework/op_proto_maker.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ limitations under the License. */

#pragma once

#include <memory>
#include <string>
#include <unordered_set>
#include "glog/logging.h"
#include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/framework.pb.h"
Expand Down Expand Up @@ -49,7 +51,8 @@ class OpProtoAndCheckerMaker {
static const char *OpNamescopeAttrName() { return "op_namescope"; }
static const char *OpCreationCallstackAttrName() { return "op_callstack"; }

void operator()(proto::OpProto *proto, OpAttrChecker *attr_checker);
void operator()(std::shared_ptr<proto::OpProto> proto,
std::shared_ptr<OpAttrChecker> attr_checker);

virtual void Make() = 0;

Expand Down Expand Up @@ -99,9 +102,8 @@ class OpProtoAndCheckerMaker {
private:
void CheckNoDuplicatedInOutAttrs();
void Validate();

proto::OpProto *proto_;
OpAttrChecker *op_checker_;
std::shared_ptr<proto::OpProto> proto_;
std::shared_ptr<OpAttrChecker> op_checker_;
bool validated_{false};
};
} // namespace framework
Expand Down
12 changes: 6 additions & 6 deletions paddle/fluid/framework/op_proto_maker_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ class TestAttrProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
};

TEST(ProtoMaker, DuplicatedAttr) {
paddle::framework::proto::OpProto op_proto;
paddle::framework::OpAttrChecker op_checker;
auto op_proto = std::make_shared<paddle::framework::proto::OpProto>();
auto op_checker = std::make_shared<paddle::framework::OpAttrChecker>();
TestAttrProtoMaker proto_maker;
ASSERT_THROW(proto_maker(&op_proto, &op_checker),
ASSERT_THROW(proto_maker(op_proto, op_checker),
paddle::platform::EnforceNotMet);
}

Expand All @@ -41,9 +41,9 @@ class TestInOutProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
};

TEST(ProtoMaker, DuplicatedInOut) {
paddle::framework::proto::OpProto op_proto;
paddle::framework::OpAttrChecker op_checker;
auto op_proto = std::make_shared<paddle::framework::proto::OpProto>();
auto op_checker = std::make_shared<paddle::framework::OpAttrChecker>();
TestAttrProtoMaker proto_maker;
ASSERT_THROW(proto_maker(&op_proto, &op_checker),
ASSERT_THROW(proto_maker(op_proto, op_checker),
paddle::platform::EnforceNotMet);
}