Skip to content

Commit 656e60b

Browse files
authored
new class: op_version_registry, test=develop (#26542)
1 parent 24566e9 commit 656e60b

File tree

4 files changed

+189
-0
lines changed

4 files changed

+189
-0
lines changed

paddle/fluid/framework/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,10 @@ cc_library(data_transform SRCS data_transform.cc DEPS math_function tensor
122122
cc_library(attribute SRCS attribute.cc DEPS framework_proto boost)
123123
cc_test(program_desc_test SRCS program_desc_test.cc DEPS proto_desc
124124
device_context)
125+
126+
cc_library(op_version_registry SRCS op_version_registry.cc DEPS framework_proto boost)
127+
cc_test(op_version_registry_test SRCS op_version_registry_test.cc DEPS op_version_registry)
128+
125129
cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute glog)
126130
cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker)
127131
cc_library(no_need_buffer_vars_inference SRCS no_need_buffer_vars_inference.cc DEPS attribute device_context)
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/framework/op_version_registry.h"
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include <memory>
18+
#include <string>
19+
#include <unordered_map>
20+
#include <utility>
21+
#include <vector>
22+
23+
#include <boost/any.hpp>
24+
#include "paddle/fluid/framework/framework.pb.h"
25+
#include "paddle/fluid/platform/enforce.h"
26+
27+
namespace paddle {
28+
namespace framework {
29+
namespace compatible {
30+
31+
struct OpUpdateRecord {
32+
enum class Type { kInvalid = 0, kModifyAttr, kNewAttr };
33+
Type type_;
34+
std::string remark_;
35+
};
36+
37+
struct ModifyAttr : OpUpdateRecord {
38+
ModifyAttr(const std::string& name, const std::string& remark,
39+
boost::any default_value)
40+
: OpUpdateRecord({Type::kModifyAttr, remark}),
41+
name_(name),
42+
default_value_(default_value) {
43+
// TODO(Shixiaowei02): Check the data type with proto::OpDesc.
44+
}
45+
46+
private:
47+
std::string name_;
48+
boost::any default_value_;
49+
};
50+
struct NewAttr : OpUpdateRecord {
51+
NewAttr(const std::string& name, const std::string& remark)
52+
: OpUpdateRecord({Type::kNewAttr, remark}), name_(name) {}
53+
54+
private:
55+
std::string name_;
56+
};
57+
58+
class OpVersionDesc {
59+
public:
60+
OpVersionDesc& ModifyAttr(const std::string& name, const std::string& remark,
61+
boost::any default_value) {
62+
infos_.push_back(std::shared_ptr<OpUpdateRecord>(
63+
new compatible::ModifyAttr(name, remark, default_value)));
64+
return *this;
65+
}
66+
67+
OpVersionDesc& NewAttr(const std::string& name, const std::string& remark) {
68+
infos_.push_back(
69+
std::shared_ptr<OpUpdateRecord>(new compatible::NewAttr(name, remark)));
70+
return *this;
71+
}
72+
73+
private:
74+
std::vector<std::shared_ptr<OpUpdateRecord>> infos_;
75+
};
76+
77+
class OpVersion {
78+
public:
79+
OpVersion& AddCheckpoint(const std::string& note,
80+
const OpVersionDesc& op_version_desc) {
81+
checkpoints_.push_back(Checkpoint({note, op_version_desc}));
82+
return *this;
83+
}
84+
85+
private:
86+
struct Checkpoint {
87+
std::string note_;
88+
OpVersionDesc op_version_desc_;
89+
};
90+
std::vector<Checkpoint> checkpoints_;
91+
};
92+
93+
class OpVersionRegistrar {
94+
public:
95+
static OpVersionRegistrar& GetInstance() {
96+
static OpVersionRegistrar instance;
97+
return instance;
98+
}
99+
OpVersion& Register(const std::string& op_type) {
100+
if (op_version_map_.find(op_type) != op_version_map_.end()) {
101+
PADDLE_THROW("'%s' is registered in operator version more than once.",
102+
op_type);
103+
}
104+
op_version_map_.insert({op_type, OpVersion()});
105+
return op_version_map_[op_type];
106+
}
107+
108+
private:
109+
std::unordered_map<std::string, OpVersion> op_version_map_;
110+
111+
OpVersionRegistrar() = default;
112+
OpVersionRegistrar& operator=(const OpVersionRegistrar&) = delete;
113+
};
114+
115+
} // namespace compatible
116+
} // namespace framework
117+
} // namespace paddle
118+
119+
#define REGISTER_OP_VERSION(op_type) \
120+
static paddle::framework::compatible::OpVersion \
121+
RegisterOpVersion__##op_type = \
122+
paddle::framework::compatible::OpVersionRegistrar::GetInstance() \
123+
.Register(#op_type)
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include <glog/logging.h>
16+
#include <gtest/gtest.h>
17+
18+
#include "paddle/fluid/framework/op_version_registry.h"
19+
20+
namespace paddle {
21+
namespace framework {
22+
namespace compatible {
23+
24+
TEST(test_operator_version, test_operator_version) {
25+
REGISTER_OP_VERSION(test__)
26+
.AddCheckpoint(
27+
R"ROC(
28+
Upgrade reshape, modified one attribute [axis] and add a new attribute [size].
29+
)ROC",
30+
framework::compatible::OpVersionDesc()
31+
.ModifyAttr("axis",
32+
"Increased from the original one method to two.", -1)
33+
.NewAttr("size",
34+
"In order to represent a two-dimensional rectangle, the "
35+
"parameter size is added."))
36+
.AddCheckpoint(
37+
R"ROC(
38+
Add a new attribute [height]
39+
)ROC",
40+
framework::compatible::OpVersionDesc().NewAttr(
41+
"height",
42+
"In order to represent a two-dimensional rectangle, the "
43+
"parameter height is added."));
44+
}
45+
} // namespace compatible
46+
} // namespace framework
47+
} // namespace paddle

0 commit comments

Comments
 (0)