From 1461992434415af37c0b7ecf7a3c524989500a38 Mon Sep 17 00:00:00 2001 From: youth123 <2042519524@qq.com> Date: Tue, 7 Sep 2021 18:59:26 +0800 Subject: [PATCH 1/3] upload global scatter and global gather operators related files --- .../operators/collective/global_gather_op.cc | 114 ++++++++++ .../collective/global_gather_op.cu.cc | 137 ++++++++++++ .../operators/collective/global_gather_op.h | 37 ++++ .../operators/collective/global_scatter_op.cc | 117 ++++++++++ .../collective/global_scatter_op.cu.cc | 137 ++++++++++++ .../operators/collective/global_scatter_op.h | 37 ++++ python/paddle/distributed/utils.py | 204 +++++++++++++++++- .../fluid/tests/unittests/CMakeLists.txt | 6 + .../unittests/collective_global_gather.py | 113 ++++++++++ .../collective_global_gather_dygraph.py | 64 ++++++ .../unittests/collective_global_scatter.py | 101 +++++++++ .../collective_global_scatter_dygraph.py | 62 ++++++ .../test_collective_global_gather.py | 191 ++++++++++++++++ .../test_collective_global_scatter.py | 148 +++++++++++++ 14 files changed, 1467 insertions(+), 1 deletion(-) create mode 100644 paddle/fluid/operators/collective/global_gather_op.cc create mode 100644 paddle/fluid/operators/collective/global_gather_op.cu.cc create mode 100644 paddle/fluid/operators/collective/global_gather_op.h create mode 100644 paddle/fluid/operators/collective/global_scatter_op.cc create mode 100644 paddle/fluid/operators/collective/global_scatter_op.cu.cc create mode 100644 paddle/fluid/operators/collective/global_scatter_op.h create mode 100644 python/paddle/fluid/tests/unittests/collective_global_gather.py create mode 100644 python/paddle/fluid/tests/unittests/collective_global_gather_dygraph.py create mode 100644 python/paddle/fluid/tests/unittests/collective_global_scatter.py create mode 100644 python/paddle/fluid/tests/unittests/collective_global_scatter_dygraph.py create mode 100644 python/paddle/fluid/tests/unittests/test_collective_global_gather.py create mode 100644 python/paddle/fluid/tests/unittests/test_collective_global_scatter.py diff --git a/paddle/fluid/operators/collective/global_gather_op.cc b/paddle/fluid/operators/collective/global_gather_op.cc new file mode 100644 index 00000000000000..3c9929326bc688 --- /dev/null +++ b/paddle/fluid/operators/collective/global_gather_op.cc @@ -0,0 +1,114 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/global_gather_op.h" + +namespace paddle { +namespace operators { + +class GlobalGatherOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "GlobalGather"); + OP_INOUT_CHECK(ctx->HasInput("local_count"), "Input", "local_count", + "GlobalGather"); + OP_INOUT_CHECK(ctx->HasInput("global_count"), "Input", "global_count", + "GlobalGather"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "GlobalGather"); + int ring_id = ctx->Attrs().Get("ring_id"); + PADDLE_ENFORCE_GE( + ring_id, 0, + platform::errors::InvalidArgument( + "The ring_id (%d) for global gather op must be non-negative.", + ring_id)); + auto input_dims = ctx->GetInputDim("X"); + auto ndim_input = input_dims.size(); + // dim check + PADDLE_ENFORCE_EQ(ndim_input, 2, + platform::errors::InvalidArgument( + "The input tensor's dimension must be 2. " + "But received input's dimension = [%s].", + ndim_input)); + framework::DDim out_dims = framework::make_ddim({-1, -1}); + ctx->SetOutputDim("Out", out_dims); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + } +}; + +class GlobalGatherOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() { + AddInput("X", "(Tensor) tensor send."); + AddInput("local_count", + "(Tensor) Tensor which has n_expert * world_size elements that " + "indicates" + "how many data needed to be received from each expert."); + AddInput("global_count", + "(Tensor) Tensor which has n_expert * world_size elements that " + "indicates" + "how many data needed to be sent to each expert."); + AddOutput("Out", "(Tensor) the result of global_gather."); + AddAttr("ring_id", "(int default 0) nccl communication ring id.") + .SetDefault(0); + AddAttr( + "use_calc_stream", + "(bool default false) eject CUDA operations to calculation stream.") + .SetDefault(false); + AddComment(R"DOC( +Global Gather Operator +Gather data in X to n_expert * world_size exeperts according to +local_count and receive tensors from n_expert * world_size experts according +to global_count. +)DOC"); + } +}; + +template +class GlobalGatherOpGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr retv) const override { + retv->SetType("global_scatter"); + retv->SetInput("X", this->OutputGrad("Out")); + retv->SetInput("local_count", this->Input("local_count")); + retv->SetInput("global_count", this->Input("global_count")); + retv->SetOutput("Out", this->InputGrad("X")); + retv->SetAttrMap(this->Attrs()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; +REGISTER_OPERATOR(global_gather, ops::GlobalGatherOp, ops::GlobalGatherOpMaker, + ops::GlobalGatherOpGradMaker, + ops::GlobalGatherOpGradMaker) + +REGISTER_OP_CPU_KERNEL(global_gather, ops::GlobalGatherOpCPUKernel, + ops::GlobalGatherOpCPUKernel, + ops::GlobalGatherOpCPUKernel, + ops::GlobalGatherOpCPUKernel, + ops::GlobalGatherOpCPUKernel); diff --git a/paddle/fluid/operators/collective/global_gather_op.cu.cc b/paddle/fluid/operators/collective/global_gather_op.cu.cc new file mode 100644 index 00000000000000..d5d93f97b369bf --- /dev/null +++ b/paddle/fluid/operators/collective/global_gather_op.cu.cc @@ -0,0 +1,137 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/global_gather_op.h" + +#if defined(PADDLE_WITH_NCCL) +#include "paddle/fluid/platform/collective_helper.h" +#include "paddle/fluid/platform/nccl_helper.h" +#endif + +namespace paddle { +namespace operators { +template +class GlobalGatherOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { +#if defined(PADDLE_WITH_NCCL) +#if NCCL_VERSION_CODE >= 2703 + auto x = ctx.Input("X"); + auto local_count = ctx.Input("local_count"); + auto global_count = ctx.Input("global_count"); + auto out = ctx.Output("Out"); + const int64_t* cpu_local_count_data; + const int64_t* cpu_global_count_data; + auto local_count_len = 0; + + framework::Tensor cpu_local_count; + if (platform::is_gpu_place(local_count->place())) { + framework::TensorCopySync(*local_count, platform::CPUPlace(), + &cpu_local_count); + cpu_local_count_data = cpu_local_count.data(); + local_count_len = cpu_local_count.numel(); + } else { + cpu_local_count_data = local_count->data(); + local_count_len = local_count->numel(); + } + + framework::Tensor cpu_global_count; + if (platform::is_gpu_place(global_count->place())) { + framework::TensorCopySync(*global_count, platform::CPUPlace(), + &cpu_global_count); + cpu_global_count_data = cpu_global_count.data(); + } else { + cpu_global_count_data = global_count->data(); + } + + ncclDataType_t dtype = platform::ToNCCLDataType(x->type()); + + int ring_id = ctx.Attr("ring_id"); + PADDLE_ENFORCE_GE( + ring_id, 0, + platform::errors::InvalidArgument( + "The ring_id (%d) for global gather op must be non-negative.", + ring_id)); + auto place = ctx.GetPlace(); + auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); + cudaStream_t stream = nullptr; + if (ctx.Attr("use_calc_stream")) { + auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); + stream = static_cast(dev_ctx)->stream(); + } else { + stream = comm->stream(); + } + int nranks = comm->nranks(); + auto in_feat = x->dims()[1]; + auto n_expert = local_count->dims()[0] / nranks; + + auto fwd_count = 0; + + for (auto i = 0; i < local_count_len; ++i) { + fwd_count += cpu_local_count_data[i]; + } + framework::DDim out_dims = framework::make_ddim({fwd_count, in_feat}); + int64_t* expert_ptr = new int64_t[n_expert * nranks]; + expert_ptr[0] = 0; + auto tot_experts = n_expert * nranks; + for (auto i = 1; i < tot_experts; ++i) { + expert_ptr[i] = expert_ptr[i - 1] + cpu_local_count_data[i - 1]; + } + auto send_ptr = 0; + auto send_buf = x->data(); + auto recv_buf = out->mutable_data(out_dims, place); + + for (auto i = 0; i < n_expert; ++i) { + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupStart()); + for (auto j = 0; j < nranks; ++j) { + int idx = i + j * n_expert; + if (cpu_global_count_data[idx]) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::ncclSend(send_buf + send_ptr * in_feat, + cpu_global_count_data[idx] * in_feat, + dtype, j, comm->comm(), stream)); + send_ptr += cpu_global_count_data[idx]; + } + if (cpu_local_count_data[idx]) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::ncclRecv(recv_buf + expert_ptr[idx] * in_feat, + cpu_local_count_data[idx] * in_feat, + dtype, j, comm->comm(), stream)); + } + } + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupEnd()); + PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream)); + } +#else + PADDLE_THROW( + platform::errors::Unavailable("NCCL version >= 2.7.3 is needed.")); +#endif +#else + PADDLE_THROW( + platform::errors::Unavailable("PaddlePaddle should compile with GPU.")); +#endif + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_CUDA_KERNEL(global_gather, ops::GlobalGatherOpCUDAKernel, + ops::GlobalGatherOpCUDAKernel, + ops::GlobalGatherOpCUDAKernel, + ops::GlobalGatherOpCUDAKernel, + ops::GlobalGatherOpCUDAKernel); diff --git a/paddle/fluid/operators/collective/global_gather_op.h b/paddle/fluid/operators/collective/global_gather_op.h new file mode 100644 index 00000000000000..3ff2df9e48f3d9 --- /dev/null +++ b/paddle/fluid/operators/collective/global_gather_op.h @@ -0,0 +1,37 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" + +#if defined(PADDLE_WITH_GLOO) +#include "paddle/fluid/framework/fleet/gloo_wrapper.h" +#endif + +namespace paddle { +namespace operators { + +template +class GlobalGatherOpCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_THROW(platform::errors::Unavailable( + "Do not support global gather op for cpu kernel now.")); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/collective/global_scatter_op.cc b/paddle/fluid/operators/collective/global_scatter_op.cc new file mode 100644 index 00000000000000..ae7fe85e37d3aa --- /dev/null +++ b/paddle/fluid/operators/collective/global_scatter_op.cc @@ -0,0 +1,117 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/global_scatter_op.h" + +namespace paddle { +namespace operators { + +class GlobalScatterOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "GlobalScatter"); + OP_INOUT_CHECK(ctx->HasInput("local_count"), "Input", "local_count", + "GlobalScatter"); + OP_INOUT_CHECK(ctx->HasInput("global_count"), "Input", "global_count", + "GlobalScatter"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "GlobalScatter"); + int ring_id = ctx->Attrs().Get("ring_id"); + PADDLE_ENFORCE_GE( + ring_id, 0, + platform::errors::InvalidArgument( + "The ring_id (%d) for global scatter op must be non-negative.", + ring_id)); + auto input_dims = ctx->GetInputDim("X"); + auto ndim_input = input_dims.size(); + // dim check + PADDLE_ENFORCE_EQ(ndim_input, 2, + platform::errors::InvalidArgument( + "The input tensor's dimension must be 2. " + "But received input's dimension = [%s].", + ndim_input)); + + framework::DDim out_dims = framework::make_ddim({-1, -1}); + ctx->SetOutputDim("Out", out_dims); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + } +}; + +class GlobalScatterOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() { + AddInput("X", "(Tensor) tensor send."); + AddInput("local_count", + "(Tensor) Tensor which has n_expert * world_size elements that " + "indicates" + "how many data needed to be sent to each expert."); + AddInput("global_count", + "(Tensor) Tensor which has n_expert * world_size elements that " + "indicates" + "how many data needed to be received from each expert."); + AddAttr("ring_id", "(int default 0) nccl communication ring id.") + .SetDefault(0); + AddAttr( + "use_calc_stream", + "(bool default false) eject CUDA operations to calculation stream.") + .SetDefault(false); + AddOutput("Out", "(Tensor) the result of global_scatter."); + AddComment(R"DOC( +Global Scatter Operator +Scatter data in X which has been put together belong to one expert +to n_expert * world_size exeperts according to local_count +and receive tensors from n_expert * world_size experts according +to global_count. +)DOC"); + } +}; + +template +class GlobalScatterOpGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr retv) const override { + retv->SetType("global_gather"); + retv->SetInput("X", this->OutputGrad("Out")); + retv->SetInput("local_count", this->Input("local_count")); + retv->SetInput("global_count", this->Input("global_count")); + retv->SetOutput("Out", this->InputGrad("X")); + retv->SetAttrMap(this->Attrs()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; +REGISTER_OPERATOR(global_scatter, ops::GlobalScatterOp, + ops::GlobalScatterOpMaker, + ops::GlobalScatterOpGradMaker, + ops::GlobalScatterOpGradMaker) + +REGISTER_OP_CPU_KERNEL(global_scatter, ops::GlobalScatterOpCPUKernel, + ops::GlobalScatterOpCPUKernel, + ops::GlobalScatterOpCPUKernel, + ops::GlobalScatterOpCPUKernel, + ops::GlobalScatterOpCPUKernel); diff --git a/paddle/fluid/operators/collective/global_scatter_op.cu.cc b/paddle/fluid/operators/collective/global_scatter_op.cu.cc new file mode 100644 index 00000000000000..f67baceadf545a --- /dev/null +++ b/paddle/fluid/operators/collective/global_scatter_op.cu.cc @@ -0,0 +1,137 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/global_scatter_op.h" + +#if defined(PADDLE_WITH_NCCL) +#include "paddle/fluid/platform/collective_helper.h" +#include "paddle/fluid/platform/nccl_helper.h" +#endif + +namespace paddle { +namespace operators { +template +class GlobalScatterOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { +#if defined(PADDLE_WITH_NCCL) +#if NCCL_VERSION_CODE >= 2703 + auto x = ctx.Input("X"); + auto local_count = ctx.Input("local_count"); + auto global_count = ctx.Input("global_count"); + auto out = ctx.Output("Out"); + const int64_t* cpu_local_count_data; + const int64_t* cpu_global_count_data; + framework::Tensor cpu_local_count; + if (platform::is_gpu_place(local_count->place())) { + framework::TensorCopy(*local_count, platform::CPUPlace(), + &cpu_local_count); + cpu_local_count_data = cpu_local_count.data(); + } else { + cpu_local_count_data = local_count->data(); + } + auto global_count_len = 0; + framework::Tensor cpu_global_count; + if (platform::is_gpu_place(global_count->place())) { + framework::TensorCopy(*global_count, platform::CPUPlace(), + &cpu_global_count); + cpu_global_count_data = cpu_global_count.data(); + global_count_len = cpu_global_count.numel(); + } else { + cpu_global_count_data = global_count->data(); + global_count_len = global_count->numel(); + } + + ncclDataType_t dtype = platform::ToNCCLDataType(x->type()); + + int ring_id = ctx.Attr("ring_id"); + PADDLE_ENFORCE_GE( + ring_id, 0, + platform::errors::InvalidArgument( + "The ring_id (%d) for global scatter op must be non-negative.", + ring_id)); + + auto place = ctx.GetPlace(); + auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); + cudaStream_t stream = nullptr; + if (ctx.Attr("use_calc_stream")) { + auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); + stream = static_cast(dev_ctx)->stream(); + } else { + stream = comm->stream(); + } + int nranks = comm->nranks(); + auto in_feat = x->dims()[1]; + auto n_expert = local_count->dims()[0] / nranks; + int64_t fwd_count = 0; + + for (auto i = 0; i < global_count_len; ++i) { + fwd_count += cpu_global_count_data[i]; + } + framework::DDim out_dims = framework::make_ddim({fwd_count, in_feat}); + int64_t* expert_ptr = new int64_t[n_expert * nranks]; + expert_ptr[0] = 0; + auto tot_experts = n_expert * nranks; + for (auto i = 1; i < tot_experts; ++i) { + expert_ptr[i] = expert_ptr[i - 1] + cpu_local_count_data[i - 1]; + } + + auto recv_ptr = 0; + auto send_buf = x->data(); + auto recv_buf = out->mutable_data(out_dims, place); + + for (auto i = 0; i < n_expert; ++i) { + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupStart()); + for (auto j = 0; j < nranks; ++j) { + int idx = i + j * n_expert; + if (cpu_local_count_data[idx]) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::ncclSend(send_buf + expert_ptr[idx] * in_feat, + cpu_local_count_data[idx] * in_feat, + dtype, j, comm->comm(), stream)); + } + if (cpu_global_count_data[idx]) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::ncclRecv(recv_buf + recv_ptr * in_feat, + cpu_global_count_data[idx] * in_feat, + dtype, j, comm->comm(), stream)); + recv_ptr += cpu_global_count_data[idx]; + } + } + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupEnd()); + PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream)); + } + +#else + PADDLE_THROW( + platform::errors::Unavailable("NCCL version >= 2.7.3 is needed.")); +#endif +#else + PADDLE_THROW( + platform::errors::Unavailable("PaddlePaddle should compile with GPU.")); +#endif + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_CUDA_KERNEL(global_scatter, ops::GlobalScatterOpCUDAKernel, + ops::GlobalScatterOpCUDAKernel, + ops::GlobalScatterOpCUDAKernel, + ops::GlobalScatterOpCUDAKernel, + ops::GlobalScatterOpCUDAKernel); diff --git a/paddle/fluid/operators/collective/global_scatter_op.h b/paddle/fluid/operators/collective/global_scatter_op.h new file mode 100644 index 00000000000000..52b486aef25c2b --- /dev/null +++ b/paddle/fluid/operators/collective/global_scatter_op.h @@ -0,0 +1,37 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" + +#if defined(PADDLE_WITH_GLOO) +#include "paddle/fluid/framework/fleet/gloo_wrapper.h" +#endif + +namespace paddle { +namespace operators { + +template +class GlobalScatterOpCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_THROW(platform::errors::Unavailable( + "Do not support global scatter op for cpu kernel now.")); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/distributed/utils.py b/python/paddle/distributed/utils.py index 447c059537ba3f..28d24e37777f19 100644 --- a/python/paddle/distributed/utils.py +++ b/python/paddle/distributed/utils.py @@ -27,6 +27,11 @@ from paddle.fluid import core from distutils.util import strtobool +from paddle.fluid.layer_helper import LayerHelper +from paddle.fluid.framework import in_dygraph_mode +from paddle.fluid.data_feeder import check_variable_and_dtype + + __all__ = [ #noqa 'get_host_name_ip', 'Trainer', @@ -42,9 +47,206 @@ 'terminate_local_procs', 'TrainerProc', 'get_logger', - 'pull_worker_log' + 'pull_worker_log', + 'global_scatter', + 'global_gather' ] + +def global_scatter(x, + local_count, + global_count, + group=None, + use_calc_stream=True): + """ + Scatter data in x which has been put together belong to one expert + to n_expert * world_size exeperts according to local_count and receive tensors + from n_expert * world_size experts according + to global_count. + + Args: + x (Tensor): Tensor. Every element in the list must be a Tensor whose data type + should be float16, float32, float64, int32 or int64. + local_count (Tensor): Tensor which have n_expert * world_size elements that indicates + how many data needed to be sent. Every element in the list must be a Tensor whose + data type should be int64. + global_count (Tensor): Tensor which have n_expert * world_size elements that indicates + how many data needed to be received. Every element in the list must be a Tensor whose + data type should be int64. + group (Group, optional): The group instance return by new_group or None for global default group. Default: None. + use_calc_stream (bool, optional): Wether to use calculation stream (True) or communication stream. Default: True. + + Returns: + out (Tensor): The data received from all experts. + + Examples: + .. code-block:: python + + # required: distributed + import numpy as np + import paddle + from paddle.distributed import init_parallel_env + init_parallel_env() + n_expert = 2 + world_size = 2 + d_model = 2 + in_feat = d_model + local_input_buf = np.array([[1, 2],[3, 4],[5, 6],[7, 8],[9, 10]], \ + dtype=np.float32) + if paddle.distributed.ParallelEnv().local_rank == 0: + local_count = np.array([2, 1, 1, 1]) + global_count = np.array([2, 1, 1, 1]) + else: + local_count = np.array([1, 1, 2, 1]) + global_count = np.array([1, 1, 2, 1]) + local_input_buf = paddle.to_tensor(local_input_buf, dtype="float32", stop_gradient=False) + local_count = paddle.to_tensor(local_count, dtype="int64") + global_count = paddle.to_tensor(global_count, dtype="int64") + a = paddle.distributed.utils.global_scatter(local_input_buf, \ + local_count, global_count) + a.stop_gradient = False + print(a) + # out for rank 0: [[1, 2], [3, 4], [1, 2], [5, 6], [3, 4]] + # out for rank 1: [[7, 8], [5, 6], [7, 8], [9, 10], [9, 10]] + # backward test + c = a * a + c.backward() + print("local_input_buf.grad: ", local_input_buf.grad) + # out for rank 0: [[2, 4], [6, 8], [10, 12], [14, 16], [18, 20]] + # out for rank 1: [[2, 4], [6, 8], [10, 12], [14, 16], [18, 20]] + """ + if group is not None and not group.is_member(): + return + + ring_id = 0 if group is None else group.id + if in_dygraph_mode(): + return core.ops.global_scatter(x, local_count, \ + global_count, \ + 'use_calc_stream', use_calc_stream, \ + 'ring_id', ring_id) + else: + op_type = 'global_scatter' + check_variable_and_dtype( + x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64'], + 'global_scatter') + check_variable_and_dtype(local_count, 'local_count', ['int64'], + 'global_scatter') + check_variable_and_dtype(global_count, 'global_count', ['int64'], + 'global_scatter') + + helper = LayerHelper(op_type, **locals()) + out = helper.create_variable_for_type_inference(dtype=x.dtype) + + helper.append_op( + type=op_type, + inputs={ + 'X': [x], + 'local_count': [local_count], + 'global_count': [global_count], + }, + outputs={'Out': [out]}, + attrs={'ring_id': ring_id, + 'use_calc_stream': use_calc_stream}) + return out + + +def global_gather(x, + local_count, + global_count, + group=None, + use_calc_stream=True): + """ + Gather data in x to n_expert * world_size exeperts according to + local_count and receive tensors from n_expert * world_size experts according + to global_count. + + Args: + x (Tensor): Tensor. Every element in the list must be a Tensor whose data type + should be float16, float32, float64, int32 or int64. + local_count (Tensor): Tensor which have n_expert * world_size elements that indicates + how many data needed to be received. Every element in the list must be a Tensor whose + data type should be int64. + global_count (Tensor): Tensor which have n_expert * world_size elements that indicates + how many data needed to be sent. Every element in the list must be a Tensor whose + data type should be int64. + group (Group, optional): The group instance return by new_group or None for global default group. Default: None. + use_calc_stream (bool, optional): Wether to use calculation stream (True) or communication stream. Default: True. + + Returns: + None. + + Examples: + .. code-block:: python + + # required: distributed + import numpy as np + import paddle + from paddle.distributed import init_parallel_env + init_parallel_env() + n_expert = 2 + world_size = 2 + d_model = 2 + in_feat = d_model + local_input_buf = np.array([[1, 2],[3, 4],[5, 6],[7, 8],[9, 10]],\ + dtype=np.float32) + if paddle.distributed.ParallelEnv().local_rank == 0: + local_count = np.array([2, 1, 1, 1]) + global_count = np.array([2, 1, 1, 1]) + else: + local_count = np.array([1, 1, 2, 1]) + global_count = np.array([1, 1, 2, 1]) + local_input_buf = paddle.to_tensor(local_input_buf, dtype="float32", stop_gradient=False) + local_count = paddle.to_tensor(local_count, dtype="int64") + global_count = paddle.to_tensor(global_count, dtype="int64") + a = paddle.distributed.utils.global_gather(local_input_buf, local_count, global_count) + print(a) + # out for rank 0: [[1, 2], [3, 4], [7, 8], [1, 2], [7, 8]] + # out for rank 1: [[5, 6], [9, 10], [3, 4], [5, 6], [9, 10]] + a.stop_gradient = False + c = a * a + c.backward() + print("local_input_buf.grad", local_input_buf.grad) + # out for rank 0: [[2, 4], [6, 8], [10, 12], [14, 16], [18, 20]] + # out for rank 1: [[2, 4], [6, 8], [10, 12], [14, 16], [18, 20]] + """ + if group is not None and not group.is_member(): + return + + ring_id = 0 if group is None else group.id + if in_dygraph_mode(): + return core.ops.global_gather(x, local_count, \ + global_count, \ + 'use_calc_stream', use_calc_stream, \ + 'ring_id', ring_id) + else: + op_type = 'global_gather' + check_variable_and_dtype( + x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64'], + 'global_gather') + + check_variable_and_dtype(local_count, 'local_count', ['int64'], + 'global_gather') + + check_variable_and_dtype(global_count, 'global_count', ['int64'], + 'global_gather') + helper = LayerHelper(op_type, **locals()) + out = helper.create_variable_for_type_inference(dtype=x.dtype) + + helper.append_op( + type=op_type, + inputs={ + 'X': [x], + 'local_count': [local_count], + 'global_count': [global_count] + }, + outputs={'Out': [out]}, + attrs={ + 'ring_id': group, + 'use_calc_stream': use_calc_stream, + }) + return out + + logger = logging.getLogger("root") logger.propagate = False diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 2c001614d1bacb..3085e467ecaeca 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -112,6 +112,8 @@ if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32) LIST(REMOVE_ITEM TEST_OPS test_collective_broadcast_api) LIST(REMOVE_ITEM TEST_OPS test_collective_allgather_api) LIST(REMOVE_ITEM TEST_OPS test_collective_alltoall_api) + LIST(REMOVE_ITEM TEST_OPS test_collective_global_gather) + LIST(REMOVE_ITEM TEST_OPS test_collective_global_scatter) LIST(REMOVE_ITEM TEST_OPS test_collective_sendrecv_api) LIST(REMOVE_ITEM TEST_OPS test_collective_wait) LIST(REMOVE_ITEM TEST_OPS test_memcpy_op) @@ -944,6 +946,8 @@ endif() if((WITH_ROCM OR WITH_GPU) AND NOT WIN32) set_tests_properties(test_collective_allgather_api PROPERTIES TIMEOUT 120) set_tests_properties(test_collective_alltoall_api PROPERTIES TIMEOUT 120) + set_tests_properties(test_collective_global_gather PROPERTIES TIMEOUT 120) + set_tests_properties(test_collective_global_scatter PROPERTIES TIMEOUT 120) set_tests_properties(test_collective_sendrecv_api PROPERTIES TIMEOUT 120) set_tests_properties(test_collective_broadcast_api PROPERTIES TIMEOUT 120) set_tests_properties(test_collective_allreduce_api PROPERTIES TIMEOUT 120) @@ -964,6 +968,8 @@ if((WITH_ROCM OR WITH_GPU) AND NOT WIN32) test_collective_broadcast_api test_collective_allgather_api test_collective_alltoall_api + test_collective_global_gather + test_collective_global_scatter PROPERTIES LABELS "RUN_TYPE=DIST") endif() set_tests_properties(test_reducescatter_api PROPERTIES TIMEOUT 120) diff --git a/python/paddle/fluid/tests/unittests/collective_global_gather.py b/python/paddle/fluid/tests/unittests/collective_global_gather.py new file mode 100644 index 00000000000000..d3a6071ed04dfd --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective_global_gather.py @@ -0,0 +1,113 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import os +import sys +import paddle +import paddle.fluid as fluid +import unittest +import paddle.fluid.layers as layers +from test_collective_api_base import TestCollectiveAPIRunnerBase, runtime_main +import pickle + +paddle.enable_static() + + +class TestCollectiveGlobalGatherAPI(TestCollectiveAPIRunnerBase): + def __init__(self): + self.global_ring_id = 0 + + def get_model(self, main_prog, startup_program, rank, indata=None): + with fluid.program_guard(main_prog, startup_program): + seed = os.getpid() + np.random.seed(seed) + in_feat = 2 + n_expert = 2 + world_size = 2 + tot_expert = n_expert * world_size + local_input_buf = paddle.static.data( + name="local_input_buf", shape=[-1, in_feat], dtype="float32") + local_expert_count = paddle.static.data( + name="local_expert_count", shape=[tot_expert], dtype="int64") + global_expert_count = paddle.static.data( + name="global_expert_count", shape=[tot_expert], dtype="int64") + + output = paddle.distributed.utils.global_gather( + local_input_buf, local_expert_count, global_expert_count) + + return [output] + + def run_trainer(self, args): + train_prog = fluid.Program() + startup_prog = fluid.Program() + endpoints = args["endpoints"].split(",") + rank = args["trainerid"] + current_endpoint = args["currentendpoint"] + nranks = 2 + paddle.distributed.init_parallel_env() + if args['backend'] == 'nccl': + device_id = int(os.getenv("FLAGS_selected_gpus", "0")) + place = fluid.CUDAPlace( + device_id) #if args.use_gpu else fluid.CPUPlace() + elif args['backend'] == 'bkcl': + device_id = int(os.getenv("FLAGS_selected_xpus", "0")) + place = fluid.XPUPlace(device_id) + else: + place = fluid.CPUPlace() + + in_feat = 2 + n_expert = 2 + world_size = 2 + tot_expert = n_expert * world_size + paddle.disable_static() + np.random.seed(os.getpid()) + local_expert_count = np.random.randint( + 1, 4, size=tot_expert).astype("int64") + local_expert_count = paddle.to_tensor(local_expert_count) + global_expert_count = [] + paddle.distributed.alltoall( + paddle.split( + local_expert_count, 2, axis=0), global_expert_count) + global_expert_count = paddle.concat(global_expert_count, axis=0) + global_expert_count = global_expert_count.numpy() + local_expert_count = local_expert_count.numpy() + fwd_expert_count = sum(global_expert_count) + np.random.seed(os.getpid()) + local_input_buf = np.random.rand(fwd_expert_count, + in_feat).astype("float32") + + paddle.enable_static() + if args['static_mode']: + result = self.get_model(train_prog, startup_prog, rank) + exe = fluid.Executor(place) + exe.run(startup_prog) + fetch_list = [] + for elem in result: + fetch_list.append(elem.name) + out = exe.run(train_prog, + feed={ + 'local_expert_count': local_expert_count, + 'global_expert_count': global_expert_count, + 'local_input_buf': local_input_buf + }, + fetch_list=fetch_list) + + sys.stdout.buffer.write(pickle.dumps(out)) + + +if __name__ == "__main__": + runtime_main(TestCollectiveGlobalGatherAPI, "global_gather") diff --git a/python/paddle/fluid/tests/unittests/collective_global_gather_dygraph.py b/python/paddle/fluid/tests/unittests/collective_global_gather_dygraph.py new file mode 100644 index 00000000000000..20df5f35555964 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective_global_gather_dygraph.py @@ -0,0 +1,64 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import os +import sys +import paddle +import paddle.fluid as fluid +import unittest +import paddle.fluid.layers as layers +from test_collective_api_base import TestCollectiveAPIRunnerBase, runtime_main + + +class TestCollectiveGlobalGatherAPI(TestCollectiveAPIRunnerBase): + def __init__(self): + self.global_ring_id = 0 + + def get_model(self, main_prog, startup_program, rank, indata=None): + with fluid.program_guard(main_prog, startup_program): + seed = os.getpid() + np.random.seed(seed) + in_feat = 2 + n_expert = 2 + world_size = 2 + tot_expert = n_expert * world_size + local_expert_count = np.random.randint( + 1, 4, size=tot_expert).astype("int") + local_expert_count = paddle.to_tensor(local_expert_count) + global_expert_count = [] + paddle.distributed.alltoall( + paddle.split( + local_expert_count, 2, axis=0), + global_expert_count) + global_expert_count = paddle.concat(global_expert_count, axis=0) + fwd_expert_count = sum(global_expert_count) + np.random.seed(seed) + local_input_buf = np.random.rand(fwd_expert_count, + in_feat).astype("float32") + local_input_buf = paddle.to_tensor(local_input_buf) + local_input_buf.stop_gradient = False + output = paddle.distributed.utils.global_gather( + local_input_buf, local_expert_count, global_expert_count) + output.stop_gradient = False + c = output * output + c.stop_gradient = False + c.backward() + return [output.numpy(), local_input_buf.grad.numpy()] + + +if __name__ == "__main__": + runtime_main(TestCollectiveGlobalGatherAPI, "global_gather") diff --git a/python/paddle/fluid/tests/unittests/collective_global_scatter.py b/python/paddle/fluid/tests/unittests/collective_global_scatter.py new file mode 100644 index 00000000000000..74d12b61aca410 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective_global_scatter.py @@ -0,0 +1,101 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import os +import sys +import paddle +import paddle.fluid as fluid +import unittest +import paddle.fluid.layers as layers +from test_collective_api_base import TestCollectiveAPIRunnerBase, runtime_main +import pickle + +paddle.enable_static() + + +class TestCollectiveGlobalScatterAPI(TestCollectiveAPIRunnerBase): + def __init__(self): + self.global_ring_id = 0 + + def get_model(self, main_prog, startup_program, rank, indata=None): + with fluid.program_guard(main_prog, startup_program): + seed = os.getpid() + np.random.seed(seed) + in_feat = 2 + n_expert = 2 + world_size = 2 + tot_expert = n_expert * world_size + local_input_buf = paddle.static.data( + name="local_input_buf", shape=[-1, in_feat], dtype="float32") + local_expert_count = paddle.static.data( + name="local_expert_count", shape=[tot_expert], dtype="int64") + global_expert_count = [] + paddle.distributed.alltoall( + paddle.split( + local_expert_count, 2, axis=0), + global_expert_count) + global_expert_count = paddle.concat(global_expert_count, axis=0) + output = paddle.distributed.utils.global_scatter( + local_input_buf, local_expert_count, global_expert_count) + return [output] + + def run_trainer(self, args): + train_prog = fluid.Program() + startup_prog = fluid.Program() + endpoints = args["endpoints"].split(",") + rank = args["trainerid"] + current_endpoint = args["currentendpoint"] + nranks = 2 + paddle.distributed.init_parallel_env() + if args['backend'] == 'nccl': + device_id = int(os.getenv("FLAGS_selected_gpus", "0")) + place = fluid.CUDAPlace( + device_id) #if args.use_gpu else fluid.CPUPlace() + elif args['backend'] == 'bkcl': + device_id = int(os.getenv("FLAGS_selected_xpus", "0")) + place = fluid.XPUPlace(device_id) + else: + place = fluid.CPUPlace() + np.random.seed(os.getpid()) + in_feat = 2 + n_expert = 2 + world_size = 2 + tot_expert = n_expert * world_size + local_expert_count = np.random.randint( + 1, 4, size=tot_expert).astype("int64") + fwd_expert_count = sum(local_expert_count) + local_input_buf = np.random.rand(fwd_expert_count, + in_feat).astype("float32") + if args['static_mode']: + result = self.get_model(train_prog, startup_prog, rank) + exe = fluid.Executor(place) + exe.run(startup_prog) + fetch_list = [] + for elem in result: + fetch_list.append(elem.name) + out = exe.run(train_prog, + feed={ + 'local_expert_count': local_expert_count, + 'local_input_buf': local_input_buf + }, + fetch_list=fetch_list) + + sys.stdout.buffer.write(pickle.dumps(out)) + + +if __name__ == "__main__": + runtime_main(TestCollectiveGlobalScatterAPI, "global_scatter") diff --git a/python/paddle/fluid/tests/unittests/collective_global_scatter_dygraph.py b/python/paddle/fluid/tests/unittests/collective_global_scatter_dygraph.py new file mode 100644 index 00000000000000..f7e13a87622745 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective_global_scatter_dygraph.py @@ -0,0 +1,62 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import os +import sys +import paddle +import paddle.fluid as fluid +import unittest +import paddle.fluid.layers as layers +from test_collective_api_base import TestCollectiveAPIRunnerBase, runtime_main + + +class TestCollectiveGlobalScatterAPI(TestCollectiveAPIRunnerBase): + def __init__(self): + self.global_ring_id = 0 + + def get_model(self, main_prog, startup_program, rank, indata=None): + with fluid.program_guard(main_prog, startup_program): + seed = os.getpid() + np.random.seed(seed) + in_feat = 2 + n_expert = 2 + world_size = 2 + tot_expert = n_expert * world_size + local_expert_count = np.random.randint( + 1, 4, size=tot_expert).astype("int") + fwd_expert_count = sum(local_expert_count) + local_input_buf = np.random.rand(fwd_expert_count, + in_feat).astype("float32") + local_expert_count = paddle.to_tensor(local_expert_count) + local_input_buf = paddle.to_tensor(local_input_buf) + global_expert_count = [] + paddle.distributed.alltoall( + paddle.split( + local_expert_count, 2, axis=0), + global_expert_count) + global_expert_count = paddle.concat(global_expert_count, axis=0) + local_input_buf.stop_gradient = False + output = paddle.distributed.utils.global_scatter( + local_input_buf, local_expert_count, global_expert_count) + output.stop_gradient = False + c = output * output + c.backward() + return [output.numpy(), local_input_buf.grad.numpy()] + + +if __name__ == "__main__": + runtime_main(TestCollectiveGlobalScatterAPI, "global_scatter") diff --git a/python/paddle/fluid/tests/unittests/test_collective_global_gather.py b/python/paddle/fluid/tests/unittests/test_collective_global_gather.py new file mode 100644 index 00000000000000..efcbdaa09c360b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_collective_global_gather.py @@ -0,0 +1,191 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +import numpy as np +import paddle + +from test_collective_api_base import TestDistBase +import os + + +class TestCollectiveGlobalGatherAPI(TestDistBase): + def _setup_config(self): + pass + + def test_global_gather_nccl(self): + paddle.enable_static() + self.check_with_place("collective_global_gather.py", "global_gather", + "nccl") + + def test_global_gather_nccl_dygraph(self): + self.check_with_place( + "collective_global_gather_dygraph.py", + "global_gather", + "nccl", + static_mode="0") + + def check_with_place(self, + model_file, + col_type, + backend="nccl", + path_id="0", + static_mode="1", + check_error_log=False, + need_envs={}): + if backend == "nccl" or backend == "bkcl": + with_gloo = '0' + else: + with_gloo = '1' + required_envs = { + "FLAGS_fraction_of_gpu_memory_to_use": "0.15", + "FLAGS_eager_delete_tensor_gb": "0.0", + "PATH": os.getenv("PATH"), + "PYTHONPATH": os.getenv("PYTHONPATH", ""), + "LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""), + "LD_PRELOAD": os.getenv("LD_PRELOAD", ""), + "FLAGS_call_stack_level": "2", + "GLOG_v": "3", + "NCCL_P2P_DISABLE": "1", + "STATIC_MODE": static_mode, + "PADDLE_WITH_GLOO": with_gloo, + "BACKEND": backend, + "PATH_ID": path_id + } + required_envs.update(need_envs) + if check_error_log: + required_envs["GLOG_v"] = "3" + required_envs["GLOG_logtostderr"] = "1" + required_envs["GLOO_LOG_LEVEL"] = "TRACE" + tr0_out, tr1_out, pid0, pid1 = self._run_cluster(model_file, + required_envs) + + if col_type == "global_gather": + in_feat = 2 + n_expert = 2 + world_size = 2 + tot_expert = n_expert * world_size + + np.random.seed(pid0) + local_expert_count1 = np.random.randint( + 1, 4, size=tot_expert).astype("int") + expert_ptr1 = np.ones(tot_expert, dtype=np.int32) + expert_ptr1[0] = 0 + for i in range(1, tot_expert): + expert_ptr1[i] = expert_ptr1[i - 1] + local_expert_count1[i - 1] + + np.random.seed(pid1) + local_expert_count2 = np.random.randint( + 1, 4, size=tot_expert).astype("int") + expert_ptr2 = np.ones(tot_expert, dtype=np.int32) + expert_ptr2[0] = 0 + for i in range(1, tot_expert): + expert_ptr2[i] = expert_ptr2[i - 1] + local_expert_count2[i - 1] + + global_expert_count1 = np.zeros(tot_expert).astype("int") + global_expert_count2 = np.zeros(tot_expert).astype("int") + global_expert_count1[0:n_expert] = local_expert_count1[0:n_expert] + global_expert_count1[n_expert:] = local_expert_count2[0:n_expert] + global_expert_count2[0:n_expert] = local_expert_count1[n_expert:] + global_expert_count2[n_expert:] = local_expert_count2[n_expert:] + + np.random.seed(pid0) + fwd_expert_count = sum(global_expert_count1).astype("int") + local_input_buf1 = np.random.rand(fwd_expert_count, + in_feat).astype("float32") + np.random.seed(pid1) + fwd_expert_count = sum(global_expert_count2).astype("int") + local_input_buf2 = np.random.rand(fwd_expert_count, + in_feat).astype("float32") + output1 = [[], [], [], []] + output2 = [[], [], [], []] + send_ptr1 = 0 + send_ptr2 = 0 + + for i in range(n_expert): + for j in range(world_size): + idx = j * n_expert + i + if j == 0: + output1_part1 = local_input_buf1[send_ptr1: \ + send_ptr1 + global_expert_count1[idx], :] + output1_part2 = local_input_buf2[send_ptr2: \ + send_ptr2 + global_expert_count2[idx], :] + output1[i].extend(output1_part1) + output1[i + n_expert].extend(output1_part2) + else: + output2_part1 = local_input_buf1[send_ptr1: \ + send_ptr1 + global_expert_count1[idx]] + output2_part2 = local_input_buf2[send_ptr2: \ + send_ptr2 + global_expert_count2[idx]] + output2[i].extend(output2_part1) + output2[i + n_expert].extend(output2_part2) + send_ptr1 = send_ptr1 + global_expert_count1[idx] + send_ptr2 = send_ptr2 + global_expert_count2[idx] + result1 = [] + result2 = [] + for i in range(tot_expert): + for arr in output1[i]: + if arr == []: + continue + result1.append(arr) + for i in range(tot_expert): + for arr in output2[i]: + if arr == []: + continue + result2.append(arr) + if result1 == []: + output1 = np.array([]) + else: + output1 = np.concatenate( + result1, axis=0).reshape( + sum(local_expert_count1), in_feat) + if result2 == []: + output2 = np.array([]) + else: + output2 = np.concatenate( + result2, axis=0).reshape( + sum(local_expert_count2), in_feat) + + if tr0_out[0] is None or tr0_out[0].shape[0] == 0: + tr0_out[0] = np.array([]) + + if tr1_out[0] is None or tr1_out[0].shape[0] == 0: + tr1_out[0] = np.array([]) + + self.assertTrue( + np.allclose( + tr0_out[0], output1, rtol=1e-05, atol=1e-05)) + self.assertTrue( + np.allclose( + tr1_out[0], output2, rtol=1e-05, atol=1e-05)) + if static_mode == 0: + self.assertTrue( + np.allclose( + tr0_out[1], + 2 * local_input_buf1, + rtol=1e-05, + atol=1e-05)) + self.assertTrue( + np.allclose( + tr1_out[1], + 2 * local_input_buf2, + rtol=1e-05, + atol=1e-05)) + else: + pass + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_collective_global_scatter.py b/python/paddle/fluid/tests/unittests/test_collective_global_scatter.py new file mode 100644 index 00000000000000..038118db0d18a1 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_collective_global_scatter.py @@ -0,0 +1,148 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +import numpy as np +import paddle + +from test_collective_api_base import TestDistBase +import os + + +class TestCollectiveSelectScatterAPI(TestDistBase): + def _setup_config(self): + pass + + def test_global_scatter_nccl(self): + paddle.enable_static() + self.check_with_place("collective_global_scatter.py", "global_scatter", + "nccl") + + def test_global_scatter_nccl_dygraph(self): + self.check_with_place( + "collective_global_scatter_dygraph.py", + "global_scatter", + "nccl", + static_mode="0") + + def check_with_place(self, + model_file, + col_type, + backend="nccl", + path_id="0", + static_mode="1", + check_error_log=False, + need_envs={}): + if backend == "nccl" or backend == "bkcl": + with_gloo = '0' + else: + with_gloo = '1' + required_envs = { + "FLAGS_fraction_of_gpu_memory_to_use": "0.15", + "FLAGS_eager_delete_tensor_gb": "0.0", + "PATH": os.getenv("PATH"), + "PYTHONPATH": os.getenv("PYTHONPATH", ""), + "LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""), + "LD_PRELOAD": os.getenv("LD_PRELOAD", ""), + "FLAGS_call_stack_level": "2", + "GLOG_v": "3", + "NCCL_P2P_DISABLE": "1", + "STATIC_MODE": static_mode, + "PADDLE_WITH_GLOO": with_gloo, + "BACKEND": backend, + "PATH_ID": path_id + } + required_envs.update(need_envs) + if check_error_log: + required_envs["GLOG_v"] = "3" + required_envs["GLOG_logtostderr"] = "1" + required_envs["GLOO_LOG_LEVEL"] = "TRACE" + tr0_out, tr1_out, pid0, pid1 = self._run_cluster(model_file, + required_envs) + + if col_type == "global_scatter": + np.random.seed(pid0) + local_expert_count1 = np.random.randint(1, 4, size=4).astype("int") + fwd_expert_count = sum(local_expert_count1) + local_input_buf1 = np.random.rand(fwd_expert_count, + 2).astype("float32") + expert_ptr1 = np.ones(4, dtype=np.int32) + expert_ptr1[0] = 0 + for i in range(1, 4): + expert_ptr1[i] = expert_ptr1[i - 1] + local_expert_count1[i - 1] + np.random.seed(pid1) + local_expert_count2 = np.random.randint(1, 4, size=4).astype("int") + fwd_expert_count = sum(local_expert_count2) + local_input_buf2 = np.random.rand(fwd_expert_count, + 2).astype("float32") + expert_ptr2 = np.ones(4, dtype=np.int32) + expert_ptr2[0] = 0 + for i in range(1, 4): + expert_ptr2[i] = expert_ptr2[i - 1] + local_expert_count2[i - 1] + + output1 = [] + output2 = [] + for i in range(2): + for j in range(2): + idx = j * 2 + i + if j == 0: + # send data to 0 card + output1.append(local_input_buf1[expert_ptr1[idx]: \ + expert_ptr1[idx]+local_expert_count1[idx]]) + output1.append(local_input_buf2[expert_ptr2[idx]:\ + expert_ptr2[idx]+local_expert_count2[idx]]) + else: + output2.append(local_input_buf1[expert_ptr1[idx]: \ + expert_ptr1[idx]+local_expert_count1[idx]]) + output2.append(local_input_buf2[expert_ptr2[idx]:\ + expert_ptr2[idx]+local_expert_count2[idx]]) + if output1 == []: + output1 = np.array([]) + else: + output1 = np.concatenate(output1) + if output2 == []: + output2 = np.array([]) + else: + output2 = np.concatenate(output2) + + if tr0_out[0] is None or tr0_out[0].shape[0] == 0: + tr0_out[0] = np.array([]) + + if tr1_out[0] is None or tr1_out[0].shape[0] == 0: + tr1_out[0] = np.array([]) + + self.assertTrue( + np.allclose( + tr0_out[0], output1, rtol=1e-05, atol=1e-05)) + self.assertTrue( + np.allclose( + tr1_out[0], output2, rtol=1e-05, atol=1e-05)) + if static_mode == 0: + self.assertTrue( + np.allclose( + tr0_out[1], + 2 * local_input_buf1, + rtol=1e-05, + atol=1e-05)) + self.assertTrue( + np.allclose( + tr1_out[1], + 2 * local_input_buf2, + rtol=1e-05, + atol=1e-05)) + + +if __name__ == '__main__': + unittest.main() From 3baaa23647a41876303a43068dd23407a9ae3ff0 Mon Sep 17 00:00:00 2001 From: youth123 <2042519524@qq.com> Date: Thu, 9 Sep 2021 19:28:53 +0800 Subject: [PATCH 2/3] fix some review bug --- .../operators/collective/global_gather_op.cc | 2 +- .../collective/global_gather_op.cu.cc | 25 ++- .../operators/collective/global_scatter_op.cc | 2 +- .../collective/global_scatter_op.cu.cc | 25 ++- .../unittests/test_collective_api_base.py | 182 ++++++++++++++++++ .../test_collective_global_gather.py | 149 -------------- .../test_collective_global_scatter.py | 106 ---------- 7 files changed, 218 insertions(+), 273 deletions(-) diff --git a/paddle/fluid/operators/collective/global_gather_op.cc b/paddle/fluid/operators/collective/global_gather_op.cc index 3c9929326bc688..110ff5755695e3 100644 --- a/paddle/fluid/operators/collective/global_gather_op.cc +++ b/paddle/fluid/operators/collective/global_gather_op.cc @@ -40,7 +40,7 @@ class GlobalGatherOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ(ndim_input, 2, platform::errors::InvalidArgument( "The input tensor's dimension must be 2. " - "But received input's dimension = [%s].", + "But received input's dimension = %d.", ndim_input)); framework::DDim out_dims = framework::make_ddim({-1, -1}); ctx->SetOutputDim("Out", out_dims); diff --git a/paddle/fluid/operators/collective/global_gather_op.cu.cc b/paddle/fluid/operators/collective/global_gather_op.cu.cc index d5d93f97b369bf..70b5d0244d3852 100644 --- a/paddle/fluid/operators/collective/global_gather_op.cu.cc +++ b/paddle/fluid/operators/collective/global_gather_op.cu.cc @@ -30,29 +30,39 @@ class GlobalGatherOpCUDAKernel : public framework::OpKernel { auto x = ctx.Input("X"); auto local_count = ctx.Input("local_count"); auto global_count = ctx.Input("global_count"); + auto local_count_type = local_count->type(); + auto global_count_type = global_count->type(); + if (local_count_type != framework::proto::VarType::INT64) { + PADDLE_THROW(platform::errors::InvalidArgument( + "Please use int64 type in local_count.")); + } + if (global_count_type != framework::proto::VarType::INT64) { + PADDLE_THROW(platform::errors::InvalidArgument( + "Please use int64 type in global_count.")); + } auto out = ctx.Output("Out"); const int64_t* cpu_local_count_data; const int64_t* cpu_global_count_data; auto local_count_len = 0; framework::Tensor cpu_local_count; - if (platform::is_gpu_place(local_count->place())) { + if (platform::is_cpu_place(local_count->place())) { + cpu_local_count_data = local_count->data(); + local_count_len = local_count->numel(); + } else { framework::TensorCopySync(*local_count, platform::CPUPlace(), &cpu_local_count); cpu_local_count_data = cpu_local_count.data(); local_count_len = cpu_local_count.numel(); - } else { - cpu_local_count_data = local_count->data(); - local_count_len = local_count->numel(); } framework::Tensor cpu_global_count; - if (platform::is_gpu_place(global_count->place())) { + if (platform::is_cpu_place(global_count->place())) { + cpu_global_count_data = global_count->data(); + } else { framework::TensorCopySync(*global_count, platform::CPUPlace(), &cpu_global_count); cpu_global_count_data = cpu_global_count.data(); - } else { - cpu_global_count_data = global_count->data(); } ncclDataType_t dtype = platform::ToNCCLDataType(x->type()); @@ -111,7 +121,6 @@ class GlobalGatherOpCUDAKernel : public framework::OpKernel { } } PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupEnd()); - PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream)); } #else PADDLE_THROW( diff --git a/paddle/fluid/operators/collective/global_scatter_op.cc b/paddle/fluid/operators/collective/global_scatter_op.cc index ae7fe85e37d3aa..2c859643c4df97 100644 --- a/paddle/fluid/operators/collective/global_scatter_op.cc +++ b/paddle/fluid/operators/collective/global_scatter_op.cc @@ -40,7 +40,7 @@ class GlobalScatterOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ(ndim_input, 2, platform::errors::InvalidArgument( "The input tensor's dimension must be 2. " - "But received input's dimension = [%s].", + "But received input's dimension = %d.", ndim_input)); framework::DDim out_dims = framework::make_ddim({-1, -1}); diff --git a/paddle/fluid/operators/collective/global_scatter_op.cu.cc b/paddle/fluid/operators/collective/global_scatter_op.cu.cc index f67baceadf545a..64765b549e5c1f 100644 --- a/paddle/fluid/operators/collective/global_scatter_op.cu.cc +++ b/paddle/fluid/operators/collective/global_scatter_op.cu.cc @@ -30,27 +30,37 @@ class GlobalScatterOpCUDAKernel : public framework::OpKernel { auto x = ctx.Input("X"); auto local_count = ctx.Input("local_count"); auto global_count = ctx.Input("global_count"); + auto local_count_type = local_count->type(); + auto global_count_type = global_count->type(); + if (local_count_type != framework::proto::VarType::INT64) { + PADDLE_THROW(platform::errors::InvalidArgument( + "Please use int64 type in local_count.")); + } + if (global_count_type != framework::proto::VarType::INT64) { + PADDLE_THROW(platform::errors::InvalidArgument( + "Please use int64 type in global_count.")); + } auto out = ctx.Output("Out"); const int64_t* cpu_local_count_data; const int64_t* cpu_global_count_data; framework::Tensor cpu_local_count; - if (platform::is_gpu_place(local_count->place())) { + if (platform::is_cpu_place(local_count->place())) { + cpu_local_count_data = local_count->data(); + } else { framework::TensorCopy(*local_count, platform::CPUPlace(), &cpu_local_count); cpu_local_count_data = cpu_local_count.data(); - } else { - cpu_local_count_data = local_count->data(); } auto global_count_len = 0; framework::Tensor cpu_global_count; - if (platform::is_gpu_place(global_count->place())) { + if (platform::is_cpu_place(global_count->place())) { + cpu_global_count_data = global_count->data(); + global_count_len = global_count->numel(); + } else { framework::TensorCopy(*global_count, platform::CPUPlace(), &cpu_global_count); cpu_global_count_data = cpu_global_count.data(); global_count_len = cpu_global_count.numel(); - } else { - cpu_global_count_data = global_count->data(); - global_count_len = global_count->numel(); } ncclDataType_t dtype = platform::ToNCCLDataType(x->type()); @@ -110,7 +120,6 @@ class GlobalScatterOpCUDAKernel : public framework::OpKernel { } } PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupEnd()); - PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream)); } #else diff --git a/python/paddle/fluid/tests/unittests/test_collective_api_base.py b/python/paddle/fluid/tests/unittests/test_collective_api_base.py index 6868fb4c7499e9..00294bf6071b32 100644 --- a/python/paddle/fluid/tests/unittests/test_collective_api_base.py +++ b/python/paddle/fluid/tests/unittests/test_collective_api_base.py @@ -292,5 +292,187 @@ def check_with_place(self, self.assertTrue( np.allclose( input1, result_data, rtol=1e-05, atol=1e-05)) + elif col_type == "global_gather": + in_feat = 2 + n_expert = 2 + world_size = 2 + tot_expert = n_expert * world_size + + np.random.seed(pid0) + local_expert_count1 = np.random.randint( + 1, 4, size=tot_expert).astype("int") + expert_ptr1 = np.ones(tot_expert, dtype=np.int32) + expert_ptr1[0] = 0 + for i in range(1, tot_expert): + expert_ptr1[i] = expert_ptr1[i - 1] + local_expert_count1[i - 1] + + np.random.seed(pid1) + local_expert_count2 = np.random.randint( + 1, 4, size=tot_expert).astype("int") + expert_ptr2 = np.ones(tot_expert, dtype=np.int32) + expert_ptr2[0] = 0 + for i in range(1, tot_expert): + expert_ptr2[i] = expert_ptr2[i - 1] + local_expert_count2[i - 1] + + global_expert_count1 = np.zeros(tot_expert).astype("int") + global_expert_count2 = np.zeros(tot_expert).astype("int") + global_expert_count1[0:n_expert] = local_expert_count1[0:n_expert] + global_expert_count1[n_expert:] = local_expert_count2[0:n_expert] + global_expert_count2[0:n_expert] = local_expert_count1[n_expert:] + global_expert_count2[n_expert:] = local_expert_count2[n_expert:] + + np.random.seed(pid0) + fwd_expert_count = sum(global_expert_count1).astype("int") + local_input_buf1 = np.random.rand(fwd_expert_count, + in_feat).astype("float32") + np.random.seed(pid1) + fwd_expert_count = sum(global_expert_count2).astype("int") + local_input_buf2 = np.random.rand(fwd_expert_count, + in_feat).astype("float32") + output1 = [[], [], [], []] + output2 = [[], [], [], []] + send_ptr1 = 0 + send_ptr2 = 0 + + for i in range(n_expert): + for j in range(world_size): + idx = j * n_expert + i + if j == 0: + output1_part1 = local_input_buf1[send_ptr1: \ + send_ptr1 + global_expert_count1[idx], :] + output1_part2 = local_input_buf2[send_ptr2: \ + send_ptr2 + global_expert_count2[idx], :] + output1[i].extend(output1_part1) + output1[i + n_expert].extend(output1_part2) + else: + output2_part1 = local_input_buf1[send_ptr1: \ + send_ptr1 + global_expert_count1[idx]] + output2_part2 = local_input_buf2[send_ptr2: \ + send_ptr2 + global_expert_count2[idx]] + output2[i].extend(output2_part1) + output2[i + n_expert].extend(output2_part2) + send_ptr1 = send_ptr1 + global_expert_count1[idx] + send_ptr2 = send_ptr2 + global_expert_count2[idx] + result1 = [] + result2 = [] + for i in range(tot_expert): + for arr in output1[i]: + if arr == []: + continue + result1.append(arr) + for i in range(tot_expert): + for arr in output2[i]: + if arr == []: + continue + result2.append(arr) + if result1 == []: + output1 = np.array([]) + else: + output1 = np.concatenate( + result1, axis=0).reshape( + sum(local_expert_count1), in_feat) + if result2 == []: + output2 = np.array([]) + else: + output2 = np.concatenate( + result2, axis=0).reshape( + sum(local_expert_count2), in_feat) + + if tr0_out[0] is None or tr0_out[0].shape[0] == 0: + tr0_out[0] = np.array([]) + + if tr1_out[0] is None or tr1_out[0].shape[0] == 0: + tr1_out[0] = np.array([]) + + self.assertTrue( + np.allclose( + tr0_out[0], output1, rtol=1e-05, atol=1e-05)) + self.assertTrue( + np.allclose( + tr1_out[0], output2, rtol=1e-05, atol=1e-05)) + if static_mode == 0: + self.assertTrue( + np.allclose( + tr0_out[1], + 2 * local_input_buf1, + rtol=1e-05, + atol=1e-05)) + self.assertTrue( + np.allclose( + tr1_out[1], + 2 * local_input_buf2, + rtol=1e-05, + atol=1e-05)) + + elif col_type == "global_scatter": + np.random.seed(pid0) + local_expert_count1 = np.random.randint(1, 4, size=4).astype("int") + fwd_expert_count = sum(local_expert_count1) + local_input_buf1 = np.random.rand(fwd_expert_count, + 2).astype("float32") + expert_ptr1 = np.ones(4, dtype=np.int32) + expert_ptr1[0] = 0 + for i in range(1, 4): + expert_ptr1[i] = expert_ptr1[i - 1] + local_expert_count1[i - 1] + np.random.seed(pid1) + local_expert_count2 = np.random.randint(1, 4, size=4).astype("int") + fwd_expert_count = sum(local_expert_count2) + local_input_buf2 = np.random.rand(fwd_expert_count, + 2).astype("float32") + expert_ptr2 = np.ones(4, dtype=np.int32) + expert_ptr2[0] = 0 + for i in range(1, 4): + expert_ptr2[i] = expert_ptr2[i - 1] + local_expert_count2[i - 1] + + output1 = [] + output2 = [] + for i in range(2): + for j in range(2): + idx = j * 2 + i + if j == 0: + # send data to 0 card + output1.append(local_input_buf1[expert_ptr1[idx]: \ + expert_ptr1[idx]+local_expert_count1[idx]]) + output1.append(local_input_buf2[expert_ptr2[idx]:\ + expert_ptr2[idx]+local_expert_count2[idx]]) + else: + output2.append(local_input_buf1[expert_ptr1[idx]: \ + expert_ptr1[idx]+local_expert_count1[idx]]) + output2.append(local_input_buf2[expert_ptr2[idx]:\ + expert_ptr2[idx]+local_expert_count2[idx]]) + if output1 == []: + output1 = np.array([]) + else: + output1 = np.concatenate(output1) + if output2 == []: + output2 = np.array([]) + else: + output2 = np.concatenate(output2) + + if tr0_out[0] is None or tr0_out[0].shape[0] == 0: + tr0_out[0] = np.array([]) + + if tr1_out[0] is None or tr1_out[0].shape[0] == 0: + tr1_out[0] = np.array([]) + + self.assertTrue( + np.allclose( + tr0_out[0], output1, rtol=1e-05, atol=1e-05)) + self.assertTrue( + np.allclose( + tr1_out[0], output2, rtol=1e-05, atol=1e-05)) + if static_mode == 0: + self.assertTrue( + np.allclose( + tr0_out[1], + 2 * local_input_buf1, + rtol=1e-05, + atol=1e-05)) + self.assertTrue( + np.allclose( + tr1_out[1], + 2 * local_input_buf2, + rtol=1e-05, + atol=1e-05)) else: pass diff --git a/python/paddle/fluid/tests/unittests/test_collective_global_gather.py b/python/paddle/fluid/tests/unittests/test_collective_global_gather.py index efcbdaa09c360b..c9dee529c21a16 100644 --- a/python/paddle/fluid/tests/unittests/test_collective_global_gather.py +++ b/python/paddle/fluid/tests/unittests/test_collective_global_gather.py @@ -37,155 +37,6 @@ def test_global_gather_nccl_dygraph(self): "nccl", static_mode="0") - def check_with_place(self, - model_file, - col_type, - backend="nccl", - path_id="0", - static_mode="1", - check_error_log=False, - need_envs={}): - if backend == "nccl" or backend == "bkcl": - with_gloo = '0' - else: - with_gloo = '1' - required_envs = { - "FLAGS_fraction_of_gpu_memory_to_use": "0.15", - "FLAGS_eager_delete_tensor_gb": "0.0", - "PATH": os.getenv("PATH"), - "PYTHONPATH": os.getenv("PYTHONPATH", ""), - "LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""), - "LD_PRELOAD": os.getenv("LD_PRELOAD", ""), - "FLAGS_call_stack_level": "2", - "GLOG_v": "3", - "NCCL_P2P_DISABLE": "1", - "STATIC_MODE": static_mode, - "PADDLE_WITH_GLOO": with_gloo, - "BACKEND": backend, - "PATH_ID": path_id - } - required_envs.update(need_envs) - if check_error_log: - required_envs["GLOG_v"] = "3" - required_envs["GLOG_logtostderr"] = "1" - required_envs["GLOO_LOG_LEVEL"] = "TRACE" - tr0_out, tr1_out, pid0, pid1 = self._run_cluster(model_file, - required_envs) - - if col_type == "global_gather": - in_feat = 2 - n_expert = 2 - world_size = 2 - tot_expert = n_expert * world_size - - np.random.seed(pid0) - local_expert_count1 = np.random.randint( - 1, 4, size=tot_expert).astype("int") - expert_ptr1 = np.ones(tot_expert, dtype=np.int32) - expert_ptr1[0] = 0 - for i in range(1, tot_expert): - expert_ptr1[i] = expert_ptr1[i - 1] + local_expert_count1[i - 1] - - np.random.seed(pid1) - local_expert_count2 = np.random.randint( - 1, 4, size=tot_expert).astype("int") - expert_ptr2 = np.ones(tot_expert, dtype=np.int32) - expert_ptr2[0] = 0 - for i in range(1, tot_expert): - expert_ptr2[i] = expert_ptr2[i - 1] + local_expert_count2[i - 1] - - global_expert_count1 = np.zeros(tot_expert).astype("int") - global_expert_count2 = np.zeros(tot_expert).astype("int") - global_expert_count1[0:n_expert] = local_expert_count1[0:n_expert] - global_expert_count1[n_expert:] = local_expert_count2[0:n_expert] - global_expert_count2[0:n_expert] = local_expert_count1[n_expert:] - global_expert_count2[n_expert:] = local_expert_count2[n_expert:] - - np.random.seed(pid0) - fwd_expert_count = sum(global_expert_count1).astype("int") - local_input_buf1 = np.random.rand(fwd_expert_count, - in_feat).astype("float32") - np.random.seed(pid1) - fwd_expert_count = sum(global_expert_count2).astype("int") - local_input_buf2 = np.random.rand(fwd_expert_count, - in_feat).astype("float32") - output1 = [[], [], [], []] - output2 = [[], [], [], []] - send_ptr1 = 0 - send_ptr2 = 0 - - for i in range(n_expert): - for j in range(world_size): - idx = j * n_expert + i - if j == 0: - output1_part1 = local_input_buf1[send_ptr1: \ - send_ptr1 + global_expert_count1[idx], :] - output1_part2 = local_input_buf2[send_ptr2: \ - send_ptr2 + global_expert_count2[idx], :] - output1[i].extend(output1_part1) - output1[i + n_expert].extend(output1_part2) - else: - output2_part1 = local_input_buf1[send_ptr1: \ - send_ptr1 + global_expert_count1[idx]] - output2_part2 = local_input_buf2[send_ptr2: \ - send_ptr2 + global_expert_count2[idx]] - output2[i].extend(output2_part1) - output2[i + n_expert].extend(output2_part2) - send_ptr1 = send_ptr1 + global_expert_count1[idx] - send_ptr2 = send_ptr2 + global_expert_count2[idx] - result1 = [] - result2 = [] - for i in range(tot_expert): - for arr in output1[i]: - if arr == []: - continue - result1.append(arr) - for i in range(tot_expert): - for arr in output2[i]: - if arr == []: - continue - result2.append(arr) - if result1 == []: - output1 = np.array([]) - else: - output1 = np.concatenate( - result1, axis=0).reshape( - sum(local_expert_count1), in_feat) - if result2 == []: - output2 = np.array([]) - else: - output2 = np.concatenate( - result2, axis=0).reshape( - sum(local_expert_count2), in_feat) - - if tr0_out[0] is None or tr0_out[0].shape[0] == 0: - tr0_out[0] = np.array([]) - - if tr1_out[0] is None or tr1_out[0].shape[0] == 0: - tr1_out[0] = np.array([]) - - self.assertTrue( - np.allclose( - tr0_out[0], output1, rtol=1e-05, atol=1e-05)) - self.assertTrue( - np.allclose( - tr1_out[0], output2, rtol=1e-05, atol=1e-05)) - if static_mode == 0: - self.assertTrue( - np.allclose( - tr0_out[1], - 2 * local_input_buf1, - rtol=1e-05, - atol=1e-05)) - self.assertTrue( - np.allclose( - tr1_out[1], - 2 * local_input_buf2, - rtol=1e-05, - atol=1e-05)) - else: - pass - if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_collective_global_scatter.py b/python/paddle/fluid/tests/unittests/test_collective_global_scatter.py index 038118db0d18a1..2b4555de2744d6 100644 --- a/python/paddle/fluid/tests/unittests/test_collective_global_scatter.py +++ b/python/paddle/fluid/tests/unittests/test_collective_global_scatter.py @@ -37,112 +37,6 @@ def test_global_scatter_nccl_dygraph(self): "nccl", static_mode="0") - def check_with_place(self, - model_file, - col_type, - backend="nccl", - path_id="0", - static_mode="1", - check_error_log=False, - need_envs={}): - if backend == "nccl" or backend == "bkcl": - with_gloo = '0' - else: - with_gloo = '1' - required_envs = { - "FLAGS_fraction_of_gpu_memory_to_use": "0.15", - "FLAGS_eager_delete_tensor_gb": "0.0", - "PATH": os.getenv("PATH"), - "PYTHONPATH": os.getenv("PYTHONPATH", ""), - "LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""), - "LD_PRELOAD": os.getenv("LD_PRELOAD", ""), - "FLAGS_call_stack_level": "2", - "GLOG_v": "3", - "NCCL_P2P_DISABLE": "1", - "STATIC_MODE": static_mode, - "PADDLE_WITH_GLOO": with_gloo, - "BACKEND": backend, - "PATH_ID": path_id - } - required_envs.update(need_envs) - if check_error_log: - required_envs["GLOG_v"] = "3" - required_envs["GLOG_logtostderr"] = "1" - required_envs["GLOO_LOG_LEVEL"] = "TRACE" - tr0_out, tr1_out, pid0, pid1 = self._run_cluster(model_file, - required_envs) - - if col_type == "global_scatter": - np.random.seed(pid0) - local_expert_count1 = np.random.randint(1, 4, size=4).astype("int") - fwd_expert_count = sum(local_expert_count1) - local_input_buf1 = np.random.rand(fwd_expert_count, - 2).astype("float32") - expert_ptr1 = np.ones(4, dtype=np.int32) - expert_ptr1[0] = 0 - for i in range(1, 4): - expert_ptr1[i] = expert_ptr1[i - 1] + local_expert_count1[i - 1] - np.random.seed(pid1) - local_expert_count2 = np.random.randint(1, 4, size=4).astype("int") - fwd_expert_count = sum(local_expert_count2) - local_input_buf2 = np.random.rand(fwd_expert_count, - 2).astype("float32") - expert_ptr2 = np.ones(4, dtype=np.int32) - expert_ptr2[0] = 0 - for i in range(1, 4): - expert_ptr2[i] = expert_ptr2[i - 1] + local_expert_count2[i - 1] - - output1 = [] - output2 = [] - for i in range(2): - for j in range(2): - idx = j * 2 + i - if j == 0: - # send data to 0 card - output1.append(local_input_buf1[expert_ptr1[idx]: \ - expert_ptr1[idx]+local_expert_count1[idx]]) - output1.append(local_input_buf2[expert_ptr2[idx]:\ - expert_ptr2[idx]+local_expert_count2[idx]]) - else: - output2.append(local_input_buf1[expert_ptr1[idx]: \ - expert_ptr1[idx]+local_expert_count1[idx]]) - output2.append(local_input_buf2[expert_ptr2[idx]:\ - expert_ptr2[idx]+local_expert_count2[idx]]) - if output1 == []: - output1 = np.array([]) - else: - output1 = np.concatenate(output1) - if output2 == []: - output2 = np.array([]) - else: - output2 = np.concatenate(output2) - - if tr0_out[0] is None or tr0_out[0].shape[0] == 0: - tr0_out[0] = np.array([]) - - if tr1_out[0] is None or tr1_out[0].shape[0] == 0: - tr1_out[0] = np.array([]) - - self.assertTrue( - np.allclose( - tr0_out[0], output1, rtol=1e-05, atol=1e-05)) - self.assertTrue( - np.allclose( - tr1_out[0], output2, rtol=1e-05, atol=1e-05)) - if static_mode == 0: - self.assertTrue( - np.allclose( - tr0_out[1], - 2 * local_input_buf1, - rtol=1e-05, - atol=1e-05)) - self.assertTrue( - np.allclose( - tr1_out[1], - 2 * local_input_buf2, - rtol=1e-05, - atol=1e-05)) - if __name__ == '__main__': unittest.main() From 01aab4b8cfbf793464e7c0b205e307714be1a768 Mon Sep 17 00:00:00 2001 From: youth123 <2042519524@qq.com> Date: Thu, 9 Sep 2021 19:35:50 +0800 Subject: [PATCH 3/3] add commas to avoid conflict --- python/paddle/distributed/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/distributed/utils.py b/python/paddle/distributed/utils.py index 28d24e37777f19..a8d304907a6bb1 100644 --- a/python/paddle/distributed/utils.py +++ b/python/paddle/distributed/utils.py @@ -49,7 +49,7 @@ 'get_logger', 'pull_worker_log', 'global_scatter', - 'global_gather' + 'global_gather', ]