Skip to content

Incorporate cudnn_lstm into LSTM api#27217

Merged
guoshengCS merged 30 commits intoPaddlePaddle:developfrom
guoshengCS:add-lstm-cudnn
Oct 16, 2020
Merged

Incorporate cudnn_lstm into LSTM api#27217
guoshengCS merged 30 commits intoPaddlePaddle:developfrom
guoshengCS:add-lstm-cudnn

Conversation

@guoshengCS
Copy link
Contributor

@guoshengCS guoshengCS commented Sep 9, 2020

PR types

New features

PR changes

Others

Describe

Incorporate cudnn_lstm into LSTM api

  1. 在LSTM中集成cudnn_lstm
  2. 将基类RNNMixin修改为RNNBase,将RNNCell中的param暴露在RNNBase中
  3. 为coalesce_tensor_op增加use_align属性,以便在参数转换中小tensor合并得到大tensor时忽略小tensor中的memory chunk
  4. 修改cudnn_lstm kernel:对in_h/in_c为可选的求导,调整cudnn_lstm中dropout的seed设置,test时重置weightlist中参数的指针避免检测到不连续而每次拷贝。

@paddle-bot-old
Copy link

paddle-bot-old bot commented Sep 9, 2020

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

params = self.parameters()
shape = [np.prod(param.shape) for param in params]
# for static-graph, append coalesce_tensor into startup program
with fluid.program_guard(fluid.default_startup_program(),
Copy link

@iclementine iclementine Sep 10, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. coalsectensor 拼接大 tensor 目前不保证内存连续,可能需要修改。
  2. 如果 cudnn lstm 完成了小 weight list 的调用方式,那么是否还必须在 python 端持有一个 fused tensor?(可能需要测试一下?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Done
为coalsec_tensor_op增加use_align属性,默认使用原有的保留小tensor的完整chunk,可选的在合并时小tensor连续存放。

attrs={"copy_data": True,
"dtype": params[0].dtype})

def _cudnn_impl(self, inputs, initial_states, sequence_length):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

一个小建议:等 cudnn_lstm_op 实现了 python 接口(functional 接口)之后改调用那个接口。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. 这个会在cudnn_lstm_op支持functional后调整。

self.time_major = time_major
self.num_layers = num_layers
self.state_components = 1
if activation == "tanh":
Copy link

@iclementine iclementine Sep 10, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议把三个改 RNN 网络(SimpleRNN, GRU, LSTM) init 的内容统一成一个,放到 RNNMixin 里面。因为目前这三个类的 init 方法基本是相同的,只要用 cls 区分,稍作修改可能就可以。

RNNMixin 如有必要也可以改名 RNNBase, 因为它目前并不公开。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks.
RNNMixin 改名 RNNBase 并调整了SimpleRNN, GRU, LSTM对RNNBase的调用

@guoshengCS
Copy link
Contributor Author

guoshengCS commented Sep 17, 2020

测试代码

import paddle
import paddle.fluid as fluid
import numpy as np
paddle.disable_static()
paddle.manual_seed(123)
x = paddle.randn((4, 10, 16))
x.stop_gradient = False
prev_h = paddle.randn((4, 4, 32))
prev_c = paddle.randn((4, 4, 32))
seq_len = paddle.to_tensor(np.array([10,6,8,5]))
mask = fluid.layers.sequence_mask(seq_len, maxlen=10, dtype=prev_h.dtype)
mask = paddle.unsqueeze(mask, [2])
rnn = paddle.nn.LSTM(16, 32, 2, direction="bidirectional")
y, (h, c) = rnn(x, (prev_h, prev_c), seq_len)
y = y * mask
loss = paddle.mean(y)
loss.backward()
optimizer = paddle.optimizer.Adam(learning_rate=0.1, parameters=rnn.parameters())
optimizer.step()
print(rnn[0].cell_fw.weight_hh)

多卡

import paddle
import paddle.fluid as fluid
import paddle.distributed as dist
import numpy as np

def train():
    paddle.disable_static()
    paddle.distributed.init_parallel_env()
    paddle.manual_seed(123)
    x = paddle.randn((4, 10, 16))
    x.stop_gradient = False
    prev_h = paddle.randn((4, 4, 32))
    prev_c = paddle.randn((4, 4, 32))
    seq_len = paddle.to_tensor(np.array([10,6,8,5]))
    mask = fluid.layers.sequence_mask(seq_len, maxlen=10, dtype=prev_h.dtype)
    mask = paddle.unsqueeze(mask, [2])
    rnn = paddle.nn.LSTM(16, 32, 2, direction="bidirectional")#, dropout=0.0)
    dp_layer = paddle.DataParallel(rnn)
    y, (h, c) = dp_layer(x, (prev_h, prev_c), seq_len)
    y = y * mask
    loss = paddle.mean(y)
    loss = dp_layer.scale_loss(loss)
    loss.backward()
    dp_layer.apply_collective_grads()
    optimizer = paddle.optimizer.Adam(learning_rate=0.1, parameters=rnn.parameters())
    optimizer.step()
    print(dp_layer._layers[0].cell_fw.weight_hh)

if __name__ == '__main__':
    dist.spawn(train, nprocs=2)

可以通过修改python api RNNBase.__init__代码中的self.could_use_cudnn = False对比cudnn与非cudnn版本

保存预测:

import paddle
import paddle.fluid as fluid
import paddle.distributed as dist
import numpy as np

class Net(paddle.nn.Layer):
    def __init__(self):
        super(Net, self).__init__()
        self.rnn1 = paddle.nn.LSTM(
            16, 32, 2, direction="bidirectional", dropout=0.1)

    def forward(self, input):
        return self.rnn1(input)

def train():
    paddle.disable_static()
    paddle.distributed.init_parallel_env()
    paddle.manual_seed(123)
    np.random.seed(123)
    x_np = np.random.rand(4, 10, 16).astype("float32")
    x = paddle.randn((4, 10, 16))
    x = paddle.to_tensor(x_np)
    x.stop_gradient = False
    prev_h = paddle.randn((4, 4, 32))
    prev_c = paddle.randn((4, 4, 32))
    seq_len = paddle.to_tensor(np.array([10,6,8,5]))
    mask = fluid.layers.sequence_mask(seq_len, maxlen=10, dtype=prev_h.dtype)
    mask = paddle.unsqueeze(mask, [2])
    rnn = Net()
    dp_layer = paddle.DataParallel(rnn)
    y, (h, c) = dp_layer(x)
    y = y * mask
    loss = paddle.mean(y)
    loss = dp_layer.scale_loss(loss)
    loss.backward()
    dp_layer.apply_collective_grads()
    optimizer = paddle.optimizer.Adam(learning_rate=0.1, parameters=rnn.parameters())
    optimizer.step()
    dp_layer.eval()
    y, (h, c) = dp_layer(x)
    print(y)
    dp_layer.train()
    if dist.get_rank() == 0:
        rnn = paddle.jit.to_static(rnn, [paddle.static.InputSpec(shape=[None, None, 16])])
        print(rnn.forward.concrete_program.main_program)
        paddle.jit.save(rnn, "./infer")

        paddle.enable_static()
        place = fluid.CPUPlace() if not fluid.is_compiled_with_cuda(
                ) else fluid.CUDAPlace(0)
        new_scope = fluid.Scope()
        with fluid.scope_guard(new_scope):
            exe = fluid.Executor(place)
            [inference_program, feed_target_names, fetch_targets] = (
                fluid.io.load_inference_model(
                    dirname="./", executor=exe, model_filename="infer.pdmodel", params_filename="infer.pdiparams"))
            results = exe.run(inference_program,
                              feed={feed_target_names[0]: x_np.astype("float32")},
                              fetch_list=fetch_targets)
            print(results)
            print(y.numpy() == results[0])  # eval与infer结果相同

if __name__ == '__main__':
    dist.spawn(train, nprocs=2)

@guoshengCS
Copy link
Contributor Author

2020-09-17 12:23:16 ****************
2020-09-17 12:23:16 0. You must have one RD (XiaoguangHu01 or lanxianghit) and one TPM (saxon-zh or jzhang533 or swtkiwi or Heeenrrry or TCChenlong) approval for the api change for the management reason of API interface.
2020-09-17 12:23:16 1. You must have one TPM (saxon-zh or jzhang533 or swtkiwi or Heeenrrry or TCChenlong) approval for the api change for the management reason of API document.
2020-09-17 12:23:16 2. You must have one RD (zhiqiu (Recommend) or phlrain) approval for the api change for the opreator-related api without 'core.ops'.
2020-09-17 12:23:16 For more details, please click [https://github.com/PaddlePaddle/Paddle/wiki/paddle_api_development_manual.md]
2020-09-17 12:23:16 Related APIs: paddle.nn.SimpleRNN.flatten_parameters
2020-09-17 12:23:16 paddle.nn.LSTM.flatten_parameters
2020-09-17 12:23:16 paddle.nn.GRU.flatten_parameters
2020-09-17 12:23:16 
2020-09-17 12:23:16 There are 3 approved errors.
2020-09-17 12:23:16 ****************
  • API修改还请 @XiaoguangHu01 @jzhang533 确认
  • cudnn_lstm由于有输入输出使用同一变量的内容(dropout的state),这里在动态图时使用了append_op未使用 core.ops,还请 @phlrain 确认

@guoshengCS guoshengCS requested a review from Xreki September 17, 2020 10:03
Copy link
Contributor

@jzhang533 jzhang533 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • SimpleRNN & SimpleRNNCell 有activation参数,需要相应的更新文档当中的公式和说明。
  • num_layers=1, activation="tanh", direction="forward", dropout=0., time_major=False, 这几个参数的顺序建议做一下调整: num_layers, direction, time_major, dropout, activation 这样的顺序,可能会好一些,前三个都是跟shape相关的,放一起容易理解; activation是SimpleRNN独有的,放最后比较合适,同时,文档中可以建议用户用keyword arguments调用;
  • direction='bidirectional` 时,我们只有concat这一种merge的方式,建议在这个参数的说明里说明一下。
  • 1165 行有typo

Copy link
Member

@ZeyuChen ZeyuChen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

对cudnn_lstm的Kernel名称需要讨论,避免将来cpu和gpu kernel部署的统一造成困扰。

}

self._helper.append_op(
type="cudnn_lstm", inputs=inputs, outputs=outputs, attrs=attrs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果这里op名字限制为cudnn_lstm, 对将来cpu kernel名字要统一是否会带来困扰,毕竟cudnn不是cpu关键字。
用lstm_v2之类的是否更合适?

@guoshengCS
Copy link
Contributor Author

  • SimpleRNN & SimpleRNNCell 有activation参数,需要相应的更新文档当中的公式和说明。
  • num_layers=1, activation="tanh", direction="forward", dropout=0., time_major=False, 这几个参数的顺序建议做一下调整: num_layers, direction, time_major, dropout, activation 这样的顺序,可能会好一些,前三个都是跟shape相关的,放一起容易理解; activation是SimpleRNN独有的,放最后比较合适,同时,文档中可以建议用户用keyword arguments调用;
  • direction='bidirectional` 时,我们只有concat这一种merge的方式,建议在这个参数的说明里说明一下。
  • 1165 行有typo

@jzhang533 Done, thanks.

jzhang533
jzhang533 previously approved these changes Sep 28, 2020
Copy link
Contributor

@jzhang533 jzhang533 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@guoshengCS guoshengCS requested a review from zhiqiu October 15, 2020 10:53
jzhang533
jzhang533 previously approved these changes Oct 15, 2020
Copy link
Contributor

@jzhang533 jzhang533 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@zhangting2020 zhangting2020 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@jzhang533 jzhang533 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

Copy link
Contributor

@zhiqiu zhiqiu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

flatten_parameters can add dygraph branch in the future.

@guoshengCS guoshengCS merged commit fa9d3fa into PaddlePaddle:develop Oct 16, 2020
guoshengCS added a commit to guoshengCS/Paddle that referenced this pull request Oct 17, 2020
* Incorporate cudnn_lstm into LSTM api.
test=develop

* Make coalesce_tensor support alignment optionally.
test=develop

* Reorganize RNN apis. test=develop

* Fix cudnn rnn layout conversion.
test=develop

* Add sequence_length support for RNN cudnn implement.
Add optional init_h and init_c gradient for cudnn_lstm_op.
test=develop

* Use create_parameter for rnn cudnn impl.
test=develop

* Move `self._flat_weight = self.create_parameter()` in RNNBase to main_program.
test=develop

* Update RNN api unittest to use set_device.
test=develop

* Fix set_place for unit tests of RNN apis.
test=develop

* Fix use_align in coalesce_tensor_op.
test=develop

* Adjust RNN apis arguments according to comments.
test=develop

* Polish documents for SimpleRNN apis.
test=develop

* Refine random seed in cudnn_lstm_op.
Expose rnn params from sublayers to RNN.
test=develop

* Fix RNN saving for jit.save.
Refine cudnn_lstm dropout behavior.
test=develop

* Fix doc of GRU. test=develop

* Use ShareDataWith to avoid copying for cudnn_lstm_op test.
test=develop

* Remove updates on cudnn_lstm temporarily.
test=develop

* Use ShareDataWith to avoid copying for cudnn_lstm_op test.
test=develop

* Refine random seed in cudnn_lstm_op.
test=develop

* Fix test_lstm by adjust ConcreteProgram buffer getter.
test=develop

* Use create_parameter instead of create_var for rnn._flat_weight for static graph usage.
test=develop

* Remove W input for cudnn_lstm to pass unused_var_check.
test=develop

* Add test_predict for RNN unit tests coverage.
test=develop

* Fix code style of rnn.
test=develop

* Fix F.rnn usage in rnn.py.
test=develop
guoshengCS added a commit that referenced this pull request Oct 19, 2020
* Incorporate cudnn_lstm into LSTM api (#27217)

* Incorporate cudnn_lstm into LSTM api.
test=develop

* Make coalesce_tensor support alignment optionally.
test=develop

* Reorganize RNN apis. test=develop

* Fix cudnn rnn layout conversion.
test=develop

* Add sequence_length support for RNN cudnn implement.
Add optional init_h and init_c gradient for cudnn_lstm_op.
test=develop

* Use create_parameter for rnn cudnn impl.
test=develop

* Move `self._flat_weight = self.create_parameter()` in RNNBase to main_program.
test=develop

* Update RNN api unittest to use set_device.
test=develop

* Fix set_place for unit tests of RNN apis.
test=develop

* Fix use_align in coalesce_tensor_op.
test=develop

* Adjust RNN apis arguments according to comments.
test=develop

* Polish documents for SimpleRNN apis.
test=develop

* Refine random seed in cudnn_lstm_op.
Expose rnn params from sublayers to RNN.
test=develop

* Fix RNN saving for jit.save.
Refine cudnn_lstm dropout behavior.
test=develop

* Fix doc of GRU. test=develop

* Use ShareDataWith to avoid copying for cudnn_lstm_op test.
test=develop

* Remove updates on cudnn_lstm temporarily.
test=develop

* Use ShareDataWith to avoid copying for cudnn_lstm_op test.
test=develop

* Refine random seed in cudnn_lstm_op.
test=develop

* Fix test_lstm by adjust ConcreteProgram buffer getter.
test=develop

* Use create_parameter instead of create_var for rnn._flat_weight for static graph usage.
test=develop

* Remove W input for cudnn_lstm to pass unused_var_check.
test=develop

* Add test_predict for RNN unit tests coverage.
test=develop

* Fix code style of rnn.
test=develop

* Fix F.rnn usage in rnn.py.
test=develop

* Fix test_lstm unittest failed and Add more unittest (#28029)

* fix test_lstm unittest failed

* add more unittest

* modify cmakelist

* fix judgement

Co-authored-by: Aurelius84 <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants