Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions paddle/fluid/distributed/collective/process_group_nccl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,15 @@ ProcessGroupNCCL::ProcessGroupNCCL(
int rank,
int size,
int gid,
int64_t timeout)
int64_t timeout,
int nccl_comm_init_option)
: ProcessGroupWithStream(rank, size, gid),
store_(store),
pg_timeout_(timeout) {
pg_timeout_(timeout),
nccl_comm_init_option_(nccl_comm_init_option) {
LOG(INFO) << "ProcessGroupNCCL pg_timeout_ " << pg_timeout_;
LOG(INFO) << "ProcessGroupNCCL nccl_comm_init_option_ "
<< nccl_comm_init_option_;
}
ProcessGroupNCCL::~ProcessGroupNCCL() {
LOG(INFO) << "ProcessGroupNCCL destruct ";
Expand Down Expand Up @@ -718,7 +722,7 @@ void ProcessGroupNCCL::CreateNCCLEnvCache(const Place& place,

phi::distributed::P2POption p2p_opts({is_p2p_op, p2p_rank, num_ranks, rank});
phi::distributed::CommContextManager::CreateNCCLCommContext(
store_, store_key, rank_, size_, "", &p2p_opts);
store_, store_key, rank_, size_, "", &p2p_opts, nccl_comm_init_option_);

NCCL_CHECK(phi::dynload::ncclGroupEnd());

Expand Down Expand Up @@ -1009,9 +1013,10 @@ std::shared_ptr<ProcessGroupNCCL> ProcessGroupNCCL::CreateProcessGroupNCCL(
int rank,
int size,
int gid,
int64_t timeout) {
auto process_group =
std::make_shared<ProcessGroupNCCL>(store, rank, size, gid, timeout);
int64_t timeout,
int nccl_comm_init_option) {
auto process_group = std::make_shared<ProcessGroupNCCL>(
store, rank, size, gid, timeout, nccl_comm_init_option);
ProcessGroupIdMap::GetInstance().emplace(gid, process_group);
return process_group;
}
Expand Down
9 changes: 7 additions & 2 deletions paddle/fluid/distributed/collective/process_group_nccl.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,15 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream {
int rank,
int size,
int gid,
int64_t timeout);
int64_t timeout,
int nccl_comm_init_option);

ProcessGroupNCCL(const std::shared_ptr<phi::distributed::Store>& store,
int rank,
int size,
int gid,
int64_t timeout = 30 * 60 * 1000);
int64_t timeout = 30 * 60 * 1000,
int nccl_comm_init_option = 0);
~ProcessGroupNCCL();

std::string GetBackendName() const override { return "NCCL"; }
Expand Down Expand Up @@ -177,6 +179,8 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream {

ncclComm_t NCCLComm(const Place& place) const;

const bool GetNCCLCommInitOption() { return nccl_comm_init_option_; }

private:
std::shared_ptr<ProcessGroupNCCL::NCCLTask> CreateTask(const Place& place,
int rank,
Expand Down Expand Up @@ -247,6 +251,7 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream {
static uint64_t s_group_call_counter;
// default 30 minutes
int64_t pg_timeout_;
int nccl_comm_init_option_;

// optimize memory for process_group
std::vector<std::pair<std::weak_ptr<phi::Allocation>, gpuStream_t>>
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/platform/dynload/nccl.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ namespace dynload {
__macro(ncclCommInitAll); \
__macro(ncclGetUniqueId); \
__macro(ncclCommInitRank); \
__macro(ncclCommInitRank2); \
__macro(ncclCommAbort); \
__macro(ncclCommDestroy); \
__macro(ncclCommCount); \
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/pybind/communication.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ void BindCommContextManager(py::module *m) {
py::arg("size"),
py::arg("hash_key") = "",
py::arg("p2p_opt") = nullptr,
py::arg("nccl_comm_init_option") = 0,
py::call_guard<py::gil_scoped_release>())
#endif
#if defined(PADDLE_WITH_XPU_BKCL)
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/pybind/distributed_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1235,6 +1235,7 @@ void BindDistributed(py::module *m) {
py::arg("world_size"),
py::arg("group_id") = 0,
py::arg("timeout") = 30 * 60 * 1000,
py::arg("nccl_comm_init_option") = 0,
py::call_guard<py::gil_scoped_release>())
.def_static("group_start", distributed::ProcessGroupNCCL::GroupStart)
.def_static("group_end", distributed::ProcessGroupNCCL::GroupEnd);
Expand Down
11 changes: 10 additions & 1 deletion paddle/phi/backends/dynload/nccl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,20 @@ limitations under the License. */

#include "paddle/phi/backends/dynload/nccl.h"

ncclResult_t ncclCommInitRank2(ncclComm_t* newcomm,
int nranks,
ncclUniqueId commId,
int myrank,
int param) {
// fake impl for compilation
return ncclInvalidUsage;
}

namespace phi {
namespace dynload {

std::once_flag nccl_dso_flag;
void *nccl_dso_handle;
void* nccl_dso_handle;

#define DEFINE_WRAP(__name) DynLoad__##__name __name

Expand Down
13 changes: 13 additions & 0 deletions paddle/phi/backends/dynload/nccl.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,18 @@ limitations under the License. */
#include "paddle/phi/backends/dynload/dynamic_loader.h"
#include "paddle/phi/backends/dynload/port.h"

#ifdef __cplusplus
extern "C" {
#endif
ncclResult_t ncclCommInitRank2(ncclComm_t* newcomm,
int nranks,
ncclUniqueId commId,
int myrank,
int param);
#ifdef __cplusplus
}
#endif

namespace phi {
namespace dynload {

Expand All @@ -44,6 +56,7 @@ extern void* nccl_dso_handle;
__macro(ncclCommInitAll); \
__macro(ncclGetUniqueId); \
__macro(ncclCommInitRank); \
__macro(ncclCommInitRank2); \
__macro(ncclCommAbort); \
__macro(ncclCommDestroy); \
__macro(ncclCommCount); \
Expand Down
7 changes: 4 additions & 3 deletions paddle/phi/core/distributed/comm_context_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ void CommContextManager::CreateNCCLCommContext(
int rank,
int size,
const std::string& hash_key,
const P2POption* p2p_opt) {
const P2POption* p2p_opt,
int nccl_comm_init_option) {
auto& comm_context_manager = CommContextManager::GetInstance();
if (comm_context_manager.Has(unique_comm_key)) {
return;
Expand Down Expand Up @@ -91,8 +92,8 @@ void CommContextManager::CreateNCCLCommContext(
<< ", unique_comm_key: " << unique_comm_key
<< ", unique_key: " << unique_key
<< ", nccl_id: " << SerializeNCCLUniqueId(nccl_id);
auto nccl_comm_context =
std::make_unique<NCCLCommContext>(rank, size, nccl_id);
auto nccl_comm_context = std::make_unique<NCCLCommContext>(
rank, size, nccl_id, nccl_comm_init_option);
if (CommContextManager::device_id != -1) {
std::unique_ptr<phi::GPUContext> dev_ctx(
new phi::GPUContext(phi::GPUPlace(CommContextManager::device_id)));
Expand Down
3 changes: 2 additions & 1 deletion paddle/phi/core/distributed/comm_context_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ class CommContextManager {
int rank,
int size,
const std::string& hash_key = "",
const P2POption* opt = nullptr);
const P2POption* opt = nullptr,
int nccl_comm_init_option = 0);
#endif

#if defined(PADDLE_WITH_GLOO)
Expand Down
15 changes: 12 additions & 3 deletions paddle/phi/core/distributed/nccl_comm_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,19 @@ namespace distributed {
// set this flag to `true` and recompile to enable dynamic checks
constexpr bool FLAGS_enable_nccl_dynamic_check = false;

NCCLCommContext::NCCLCommContext(int rank, int size, ncclUniqueId nccl_id)
NCCLCommContext::NCCLCommContext(int rank,
int size,
ncclUniqueId nccl_id,
int nccl_comm_init_option)
: CommContext(rank, size) {
NCCL_CHECK(
phi::dynload::ncclCommInitRank(&nccl_comm_, size_, nccl_id, rank_));
if (nccl_comm_init_option > 0) {
LOG(WARNING) << "Creating modified qp with ncclCommInitRank2.";
NCCL_CHECK(phi::dynload::ncclCommInitRank2(
&nccl_comm_, size_, nccl_id, rank_, nccl_comm_init_option));
} else {
NCCL_CHECK(
phi::dynload::ncclCommInitRank(&nccl_comm_, size_, nccl_id, rank_));
}
NCCL_CHECK(phi::dynload::ncclGetVersion(&nccl_version_));
}

Expand Down
5 changes: 4 additions & 1 deletion paddle/phi/core/distributed/nccl_comm_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ namespace distributed {

class NCCLCommContext final : public CommContext {
public:
NCCLCommContext(int rank, int size, ncclUniqueId nccl_id);
NCCLCommContext(int rank,
int size,
ncclUniqueId nccl_id,
int nccl_comm_init_option = 0);
~NCCLCommContext() override = default;

int GetNcclVersion();
Expand Down
16 changes: 14 additions & 2 deletions python/paddle/distributed/collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def _new_process_group_impl(
group_name,
pg_options,
group_id=0,
nccl_comm_init_option=0,
):
pg = None
genv = _get_global_env()
Expand All @@ -155,7 +156,12 @@ def _new_process_group_impl(
pg = core.ProcessGroupGloo.create(store, rank, world_size, group_id)
elif backend == "nccl":
pg = core.ProcessGroupNCCL.create(
store, rank, world_size, group_id, genv.pg_timeout
store,
rank,
world_size,
group_id,
genv.pg_timeout,
nccl_comm_init_option,
)
elif backend == "xccl":
pg = core.ProcessGroupCustom.create(
Expand All @@ -177,7 +183,12 @@ def _set_custom_gid(gid):
_custom_gid = gid


def new_group(ranks=None, backend=None, timeout=_default_timeout):
def new_group(
ranks=None,
backend=None,
timeout=_default_timeout,
nccl_comm_init_option=0,
):
"""

Creates a new distributed communication group.
Expand Down Expand Up @@ -231,6 +242,7 @@ def new_group(ranks=None, backend=None, timeout=_default_timeout):
group_name,
pg_options=None,
group_id=gid,
nccl_comm_init_option=nccl_comm_init_option,
)
else:
rank = -1
Expand Down
14 changes: 13 additions & 1 deletion python/paddle/distributed/fleet/base/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@
'PADDLE_USE_FOUR_DIRECTIONS_P2P', paddle.base.core.is_compiled_with_xpu()
)

g_pipeline_nccl_comm_init_option = int(
os.environ.get("FLAGS_pipeline_nccl_comm_init_option", 0)
)


class ParallelMode:
"""
Expand Down Expand Up @@ -347,8 +351,16 @@ def _set_comm_group(self, parallel_method="data"):
parallel_comm_group = None
parallel_groups = self._topo.get_comm_list(parallel_method)

group_nccl_comm_init_option = (
g_pipeline_nccl_comm_init_option
if (parallel_method == "pipe")
else 0
)
for group in parallel_groups:
comm_group = paddle.distributed.new_group(ranks=group)
comm_group = paddle.distributed.new_group(
ranks=group,
nccl_comm_init_option=group_nccl_comm_init_option,
)
if self.global_rank in group:
parallel_group = group
parallel_comm_group = comm_group
Expand Down