Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AUTHORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
| lcy-seso | Ying Cao |
| cjld | Dun Liang |
| lipeng-unisound | Peng Li |
| liuyi05 | Yi Liu |
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you gavin?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

囧,忙晕了

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

| liuyuan | Yuan Liu |
| livc | Zhao Li |
| llxxxll | Yong-Feng Liu |
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/API.spec
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ paddle.fluid.layers.group_norm (ArgSpec(args=['input', 'groups', 'epsilon', 'par
paddle.fluid.layers.spectral_norm (ArgSpec(args=['weight', 'dim', 'power_iters', 'eps', 'name'], varargs=None, keywords=None, defaults=(0, 1, 1e-12, None)), ('document', '9461e67095a6fc5d568fb2ce8fef66ff'))
paddle.fluid.layers.softmax_with_cross_entropy (ArgSpec(args=['logits', 'label', 'soft_label', 'ignore_index', 'numeric_stable_mode', 'return_softmax', 'axis'], varargs=None, keywords=None, defaults=(False, -100, True, False, -1)), ('document', '54e1675aa0364f4a78fa72804ec0f413'))
paddle.fluid.layers.smooth_l1 (ArgSpec(args=['x', 'y', 'inside_weight', 'outside_weight', 'sigma'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', 'ecb75c1b00c4c76c98b482f633b7a10c'))
paddle.fluid.layers.one_hot (ArgSpec(args=['input', 'depth'], varargs=None, keywords=None, defaults=None), ('document', '52db6229214fc6ab167d7009df29170d'))
paddle.fluid.layers.one_hot (ArgSpec(args=['input', 'depth', 'allow_out_of_range'], varargs=None, keywords=None, defaults=(False,)), ('document', 'ec4115591be842868c86b2e5334245c6'))
paddle.fluid.layers.autoincreased_step_counter (ArgSpec(args=['counter_name', 'begin', 'step'], varargs=None, keywords=None, defaults=(None, 1, 1)), ('document', '98e7927f09ee2270535b29f048e481ec'))
paddle.fluid.layers.reshape (ArgSpec(args=['x', 'shape', 'actual_shape', 'act', 'inplace', 'name'], varargs=None, keywords=None, defaults=(None, None, False, None)), ('document', '6196c9ec3075ca5a9c058ea1f8492256'))
paddle.fluid.layers.squeeze (ArgSpec(args=['input', 'axes', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'ebbac07662a6e22e8e299ced880c7775'))
Expand Down Expand Up @@ -264,6 +264,7 @@ paddle.fluid.layers.sign (ArgSpec(args=['x'], varargs=None, keywords=None, defau
paddle.fluid.layers.deformable_conv (ArgSpec(args=['input', 'offset', 'mask', 'num_filters', 'filter_size', 'stride', 'padding', 'dilation', 'groups', 'deformable_groups', 'im2col_step', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None, None, None, None, None, None)), ('document', '4d83ba6b971cfd590493b0925b3e081e'))
paddle.fluid.layers.unfold (ArgSpec(args=['x', 'kernel_sizes', 'strides', 'paddings', 'dilations', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None)), ('document', '3f884662ad443d9ecc2b3734b4f61ad6'))
paddle.fluid.layers.deformable_roi_pooling (ArgSpec(args=['input', 'rois', 'trans', 'no_trans', 'spatial_scale', 'group_size', 'pooled_height', 'pooled_width', 'part_size', 'sample_per_part', 'trans_std', 'position_sensitive', 'name'], varargs=None, keywords=None, defaults=(False, 1.0, [1, 1], 1, 1, None, 1, 0.1, False, None)), ('document', '99c03e3f249e36854f87dedaa17c8f35'))
paddle.fluid.layers.shard_index (ArgSpec(args=['input', 'index_range', 'nshards', 'shard_id', 'ignore_value'], varargs=None, keywords=None, defaults=(-1,)), ('document', '401929a9268976ff4ba692d3d70879d9'))
paddle.fluid.layers.data (ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)), ('document', '9d7806e31bdf727c1a23b8782a09b545'))
paddle.fluid.layers.open_files (ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)), ('document', 'cccb6eb5410c822e5307c947aca2c899'))
paddle.fluid.layers.read_file (ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None), ('document', '32181f6037e387fb6e68a5beaafe33b6'))
Expand Down
115 changes: 115 additions & 0 deletions paddle/fluid/operators/shard_index_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/operators/shard_index_op.h"

namespace paddle {
namespace operators {

class ShardIndexOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of ShardIndexOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of ShardIndexOp should not be null.");

auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_GE(x_dims.size(), 2,
"Rank of Input(X) should be at least 2.");
if (ctx->IsRuntime() || x_dims[x_dims.size() - 1] > 0) {
PADDLE_ENFORCE_GE(x_dims[x_dims.size() - 1], 1U,
"Last dimension of Input(X) should be 1.");
}

ctx->SetOutputDim("Out", x_dims);
ctx->ShareLoD("X", /* --> */ "Out");
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
ctx.device_context());
}
};

class ShardIndexOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(LoDTensor, LoDTensor<int|int64>) Input variable. Each value "
"of X is an index.");
AddOutput(
"Out",
"(Tensor, Tensor<int|int64>) Output tensor with same shape as X. "
"The tensor consists of sharding representations of values in X.");
AddAttr<int>("index_range",
"A positive integer to specify the range of the input X.");

AddAttr<int>("nshards",
"A positive integer to specify the number of shards.");
AddAttr<int>("shard_id", "The current shard id");
AddAttr<int>("ignore_value", "An ingeter value out of sharded range")
.SetDefault(-1);
AddComment(R"DOC(
This layer creates the sharded index for input. This layers is used in
model- and data- parallel mixed training generally, in which the index
data (usually the label) should be recaculated in each trainer according
to

.. math::

assert index_range % nshards == 0

shard_range = index_range / nshards

y = x % shard_range if x / shard_range == shard_id else ignore_value

We take the distributed one-hot representation to show what this layer is
used for. The distributed one-hot representation is seperated into multiple
shards, and each shard is filling zeros except the one with the index
inside. In order to create these sharded representation in each trainer,
the original index should be recalculated (i.e. sharded) before.

Examples:

X is a Tensor of integer values:
X.shape = [4, 1]
X.data = [[1], [6], [12], [19]]

suppose index_range = 20 and nshards = 2, then we get shard_range = 10

if shard_id == 0, we get the Out:
Out.shape = [4, 1]
Out.data = [[1], [6], [-1], [-1]]

if shard_id == 1, we get the Out:
Out.shape = [4, 1]
Out.data = [[-1], [-1], [2], [9]]

the default `ignore_value` -1 is used in this example.
)DOC");
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(shard_index, ops::ShardIndexOp,
ops::ShardIndexOpMaker);
REGISTER_OP_CPU_KERNEL(shard_index, ops::ShardIndexCPUKernel<int>,
ops::ShardIndexCPUKernel<int64_t>);
77 changes: 77 additions & 0 deletions paddle/fluid/operators/shard_index_op.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/operators/shard_index_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/gpu_info.h"

namespace paddle {
namespace operators {

using platform::PADDLE_CUDA_NUM_THREADS;

template <typename T>
__global__ void ShardIndexInner(const T* in_data, T* out_data,
const int64_t numel, const int index_range,
const int nshards, const int shard_id,
const int ignore_value) {
int shard_range = index_range / nshards;
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < numel) {
assert(in_data[idx] >= 0 && in_data[idx] < index_range);
if (in_data[idx] / shard_range == shard_id) {
out_data[idx] = in_data[idx] % shard_range;
} else {
out_data[idx] = ignore_value;
}
}
}

using LoDTensor = framework::LoDTensor;

template <typename T>
class ShardIndexCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<LoDTensor>("X");
auto* out = context.Output<LoDTensor>("Out");
int index_range = context.Attr<int>("index_range");
int nshards = context.Attr<int>("nshards");
int shard_id = context.Attr<int>("shard_id");
int ignore_value = context.Attr<int>("ignore_value");
PADDLE_ENFORCE_GT(index_range, 0);
PADDLE_ENFORCE_GT(nshards, 0);
PADDLE_ENFORCE(shard_id >= 0 && shard_id < nshards,
"shard_id(%d) is not in range [0, %d)", shard_id, nshards);

out->Resize(in->dims());
out->set_lod(in->lod());
auto* in_data = in->data<T>();
auto* out_data = out->mutable_data<T>(context.GetPlace());
int64_t numel = in->numel();
auto stream =
context.template device_context<platform::CUDADeviceContext>().stream();
ShardIndexInner<<<(numel + PADDLE_CUDA_NUM_THREADS - 1) /
PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
in_data, out_data, numel, index_range, nshards, shard_id, ignore_value);
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(shard_index, ops::ShardIndexCUDAKernel<int>,
ops::ShardIndexCUDAKernel<int64_t>);
58 changes: 58 additions & 0 deletions paddle/fluid/operators/shard_index_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include "paddle/fluid/framework/op_registry.h"

namespace paddle {
namespace operators {

using LoDTensor = framework::LoDTensor;
template <typename T>
class ShardIndexCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<LoDTensor>("X");
auto* out = context.Output<LoDTensor>("Out");
int index_range = context.Attr<int>("index_range");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you use max_index or some other specific name for this variable? index_range seems to be a value of width, better to use a detailed variable name.

Copy link
Collaborator Author

@gavin1332 gavin1332 Jul 21, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the attribute name "index_range" is ambiguous indeed, and we need another proper name. In most of cases, users have the variable preserving number of indices, so "max_index" attributed requires user manually subtract 1 from the variable and we also have to recover it for "shard_size" calculation later. So I change the attribute "index_range" to "index_num" as a detailed name, which denotes the number of indices precisely.

int nshards = context.Attr<int>("nshards");
int shard_id = context.Attr<int>("shard_id");
int ignore_value = context.Attr<int>("ignore_value");
PADDLE_ENFORCE_GT(index_range, 0);
PADDLE_ENFORCE_GT(nshards, 0);
PADDLE_ENFORCE(shard_id >= 0 && shard_id < nshards,
"shard_id(%d) is not in range [0, %d)", shard_id, nshards);

int shard_range = index_range / nshards;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above. shard_width can be better? shard_range can be width of a shared, it also can be how many shards we have.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have been renamed to "shard_size".


out->Resize(in->dims());
out->set_lod(in->lod());
auto* in_data = in->data<T>();
auto* out_data = out->mutable_data<T>(context.GetPlace());
int64_t numel = in->numel();
for (int64_t i = 0; i < numel; ++i) {
PADDLE_ENFORCE(in_data[i] >= 0 && in_data[i] < index_range,
"Input index(%d) is out of range [0,%d)", in_data[i],
index_range);
if (in_data[i] / shard_range == shard_id) {
out_data[i] = in_data[i] % shard_range;
} else {
out_data[i] = ignore_value;
}
}
}
};

} // namespace operators
} // namespace paddle
2 changes: 2 additions & 0 deletions python/paddle/fluid/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -3632,6 +3632,8 @@ def __init__(self, block, shape, dtype, **kwargs):

self.do_model_average = kwargs.get('do_model_average', None)

self.is_distributed = False

def __str__(self):
return self.to_string(True)

Expand Down
Loading