Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
154cadb
Merge pull request #1 from pytorch/master
ekagra-ranjan Feb 14, 2019
e964780
Merge pull request #2 from pytorch/master
ekagra-ranjan Mar 9, 2019
103b25a
Merge pull request #3 from pytorch/master
ekagra-ranjan Mar 11, 2019
2925fed
Merge pull request #5 from pytorch/master
ekagra-ranjan Mar 12, 2019
70e21e7
Merge pull request #9 from pytorch/master
ekagra-ranjan Mar 25, 2019
149b436
Merge pull request #10 from pytorch/master
ekagra-ranjan Apr 9, 2019
eda387b
Merge pull request #11 from pytorch/master
ekagra-ranjan May 3, 2019
1ff32d3
Merge pull request #13 from pytorch/master
ekagra-ranjan May 31, 2019
227111b
Merge pull request #14 from pytorch/master
ekagra-ranjan Jul 4, 2019
ecda81b
add float32 to keypoint_rcnn docs
ekagra-ranjan Jul 4, 2019
511202d
add float32 to faster_rcnn docs
ekagra-ranjan Jul 4, 2019
dbdd372
add float32 to mask_rcnn
ekagra-ranjan Jul 4, 2019
0662e6b
Update faster_rcnn.py
ekagra-ranjan Jul 4, 2019
778207f
Update keypoint_rcnn.py
ekagra-ranjan Jul 4, 2019
299904a
Update mask_rcnn.py
ekagra-ranjan Jul 4, 2019
d5335a6
Update faster_rcnn.py
ekagra-ranjan Jul 4, 2019
87d6927
make keypoints float
ekagra-ranjan Jul 4, 2019
076fd78
make masks uint8
ekagra-ranjan Jul 4, 2019
8d9fbf1
Update keypoint_rcnn.py
ekagra-ranjan Jul 4, 2019
3efb26d
make labels Int64
ekagra-ranjan Jul 4, 2019
4e6299f
make labels Int64
ekagra-ranjan Jul 4, 2019
1068ff5
make labels Int64
ekagra-ranjan Jul 4, 2019
41c36f6
Add checks for boxes, labels, masks, keypoints
ekagra-ranjan Jul 4, 2019
80c409a
Merge branch 'master' into mz-rpn-float
ekagra-ranjan Jul 4, 2019
03932e9
update mask dim
ekagra-ranjan Jul 4, 2019
2fbef71
remove dtype
ekagra-ranjan Jul 4, 2019
86db726
check only if targets is not None
ekagra-ranjan Jul 4, 2019
105373e
account for targets being a list
ekagra-ranjan Jul 4, 2019
5467466
update target to be list of dict
ekagra-ranjan Jul 4, 2019
f1ec459
Update faster_rcnn.py
ekagra-ranjan Jul 4, 2019
3fffae4
Update keypoint_rcnn.py
ekagra-ranjan Jul 4, 2019
f11dc3b
allow boxes to be of float16 type as well
ekagra-ranjan Jul 9, 2019
3fb2a9e
remove checks on mask
ekagra-ranjan Jul 9, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions torchvision/models/detection/faster_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,21 @@ class FasterRCNN(GeneralizedRCNN):

The behavior of the model changes depending if it is in training or evaluation mode.

During training, the model expects both the input tensors, as well as a targets dictionary,
During training, the model expects both the input tensors, as well as a targets (list of dictionary),
containing:
- boxes (Tensor[N, 4]): the ground-truth boxes in [x0, y0, x1, y1] format, with values
- boxes (FloatTensor[N, 4]): the ground-truth boxes in [x0, y0, x1, y1] format, with values
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those are great changes, thanks!

But I'm a bit concerned that this looks like the legacy interface of torch.FloatTensor, etc.
I wonder if there is a better way of representing this? For example, in numpy, everything is a ndarray, but with different types.

Thoughts?

Copy link
Contributor Author

@ekagra-ranjan ekagra-ranjan Jul 9, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would the descriptions like boxes(Tensor[N, 4], dtype=torch.float) or boxes(FloatTensor(N, 4)) be better?

(Tensor[N, 4], dtype=torch.float) is not a valid syntax but conveys the requirement whereas FloatTensor(N, 4) would actually create a float tensor with dim = (N, 4).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's just keep this as is for now. But I'd like to remove the FloatTensor, changes, and only keep Int64Tensor instead, because boxes is not required to be Float

between 0 and H and 0 and W
- labels (Tensor[N]): the class label for each ground-truth box
- labels (Int64Tensor[N]): the class label for each ground-truth box

The model returns a Dict[Tensor] during training, containing the classification and regression
losses for both the RPN and the R-CNN.

During inference, the model requires only the input tensors, and returns the post-processed
predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
follows:
- boxes (Tensor[N, 4]): the predicted boxes in [x0, y0, x1, y1] format, with values between
- boxes (FloatTensor[N, 4]): the predicted boxes in [x0, y0, x1, y1] format, with values between
0 and H and 0 and W
- labels (Tensor[N]): the predicted labels for each image
- labels (Int64Tensor[N]): the predicted labels for each image
- scores (Tensor[N]): the scores or each prediction

Arguments:
Expand Down Expand Up @@ -298,21 +298,21 @@ def fasterrcnn_resnet50_fpn(pretrained=False, progress=True,

The behavior of the model changes depending if it is in training or evaluation mode.

During training, the model expects both the input tensors, as well as a targets dictionary,
During training, the model expects both the input tensors, as well as a targets (list of dictionary),
containing:
- boxes (``Tensor[N, 4]``): the ground-truth boxes in ``[x0, y0, x1, y1]`` format, with values
- boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x0, y0, x1, y1]`` format, with values
between ``0`` and ``H`` and ``0`` and ``W``
- labels (``Tensor[N]``): the class label for each ground-truth box
- labels (``Int64Tensor[N]``): the class label for each ground-truth box

The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
losses for both the RPN and the R-CNN.

During inference, the model requires only the input tensors, and returns the post-processed
predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
follows:
- boxes (``Tensor[N, 4]``): the predicted boxes in ``[x0, y0, x1, y1]`` format, with values between
- boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x0, y0, x1, y1]`` format, with values between
``0`` and ``H`` and ``0`` and ``W``
- labels (``Tensor[N]``): the predicted labels for each image
- labels (``Int64Tensor[N]``): the predicted labels for each image
- scores (``Tensor[N]``): the scores or each prediction

Example::
Expand Down
28 changes: 14 additions & 14 deletions torchvision/models/detection/keypoint_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ class KeypointRCNN(FasterRCNN):

The behavior of the model changes depending if it is in training or evaluation mode.

During training, the model expects both the input tensors, as well as a targets dictionary,
During training, the model expects both the input tensors, as well as a targets (list of dictionary),
containing:
- boxes (Tensor[N, 4]): the ground-truth boxes in [x0, y0, x1, y1] format, with values
- boxes (FloatTensor[N, 4]): the ground-truth boxes in [x0, y0, x1, y1] format, with values
between 0 and H and 0 and W
- labels (Tensor[N]): the class label for each ground-truth box
- keypoints (Tensor[N, K, 3]): the K keypoints location for each of the N instances, in the
- labels (Int64Tensor[N]): the class label for each ground-truth box
- keypoints (FloatTensor[N, K, 3]): the K keypoints location for each of the N instances, in the
format [x, y, visibility], where visibility=0 means that the keypoint is not visible.

The model returns a Dict[Tensor] during training, containing the classification and regression
Expand All @@ -38,11 +38,11 @@ class KeypointRCNN(FasterRCNN):
During inference, the model requires only the input tensors, and returns the post-processed
predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
follows:
- boxes (Tensor[N, 4]): the predicted boxes in [x0, y0, x1, y1] format, with values between
- boxes (FloatTensor[N, 4]): the predicted boxes in [x0, y0, x1, y1] format, with values between
0 and H and 0 and W
- labels (Tensor[N]): the predicted labels for each image
- labels (Int64Tensor[N]): the predicted labels for each image
- scores (Tensor[N]): the scores or each prediction
- keypoints (Tensor[N, K, 3]): the locations of the predicted keypoints, in [x, y, v] format.
- keypoints (FloatTensor[N, K, 3]): the locations of the predicted keypoints, in [x, y, v] format.

Arguments:
backbone (nn.Module): the network used to compute the features for the model.
Expand Down Expand Up @@ -274,12 +274,12 @@ def keypointrcnn_resnet50_fpn(pretrained=False, progress=True,

The behavior of the model changes depending if it is in training or evaluation mode.

During training, the model expects both the input tensors, as well as a targets dictionary,
During training, the model expects both the input tensors, as well as a targets (list of dictionary),
containing:
- boxes (``Tensor[N, 4]``): the ground-truth boxes in ``[x0, y0, x1, y1]`` format, with values
- boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x0, y0, x1, y1]`` format, with values
between ``0`` and ``H`` and ``0`` and ``W``
- labels (``Tensor[N]``): the class label for each ground-truth box
- keypoints (``Tensor[N, K, 3]``): the ``K`` keypoints location for each of the ``N`` instances, in the
- labels (``Int64Tensor[N]``): the class label for each ground-truth box
- keypoints (``FloatTensor[N, K, 3]``): the ``K`` keypoints location for each of the ``N`` instances, in the
format ``[x, y, visibility]``, where ``visibility=0`` means that the keypoint is not visible.

The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
Expand All @@ -288,11 +288,11 @@ def keypointrcnn_resnet50_fpn(pretrained=False, progress=True,
During inference, the model requires only the input tensors, and returns the post-processed
predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
follows:
- boxes (``Tensor[N, 4]``): the predicted boxes in ``[x0, y0, x1, y1]`` format, with values between
- boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x0, y0, x1, y1]`` format, with values between
``0`` and ``H`` and ``0`` and ``W``
- labels (``Tensor[N]``): the predicted labels for each image
- labels (``Int64Tensor[N]``): the predicted labels for each image
- scores (``Tensor[N]``): the scores or each prediction
- keypoints (``Tensor[N, K, 3]``): the locations of the predicted keypoints, in ``[x, y, v]`` format.
- keypoints (``FloatTensor[N, K, 3]``): the locations of the predicted keypoints, in ``[x, y, v]`` format.

Example::

Expand Down
28 changes: 14 additions & 14 deletions torchvision/models/detection/mask_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,24 +26,24 @@ class MaskRCNN(FasterRCNN):

The behavior of the model changes depending if it is in training or evaluation mode.

During training, the model expects both the input tensors, as well as a targets dictionary,
During training, the model expects both the input tensors, as well as a targets (list of dictionary),
containing:
- boxes (Tensor[N, 4]): the ground-truth boxes in [x0, y0, x1, y1] format, with values
- boxes (FloatTensor[N, 4]): the ground-truth boxes in [x0, y0, x1, y1] format, with values
between 0 and H and 0 and W
- labels (Tensor[N]): the class label for each ground-truth box
- masks (Tensor[N, 1, H, W]): the segmentation binary masks for each instance
- labels (Int64Tensor[N]): the class label for each ground-truth box
- masks (UInt8Tensor[N, 1, H, W]): the segmentation binary masks for each instance

The model returns a Dict[Tensor] during training, containing the classification and regression
losses for both the RPN and the R-CNN, and the mask loss.

During inference, the model requires only the input tensors, and returns the post-processed
predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
follows:
- boxes (Tensor[N, 4]): the predicted boxes in [x0, y0, x1, y1] format, with values between
- boxes (FloatTensor[N, 4]): the predicted boxes in [x0, y0, x1, y1] format, with values between
0 and H and 0 and W
- labels (Tensor[N]): the predicted labels for each image
- labels (Int64Tensor[N]): the predicted labels for each image
- scores (Tensor[N]): the scores or each prediction
- masks (Tensor[N, 1, H, W]): the predicted masks for each instance, in 0-1 range. In order to
- masks (UInt8Tensor[N, 1, H, W]): the predicted masks for each instance, in 0-1 range. In order to
obtain the final segmentation masks, the soft masks can be thresholded, generally
with a value of 0.5 (mask >= 0.5)

Expand Down Expand Up @@ -273,24 +273,24 @@ def maskrcnn_resnet50_fpn(pretrained=False, progress=True,

The behavior of the model changes depending if it is in training or evaluation mode.

During training, the model expects both the input tensors, as well as a targets dictionary,
During training, the model expects both the input tensors, as well as a targets (list of dictionary),
containing:
- boxes (``Tensor[N, 4]``): the ground-truth boxes in ``[x0, y0, x1, y1]`` format, with values
- boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x0, y0, x1, y1]`` format, with values
between ``0`` and ``H`` and ``0`` and ``W``
- labels (``Tensor[N]``): the class label for each ground-truth box
- masks (``Tensor[N, H, W]``): the segmentation binary masks for each instance
- labels (``Int64Tensor[N]``): the class label for each ground-truth box
- masks (``UInt8Tensor[N, 1, H, W]``): the segmentation binary masks for each instance

The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
losses for both the RPN and the R-CNN, and the mask loss.

During inference, the model requires only the input tensors, and returns the post-processed
predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
follows:
- boxes (``Tensor[N, 4]``): the predicted boxes in ``[x0, y0, x1, y1]`` format, with values between
- boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x0, y0, x1, y1]`` format, with values between
``0`` and ``H`` and ``0`` and ``W``
- labels (``Tensor[N]``): the predicted labels for each image
- labels (``Int64Tensor[N]``): the predicted labels for each image
- scores (``Tensor[N]``): the scores or each prediction
- masks (``Tensor[N, H, W]``): the predicted masks for each instance, in ``0-1`` range. In order to
- masks (``UInt8Tensor[N, 1, H, W]``): the predicted masks for each instance, in ``0-1`` range. In order to
obtain the final segmentation masks, the soft masks can be thresholded, generally
with a value of 0.5 (``mask >= 0.5``)

Expand Down
13 changes: 11 additions & 2 deletions torchvision/models/detection/roi_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
Arguments:
class_logits (Tensor)
box_regression (Tensor)
labels (list[BoxList])
regression_targets (Tensor)

Returns:
classification_loss (Tensor)
Expand Down Expand Up @@ -55,7 +57,7 @@ def maskrcnn_inference(x, labels):

Arguments:
x (Tensor): the mask logits
boxes (list[BoxList]): bounding boxes that are used as
labels (list[BoxList]): bounding boxes that are used as
reference, one for ech image

Returns:
Expand Down Expand Up @@ -250,7 +252,7 @@ def keypointrcnn_inference(x, boxes):

# the next two functions should be merged inside Masker
# but are kept here for the moment while we need them
# temporarily gor paste_mask_in_image
# temporarily for paste_mask_in_image
def expand_boxes(boxes, scale):
w_half = (boxes[:, 2] - boxes[:, 0]) * .5
h_half = (boxes[:, 3] - boxes[:, 1]) * .5
Expand Down Expand Up @@ -525,6 +527,13 @@ def forward(self, features, proposals, image_shapes, targets=None):
image_shapes (List[Tuple[H, W]])
targets (List[Dict])
"""
if targets is not None:
for t in targets:
assert t["boxes"].dtype.is_floating_point, 'target boxes must of float type'
assert t["labels"].dtype == torch.int64, 'target labels must of int64 type'
if self.has_keypoint:
assert t["keypoints"].dtype == torch.float32, 'target keypoints must of float type'

if self.training:
proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets)

Expand Down