Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
3b8c4ad
Incorporate cudnn_lstm into LSTM api.
guoshengCS Sep 9, 2020
606aad5
Make coalesce_tensor support alignment optionally.
guoshengCS Sep 11, 2020
55391b6
Reorganize RNN apis. test=develop
guoshengCS Sep 15, 2020
f10e8bb
Fix cudnn rnn layout conversion.
guoshengCS Sep 15, 2020
7404b6e
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
guoshengCS Sep 15, 2020
d633a40
Add sequence_length support for RNN cudnn implement.
guoshengCS Sep 15, 2020
bcd56e6
Use create_parameter for rnn cudnn impl.
guoshengCS Sep 15, 2020
7d9308e
Move `self._flat_weight = self.create_parameter()` in RNNBase to main…
guoshengCS Sep 16, 2020
762426c
Update RNN api unittest to use set_device.
guoshengCS Sep 16, 2020
f99135b
Fix set_place for unit tests of RNN apis.
guoshengCS Sep 17, 2020
2859cca
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
guoshengCS Sep 27, 2020
05570ac
Fix use_align in coalesce_tensor_op.
guoshengCS Sep 27, 2020
63614e2
Adjust RNN apis arguments according to comments.
guoshengCS Sep 27, 2020
63a79f8
Polish documents for SimpleRNN apis.
guoshengCS Sep 27, 2020
17a3a3d
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
guoshengCS Oct 13, 2020
bac2e54
Refine random seed in cudnn_lstm_op.
guoshengCS Oct 13, 2020
2cf6936
Fix RNN saving for jit.save.
guoshengCS Oct 13, 2020
11bf632
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
guoshengCS Oct 13, 2020
50be089
Fix doc of GRU. test=develop
guoshengCS Oct 13, 2020
2e168c0
Use ShareDataWith to avoid copying for cudnn_lstm_op test.
guoshengCS Oct 14, 2020
0f8a455
Remove updates on cudnn_lstm temporarily.
guoshengCS Oct 14, 2020
cdb3514
Use ShareDataWith to avoid copying for cudnn_lstm_op test.
guoshengCS Oct 14, 2020
eb6c671
Refine random seed in cudnn_lstm_op.
guoshengCS Oct 14, 2020
8ec21a0
Fix test_lstm by adjust ConcreteProgram buffer getter.
guoshengCS Oct 14, 2020
7ff3fa0
Use create_parameter instead of create_var for rnn._flat_weight for s…
guoshengCS Oct 14, 2020
a731728
Remove W input for cudnn_lstm to pass unused_var_check.
guoshengCS Oct 14, 2020
47ed7e1
Add test_predict for RNN unit tests coverage.
guoshengCS Oct 15, 2020
277029e
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
guoshengCS Oct 15, 2020
3924c1f
Fix code style of rnn.
guoshengCS Oct 15, 2020
065fcc9
Fix F.rnn usage in rnn.py.
guoshengCS Oct 16, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 22 additions & 9 deletions paddle/fluid/operators/coalesce_tensor_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ class CoalesceTensorOpKernel : public framework::OpKernel<T> {
}

auto in_tensors = context.MultiInput<framework::LoDTensor>("Input");
bool use_align = context.Attr<bool>("use_align");

if (context.Attr<bool>("check_name")) {
for (size_t i = 0; i < in_var_names.size(); ++i) {
Expand All @@ -93,7 +94,7 @@ class CoalesceTensorOpKernel : public framework::OpKernel<T> {
context.Attr<int>("dtype"));
size_t size_of_dtype = framework::SizeOfType(dtype);
GetMemSizeAndDtype(in_tensors, in_var_names, &numel, size_of_dtype,
context.GetPlace());
context.GetPlace(), use_align);

// Alloc the continuous space
auto fused_tensor = context.Output<framework::LoDTensor>("FusedOutput");
Expand All @@ -111,8 +112,11 @@ class CoalesceTensorOpKernel : public framework::OpKernel<T> {
framework::TensorCopy(*in_tensors[i], context.GetPlace(), dev_ctx,
&sub_tensor);

offset += platform::Alignment(len * size_of_dtype, context.GetPlace()) /
size_of_dtype;
offset +=
use_align
? platform::Alignment(len * size_of_dtype, context.GetPlace()) /
size_of_dtype
: len;
}
} else if (context.Attr<bool>("set_constant")) {
math::SetConstant<DeviceContext, T> set_constant;
Expand All @@ -131,8 +135,10 @@ class CoalesceTensorOpKernel : public framework::OpKernel<T> {
->ShareDataWith(fused_tensor->Slice(
static_cast<int64_t>(offset), static_cast<int64_t>(offset + len)))
.Resize(dim);
len = platform::Alignment(len * size_of_dtype, context.GetPlace()) /
size_of_dtype;
len = use_align
? platform::Alignment(len * size_of_dtype, context.GetPlace()) /
size_of_dtype
: len;
offset += len;
ss << "output(" << out_var_names[i] << ") dim:(" << dim << ")"
<< " address: " << out_tensors[i]->data<void>() << ", ";
Expand All @@ -144,7 +150,8 @@ class CoalesceTensorOpKernel : public framework::OpKernel<T> {
void GetMemSizeAndDtype(
const std::vector<const framework::LoDTensor *> &lod_tensors,
const std::vector<std::string> var_names, size_t *numel,
const size_t &size_of_dtype, const platform::Place &place) const {
const size_t &size_of_dtype, const platform::Place &place,
const bool use_align = true) const {
PADDLE_ENFORCE_EQ(
lod_tensors.size(), var_names.size(),
platform::errors::InvalidArgument(
Expand All @@ -167,9 +174,11 @@ class CoalesceTensorOpKernel : public framework::OpKernel<T> {
ss << "input(" << var_names[i] << ") dim:(" << lod_tensors[i]->dims()
<< ") "
<< " addres:" << lod_tensors[i]->data<void>() << ", ";
*numel += platform::Alignment(static_cast<size_t>(size) * size_of_dtype,
place) /
size_of_dtype;
*numel += use_align
? platform::Alignment(
static_cast<size_t>(size) * size_of_dtype, place) /
size_of_dtype
: static_cast<size_t>(size);
}

VLOG(10) << ss.str();
Expand Down Expand Up @@ -223,6 +232,10 @@ class CoalesceTensorOpMaker : public framework::OpProtoAndCheckerMaker {
"Whether to check the name of Input and Output to ensure "
"they are the same separately.")
.SetDefault(false);
AddAttr<bool>("use_align",
"Whether to consider memory chunk and take alignment into "
"account for inputs and outputs.")
.SetDefault(true);
AddComment(R"DOC(
CoalesceTensor Operator.

Expand Down
42 changes: 32 additions & 10 deletions paddle/fluid/operators/cudnn_lstm_op.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/cudnn_lstm_cache.h"
#include "paddle/fluid/operators/math/math_function.h"
Expand Down Expand Up @@ -156,6 +157,21 @@ class CudnnLSTMGPUKernel : public framework::OpKernel<T> {
bool is_test = ctx.Attr<bool>("is_test");
int seed = ctx.Attr<int>("seed");

if (!is_test) {
int device_id =
BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace()).GetDeviceId();
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);
if (gen_cuda->GetIsInitPy() && seed == 0) {
// If perform `manual_seed` in python and inner seed is not specified
// (equals 0), use global generator generated seed.
seed = static_cast<int>(gen_cuda->Random64());
} else if (seed == 0) {
// use random generated seed
std::random_device rd;
seed = rd();
} // else use `ctx.Attr<int>("seed")` specified seed
}

bool has_seq_length = ctx.HasInput("SequenceLength");
std::vector<int> SequenceLength;
if (has_seq_length) {
Expand Down Expand Up @@ -194,13 +210,25 @@ class CudnnLSTMGPUKernel : public framework::OpKernel<T> {

if (!continuous) {
LOG_FIRST_N(WARNING, 2)
<< "If the memory space of the Input WeightList is not "
"continuous, less efficient calculation will be "
"called. Please call coalesce_tensor op to make the "
"input memory continuous.";
<< "If the memory space of the Input WeightList is not continuous, "
"less efficient calculation will be called. Please call "
"flatten_parameters() to make the input memory continuous.";
weight_whole.mutable_data<T>({weight_numel}, place);
weight_to_tensor<T>(place, stream, weight_list, &weight_whole);
w_data = weight_whole.data<T>();
if (is_test) { // maybe also reset small weights' ptr for training
int offset = 0;
for (size_t i = 0; i < weight_list.size(); ++i) {
size_t len = weight_list[i]->numel();
auto dim = weight_list[i]->dims();
const_cast<Tensor *>(weight_list[i])
->ShareDataWith(
weight_whole.Slice(static_cast<int64_t>(offset),
static_cast<int64_t>(offset + len)))
.Resize(dim);
offset += len;
}
}
} else {
w_data = const_cast<T *>(weight_list[0]->data<T>());
}
Expand All @@ -226,12 +254,6 @@ class CudnnLSTMGPUKernel : public framework::OpKernel<T> {
LSTMInferece<T>(has_seq_length, handle, seq_length, &rnn, x_data,
init_h_data, init_c_data, w_data, out_data, last_h_data,
last_c_data, &workspace_data_, workspace_size);
if (!w_initialized && ctx.HasInput("W") && ctx.HasInput("WeightList")) {
auto *W = const_cast<Tensor *>(ctx.Input<Tensor>("W"));
auto weight_list = ctx.MultiInput<framework::Tensor>("WeightList");
W->mutable_data<T>({weight_numel}, place);
weight_to_tensor<T>(place, stream, weight_list, W);
}
} else {
if (!has_seq_length) {
// for train
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/operators/save_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ REGISTER_OP_CPU_KERNEL(
save, ops::SaveOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::SaveOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::SaveOpKernel<paddle::platform::CPUDeviceContext, int>,
ops::SaveOpKernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::SaveOpKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::SaveOpKernel<paddle::platform::CPUDeviceContext, int16_t>,
ops::SaveOpKernel<paddle::platform::CPUDeviceContext, int64_t>);
1 change: 1 addition & 0 deletions paddle/fluid/operators/save_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ REGISTER_OP_CUDA_KERNEL(
save, ops::SaveOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::SaveOpKernel<paddle::platform::CUDADeviceContext, double>,
ops::SaveOpKernel<paddle::platform::CUDADeviceContext, int>,
ops::SaveOpKernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::SaveOpKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::SaveOpKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::SaveOpKernel<paddle::platform::CUDADeviceContext,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def from_func_and_args(cls, function_spec, args, kwargs, class_instance):
# 1. filter `self` in args
if args and isinstance(args[0], layers.Layer):
args = args[1:]
# 2. convert tensor and numpy array into InputSpec
# 2. convert tensor and numpy array into InputSpec
_args, _kwargs = function_spec.unified_args_and_kwargs(args, kwargs)
input_with_spec = function_spec.args_to_input_spec(_args, _kwargs)

Expand Down Expand Up @@ -592,9 +592,8 @@ def from_func_spec(func_spec, input_spec, class_instance):
inputs = tuple([class_instance] + list(inputs))

# 2. Gets all ParamBases and buffered VarBases in the function
all_parameters_and_buffers = list(
get_parameters(class_instance).values()) + list(
get_buffers(class_instance).values())
all_parameters_and_buffers = _extract_indeed_params_buffers(
class_instance)

# 3. Builds program only once and returns the output Variables.
with param_guard(get_parameters(
Expand Down Expand Up @@ -622,6 +621,17 @@ def from_func_spec(func_spec, input_spec, class_instance):
startup_program=startup_program)


def _extract_indeed_params_buffers(class_instance):
"""
To filter not initialzed buffers.
"""
params = list(get_parameters(class_instance).values())
buffers = list(get_buffers(class_instance).values())
buffers = [buffer for buffer in buffers if buffer.shape != []]

return params + buffers


class ProgramCache(object):
"""
Wrapper class for the program functions defined by dygraph function.
Expand Down
82 changes: 73 additions & 9 deletions python/paddle/fluid/tests/unittests/rnn/test_rnn_nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,13 @@ def __init__(self, time_major=True, direction="forward", place="cpu"):
self.time_major = time_major
self.direction = direction
self.num_directions = 2 if direction == "bidirectional" else 1
self.place = paddle.CPUPlace() if place == "cpu" \
else paddle.CUDAPlace(0)
self.place = place

def setUp(self):
paddle.disable_static(self.place)
# Since `set_device` is global, set `set_device` in `setUp` rather than
# `__init__` to avoid using an error device set by another test case.
place = paddle.set_device(self.place)
paddle.disable_static(place)
rnn1 = SimpleRNN(
16, 32, 2, time_major=self.time_major, direction=self.direction)
rnn2 = paddle.nn.SimpleRNN(
Expand Down Expand Up @@ -103,11 +105,13 @@ def __init__(self, time_major=True, direction="forward", place="cpu"):
self.time_major = time_major
self.direction = direction
self.num_directions = 2 if direction == "bidirectional" else 1
self.place = paddle.CPUPlace() if place == "cpu" \
else paddle.CUDAPlace(0)
self.place = place

def setUp(self):
paddle.disable_static(self.place)
# Since `set_device` is global, set `set_device` in `setUp` rather than
# `__init__` to avoid using an error device set by another test case.
place = paddle.set_device(self.place)
paddle.disable_static(place)
rnn1 = GRU(16,
32,
2,
Expand Down Expand Up @@ -183,11 +187,13 @@ def __init__(self, time_major=True, direction="forward", place="cpu"):
self.time_major = time_major
self.direction = direction
self.num_directions = 2 if direction == "bidirectional" else 1
self.place = paddle.CPUPlace() if place == "cpu" \
else paddle.CUDAPlace(0)
self.place = place

def setUp(self):
paddle.disable_static(self.place)
# Since `set_device` is global, set `set_device` in `setUp` rather than
# `__init__` to avoid using an error device set by another test case.
place = paddle.set_device(self.place)
paddle.disable_static(place)
rnn1 = LSTM(
16, 32, 2, time_major=self.time_major, direction=self.direction)
rnn2 = paddle.nn.LSTM(
Expand Down Expand Up @@ -251,10 +257,68 @@ def test_with_input_lengths(self):
np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5)
np.testing.assert_allclose(c1, c2.numpy(), atol=1e-8, rtol=1e-5)

def test_predict(self):
place = paddle.set_device(self.place)
paddle.manual_seed(123)
np.random.seed(123)

class Net(paddle.nn.Layer):
def __init__(self):
super(Net, self).__init__()
self.rnn1 = paddle.nn.LSTM(
16, 32, 2, direction="bidirectional", dropout=0.1)

def forward(self, input):
return self.rnn1(input)

x = paddle.randn((4, 10, 16))
x.stop_gradient = False
seq_len = paddle.to_tensor(np.array([10, 6, 8, 5]))
mask = sequence_mask(seq_len, maxlen=10, dtype=x.dtype)
mask = paddle.unsqueeze(mask, [2])
rnn = Net()
y, (h, c) = rnn(x)
y = y * mask
loss = paddle.mean(y)
loss.backward()
optimizer = paddle.optimizer.Adam(
learning_rate=0.1, parameters=rnn.parameters())
optimizer.step()
rnn.eval()
y, (h, c) = rnn(x)
# `jit.to_static` would include a train_program, eval mode might cause
# some errors currently, such as dropout grad op gets `is_test == True`.
rnn.train()

rnn = paddle.jit.to_static(
rnn,
[paddle.static.InputSpec(
shape=[None, None, 16], dtype=x.dtype)])
paddle.jit.save(rnn, "./inference/lstm_infer")

paddle.enable_static()

new_scope = paddle.static.Scope()
with paddle.static.scope_guard(new_scope):
exe = paddle.static.Executor(place)
[inference_program, feed_target_names,
fetch_targets] = paddle.static.load_inference_model(
dirname="./inference",
executor=exe,
model_filename="lstm_infer.pdmodel",
params_filename="lstm_infer.pdiparams")
results = exe.run(inference_program,
feed={feed_target_names[0]: x.numpy()},
fetch_list=fetch_targets)
np.testing.assert_equal(
y.numpy(), results[0]) # eval results equal predict results
paddle.disable_static()

def runTest(self):
self.test_with_initial_state()
self.test_with_zero_state()
self.test_with_input_lengths()
self.test_predict()


def load_tests(loader, tests, pattern):
Expand Down
21 changes: 12 additions & 9 deletions python/paddle/fluid/tests/unittests/rnn/test_rnn_nets_static.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,12 @@ def __init__(self, time_major=True, direction="forward", place="cpu"):
self.time_major = time_major
self.direction = direction
self.num_directions = 2 if direction == "bidirectional" else 1
self.place = paddle.CPUPlace() if place == "cpu" \
else paddle.CUDAPlace(0)
self.place = place

def setUp(self):
# Since `set_device` is global, set `set_device` in `setUp` rather than
# `__init__` to avoid using an error device set by another test case.
place = paddle.set_device(self.place)
rnn1 = SimpleRNN(
16, 32, 2, time_major=self.time_major, direction=self.direction)

Expand All @@ -48,7 +50,6 @@ def setUp(self):
time_major=self.time_major,
direction=self.direction)

place = self.place
exe = paddle.static.Executor(place)
scope = paddle.fluid.Scope()
with paddle.static.scope_guard(scope):
Expand Down Expand Up @@ -172,10 +173,12 @@ def __init__(self, time_major=True, direction="forward", place="cpu"):
self.time_major = time_major
self.direction = direction
self.num_directions = 2 if direction == "bidirectional" else 1
self.place = paddle.CPUPlace() if place == "cpu" \
else paddle.CUDAPlace(0)
self.place = place

def setUp(self):
# Since `set_device` is global, set `set_device` in `setUp` rather than
# `__init__` to avoid using an error device set by another test case.
place = paddle.set_device(self.place)
rnn1 = GRU(16,
32,
2,
Expand All @@ -192,7 +195,6 @@ def setUp(self):
time_major=self.time_major,
direction=self.direction)

place = self.place
exe = paddle.static.Executor(place)
scope = paddle.fluid.Scope()
with paddle.static.scope_guard(scope):
Expand Down Expand Up @@ -316,10 +318,12 @@ def __init__(self, time_major=True, direction="forward", place="cpu"):
self.time_major = time_major
self.direction = direction
self.num_directions = 2 if direction == "bidirectional" else 1
self.place = paddle.CPUPlace() if place == "cpu" \
else paddle.CUDAPlace(0)
self.place = place

def setUp(self):
# Since `set_device` is global, set `set_device` in `setUp` rather than
# `__init__` to avoid using an error device set by another test case.
place = paddle.set_device(self.place)
rnn1 = LSTM(
16, 32, 2, time_major=self.time_major, direction=self.direction)

Expand All @@ -334,7 +338,6 @@ def setUp(self):
time_major=self.time_major,
direction=self.direction)

place = self.place
exe = paddle.static.Executor(place)
scope = paddle.fluid.Scope()
with paddle.static.scope_guard(scope):
Expand Down
Loading