Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/pir/dialect/op_generator/op_build_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
'Conv2dTransposeInferMeta',
'FusedConv2dAddActInferMeta',
'InterpolateInferMeta',
'DeformableConvInferMeta',
}

_PREPARE_DATA_WITH_VECTOR_INT64_MTTABLE_ATTRIBUTE = {'FrobeniusNormOp'}
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/vision/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,7 +873,7 @@ def deform_conv2d(

use_deform_conv2d_v1 = True if mask is None else False

if in_dygraph_mode():
if in_dynamic_or_pir_mode():
pre_bias = _C_ops.deformable_conv(
x,
offset,
Expand Down
56 changes: 28 additions & 28 deletions test/legacy_test/test_deform_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import paddle
import paddle.nn.initializer as I
from paddle.pir_utils import test_with_pir_api


class TestDeformConv2D(TestCase):
Expand Down Expand Up @@ -142,38 +143,31 @@ def static_graph_case_dcn(self):
dtype=self.dtype,
)

y_v1 = paddle.static.nn.common.deformable_conv(
input=x,
offset=offset,
mask=None,
num_filters=self.out_channels,
filter_size=self.filter_shape,
y_v1 = paddle.vision.ops.DeformConv2D(
in_channels=self.in_channels,
out_channels=self.out_channels,
kernel_size=self.filter_shape,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
groups=self.groups,
deformable_groups=self.deformable_groups,
im2col_step=1,
param_attr=I.Assign(self.weight),
weight_attr=I.Assign(self.weight),
bias_attr=False if self.no_bias else I.Assign(self.bias),
modulated=False,
)
)(x, offset, None)

y_v2 = paddle.static.nn.common.deformable_conv(
input=x,
offset=offset,
mask=mask,
num_filters=self.out_channels,
filter_size=self.filter_shape,
y_v2 = paddle.vision.ops.DeformConv2D(
in_channels=self.in_channels,
out_channels=self.out_channels,
kernel_size=self.filter_shape,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
groups=self.groups,
deformable_groups=self.deformable_groups,
im2col_step=1,
param_attr=I.Assign(self.weight),
weight_attr=I.Assign(self.weight),
bias_attr=False if self.no_bias else I.Assign(self.bias),
)
)(x, offset, mask)

exe = paddle.static.Executor(self.place)
exe.run(start)
Expand Down Expand Up @@ -217,6 +211,7 @@ def dygraph_case_dcn(self):

return out_v1, out_v2

@test_with_pir_api
def _test_identity(self):
self.prepare()
static_dcn_v1, static_dcn_v2 = self.static_graph_case_dcn()
Expand Down Expand Up @@ -522,10 +517,11 @@ def _test_identity(self):
self.prepare()
static_dcn_v1, static_dcn_v2 = self.static_graph_case_dcn()
dy_dcn_v1, dy_dcn_v2 = self.dygraph_case_dcn()
(
new_static_dcn_v1,
new_static_dcn_v2,
) = self.new_api_static_graph_case_dcn()
with paddle.pir_utils.IrGuard():
(
new_static_dcn_v1,
new_static_dcn_v2,
) = self.new_api_static_graph_case_dcn()
np.testing.assert_array_almost_equal(static_dcn_v1, dy_dcn_v1)
np.testing.assert_array_almost_equal(static_dcn_v2, dy_dcn_v2)
np.testing.assert_array_almost_equal(static_dcn_v1, new_static_dcn_v1)
Expand Down Expand Up @@ -727,6 +723,7 @@ def setUp(self):


class TestDeformConv2DError(unittest.TestCase):
@test_with_pir_api
def test_input_error(self):
def test_input_rank_error():
paddle.enable_static()
Expand All @@ -737,11 +734,14 @@ def test_input_rank_error():
mask = paddle.static.data(
name='error_mask_1', shape=[0, 0, 0], dtype='float32'
)
out = paddle.static.nn.deform_conv2d(
x, offset, mask, 0, 0, deformable_groups=0
)

self.assertRaises(ValueError, test_input_rank_error)
out = paddle.vision.ops.DeformConv2D(
in_channels=0,
out_channels=0,
kernel_size=0,
deformable_groups=0,
)(x, offset, mask)

self.assertRaises(AssertionError, test_input_rank_error)


if __name__ == "__main__":
Expand Down
97 changes: 87 additions & 10 deletions test/legacy_test/test_deformable_conv_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from op_test import OpTest

import paddle
from paddle.pir_utils import test_with_pir_api

paddle.enable_static()

Expand Down Expand Up @@ -194,13 +195,14 @@ def setUp(self):
self.outputs = {'Output': output}

def test_check_output(self):
self.check_output()
self.check_output(check_pir=True)

def test_check_grad(self):
self.check_grad(
{'Input', 'Offset', 'Mask', 'Filter'},
'Output',
max_relative_error=0.05,
check_pir=True,
)

def init_test_case(self):
Expand Down Expand Up @@ -457,8 +459,84 @@ def test_invalid_groups():

self.assertRaises(ValueError, test_invalid_groups)

@test_with_pir_api
def test_error_api(self):
def test_invalid_input():
paddle.enable_static()
input = [1, 3, 32, 32]
offset = paddle.static.data(
name='offset', shape=[None, 3, 32, 32], dtype='float32'
)
mask = paddle.static.data(
name='mask', shape=[None, 3, 32, 32], dtype='float32'
)
loss = paddle.vision.ops.DeformConv2D(
in_channels=input[1], out_channels=4, kernel_size=1
)(input, offset, mask)

err_type = (
ValueError if paddle.base.framework.in_pir_mode() else TypeError
)
self.assertRaises(err_type, test_invalid_input)

def test_invalid_offset():
paddle.enable_static()
input = paddle.static.data(
name='input', shape=[None, 3, 32, 32], dtype='int32'
)
offset = paddle.static.data(
name='offset', shape=[None, 3, 32, 32], dtype='float32'
)
mask = paddle.static.data(
name='mask', shape=[None, 3, 32, 32], dtype='float32'
)
loss = paddle.vision.ops.DeformConv2D(
in_channels=input.shape[1], out_channels=4, kernel_size=1
)(input, offset, mask)

self.assertRaises(TypeError, test_invalid_offset)

def test_invalid_filter():
paddle.enable_static()
input = paddle.static.data(
name='input_filter', shape=[None, 3, 32, 32], dtype='float32'
)
offset = paddle.static.data(
name='offset_filter', shape=[None, 3, 32, 32], dtype='float32'
)
mask = paddle.static.data(
name='mask_filter', shape=[None, 3, 32, 32], dtype='float32'
)
loss = paddle.vision.ops.DeformConv2D(
in_channels=input.shape[1], out_channels=4, kernel_size=0
)(input, offset, mask)

self.assertRaises(AssertionError, test_invalid_filter)

def test_invalid_groups():
paddle.enable_static()
input = paddle.static.data(
name='input_groups', shape=[1, 1, 1, 1], dtype='float32'
)
offset = paddle.static.data(
name='offset_groups', shape=[1, 1], dtype='float32'
)
mask = paddle.static.data(
name='mask_groups', shape=[1], dtype='float32'
)
loss = paddle.vision.ops.DeformConv2D(
in_channels=input.shape[1],
out_channels=1,
kernel_size=1,
padding=1,
groups=0,
)(input, offset, mask)

self.assertRaises(ZeroDivisionError, test_invalid_groups)


class TestDeformConv2DAPI(unittest.TestCase):
@test_with_pir_api
def test_api(self):
def test_deform_conv2d_v1():
paddle.enable_static()
Expand All @@ -468,11 +546,10 @@ def test_deform_conv2d_v1():
offset = paddle.static.data(
name='offset_v1', shape=[None, 4, 32, 32], dtype='float32'
)
out = paddle.static.nn.deform_conv2d(
input, offset, None, num_filters=4, filter_size=1
)

assert out.shape == (-1, 4, 32, 32)
out = paddle.vision.ops.DeformConv2D(
in_channels=input.shape[1], out_channels=4, kernel_size=1
)(input, offset, None)
assert tuple(out.shape) == (-1, 4, 32, 32)

test_deform_conv2d_v1()

Expand All @@ -487,11 +564,11 @@ def test_deform_conv2d_v2():
mask = paddle.static.data(
name='mask_v2', shape=[None, 2, 32, 32], dtype='float32'
)
out = paddle.static.nn.deform_conv2d(
input, offset, mask, num_filters=4, filter_size=1
)
out = paddle.vision.ops.DeformConv2D(
in_channels=input.shape[1], out_channels=4, kernel_size=1
)(input, offset, mask)

assert out.shape == (-1, 4, 32, 32)
assert tuple(out.shape) == (-1, 4, 32, 32)

test_deform_conv2d_v2()

Expand Down
4 changes: 3 additions & 1 deletion test/legacy_test/test_deformable_conv_v1_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,13 +188,14 @@ def setUp(self):
self.outputs = {'Output': output}

def test_check_output(self):
self.check_output()
self.check_output(check_pir=True)

def test_check_grad(self):
self.check_grad(
['Input', 'Offset', 'Filter'],
'Output',
max_relative_error=0.05,
check_pir=True,
)

def test_check_grad_no_filter(self):
Expand All @@ -203,6 +204,7 @@ def test_check_grad_no_filter(self):
'Output',
max_relative_error=0.1,
no_grad_set={'Filter'},
check_pir=True,
)

def init_test_case(self):
Expand Down