Skip to content

Commit 5017ba8

Browse files
authored
[PIR] Fix bug in ParseDeviceContext (#60705)
* fix * fix * fix * fix * fix
1 parent b83a6fd commit 5017ba8

3 files changed

Lines changed: 40 additions & 31 deletions

File tree

paddle/fluid/framework/new_executor/instruction/instruction_util.cc

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,13 +91,13 @@ platform::DeviceContext* ParseDeviceContext(
9191
return dev_ctx;
9292
}
9393

94-
if (op_name == interpreter::kMemcpyD2H) {
94+
if (op_name.compare(paddle::dialect::MemcpyD2hOp::name()) == 0) {
9595
dev_ctx = ctx_manager.Get(std::string(kD2HStream), place, stream_priority)
9696
.get()
9797
.get();
9898
interpreter::SetDeviceCommContext(op, dev_ctx);
9999
return dev_ctx;
100-
} else if (op_name == interpreter::kMemcpyH2D) {
100+
} else if (op_name.compare(paddle::dialect::MemcpyH2dOp::name()) == 0) {
101101
dev_ctx = ctx_manager.Get(std::string(kH2DStream), place, stream_priority)
102102
.get()
103103
.get();
@@ -114,9 +114,11 @@ platform::DeviceContext* ParseDeviceContext(
114114
// DeviceContext passed from executor (see CAllReduceOpCUDAKernel in
115115
// c_allreduce_op.h). Now it is just a temporary solution for ONLY
116116
// c_allreduce_sum which is used in ResNet50 distributed training.
117-
if (op_name == "c_allreduce_sum" && op_attributes.at("use_calc_stream")
118-
.dyn_cast<pir::BoolAttribute>()
119-
.data() == false) {
117+
if ((op_name.compare(paddle::dialect::CAllreduceSumOp::name()) == 0 ||
118+
op_name.compare(paddle::dialect::CAllreduceSum_Op::name()) == 0) &&
119+
op_attributes.at("use_calc_stream")
120+
.dyn_cast<pir::BoolAttribute>()
121+
.data() == false) {
120122
int ring_id =
121123
op_attributes.at("ring_id").dyn_cast<pir::Int32Attribute>().data();
122124
if (FLAGS_dynamic_static_unified_comm) {

paddle/fluid/framework/new_executor/instruction/legacy_kernel_instruction.cc

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,10 @@ LegacyKernelInstruction::LegacyKernelInstruction(
9090
}
9191
SetEventsToWaitInfo(events_to_wait);
9292
}
93-
VLOG(6) << "finish process dist attributes";
93+
VLOG(6) << "finish process dist attributes for " << op_name
94+
<< " : [execution_stream, stream_priority, scheduling_priority] = ["
95+
<< GetExecutionStream() << ", " << GetStreamPriority() << ", "
96+
<< GetSchedulingPriority() << "]";
9497

9598
SetKernelType(AnalyseOpFuncType(op, place));
9699
VLOG(6) << "finish process analyse kernel type";
@@ -137,28 +140,29 @@ LegacyKernelInstruction::LegacyKernelInstruction(
137140

138141
operator_base_ = BuildOperatorBase(op, *value_exec_info_, yaml_info_parser);
139142

143+
SetDeviceContext(
144+
ParseDeviceContext(op,
145+
phi::DeviceContextPool::Instance().Get(
146+
phi::TransToPhiPlace(kernel_key.backend())),
147+
place,
148+
GetExecutionStream(),
149+
GetStreamPriority()));
150+
VLOG(6) << "finish process device context";
151+
140152
paddle::framework::VariableValueMap in_map;
141153
paddle::framework::VariableValueMap out_map;
142-
auto dev_ctx = phi::DeviceContextPool::Instance().Get(
143-
phi::TransToPhiPlace(kernel_key.backend()));
144-
145154
runtime_context_ = std::make_shared<paddle::framework::RuntimeContext>(
146155
paddle::framework::RuntimeContext(in_map, out_map));
147156
BuildRuntimeContext(
148157
op, *value_exec_info, yaml_info_parser, runtime_context_.get());
149158

150-
kernel_context_ = new paddle::framework::ExecutionContext(
151-
*operator_base_, *inner_scope, *dev_ctx, *(runtime_context_.get()));
159+
kernel_context_ =
160+
new paddle::framework::ExecutionContext(*operator_base_,
161+
*inner_scope,
162+
DeviceContext(),
163+
*(runtime_context_.get()));
152164

153165
VLOG(6) << "finish process kernel context";
154-
SetDeviceContext(
155-
ParseDeviceContext(op,
156-
phi::DeviceContextPool::Instance().Get(
157-
phi::TransToPhiPlace(kernel_key.backend())),
158-
place,
159-
GetExecutionStream(),
160-
GetStreamPriority()));
161-
VLOG(6) << "finish process device context";
162166

163167
InitInputsOutputsIds(op, *value_exec_info);
164168
VLOG(6) << "finish process inputs outputs index";

paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.cc

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,10 @@ PhiKernelInstruction::PhiKernelInstruction(
9494
}
9595
SetEventsToWaitInfo(events_to_wait);
9696
}
97-
VLOG(6) << "finish process dist attributes";
97+
VLOG(6) << "finish process dist attributes for " << op_name
98+
<< " : [execution_stream, stream_priority, scheduling_priority] = ["
99+
<< GetExecutionStream() << ", " << GetStreamPriority() << ", "
100+
<< GetSchedulingPriority() << "]";
98101

99102
SetKernelType(AnalyseOpFuncType(op, place));
100103
VLOG(6) << "finish process analyse kernel type";
@@ -137,6 +140,16 @@ PhiKernelInstruction::PhiKernelInstruction(
137140
phi_kernel_->IsValid(), true, "not found kernel for [%s]", kernel_name);
138141
VLOG(6) << "finish process select kernel";
139142

143+
platform::DeviceContext* dev_ctx =
144+
ParseDeviceContext(op,
145+
phi::DeviceContextPool::Instance().Get(
146+
phi::TransToPhiPlace(kernel_key.backend())),
147+
place,
148+
GetExecutionStream(),
149+
GetStreamPriority());
150+
SetDeviceContext(dev_ctx);
151+
VLOG(6) << "finish process device context";
152+
140153
BuildPhiContext<phi::KernelContext,
141154
const phi::TensorBase*,
142155
phi::TensorBase*,
@@ -145,19 +158,9 @@ PhiKernelInstruction::PhiKernelInstruction(
145158
true>(
146159
op, *value_exec_info_, yaml_info_parser, &kernel_context_);
147160

148-
kernel_context_.SetDeviceContext(phi::DeviceContextPool::Instance().Get(
149-
phi::TransToPhiPlace(kernel_key.backend())));
161+
kernel_context_.SetDeviceContext(dev_ctx);
150162
VLOG(6) << "finish process kernel context";
151163

152-
SetDeviceContext(
153-
ParseDeviceContext(op,
154-
phi::DeviceContextPool::Instance().Get(
155-
phi::TransToPhiPlace(kernel_key.backend())),
156-
place,
157-
GetExecutionStream(),
158-
GetStreamPriority()));
159-
VLOG(6) << "finish process device context";
160-
161164
InitInputsOutputsIds(op, *value_exec_info);
162165
VLOG(6) << "finish process inputs outputs index";
163166

0 commit comments

Comments
 (0)