Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions paddle/fluid/operators/math/context_project.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ template <typename DeviceContext, typename T>
class ContextProjectFunctor {
public:
void operator()(const DeviceContext& context, const LoDTensor& in,
const Tensor& padding_data, bool padding_trainable,
const Tensor* padding_data, bool padding_trainable,
const int context_start, const int context_length,
const int context_stride, const int up_pad,
const int down_pad, Tensor* col) {
Expand Down Expand Up @@ -132,6 +132,7 @@ class ContextProjectFunctor {
}
}
if (padding_trainable) {
PADDLE_ENFORCE_NOT_NULL(padding_data);
for (int i = 0; i < static_cast<int>(lod_level_0.size()) - 1; ++i) {
Tensor out_t = col->Slice(static_cast<int>(lod_level_0[i]),
static_cast<int>(lod_level_0[i + 1]));
Expand All @@ -150,7 +151,7 @@ class ContextProjectFunctor {
k + context_length < up_pad ? context_length : up_pad - k;
Tensor out_t_sub = out_t.Slice(k * context_length,
k * context_length + padding_size);
Tensor w_sub = padding_data.Slice(k, k + padding_size);
Tensor w_sub = padding_data->Slice(k, k + padding_size);
framework::TensorCopy(w_sub, context.GetPlace(), context,
&out_t_sub);
}
Expand Down Expand Up @@ -180,7 +181,7 @@ class ContextProjectFunctor {
Tensor out_t_sub = out_t.Slice(
(down_pad_begin_row + t) * context_length - padding_size,
(down_pad_begin_row + t) * context_length);
Tensor w_sub = padding_data.Slice(
Tensor w_sub = padding_data->Slice(
up_pad + padding_idx, up_pad + padding_idx + padding_size);
framework::TensorCopy(w_sub, context.GetPlace(), context,
&out_t_sub);
Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/operators/sequence_ops/sequence_conv_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class SequenceConvKernel : public framework::OpKernel<T> {

int up_pad = std::max(0, -context_start);
int down_pad = std::max(0, context_start + context_length - 1);
int sequence_width = static_cast<int>(in->dims()[1]);
auto sequence_width = static_cast<int64_t>(in->dims()[1]);

framework::DDim col_shape = {in->dims()[0],
context_length * sequence_width};
Expand All @@ -62,7 +62,7 @@ class SequenceConvKernel : public framework::OpKernel<T> {
set_zero(dev_ctx, &col, static_cast<T>(0));
math::ContextProjectFunctor<DeviceContext, T> seq_project_functor;

seq_project_functor(dev_ctx, *in, *padding_data, padding_trainable,
seq_project_functor(dev_ctx, *in, padding_data, padding_trainable,
context_start, context_length, context_stride, up_pad,
down_pad, &col);

Expand Down Expand Up @@ -93,7 +93,7 @@ class SequenceConvGradKernel : public framework::OpKernel<T> {

int up_pad = std::max(0, -context_start);
int down_pad = std::max(0, context_start + context_length - 1);
int sequence_width = static_cast<int>(in->dims()[1]);
auto sequence_width = static_cast<int64_t>(in->dims()[1]);

math::SetConstant<DeviceContext, T> set_zero;
auto& dev_ctx = context.template device_context<DeviceContext>();
Expand Down Expand Up @@ -144,7 +144,7 @@ class SequenceConvGradKernel : public framework::OpKernel<T> {
padding_data = context.Input<Tensor>("PaddingData");
}

seq_project_functor(dev_ctx, *in, *padding_data, padding_trainable,
seq_project_functor(dev_ctx, *in, padding_data, padding_trainable,
context_start, context_length, context_stride, up_pad,
down_pad, &col);

Expand Down