Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 25 additions & 4 deletions paddle/fluid/imperative/bkcl_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,11 @@
#include <utility>
#include <vector>

#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/bkcl_helper.h"
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/gen_comm_id_helper.h"

#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/gen_comm_id_helper.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/split.h"
#include "paddle/fluid/string/string_helper.h"
Expand Down Expand Up @@ -77,7 +76,7 @@ void BKCLParallelContext::Init() {
bkcl_ids.resize(strategy_.nrings_);

if (strategy_.local_rank_ == 0) {
// generate the unique ncclid on the root worker
// generate the unique bkclid on the root worker
for (size_t i = 0; i < bkcl_ids.size(); ++i) {
auto ret = bkcl_get_unique_id(&bkcl_ids[i]);
PADDLE_ENFORCE_EQ(BKCL_SUCCESS, ret,
Expand All @@ -99,6 +98,28 @@ void BKCLParallelContext::Init() {
}
}

void BKCLParallelContext::InitWithRingID(int ring_id) {
std::vector<BKCLUniqueId> bkcl_ids;
bkcl_ids.resize(1);

if (strategy_.local_rank_ == 0) {
// generate the unique bkclid on the root worker
auto ret = bkcl_get_unique_id(&bkcl_ids[0]);
PADDLE_ENFORCE_EQ(BKCL_SUCCESS, ret,
platform::errors::PreconditionNotMet(
"BKCL get unique id failed [%d]", ret));
}
BcastBKCLId(bkcl_ids, 0);

int xpu_id = BOOST_GET_CONST(platform::XPUPlace, place_).device;
VLOG(0) << "init BKCL context nranks: " << strategy_.nranks_
<< " local rank: " << strategy_.local_rank_ << " xpu id: " << xpu_id
<< " ring id: " << ring_id;
// it will assign bkcl_comm in XPUDeviceContext within ring_id
platform::BKCLCommContext::Instance().CreateBKCLComm(
&bkcl_ids[0], strategy_.nranks_, strategy_.local_rank_, xpu_id, ring_id);
}

void BKCLParallelContext::AllReduceByStream(const framework::Variable &src,
framework::Variable *dst,
int ring_id, bool use_calc_stream) {
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/imperative/bkcl_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class BKCLParallelContext : public ParallelContext {

void Init() override;

void InitWithRingID(int ring_id) override;

void AllReduceByStream(const framework::Variable& src,
framework::Variable* dst, int ring_id,
bool use_calc_stream) override;
Expand Down
24 changes: 24 additions & 0 deletions paddle/fluid/imperative/nccl_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,30 @@ void NCCLParallelContext::Init() {
}
}

void NCCLParallelContext::InitWithRingID(int ring_id) {
std::vector<ncclUniqueId> nccl_ids;
nccl_ids.resize(1);

if (strategy_.local_rank_ == 0) {
// generate the unique ncclid on the root worker
platform::dynload::ncclGetUniqueId(&nccl_ids[0]);
}
BcastNCCLId(nccl_ids, 0);

int gpu_id = BOOST_GET_CONST(platform::CUDAPlace, place_).device;
VLOG(0) << "init nccl context nranks: " << strategy_.nranks_
<< " local rank: " << strategy_.local_rank_ << " gpu id: " << gpu_id
<< " ring id: " << ring_id;
// it will assign nccl_comm in CUDADeviceContext within ring_id
platform::NCCLCommContext::Instance().CreateNCCLComm(
&nccl_ids[0], strategy_.nranks_, strategy_.local_rank_, gpu_id, ring_id);

compute_events_.emplace_back(platform::CudaEventResourcePool::Instance().New(
BOOST_GET_CONST(platform::CUDAPlace, place_).device));
comm_events_.emplace_back(platform::CudaEventResourcePool::Instance().New(
BOOST_GET_CONST(platform::CUDAPlace, place_).device));
}

void NCCLParallelContext::AllReduceByStream(const framework::Variable &src,
framework::Variable *dst,
int ring_id, bool use_calc_stream) {
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/imperative/nccl_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ class NCCLParallelContext : public ParallelContext {

void Init() override;

void InitWithRingID(int ring_id) override;

void AllReduceByStream(const framework::Variable& src,
framework::Variable* dst, int ring_id,
bool use_calc_stream) override;
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/imperative/parallel_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ class ParallelContext {

virtual void Init() = 0;

virtual void InitWithRingID(int ring_id) = 0;

virtual void AllReduceByStream(const framework::Variable& src,
framework::Variable* dst, int ring_id,
bool use_calc_stream) = 0;
Expand Down
70 changes: 38 additions & 32 deletions paddle/fluid/operators/collective/c_sync_calc_stream_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,40 +15,20 @@ limitations under the License. */

#include "paddle/fluid/framework/op_registry.h"

namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle

namespace paddle {
namespace operators {

class CSyncCalcStreamOp : public framework::OperatorBase {
class CSyncCalcStreamOp : public framework::OperatorWithKernel {
public:
CSyncCalcStreamOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}

void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
PADDLE_ENFORCE_EQ(is_gpu_place(place), true,
platform::errors::PreconditionNotMet(
"Sync stream op can run on gpu place only for now."));
#if (defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)) && !defined(_WIN32)
auto dev_ctx = static_cast<platform::CUDADeviceContext*>(
platform::DeviceContextPool::Instance().Get(place));
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(dev_ctx->stream()));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(dev_ctx->stream()));
#endif
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with GPU."));
#endif
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(framework::proto::VarType::FP32,
ctx.GetPlace());
}
};

Expand All @@ -65,10 +45,36 @@ Call calculation stream synchronization.
}
};

template <typename T>
class CSyncCalcStreamCudaKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
#if (defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)) && !defined(_WIN32)

auto place = ctx.GetPlace();
auto dev_ctx = static_cast<platform::CUDADeviceContext*>(
platform::DeviceContextPool::Instance().Get(place));

#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(dev_ctx->stream()));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(dev_ctx->stream()));
#endif

#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with GPU."));
#endif
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;

REGISTER_OPERATOR(c_sync_calc_stream, ops::CSyncCalcStreamOp,
ops::CSyncCalcStreamOpMaker);
REGISTER_OP_WITHOUT_GRADIENT(c_sync_calc_stream, ops::CSyncCalcStreamOp,
ops::CSyncCalcStreamOpMaker);

REGISTER_OP_CUDA_KERNEL(c_sync_calc_stream,
ops::CSyncCalcStreamCudaKernel<float>);
74 changes: 41 additions & 33 deletions paddle/fluid/operators/collective/c_sync_comm_stream_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,45 +14,25 @@ limitations under the License. */
#include <string>

#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif

namespace paddle {
namespace operators {

class CSyncCommStreamOp : public framework::OperatorBase {
class CSyncCommStreamOp : public framework::OperatorWithKernel {
public:
CSyncCommStreamOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}

void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
PADDLE_ENFORCE_EQ(is_gpu_place(place), true,
platform::errors::PreconditionNotMet(
"Sync stream op can run on gpu place only for now."));
using framework::OperatorWithKernel::OperatorWithKernel;

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
int ring_id = Attr<int>("ring_id");
auto stream =
platform::NCCLCommContext::Instance().Get(ring_id, place)->stream();
#ifdef PADDLE_WITH_RCCL
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(stream));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));
#endif
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with GPU."));
#endif
void InferShape(framework::InferShapeContext* ctx) const override {}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(framework::proto::VarType::FP32,
ctx.GetPlace());
}
};

Expand All @@ -72,10 +52,38 @@ Call communication stream synchronization.
}
};

template <typename T>
class CSyncCommStreamCudaKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)

auto place = ctx.GetPlace();

int ring_id = ctx.Attr<int>("ring_id");
auto stream =
platform::NCCLCommContext::Instance().Get(ring_id, place)->stream();

#ifdef PADDLE_WITH_RCCL
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(stream));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));
#endif

#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with GPU."));
#endif
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;

REGISTER_OPERATOR(c_sync_comm_stream, ops::CSyncCommStreamOp,
ops::CSyncCommStreamOpMaker);
REGISTER_OP_WITHOUT_GRADIENT(c_sync_comm_stream, ops::CSyncCommStreamOp,
ops::CSyncCommStreamOpMaker);

REGISTER_OP_CUDA_KERNEL(c_sync_comm_stream,
ops::CSyncCommStreamCudaKernel<float>);
10 changes: 8 additions & 2 deletions paddle/fluid/pybind/imperative.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1441,7 +1441,10 @@ void BindImperative(py::module *m_ptr) {
m, "NCCLParallelContext")
.def(py::init<const imperative::ParallelStrategy &,
const platform::CUDAPlace &>())
.def("init", [](imperative::NCCLParallelContext &self) { self.Init(); });
.def("init", [](imperative::NCCLParallelContext &self) { self.Init(); })
.def("init_with_ring_id",
&imperative::NCCLParallelContext::InitWithRingID,
py::arg("ring_id"));
#endif

#if defined(PADDLE_WITH_XPU_BKCL)
Expand All @@ -1450,7 +1453,10 @@ void BindImperative(py::module *m_ptr) {
m, "BKCLParallelContext")
.def(py::init<const imperative::ParallelStrategy &,
const platform::XPUPlace &>())
.def("init", [](imperative::BKCLParallelContext &self) { self.Init(); });
.def("init", [](imperative::BKCLParallelContext &self) { self.Init(); })
.def("init_with_ring_id",
&imperative::BKCLParallelContext::InitWithRingID,
py::arg("ring_id"));
#endif
}

Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/pybind/op_function_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
{"fill_constant", {"Out"}},
{"matmul", {"Out"}},
{"c_broadcast", {"Out"}},
{"c_sync_calc_stream", {"Out"}},
{"c_sync_comm_stream", {"Out"}},
{"c_allreduce_sum", {"Out"}},
{"c_allreduce_max", {"Out"}},
{"c_allreduce_min", {"Out"}},
Expand Down
Loading