Skip to content

Commit 6d901de

Browse files
committed
[0-size Tensor No.166、257] Add 0-size Tensor support for paddle.nn.functional.ctc_loss [fluid_ops] (PaddlePaddle#74042)
* Fix * Fix * Fix * ci * Fix * Fix * Fix
1 parent 5a8936e commit 6d901de

File tree

5 files changed

+38
-0
lines changed

5 files changed

+38
-0
lines changed

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4489,6 +4489,17 @@ bool WarpctcOpInferSymbolicShape(
44894489
infer_context->GetShapeOrDataForValue(op->operand_source(0));
44904490
const std::vector<symbol::DimExpr> &logits_shape =
44914491
logits_shape_or_data.shape();
4492+
bool logits_0_size = false;
4493+
for (size_t i = 0; i < logits_shape.size(); ++i) {
4494+
if (logits_shape[i] == 0) {
4495+
logits_0_size = true;
4496+
break;
4497+
}
4498+
}
4499+
if (logits_0_size) {
4500+
PADDLE_THROW(
4501+
common::errors::InvalidArgument("The input size can not be zero."));
4502+
}
44924503

44934504
symbol::DimExpr max_sequence_length, num_sequences;
44944505
symbol::DimExpr sequence_width = symbol::DimExpr(1);

paddle/phi/infermeta/multiary.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5645,6 +5645,9 @@ void WarpctcInferMeta(const MetaTensor& logits,
56455645
MetaTensor* loss,
56465646
MetaTensor* warpctcgrad) {
56475647
auto logits_dims = logits.dims();
5648+
if (common::product(logits_dims) == 0) {
5649+
PADDLE_THROW(errors::InvalidArgument("The input size can not be zero."));
5650+
}
56485651
int num_sequences, sequence_width, max_sequence_length;
56495652

56505653
if (logits_length && labels_length) {

python/paddle/signal.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,7 @@ def istft(
569569
fft_size = x.shape[-2]
570570

571571
if in_dynamic_mode():
572+
assert x.size != 0, 'x should not be an empty tensor.'
572573
if onesided:
573574
assert (
574575
fft_size == n_fft // 2 + 1

test/legacy_test/test_signal.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1045,5 +1045,12 @@ def test_istft(self):
10451045
),
10461046

10471047

1048+
class TestIstftException_ZeroSize(unittest.TestCase):
1049+
def test_istft(self):
1050+
self.x = np.random.random([5, 0])
1051+
with self.assertRaises(AssertionError):
1052+
paddle.signal.istft(paddle.to_tensor(self.x), 512)
1053+
1054+
10481055
if __name__ == '__main__':
10491056
unittest.main()

test/legacy_test/test_warpctc_op.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -613,8 +613,24 @@ def test_dygraph_with_lod():
613613
reduction='none',
614614
)
615615

616+
def test_dygraph_zero_size():
617+
logits = np.random.uniform(0.1, 1.0, [0, 15]).astype("float32")
618+
# labels should not be blank
619+
labels = np.random.randint(0, 15 - 1, [15, 1], dtype="int32")
620+
softmax = paddle.to_tensor(logits)
621+
labels = paddle.to_tensor(labels)
622+
623+
paddle.nn.functional.ctc_loss(
624+
log_probs=softmax,
625+
labels=labels,
626+
input_lengths=None,
627+
label_lengths=None,
628+
reduction='none',
629+
)
630+
616631
paddle.disable_static()
617632
self.assertRaises(ValueError, test_dygraph_with_lod)
633+
self.assertRaises(ValueError, test_dygraph_zero_size)
618634
paddle.enable_static()
619635

620636

0 commit comments

Comments
 (0)