Skip to content

Commit 2584ff7

Browse files
committed
Incorporate cudnn_lstm into LSTM api (PaddlePaddle#27217)
* Incorporate cudnn_lstm into LSTM api. test=develop * Make coalesce_tensor support alignment optionally. test=develop * Reorganize RNN apis. test=develop * Fix cudnn rnn layout conversion. test=develop * Add sequence_length support for RNN cudnn implement. Add optional init_h and init_c gradient for cudnn_lstm_op. test=develop * Use create_parameter for rnn cudnn impl. test=develop * Move `self._flat_weight = self.create_parameter()` in RNNBase to main_program. test=develop * Update RNN api unittest to use set_device. test=develop * Fix set_place for unit tests of RNN apis. test=develop * Fix use_align in coalesce_tensor_op. test=develop * Adjust RNN apis arguments according to comments. test=develop * Polish documents for SimpleRNN apis. test=develop * Refine random seed in cudnn_lstm_op. Expose rnn params from sublayers to RNN. test=develop * Fix RNN saving for jit.save. Refine cudnn_lstm dropout behavior. test=develop * Fix doc of GRU. test=develop * Use ShareDataWith to avoid copying for cudnn_lstm_op test. test=develop * Remove updates on cudnn_lstm temporarily. test=develop * Use ShareDataWith to avoid copying for cudnn_lstm_op test. test=develop * Refine random seed in cudnn_lstm_op. test=develop * Fix test_lstm by adjust ConcreteProgram buffer getter. test=develop * Use create_parameter instead of create_var for rnn._flat_weight for static graph usage. test=develop * Remove W input for cudnn_lstm to pass unused_var_check. test=develop * Add test_predict for RNN unit tests coverage. test=develop * Fix code style of rnn. test=develop * Fix F.rnn usage in rnn.py. test=develop
1 parent a6c1807 commit 2584ff7

8 files changed

Lines changed: 444 additions & 193 deletions

File tree

paddle/fluid/operators/coalesce_tensor_op.cc

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ class CoalesceTensorOpKernel : public framework::OpKernel<T> {
6767
}
6868

6969
auto in_tensors = context.MultiInput<framework::LoDTensor>("Input");
70+
bool use_align = context.Attr<bool>("use_align");
7071

7172
if (context.Attr<bool>("check_name")) {
7273
for (size_t i = 0; i < in_var_names.size(); ++i) {
@@ -93,7 +94,7 @@ class CoalesceTensorOpKernel : public framework::OpKernel<T> {
9394
context.Attr<int>("dtype"));
9495
size_t size_of_dtype = framework::SizeOfType(dtype);
9596
GetMemSizeAndDtype(in_tensors, in_var_names, &numel, size_of_dtype,
96-
context.GetPlace());
97+
context.GetPlace(), use_align);
9798

9899
// Alloc the continuous space
99100
auto fused_tensor = context.Output<framework::LoDTensor>("FusedOutput");
@@ -111,8 +112,11 @@ class CoalesceTensorOpKernel : public framework::OpKernel<T> {
111112
framework::TensorCopy(*in_tensors[i], context.GetPlace(), dev_ctx,
112113
&sub_tensor);
113114

114-
offset += platform::Alignment(len * size_of_dtype, context.GetPlace()) /
115-
size_of_dtype;
115+
offset +=
116+
use_align
117+
? platform::Alignment(len * size_of_dtype, context.GetPlace()) /
118+
size_of_dtype
119+
: len;
116120
}
117121
} else if (context.Attr<bool>("set_constant")) {
118122
math::SetConstant<DeviceContext, T> set_constant;
@@ -131,8 +135,10 @@ class CoalesceTensorOpKernel : public framework::OpKernel<T> {
131135
->ShareDataWith(fused_tensor->Slice(
132136
static_cast<int64_t>(offset), static_cast<int64_t>(offset + len)))
133137
.Resize(dim);
134-
len = platform::Alignment(len * size_of_dtype, context.GetPlace()) /
135-
size_of_dtype;
138+
len = use_align
139+
? platform::Alignment(len * size_of_dtype, context.GetPlace()) /
140+
size_of_dtype
141+
: len;
136142
offset += len;
137143
ss << "output(" << out_var_names[i] << ") dim:(" << dim << ")"
138144
<< " address: " << out_tensors[i]->data<void>() << ", ";
@@ -144,7 +150,8 @@ class CoalesceTensorOpKernel : public framework::OpKernel<T> {
144150
void GetMemSizeAndDtype(
145151
const std::vector<const framework::LoDTensor *> &lod_tensors,
146152
const std::vector<std::string> var_names, size_t *numel,
147-
const size_t &size_of_dtype, const platform::Place &place) const {
153+
const size_t &size_of_dtype, const platform::Place &place,
154+
const bool use_align = true) const {
148155
PADDLE_ENFORCE_EQ(
149156
lod_tensors.size(), var_names.size(),
150157
platform::errors::InvalidArgument(
@@ -167,9 +174,11 @@ class CoalesceTensorOpKernel : public framework::OpKernel<T> {
167174
ss << "input(" << var_names[i] << ") dim:(" << lod_tensors[i]->dims()
168175
<< ") "
169176
<< " addres:" << lod_tensors[i]->data<void>() << ", ";
170-
*numel += platform::Alignment(static_cast<size_t>(size) * size_of_dtype,
171-
place) /
172-
size_of_dtype;
177+
*numel += use_align
178+
? platform::Alignment(
179+
static_cast<size_t>(size) * size_of_dtype, place) /
180+
size_of_dtype
181+
: static_cast<size_t>(size);
173182
}
174183

175184
VLOG(10) << ss.str();
@@ -223,6 +232,10 @@ class CoalesceTensorOpMaker : public framework::OpProtoAndCheckerMaker {
223232
"Whether to check the name of Input and Output to ensure "
224233
"they are the same separately.")
225234
.SetDefault(false);
235+
AddAttr<bool>("use_align",
236+
"Whether to consider memory chunk and take alignment into "
237+
"account for inputs and outputs.")
238+
.SetDefault(true);
226239
AddComment(R"DOC(
227240
CoalesceTensor Operator.
228241

paddle/fluid/operators/cudnn_lstm_op.cu.cc

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15+
#include "paddle/fluid/framework/generator.h"
1516
#include "paddle/fluid/framework/op_registry.h"
1617
#include "paddle/fluid/operators/cudnn_lstm_cache.h"
1718
#include "paddle/fluid/operators/math/math_function.h"
@@ -156,6 +157,21 @@ class CudnnLSTMGPUKernel : public framework::OpKernel<T> {
156157
bool is_test = ctx.Attr<bool>("is_test");
157158
int seed = ctx.Attr<int>("seed");
158159

160+
if (!is_test) {
161+
int device_id =
162+
BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace()).GetDeviceId();
163+
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);
164+
if (gen_cuda->GetIsInitPy() && seed == 0) {
165+
// If perform `manual_seed` in python and inner seed is not specified
166+
// (equals 0), use global generator generated seed.
167+
seed = static_cast<int>(gen_cuda->Random64());
168+
} else if (seed == 0) {
169+
// use random generated seed
170+
std::random_device rd;
171+
seed = rd();
172+
} // else use `ctx.Attr<int>("seed")` specified seed
173+
}
174+
159175
bool has_seq_length = ctx.HasInput("SequenceLength");
160176
std::vector<int> SequenceLength;
161177
if (has_seq_length) {
@@ -194,13 +210,25 @@ class CudnnLSTMGPUKernel : public framework::OpKernel<T> {
194210

195211
if (!continuous) {
196212
LOG_FIRST_N(WARNING, 2)
197-
<< "If the memory space of the Input WeightList is not "
198-
"continuous, less efficient calculation will be "
199-
"called. Please call coalesce_tensor op to make the "
200-
"input memory continuous.";
213+
<< "If the memory space of the Input WeightList is not continuous, "
214+
"less efficient calculation will be called. Please call "
215+
"flatten_parameters() to make the input memory continuous.";
201216
weight_whole.mutable_data<T>({weight_numel}, place);
202217
weight_to_tensor<T>(place, stream, weight_list, &weight_whole);
203218
w_data = weight_whole.data<T>();
219+
if (is_test) { // maybe also reset small weights' ptr for training
220+
int offset = 0;
221+
for (size_t i = 0; i < weight_list.size(); ++i) {
222+
size_t len = weight_list[i]->numel();
223+
auto dim = weight_list[i]->dims();
224+
const_cast<Tensor *>(weight_list[i])
225+
->ShareDataWith(
226+
weight_whole.Slice(static_cast<int64_t>(offset),
227+
static_cast<int64_t>(offset + len)))
228+
.Resize(dim);
229+
offset += len;
230+
}
231+
}
204232
} else {
205233
w_data = const_cast<T *>(weight_list[0]->data<T>());
206234
}
@@ -226,12 +254,6 @@ class CudnnLSTMGPUKernel : public framework::OpKernel<T> {
226254
LSTMInferece<T>(has_seq_length, handle, seq_length, &rnn, x_data,
227255
init_h_data, init_c_data, w_data, out_data, last_h_data,
228256
last_c_data, &workspace_data_, workspace_size);
229-
if (!w_initialized && ctx.HasInput("W") && ctx.HasInput("WeightList")) {
230-
auto *W = const_cast<Tensor *>(ctx.Input<Tensor>("W"));
231-
auto weight_list = ctx.MultiInput<framework::Tensor>("WeightList");
232-
W->mutable_data<T>({weight_numel}, place);
233-
weight_to_tensor<T>(place, stream, weight_list, W);
234-
}
235257
} else {
236258
if (!has_seq_length) {
237259
// for train

paddle/fluid/operators/save_op.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ REGISTER_OP_CPU_KERNEL(
8989
save, ops::SaveOpKernel<paddle::platform::CPUDeviceContext, float>,
9090
ops::SaveOpKernel<paddle::platform::CPUDeviceContext, double>,
9191
ops::SaveOpKernel<paddle::platform::CPUDeviceContext, int>,
92+
ops::SaveOpKernel<paddle::platform::CPUDeviceContext, uint8_t>,
9293
ops::SaveOpKernel<paddle::platform::CPUDeviceContext, int8_t>,
9394
ops::SaveOpKernel<paddle::platform::CPUDeviceContext, int16_t>,
9495
ops::SaveOpKernel<paddle::platform::CPUDeviceContext, int64_t>);

paddle/fluid/operators/save_op.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ REGISTER_OP_CUDA_KERNEL(
2121
save, ops::SaveOpKernel<paddle::platform::CUDADeviceContext, float>,
2222
ops::SaveOpKernel<paddle::platform::CUDADeviceContext, double>,
2323
ops::SaveOpKernel<paddle::platform::CUDADeviceContext, int>,
24+
ops::SaveOpKernel<paddle::platform::CUDADeviceContext, uint8_t>,
2425
ops::SaveOpKernel<paddle::platform::CUDADeviceContext, int8_t>,
2526
ops::SaveOpKernel<paddle::platform::CUDADeviceContext, int64_t>,
2627
ops::SaveOpKernel<paddle::platform::CUDADeviceContext,

python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def from_func_and_args(cls, function_spec, args, kwargs, class_instance):
174174
# 1. filter `self` in args
175175
if args and isinstance(args[0], layers.Layer):
176176
args = args[1:]
177-
# 2. convert tensor and numpy array into InputSpec
177+
# 2. convert tensor and numpy array into InputSpec
178178
_args, _kwargs = function_spec.unified_args_and_kwargs(args, kwargs)
179179
input_with_spec = function_spec.args_to_input_spec(_args, _kwargs)
180180

@@ -592,9 +592,8 @@ def from_func_spec(func_spec, input_spec, class_instance):
592592
inputs = tuple([class_instance] + list(inputs))
593593

594594
# 2. Gets all ParamBases and buffered VarBases in the function
595-
all_parameters_and_buffers = list(
596-
get_parameters(class_instance).values()) + list(
597-
get_buffers(class_instance).values())
595+
all_parameters_and_buffers = _extract_indeed_params_buffers(
596+
class_instance)
598597

599598
# 3. Builds program only once and returns the output Variables.
600599
with param_guard(get_parameters(
@@ -622,6 +621,17 @@ def from_func_spec(func_spec, input_spec, class_instance):
622621
startup_program=startup_program)
623622

624623

624+
def _extract_indeed_params_buffers(class_instance):
625+
"""
626+
To filter not initialzed buffers.
627+
"""
628+
params = list(get_parameters(class_instance).values())
629+
buffers = list(get_buffers(class_instance).values())
630+
buffers = [buffer for buffer in buffers if buffer.shape != []]
631+
632+
return params + buffers
633+
634+
625635
class ProgramCache(object):
626636
"""
627637
Wrapper class for the program functions defined by dygraph function.

python/paddle/fluid/tests/unittests/rnn/test_rnn_nets.py

Lines changed: 73 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,13 @@ def __init__(self, time_major=True, direction="forward", place="cpu"):
2929
self.time_major = time_major
3030
self.direction = direction
3131
self.num_directions = 2 if direction == "bidirectional" else 1
32-
self.place = paddle.CPUPlace() if place == "cpu" \
33-
else paddle.CUDAPlace(0)
32+
self.place = place
3433

3534
def setUp(self):
36-
paddle.disable_static(self.place)
35+
# Since `set_device` is global, set `set_device` in `setUp` rather than
36+
# `__init__` to avoid using an error device set by another test case.
37+
place = paddle.set_device(self.place)
38+
paddle.disable_static(place)
3739
rnn1 = SimpleRNN(
3840
16, 32, 2, time_major=self.time_major, direction=self.direction)
3941
rnn2 = paddle.nn.SimpleRNN(
@@ -103,11 +105,13 @@ def __init__(self, time_major=True, direction="forward", place="cpu"):
103105
self.time_major = time_major
104106
self.direction = direction
105107
self.num_directions = 2 if direction == "bidirectional" else 1
106-
self.place = paddle.CPUPlace() if place == "cpu" \
107-
else paddle.CUDAPlace(0)
108+
self.place = place
108109

109110
def setUp(self):
110-
paddle.disable_static(self.place)
111+
# Since `set_device` is global, set `set_device` in `setUp` rather than
112+
# `__init__` to avoid using an error device set by another test case.
113+
place = paddle.set_device(self.place)
114+
paddle.disable_static(place)
111115
rnn1 = GRU(16,
112116
32,
113117
2,
@@ -183,11 +187,13 @@ def __init__(self, time_major=True, direction="forward", place="cpu"):
183187
self.time_major = time_major
184188
self.direction = direction
185189
self.num_directions = 2 if direction == "bidirectional" else 1
186-
self.place = paddle.CPUPlace() if place == "cpu" \
187-
else paddle.CUDAPlace(0)
190+
self.place = place
188191

189192
def setUp(self):
190-
paddle.disable_static(self.place)
193+
# Since `set_device` is global, set `set_device` in `setUp` rather than
194+
# `__init__` to avoid using an error device set by another test case.
195+
place = paddle.set_device(self.place)
196+
paddle.disable_static(place)
191197
rnn1 = LSTM(
192198
16, 32, 2, time_major=self.time_major, direction=self.direction)
193199
rnn2 = paddle.nn.LSTM(
@@ -251,10 +257,68 @@ def test_with_input_lengths(self):
251257
np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5)
252258
np.testing.assert_allclose(c1, c2.numpy(), atol=1e-8, rtol=1e-5)
253259

260+
def test_predict(self):
261+
place = paddle.set_device(self.place)
262+
paddle.manual_seed(123)
263+
np.random.seed(123)
264+
265+
class Net(paddle.nn.Layer):
266+
def __init__(self):
267+
super(Net, self).__init__()
268+
self.rnn1 = paddle.nn.LSTM(
269+
16, 32, 2, direction="bidirectional", dropout=0.1)
270+
271+
def forward(self, input):
272+
return self.rnn1(input)
273+
274+
x = paddle.randn((4, 10, 16))
275+
x.stop_gradient = False
276+
seq_len = paddle.to_tensor(np.array([10, 6, 8, 5]))
277+
mask = sequence_mask(seq_len, maxlen=10, dtype=x.dtype)
278+
mask = paddle.unsqueeze(mask, [2])
279+
rnn = Net()
280+
y, (h, c) = rnn(x)
281+
y = y * mask
282+
loss = paddle.mean(y)
283+
loss.backward()
284+
optimizer = paddle.optimizer.Adam(
285+
learning_rate=0.1, parameters=rnn.parameters())
286+
optimizer.step()
287+
rnn.eval()
288+
y, (h, c) = rnn(x)
289+
# `jit.to_static` would include a train_program, eval mode might cause
290+
# some errors currently, such as dropout grad op gets `is_test == True`.
291+
rnn.train()
292+
293+
rnn = paddle.jit.to_static(
294+
rnn,
295+
[paddle.static.InputSpec(
296+
shape=[None, None, 16], dtype=x.dtype)])
297+
paddle.jit.save(rnn, "./inference/lstm_infer")
298+
299+
paddle.enable_static()
300+
301+
new_scope = paddle.static.Scope()
302+
with paddle.static.scope_guard(new_scope):
303+
exe = paddle.static.Executor(place)
304+
[inference_program, feed_target_names,
305+
fetch_targets] = paddle.static.load_inference_model(
306+
dirname="./inference",
307+
executor=exe,
308+
model_filename="lstm_infer.pdmodel",
309+
params_filename="lstm_infer.pdiparams")
310+
results = exe.run(inference_program,
311+
feed={feed_target_names[0]: x.numpy()},
312+
fetch_list=fetch_targets)
313+
np.testing.assert_equal(
314+
y.numpy(), results[0]) # eval results equal predict results
315+
paddle.disable_static()
316+
254317
def runTest(self):
255318
self.test_with_initial_state()
256319
self.test_with_zero_state()
257320
self.test_with_input_lengths()
321+
self.test_predict()
258322

259323

260324
def load_tests(loader, tests, pattern):

python/paddle/fluid/tests/unittests/rnn/test_rnn_nets_static.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,12 @@ def __init__(self, time_major=True, direction="forward", place="cpu"):
3030
self.time_major = time_major
3131
self.direction = direction
3232
self.num_directions = 2 if direction == "bidirectional" else 1
33-
self.place = paddle.CPUPlace() if place == "cpu" \
34-
else paddle.CUDAPlace(0)
33+
self.place = place
3534

3635
def setUp(self):
36+
# Since `set_device` is global, set `set_device` in `setUp` rather than
37+
# `__init__` to avoid using an error device set by another test case.
38+
place = paddle.set_device(self.place)
3739
rnn1 = SimpleRNN(
3840
16, 32, 2, time_major=self.time_major, direction=self.direction)
3941

@@ -48,7 +50,6 @@ def setUp(self):
4850
time_major=self.time_major,
4951
direction=self.direction)
5052

51-
place = self.place
5253
exe = paddle.static.Executor(place)
5354
scope = paddle.fluid.Scope()
5455
with paddle.static.scope_guard(scope):
@@ -172,10 +173,12 @@ def __init__(self, time_major=True, direction="forward", place="cpu"):
172173
self.time_major = time_major
173174
self.direction = direction
174175
self.num_directions = 2 if direction == "bidirectional" else 1
175-
self.place = paddle.CPUPlace() if place == "cpu" \
176-
else paddle.CUDAPlace(0)
176+
self.place = place
177177

178178
def setUp(self):
179+
# Since `set_device` is global, set `set_device` in `setUp` rather than
180+
# `__init__` to avoid using an error device set by another test case.
181+
place = paddle.set_device(self.place)
179182
rnn1 = GRU(16,
180183
32,
181184
2,
@@ -192,7 +195,6 @@ def setUp(self):
192195
time_major=self.time_major,
193196
direction=self.direction)
194197

195-
place = self.place
196198
exe = paddle.static.Executor(place)
197199
scope = paddle.fluid.Scope()
198200
with paddle.static.scope_guard(scope):
@@ -316,10 +318,12 @@ def __init__(self, time_major=True, direction="forward", place="cpu"):
316318
self.time_major = time_major
317319
self.direction = direction
318320
self.num_directions = 2 if direction == "bidirectional" else 1
319-
self.place = paddle.CPUPlace() if place == "cpu" \
320-
else paddle.CUDAPlace(0)
321+
self.place = place
321322

322323
def setUp(self):
324+
# Since `set_device` is global, set `set_device` in `setUp` rather than
325+
# `__init__` to avoid using an error device set by another test case.
326+
place = paddle.set_device(self.place)
323327
rnn1 = LSTM(
324328
16, 32, 2, time_major=self.time_major, direction=self.direction)
325329

@@ -334,7 +338,6 @@ def setUp(self):
334338
time_major=self.time_major,
335339
direction=self.direction)
336340

337-
place = self.place
338341
exe = paddle.static.Executor(place)
339342
scope = paddle.fluid.Scope()
340343
with paddle.static.scope_guard(scope):

0 commit comments

Comments
 (0)