Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions python/paddle/nn/initializer/Bilinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
from paddle import _C_ops

from ...base import core, framework, unique_name
from ...base.framework import _current_expected_place, in_dygraph_mode
from ...base.framework import _current_expected_place, in_dynamic_or_pir_mode
from .initializer import Initializer


__all__ = []


Expand Down Expand Up @@ -140,7 +141,7 @@ def forward(self, var, block=None):
if np.prod(shape) > 1024 * 1024:
raise ValueError("The size of input is too big. ")

if in_dygraph_mode():
if in_dynamic_or_pir_mode():
_C_ops.assign_value_(
out_var,
list(shape),
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/nn/initializer/assign.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from ...base import core, framework, unique_name
from ...base.data_feeder import check_type
from ...base.framework import _current_expected_place, in_dygraph_mode
from ...base.framework import _current_expected_place, in_dynamic_or_pir_mode
from .initializer import Initializer

__all__ = []
Expand Down Expand Up @@ -98,7 +98,7 @@ def forward(self, var, block=None):
"saving it to file and 'load_op' to load it"
)

if in_dygraph_mode():
if in_dynamic_or_pir_mode():
_C_ops.assign_value_(
out_var,
list(self._value.shape),
Expand Down
5 changes: 5 additions & 0 deletions test/legacy_test/test_initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from paddle import base
from paddle.base import framework
from paddle.base.core import VarDesc
from paddle.pir_utils import test_with_pir_api
from paddle.regularizer import L2Decay

DELTA = 0.00001
Expand Down Expand Up @@ -541,6 +542,7 @@ def test_msra_initializer_bf16(self):


class TestBilinearInitializer(unittest.TestCase):
@test_with_pir_api
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

initializer 相关的 pir 单测无法适配旧 ir,所以需要单独新增 pir 下的单测,可以参考:#59419

def test_bilinear_initializer(self, dtype="float32"):
"""Test the bilinear initializer with supplied arguments"""
program = framework.Program()
Expand Down Expand Up @@ -628,6 +630,7 @@ def test_bilinear_initializer_fp16(self):


class TestNumpyArrayInitializer(unittest.TestCase):
@test_with_pir_api
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

def test_numpy_array_initializer(self, dtype="float32"):
"""Test the numpy array initializer with supplied arguments"""
import numpy
Expand Down Expand Up @@ -876,6 +879,7 @@ def run_dynamic_graph(dtype):
)
return w

@test_with_pir_api
def run_static_graph(dtype):
with static_guard():
exe = paddle.static.Executor(paddle.CPUPlace())
Expand Down Expand Up @@ -914,6 +918,7 @@ def run_dynamic_graph(dtype):
)
return w

@test_with_pir_api
def run_static_graph(dtype):
with static_guard():
exe = paddle.static.Executor(paddle.CPUPlace())
Expand Down