diff --git a/paddle/fluid/distributed/collective/process_group_nccl.cc b/paddle/fluid/distributed/collective/process_group_nccl.cc index 82e95204590bd3..a1d038f255e7ad 100644 --- a/paddle/fluid/distributed/collective/process_group_nccl.cc +++ b/paddle/fluid/distributed/collective/process_group_nccl.cc @@ -123,11 +123,15 @@ ProcessGroupNCCL::ProcessGroupNCCL( int rank, int size, int gid, - int64_t timeout) + int64_t timeout, + int nccl_comm_init_option) : ProcessGroupWithStream(rank, size, gid), store_(store), - pg_timeout_(timeout) { + pg_timeout_(timeout), + nccl_comm_init_option_(nccl_comm_init_option) { LOG(INFO) << "ProcessGroupNCCL pg_timeout_ " << pg_timeout_; + LOG(INFO) << "ProcessGroupNCCL nccl_comm_init_option_ " + << nccl_comm_init_option_; } ProcessGroupNCCL::~ProcessGroupNCCL() { LOG(INFO) << "ProcessGroupNCCL destruct "; @@ -718,7 +722,7 @@ void ProcessGroupNCCL::CreateNCCLEnvCache(const Place& place, phi::distributed::P2POption p2p_opts({is_p2p_op, p2p_rank, num_ranks, rank}); phi::distributed::CommContextManager::CreateNCCLCommContext( - store_, store_key, rank_, size_, "", &p2p_opts); + store_, store_key, rank_, size_, "", &p2p_opts, nccl_comm_init_option_); NCCL_CHECK(phi::dynload::ncclGroupEnd()); @@ -1009,9 +1013,10 @@ std::shared_ptr ProcessGroupNCCL::CreateProcessGroupNCCL( int rank, int size, int gid, - int64_t timeout) { - auto process_group = - std::make_shared(store, rank, size, gid, timeout); + int64_t timeout, + int nccl_comm_init_option) { + auto process_group = std::make_shared( + store, rank, size, gid, timeout, nccl_comm_init_option); ProcessGroupIdMap::GetInstance().emplace(gid, process_group); return process_group; } diff --git a/paddle/fluid/distributed/collective/process_group_nccl.h b/paddle/fluid/distributed/collective/process_group_nccl.h index 22d90370f16afc..a57337f1d47fa2 100644 --- a/paddle/fluid/distributed/collective/process_group_nccl.h +++ b/paddle/fluid/distributed/collective/process_group_nccl.h @@ -76,13 +76,15 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream { int rank, int size, int gid, - int64_t timeout); + int64_t timeout, + int nccl_comm_init_option); ProcessGroupNCCL(const std::shared_ptr& store, int rank, int size, int gid, - int64_t timeout = 30 * 60 * 1000); + int64_t timeout = 30 * 60 * 1000, + int nccl_comm_init_option = 0); ~ProcessGroupNCCL(); std::string GetBackendName() const override { return "NCCL"; } @@ -177,6 +179,8 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream { ncclComm_t NCCLComm(const Place& place) const; + const bool GetNCCLCommInitOption() { return nccl_comm_init_option_; } + private: std::shared_ptr CreateTask(const Place& place, int rank, @@ -247,6 +251,7 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream { static uint64_t s_group_call_counter; // default 30 minutes int64_t pg_timeout_; + int nccl_comm_init_option_; // optimize memory for process_group std::vector, gpuStream_t>> diff --git a/paddle/fluid/platform/dynload/nccl.h b/paddle/fluid/platform/dynload/nccl.h index d9516c9f4de4e8..2dba64af332060 100644 --- a/paddle/fluid/platform/dynload/nccl.h +++ b/paddle/fluid/platform/dynload/nccl.h @@ -31,6 +31,7 @@ namespace dynload { __macro(ncclCommInitAll); \ __macro(ncclGetUniqueId); \ __macro(ncclCommInitRank); \ + __macro(ncclCommInitRank2); \ __macro(ncclCommAbort); \ __macro(ncclCommDestroy); \ __macro(ncclCommCount); \ diff --git a/paddle/fluid/pybind/communication.cc b/paddle/fluid/pybind/communication.cc index 391dbabb1a2109..5e202a2b79d2e6 100644 --- a/paddle/fluid/pybind/communication.cc +++ b/paddle/fluid/pybind/communication.cc @@ -58,6 +58,7 @@ void BindCommContextManager(py::module *m) { py::arg("size"), py::arg("hash_key") = "", py::arg("p2p_opt") = nullptr, + py::arg("nccl_comm_init_option") = 0, py::call_guard()) #endif #if defined(PADDLE_WITH_XPU_BKCL) diff --git a/paddle/fluid/pybind/distributed_py.cc b/paddle/fluid/pybind/distributed_py.cc index 4577171fd77bb5..df48a677b9692a 100644 --- a/paddle/fluid/pybind/distributed_py.cc +++ b/paddle/fluid/pybind/distributed_py.cc @@ -1235,6 +1235,7 @@ void BindDistributed(py::module *m) { py::arg("world_size"), py::arg("group_id") = 0, py::arg("timeout") = 30 * 60 * 1000, + py::arg("nccl_comm_init_option") = 0, py::call_guard()) .def_static("group_start", distributed::ProcessGroupNCCL::GroupStart) .def_static("group_end", distributed::ProcessGroupNCCL::GroupEnd); diff --git a/paddle/phi/backends/dynload/nccl.cc b/paddle/phi/backends/dynload/nccl.cc index 147066b43b0317..fe322c2ad7be59 100644 --- a/paddle/phi/backends/dynload/nccl.cc +++ b/paddle/phi/backends/dynload/nccl.cc @@ -14,11 +14,20 @@ limitations under the License. */ #include "paddle/phi/backends/dynload/nccl.h" +ncclResult_t ncclCommInitRank2(ncclComm_t* newcomm, + int nranks, + ncclUniqueId commId, + int myrank, + int param) { + // fake impl for compilation + return ncclInvalidUsage; +} + namespace phi { namespace dynload { std::once_flag nccl_dso_flag; -void *nccl_dso_handle; +void* nccl_dso_handle; #define DEFINE_WRAP(__name) DynLoad__##__name __name diff --git a/paddle/phi/backends/dynload/nccl.h b/paddle/phi/backends/dynload/nccl.h index 91b6f5dcd58dc5..278474f12d82b3 100644 --- a/paddle/phi/backends/dynload/nccl.h +++ b/paddle/phi/backends/dynload/nccl.h @@ -20,6 +20,18 @@ limitations under the License. */ #include "paddle/phi/backends/dynload/dynamic_loader.h" #include "paddle/phi/backends/dynload/port.h" +#ifdef __cplusplus +extern "C" { +#endif +ncclResult_t ncclCommInitRank2(ncclComm_t* newcomm, + int nranks, + ncclUniqueId commId, + int myrank, + int param); +#ifdef __cplusplus +} +#endif + namespace phi { namespace dynload { @@ -28,15 +40,21 @@ extern void* nccl_dso_handle; #define DECLARE_DYNAMIC_LOAD_NCCL_WRAP(__name) \ struct DynLoad__##__name { \ - template \ - auto operator()(Args... args) -> decltype(__name(args...)) { \ + static auto GetNCCLFunc() { \ using nccl_func = decltype(&::__name); \ std::call_once(nccl_dso_flag, []() { \ nccl_dso_handle = phi::dynload::GetNCCLDsoHandle(); \ }); \ static void* p_##__name = dlsym(nccl_dso_handle, #__name); \ - return reinterpret_cast(p_##__name)(args...); \ + return reinterpret_cast(p_##__name); \ + } \ + \ + template \ + auto operator()(Args... args) -> decltype(__name(args...)) { \ + return GetNCCLFunc()(args...); \ } \ + \ + static bool IsValid() { return GetNCCLFunc() != nullptr; } \ }; \ extern DynLoad__##__name __name @@ -44,6 +62,7 @@ extern void* nccl_dso_handle; __macro(ncclCommInitAll); \ __macro(ncclGetUniqueId); \ __macro(ncclCommInitRank); \ + __macro(ncclCommInitRank2); \ __macro(ncclCommAbort); \ __macro(ncclCommDestroy); \ __macro(ncclCommCount); \ diff --git a/paddle/phi/core/distributed/comm_context_manager.cc b/paddle/phi/core/distributed/comm_context_manager.cc index 5fd7861cc52b2d..01ffd15f79d283 100644 --- a/paddle/phi/core/distributed/comm_context_manager.cc +++ b/paddle/phi/core/distributed/comm_context_manager.cc @@ -62,7 +62,8 @@ void CommContextManager::CreateNCCLCommContext( int rank, int size, const std::string& hash_key, - const P2POption* p2p_opt) { + const P2POption* p2p_opt, + int nccl_comm_init_option) { auto& comm_context_manager = CommContextManager::GetInstance(); if (comm_context_manager.Has(unique_comm_key)) { return; @@ -91,8 +92,8 @@ void CommContextManager::CreateNCCLCommContext( << ", unique_comm_key: " << unique_comm_key << ", unique_key: " << unique_key << ", nccl_id: " << SerializeNCCLUniqueId(nccl_id); - auto nccl_comm_context = - std::make_unique(rank, size, nccl_id); + auto nccl_comm_context = std::make_unique( + rank, size, nccl_id, nccl_comm_init_option); if (CommContextManager::device_id != -1) { std::unique_ptr dev_ctx( new phi::GPUContext(phi::GPUPlace(CommContextManager::device_id))); diff --git a/paddle/phi/core/distributed/comm_context_manager.h b/paddle/phi/core/distributed/comm_context_manager.h index 8c4d802294986f..9e0cb8e5ec3d70 100644 --- a/paddle/phi/core/distributed/comm_context_manager.h +++ b/paddle/phi/core/distributed/comm_context_manager.h @@ -77,7 +77,8 @@ class CommContextManager { int rank, int size, const std::string& hash_key = "", - const P2POption* opt = nullptr); + const P2POption* opt = nullptr, + int nccl_comm_init_option = 0); #endif #if defined(PADDLE_WITH_GLOO) diff --git a/paddle/phi/core/distributed/nccl_comm_context.cc b/paddle/phi/core/distributed/nccl_comm_context.cc index 8da676e74d911a..bfa9a494b327a3 100644 --- a/paddle/phi/core/distributed/nccl_comm_context.cc +++ b/paddle/phi/core/distributed/nccl_comm_context.cc @@ -30,10 +30,22 @@ namespace distributed { // set this flag to `true` and recompile to enable dynamic checks constexpr bool FLAGS_enable_nccl_dynamic_check = false; -NCCLCommContext::NCCLCommContext(int rank, int size, ncclUniqueId nccl_id) +NCCLCommContext::NCCLCommContext(int rank, + int size, + ncclUniqueId nccl_id, + int nccl_comm_init_option) : CommContext(rank, size) { - NCCL_CHECK( - phi::dynload::ncclCommInitRank(&nccl_comm_, size_, nccl_id, rank_)); + if (nccl_comm_init_option > 0 && phi::dynload::ncclCommInitRank2.IsValid()) { + LOG(WARNING) << "Creating modified qp with ncclCommInitRank2."; + NCCL_CHECK(phi::dynload::ncclCommInitRank2( + &nccl_comm_, size_, nccl_id, rank_, nccl_comm_init_option)); + } else { + if (nccl_comm_init_option > 0) { + LOG(WARNING) << "ncclCommInitRank2 is not supported."; + } + NCCL_CHECK( + phi::dynload::ncclCommInitRank(&nccl_comm_, size_, nccl_id, rank_)); + } NCCL_CHECK(phi::dynload::ncclGetVersion(&nccl_version_)); } diff --git a/paddle/phi/core/distributed/nccl_comm_context.h b/paddle/phi/core/distributed/nccl_comm_context.h index 609b5e0defe079..e11c9709976d3f 100644 --- a/paddle/phi/core/distributed/nccl_comm_context.h +++ b/paddle/phi/core/distributed/nccl_comm_context.h @@ -39,7 +39,10 @@ namespace distributed { class NCCLCommContext final : public CommContext { public: - NCCLCommContext(int rank, int size, ncclUniqueId nccl_id); + NCCLCommContext(int rank, + int size, + ncclUniqueId nccl_id, + int nccl_comm_init_option = 0); ~NCCLCommContext() override = default; int GetNcclVersion(); diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index ead61419af4d61..ed6e0899692382 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -147,6 +147,7 @@ def _new_process_group_impl( group_name, pg_options, group_id=0, + nccl_comm_init_option=0, ): pg = None genv = _get_global_env() @@ -155,7 +156,12 @@ def _new_process_group_impl( pg = core.ProcessGroupGloo.create(store, rank, world_size, group_id) elif backend == "nccl": pg = core.ProcessGroupNCCL.create( - store, rank, world_size, group_id, genv.pg_timeout + store, + rank, + world_size, + group_id, + genv.pg_timeout, + nccl_comm_init_option, ) elif backend == "xccl": pg = core.ProcessGroupCustom.create( @@ -177,7 +183,12 @@ def _set_custom_gid(gid): _custom_gid = gid -def new_group(ranks=None, backend=None, timeout=_default_timeout): +def new_group( + ranks=None, + backend=None, + timeout=_default_timeout, + nccl_comm_init_option=0, +): """ Creates a new distributed communication group. @@ -231,6 +242,7 @@ def new_group(ranks=None, backend=None, timeout=_default_timeout): group_name, pg_options=None, group_id=gid, + nccl_comm_init_option=nccl_comm_init_option, ) else: rank = -1 diff --git a/python/paddle/distributed/fleet/base/topology.py b/python/paddle/distributed/fleet/base/topology.py index 3b5a590ae32e23..1c73198bcc744b 100644 --- a/python/paddle/distributed/fleet/base/topology.py +++ b/python/paddle/distributed/fleet/base/topology.py @@ -29,6 +29,10 @@ 'PADDLE_USE_FOUR_DIRECTIONS_P2P', paddle.base.core.is_compiled_with_xpu() ) +g_pipeline_nccl_comm_init_option = int( + os.environ.get("FLAGS_pipeline_nccl_comm_init_option", 0) +) + class ParallelMode: """ @@ -347,8 +351,16 @@ def _set_comm_group(self, parallel_method="data"): parallel_comm_group = None parallel_groups = self._topo.get_comm_list(parallel_method) + group_nccl_comm_init_option = ( + g_pipeline_nccl_comm_init_option + if (parallel_method == "pipe") + else 0 + ) for group in parallel_groups: - comm_group = paddle.distributed.new_group(ranks=group) + comm_group = paddle.distributed.new_group( + ranks=group, + nccl_comm_init_option=group_nccl_comm_init_option, + ) if self.global_rank in group: parallel_group = group parallel_comm_group = comm_group