Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/pir/dialect/op_generator/op_build_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
'FusedConv2dAddActInferMeta',
'InterpolateInferMeta',
'DeformableConvInferMeta',
'MatrixNMSInferMeta',
}

_PREPARE_DATA_WITH_VECTOR_INT64_MTTABLE_ATTRIBUTE = {'FrobeniusNormOp'}
Expand Down
144 changes: 73 additions & 71 deletions python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -2154,49 +2154,50 @@ def mm(input, mat2, name=None):


"""
if in_dynamic_or_pir_mode():
if in_dynamic_mode():
return _C_ops.matmul(input, mat2, False, False)
else:

def __check_input(x, y):
var_names = {'x': x, 'y': y}
for name, val in var_names.items():
check_variable_and_dtype(
val, name, ['float16', 'float32', 'float64'], 'mm'
def __check_input(x, y):
var_names = {'x': x, 'y': y}
for name, val in var_names.items():
check_variable_and_dtype(
val, name, ['float16', 'float32', 'float64'], 'mm'
)
x_shape = list(x.shape)
y_shape = list(y.shape)
if len(x_shape) == 1:
x_shape = [1] + x_shape
if len(y_shape) == 1:
y_shape = y_shape + [1]

# check the inner 2 dimensions
if x_shape[-1] != y_shape[-2]:
if not ((x_shape[-1] == -1) or (y_shape[-2] == -1)):
raise ValueError(
"After performing an optional transpose, Input X's width should be "
"equal to Y's width for multiplication "
"prerequisites. But received X's shape: {}, Y's shape: {}\n".format(
x_shape, y_shape
)
)
x_shape = list(x.shape)
y_shape = list(y.shape)
if len(x_shape) == 1:
x_shape = [1] + x_shape
if len(y_shape) == 1:
y_shape = y_shape + [1]

# check the inner 2 dimensions
if x_shape[-1] != y_shape[-2]:
if not ((x_shape[-1] == -1) or (y_shape[-2] == -1)):
if len(y_shape) > 2 and len(x_shape) > 2:
for i, dim_x in enumerate(x_shape[:-2]):
# don't check neg shape
if dim_x < 0 or y_shape[i] < 0:
continue
if dim_x != y_shape[i]:
raise ValueError(
"After performing an optional transpose, Input X's width should be "
"equal to Y's width for multiplication "
"prerequisites. But received X's shape: {}, Y's shape: {}\n".format(
x_shape, y_shape
)
"When the matrix is larger than 2 dimensions, the higher "
"dimensional values of the two matrices need to be equal. "
"But received x_shape[%d] != y_shape[%d]. X's shape: %s, "
"Y's shape: %s.\n" % (i, i, x_shape, y_shape)
)

if len(y_shape) > 2 and len(x_shape) > 2:
for i, dim_x in enumerate(x_shape[:-2]):
# don't check neg shape
if dim_x < 0 or y_shape[i] < 0:
continue
if dim_x != y_shape[i]:
raise ValueError(
"When the matrix is larger than 2 dimensions, the higher "
"dimensional values of the two matrices need to be equal. "
"But received x_shape[%d] != y_shape[%d]. X's shape: %s, "
"Y's shape: %s.\n" % (i, i, x_shape, y_shape)
)

__check_input(input, mat2)

__check_input(input, mat2)
if in_pir_mode():
return _C_ops.matmul(input, mat2, False, False)
else:
helper = LayerHelper('mm', **locals())
out = helper.create_variable_for_type_inference(dtype=input.dtype)
helper.append_op(
Expand Down Expand Up @@ -2514,33 +2515,33 @@ def inner(x, y, name=None):
nx = x.reshape((-1, xshape[-1]))
ny = y.reshape((-1, yshape[-1]))

def __check_input(x, y):
var_names = {'x': x, 'y': y}
for name, val in var_names.items():
check_variable_and_dtype(
val, name, ['float16', 'float32', 'float64'], 'inner'
)
x_shape = list(xshape)
y_shape = list(yshape)

# check the inner 2 dimensions
if x_shape[-1] != y_shape[-1]:
if not ((x_shape[-1] == -1) or (y_shape[-1] == -1)):
raise ValueError(
"After performing an optional transpose, Input X's last dim should be "
"equal to Y's last dim for multiplication "
"prerequisites. But received X's shape: {}, Y's shape: {}\n".format(
x_shape, y_shape
)
)

__check_input(nx, ny)

if in_dynamic_or_pir_mode():
return _C_ops.matmul(
nx, paddle.transpose(ny, [1, 0]), False, False
).reshape(dstshape)
else:

def __check_input(x, y):
var_names = {'x': x, 'y': y}
for name, val in var_names.items():
check_variable_and_dtype(
val, name, ['float16', 'float32', 'float64'], 'inner'
)
x_shape = list(xshape)
y_shape = list(yshape)

# check the inner 2 dimensions
if x_shape[-1] != y_shape[-1]:
if not ((x_shape[-1] == -1) or (y_shape[-1] == -1)):
raise ValueError(
"After performing an optional transpose, Input X's last dim should be "
"equal to Y's last dim for multiplication "
"prerequisites. But received X's shape: {}, Y's shape: {}\n".format(
x_shape, y_shape
)
)

__check_input(nx, ny)
helper = LayerHelper('inner', **locals())
out = helper.create_variable_for_type_inference(dtype=nx.dtype)
helper.append_op(
Expand Down Expand Up @@ -2584,22 +2585,23 @@ def outer(x, y, name=None):
nx = x.reshape((-1, 1))
ny = y.reshape((1, -1))

if in_dynamic_or_pir_mode():
if in_dynamic_mode():
return _C_ops.matmul(nx, ny, False, False)
else:

def __check_input(x, y):
var_names = {'x': x, 'y': y}
for name, val in var_names.items():
check_variable_and_dtype(
val,
name,
['float16', 'float32', 'float64', 'int32', 'int64'],
'outer',
)

__check_input(nx, ny)
def __check_input(x, y):
var_names = {'x': x, 'y': y}
for name, val in var_names.items():
check_variable_and_dtype(
val,
name,
['float16', 'float32', 'float64', 'int32', 'int64'],
'outer',
)

__check_input(nx, ny)
if in_pir_mode():
return _C_ops.matmul(nx, ny, False, False)
else:
helper = LayerHelper('outer', **locals())
out = helper.create_variable_for_type_inference(dtype=nx.dtype)
helper.append_op(
Expand Down
5 changes: 3 additions & 2 deletions test/legacy_test/test_activation_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,7 @@ def test_dygraph_api(self):
np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05)
paddle.enable_static()

@test_with_pir_api
def test_errors(self):
with static_guard():
with paddle.static.program_guard(paddle.static.Program()):
Expand Down Expand Up @@ -4861,7 +4862,7 @@ def test_check_grad(self):
create_test_act_fp16_class(
TestSilu, check_prim=True, enable_cinn=True, check_prim_pir=True
)
create_test_act_fp16_class(TestLogSigmoid)
create_test_act_fp16_class(TestLogSigmoid, check_pir=True)
create_test_act_fp16_class(
TestTanh, check_prim=True, check_prim_pir=True, enable_cinn=True
)
Expand Down Expand Up @@ -5051,7 +5052,7 @@ def test_check_grad(self):
TestSigmoid, check_prim=True, check_pir=True, check_prim_pir=True
)
create_test_act_bf16_class(TestSilu, check_prim=True, check_prim_pir=True)
create_test_act_bf16_class(TestLogSigmoid)
create_test_act_bf16_class(TestLogSigmoid, check_pir=True)
create_test_act_bf16_class(TestTanh, check_prim=True, check_prim_pir=True)
create_test_act_bf16_class(TestTanhshrink, check_pir=True)
create_test_act_bf16_class(TestHardShrink, check_pir=True)
Expand Down
2 changes: 2 additions & 0 deletions test/legacy_test/test_inner.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def test_multiply_dynamic_case5(self):


class TestMultiplyError(unittest.TestCase):
@test_with_pir_api
def test_errors_static_case1(self):
# test static computation graph: dtype can not be int8
paddle.enable_static()
Expand All @@ -134,6 +135,7 @@ def test_errors_static_case1(self):
y = paddle.static.data(name='y', shape=[100], dtype=np.int8)
self.assertRaises(TypeError, paddle.inner, x, y)

@test_with_pir_api
def test_errors_static_case2(self):
# test static computation graph: inputs must be broadcastable
paddle.enable_static()
Expand Down
1 change: 1 addition & 0 deletions test/legacy_test/test_log_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ def test_check_api(self):
self.check_api(axis)
self.check_api(-1, 'float64')

@test_with_pir_api
def test_errors(self):
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data(name='X1', shape=[100], dtype='int32')
Expand Down
1 change: 0 additions & 1 deletion test/legacy_test/test_logcumsumexp_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,6 @@ def test_gpu(self):

self.run_static(use_gpu=True)

# @test_with_pir_api
def test_name(self):
with base.program_guard(base.Program()):
x = paddle.static.data('x', [3, 4])
Expand Down
1 change: 1 addition & 0 deletions test/legacy_test/test_logsumexp.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ def set_attrs_addition(self):


class TestLogsumexpError(unittest.TestCase):
@test_with_pir_api
def test_errors(self):
with paddle.static.program_guard(paddle.static.Program()):
self.assertRaises(TypeError, paddle.logsumexp, 1)
Expand Down
1 change: 1 addition & 0 deletions test/legacy_test/test_matmul_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ def test_dygraph_without_out(self):


class API_TestMmError(unittest.TestCase):
@test_with_pir_api
def test_errors(self):
with paddle_static_guard():

Expand Down
2 changes: 2 additions & 0 deletions test/legacy_test/test_matrix_nms_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from op_test import OpTest

import paddle
from paddle.pir_utils import test_with_pir_api


def python_matrix_nms(
Expand Down Expand Up @@ -310,6 +311,7 @@ def set_argument(self):


class TestMatrixNMSError(unittest.TestCase):
@test_with_pir_api
def test_errors(self):
M = 1200
N = 7
Expand Down
16 changes: 9 additions & 7 deletions test/legacy_test/test_matrix_power_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ def test_dygraph(self):


class TestMatrixPowerAPIError(unittest.TestCase):
@test_with_pir_api
def test_errors(self):
input_np = np.random.random([4, 4]).astype("float64")

Expand All @@ -317,13 +318,6 @@ def test_errors(self):
)
self.assertRaises(TypeError, paddle.linalg.matrix_power, input, 2)

# When out is set, the data type must be the same as input.
input = paddle.static.data(
name="input_1", shape=[4, 4], dtype="float32"
)
out = paddle.static.data(name="output", shape=[4, 4], dtype="float64")
self.assertRaises(TypeError, paddle.linalg.matrix_power, input, 2, out)

# The number of dimensions of input must be >= 2.
input = paddle.static.data(name="input_2", shape=[4], dtype="float32")
self.assertRaises(ValueError, paddle.linalg.matrix_power, input, 2)
Expand All @@ -348,6 +342,14 @@ def test_errors(self):
ValueError, paddle.linalg.matrix_power, input, -956301312
)

def test_old_ir_errors(self):
# When out is set, the data type must be the same as input.
input = paddle.static.data(
name="input_1", shape=[4, 4], dtype="float32"
)
out = paddle.static.data(name="output", shape=[4, 4], dtype="float64")
self.assertRaises(TypeError, paddle.linalg.matrix_power, input, 2, out)


class TestMatrixPowerSingularAPI(unittest.TestCase):
def setUp(self):
Expand Down
1 change: 1 addition & 0 deletions test/legacy_test/test_maxout_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def test_dygraph_api(self):
np.testing.assert_allclose(out3_ref, out3.numpy(), rtol=1e-05)
paddle.enable_static()

@test_with_pir_api
def test_errors(self):
with paddle.static.program_guard(paddle.static.Program()):
# The input type must be Variable.
Expand Down
1 change: 1 addition & 0 deletions test/legacy_test/test_multi_dot_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ def get_inputs_and_outputs(self):

# python API test
class TestMultiDotOpError(unittest.TestCase):
@test_with_pir_api
def test_errors(self):
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
Expand Down
1 change: 1 addition & 0 deletions test/legacy_test/test_outer.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def test_multiply_dynamic(self):


class TestMultiplyError(unittest.TestCase):
@test_with_pir_api
def test_errors_static(self):
# test static computation graph: dtype can not be int8
paddle.enable_static()
Expand Down