Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 3 additions & 26 deletions python/paddle/nn/layer/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from paddle.base.dygraph.base import NON_PERSISTABLE_VAR_NAME_SUFFIX
from paddle.base.framework import (
default_startup_program,
in_dygraph_mode,
in_dynamic_or_pir_mode,
program_guard,
)
from paddle.common_ops_import import Variable
Expand Down Expand Up @@ -106,7 +106,7 @@ def rnn(

"""

if in_dygraph_mode():
if in_dynamic_or_pir_mode():
return _rnn_dynamic_graph(
cell,
inputs,
Expand Down Expand Up @@ -1590,7 +1590,7 @@ def _cudnn_impl(self, inputs, initial_states, sequence_length):
if not self.time_major:
inputs = paddle.tensor.transpose(inputs, [1, 0, 2])

if in_dygraph_mode():
if in_dynamic_or_pir_mode():
out, _, state = _C_ops.rnn(
inputs,
initial_states,
Expand All @@ -1606,29 +1606,6 @@ def _cudnn_impl(self, inputs, initial_states, sequence_length):
0,
not self.training,
)
elif in_dynamic_mode():
_, _, out, state = _legacy_C_ops.rnn(
inputs,
initial_states,
self._all_weights,
sequence_length,
self._dropout_state,
self.state_components,
'dropout_prob',
self.dropout,
'is_bidirec',
self.num_directions == 2,
'input_size',
self.input_size,
'hidden_size',
self.hidden_size,
'num_layers',
self.num_layers,
'mode',
self.mode,
'is_test',
not self.training,
)
else:
out = self._helper.create_variable_for_type_inference(inputs.dtype)
state = [
Expand Down
9 changes: 6 additions & 3 deletions test/legacy_test/op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,12 @@
from prim_op_test import OpTestUtils, PrimForwardChecker, PrimGradChecker
from testsuite import append_input_output, append_loss_ops, create_op, set_input

sys.path.append("..")
# Add test/legacy and test to sys.path
legacy_test_dir = pathlib.Path(__file__).parent # test/legacy_test
test_dir = legacy_test_dir.parent # test
sys.path.append(str(legacy_test_dir.absolute()))
sys.path.append(str(test_dir.absolute()))

from utils import static_guard
from white_list import (
check_shape_white_list,
Expand All @@ -66,8 +71,6 @@
)
from paddle.base.wrapped_decorator import signature_safe_contextmanager

sys.path.append(os.path.abspath(os.path.dirname(__file__)))


@signature_safe_contextmanager
def paddle_static_guard():
Expand Down
28 changes: 21 additions & 7 deletions test/legacy_test/test_rnn_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,18 @@
import random
import sys
import unittest
from pathlib import Path

import numpy as np
from op_test import OpTest

import paddle
from paddle.base import core

sys.path.append("../../test/rnn")
# Add test/rnn to sys.path
legacy_test_dir = Path(__file__).parent
sys.path.append(str(legacy_test_dir.parent / "rnn"))

from convert import get_params_for_net
from rnn_numpy import LSTM

Expand All @@ -45,7 +49,7 @@ def rnn_wrapper(
seed=0,
is_test=False,
):
dropout_state_in = paddle.Tensor()
dropout_state_in = paddle.tensor.fill_constant([], "float32", 0.0)
return paddle._C_ops.rnn(
Input,
PreState,
Expand Down Expand Up @@ -168,7 +172,9 @@ def rocm_rnn_get_place():
}

def test_output(self):
self.check_output(no_check_set=['Reserve', 'DropoutState'])
self.check_output(
no_check_set=['Reserve', 'DropoutState'], check_pir=True
)

def set_attrs(self):
pass
Expand All @@ -179,7 +185,9 @@ def test_grad(self):
grad_check_list = ['Input', 'init_h', 'init_c']
grad_check_list.extend(var_name_list)
self.check_grad(
set(grad_check_list), ['Out', 'last_hidden', 'last_cell']
set(grad_check_list),
['Out', 'last_hidden', 'last_cell'],
# check_pir=True,
)

def test_grad_only_input(self):
Expand All @@ -188,7 +196,9 @@ def test_grad_only_input(self):
grad_check_list = ['Input']
grad_check_list.extend(var_name_list)
self.check_grad(
set(grad_check_list), ['Out', 'last_hidden', 'last_cell']
set(grad_check_list),
['Out', 'last_hidden', 'last_cell'],
# check_pir=True,
)

def test_grad_only_h(self):
Expand All @@ -197,7 +207,9 @@ def test_grad_only_h(self):
grad_check_list = ['init_h']
grad_check_list.extend(var_name_list)
self.check_grad(
set(grad_check_list), ['Out', 'last_hidden', 'last_cell']
set(grad_check_list),
['Out', 'last_hidden', 'last_cell'],
# check_pir=True,
)

def test_grad_only_c(self):
Expand All @@ -206,7 +218,9 @@ def test_grad_only_c(self):
grad_check_list = ['init_c']
grad_check_list.extend(var_name_list)
self.check_grad(
set(grad_check_list), ['Out', 'last_hidden', 'last_cell']
set(grad_check_list),
['Out', 'last_hidden', 'last_cell'],
# check_pir=True,
)


Expand Down