-
Notifications
You must be signed in to change notification settings - Fork 6k
supports distributed classification #18690
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,115 @@ | ||
| // Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. | ||
| // | ||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||
| // you may not use this file except in compliance with the License. | ||
| // You may obtain a copy of the License at | ||
| // | ||
| // http://www.apache.org/licenses/LICENSE-2.0 | ||
| // | ||
| // Unless required by applicable law or agreed to in writing, software | ||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| // See the License for the specific language governing permissions and | ||
| // limitations under the License. | ||
|
|
||
| #include "paddle/fluid/operators/shard_index_op.h" | ||
|
|
||
| namespace paddle { | ||
| namespace operators { | ||
|
|
||
| class ShardIndexOp : public framework::OperatorWithKernel { | ||
| public: | ||
| using framework::OperatorWithKernel::OperatorWithKernel; | ||
| void InferShape(framework::InferShapeContext* ctx) const override { | ||
| PADDLE_ENFORCE(ctx->HasInput("X"), | ||
| "Input(X) of ShardIndexOp should not be null."); | ||
| PADDLE_ENFORCE(ctx->HasOutput("Out"), | ||
| "Output(Out) of ShardIndexOp should not be null."); | ||
|
|
||
| auto x_dims = ctx->GetInputDim("X"); | ||
| PADDLE_ENFORCE_GE(x_dims.size(), 2, | ||
| "Rank of Input(X) should be at least 2."); | ||
| if (ctx->IsRuntime() || x_dims[x_dims.size() - 1] > 0) { | ||
| PADDLE_ENFORCE_GE(x_dims[x_dims.size() - 1], 1U, | ||
| "Last dimension of Input(X) should be 1."); | ||
| } | ||
|
|
||
| ctx->SetOutputDim("Out", x_dims); | ||
| ctx->ShareLoD("X", /* --> */ "Out"); | ||
| } | ||
|
|
||
| protected: | ||
| framework::OpKernelType GetExpectedKernelType( | ||
| const framework::ExecutionContext& ctx) const override { | ||
| return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(), | ||
| ctx.device_context()); | ||
| } | ||
| }; | ||
|
|
||
| class ShardIndexOpMaker : public framework::OpProtoAndCheckerMaker { | ||
| public: | ||
| void Make() override { | ||
| AddInput("X", | ||
| "(LoDTensor, LoDTensor<int|int64>) Input variable. Each value " | ||
| "of X is an index."); | ||
| AddOutput( | ||
| "Out", | ||
| "(Tensor, Tensor<int|int64>) Output tensor with same shape as X. " | ||
| "The tensor consists of sharding representations of values in X."); | ||
| AddAttr<int>("index_num", | ||
| "A positive integer to specify the range of the input X."); | ||
|
|
||
| AddAttr<int>("nshards", | ||
| "A positive integer to specify the number of shards."); | ||
| AddAttr<int>("shard_id", "The current shard id"); | ||
| AddAttr<int>("ignore_value", "An ingeter value out of sharded range") | ||
| .SetDefault(-1); | ||
| AddComment(R"DOC( | ||
| This layer creates the sharded index for input. This layers is used in | ||
| model- and data- parallel mixed training generally, in which the index | ||
| data (usually the label) should be recaculated in each trainer according | ||
| to | ||
|
|
||
| .. math:: | ||
|
|
||
| assert index_num % nshards == 0 | ||
|
|
||
| shard_size = index_num / nshards | ||
|
|
||
| y = x % shard_size if x / shard_size == shard_id else ignore_value | ||
|
|
||
| We take the distributed one-hot representation to show what this layer is | ||
| used for. The distributed one-hot representation is seperated into multiple | ||
| shards, and each shard is filling zeros except the one with the index | ||
| inside. In order to create these sharded representation in each trainer, | ||
| the original index should be recalculated (i.e. sharded) before. | ||
|
|
||
| Examples: | ||
|
|
||
| X is a Tensor of integer values: | ||
| X.shape = [4, 1] | ||
| X.data = [[1], [6], [12], [19]] | ||
|
|
||
| suppose index_num = 20 and nshards = 2, then we get shard_size = 10 | ||
|
|
||
| if shard_id == 0, we get the Out: | ||
| Out.shape = [4, 1] | ||
| Out.data = [[1], [6], [-1], [-1]] | ||
|
|
||
| if shard_id == 1, we get the Out: | ||
| Out.shape = [4, 1] | ||
| Out.data = [[-1], [-1], [2], [9]] | ||
|
|
||
| the default `ignore_value` -1 is used in this example. | ||
| )DOC"); | ||
| } | ||
| }; | ||
|
|
||
| } // namespace operators | ||
| } // namespace paddle | ||
|
|
||
| namespace ops = paddle::operators; | ||
| REGISTER_OP_WITHOUT_GRADIENT(shard_index, ops::ShardIndexOp, | ||
| ops::ShardIndexOpMaker); | ||
| REGISTER_OP_CPU_KERNEL(shard_index, ops::ShardIndexCPUKernel<int>, | ||
| ops::ShardIndexCPUKernel<int64_t>); |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,77 @@ | ||
| // Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. | ||
| // | ||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||
| // you may not use this file except in compliance with the License. | ||
| // You may obtain a copy of the License at | ||
| // | ||
| // http://www.apache.org/licenses/LICENSE-2.0 | ||
| // | ||
| // Unless required by applicable law or agreed to in writing, software | ||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| // See the License for the specific language governing permissions and | ||
| // limitations under the License. | ||
|
|
||
| #include "paddle/fluid/operators/shard_index_op.h" | ||
| #include "paddle/fluid/platform/cuda_primitives.h" | ||
| #include "paddle/fluid/platform/gpu_info.h" | ||
|
|
||
| namespace paddle { | ||
| namespace operators { | ||
|
|
||
| using platform::PADDLE_CUDA_NUM_THREADS; | ||
|
|
||
| template <typename T> | ||
| __global__ void ShardIndexInner(const T* in_data, T* out_data, | ||
| const int64_t numel, const int index_num, | ||
| const int nshards, const int shard_id, | ||
| const int ignore_value) { | ||
| int shard_size = index_num / nshards; | ||
| int idx = blockIdx.x * blockDim.x + threadIdx.x; | ||
| if (idx < numel) { | ||
| assert(in_data[idx] >= 0 && in_data[idx] < index_num); | ||
| if (in_data[idx] / shard_size == shard_id) { | ||
| out_data[idx] = in_data[idx] % shard_size; | ||
| } else { | ||
| out_data[idx] = ignore_value; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| using LoDTensor = framework::LoDTensor; | ||
|
|
||
| template <typename T> | ||
| class ShardIndexCUDAKernel : public framework::OpKernel<T> { | ||
| public: | ||
| void Compute(const framework::ExecutionContext& context) const override { | ||
| auto* in = context.Input<LoDTensor>("X"); | ||
| auto* out = context.Output<LoDTensor>("Out"); | ||
| int index_num = context.Attr<int>("index_num"); | ||
| int nshards = context.Attr<int>("nshards"); | ||
| int shard_id = context.Attr<int>("shard_id"); | ||
| int ignore_value = context.Attr<int>("ignore_value"); | ||
| PADDLE_ENFORCE_GT(index_num, 0); | ||
| PADDLE_ENFORCE_GT(nshards, 0); | ||
| PADDLE_ENFORCE(shard_id >= 0 && shard_id < nshards, | ||
| "shard_id(%d) is not in range [0, %d)", shard_id, nshards); | ||
|
|
||
| out->Resize(in->dims()); | ||
| out->set_lod(in->lod()); | ||
| auto* in_data = in->data<T>(); | ||
| auto* out_data = out->mutable_data<T>(context.GetPlace()); | ||
| int64_t numel = in->numel(); | ||
| auto stream = | ||
| context.template device_context<platform::CUDADeviceContext>().stream(); | ||
| ShardIndexInner<<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / | ||
| PADDLE_CUDA_NUM_THREADS, | ||
| PADDLE_CUDA_NUM_THREADS, 0, stream>>>( | ||
| in_data, out_data, numel, index_num, nshards, shard_id, ignore_value); | ||
| } | ||
| }; | ||
|
|
||
| } // namespace operators | ||
| } // namespace paddle | ||
|
|
||
| namespace ops = paddle::operators; | ||
| REGISTER_OP_CUDA_KERNEL(shard_index, ops::ShardIndexCUDAKernel<int>, | ||
| ops::ShardIndexCUDAKernel<int64_t>); | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,58 @@ | ||
| // Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. | ||
| // | ||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||
| // you may not use this file except in compliance with the License. | ||
| // You may obtain a copy of the License at | ||
| // | ||
| // http://www.apache.org/licenses/LICENSE-2.0 | ||
| // | ||
| // Unless required by applicable law or agreed to in writing, software | ||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| // See the License for the specific language governing permissions and | ||
| // limitations under the License. | ||
|
|
||
| #pragma once | ||
| #include "paddle/fluid/framework/op_registry.h" | ||
|
|
||
| namespace paddle { | ||
| namespace operators { | ||
|
|
||
| using LoDTensor = framework::LoDTensor; | ||
| template <typename T> | ||
| class ShardIndexCPUKernel : public framework::OpKernel<T> { | ||
| public: | ||
| void Compute(const framework::ExecutionContext& context) const override { | ||
| auto* in = context.Input<LoDTensor>("X"); | ||
| auto* out = context.Output<LoDTensor>("Out"); | ||
| int index_num = context.Attr<int>("index_num"); | ||
| int nshards = context.Attr<int>("nshards"); | ||
| int shard_id = context.Attr<int>("shard_id"); | ||
| int ignore_value = context.Attr<int>("ignore_value"); | ||
| PADDLE_ENFORCE_GT(index_num, 0); | ||
| PADDLE_ENFORCE_GT(nshards, 0); | ||
| PADDLE_ENFORCE(shard_id >= 0 && shard_id < nshards, | ||
| "shard_id(%d) is not in range [0, %d)", shard_id, nshards); | ||
|
|
||
| int shard_size = index_num / nshards; | ||
|
|
||
| out->Resize(in->dims()); | ||
| out->set_lod(in->lod()); | ||
| auto* in_data = in->data<T>(); | ||
| auto* out_data = out->mutable_data<T>(context.GetPlace()); | ||
| int64_t numel = in->numel(); | ||
| for (int64_t i = 0; i < numel; ++i) { | ||
| PADDLE_ENFORCE(in_data[i] >= 0 && in_data[i] < index_num, | ||
| "Input index(%d) is out of range [0,%d)", in_data[i], | ||
| index_num); | ||
| if (in_data[i] / shard_size == shard_id) { | ||
| out_data[i] = in_data[i] % shard_size; | ||
| } else { | ||
| out_data[i] = ignore_value; | ||
| } | ||
| } | ||
| } | ||
| }; | ||
|
|
||
| } // namespace operators | ||
| } // namespace paddle |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you use assert here? do you check whether it works?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I want just make sure the input is in the valid range