Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 37 additions & 6 deletions paddle/fluid/operators/cudnn_lstm_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License. */
#include <memory>
#include <string>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"

namespace paddle {
namespace operators {
Expand All @@ -25,7 +26,6 @@ class CudnnLSTMOp : public framework::OperatorWithKernel {

void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "CudnnLSTM");
OP_INOUT_CHECK(ctx->HasInput("W"), "Input", "W", "CudnnLSTM");
OP_INOUT_CHECK(ctx->HasInput("InitH"), "Input", "InitH", "CudnnLSTM");
OP_INOUT_CHECK(ctx->HasInput("InitC"), "Input", "InitC", "CudnnLSTM");

Expand Down Expand Up @@ -122,7 +122,13 @@ class CudnnLSTMOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("W",
"(Tensor) the learnable hidden-hidden weights."
" The shape is (N), where N is total weight size of the LSTM. "
" cudnn concatenate all the weight to one Tensor");
" cudnn concatenate all the weight to one Tensor")
.AsDispensable();
AddInput("WeightList",
"(vector<Tensor>), stores weight and bias data when the weight "
"use the list format. ")
.AsDispensable()
.AsDuplicable();
AddInput("SequenceLength",
"(Tensor) When the input data is padding, "
"set this parameter. This parameter represents "
Expand Down Expand Up @@ -216,7 +222,6 @@ class CudnnLSTMGradOp : public framework::OperatorWithKernel {

void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "CudnnLSTMGrad");
OP_INOUT_CHECK(ctx->HasInput("W"), "Input", "W", "CudnnLSTMGrad");
OP_INOUT_CHECK(ctx->HasInput("InitH"), "Input", "InitH", "CudnnLSTMGrad");
OP_INOUT_CHECK(ctx->HasInput("InitC"), "Input", "InitC", "CudnnLSTMGrad");

Expand All @@ -228,7 +233,10 @@ class CudnnLSTMGradOp : public framework::OperatorWithKernel {
};

SetOutGradDim("Input");
SetOutGradDim("W");
if (ctx->HasInputs("WeightList")) {
ctx->SetOutputsDim(framework::GradVarName("WeightList"),
ctx->GetInputsDim("WeightList"));
}
SetOutGradDim("InitH");
SetOutGradDim("InitC");
}
Expand All @@ -251,7 +259,9 @@ class CudnnLSTMGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetInput("Input", this->Input("Input"));
op->SetInput("InitH", this->Input("InitH"));
op->SetInput("InitC", this->Input("InitC"));
op->SetInput("W", this->Input("W"));
if (this->HasInput("WeightList")) {
op->SetInput("WeightList", this->Input("WeightList"));
}
if (this->HasInput("SequenceLength")) {
op->SetInput("SequenceLength", this->Input("SequenceLength"));
}
Expand All @@ -262,8 +272,12 @@ class CudnnLSTMGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetInput(framework::GradVarName("LastC"), this->OutputGrad("LastC"));
op->SetInput(framework::GradVarName("LastH"), this->OutputGrad("LastH"));

if (this->HasInput("WeightList")) {
op->SetOutput(framework::GradVarName("WeightList"),
this->InputGrad("WeightList", false));
}

op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));
op->SetOutput(framework::GradVarName("W"), this->InputGrad("W"));
op->SetOutput(framework::GradVarName("InitH"), this->InputGrad("InitH"));
op->SetOutput(framework::GradVarName("InitC"), this->InputGrad("InitC"));
op->SetAttrMap(this->Attrs());
Expand All @@ -290,3 +304,20 @@ REGISTER_OPERATOR(cudnn_lstm_grad, ops::CudnnLSTMGradOp);

REGISTER_OP_CPU_KERNEL(cudnn_lstm, ops::NotImpleKernel<float>);
REGISTER_OP_CPU_KERNEL(cudnn_lstm_grad, ops::NotImpleKernel<float>);

// TODO(Shixiaowei02) Add ModifyInput support
REGISTER_OP_VERSION(cudnn_lstm)
.AddCheckpoint(
R"ROC(
Upgrade cudnn_lstm add a new input [WeightList] and modify input [W] to dispensable.)ROC",
paddle::framework::compatible::OpVersionDesc()
.NewInput(
"WeightList",
"The WeightList stores weight and bias data. WeightList is "
"dispensable.")
.NewInput("SequenceLength",
"When the input data is padding, set this parameter. "
"SequenceLength is dispensable.")
.NewOutput("StateOut", "Store the global drop state when training")
.NewOutput("Reserve",
"A temporary output Tensor to store the reserve_data"));
162 changes: 144 additions & 18 deletions paddle/fluid/operators/cudnn_lstm_op.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,66 @@ namespace operators {
using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;

template <typename T, typename Type>
bool is_continuous(const Type &weight_list) {
bool continuous = true;
for (size_t i = 0; i < weight_list.size() - 1; ++i) {
auto *in_data = weight_list[i]->template data<T>();
auto *in_after_data = weight_list[i + 1]->template data<T>();
auto in_size = weight_list[i]->numel();
bool temp = in_data + in_size == in_after_data;
continuous = continuous && temp;
}
return continuous;
}

int size_sum(const std::vector<const Tensor *> &weight_list) {
int size = 0;
for (size_t i = 0; i < weight_list.size(); ++i) {
auto in_size = weight_list[i]->numel();
size += in_size;
}
return size;
}

template <typename T>
void weight_to_tensor(const platform::Place &place, cudaStream_t stream,
const std::vector<const Tensor *> &weight_list,
Tensor *weight) {
auto weight_data = weight->data<T>();
int weight_offset = 0;
for (size_t i = 0; i < weight_list.size(); ++i) {
const T *in_data = weight_list[i]->data<T>();
auto in_size = weight_list[i]->numel();

memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, weight->place()),
weight_data + weight_offset,
BOOST_GET_CONST(platform::CUDAPlace, weight_list[i]->place()),
in_data, in_size * sizeof(T), stream);
weight_offset += in_size;
}
}

template <typename T>
void weight_to_tensor_list(const platform::Place &place, cudaStream_t stream,
std::vector<Tensor *> *weight_grad,
const std::vector<const Tensor *> &weight_input,
const Tensor *weight) {
int weight_offset = 0;
auto *weight_data = weight->data<T>();
for (size_t i = 0; i < weight_input.size(); ++i) {
auto in_size = weight_input[i]->numel();
T *weight_grad_data = (*weight_grad)[i]->mutable_data<T>(place);
const T *src = weight_data + weight_offset;

memory::Copy(
BOOST_GET_CONST(platform::CUDAPlace, (*weight_grad)[i]->place()),
weight_grad_data, BOOST_GET_CONST(platform::CUDAPlace, weight->place()),
src, in_size * sizeof(T), stream);
weight_offset += in_size;
}
}

template <typename T>
void LSTMInferece(const bool &has_seq_length, const cudnnHandle_t &handle,
const int &seq_length, ScopedRNNBase *rnn, const T *x_data,
Expand Down Expand Up @@ -75,8 +135,6 @@ class CudnnLSTMGPUKernel : public framework::OpKernel<T> {
const Tensor *init_h = ctx.Input<Tensor>("InitH");
const Tensor *init_c = ctx.Input<Tensor>("InitC");

auto w = ctx.Input<Tensor>("W");

Tensor *out = ctx.Output<Tensor>("Out");
Tensor *last_h = ctx.Output<Tensor>("LastH");
Tensor *last_c = ctx.Output<Tensor>("LastC");
Expand All @@ -87,8 +145,6 @@ class CudnnLSTMGPUKernel : public framework::OpKernel<T> {
const T *init_h_data = init_h->data<T>();
const T *init_c_data = init_c->data<T>();

const T *w_data = w->data<T>();

T *out_data = out->mutable_data<T>(ctx.GetPlace());
T *last_h_data = last_h->mutable_data<T>(ctx.GetPlace());
T *last_c_data = last_c->mutable_data<T>(ctx.GetPlace());
Expand All @@ -113,11 +169,45 @@ class CudnnLSTMGPUKernel : public framework::OpKernel<T> {
int seq_length = x->dims()[0];
int batch_size = x->dims()[1];
int input_size = x->dims()[2];
int weight_numel = w->numel();
bool state_initialized = state_out->IsInitialized() ? true : false;

size_t workspace_size;
size_t reserve_size;
Tensor weight_whole;
T *w_data = nullptr;
int weight_numel;
bool w_initialized = false;
auto place = ctx.GetPlace();
auto stream = reinterpret_cast<const platform::CUDADeviceContext &>(
ctx.device_context())
.stream();
if (is_test && ctx.HasInput("W")) {
auto *W = ctx.Input<Tensor>("W");
w_initialized = W->IsInitialized() ? true : false;
weight_numel = W->numel();
}
if (!w_initialized) {
auto weight_list = ctx.MultiInput<framework::Tensor>("WeightList");
bool continuous =
is_continuous<T, std::vector<const Tensor *>>(weight_list);
weight_numel = size_sum(weight_list);

if (!continuous) {
LOG_FIRST_N(WARNING, 2)
<< "If the memory space of the Input WeightList is not "
"continuous, less efficient calculation will be "
"called. Please call coalesce_tensor op to make the "
"input memory continuous.";
weight_whole.mutable_data<T>({weight_numel}, place);
weight_to_tensor<T>(place, stream, weight_list, &weight_whole);
w_data = weight_whole.data<T>();
} else {
w_data = const_cast<T *>(weight_list[0]->data<T>());
}
} else {
auto *W = ctx.Input<Tensor>("W");
w_data = const_cast<T *>(W->data<T>());
}

ScopedRNNBase rnn(seq_length, batch_size, input_size, hidden_size,
num_layers, dropout_prob, seed, weight_numel,
Expand All @@ -136,6 +226,12 @@ class CudnnLSTMGPUKernel : public framework::OpKernel<T> {
LSTMInferece<T>(has_seq_length, handle, seq_length, &rnn, x_data,
init_h_data, init_c_data, w_data, out_data, last_h_data,
last_c_data, &workspace_data_, workspace_size);
if (!w_initialized && ctx.HasInput("W") && ctx.HasInput("WeightList")) {
auto *W = const_cast<Tensor *>(ctx.Input<Tensor>("W"));
auto weight_list = ctx.MultiInput<framework::Tensor>("WeightList");
W->mutable_data<T>({weight_numel}, place);
weight_to_tensor<T>(place, stream, weight_list, W);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里is_test==true时是否会每次拷贝呢,可否在W未被初始化的时候拷贝呢

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Python端预测时会初始化W,这时候用的是W, 而且不会拷贝数据。
C++预测时不会初始化W,用的是weight_list,但是会拷贝weight_list到W。

} else {
if (!has_seq_length) {
// for train
Expand Down Expand Up @@ -176,21 +272,22 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *input = ctx.Input<Tensor>("Input");
auto *weight = ctx.Input<Tensor>("W");
auto *init_h = ctx.Input<Tensor>("InitH");
auto *init_c = ctx.Input<Tensor>("InitC");
auto *reserve = ctx.Input<Tensor>("Reserve");
auto *state_out = ctx.Input<Tensor>("StateOut");
auto weight_list = ctx.MultiInput<Tensor>("WeightList");

auto *out = ctx.Input<Tensor>("Out");
auto *out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto *last_h_grad = ctx.Input<Tensor>(framework::GradVarName("LastH"));
auto *last_c_grad = ctx.Input<Tensor>(framework::GradVarName("LastC"));

auto *in_grad = ctx.Output<Tensor>(framework::GradVarName("Input"));
auto *weight_grad = ctx.Output<Tensor>(framework::GradVarName("W"));
auto *init_h_grad = ctx.Output<Tensor>(framework::GradVarName("InitH"));
auto *init_c_grad = ctx.Output<Tensor>(framework::GradVarName("InitC"));
auto weight_grad_list = ctx.MultiOutput<framework::Tensor>(
framework::GradVarName("WeightList"));

auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto handle = dev_ctx.cudnn_handle();
Expand All @@ -199,26 +296,57 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel<T> {
auto init_h_dims = init_h->dims();
auto init_c_dims = init_c->dims();

auto *weight_data = weight->data<T>();
auto *init_h_data = init_h->data<T>();
auto *init_c_data = init_c->data<T>();
auto *out_data = out->data<T>();
auto *out_grad_data = out_grad->data<T>();
auto *last_h_grad_data = last_h_grad->data<T>();
auto *last_c_grad_data = last_c_grad->data<T>();

auto place = ctx.GetPlace();
int weight_numel = size_sum(weight_list);
bool continuous =
is_continuous<T, std::vector<const Tensor *>>(weight_list);

auto stream = reinterpret_cast<const platform::CUDADeviceContext &>(
ctx.device_context())
.stream();
Tensor weight_whole;
T *weight_data = nullptr;

if (!continuous) {
weight_whole.mutable_data<T>({weight_numel}, place);
weight_to_tensor<T>(place, stream, weight_list, &weight_whole);
weight_data = weight_whole.data<T>();
} else {
weight_data = const_cast<T *>(weight_list[0]->data<T>());
}

Tensor weight_grad;
math::SetConstant<paddle::platform::CUDADeviceContext, T> zero;
weight_grad->mutable_data<T>(ctx.GetPlace());
zero(dev_ctx, weight_grad, static_cast<T>(0.0));
weight_grad.mutable_data<T>({weight_numel}, ctx.GetPlace());
zero(dev_ctx, &weight_grad, static_cast<T>(0.0));
T *weight_grad_data = weight_grad.data<T>();

int offset = 0;
for (size_t i = 0; i < weight_grad_list.size(); ++i) {
size_t len = weight_grad_list[i]->numel();
auto dim = weight_grad_list[i]->dims();
weight_grad_list[i]
->ShareDataWith(weight_grad.Slice(static_cast<int64_t>(offset),
static_cast<int64_t>(offset + len)))
.Resize(dim);
offset += len;
}

in_grad->mutable_data<T>(input_dims, ctx.GetPlace());
auto *in_grad_data = in_grad->data<T>();

init_h_grad->mutable_data<T>(init_h_dims, ctx.GetPlace());
auto *init_h_grad_data = init_h_grad->data<T>();
if (init_h_grad) init_h_grad->mutable_data<T>(init_h_dims, ctx.GetPlace());
auto *init_h_grad_data = init_h_grad ? init_h_grad->data<T>() : nullptr;

init_c_grad->mutable_data<T>(init_c_dims, ctx.GetPlace());
auto *init_c_grad_data = init_c_grad->data<T>();
if (init_c_grad) init_c_grad->mutable_data<T>(init_c_dims, ctx.GetPlace());
auto *init_c_grad_data = init_c_grad ? init_c_grad->data<T>() : nullptr;

float dropout_prob = ctx.Attr<float>("dropout_prob");
bool is_bidirec = ctx.Attr<bool>("is_bidirec");
Expand All @@ -236,7 +364,6 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel<T> {
int seq_length = input_dims[0];
int batch_size = input->dims()[1];
int input_size = input->dims()[2];
int weight_numel = weight->numel();

size_t workspace_size;
size_t reserve_size;
Expand Down Expand Up @@ -268,8 +395,7 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel<T> {
handle, rnn.rnn_desc(), seq_length, rnn.x_descs(), input->data<T>(),
rnn.init_h_desc(), init_h->data<T>(), rnn.y_descs(), out->data<T>(),
workspace_data_.data<uint8_t>(), workspace_size, rnn.weight_desc(),
weight_grad->data<T>(), const_cast<uint8_t *>(reserve_data),
reserve_size));
weight_grad_data, const_cast<uint8_t *>(reserve_data), reserve_size));
} else {
#if CUDNN_VERSION >= 7201
// for train
Expand All @@ -288,7 +414,7 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel<T> {
handle, rnn.rnn_desc(), rnn.x_seq_desc(), input->data<T>(),
rnn.init_h_desc(), init_h->data<T>(), rnn.y_seq_desc(),
out->data<T>(), workspace_data_.data<uint8_t>(), workspace_size,
rnn.weight_desc(), weight_grad->data<T>(),
rnn.weight_desc(), weight_grad_data,
const_cast<uint8_t *>(reserve_data), reserve_size));
#else
PADDLE_THROW(platform::errors::Unavailable(
Expand Down
20 changes: 7 additions & 13 deletions python/paddle/fluid/layers/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2443,23 +2443,17 @@ def lstm(input,
input_shape = list(input.shape)
input_size = input_shape[-1]
weight_size = 0
num_dirrection = 2 if is_bidirec == True else 1

for i in range(num_layers):
if i == 0:
input_weight_size = (input_size * hidden_size) * 4
input_weight_size = (input_size * hidden_size) * 4 * num_dirrection
else:
if is_bidirec:
input_weight_size = (hidden_size * 2 * hidden_size) * 4
else:
input_weight_size = (hidden_size * hidden_size) * 4
input_weight_size = (hidden_size * hidden_size) * 4 * num_dirrection
hidden_weight_size = (hidden_size * hidden_size) * 4 * num_dirrection

hidden_weight_size = (hidden_size * hidden_size) * 4

if is_bidirec:
weight_size += (input_weight_size + hidden_weight_size) * 2
weight_size += hidden_size * 8 * 2
else:
weight_size += input_weight_size + hidden_weight_size
weight_size += hidden_size * 8
weight_size += input_weight_size + hidden_weight_size
weight_size += hidden_size * 8 * num_dirrection

weight = helper.create_parameter(
attr=helper.param_attr,
Expand Down
Loading