Skip to content

Commit 8358d61

Browse files
fix 3 bug of new_executor (#37142)
* fix 3 bug, test=develop * refine, test=develop
1 parent b628c31 commit 8358d61

File tree

6 files changed

+20
-5
lines changed

6 files changed

+20
-5
lines changed

paddle/fluid/framework/new_executor/interpretercore.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,9 @@ void InterpreterCore::Convert() {
9898

9999
for (auto& item : op_func_node.input_index) {
100100
for (auto id : item.second) {
101+
if (id == kEmptyVarIndex) {
102+
continue;
103+
}
101104
input_var2op_info_.at(id).push_back(op_idx);
102105
// var can be gc-ed
103106
if (!info.IsBuilt()) {

paddle/fluid/framework/new_executor/interpretercore_garbage_collector.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@ void InterpreterCoreGarbageCollector::Add(
6060
void InterpreterCoreGarbageCollector::Add(paddle::framework::Variable* var,
6161
paddle::platform::DeviceEvent& event,
6262
const platform::DeviceContext* ctx) {
63+
if (!var) {
64+
return;
65+
}
66+
6367
if (var->IsType<LoDTensor>()) {
6468
Add(var->GetMutable<LoDTensor>()->MoveMemoryHolder(), event, ctx);
6569
} else if (var->IsType<

paddle/fluid/framework/new_executor/interpretercore_util.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,13 @@ void build_op_func_list(const platform::Place& place,
446446
VariableValueMap ins_map;
447447
VariableIdMap ins_name2id;
448448
bool enforce_exist = true;
449-
if (op->Type() == "recurrent_grad") enforce_exist = false;
449+
if (op->Type() == "recurrent_grad" || op->Type() == "rnn_memory_helper" ||
450+
op->Type() == "rnn_memory_helper_grad" ||
451+
op->Type() == "conditional_block" ||
452+
op->Type() == "conditional_block_grad" || op->Type() == "while" ||
453+
op->Type() == "while_grad") {
454+
enforce_exist = false;
455+
}
450456
std::tie(ins_map, ins_name2id) =
451457
build_variable_map(inputs_names, var_scope, enforce_exist);
452458

paddle/fluid/framework/new_executor/new_executor_defs.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,7 @@ const std::vector<Variable*>& InterpretercoreInferShapeContext::OutputVars(
480480
VariableScope::VariableScope(Scope* scope) {
481481
// for @EMPTY@ variable
482482
var_list_.push_back(nullptr);
483-
name2id_[kEmptyVarName] = 0;
483+
name2id_[kEmptyVarName] = kEmptyVarIndex;
484484
vec_meta_info_.emplace_back(0, nullptr);
485485
scope_ = scope;
486486
PADDLE_ENFORCE_NE(

paddle/fluid/framework/new_executor/new_executor_defs.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ using OpKernelComputeFunc = std::function<void(const ExecutionContext&)>;
4343
using OpKernelMap =
4444
std::unordered_map<OpKernelType, OpKernelComputeFunc, OpKernelType::Hash>;
4545

46+
constexpr int kEmptyVarIndex = 0;
47+
4648
class InterpretercoreInferShapeContext : public InferShapeContext {
4749
public:
4850
InterpretercoreInferShapeContext(const OperatorBase& op,

python/paddle/fluid/executor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -598,13 +598,13 @@ def _get_exe_from_cache(self, program, scope):
598598
assert isinstance(
599599
program, Program), "Required type(Program), but received {}".format(
600600
type(program).__name__)
601-
if program not in self._cached_executors:
601+
if str(program) not in self._cached_executors:
602602
new_program = program.clone()
603603
_prune_feed_ops(new_program)
604604
new_exe = _StandaloneExecutor(self._place, new_program, scope)
605-
self._cached_executors[program] = new_exe
605+
self._cached_executors[str(program)] = new_exe
606606

607-
return self._cached_executors[program]
607+
return self._cached_executors[str(program)]
608608

609609

610610
class Executor(object):

0 commit comments

Comments
 (0)