-
Notifications
You must be signed in to change notification settings - Fork 5.9k
[NPU]shard index op for npu #35281
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NPU]shard index op for npu #35281
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,118 @@ | ||
| // Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
| // | ||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||
| // you may not use this file except in compliance with the License. | ||
| // You may obtain a copy of the License at | ||
| // | ||
| // http://www.apache.org/licenses/LICENSE-2.0 | ||
| // | ||
| // Unless required by applicable law or agreed to in writing, software | ||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| // See the License for the specific language governing permissions and | ||
| // limitations under the License. | ||
|
|
||
| #include "paddle/fluid/operators/shard_index_op.h" | ||
| #include "paddle/fluid/operators/npu_op_runner.h" | ||
|
|
||
| namespace paddle { | ||
| namespace operators { | ||
|
|
||
| using LoDTensor = framework::LoDTensor; | ||
| using Tensor = framework::Tensor; | ||
| template <typename T> | ||
| class ShardIndexNPUKernel : public framework::OpKernel<T> { | ||
| public: | ||
| void Compute(const framework::ExecutionContext& context) const override { | ||
| VLOG(4) << "start kernel"; | ||
| auto* in = context.Input<LoDTensor>("X"); | ||
| auto* out = context.Output<LoDTensor>("Out"); | ||
| int index_num = context.Attr<int>("index_num"); | ||
| int nshards = context.Attr<int>("nshards"); | ||
| int shard_id = context.Attr<int>("shard_id"); | ||
| int ignore_value = context.Attr<int>("ignore_value"); | ||
|
|
||
| PADDLE_ENFORCE_GT( | ||
| index_num, 0, | ||
| platform::errors::InvalidArgument( | ||
| "The value 'index_num' for Op(shard_index) must be greater than 0, " | ||
| "but the value given is %d.", | ||
| index_num)); | ||
| PADDLE_ENFORCE_GT(nshards, 0, | ||
| platform::errors::InvalidArgument( | ||
| "The value 'nshard' for Op(shard_index) must be " | ||
| "greater than 0, but the value given is %d.", | ||
| nshards)); | ||
| PADDLE_ENFORCE_GE( | ||
| shard_id, 0, | ||
| platform::errors::InvalidArgument( | ||
| "The value 'shard_id' for Op(shard_index) must be greater or " | ||
| "equal to 0, but the value given is %d.", | ||
| shard_id)); | ||
| PADDLE_ENFORCE_LT( | ||
| shard_id, nshards, | ||
| platform::errors::InvalidArgument( | ||
| "The value 'shard_id' for Op(shard_index) must be less than " | ||
| "nshards (%d), but the value given is %d.", | ||
| nshards, shard_id)); | ||
|
|
||
| int shard_size = (index_num + nshards - 1) / nshards; | ||
|
|
||
| auto place = context.GetPlace(); | ||
| out->Resize(in->dims()); | ||
| out->set_lod(in->lod()); | ||
| out->mutable_data<T>(place); | ||
|
|
||
| Tensor tmp(in->type()); | ||
| tmp.mutable_data<T>(framework::DDim({1}), place); | ||
| FillNpuTensorWithConstant(&tmp, shard_size); | ||
|
|
||
| Tensor condition(framework::proto::VarType::BOOL); | ||
| condition.mutable_data<bool>(in->dims(), place); | ||
|
|
||
| Tensor tmp2(in->type()); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个名字也换一下,mod_out、out_mod啥的 |
||
| tmp2.mutable_data<T>(in->dims(), place); | ||
|
|
||
| Tensor tmp3(in->type()); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 好的 |
||
| tmp3.mutable_data<T>(in->dims(), place); | ||
|
|
||
| auto stream = | ||
| context.template device_context<paddle::platform::NPUDeviceContext>() | ||
| .stream(); | ||
|
|
||
| NpuOpRunner runner; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. NpuOpRunner("Mod", {*in, tmp}, {tmp2}).Run(stream);
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 好的 |
||
| runner.AddInputs({*in, tmp}); | ||
| runner.AddOutputs({tmp2}); | ||
| runner.SetType("Mod"); | ||
| runner.Run(stream); | ||
|
|
||
| NpuOpRunner runner1; | ||
| runner1.AddInputs({*in, tmp}); | ||
| runner1.AddOutputs({tmp3}); | ||
| runner1.SetType("FloorDiv"); | ||
| runner1.Run(stream); | ||
|
|
||
| FillNpuTensorWithConstant(&tmp, shard_id); | ||
| NpuOpRunner runner2; | ||
| runner2.AddInputs({tmp3, tmp}); | ||
| runner2.AddOutputs({condition}); | ||
| runner2.SetType("Equal"); | ||
| runner2.Run(stream); | ||
|
|
||
| Tensor tmp4(in->type()); | ||
| tmp4.mutable_data<T>(in->dims(), place); | ||
| FillNpuTensorWithConstant(&tmp4, ignore_value); | ||
| tmp4.Resize(in->dims()); | ||
|
|
||
| NpuOpRunner runner3; | ||
| runner3.AddInputs({condition, tmp2, tmp4}); | ||
| runner3.AddOutputs({*out}); | ||
| runner3.SetType("Select"); | ||
| runner3.Run(stream); | ||
| } | ||
| }; | ||
| } // namespace operators | ||
| } // namespace paddle | ||
| namespace ops = paddle::operators; | ||
| REGISTER_OP_NPU_KERNEL(shard_index, ops::ShardIndexNPUKernel<int>, | ||
| ops::ShardIndexNPUKernel<int64_t>); | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,84 @@ | ||
| # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from __future__ import print_function | ||
|
|
||
| import unittest | ||
| import numpy as np | ||
| import math | ||
| import sys | ||
| sys.path.append("..") | ||
| from op_test import OpTest | ||
| import paddle.fluid as fluid | ||
| import paddle.fluid.core as core | ||
| import paddle.fluid.framework as framework | ||
| from paddle.fluid.framework import Program, program_guard | ||
| import paddle | ||
| paddle.enable_static() | ||
| SEED = 2021 | ||
|
|
||
|
|
||
| def common_setup(self, index_num, nshards, shard_id, ignore_value): | ||
| self.__class__.use_npu = True | ||
| self.__class__.op_type = "shard_index" | ||
|
|
||
| self.op_type = 'shard_index' | ||
| x_lod = [[i for i in range(10)]] | ||
| N = sum(x_lod[0]) | ||
| x = [np.random.randint(0, index_num - 1) for i in range(N)] | ||
| x = np.array(x).astype('int32').reshape([N, 1]) | ||
|
|
||
| shard_size = (index_num + nshards - 1) // nshards | ||
| out = np.zeros(shape=x.shape).astype('int32') | ||
| for i in range(N): | ||
| if x[i] // shard_size == shard_id: | ||
| out[i] = x[i] % shard_size | ||
| else: | ||
| out[i] = ignore_value | ||
|
|
||
| self.inputs = {'X': (x, x_lod)} | ||
| self.attrs = { | ||
| 'index_num': index_num, | ||
| 'nshards': nshards, | ||
| 'shard_id': shard_id, | ||
| 'ignore_value': ignore_value | ||
| } | ||
| self.outputs = {'Out': (out, x_lod)} | ||
|
|
||
|
|
||
| class TestShardIndexShardId0Op(OpTest): | ||
| def setUp(self): | ||
| common_setup(self, 20, 2, 0, -1) | ||
|
|
||
| def test_check_output(self): | ||
| return self.check_output_with_place(place=paddle.NPUPlace(0)) | ||
|
|
||
|
|
||
| class TestShardIndexShardId1Op(TestShardIndexShardId0Op): | ||
| def setUp(self): | ||
| common_setup(self, 20, 2, 1, -1) | ||
|
|
||
|
|
||
| class TestShardIndexIgnoreValueOp(TestShardIndexShardId0Op): | ||
| def setUp(self): | ||
| common_setup(self, 20, 2, 0, -2) | ||
|
|
||
|
|
||
| class TestShardIndexNotEvenlyDividedOp(TestShardIndexShardId0Op): | ||
| def setUp(self): | ||
| common_setup(self, 15, 2, 1, -1) | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
后面换个名吧,shard_size_tensor或者tmp_shard_size