Skip to content

Commit 577fdde

Browse files
authored
Revert "[Dy2Stat] Refactor ExecutorCache logic and pre-support BuildStrategy for pass (#34181)" (#34348)
This reverts commit 609f822.
1 parent 0f60998 commit 577fdde

File tree

11 files changed

+156
-165
lines changed

11 files changed

+156
-165
lines changed

paddle/fluid/framework/executor_cache.cc

Lines changed: 40 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
// limitations under the License.
1414

1515
#include "paddle/fluid/framework/executor_cache.h"
16-
#include "paddle/fluid/framework/op_info.h"
1716

1817
namespace paddle {
1918
namespace framework {
@@ -26,11 +25,11 @@ namespace framework {
2625

2726
namespace details {
2827

29-
static ExecutionStrategy GetExecutionStrategy(const platform::Place &place) {
28+
static ExecutionStrategy GetExecutionStrategy(
29+
const ExecutorInfoCache::CacheKey &cache_key) {
3030
framework::ExecutionStrategy execution_strategy;
3131

32-
auto device_type = platform::Place2DeviceType(place);
33-
switch (device_type) {
32+
switch (cache_key.device_type_) {
3433
case platform::DeviceType::CPU: {
3534
execution_strategy.num_threads_ = 2;
3635
break;
@@ -47,9 +46,9 @@ static ExecutionStrategy GetExecutionStrategy(const platform::Place &place) {
4746
}
4847
default:
4948
PADDLE_THROW(platform::errors::Unavailable("Unsupported Device type %d.",
50-
device_type));
49+
cache_key.device_type_));
5150
}
52-
execution_strategy.use_device_ = device_type;
51+
execution_strategy.use_device_ = cache_key.device_type_;
5352

5453
return execution_strategy;
5554
}
@@ -137,44 +136,58 @@ ExecutorInfoCache &ExecutorInfoCache::Instance() {
137136
return g_exe_cache_info_map;
138137
}
139138

140-
CacheInfo GetExecutorInfoFromCache(const ProgramDesc &program_desc,
141-
const platform::Place &place,
142-
int64_t start_op_index, int64_t end_op_index,
143-
bool is_grad, int64_t program_id,
139+
void ExecutorInfoCache::Finalize() {
140+
// NOTE(Aurelius84): DO NOT perform finalize in destructor
141+
// to avoid problems caused by destructor order of static
142+
// object.
143+
info_map_.clear();
144+
}
145+
146+
CacheInfo GetExecutorInfoFromCache(const ExecutorInfoCache::CacheKey &cache_key,
144147
framework::Scope *scope) {
145148
auto &cached_exe_info = framework::ExecutorInfoCache::Instance();
146149

147-
if (!cached_exe_info.Has(program_id, is_grad)) {
148-
VLOG(1) << "create exe_info for " << program_id << " is_grad: " << is_grad;
149-
auto execution_strategy = details::GetExecutionStrategy(place);
150-
auto &build_strategy = cached_exe_info.GetBuildStrategy(program_id);
150+
if (!cached_exe_info.Has(cache_key)) {
151+
VLOG(1) << "create exe_info for " << cache_key.DebugString();
152+
153+
// TODO(Aurelius84): Consider to use LRU algorithm to replace this.
154+
if (cached_exe_info.Size() > 4u /* max_cached_size*/) {
155+
VLOG(2) << "The cached info size has exceeded max_cached_size: 4, clear "
156+
"all cache!";
157+
cached_exe_info.Finalize();
158+
}
159+
160+
framework::BuildStrategy build_strategy;
161+
auto execution_strategy = details::GetExecutionStrategy(cache_key);
151162

152-
// 2. Construct Graph and ParallelExecutor.
153163
auto graph = std::make_shared<framework::ir::Graph>(
154-
program_desc, start_op_index, end_op_index);
164+
*cache_key.program_desc_, cache_key.start_op_index_,
165+
cache_key.end_op_index_);
155166
auto parallel_executor = std::make_shared<framework::ParallelExecutor>(
156-
place, scope, execution_strategy, build_strategy, graph.get());
167+
cache_key.place_, scope, execution_strategy, build_strategy,
168+
graph.get());
157169
parallel_executor->PrepareVariables(scope);
158170

159-
// 3. Insert value into cached map.
160-
auto &cached_value = cached_exe_info.GetMutable(program_id, is_grad);
161-
cached_value.executor_ = parallel_executor;
162-
cached_value.graph_ = std::move(graph);
163-
return std::make_pair(parallel_executor, /*is_new_created=*/true);
171+
framework::ExecutorInfoCache::ValueType cache_val = {parallel_executor,
172+
graph};
173+
cached_exe_info.Insert(cache_key, cache_val);
174+
175+
bool is_new_created = true;
176+
return std::make_pair(parallel_executor, is_new_created);
164177
} else {
165-
VLOG(1) << "get exe_info from cache by: " << program_id
166-
<< " is_grad: " << is_grad;
167-
auto &cached_value = cached_exe_info.GetMutable(program_id, is_grad);
178+
VLOG(1) << "get exe_info from cache by: " << cache_key.DebugString();
179+
bool is_new_created = false;
180+
auto cache_val = cached_exe_info.GetMutable(cache_key);
181+
auto parallel_executor = cache_val.first;
168182

169-
auto &parallel_executor = cached_value.executor_;
170183
// update op_handle scope_map in pe->executor_->Graph
171184
std::unordered_map<Scope *, Scope *> scope_map = {
172185
{parallel_executor->GetLocalScopes().front(), scope}};
173186
parallel_executor->ResetOpHandleScopeMapOfGraphs(scope_map);
174187
// need to recreate tmp variables in new scope
175188
parallel_executor->PrepareVariables(scope);
176189

177-
return std::make_pair(parallel_executor, /*is_new_created=*/false);
190+
return std::make_pair(parallel_executor, is_new_created);
178191
}
179192
}
180193

paddle/fluid/framework/executor_cache.h

Lines changed: 92 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -45,90 +45,121 @@ void ParseSafeEagerDeletionSkipVars(
4545
std::vector<std::string>* skip_eager_delete_vars);
4646

4747
} // namespace details
48-
49-
class ExecutorInfo {
48+
class ExecutorInfoCache {
5049
public:
51-
struct CacheValue {
52-
std::shared_ptr<ParallelExecutor> executor_{nullptr};
53-
std::shared_ptr<ir::Graph> graph_{nullptr};
54-
55-
std::vector<std::string> skip_eager_delete_vars_;
50+
struct CacheKey {
51+
CacheKey(const ProgramDesc* program_desc, const platform::Place& place,
52+
int64_t start_op_index, int64_t end_op_index, bool is_grad)
53+
: program_desc_(program_desc),
54+
place_(place),
55+
start_op_index_(start_op_index),
56+
end_op_index_(end_op_index),
57+
is_grad_(is_grad) {
58+
device_type_ = platform::Place2DeviceType(place);
59+
PADDLE_ENFORCE_NOT_NULL(program_desc_,
60+
"program_desc should not be null.");
61+
}
62+
63+
std::string DebugString() const {
64+
std::stringstream ss;
65+
66+
ss << "\n CacheKey(program_desc: " << program_desc_;
67+
ss << ", start_op_index: " << start_op_index_;
68+
ss << ", end_op_index: " << end_op_index_;
69+
ss << ", is_grad: " << is_grad_;
70+
ss << ", device_type: " << device_type_ << ")";
71+
72+
return ss.str();
73+
}
74+
75+
const ProgramDesc* program_desc_;
76+
platform::Place place_;
77+
int64_t start_op_index_;
78+
int64_t end_op_index_;
79+
bool is_grad_;
80+
platform::DeviceType device_type_;
5681
};
5782

58-
bool IsAvailable(bool is_grad) {
59-
const auto& executor =
60-
is_grad ? backward_info_.executor_ : forward_info_.executor_;
61-
return executor != nullptr;
62-
}
63-
64-
CacheValue& GetMutable(bool is_grad) {
65-
return is_grad ? backward_info_ : forward_info_;
66-
}
67-
68-
private:
69-
CacheValue forward_info_;
70-
CacheValue backward_info_;
71-
};
83+
using KeyType = size_t;
84+
using ValueType =
85+
std::pair<std::shared_ptr<ParallelExecutor>, std::shared_ptr<ir::Graph>>;
86+
87+
struct KeyHasher {
88+
size_t operator()(const CacheKey& key) const noexcept {
89+
size_t seed = 10;
90+
auto* prog_desc = key.program_desc_;
91+
/*
92+
* Note(Aurelius84): DO NOT use only ProgramDesc* to calculate hash value
93+
* because a new program will hold same pointer address after an older
94+
* program is destructed with a small probability. Add op size while
95+
* hashing because program may contains at least one block.
96+
*/
97+
hash_combine(&seed, prog_desc);
98+
for (size_t i = 0; i < prog_desc->Size(); ++i) {
99+
hash_combine(&seed, &prog_desc->Block(i));
100+
hash_combine(&seed, prog_desc->Block(i).OpSize());
101+
}
102+
hash_combine(&seed, static_cast<int>(key.device_type_));
103+
hash_combine(&seed, key.start_op_index_);
104+
hash_combine(&seed, key.end_op_index_);
105+
hash_combine(&seed, key.is_grad_);
106+
VLOG(3) << "hash value is : " << seed
107+
<< " of key: " << key.DebugString();
108+
return seed;
109+
}
110+
111+
template <typename T>
112+
void hash_combine(size_t* seed, const T& val) const {
113+
std::hash<T> hasher;
114+
(*seed) ^= hasher(val) + 0x9e3779b9 + ((*seed) << 6) + ((*seed >> 2));
115+
}
116+
};
72117

73-
class ExecutorInfoCache {
74-
public:
75118
static ExecutorInfoCache& Instance();
76119

77-
const BuildStrategy& GetBuildStrategy(int64_t program_id) {
78-
// If not found, insert build_strategy with default value.
79-
return strategy_map_[program_id];
80-
}
81-
82-
void SetBuildStrategy(int64_t program_id,
83-
const BuildStrategy& build_strategy) {
120+
ValueType GetMutable(const CacheKey& key) {
121+
auto key_val = key_hash_func_(key);
84122
PADDLE_ENFORCE_EQ(
85-
strategy_map_.count(program_id), 0,
86-
platform::errors::PreconditionNotMet(
87-
"program_id: %s already exist in ExecutorInfoCache", program_id));
88-
strategy_map_[program_id] = build_strategy;
123+
Has(key_val), true,
124+
platform::errors::NotFound("%s doesn't exist in ExecutorInfoCache",
125+
key.DebugString()));
126+
return info_map_[key_val];
89127
}
90128

91-
bool Has(int64_t program_id, bool is_grad) {
92-
return info_map_.find(program_id) != info_map_.end() &&
93-
info_map_[program_id].IsAvailable(is_grad);
129+
bool Has(const CacheKey& key) const {
130+
auto key_val = key_hash_func_(key);
131+
return Has(key_val);
94132
}
95133

96-
ExecutorInfo::CacheValue& GetMutable(int64_t program_id, bool is_grad) {
97-
return info_map_[program_id].GetMutable(is_grad);
134+
bool Has(const KeyType& key) const {
135+
return info_map_.find(key) != info_map_.end();
98136
}
99137

100-
void UpdateSkipEagerDeleteVars(int64_t program_id, bool is_grad,
101-
const std::vector<std::string>& skip_vars) {
102-
auto& cached_value = GetMutable(program_id, is_grad);
103-
cached_value.skip_eager_delete_vars_ = std::move(skip_vars);
138+
void Insert(const CacheKey& key, ValueType value) {
139+
auto key_val = key_hash_func_(key);
140+
PADDLE_ENFORCE_EQ(
141+
Has(key_val), false,
142+
platform::errors::NotFound("%s has existed in ExecutorInfoCache",
143+
key.DebugString()));
144+
info_map_.insert({key_val, value});
104145
}
105146

106-
std::vector<std::string>& SkipEagerDeleteVars(int64_t program_id,
107-
bool is_grad) {
108-
auto& cached_value = GetMutable(program_id, is_grad);
109-
return cached_value.skip_eager_delete_vars_;
110-
}
147+
size_t Size() const { return info_map_.size(); }
111148

112-
void Finalize() {
113-
// NOTE(Aurelius84): DO NOT perform finalize in destructor
114-
// to avoid problems caused by destructor order of static
115-
// object.
116-
info_map_.clear();
117-
strategy_map_.clear();
118-
}
149+
void Finalize();
119150

120151
private:
121-
std::unordered_map<int64_t, ExecutorInfo> info_map_;
122-
std::unordered_map<int64_t, BuildStrategy> strategy_map_;
152+
ExecutorInfoCache() = default;
153+
DISABLE_COPY_AND_ASSIGN(ExecutorInfoCache);
154+
155+
KeyHasher key_hash_func_;
156+
std::unordered_map<KeyType, ValueType> info_map_;
123157
};
124158

125159
using CacheInfo =
126160
std::pair<std::shared_ptr<ParallelExecutor>, bool /*is_new_created*/>;
127161

128-
CacheInfo GetExecutorInfoFromCache(const ProgramDesc& program_desc,
129-
const platform::Place& place,
130-
int64_t start_op_index, int64_t end_op_index,
131-
bool is_grad, int64_t program_id,
162+
CacheInfo GetExecutorInfoFromCache(const ExecutorInfoCache::CacheKey& cache_key,
132163
framework::Scope* scope);
133164

134165
} // namespace framework

paddle/fluid/operators/run_program_op.cc

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,6 @@ class RunProgramOpMaker : public framework::OpProtoAndCheckerMaker {
103103
"(bool, default false) Set to true for inference only, false "
104104
"for training.")
105105
.SetDefault(false);
106-
AddAttr<int64_t>(
107-
"program_id",
108-
"(int64_t)"
109-
"The unique hash id used as cache key for ExecutorInfoCache.");
110106
AddComment(R"DOC(
111107
RunProgram operator.
112108

0 commit comments

Comments
 (0)