diff --git a/python/paddle/nn/layer/rnn.py b/python/paddle/nn/layer/rnn.py index 8946674cc58f92..cbde8a7f10597b 100644 --- a/python/paddle/nn/layer/rnn.py +++ b/python/paddle/nn/layer/rnn.py @@ -24,7 +24,7 @@ from paddle.base.dygraph.base import NON_PERSISTABLE_VAR_NAME_SUFFIX from paddle.base.framework import ( default_startup_program, - in_dygraph_mode, + in_dynamic_or_pir_mode, program_guard, ) from paddle.common_ops_import import Variable @@ -106,7 +106,7 @@ def rnn( """ - if in_dygraph_mode(): + if in_dynamic_or_pir_mode(): return _rnn_dynamic_graph( cell, inputs, @@ -1590,7 +1590,7 @@ def _cudnn_impl(self, inputs, initial_states, sequence_length): if not self.time_major: inputs = paddle.tensor.transpose(inputs, [1, 0, 2]) - if in_dygraph_mode(): + if in_dynamic_or_pir_mode(): out, _, state = _C_ops.rnn( inputs, initial_states, @@ -1606,29 +1606,6 @@ def _cudnn_impl(self, inputs, initial_states, sequence_length): 0, not self.training, ) - elif in_dynamic_mode(): - _, _, out, state = _legacy_C_ops.rnn( - inputs, - initial_states, - self._all_weights, - sequence_length, - self._dropout_state, - self.state_components, - 'dropout_prob', - self.dropout, - 'is_bidirec', - self.num_directions == 2, - 'input_size', - self.input_size, - 'hidden_size', - self.hidden_size, - 'num_layers', - self.num_layers, - 'mode', - self.mode, - 'is_test', - not self.training, - ) else: out = self._helper.create_variable_for_type_inference(inputs.dtype) state = [ diff --git a/test/legacy_test/op_test.py b/test/legacy_test/op_test.py index 127b5636ea230f..4192493ec46766 100644 --- a/test/legacy_test/op_test.py +++ b/test/legacy_test/op_test.py @@ -39,7 +39,12 @@ from prim_op_test import OpTestUtils, PrimForwardChecker, PrimGradChecker from testsuite import append_input_output, append_loss_ops, create_op, set_input -sys.path.append("..") +# Add test/legacy and test to sys.path +legacy_test_dir = pathlib.Path(__file__).parent # test/legacy_test +test_dir = legacy_test_dir.parent # test +sys.path.append(str(legacy_test_dir.absolute())) +sys.path.append(str(test_dir.absolute())) + from utils import static_guard from white_list import ( check_shape_white_list, @@ -66,8 +71,6 @@ ) from paddle.base.wrapped_decorator import signature_safe_contextmanager -sys.path.append(os.path.abspath(os.path.dirname(__file__))) - @signature_safe_contextmanager def paddle_static_guard(): diff --git a/test/legacy_test/test_rnn_op.py b/test/legacy_test/test_rnn_op.py index 4eb2d8332d9eca..0bdfd9ce76754f 100644 --- a/test/legacy_test/test_rnn_op.py +++ b/test/legacy_test/test_rnn_op.py @@ -15,6 +15,7 @@ import random import sys import unittest +from pathlib import Path import numpy as np from op_test import OpTest @@ -22,7 +23,10 @@ import paddle from paddle.base import core -sys.path.append("../../test/rnn") +# Add test/rnn to sys.path +legacy_test_dir = Path(__file__).parent +sys.path.append(str(legacy_test_dir.parent / "rnn")) + from convert import get_params_for_net from rnn_numpy import LSTM @@ -45,7 +49,7 @@ def rnn_wrapper( seed=0, is_test=False, ): - dropout_state_in = paddle.Tensor() + dropout_state_in = paddle.tensor.fill_constant([], "float32", 0.0) return paddle._C_ops.rnn( Input, PreState, @@ -168,7 +172,9 @@ def rocm_rnn_get_place(): } def test_output(self): - self.check_output(no_check_set=['Reserve', 'DropoutState']) + self.check_output( + no_check_set=['Reserve', 'DropoutState'], check_pir=True + ) def set_attrs(self): pass @@ -179,7 +185,9 @@ def test_grad(self): grad_check_list = ['Input', 'init_h', 'init_c'] grad_check_list.extend(var_name_list) self.check_grad( - set(grad_check_list), ['Out', 'last_hidden', 'last_cell'] + set(grad_check_list), + ['Out', 'last_hidden', 'last_cell'], + # check_pir=True, ) def test_grad_only_input(self): @@ -188,7 +196,9 @@ def test_grad_only_input(self): grad_check_list = ['Input'] grad_check_list.extend(var_name_list) self.check_grad( - set(grad_check_list), ['Out', 'last_hidden', 'last_cell'] + set(grad_check_list), + ['Out', 'last_hidden', 'last_cell'], + # check_pir=True, ) def test_grad_only_h(self): @@ -197,7 +207,9 @@ def test_grad_only_h(self): grad_check_list = ['init_h'] grad_check_list.extend(var_name_list) self.check_grad( - set(grad_check_list), ['Out', 'last_hidden', 'last_cell'] + set(grad_check_list), + ['Out', 'last_hidden', 'last_cell'], + # check_pir=True, ) def test_grad_only_c(self): @@ -206,7 +218,9 @@ def test_grad_only_c(self): grad_check_list = ['init_c'] grad_check_list.extend(var_name_list) self.check_grad( - set(grad_check_list), ['Out', 'last_hidden', 'last_cell'] + set(grad_check_list), + ['Out', 'last_hidden', 'last_cell'], + # check_pir=True, )