Skip to content

Commit 6a6d6a7

Browse files
committed
[PIR] polish the ir_mapping implimentation.
1 parent da5399a commit 6a6d6a7

File tree

3 files changed

+53
-40
lines changed

3 files changed

+53
-40
lines changed

paddle/fluid/pybind/pir.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1030,7 +1030,7 @@ std::pair<std::shared_ptr<Program>, OpResultMap> CloneProgram(
10301030
pir::IrMapping mapper;
10311031
auto cloned_program = program.Clone(mapper);
10321032
std::vector<pir::OpResult> associated_array_key, associated_array_value;
1033-
for (auto &pair : mapper.Map<pir::Value>()) {
1033+
for (auto &pair : mapper.GetMap<pir::Value>()) {
10341034
associated_array_key.push_back(pair.first.dyn_cast<pir::OpResult>());
10351035
associated_array_value.push_back(pair.second.dyn_cast<pir::OpResult>());
10361036
}
@@ -1119,12 +1119,12 @@ SplitedResult SplitForwardBackward(
11191119
auto *cloned_op = op->Clone(forward_mapper, clone_options);
11201120
forward_program->block()->push_back(cloned_op);
11211121
});
1122-
auto &forward_value_map = forward_mapper.MutableMap<pir::Value>();
1122+
auto &forward_value_map = forward_mapper.GetMutableMap<pir::Value>();
11231123

11241124
// backward program construc.
11251125
// Step1. insert data op for inputs_values and middle_values
11261126
pir::IrMapping backward_mapper;
1127-
auto &backward_value_map = backward_mapper.MutableMap<pir::Value>();
1127+
auto &backward_value_map = backward_mapper.GetMutableMap<pir::Value>();
11281128
int counter = 0;
11291129
auto create_data_fn = [&backward_builder,
11301130
&backward_inputs,

paddle/pir/core/ir_mapping.h

Lines changed: 49 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -20,24 +20,65 @@ namespace pir {
2020
class Block;
2121
class Operation;
2222

23+
namespace detail {
24+
template <typename T, typename... OthersT>
25+
struct ExactlyOneIrType {
26+
using type = void;
27+
};
28+
template <typename T, typename FirstT, typename... OthersT>
29+
struct ExactlyOneIrType<T, FirstT, OthersT...> {
30+
using type =
31+
std::conditional_t<std::is_convertible<T, FirstT>::value,
32+
FirstT,
33+
typename ExactlyOneIrType<T, OthersT...>::type>;
34+
};
35+
} // namespace detail
2336
class IrMapping {
2437
public:
2538
template <typename T>
26-
void Add(T from, T to) {
39+
using IrType =
40+
typename detail::ExactlyOneIrType<T, Value, Block*, Operation*>::type;
41+
template <typename T>
42+
std::unordered_map<T, T>& GetMutableMap() {
43+
if constexpr (std::is_same<T, Value>::value) {
44+
return value_map_;
45+
} else if constexpr (std::is_same<T, Block*>::value) {
46+
return block_map_;
47+
} else if constexpr (std::is_same<T, Operation*>::value) {
48+
return operation_map_;
49+
} else {
50+
IR_THROW("Not support type in IRMapping.");
51+
}
52+
}
53+
template <typename T>
54+
const std::unordered_map<T, T>& GetMap() const {
55+
if constexpr (std::is_same<T, Value>::value) {
56+
return value_map_;
57+
} else if constexpr (std::is_same<T, Block*>::value) {
58+
return block_map_;
59+
} else if constexpr (std::is_same<T, Operation*>::value) {
60+
return operation_map_;
61+
} else {
62+
IR_THROW("Not support type in IRMapping.");
63+
}
64+
}
65+
template <typename T, typename S>
66+
void Add(T from, S to) {
2767
if (!from) return;
28-
MutableMap<T>()[from] = to;
68+
GetMutableMap<IrType<T>>()[from] = to;
2969
}
3070

3171
template <typename T>
3272
T Lookup(T from) const {
3373
if (!from) return static_cast<T>(nullptr);
34-
IR_ENFORCE(Map<T>().count(from) > 0, "Not found key in IRMapping.");
35-
return Map<T>().at(from);
74+
IR_ENFORCE(GetMap<IrType<T>>().count(from) > 0,
75+
"Not found key in IRMapping.");
76+
return GetMap<IrType<T>>().at(from);
3677
}
3778

3879
template <typename T>
3980
void Earse(T from) {
40-
MutableMap<T>().erase(from);
81+
GetMutableMap<IrType<T>>().erase(from);
4182
}
4283

4384
void Clear() {
@@ -46,37 +87,10 @@ class IrMapping {
4687
operation_map_.clear();
4788
}
4889

49-
template <typename T>
50-
using MapType = std::unordered_map<T, T>;
51-
52-
template <typename T>
53-
const MapType<T> &Map() const {
54-
if constexpr (std::is_convertible<T, Value>::value)
55-
return value_map_;
56-
else if constexpr (std::is_convertible<T, Block *>::value)
57-
return block_map_;
58-
else if constexpr (std::is_convertible<T, Operation *>::value)
59-
return operation_map_;
60-
else
61-
IR_THROW("Not support type in IRMapping.");
62-
}
63-
64-
template <typename T>
65-
MapType<T> &MutableMap() {
66-
if constexpr (std::is_convertible<T, Value>::value)
67-
return value_map_;
68-
else if constexpr (std::is_convertible<T, Block *>::value)
69-
return block_map_;
70-
else if constexpr (std::is_convertible<T, Operation *>::value)
71-
return operation_map_;
72-
else
73-
IR_THROW("Not support type in IRMapping.");
74-
}
75-
7690
private:
77-
MapType<Value> value_map_;
78-
MapType<Block *> block_map_;
79-
MapType<Operation *> operation_map_;
91+
std::unordered_map<Value, Value> value_map_;
92+
std::unordered_map<Block*, Block*> block_map_;
93+
std::unordered_map<Operation*, Operation*> operation_map_;
8094
};
8195

8296
} // namespace pir

paddle/pir/core/operation.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,7 @@ Operation *Operation::Clone(IrMapping &ir_mapping, CloneOptions options) {
158158

159159
// record outputs mapping info
160160
for (uint32_t i = 0; i < num_results_; ++i) {
161-
ir_mapping.Add(static_cast<Value>(result(i)),
162-
static_cast<Value>(new_op->result(i)));
161+
ir_mapping.Add(result(i), new_op->result(i));
163162
}
164163

165164
if (options.IsCloneRegions()) {

0 commit comments

Comments
 (0)