@@ -45,90 +45,121 @@ void ParseSafeEagerDeletionSkipVars(
4545 std::vector<std::string>* skip_eager_delete_vars);
4646
4747} // namespace details
48-
49- class ExecutorInfo {
48+ class ExecutorInfoCache {
5049 public:
51- struct CacheValue {
52- std::shared_ptr<ParallelExecutor> executor_{nullptr };
53- std::shared_ptr<ir::Graph> graph_{nullptr };
54-
55- std::vector<std::string> skip_eager_delete_vars_;
50+ struct CacheKey {
51+ CacheKey (const ProgramDesc* program_desc, const platform::Place& place,
52+ int64_t start_op_index, int64_t end_op_index, bool is_grad)
53+ : program_desc_(program_desc),
54+ place_ (place),
55+ start_op_index_(start_op_index),
56+ end_op_index_(end_op_index),
57+ is_grad_(is_grad) {
58+ device_type_ = platform::Place2DeviceType (place);
59+ PADDLE_ENFORCE_NOT_NULL (program_desc_,
60+ " program_desc should not be null." );
61+ }
62+
63+ std::string DebugString () const {
64+ std::stringstream ss;
65+
66+ ss << " \n CacheKey(program_desc: " << program_desc_;
67+ ss << " , start_op_index: " << start_op_index_;
68+ ss << " , end_op_index: " << end_op_index_;
69+ ss << " , is_grad: " << is_grad_;
70+ ss << " , device_type: " << device_type_ << " )" ;
71+
72+ return ss.str ();
73+ }
74+
75+ const ProgramDesc* program_desc_;
76+ platform::Place place_;
77+ int64_t start_op_index_;
78+ int64_t end_op_index_;
79+ bool is_grad_;
80+ platform::DeviceType device_type_;
5681 };
5782
58- bool IsAvailable (bool is_grad) {
59- const auto & executor =
60- is_grad ? backward_info_.executor_ : forward_info_.executor_ ;
61- return executor != nullptr ;
62- }
63-
64- CacheValue& GetMutable (bool is_grad) {
65- return is_grad ? backward_info_ : forward_info_;
66- }
67-
68- private:
69- CacheValue forward_info_;
70- CacheValue backward_info_;
71- };
83+ using KeyType = size_t ;
84+ using ValueType =
85+ std::pair<std::shared_ptr<ParallelExecutor>, std::shared_ptr<ir::Graph>>;
86+
87+ struct KeyHasher {
88+ size_t operator ()(const CacheKey& key) const noexcept {
89+ size_t seed = 10 ;
90+ auto * prog_desc = key.program_desc_ ;
91+ /*
92+ * Note(Aurelius84): DO NOT use only ProgramDesc* to calculate hash value
93+ * because a new program will hold same pointer address after an older
94+ * program is destructed with a small probability. Add op size while
95+ * hashing because program may contains at least one block.
96+ */
97+ hash_combine (&seed, prog_desc);
98+ for (size_t i = 0 ; i < prog_desc->Size (); ++i) {
99+ hash_combine (&seed, &prog_desc->Block (i));
100+ hash_combine (&seed, prog_desc->Block (i).OpSize ());
101+ }
102+ hash_combine (&seed, static_cast <int >(key.device_type_ ));
103+ hash_combine (&seed, key.start_op_index_ );
104+ hash_combine (&seed, key.end_op_index_ );
105+ hash_combine (&seed, key.is_grad_ );
106+ VLOG (3 ) << " hash value is : " << seed
107+ << " of key: " << key.DebugString ();
108+ return seed;
109+ }
110+
111+ template <typename T>
112+ void hash_combine (size_t * seed, const T& val) const {
113+ std::hash<T> hasher;
114+ (*seed) ^= hasher (val) + 0x9e3779b9 + ((*seed) << 6 ) + ((*seed >> 2 ));
115+ }
116+ };
72117
73- class ExecutorInfoCache {
74- public:
75118 static ExecutorInfoCache& Instance ();
76119
77- const BuildStrategy& GetBuildStrategy (int64_t program_id) {
78- // If not found, insert build_strategy with default value.
79- return strategy_map_[program_id];
80- }
81-
82- void SetBuildStrategy (int64_t program_id,
83- const BuildStrategy& build_strategy) {
120+ ValueType GetMutable (const CacheKey& key) {
121+ auto key_val = key_hash_func_ (key);
84122 PADDLE_ENFORCE_EQ (
85- strategy_map_. count (program_id ), 0 ,
86- platform::errors::PreconditionNotMet (
87- " program_id: %s already exist in ExecutorInfoCache " , program_id ));
88- strategy_map_[program_id] = build_strategy ;
123+ Has (key_val ), true ,
124+ platform::errors::NotFound ( " %s doesn't exist in ExecutorInfoCache " ,
125+ key. DebugString () ));
126+ return info_map_[key_val] ;
89127 }
90128
91- bool Has (int64_t program_id, bool is_grad) {
92- return info_map_. find (program_id) != info_map_. end () &&
93- info_map_[program_id]. IsAvailable (is_grad );
129+ bool Has (const CacheKey& key) const {
130+ auto key_val = key_hash_func_ (key);
131+ return Has (key_val );
94132 }
95133
96- ExecutorInfo::CacheValue& GetMutable ( int64_t program_id, bool is_grad) {
97- return info_map_[program_id]. GetMutable (is_grad );
134+ bool Has ( const KeyType& key) const {
135+ return info_map_. find (key) != info_map_. end ( );
98136 }
99137
100- void UpdateSkipEagerDeleteVars (int64_t program_id, bool is_grad,
101- const std::vector<std::string>& skip_vars) {
102- auto & cached_value = GetMutable (program_id, is_grad);
103- cached_value.skip_eager_delete_vars_ = std::move (skip_vars);
138+ void Insert (const CacheKey& key, ValueType value) {
139+ auto key_val = key_hash_func_ (key);
140+ PADDLE_ENFORCE_EQ (
141+ Has (key_val), false ,
142+ platform::errors::NotFound (" %s has existed in ExecutorInfoCache" ,
143+ key.DebugString ()));
144+ info_map_.insert ({key_val, value});
104145 }
105146
106- std::vector<std::string>& SkipEagerDeleteVars (int64_t program_id,
107- bool is_grad) {
108- auto & cached_value = GetMutable (program_id, is_grad);
109- return cached_value.skip_eager_delete_vars_ ;
110- }
147+ size_t Size () const { return info_map_.size (); }
111148
112- void Finalize () {
113- // NOTE(Aurelius84): DO NOT perform finalize in destructor
114- // to avoid problems caused by destructor order of static
115- // object.
116- info_map_.clear ();
117- strategy_map_.clear ();
118- }
149+ void Finalize ();
119150
120151 private:
121- std::unordered_map<int64_t , ExecutorInfo> info_map_;
122- std::unordered_map<int64_t , BuildStrategy> strategy_map_;
152+ ExecutorInfoCache () = default;
153+ DISABLE_COPY_AND_ASSIGN (ExecutorInfoCache);
154+
155+ KeyHasher key_hash_func_;
156+ std::unordered_map<KeyType, ValueType> info_map_;
123157};
124158
125159using CacheInfo =
126160 std::pair<std::shared_ptr<ParallelExecutor>, bool /* is_new_created*/ >;
127161
128- CacheInfo GetExecutorInfoFromCache (const ProgramDesc& program_desc,
129- const platform::Place& place,
130- int64_t start_op_index, int64_t end_op_index,
131- bool is_grad, int64_t program_id,
162+ CacheInfo GetExecutorInfoFromCache (const ExecutorInfoCache::CacheKey& cache_key,
132163 framework::Scope* scope);
133164
134165} // namespace framework
0 commit comments