Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions python/paddle/nn/functional/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@
__all__ = []


@param_one_alias(["x", "input"])
def one_hot(
x: Tensor,
num_classes: int,
num_classes: int = -1,
name: str | None = None,
) -> Tensor:
"""
Expand Down Expand Up @@ -72,11 +73,17 @@ def one_hot(
so it throws an exception.


.. note::
Alias Support: The parameter name ``input`` can be used as an alias for ``x``.
For example, ``one_hot(input=tensor_x, ...)`` is equivalent to ``one_hot(x=tensor_x, ...)``.


Args:
x(Tensor): Tensor with shape :math:`[N_1, N_2, ..., N_k]` ,
which contains at least one dimension. The data type is int32 or int64.
alias: ``input``.
num_classes(int): An integer defining the `num_classes` of the one hot dimension. If input `x`
is word id, `num_classes` is generally the dictionary size.
is word id, `num_classes` is generally the dictionary size. Default value: -1.
name(str|None, optional): For detailed information, please refer
to :ref:`api_guide_Name`. Usually name is no need to set and
None by default.
Expand All @@ -103,7 +110,8 @@ def one_hot(
[1., 0., 0., 0.]])

"""

if not isinstance(num_classes, paddle.pir.Value) and num_classes == -1:
num_classes = x.max() + 1
if in_dynamic_or_pir_mode():
return _C_ops.one_hot(x, num_classes)
else:
Expand Down
42 changes: 0 additions & 42 deletions test/ir/pir/test_special_op_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,48 +293,6 @@ def test_op(self):
_ = pir.translate_to_pir(main_program.desc)


class TestOneHotOpTranscriber(unittest.TestCase):
def test_mutable_attribute(self):
with paddle.pir_utils.OldIrGuard():
place = core.Place()
place.set_place(paddle.CPUPlace())
new_scope = paddle.static.Scope()
main_program = paddle.static.Program()
with (
paddle.static.scope_guard(new_scope),
paddle.static.program_guard(main_program),
):
depth = paddle.assign(np.array([10], dtype=np.int32))
label = paddle.static.data(
name="label", shape=[-1, 1], dtype="int64"
)
one_hot_label = paddle.nn.functional.one_hot(
x=label, num_classes=depth
)

_ = pir.translate_to_pir(main_program.desc)

def test_normal_attribute(self):
with paddle.pir_utils.OldIrGuard():
place = core.Place()
place.set_place(paddle.CPUPlace())
new_scope = paddle.static.Scope()
main_program = paddle.static.Program()
with (
paddle.static.scope_guard(new_scope),
paddle.static.program_guard(main_program),
):
depth = 10
label = paddle.static.data(
name="label", shape=[-1, 1], dtype="int64"
)
one_hot_label = paddle.nn.functional.one_hot(
x=label, num_classes=depth
)

_ = pir.translate_to_pir(main_program.desc)


class TestReduceOpTranscriber(unittest.TestCase):
def test_reduce_all(self):
place = core.Place()
Expand Down
66 changes: 65 additions & 1 deletion test/legacy_test/test_one_hot_v2_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import unittest

import numpy as np
from op_test import OpTest
from op_test import OpTest, get_places

import paddle
from paddle import base
Expand Down Expand Up @@ -283,6 +283,70 @@ def test_check_output(self):
self.check_output()


class TestOneHotAPI_Compatibility(unittest.TestCase):
def setUp(self):
np.random.seed(123)
paddle.enable_static()
self.places = get_places()
self.shape = [5]
self.dtype = 'int32'
self.init_data()

def init_data(self):
self.np_input = np.random.randint(0, 8, self.shape).astype(self.dtype)
self.num_classes = self.np_input.max() + 1
self.np_out = np.eye(self.num_classes)[self.np_input]

def test_dygraph_Compatibility(self):
paddle.disable_static()
x = paddle.to_tensor(self.np_input)
paddle_dygraph_out = []
# Position args (args)
out1 = paddle.nn.functional.one_hot(x, self.num_classes)
paddle_dygraph_out.append(out1)
# Key words args (kwargs) for paddle
out2 = paddle.nn.functional.one_hot(x=x, num_classes=self.num_classes)
paddle_dygraph_out.append(out2)
# Key words args for torch
out3 = paddle.nn.functional.one_hot(
input=x, num_classes=self.num_classes
)
paddle_dygraph_out.append(out3)
# default args
out4 = paddle.nn.functional.one_hot(x, -1)
paddle_dygraph_out.append(out4)
# Check
for out in paddle_dygraph_out:
np.testing.assert_allclose(self.np_out, out.numpy())
paddle.enable_static()

def test_static_Compatibility(self):
main = paddle.static.Program()
startup = paddle.static.Program()
with base.program_guard(main, startup):
x = paddle.static.data(name="x", shape=self.shape, dtype=self.dtype)
# Position args (args)
out1 = paddle.nn.functional.one_hot(x, self.num_classes)
# Key words args (kwargs) for paddle
out2 = paddle.nn.functional.one_hot(
x=x, num_classes=self.num_classes
)
# Key words args for torch
out3 = paddle.nn.functional.one_hot(
input=x, num_classes=self.num_classes
)
# default args
out4 = paddle.nn.functional.one_hot(x, -1)
exe = base.Executor(paddle.CPUPlace())
fetches = exe.run(
main,
feed={"x": self.np_input},
fetch_list=[out1, out2, out3],
)
for out in fetches:
np.testing.assert_allclose(out, self.np_out)


if __name__ == '__main__':
paddle.enable_static()
unittest.main()
Loading