Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 108 additions & 0 deletions python/paddle/tests/test_ops_roi_align.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import numpy as np

import paddle
from paddle.vision.ops import roi_align, RoIAlign


class TestRoIAlign(unittest.TestCase):
def setUp(self):
self.data = np.random.rand(1, 256, 32, 32).astype('float32')
boxes = np.random.rand(3, 4)
boxes[:, 2] += boxes[:, 0] + 3
boxes[:, 3] += boxes[:, 1] + 4
self.boxes = boxes.astype('float32')
self.boxes_num = np.array([3], dtype=np.int32)

def roi_align_functional(self, output_size):
if isinstance(output_size, int):
output_shape = (3, 256, output_size, output_size)
else:
output_shape = (3, 256, output_size[0], output_size[1])

if paddle.in_dynamic_mode():
data = paddle.to_tensor(self.data)
boxes = paddle.to_tensor(self.boxes)
boxes_num = paddle.to_tensor(self.boxes_num)

align_out = roi_align(
data, boxes, boxes_num=boxes_num, output_size=output_size)
np.testing.assert_equal(align_out.shape, output_shape)

else:
data = paddle.static.data(
shape=self.data.shape, dtype=self.data.dtype, name='data')
boxes = paddle.static.data(
shape=self.boxes.shape, dtype=self.boxes.dtype, name='boxes')
boxes_num = paddle.static.data(
shape=self.boxes_num.shape,
dtype=self.boxes_num.dtype,
name='boxes_num')

align_out = roi_align(
data, boxes, boxes_num=boxes_num, output_size=output_size)

place = paddle.CPUPlace()
exe = paddle.static.Executor(place)

align_out = exe.run(paddle.static.default_main_program(),
feed={
'data': self.data,
'boxes': self.boxes,
'boxes_num': self.boxes_num
},
fetch_list=[align_out])

np.testing.assert_equal(align_out[0].shape, output_shape)

def test_roi_align_functional_dynamic(self):
self.roi_align_functional(3)
self.roi_align_functional(output_size=(3, 4))

def test_roi_align_functional_static(self):
paddle.enable_static()
self.roi_align_functional(3)
paddle.disable_static()

def test_RoIAlign(self):
roi_align_c = RoIAlign(output_size=(4, 3))
data = paddle.to_tensor(self.data)
boxes = paddle.to_tensor(self.boxes)
boxes_num = paddle.to_tensor(self.boxes_num)

align_out = roi_align_c(data, boxes, boxes_num)
np.testing.assert_equal(align_out.shape, (3, 256, 4, 3))

def test_value(self, ):
data = np.array([i for i in range(1, 17)]).reshape(1, 1, 4,
4).astype(np.float32)
boxes = np.array(
[[1., 1., 2., 2.], [1.5, 1.5, 3., 3.]]).astype(np.float32)
boxes_num = np.array([2]).astype(np.int32)
output = np.array([[[[6.]]], [[[9.75]]]], dtype=np.float32)

data = paddle.to_tensor(data)
boxes = paddle.to_tensor(boxes)
boxes_num = paddle.to_tensor(boxes_num)

roi_align_c = RoIAlign(output_size=1)
align_out = roi_align_c(data, boxes, boxes_num)
np.testing.assert_almost_equal(align_out.numpy(), output)


if __name__ == '__main__':
unittest.main()
159 changes: 159 additions & 0 deletions python/paddle/vision/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
'RoIPool',
'psroi_pool',
'PSRoIPool',
'roi_align',
'RoIAlign',
]


Expand Down Expand Up @@ -1138,3 +1140,160 @@ def forward(self, x, boxes, boxes_num):
def extra_repr(self):
main_str = 'output_size={_output_size}, spatial_scale={_spatial_scale}'
return main_str.format(**self.__dict__)


def roi_align(x,
boxes,
boxes_num,
output_size,
spatial_scale=1.0,
sampling_ratio=-1,
aligned=True,
name=None):
"""
This operator implements the roi_align layer.
Region of Interest (RoI) Align operator (also known as RoI Align) is to
perform bilinear interpolation on inputs of nonuniform sizes to obtain
fixed-size feature maps (e.g. 7*7), as described in Mask R-CNN.

Dividing each region proposal into equal-sized sections with the pooled_width
and pooled_height. Location remains the origin result.

In each ROI bin, the value of the four regularly sampled locations are
computed directly through bilinear interpolation. The output is the mean of
four locations. Thus avoid the misaligned problem.

Args:
x (Tensor): Input feature, 4D-Tensor with the shape of [N,C,H,W],
where N is the batch size, C is the input channel, H is Height,
W is weight. The data type is float32 or float64.
boxes (Tensor): Boxes (RoIs, Regions of Interest) to pool over. It
should be a 2-D Tensor of shape (num_boxes, 4). The data type is
float32 or float64. Given as [[x1, y1, x2, y2], ...], (x1, y1) is
the top left coordinates, and (x2, y2) is the bottom right coordinates.
boxes_num (Tensor): The number of boxes contained in each picture in
the batch, the data type is int32.
output_size (int or Tuple[int, int]): The pooled output size(h, w), data
type is int32. If int, h and w are both equal to output_size.
spatial_scale (float32): Multiplicative spatial scale factor to translate
ROI coords from their input scale to the scale used when pooling.
Default: 1.0
sampling_ratio (int32): number of sampling points in the interpolation
grid used to compute the output value of each pooled output bin.
If > 0, then exactly ``sampling_ratio x sampling_ratio`` sampling
points per bin are used.
If <= 0, then an adaptive number of grid points are used (computed
as ``ceil(roi_width / output_width)``, and likewise for height).
Default: -1
aligned (bool): If False, use the legacy implementation. If True, pixel
shift the box coordinates it by -0.5 for a better alignment with the
two neighboring pixel indices. This version is used in Detectron2.
Default: True
name(str, optional): For detailed information, please refer to :
ref:`api_guide_Name`. Usually name is no need to set and None by
default.

Returns:
Tensor: The output of ROIAlignOp is a 4-D tensor with shape (num_boxes,
channels, pooled_h, pooled_w). The data type is float32 or float64.

Examples:
.. code-block:: python

import paddle
from paddle.vision.ops import roi_align

data = paddle.rand([1, 256, 32, 32])
boxes = paddle.rand([3, 4])
boxes[:, 2] += boxes[:, 0] + 3
boxes[:, 3] += boxes[:, 1] + 4
boxes_num = paddle.to_tensor([3]).astype('int32')
align_out = roi_align(data, boxes, boxes_num, output_size=3)
assert align_out.shape == [3, 256, 3, 3]
"""

check_type(output_size, 'output_size', (int, tuple), 'roi_align')
if isinstance(output_size, int):
output_size = (output_size, output_size)

pooled_height, pooled_width = output_size
if in_dygraph_mode():
assert boxes_num is not None, "boxes_num should not be None in dygraph mode."
align_out = core.ops.roi_align(
x, boxes, boxes_num, "pooled_height", pooled_height, "pooled_width",
pooled_width, "spatial_scale", spatial_scale, "sampling_ratio",
sampling_ratio, "aligned", aligned)
return align_out

else:
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'roi_align')
check_variable_and_dtype(boxes, 'boxes', ['float32', 'float64'],
'roi_align')
helper = LayerHelper('roi_align', **locals())
dtype = helper.input_dtype()
align_out = helper.create_variable_for_type_inference(dtype)
inputs = {
"X": x,
"ROIs": boxes,
}
if boxes_num is not None:
inputs['RoisNum'] = boxes_num
helper.append_op(
type="roi_align",
inputs=inputs,
outputs={"Out": align_out},
attrs={
"pooled_height": pooled_height,
"pooled_width": pooled_width,
"spatial_scale": spatial_scale,
"sampling_ratio": sampling_ratio,
"aligned": aligned,
})
return align_out


class RoIAlign(Layer):
"""
This interface is used to construct a callable object of the `RoIAlign` class.
Please refer to :ref:`api_paddle_vision_ops_roi_align`.

Args:
output_size (int or tuple[int, int]): The pooled output size(h, w),
data type is int32. If int, h and w are both equal to output_size.
spatial_scale (float32, optional): Multiplicative spatial scale factor
to translate ROI coords from their input scale to the scale used
when pooling. Default: 1.0

Returns:
align_out (Tensor): The output of ROIAlign operator is a 4-D tensor with
shape (num_boxes, channels, pooled_h, pooled_w).

Examples:
.. code-block:: python

import paddle
from paddle.vision.ops import RoIAlign

data = paddle.rand([1, 256, 32, 32])
boxes = paddle.rand([3, 4])
boxes[:, 2] += boxes[:, 0] + 3
boxes[:, 3] += boxes[:, 1] + 4
boxes_num = paddle.to_tensor([3]).astype('int32')
roi_align = RoIAlign(output_size=(4, 3))
align_out = roi_align(data, boxes, boxes_num)
assert align_out.shape == [3, 256, 4, 3]
"""

def __init__(self, output_size, spatial_scale=1.0):
super(RoIAlign, self).__init__()
self._output_size = output_size
self._spatial_scale = spatial_scale

def forward(self, x, boxes, boxes_num, aligned=True):
return roi_align(
x=x,
boxes=boxes,
boxes_num=boxes_num,
output_size=self._output_size,
spatial_scale=self._spatial_scale,
aligned=aligned)