-
Notifications
You must be signed in to change notification settings - Fork 7.2k
Add checks to roi_heads in detection module #1091
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
update before transform
update 9/03/19
11/03/19
12/03/10 10:34pm
update 25/03/19
update 09/04/19
update 3/5/19
Update 31/05/19
Codecov Report
@@ Coverage Diff @@
## master #1091 +/- ##
==========================================
- Coverage 64.65% 64.53% -0.13%
==========================================
Files 68 68
Lines 5410 5417 +7
Branches 830 834 +4
==========================================
- Hits 3498 3496 -2
- Misses 1662 1669 +7
- Partials 250 252 +2
Continue to review full report at Codecov.
|
fmassa
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the PR!
I have a question regarding how to support fp16 training in the future. Let me know what you think
| """ | ||
| if targets is not None: | ||
| for t in targets: | ||
| assert t["boxes"].dtype == torch.float32, 'target boxes must of float type' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Those checks might not work if we want to perform fp16 training in the future.
Maybe might be better to use something like t["boxes"].dtype.is_floating_point?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, you are right! Will change it.
| assert t["boxes"].dtype == torch.float32, 'target boxes must of float type' | ||
| assert t["labels"].dtype == torch.int64, 'target labels must of int64 type' | ||
| if self.has_mask: | ||
| assert t["masks"].dtype == torch.uint8, 'target masks must of uint8 type' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do the masks need to be uint8? I thought that the code worked as well for other types, as the only place where it is used in the model we perform a cast to float?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was following the tutorial, where the dtype of mask was specified uint8. Should I remove that check?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think having the masks being uint8 is a hard-restriction, so yes, maybe remove this check for now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
| During training, the model expects both the input tensors, as well as a targets (list of dictionary), | ||
| containing: | ||
| - boxes (Tensor[N, 4]): the ground-truth boxes in [x0, y0, x1, y1] format, with values | ||
| - boxes (FloatTensor[N, 4]): the ground-truth boxes in [x0, y0, x1, y1] format, with values |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Those are great changes, thanks!
But I'm a bit concerned that this looks like the legacy interface of torch.FloatTensor, etc.
I wonder if there is a better way of representing this? For example, in numpy, everything is a ndarray, but with different types.
Thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would the descriptions like boxes(Tensor[N, 4], dtype=torch.float) or boxes(FloatTensor(N, 4)) be better?
(Tensor[N, 4], dtype=torch.float) is not a valid syntax but conveys the requirement whereas FloatTensor(N, 4) would actually create a float tensor with dim = (N, 4).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's just keep this as is for now. But I'd like to remove the FloatTensor, changes, and only keep Int64Tensor instead, because boxes is not required to be Float
|
Sorry for the delayed response. Please have a look at my response @fmassa . |
|
@fmassa Made the changes! |
fmassa
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
This PR addresses #1027 by:
boxes,masks,keypointsandlabelsattributes of target passed to detection models.