Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,13 @@ platform::DeviceContext* ParseDeviceContext(
return dev_ctx;
}

if (op_name == interpreter::kMemcpyD2H) {
if (op_name.compare(paddle::dialect::MemcpyD2hOp::name()) == 0) {
dev_ctx = ctx_manager.Get(std::string(kD2HStream), place, stream_priority)
.get()
.get();
interpreter::SetDeviceCommContext(op, dev_ctx);
return dev_ctx;
} else if (op_name == interpreter::kMemcpyH2D) {
} else if (op_name.compare(paddle::dialect::MemcpyH2dOp::name()) == 0) {
dev_ctx = ctx_manager.Get(std::string(kH2DStream), place, stream_priority)
.get()
.get();
Expand All @@ -114,9 +114,11 @@ platform::DeviceContext* ParseDeviceContext(
// DeviceContext passed from executor (see CAllReduceOpCUDAKernel in
// c_allreduce_op.h). Now it is just a temporary solution for ONLY
// c_allreduce_sum which is used in ResNet50 distributed training.
if (op_name == "c_allreduce_sum" && op_attributes.at("use_calc_stream")
.dyn_cast<pir::BoolAttribute>()
.data() == false) {
if ((op_name.compare(paddle::dialect::CAllreduceSumOp::name()) == 0 ||
op_name.compare(paddle::dialect::CAllreduceSum_Op::name()) == 0) &&
op_attributes.at("use_calc_stream")
.dyn_cast<pir::BoolAttribute>()
.data() == false) {
int ring_id =
op_attributes.at("ring_id").dyn_cast<pir::Int32Attribute>().data();
if (FLAGS_dynamic_static_unified_comm) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,10 @@ LegacyKernelInstruction::LegacyKernelInstruction(
}
SetEventsToWaitInfo(events_to_wait);
}
VLOG(6) << "finish process dist attributes";
VLOG(6) << "finish process dist attributes for " << op_name
<< " : [execution_stream, stream_priority, scheduling_priority] = ["
<< GetExecutionStream() << ", " << GetStreamPriority() << ", "
<< GetSchedulingPriority() << "]";

SetKernelType(AnalyseOpFuncType(op, place));
VLOG(6) << "finish process analyse kernel type";
Expand Down Expand Up @@ -137,28 +140,29 @@ LegacyKernelInstruction::LegacyKernelInstruction(

operator_base_ = BuildOperatorBase(op, *value_exec_info_, yaml_info_parser);

SetDeviceContext(
ParseDeviceContext(op,
phi::DeviceContextPool::Instance().Get(
phi::TransToPhiPlace(kernel_key.backend())),
place,
GetExecutionStream(),
GetStreamPriority()));
VLOG(6) << "finish process device context";

paddle::framework::VariableValueMap in_map;
paddle::framework::VariableValueMap out_map;
auto dev_ctx = phi::DeviceContextPool::Instance().Get(
phi::TransToPhiPlace(kernel_key.backend()));

runtime_context_ = std::make_shared<paddle::framework::RuntimeContext>(
paddle::framework::RuntimeContext(in_map, out_map));
BuildRuntimeContext(
op, *value_exec_info, yaml_info_parser, runtime_context_.get());

kernel_context_ = new paddle::framework::ExecutionContext(
*operator_base_, *inner_scope, *dev_ctx, *(runtime_context_.get()));
kernel_context_ =
new paddle::framework::ExecutionContext(*operator_base_,
*inner_scope,
DeviceContext(),
*(runtime_context_.get()));

VLOG(6) << "finish process kernel context";
SetDeviceContext(
ParseDeviceContext(op,
phi::DeviceContextPool::Instance().Get(
phi::TransToPhiPlace(kernel_key.backend())),
place,
GetExecutionStream(),
GetStreamPriority()));
VLOG(6) << "finish process device context";

InitInputsOutputsIds(op, *value_exec_info);
VLOG(6) << "finish process inputs outputs index";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,10 @@ PhiKernelInstruction::PhiKernelInstruction(
}
SetEventsToWaitInfo(events_to_wait);
}
VLOG(6) << "finish process dist attributes";
VLOG(6) << "finish process dist attributes for " << op_name
<< " : [execution_stream, stream_priority, scheduling_priority] = ["
<< GetExecutionStream() << ", " << GetStreamPriority() << ", "
<< GetSchedulingPriority() << "]";

SetKernelType(AnalyseOpFuncType(op, place));
VLOG(6) << "finish process analyse kernel type";
Expand Down Expand Up @@ -137,6 +140,16 @@ PhiKernelInstruction::PhiKernelInstruction(
phi_kernel_->IsValid(), true, "not found kernel for [%s]", kernel_name);
VLOG(6) << "finish process select kernel";

platform::DeviceContext* dev_ctx =
ParseDeviceContext(op,
phi::DeviceContextPool::Instance().Get(
phi::TransToPhiPlace(kernel_key.backend())),
place,
GetExecutionStream(),
GetStreamPriority());
SetDeviceContext(dev_ctx);
VLOG(6) << "finish process device context";

BuildPhiContext<phi::KernelContext,
const phi::TensorBase*,
phi::TensorBase*,
Expand All @@ -145,19 +158,9 @@ PhiKernelInstruction::PhiKernelInstruction(
true>(
op, *value_exec_info_, yaml_info_parser, &kernel_context_);

kernel_context_.SetDeviceContext(phi::DeviceContextPool::Instance().Get(
phi::TransToPhiPlace(kernel_key.backend())));
kernel_context_.SetDeviceContext(dev_ctx);
VLOG(6) << "finish process kernel context";

SetDeviceContext(
ParseDeviceContext(op,
phi::DeviceContextPool::Instance().Get(
phi::TransToPhiPlace(kernel_key.backend())),
place,
GetExecutionStream(),
GetStreamPriority()));
VLOG(6) << "finish process device context";

InitInputsOutputsIds(op, *value_exec_info);
VLOG(6) << "finish process inputs outputs index";

Expand Down