Skip to content

Commit f321e5f

Browse files
committed
refine case when thread_num = 1
1 parent f068e08 commit f321e5f

File tree

1 file changed

+17
-3
lines changed

1 file changed

+17
-3
lines changed

paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,16 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor(
4747
<< "Change thread number to 1 because the toposort order is unique";
4848
strategy_.num_threads_ = 1;
4949
}
50-
pool_.reset(new ::ThreadPool(strategy.num_threads_));
50+
if (strategy_.num_threads_ > 1) {
51+
pool_.reset(new ::ThreadPool(strategy.num_threads_));
52+
} else {
53+
auto nodes = ir::TopologySortOperations(*graph_);
54+
traced_ops_.clear();
55+
traced_ops_.reserve(nodes.size());
56+
for (auto *node : nodes) {
57+
traced_ops_.push_back(&node->Wrapper<OpHandleBase>());
58+
}
59+
}
5160
for (auto &op : ir::FilterByNodeWrapper<OpHandleBase>(*graph_)) {
5261
int dep = static_cast<int>(op->NotReadyInputSize());
5362
op_deps_.emplace(op, dep);
@@ -228,7 +237,7 @@ void FastThreadedSSAGraphExecutor::RunOpAsync(
228237
OpHandleBase *op,
229238
const std::shared_ptr<BlockingQueue<size_t>> &complete_q) {
230239
++remaining_;
231-
this->pool_->enqueue([=] {
240+
auto func = [=] {
232241
std::deque<OpHandleBase *> op_queue;
233242
op_queue.push_front(op);
234243

@@ -287,7 +296,12 @@ void FastThreadedSSAGraphExecutor::RunOpAsync(
287296
}
288297
--remaining_;
289298
complete_q->Push(complete);
290-
});
299+
};
300+
if (pool_) {
301+
pool_->enqueue(func);
302+
} else {
303+
func();
304+
}
291305
}
292306

293307
void FastThreadedSSAGraphExecutor::PrepareAtomicOpDeps() {

0 commit comments

Comments
 (0)