Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/paddle/nn/initializer/Bilinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from paddle import _C_ops

from ...base import core, framework, unique_name
from ...base.framework import _current_expected_place, in_dygraph_mode
from ...base.framework import _current_expected_place, in_dynamic_or_pir_mode
from .initializer import Initializer

__all__ = []
Expand Down Expand Up @@ -140,7 +140,7 @@ def forward(self, var, block=None):
if np.prod(shape) > 1024 * 1024:
raise ValueError("The size of input is too big. ")

if in_dygraph_mode():
if in_dynamic_or_pir_mode():
_C_ops.assign_value_(
out_var,
list(shape),
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/nn/initializer/assign.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from ...base import core, framework, unique_name
from ...base.data_feeder import check_type
from ...base.framework import _current_expected_place, in_dygraph_mode
from ...base.framework import _current_expected_place, in_dynamic_or_pir_mode
from .initializer import Initializer

__all__ = []
Expand Down Expand Up @@ -98,7 +98,7 @@ def forward(self, var, block=None):
"saving it to file and 'load_op' to load it"
)

if in_dygraph_mode():
if in_dynamic_or_pir_mode():
_C_ops.assign_value_(
out_var,
list(self._value.shape),
Expand Down
102 changes: 102 additions & 0 deletions test/legacy_test/test_initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from paddle import base
from paddle.base import framework
from paddle.base.core import VarDesc
from paddle.pir_utils import test_with_pir_api
from paddle.regularizer import L2Decay

DELTA = 0.00001
Expand Down Expand Up @@ -726,6 +727,57 @@ def test_type_error(self):
self.assertRaises(TypeError, self.test_bilinear_initializer, 'int32')


class TestBilinearInitializerPir(unittest.TestCase):
def setUp(self):
self.init_uniform_op_name = 'pd_op.uniform'
self.init_normal_op_name = 'pd_op.assign_value'
self.set_parameter_op_name = 'builtin.set_parameter'

def get_init_ops_by_op_name(self, block, op_name):
checked_ops = []
for op in block.ops:
# get init op
if op_name == op.name():
checked_ops.append(op)
return checked_ops

def test_bilinear_initializer(self, dtype="float32"):
"""Test the bilinear initializer with supplied arguments"""
with paddle.pir_utils.IrGuard():
main = paddle.static.Program()
startup = paddle.static.Program()
with paddle.static.program_guard(main, startup):
param = paddle.pir.core.create_parameter(
dtype=dtype,
shape=[8, 1, 3, 3],
name="param",
initializer=paddle.nn.initializer.Bilinear(),
)

block = startup.global_block()
checked_ops = self.get_init_ops_by_op_name(
block, self.init_uniform_op_name
)
num_ops = 2 if dtype in ["float16", "uint16", "float64"] else 1
init_op = checked_ops[0]
self.assertEqual(len(checked_ops), 1)
return block

def test_bilinear_initializer_fp64(self):
self.test_bilinear_initializer(dtype='float64')

def test_bilinear_initializer_fp16(self):
"""Test the bilinear initializer with supplied arguments"""
block = self.test_bilinear_initializer("float16")

def test_bilinear_initializer_bf16(self):
"""Test the bilinear initializer with supplied arguments"""
block = self.test_bilinear_initializer("uint16")

def test_type_error(self):
self.assertRaises(TypeError, self.test_bilinear_initializer, 'int32')


class TestBilinearInitializerDygraphAPI(unittest.TestCase):
def func_test_case(self):
factor = 2
Expand Down Expand Up @@ -811,6 +863,54 @@ def test_numpy_array_initializer_bf16(self):
self.assertTrue(block.ops[1])


class TestNumpyArrayInitializerPir(unittest.TestCase):
def setUp(self):
self.init_uniform_op_name = 'pd_op.uniform'
self.init_normal_op_name = 'pd_op.assign_value'
self.set_parameter_op_name = 'builtin.set_parameter'

def get_init_ops_by_op_name(self, block, op_name):
checked_ops = []
for op in block.ops:
# get init op
if op_name == op.name():
checked_ops.append(op)
return checked_ops

def test_numpy_array_initializer(self, dtype="float32"):
"""Test the numpy array initializer with supplied arguments"""
import numpy

main = paddle.static.Program()
startup = paddle.static.Program()
np_array = numpy.random.random(10000).astype(dtype)
with paddle.static.program_guard(main, startup):
param = paddle.pir.core.create_parameter(
dtype=np_array.dtype,
shape=np_array.shape,
name="param",
initializer=paddle.nn.initializer.Assign(np_array),
)

block = startup.global_block()
checked_ops = self.get_init_ops_by_op_name(
block, self.init_uniform_op_name
)
num_ops = 2 if dtype in ["float16", "uint16"] else 1
init_op = checked_ops[0]
self.assertEqual(len(checked_ops), 1)
assert (init_op.attrs()['fp32_values'] == np_array).all()
return block

def test_numpy_array_initializer_fp16(self):
"""Test the numpy array initializer with float16"""
block = self.test_numpy_array_initializer("float16")

def test_numpy_array_initializer_bf16(self):
"""Test the numpy array initializer with bfloat16"""
block = self.test_numpy_array_initializer("uint16")


class TestSetGlobalInitializer(unittest.TestCase):
def test_set_global_weight_initilizer(self):
"""Test Set Global Param initilizer with UniformInitializer"""
Expand Down Expand Up @@ -1026,6 +1126,7 @@ def run_dynamic_graph(dtype):
)
return w

@test_with_pir_api
def run_static_graph(dtype):
with static_guard():
exe = paddle.static.Executor(paddle.CPUPlace())
Expand Down Expand Up @@ -1064,6 +1165,7 @@ def run_dynamic_graph(dtype):
)
return w

@test_with_pir_api
def run_static_graph(dtype):
with static_guard():
exe = paddle.static.Executor(paddle.CPUPlace())
Expand Down