-
Notifications
You must be signed in to change notification settings - Fork 7.2k
Implementation of the MNASNet family of models #829
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 13 commits
Commits
Show all changes
23 commits
Select commit
Hold shift + click to select a range
50bfbe6
Add initial mnasnet impl
1e100 e1c5506
Remove all type hints, comply with PyTorch overall style
1e100 0d77acc
Expose models
1e100 c41aaab
Remove avgpool from features() and add separately
1e100 d6115f9
Merge upstream
1e100 568bd50
Fix python3-only stuff, replace subclasses with functions
1e100 5617b8e
fix __all__
1e100 ba0ad4d
Fix typo
1e100 bd4836b
Remove conditional dropout
1e100 5ac43bd
Merge branch 'master' of github.com:1e100/vision
1e100 102ba55
Make dropout functional
1e100 9c8b827
Addressing @fmassa's feedback, round 1
1e100 2872b1f
Replaced adaptive avgpool with mean on H and W to prevent collapsing …
1e100 05b387b
Partially address feedback
1e100 2d39797
YAPF
1e100 8b5f7b9
Removed redundant class vars
1e100 8de71fe
Merge master
1e100 40471ac
Update urls to releases
1e100 b1d54ec
Add information to models.rst
1e100 ec717d0
Replace init with kaiming_normal_ in fan-out mode
1e100 8b2dba9
Use load_state_dict_from_url
1e100 06177ee
Merge master
1e100 c34df87
Merge master again
1e100 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,3 +6,4 @@ | |
| from .densenet import * | ||
| from .googlenet import * | ||
| from .mobilenet import * | ||
| from .mnasnet import * | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,163 @@ | ||
| import math | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
|
|
||
| __all__ = ['MNASNet', 'mnasnet0_5', 'mnasnet0_75', 'mnasnet1_0', 'mnasnet1_3'] | ||
|
|
||
| # Paper suggests 0.9997 momentum, for TensorFlow. Equivalent PyTorch momentum is | ||
| # 1.0 - tensorflow. | ||
| _BN_MOMENTUM = 1 - 0.9997 | ||
|
|
||
|
|
||
| class _InvertedResidual(nn.Module): | ||
1e100 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def __init__(self, in_ch, out_ch, kernel_size, stride, expansion_factor, | ||
| bn_momentum=0.1): | ||
| super(_InvertedResidual, self).__init__() | ||
| assert stride in [1, 2] | ||
| assert kernel_size in [3, 5] | ||
| mid_ch = in_ch * expansion_factor | ||
| self.apply_residual = (in_ch == out_ch and stride == 1) | ||
| self.layers = nn.Sequential( | ||
| # Pointwise | ||
| nn.Conv2d(in_ch, mid_ch, 1, bias=False), | ||
| nn.BatchNorm2d(mid_ch, momentum=bn_momentum), | ||
| nn.ReLU(inplace=True), | ||
| # Depthwise | ||
| nn.Conv2d(mid_ch, mid_ch, kernel_size, padding=kernel_size // 2, | ||
| stride=stride, groups=mid_ch, bias=False), | ||
| nn.BatchNorm2d(mid_ch, momentum=bn_momentum), | ||
| nn.ReLU(inplace=True), | ||
| # Linear pointwise. Note that there's no activation. | ||
| nn.Conv2d(mid_ch, out_ch, 1, bias=False), | ||
| nn.BatchNorm2d(out_ch, momentum=bn_momentum)) | ||
|
|
||
| def forward(self, input): | ||
| if self.apply_residual: | ||
| return self.layers(input) + input | ||
| else: | ||
| return self.layers(input) | ||
|
|
||
|
|
||
| def _stack(in_ch, out_ch, kernel_size, stride, exp_factor, repeats, | ||
| bn_momentum): | ||
| """ Creates a stack of inverted residuals. """ | ||
| assert repeats >= 1 | ||
| # First one has no skip, because feature map size changes. | ||
| first = _InvertedResidual(in_ch, out_ch, kernel_size, stride, exp_factor, | ||
| bn_momentum=bn_momentum) | ||
| remaining = [] | ||
| for _ in range(1, repeats): | ||
| remaining.append( | ||
| _InvertedResidual(out_ch, out_ch, kernel_size, 1, exp_factor, | ||
| bn_momentum=bn_momentum)) | ||
| return nn.Sequential(first, *remaining) | ||
|
|
||
|
|
||
| def _round_to_multiple_of(val, divisor, round_up_bias=0.9): | ||
| """ Asymmetric rounding to make `val` divisible by `divisor`. With default | ||
| bias, will round up, unless the number is no more than 10% greater than the | ||
| smaller divisible value, i.e. (83, 8) -> 80, but (84, 8) -> 88. """ | ||
| assert 0.0 < round_up_bias < 1.0 | ||
| new_val = max(divisor, int(val + divisor / 2) // divisor * divisor) | ||
| return new_val if new_val >= round_up_bias * val else new_val + divisor | ||
|
|
||
|
|
||
| def _scale_depths(depths, alpha): | ||
| """ Scales tensor depths as in reference MobileNet code, prefers rouding up | ||
| rather than down. """ | ||
| return [_round_to_multiple_of(depth * alpha, 8) for depth in depths] | ||
|
|
||
|
|
||
| class MNASNet(torch.nn.Module): | ||
| """ MNASNet, as described in https://arxiv.org/pdf/1807.11626.pdf. | ||
| >>> model = MNASNet(1000, 1.0) | ||
| >>> x = torch.rand(1, 3, 224, 224) | ||
| >>> y = model.forward(x) | ||
1e100 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| >>> y.dim() | ||
| 1 | ||
| >>> y.nelement() | ||
| 1000 | ||
| """ | ||
|
|
||
| def __init__(self, num_classes, alpha, dropout=0.2): | ||
| super(MNASNet, self).__init__() | ||
| self.alpha = alpha | ||
| self.num_classes = num_classes | ||
1e100 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| self.dropout = dropout | ||
| depths = _scale_depths([24, 40, 80, 96, 192, 320], alpha) | ||
| layers = [ | ||
1e100 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # First layer: regular conv. | ||
| nn.Conv2d(3, 32, 3, padding=1, stride=2, bias=False), | ||
| nn.BatchNorm2d(32, momentum=_BN_MOMENTUM), | ||
| nn.ReLU(inplace=True), | ||
| # Depthwise separable, no skip. | ||
| nn.Conv2d(32, 32, 3, padding=1, stride=1, groups=32, bias=False), | ||
| nn.BatchNorm2d(32, momentum=_BN_MOMENTUM), | ||
| nn.ReLU(inplace=True), | ||
| nn.Conv2d(32, 16, 1, padding=0, stride=1, bias=False), | ||
| nn.BatchNorm2d(16, momentum=_BN_MOMENTUM), | ||
| # MNASNet blocks: stacks of inverted residuals. | ||
| _stack(16, depths[0], 3, 2, 3, 3, _BN_MOMENTUM), | ||
| _stack(depths[0], depths[1], 5, 2, 3, 3, _BN_MOMENTUM), | ||
| _stack(depths[1], depths[2], 5, 2, 6, 3, _BN_MOMENTUM), | ||
| _stack(depths[2], depths[3], 3, 1, 6, 2, _BN_MOMENTUM), | ||
| _stack(depths[3], depths[4], 5, 2, 6, 4, _BN_MOMENTUM), | ||
| _stack(depths[4], depths[5], 3, 1, 6, 1, _BN_MOMENTUM), | ||
| # Final mapping to classifier input. | ||
| nn.Conv2d(depths[5], 1280, 1, padding=0, stride=1, bias=False), | ||
| nn.BatchNorm2d(1280, momentum=_BN_MOMENTUM), | ||
| nn.ReLU(inplace=True), | ||
| ] | ||
| self.layers = nn.Sequential(*layers) | ||
| self.classifier = nn.Linear(1280, self.num_classes) | ||
|
|
||
| self._initialize_weights() | ||
|
|
||
| def features(self, x): | ||
1e100 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| return self.layers(x) | ||
|
|
||
| def forward(self, x): | ||
| x = self.features(x) | ||
1e100 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| # Equivalent to global avgpool and removing H and W dimensions. | ||
| x = x.mean([2, 3]) | ||
| if self.dropout > 0.0: | ||
| x = nn.functional.dropout(x, p=self.dropout, training=self.training, | ||
1e100 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| inplace=True) | ||
| return self.classifier(x) | ||
|
|
||
| def _initialize_weights(self): | ||
| for m in self.modules(): | ||
| if isinstance(m, nn.Conv2d): | ||
| n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | ||
1e100 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| m.weight.data.normal_(0, math.sqrt(2.0 / n)) | ||
| if m.bias is not None: | ||
| m.bias.data.zero_() | ||
| elif isinstance(m, nn.BatchNorm2d): | ||
| m.weight.data.fill_(1.0) | ||
| m.bias.data.zero_() | ||
| elif isinstance(m, nn.Linear): | ||
| n = m.weight.size(1) | ||
| m.weight.data.normal_(0, 0.01) | ||
| m.bias.data.zero_() | ||
|
|
||
|
|
||
| def mnasnet0_5(num_classes): | ||
1e100 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """ MNASNet with depth multiplier of 0.5. """ | ||
| return MNASNet(num_classes, alpha=0.5) | ||
|
|
||
|
|
||
| def mnasnet0_75(num_classes): | ||
| """ MNASNet with depth multiplier of 0.75. """ | ||
| return MNASNet(num_classes, alpha=0.75) | ||
|
|
||
|
|
||
| def mnasnet1_0(num_classes): | ||
| """ MNASNet with depth multiplier of 1.0. """ | ||
| return MNASNet(num_classes, alpha=1.0) | ||
|
|
||
|
|
||
| def mnasnet1_3(num_classes): | ||
| """ MNASNet with depth multiplier of 1.3. """ | ||
| return MNASNet(num_classes, alpha=1.3) | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.