Skip to content
Closed
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/paddle/nn/initializer/Bilinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from paddle import _C_ops

from ...base import core, framework, unique_name
from ...base.framework import _current_expected_place, in_dygraph_mode
from ...base.framework import _current_expected_place, in_dynamic_or_pir_mode
from .initializer import Initializer

__all__ = []
Expand Down Expand Up @@ -140,7 +140,7 @@ def forward(self, var, block=None):
if np.prod(shape) > 1024 * 1024:
raise ValueError("The size of input is too big. ")

if in_dygraph_mode():
if in_dynamic_or_pir_mode():
_C_ops.assign_value_(
out_var,
list(shape),
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/nn/initializer/assign.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from ...base import core, framework, unique_name
from ...base.data_feeder import check_type
from ...base.framework import _current_expected_place, in_dygraph_mode
from ...base.framework import _current_expected_place, in_dynamic_or_pir_mode
from .initializer import Initializer

__all__ = []
Expand Down Expand Up @@ -98,7 +98,7 @@ def forward(self, var, block=None):
"saving it to file and 'load_op' to load it"
)

if in_dygraph_mode():
if in_dynamic_or_pir_mode():
_C_ops.assign_value_(
out_var,
list(self._value.shape),
Expand Down
107 changes: 107 additions & 0 deletions test/legacy_test/test_initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from paddle import base
from paddle.base import framework
from paddle.base.core import VarDesc
from paddle.pir_utils import test_with_pir_api
from paddle.regularizer import L2Decay

DELTA = 0.00001
Expand Down Expand Up @@ -726,6 +727,59 @@ def test_type_error(self):
self.assertRaises(TypeError, self.test_bilinear_initializer, 'int32')


class TestBilinearInitializerPir(unittest.TestCase):
def setUp(self):
self.init_uniform_op_name = 'pd_op.uniform'
self.init_normal_op_name = 'pd_op.assign_value'
self.set_parameter_op_name = 'builtin.set_parameter'

def get_init_ops_by_op_name(self, block, op_name):
checked_ops = []
for op in block.ops:
# get init op
if op_name == op.name():
checked_ops.append(op)
return checked_ops

def test_bilinear_initializer(self, dtype="float32"):
"""Test the bilinear initializer with supplied arguments"""
with paddle.pir_utils.IrGuard():
main = paddle.static.Program()
startup = paddle.static.Program()
with paddle.static.program_guard(main, startup):
param = paddle.pir.core.create_parameter(
dtype=dtype,
shape=[5, 10],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shape 应与 TestBilinearInitializer.test_bilinear_initializer 单测对齐

name="param",
initializer=paddle.nn.initializer.Bilinear(),
)

block = startup.global_block()
checked_ops = self.get_init_ops_by_op_name(
block, self.init_uniform_op_name
)
num_ops = 2 if dtype in ["float16", "uint16", "float64"] else 1
init_op = checked_ops[0]
self.assertEqual(len(checked_ops), 1)
return block

def test_bilinear_initializer_fp64(self):
self.test_bilinear_initializer(dtype='float64')

def test_bilinear_initializer_fp16(self):
"""Test the bilinear initializer with supplied arguments"""
block = self.test_bilinear_initializer("float16")
self.assertTrue(check_cast_op(block.ops[1]))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果要检查 pir 的 block 里是否含有 cast op,需要重新为写一个 check_cast_op 函数。当前的 check_cast_op 函数并不适配 pir 模式。而且,pir 模式下 block.ops[1] 也不一定是 cast op

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

所以直接不检查是不是更合适一点?


def test_bilinear_initializer_bf16(self):
"""Test the bilinear initializer with supplied arguments"""
block = self.test_bilinear_initializer("uint16")
self.assertTrue(check_cast_op(block.ops[1]))

def test_type_error(self):
self.assertRaises(TypeError, self.test_bilinear_initializer, 'int32')


class TestBilinearInitializerDygraphAPI(unittest.TestCase):
def func_test_case(self):
factor = 2
Expand Down Expand Up @@ -811,6 +865,57 @@ def test_numpy_array_initializer_bf16(self):
self.assertTrue(block.ops[1])


class TestNumpyArrayInitializerPir(unittest.TestCase):
def setUp(self):
self.init_uniform_op_name = 'pd_op.uniform'
self.init_normal_op_name = 'pd_op.gaussian'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.init_op_name = pd_op.assign_value

self.set_parameter_op_name = 'builtin.set_parameter'

def get_init_ops_by_op_name(self, block, op_name):
checked_ops = []
for op in block.ops:
# get init op
if op_name == op.name():
checked_ops.append(op)
return checked_ops

def test_numpy_array_initializer(self, dtype="float32"):
"""Test the numpy array initializer with supplied arguments"""
import numpy

main = paddle.static.Program()
startup = paddle.static.Program()
np_array = numpy.random.random(10000).astype(dtype)
with paddle.static.program_guard(main, startup):
param = paddle.pir.core.create_parameter(
dtype=dtype,
shape=[5, 10],
name="param",
initializer=paddle.nn.initializer.Assign(np_array),
)

block = startup.global_block()
checked_ops = self.get_init_ops_by_op_name(
block, self.init_uniform_op_name
)
num_ops = 2 if dtype in ["float16", "uint16"] else 1
self.assertEqual(len(checked_ops), num_ops)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同理,也不需要做 op 数量的判断

init_op = checked_ops[0]
self.assertEqual(init_op.type, 'assign_value')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不用再去检查 init_op.type,因为前面的 self.get_init_ops_by_op_name 就是根据 op_name 去获取 op 的。这里改成

self.assertEqual(len(checked_ops), 1)

即可

assert (init_op.attr('fp32_values') == np_array).all()
return block

def test_numpy_array_initializer_fp16(self):
"""Test the numpy array initializer with float16"""
block = self.test_numpy_array_initializer("float16")
self.assertTrue(block.ops[1])

def test_numpy_array_initializer_bf16(self):
"""Test the numpy array initializer with bfloat16"""
block = self.test_numpy_array_initializer("uint16")
self.assertTrue(block.ops[1])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这行不需要了



class TestSetGlobalInitializer(unittest.TestCase):
def test_set_global_weight_initilizer(self):
"""Test Set Global Param initilizer with UniformInitializer"""
Expand Down Expand Up @@ -1026,6 +1131,7 @@ def run_dynamic_graph(dtype):
)
return w

@test_with_pir_api
def run_static_graph(dtype):
with static_guard():
exe = paddle.static.Executor(paddle.CPUPlace())
Expand Down Expand Up @@ -1064,6 +1170,7 @@ def run_dynamic_graph(dtype):
)
return w

@test_with_pir_api
def run_static_graph(dtype):
with static_guard():
exe = paddle.static.Executor(paddle.CPUPlace())
Expand Down