Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion paddle/fluid/framework/op_proto_maker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto,
static_cast<int>(OpRole::kBackward),
static_cast<int>(OpRole::kOptimize) |
static_cast<int>(OpRole::kLRSched),
static_cast<int>(OpRole::kCollective),
static_cast<int>(OpRole::kNotSpecified)})
.SetDefault(static_cast<int>(OpRole::kNotSpecified));
AddAttr<std::vector<std::string>>(OpRoleVarAttrName(),
Expand Down
3 changes: 0 additions & 3 deletions paddle/fluid/framework/op_proto_maker.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,6 @@ enum class OpRole {
kDist = 0x0008,
// Tag all learning rate scheduler operators.
kLRSched = 0x0010,
// Collective role is for all collective operators and other operators used
// for collective training
kCollective = 0x0020,

kLoss = 0x0100,
// The default value of op's role. This should be only used for unittests and
Expand Down
21 changes: 9 additions & 12 deletions paddle/fluid/operators/collective/c_allgather_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/collective/c_allgather_op.h"
#include <future> // NOLINT

#include <memory>
#include <ostream>

namespace paddle {
namespace operators {
Expand All @@ -25,8 +24,7 @@ class CAllGatherOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SyncFCGather op should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null.");
int nranks = ctx->Attrs().Get<int>("nranks");
PADDLE_ENFORCE_GE(nranks, 2, "nranks should be >=2");
framework::DDim dim = ctx->GetInputDim("X");
Expand All @@ -49,10 +47,10 @@ class CAllGatherOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>("nranks",
"Total trainer count of the distributed training job");
AddComment(R"DOC(
***CAllGather Operator***
CAllGather Operator
each rank receives the aggregation of data from all ranks in the order of the ranks

Call NCCL collective AllGather internally.https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/api/colls.html#c.ncclAllGather
reference: https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/usage/operations.html#allgather
)DOC");
}
};
Expand Down Expand Up @@ -81,9 +79,8 @@ namespace plat = paddle::platform;
REGISTER_OPERATOR(c_allgather, ops::CAllGatherOp, ops::CAllGatherOpGradMaker,
ops::CAllGatherOpMaker);

REGISTER_OP_CPU_KERNEL(
c_allgather, ops::CAllGatherOpKernel<plat::CPUDeviceContext, float>,
ops::CAllGatherOpKernel<plat::CPUDeviceContext, double>,
ops::CAllGatherOpKernel<plat::CPUDeviceContext, int>,
ops::CAllGatherOpKernel<plat::CPUDeviceContext, int64_t>,
ops::CAllGatherOpKernel<plat::CPUDeviceContext, plat::float16>);
REGISTER_OP_CPU_KERNEL(c_allgather, ops::CAllGatherOpCPUKernel<float>,
ops::CAllGatherOpCPUKernel<double>,
ops::CAllGatherOpCPUKernel<int>,
ops::CAllGatherOpCPUKernel<int64_t>,
ops::CAllGatherOpCPUKernel<plat::float16>);
64 changes: 58 additions & 6 deletions paddle/fluid/operators/collective/c_allgather_op.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,64 @@ limitations under the License. */

#include "paddle/fluid/operators/collective/c_allgather_op.h"

#include <memory>

#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif

namespace paddle {
namespace operators {

template <typename T>
class CAllGatherOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
auto in = ctx.Input<framework::Tensor>("X");
auto out = ctx.Output<framework::Tensor>("Out");
ncclDataType_t dtype = platform::ToNCCLDataType(in->type());

int nranks = ctx.Attr<int>("nranks");
int rid = ctx.Attr<int>("ring_id");
auto comm = platform::NCCLCommContext::Instance().Get(rid);
PADDLE_ENFORCE_EQ(nranks, comm->nranks());

auto place = ctx.GetPlace();
framework::DDim out_dims = in->dims();
out_dims[0] *= nranks;
out->mutable_data<T>(out_dims, place);

int64_t send_numel = in->numel();
const T* send_buff = in->data<T>();
T* recv_buff = out->data<T>();

cudaStream_t stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<platform::CUDADeviceContext*>(dev_ctx)->stream();
} else {
stream = comm->stream();
}

PADDLE_ENFORCE(platform::dynload::ncclAllGather(
send_buff, recv_buff, send_numel, static_cast<ncclDataType_t>(dtype),
comm->comm(), stream));
#else
PADDLE_THROW("PaddlePaddle should compile with GPU.");
#endif
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;

REGISTER_OP_CUDA_KERNEL(
c_allgather, ops::CAllGatherOpKernel<plat::CUDADeviceContext, float>,
ops::CAllGatherOpKernel<plat::CUDADeviceContext, double>,
ops::CAllGatherOpKernel<plat::CUDADeviceContext, int>,
ops::CAllGatherOpKernel<plat::CUDADeviceContext, int64_t>,
ops::CAllGatherOpKernel<plat::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(c_allgather, ops::CAllGatherOpCUDAKernel<float>,
ops::CAllGatherOpCUDAKernel<double>,
ops::CAllGatherOpCUDAKernel<int>,
ops::CAllGatherOpCUDAKernel<int64_t>,
ops::CAllGatherOpCUDAKernel<plat::float16>);
47 changes: 5 additions & 42 deletions paddle/fluid/operators/collective/c_allgather_op.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <algorithm>
#include <utility>
#include <vector>
Expand All @@ -22,52 +23,14 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"

#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif

namespace paddle {
namespace operators {

template <typename DeviceContext, typename T>
class CAllGatherOpKernel : public framework::OpKernel<T> {
template <typename T>
class CAllGatherOpCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto place = ctx.GetPlace();
PADDLE_ENFORCE(is_gpu_place(place),
"CAllGatherOp can run on gpu place only for now.");
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
auto in = ctx.Input<framework::Tensor>("X");
auto out = ctx.Output<framework::Tensor>("Out");
ncclDataType_t dtype = platform::ToNCCLDataType(in->type());

int rid = ctx.Attr<int>("ring_id");
auto comm = platform::NCCLCommContext::Instance().Get(rid);
int nranks = comm->nranks();

framework::DDim out_dims = in->dims();
out_dims[0] *= nranks;
out->mutable_data<T>(out_dims, place);

int64_t send_numel = in->numel();
const T* send_buff = in->data<T>();
T* recv_buff = out->data<T>();

cudaStream_t stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<platform::CUDADeviceContext*>(dev_ctx)->stream();
} else {
stream = comm->stream();
}

PADDLE_ENFORCE(platform::dynload::ncclAllGather(
send_buff, recv_buff, send_numel, static_cast<ncclDataType_t>(dtype),
comm->comm(), stream));
#else
PADDLE_THROW("PaddlePaddle should compile with GPU.");
#endif
PADDLE_THROW("unimplemented cpu kernel for CAllGatherOp.");
}
};

Expand Down
39 changes: 39 additions & 0 deletions paddle/fluid/operators/collective/c_allreduce_max_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/collective/c_allreduce_op.h"

namespace paddle {
namespace operators {

class CAllReduceMaxOpMaker : public CAllReduceOpMaker {
protected:
std::string GetName() const override { return "Max"; }
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;

REGISTER_OP_WITHOUT_GRADIENT(c_allreduce_max, ops::CAllReduceOp,
ops::CAllReduceMaxOpMaker);

REGISTER_OP_CPU_KERNEL(c_allreduce_max,
ops::CAllReduceOpCPUKernel<ops::kRedMax, float>,
ops::CAllReduceOpCPUKernel<ops::kRedMax, double>,
ops::CAllReduceOpCPUKernel<ops::kRedMax, int>,
ops::CAllReduceOpCPUKernel<ops::kRedMax, int64_t>,
ops::CAllReduceOpCPUKernel<ops::kRedMax, plat::float16>);
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ namespace ops = paddle::operators;
namespace plat = paddle::platform;

REGISTER_OP_CUDA_KERNEL(
c_allreduce, ops::CAllReduceOpKernel<plat::CUDADeviceContext, float>,
ops::CAllReduceOpKernel<plat::CUDADeviceContext, double>,
ops::CAllReduceOpKernel<plat::CUDADeviceContext, int>,
ops::CAllReduceOpKernel<plat::CUDADeviceContext, int64_t>,
ops::CAllReduceOpKernel<plat::CUDADeviceContext, plat::float16>);
c_allreduce_max, ops::CAllReduceOpCUDAKernel<ops::kRedMax, float>,
ops::CAllReduceOpCUDAKernel<ops::kRedMax, double>,
ops::CAllReduceOpCUDAKernel<ops::kRedMax, int>,
ops::CAllReduceOpCUDAKernel<ops::kRedMax, int64_t>,
ops::CAllReduceOpCUDAKernel<ops::kRedMax, plat::float16>)
39 changes: 39 additions & 0 deletions paddle/fluid/operators/collective/c_allreduce_min_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/collective/c_allreduce_op.h"

namespace paddle {
namespace operators {

class CAllReduceMinOpMaker : public CAllReduceOpMaker {
protected:
std::string GetName() const override { return "Min"; }
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;

REGISTER_OP_WITHOUT_GRADIENT(c_allreduce_min, ops::CAllReduceOp,
ops::CAllReduceMinOpMaker);

REGISTER_OP_CPU_KERNEL(c_allreduce_min,
ops::CAllReduceOpCPUKernel<ops::kRedMin, float>,
ops::CAllReduceOpCPUKernel<ops::kRedMin, double>,
ops::CAllReduceOpCPUKernel<ops::kRedMin, int>,
ops::CAllReduceOpCPUKernel<ops::kRedMin, int64_t>,
ops::CAllReduceOpCPUKernel<ops::kRedMin, plat::float16>);
25 changes: 25 additions & 0 deletions paddle/fluid/operators/collective/c_allreduce_min_op.cu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/collective/c_allreduce_op.h"

namespace ops = paddle::operators;
namespace plat = paddle::platform;

REGISTER_OP_CUDA_KERNEL(
c_allreduce_min, ops::CAllReduceOpCUDAKernel<ops::kRedMin, float>,
ops::CAllReduceOpCUDAKernel<ops::kRedMin, double>,
ops::CAllReduceOpCUDAKernel<ops::kRedMin, int>,
ops::CAllReduceOpCUDAKernel<ops::kRedMin, int64_t>,
ops::CAllReduceOpCUDAKernel<ops::kRedMin, plat::float16>)
Loading