Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/paddle/nn/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
class_center_sample,
cosine_similarity,
dropout,
dropout1d,
dropout2d,
dropout3d,
feature_alpha_dropout,
Expand Down Expand Up @@ -216,6 +217,7 @@
'gumbel_softmax',
'sequence_mask',
'dropout',
'dropout1d',
'dropout2d',
'dropout3d',
'alpha_dropout',
Expand Down
69 changes: 69 additions & 0 deletions python/paddle/nn/functional/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from __future__ import annotations

import math
import warnings
from typing import TYPE_CHECKING, Literal

import numpy
Expand Down Expand Up @@ -1427,6 +1428,74 @@ def get_attrs(prog, dropout_prob, is_test, seed):
return ret


def dropout1d(
input: paddle.Tensor,
p: float = 0.5,
training: bool = True,
inplace: bool = False,
) -> paddle.Tensor:
"""
Randomly zero out entire 1D channels (feature maps) during training.

Args:
input: Input tensor of shape [C, L] (2D) or [N, C, L] (3D)
p: Probability of a channel being zeroed. Default: 0.5
training: If False, returns input unchanged. Default: True
inplace: If True, modifies input tensor in-place. Default: False
WARNING: Currently not implemented (will behave as False).
TODO: Implement in-place operation in future versions.
Default: False

Returns:
Tensor with the same shape as input, where entire channels are zeroed with probability p

Examples:
.. code-block:: python

>>> import paddle

# Case 1: 3D input (batched)
>>> x = paddle.randn([2, 3, 10]) # [N, C, L]
>>> y_train = paddle.nn.functional.dropout1d(x, p=0.2) # Training mode
>>> y_test = paddle.nn.functional.dropout1d(x, p=0.2, training=False) # Test mode
>>> print("Original first channel:", x[0, 0, :])
>>> print("Train output (may be zeroed):", y_train[0, 0, :])
>>> print("Test output (always unchanged):", y_test[0, 0, :])

# Case 2: 2D input (single sample)
>>> x = paddle.randn([3, 8]) # [C, L]
>>> y = paddle.nn.functional.dropout1d(x, p=0.5)
>>> print("Input shape:", x.shape)
>>> print("Output shape:", y.shape)
>>> print("Zeroed channels count:", paddle.sum(y == 0).item())
"""
if p < 0 or p > 1:
raise ValueError(f"dropout probability must be in [0, 1], got {p}")

ndim = input.ndim
if ndim not in [2, 3]:
raise RuntimeError(f"dropout1d expects 2D or 3D input, got {ndim}D")

if inplace:
warnings.warn(
"inplace=True is currently not supported in dropout1d and will be ignored. "
"This parameter is reserved for future implementation."
)
# TODO: Implement actual in-place operation when supported by dropout

need_squeeze = ndim == 2
if need_squeeze:
input = input.unsqueeze(0) # [C, L] -> [1, C, L]

# Apply dropout along channel dimension
result = dropout(input, p=p, axis=1, training=training)

if need_squeeze:
result = result.squeeze(0) # [1, C, L] -> [C, L]

return result


def dropout2d(
x: Tensor,
p: float = 0.5,
Expand Down
121 changes: 121 additions & 0 deletions test/legacy_test/test_dropout_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,6 +919,121 @@ def test_dygraph(self):
)


class TestDropout1DFAPI(unittest.TestCase):
def setUp(self):
np.random.seed(123)
self.places = get_places()

def check_static_result(
self, place, input_name, input_shape, training=False, p=0.0
):
paddle.enable_static()
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, startup_prog):
input_var = paddle.static.data(
name=input_name, shape=input_shape, dtype="float32"
)
res = paddle.nn.functional.dropout1d(
input=input_var, p=p, training=training
)
in_np = np.random.random(input_shape).astype("float32")
exe = base.Executor(place)
fetches = exe.run(
main_prog,
feed={input_name: in_np},
fetch_list=[res],
)

np.testing.assert_allclose(fetches[0], in_np, rtol=1e-05)

def test_static(self):
for place in self.places:
self.check_static_result(
place=place,
input_name="input_2d",
input_shape=[3, 4],
training=False,
p=0.0,
)

self.check_static_result(
place=place,
input_name="input_3d",
input_shape=[2, 3, 4],
training=False,
p=0.0,
)

self.check_static_result(
place=place,
input_name="input_2d_1",
input_shape=[3, 4],
training=False,
p=1.0,
)

self.check_static_result(
place=place,
input_name="input_3d_1",
input_shape=[2, 3, 4],
training=False,
p=1.0,
)

def test_dygraph(self):
for place in self.places:
with base.dygraph.guard(place):
# Test 2D input
in_np_2d = np.random.random([3, 4]).astype("float32")
input_2d = paddle.to_tensor(in_np_2d)
res1 = paddle.nn.functional.dropout1d(
input=input_2d, p=0.0, training=False
)
np.testing.assert_allclose(res1.numpy(), in_np_2d, rtol=1e-05)

# Test 3D input
in_np_3d = np.random.random([2, 3, 4]).astype("float32")
input_3d = paddle.to_tensor(in_np_3d)
res2 = paddle.nn.functional.dropout1d(
input=input_3d, p=0.0, training=False
)
np.testing.assert_allclose(res2.numpy(), in_np_3d, rtol=1e-05)


class TestDropout1DFAPIError(unittest.TestCase):
def test_errors(self):
paddle.enable_static()
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, startup_prog):

def test_xdim_1d():
# dimensions of x should be 2 or 3
x = paddle.static.data(name='x1', shape=[4], dtype="float32")
paddle.nn.functional.dropout1d(x)

self.assertRaises(RuntimeError, test_xdim_1d)

def test_xdim_4d():
# dimensions of x should be 2 or 3
x = paddle.static.data(
name='x2', shape=[2, 3, 4, 5], dtype="float32"
)
paddle.nn.functional.dropout1d(x)

self.assertRaises(RuntimeError, test_xdim_4d)

def test_prob_range():
# p should be in [0, 1]
x = paddle.static.data(
name='x3', shape=[2, 3, 4], dtype="float32"
)
paddle.nn.functional.dropout1d(x, p=1.5)

self.assertRaises(ValueError, test_prob_range)


class TestDropout2DFAPI(unittest.TestCase):
def setUp(self):
np.random.seed(123)
Expand Down Expand Up @@ -1404,6 +1519,12 @@ def test_p_tensor(self):
np.testing.assert_array_equal(static_res, dygraph_res)


class TestDropOut1DWithProbTensor(TestDropOutWithProbTensor):
def init_info(self):
self.shape = [2, 3, 4]
self.api = paddle.nn.functional.dropout1d


class TestDropOut2DWithProbTensor(TestDropOutWithProbTensor):
def init_info(self):
self.shape = [2, 3, 10, 10]
Expand Down