Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions paddle/fluid/operators/math/sequence_padding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ void CopyValidData(framework::Tensor* dst_tensor,
const framework::Tensor* src_tensor,
const framework::Vector<size_t>& seq_offsets,
int pad_seq_len, int step_width, bool norm_by_len,
bool size_average, bool length_average, int total_logits_len,
CopyType type, PadLayout layout) {
int seq_num = seq_offsets.size() - 1;
const T* src_data = src_tensor->data<T>();
Expand All @@ -54,7 +55,21 @@ void CopyValidData(framework::Tensor* dst_tensor,
int pad_data_offset = layout == kBatchLengthWidth
? seq_idx * pad_seq_len * step_width
: seq_idx * step_width;
float scale = 1.0f / static_cast<float>(valid_seq_len);

float scale = 1.0f;
if (length_average) {
scale = 1.0f / static_cast<float>(total_logits_len);
VLOG(3) << "[warpctc grad][length_average]: scale " << scale
<< "total_logits_len " << total_logits_len;
} else if (size_average) {
scale = 1.0f / static_cast<float>(seq_num);
VLOG(3) << "[warpctc grad][size_average]: scale " << scale << "B "
<< seq_num;
} else if (norm_by_len) {
scale = 1.0f / static_cast<float>(valid_seq_len);
VLOG(3) << "[warpctc grad][norm_by_len]: scale " << scale << "T "
<< valid_seq_len;
}

for (int step_idx = 0; step_idx < valid_seq_len; ++step_idx) {
const T* src =
Expand Down Expand Up @@ -97,6 +112,7 @@ class PaddingLoDTensorFunctor<platform::CPUDeviceContext, T> {
framework::LoDTensor* pad_tensor,
const framework::LoDTensor& pad_value, int pad_seq_len = -1,
int lod_level = 0, bool norm_by_times = false,
bool size_average = false, bool length_average = false,
const PadLayout layout = kBatchLengthWidth) {
auto seq_lod = seq_tensor.lod();
const auto seq_offsets = framework::ToAbsOffset(seq_lod)[lod_level];
Expand Down Expand Up @@ -131,7 +147,8 @@ class PaddingLoDTensorFunctor<platform::CPUDeviceContext, T> {
}

CopyValidData<T>(pad_tensor, &seq_tensor, seq_offsets, pad_seq_len,
step_width, norm_by_times, kSeqToPad, layout);
step_width, norm_by_times, false, false, 0, kSeqToPad,
layout);
}
};

Expand All @@ -142,20 +159,23 @@ class UnpaddingLoDTensorFunctor<platform::CPUDeviceContext, T> {
const framework::LoDTensor& pad_tensor,
framework::LoDTensor* seq_tensor, int pad_seq_len = -1,
int lod_level = 0, bool norm_by_times = false,
bool size_average = false, bool length_average = false,
const PadLayout layout = kBatchLengthWidth) {
auto seq_offsets = framework::ToAbsOffset(seq_tensor->lod())[lod_level];
const auto& seq_tensor_dims = seq_tensor->dims();
const auto& pad_tensor_dims = pad_tensor.dims();
if (pad_seq_len == -1) {
pad_seq_len = MaximumSequenceLength(seq_offsets);
}
int total_logits_len = TotalSequenceLength(seq_offsets);
int step_width = seq_tensor->numel() / seq_tensor_dims[0];

CheckDims(seq_tensor_dims, pad_tensor_dims, seq_offsets, pad_seq_len,
step_width, layout);

CopyValidData<T>(seq_tensor, &pad_tensor, seq_offsets, pad_seq_len,
step_width, norm_by_times, kPadToSeq, layout);
step_width, norm_by_times, size_average, length_average,
total_logits_len, kPadToSeq, layout);
}
};

Expand Down
21 changes: 17 additions & 4 deletions paddle/fluid/operators/math/sequence_padding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ template <typename T, CopyType Type>
__global__ void SequencePaddingKernel(
T* dst, const T* src, const T* pad_value, bool is_constant_pad,
const size_t* seq_offsets, const size_t seq_num, const size_t pad_seq_len,
const size_t step_width, bool norm_by_len, const PadLayout layout) {
const size_t step_width, bool norm_by_len, bool size_average,
bool length_average, int total_logits_len, const PadLayout layout) {
size_t seq_idx = blockIdx.y;
size_t seq_len = seq_offsets[seq_idx + 1] - seq_offsets[seq_idx];

Expand All @@ -38,7 +39,15 @@ __global__ void SequencePaddingKernel(
src + (Type == kSeqToPad ? seq_data_offset : pad_data_offset);

if (step_idx < seq_len) {
float scale = norm_by_len ? (1.0f / static_cast<float>(seq_len)) : 1.0f;
float scale = 1.0f;
if (length_average) {
scale = 1.0f / static_cast<float>(total_logits_len);
} else if (size_average) {
scale = 1.0f / static_cast<float>(seq_num);
} else if (norm_by_len) {
scale = norm_by_len ? (1.0f / static_cast<float>(seq_len)) : 1.0f;
}

for (size_t i = threadIdx.x; i < step_width; i += blockDim.x) {
dst_data[i] = scale * src_data[i];
}
Expand All @@ -57,6 +66,7 @@ class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
framework::LoDTensor* pad_tensor,
const framework::LoDTensor& pad_value, int pad_seq_len = -1,
int lod_level = 0, bool norm_by_times = false,
bool size_average = false, bool length_average = false,
const PadLayout layout = kBatchLengthWidth) {
auto seq_lod = seq_tensor.lod();
const auto seq_offsets = framework::ToAbsOffset(seq_lod)[lod_level];
Expand Down Expand Up @@ -107,7 +117,7 @@ class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
SequencePaddingKernel<T, kSeqToPad><<<grid, threads, 0, context.stream()>>>(
pad_data, seq_data, pad_value_data, pad_value.numel() == 1,
seq_offsets.CUDAData(context.GetPlace()), seq_num, pad_seq_len,
step_width, norm_by_times, layout);
step_width, norm_by_times, false, false, 0, layout);
}
};

Expand All @@ -118,6 +128,7 @@ class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
const framework::LoDTensor& pad_tensor,
framework::LoDTensor* seq_tensor, int pad_seq_len = -1,
int lod_level = 0, bool norm_by_times = false,
bool size_average = false, bool length_average = false,
const PadLayout layout = kBatchLengthWidth) {
auto seq_offsets = framework::ToAbsOffset(seq_tensor->lod())[lod_level];
const auto& seq_tensor_dims = seq_tensor->dims();
Expand All @@ -126,6 +137,7 @@ class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
if (pad_seq_len == -1) {
pad_seq_len = max_seq_len;
}
int total_logits_len = TotalSequenceLength(seq_offsets);
int step_width = seq_tensor->numel() / seq_tensor_dims[0];
int seq_num = seq_offsets.size() - 1;

Expand Down Expand Up @@ -159,7 +171,8 @@ class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
SequencePaddingKernel<T, kPadToSeq><<<grid, threads, 0, context.stream()>>>(
seq_data, pad_data, nullptr, false,
seq_offsets.CUDAData(context.GetPlace()), seq_num, pad_seq_len,
step_width, norm_by_times, layout);
step_width, norm_by_times, size_average, length_average,
total_logits_len, layout);
}
};

Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/operators/math/sequence_padding.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ class PaddingLoDTensorFunctor {
framework::LoDTensor* pad_tensor,
const framework::LoDTensor& pad_value, int pad_seq_len = -1,
int lod_level = 0, bool norm_by_times = false,
bool size_average = false, bool length_average = false,
const PadLayout layout = kBatchLengthWidth);
};

Expand All @@ -117,6 +118,7 @@ class UnpaddingLoDTensorFunctor {
const framework::LoDTensor& pad_tensor,
framework::LoDTensor* seq_tensor, int pad_seq_len = -1,
int lod_level = 0, bool norm_by_times = false,
bool size_average = false, bool length_average = false,
const PadLayout layout = kBatchLengthWidth);
};

Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/operators/math/sequence_padding_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,13 @@ void TestSequencePadding(const DeviceContext &context,
}

paddle::operators::math::PaddingLoDTensorFunctor<DeviceContext, T>()(
context, seq, &padding, pad_value, -1, 0, false,
context, seq, &padding, pad_value, -1, 0, false, false, false,
paddle::operators::math::kLengthBatchWidth);

seq_back.set_lod(lod);
seq_back.mutable_data<T>(seq_dims, place);
paddle::operators::math::UnpaddingLoDTensorFunctor<DeviceContext, T>()(
context, padding, &seq_back, -1, 0, false,
context, padding, &seq_back, -1, 0, false, false, false,
paddle::operators::math::kLengthBatchWidth);

if (paddle::platform::is_cpu_place(place)) {
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/operators/sequence_ops/sequence_pad_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class SequencePadOpKernel : public framework::OpKernel<T> {

math::PaddingLoDTensorFunctor<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), *x, out, *pad_value,
padded_length, 0, false, math::kBatchLengthWidth);
padded_length, 0, false, false, false, math::kBatchLengthWidth);

LoDTensor seq_len;
seq_len.Resize(len_t->dims());
Expand All @@ -72,7 +72,7 @@ class SequencePadGradOpKernel : public framework::OpKernel<T> {

math::UnpaddingLoDTensorFunctor<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), *d_out, d_x,
padded_length, 0, false, math::kBatchLengthWidth);
padded_length, 0, false, false, false, math::kBatchLengthWidth);
}
}
};
Expand Down
5 changes: 3 additions & 2 deletions paddle/fluid/operators/sequence_ops/sequence_unpad_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ class SequenceUnpadOpKernel : public framework::OpKernel<T> {

int64_t padded_length = x_t->dims()[1];
math::UnpaddingLoDTensorFunctor<DeviceContext, T>()(
dev_ctx, *x_t, out_t, padded_length, 0, false, math::kBatchLengthWidth);
dev_ctx, *x_t, out_t, padded_length, 0, false, false, false,
math::kBatchLengthWidth);
}
};

Expand All @@ -93,7 +94,7 @@ class SequenceUnpadGradOpKernel : public framework::OpKernel<T> {

math::PaddingLoDTensorFunctor<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), *d_out, d_x, zero_pads,
padded_length, 0, false, math::kBatchLengthWidth);
padded_length, 0, false, false, false, math::kBatchLengthWidth);
}
}
};
Expand Down
29 changes: 29 additions & 0 deletions paddle/fluid/operators/warpctc_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,17 @@ class WarpCTCOpMaker : public framework::OpProtoAndCheckerMaker {
"normalize the gradients by the number of time-step, "
"which is also the sequence's length.")
.SetDefault(false);
AddAttr<bool>(
"size_average",
"(bool, default: false), normalize the loss by the batch size."
"If True, supersedes norm_by_times")
.SetDefault(false);
AddAttr<bool>(
"length_average",
"(bool, default: false), normalize the loss by the total number of "
"frames"
"in the batch. If True, supersedes size_average and norm_by_times")
.SetDefault(false);
AddComment(R"DOC(
An operator integrating the open-source
[warp-ctc](https://github.com/baidu-research/warp-ctc) library, which is used in
Expand Down Expand Up @@ -206,3 +217,21 @@ REGISTER_OP_CPU_KERNEL(
warpctc_grad,
ops::WarpCTCGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::WarpCTCGradKernel<paddle::platform::CPUDeviceContext, double>);

REGISTER_OP_VERSION(warpctc)
.AddCheckpoint(
R"ROC(
Upgrade warpctc add a new attribute [size_average] and [length_average])ROC",
paddle::framework::compatible::OpVersionDesc()
.NewAttr(
"size_average",
"(bool, default: false), normalize the loss by the batch size."
"If True, supersedes norm_by_times",
false)
.NewAttr("length_average",
"(bool, default: false), normalize the loss by the total "
"number of "
"frames"
"in the batch. If True, supersedes size_average and "
"norm_by_times",
false));
77 changes: 57 additions & 20 deletions paddle/fluid/operators/warpctc_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License. */

#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/sequence_padding.h"
#include "paddle/fluid/operators/math/sequence_scale.h"
Expand Down Expand Up @@ -150,7 +151,7 @@ class WarpCTCFunctor {
PADDLE_ENFORCE_EQ(
CTC_STATUS_SUCCESS, status,
platform::errors::PreconditionNotMet(
"warp-ctc [version %d] Error in get_workspace_size: %s",
"warp-ctc [version %d] Error in ComputeCtcLossFunctor: %s",
warpctc_version_, platform::dynload::ctcGetStatusString(status)));
}

Expand Down Expand Up @@ -285,8 +286,8 @@ class WarpCTCKernel : public framework::OpKernel<T> {

math::PaddingLoDTensorFunctor<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), *logits,
&warpctc_logits, pad_value, -1, 0, false /* norm_by_times */,
math::kLengthBatchWidth);
&warpctc_logits, pad_value, -1, 0, false /* norm_by_times */, false,
false, math::kLengthBatchWidth);
}
const T* warpctc_logits_data = warpctc_logits.data<T>();

Expand Down Expand Up @@ -321,7 +322,7 @@ class WarpCTCKernel : public framework::OpKernel<T> {
math::UnpaddingLoDTensorFunctor<DeviceContext, int>()(
ctx.template device_context<DeviceContext>(), *label,
&warpctc_label, label->dims()[1] /*pad_seq_len*/, 0 /*lod_level*/,
false /*norm_by_times*/, math::kBatchLengthWidth);
false /*norm_by_times*/, false, false, math::kBatchLengthWidth);
} else {
LoDTensor gpu_label;
gpu_label.mutable_data<int>(
Expand All @@ -331,7 +332,7 @@ class WarpCTCKernel : public framework::OpKernel<T> {
math::UnpaddingLoDTensorFunctor<DeviceContext, int>()(
ctx.template device_context<DeviceContext>(), *label, &gpu_label,
label->dims()[1] /*pad_seq_len*/, 0 /*lod_level*/,
false /*norm_by_times*/, math::kBatchLengthWidth);
false /*norm_by_times*/, false, false, math::kBatchLengthWidth);
TensorCopySync(gpu_label, platform::CPUPlace(), &warpctc_label);
}
} else {
Expand Down Expand Up @@ -366,22 +367,49 @@ class WarpCTCGradKernel : public framework::OpKernel<T> {

logits_grad->mutable_data<T>(ctx.GetPlace());
bool norm_by_times = ctx.Attr<bool>("norm_by_times");
bool size_average = ctx.Attr<bool>("size_average");
bool length_average = ctx.Attr<bool>("length_average");

if ((norm_by_times && size_average) || (norm_by_times && length_average) ||
(size_average && length_average)) {
PADDLE_THROW(platform::errors::InvalidArgument(
"[warpctc grad] norm_by_times, size_average and length_average "
"should one be true."));
}

if (ctx.HasInput("LogitsLength")) {
size_t max_seq_length = warpctc_grad->dims()[0];
size_t num_sequences = warpctc_grad->dims()[1];
size_t seq_width = warpctc_grad->dims()[2];
size_t Tmax = warpctc_grad->dims()[0];
size_t B = warpctc_grad->dims()[1];
size_t D = warpctc_grad->dims()[2];

auto* logits_length = ctx.Input<framework::Tensor>("LogitsLength");
framework::Tensor logits_length_cpu;
framework::TensorCopy(*logits_length, platform::CPUPlace(),
&logits_length_cpu);

// total logits length
int total_length = 0;
auto* length_ptr = logits_length_cpu.data<int64_t>();
for (size_t i = 0; i < B; i++) {
total_length += length_ptr[i];
}
VLOG(3) << "[warpctc grad] total logits length: " << total_length;
if (length_average) {
T scale = 1.0;
scale = 1.0 / static_cast<T>(total_length);
VLOG(3) << "[warpctc grad][length_average] scale: " << scale
<< "total logits len: " << total_length;
} else if (size_average) {
T scale = 1.0;
scale = 1.0 / static_cast<T>(B);
VLOG(3) << "[warpctc grad][size_average] scale: " << scale
<< "Batchsize: " << B;
}

LoDTensor logits_grad_with_lod;
auto logits_grad_dims =
framework::make_ddim({static_cast<int64_t>(max_seq_length),
static_cast<int64_t>(num_sequences),
static_cast<int64_t>(seq_width)});
auto logits_grad_dims = framework::make_ddim({static_cast<int64_t>(Tmax),
static_cast<int64_t>(B),
static_cast<int64_t>(D)});
T* logits_grad_cpu_data = logits_grad_with_lod.mutable_data<T>(
logits_grad_dims, platform::CPUPlace());

Expand All @@ -397,25 +425,34 @@ class WarpCTCGradKernel : public framework::OpKernel<T> {
scaled_logits.mutable_data<T>(logits_grad_dims, platform::CPUPlace());

const T* loss_grad_data = loss_grad_cpu.data<T>();
for (size_t i = 0; i < max_seq_length; ++i) {
for (size_t j = 0; j < num_sequences; ++j) {
for (size_t i = 0; i < Tmax; ++i) {
for (size_t j = 0; j < B; ++j) {
T scale = 1.0;
if (norm_by_times) {
scale = 1.0 / static_cast<T>(logits_length_cpu.data<int64_t>()[j]);
if (length_average) {
// Compute the avg. log-probability per batch sample and frame.
// https://github.com/espnet/warp-ctc/blob/pytorch_bindings/pytorch_binding/warpctc_pytorch/__init__.py#L42
scale = 1.0 / static_cast<T>(total_length);
} else if (size_average) {
// Compute the avg. log-probability per batch sample.
// https://github.com/espnet/warp-ctc/blob/pytorch_bindings/pytorch_binding/warpctc_pytorch/__init__.py#L46
scale = 1.0 / static_cast<T>(B);
} else if (norm_by_times) {
auto len = static_cast<T>(logits_length_cpu.data<int64_t>()[j]);
scale = 1.0 / len;
}
for (size_t k = 0; k < seq_width; ++k) {
size_t idx = i * (num_sequences * seq_width) + j * seq_width + k;
for (size_t k = 0; k < D; ++k) {
size_t idx = i * (B * D) + j * D + k;
scaled_logits_data[idx] =
logits_grad_cpu_data[idx] * loss_grad_data[j] * scale;
}
}
}

TensorCopySync(scaled_logits, ctx.GetPlace(), logits_grad);
} else {
math::UnpaddingLoDTensorFunctor<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), *warpctc_grad,
logits_grad, -1, 0, norm_by_times, math::kLengthBatchWidth);
logits_grad, -1, 0, norm_by_times, size_average, length_average,
math::kLengthBatchWidth);

const T* loss_grad_data = loss_grad->data<T>();
math::ScaleLoDTensorFunctor<DeviceContext, T>()(
Expand Down
Loading