Skip to content

Commit beef7af

Browse files
authored
Add activation offloader (#74837)
* Add activation offloader * fix mac compile error * fix windows compile error * fix compile error of Windows and XPU * remove dist_api_gen.py modification * add activation offloader ut * fix ut * fix ut on windows * fix ut and improve converage * add more ut * improve coverage
1 parent e8e81ce commit beef7af

File tree

22 files changed

+800
-33
lines changed

22 files changed

+800
-33
lines changed

paddle/common/flags.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2140,6 +2140,16 @@ PHI_DEFINE_EXPORTED_bool(
21402140
false,
21412141
"Enable add lock when call AutoGrowthBestFitAllocator::ReleaseImpl");
21422142

2143+
PHI_DEFINE_EXPORTED_int64(offload_retry_times, -1, "Offload retry times.");
2144+
2145+
PHI_DEFINE_EXPORTED_bool(offload_inplace_tensor,
2146+
true,
2147+
"Whether to allow offload inplace tensor.");
2148+
2149+
PHI_DEFINE_EXPORTED_bool(print_offload_info,
2150+
false,
2151+
"Whether to print the offload information.");
2152+
21432153
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
21442154
/**
21452155
* FlashAttention related FLAG

paddle/fluid/distributed/collective/process_group_nccl.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,15 @@ ProcessGroupNCCL::~ProcessGroupNCCL() {
166166
}
167167
}
168168

169+
void ProcessGroupNCCL::EraseStream(const phi::DenseTensor& tensor) const {
170+
if (!tensor.initialized()) return;
171+
auto place = tensor.place();
172+
auto iter = place_to_comm_ctx_.find(GetKeyFromPlace(place));
173+
if (iter != place_to_comm_ctx_.end()) {
174+
memory::EraseStream(tensor.Holder(), iter->second->stream());
175+
}
176+
}
177+
169178
void ProcessGroupNCCL::GroupStart() {
170179
NCCL_CHECK(phi::dynload::ncclGroupStart());
171180
++s_group_call_counter;

paddle/fluid/distributed/collective/process_group_nccl.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream {
9292
std::shared_ptr<phi::distributed::NCCLConfig> nccl_config = nullptr);
9393
~ProcessGroupNCCL();
9494

95+
void EraseStream(const phi::DenseTensor& tensor) const override;
96+
9597
std::string GetBackendName() const override { return "NCCL"; }
9698

9799
phi::DeviceContext* GetDeviceContext(const Place& place) const override;

paddle/fluid/distributed/collective/process_group_with_stream.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@ class ProcessGroupWithStream : public ProcessGroup {
6060
ProcessGroupWithStream(int rank, int size, int gid)
6161
: ProcessGroup(rank, size, gid) {}
6262

63+
virtual void EraseStream(const phi::DenseTensor& tensor) const {
64+
PADDLE_THROW(phi::errors::Unimplemented("EraseStream is not implemented."));
65+
}
66+
6367
virtual ~ProcessGroupWithStream() = default;
6468

6569
std::shared_ptr<ProcessGroup::Task> AllGather(

paddle/fluid/eager/CMakeLists.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,14 @@ set(eager_deps
1313
grad_tensor_holder
1414
custom_operator_node)
1515

16+
if(WITH_GPU)
17+
cc_library(
18+
activation_offloader
19+
SRCS activation_offloader.cc
20+
DEPS phi_core phi_gpu)
21+
list(APPEND eager_deps activation_offloader)
22+
endif()
23+
1624
if(WITH_GPU OR WITH_ROCM)
1725
set(eager_deps ${eager_deps} phi_gpu)
1826
endif()

0 commit comments

Comments
 (0)