|
1 | 1 | from functools import partial |
2 | | -from typing import Any, List, Optional |
| 2 | +from typing import Any, Optional, Sequence |
3 | 3 |
|
4 | 4 | import torch |
5 | 5 | from torch import nn |
@@ -46,9 +46,9 @@ class DeepLabV3(_SimpleSegmentationModel): |
46 | 46 |
|
47 | 47 |
|
48 | 48 | class DeepLabHead(nn.Sequential): |
49 | | - def __init__(self, in_channels: int, num_classes: int) -> None: |
| 49 | + def __init__(self, in_channels: int, num_classes: int, atrous_rates: Sequence[int] = (12, 24, 36)) -> None: |
50 | 50 | super().__init__( |
51 | | - ASPP(in_channels, [12, 24, 36]), |
| 51 | + ASPP(in_channels, atrous_rates), |
52 | 52 | nn.Conv2d(256, 256, 3, padding=1, bias=False), |
53 | 53 | nn.BatchNorm2d(256), |
54 | 54 | nn.ReLU(), |
@@ -83,7 +83,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: |
83 | 83 |
|
84 | 84 |
|
85 | 85 | class ASPP(nn.Module): |
86 | | - def __init__(self, in_channels: int, atrous_rates: List[int], out_channels: int = 256) -> None: |
| 86 | + def __init__(self, in_channels: int, atrous_rates: Sequence[int], out_channels: int = 256) -> None: |
87 | 87 | super().__init__() |
88 | 88 | modules = [] |
89 | 89 | modules.append( |
|
0 commit comments