|
2 | 2 | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
3 | 3 | """Conv Layer Class.""" |
4 | 4 |
|
| 5 | +import math |
| 6 | + |
5 | 7 | import torch |
6 | 8 | import torch.nn as nn |
7 | 9 | import torch.nn.functional as F |
8 | 10 |
|
9 | 11 | from vllm.model_executor.custom_op import CustomOp |
10 | 12 | from vllm.model_executor.layers.linear import ReplicatedLinear |
11 | 13 | from vllm.model_executor.layers.quantization.base_config import QuantizationConfig |
| 14 | +from vllm.model_executor.utils import set_weight_attrs |
12 | 15 |
|
13 | 16 |
|
14 | | -@CustomOp.register("conv") |
15 | 17 | class ConvLayerBase(CustomOp): |
16 | 18 | """Conv layer base class.""" |
17 | 19 |
|
18 | | - def __init__( |
19 | | - self, |
20 | | - ) -> None: |
21 | | - super().__init__() |
22 | | - |
23 | | - |
24 | | -class Conv2dLayer(ConvLayerBase): |
25 | | - """Conv layer with Conv2d.""" |
| 20 | + num_dim: int |
26 | 21 |
|
27 | 22 | def __init__( |
28 | 23 | self, |
29 | 24 | in_channels: int, |
30 | 25 | out_channels: int, |
31 | | - kernel_size: int | tuple, |
32 | | - stride: int | tuple | None, |
33 | | - padding: int | tuple | str | None, |
34 | | - dilation: int | tuple | None, |
35 | | - groups: int | None, |
36 | | - bias: bool | None, |
37 | | - padding_mode: str | None, |
| 26 | + kernel_size: int | tuple[int, ...], |
| 27 | + stride: int | tuple[int, ...] = 1, |
| 28 | + padding: int | tuple[int, ...] = 0, |
| 29 | + dilation: int | tuple[int, ...] = 1, |
| 30 | + groups: int = 1, |
| 31 | + bias: bool = True, |
| 32 | + padding_mode: str = "zeros", |
| 33 | + *, |
| 34 | + params_dtype: torch.dtype | None = None, |
38 | 35 | ) -> None: |
39 | 36 | super().__init__() |
40 | 37 |
|
| 38 | + if params_dtype is None: |
| 39 | + params_dtype = torch.get_default_dtype() |
| 40 | + |
| 41 | + kernel_size = ( |
| 42 | + (kernel_size,) * self.num_dim |
| 43 | + if isinstance(kernel_size, int) |
| 44 | + else kernel_size |
| 45 | + ) |
| 46 | + stride = (stride,) * self.num_dim if isinstance(stride, int) else stride |
| 47 | + padding = (padding,) * self.num_dim if isinstance(padding, int) else padding |
| 48 | + dilation = (dilation,) * self.num_dim if isinstance(dilation, int) else dilation |
| 49 | + |
41 | 50 | self.in_channels = in_channels |
42 | 51 | self.out_channels = out_channels |
43 | 52 | self.kernel_size = kernel_size |
44 | 53 | self.stride = stride |
45 | 54 | self.padding = padding |
46 | 55 | self.dilation = dilation |
47 | 56 | self.groups = groups |
48 | | - self.bias = bias |
49 | 57 | self.padding_mode = padding_mode |
50 | 58 |
|
51 | | - self.proj = nn.Conv2d( |
52 | | - in_channels=in_channels, |
53 | | - out_channels=out_channels, |
54 | | - kernel_size=kernel_size, |
55 | | - stride=stride, |
56 | | - padding=padding, |
57 | | - dilation=dilation, |
58 | | - groups=groups, |
59 | | - bias=bias, |
60 | | - padding_mode=padding_mode, |
| 59 | + self.can_linearize = ( |
| 60 | + (self.kernel_size == self.stride) |
| 61 | + and not any(self.padding) |
| 62 | + and self.groups == 1 |
61 | 63 | ) |
62 | 64 |
|
63 | | - def forward_native(self, x: torch.Tensor) -> torch.Tensor: |
64 | | - x = self.proj(x) |
65 | | - return x |
| 65 | + if self.can_linearize: |
| 66 | + self.weight = nn.Parameter( |
| 67 | + torch.empty( |
| 68 | + out_channels, |
| 69 | + in_channels * math.prod(self.kernel_size), |
| 70 | + dtype=params_dtype, |
| 71 | + ), |
| 72 | + ) |
| 73 | + else: |
| 74 | + self.weight = nn.Parameter( |
| 75 | + torch.empty( |
| 76 | + out_channels, |
| 77 | + in_channels // groups, |
| 78 | + *kernel_size, |
| 79 | + dtype=params_dtype, |
| 80 | + ), |
| 81 | + ) |
| 82 | + set_weight_attrs(self.weight, {"weight_loader": self.weight_loader}) |
66 | 83 |
|
67 | | - def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: |
68 | | - return self.forward_native(x) |
| 84 | + if bias: |
| 85 | + self.bias = nn.Parameter(torch.empty(self.out_channels, dtype=params_dtype)) |
| 86 | + set_weight_attrs( |
| 87 | + self.bias, |
| 88 | + { |
| 89 | + "weight_loader": self.weight_loader, |
| 90 | + }, |
| 91 | + ) |
| 92 | + else: |
| 93 | + self.register_parameter("bias", None) |
69 | 94 |
|
70 | 95 | def extra_repr(self) -> str: |
71 | 96 | s = f"in_channels={self.in_channels}, " |
72 | 97 | s += f"out_channels={self.out_channels}, " |
73 | 98 | s += f"kernel_size={self.kernel_size}, " |
74 | 99 | s += f"stride={self.stride}, " |
75 | 100 | s += f"padding={self.padding}, " |
76 | | - s += f"bias={self.bias}, " |
| 101 | + s += f"bias={self.bias is not None}" |
77 | 102 | return s |
78 | 103 |
|
| 104 | + def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): |
| 105 | + param.data.copy_(loaded_weight.view(param.shape)) |
| 106 | + |
| 107 | + |
| 108 | +@CustomOp.register("conv2d") |
| 109 | +class Conv2dLayer(ConvLayerBase): |
| 110 | + """Conv layer with Conv2d.""" |
| 111 | + |
| 112 | + num_dim = 2 |
| 113 | + |
| 114 | + def forward_native(self, x: torch.Tensor) -> torch.Tensor: |
| 115 | + """Expected input shape: (batch_size, in_channels, height, width)""" |
| 116 | + assert x.dim() == 4 |
| 117 | + if self.can_linearize: |
| 118 | + B, C, H, W = x.shape |
| 119 | + K1, K2 = self.kernel_size |
| 120 | + H, W = H // K1, W // K2 |
| 121 | + x = x.view(-1, self.in_channels * math.prod(self.kernel_size)) |
| 122 | + x = F.linear(x, self.weight, self.bias) |
| 123 | + x = x.view(B, self.out_channels, H, W) |
| 124 | + else: |
| 125 | + x = F.conv2d( |
| 126 | + x, |
| 127 | + self.weight, |
| 128 | + self.bias, |
| 129 | + stride=self.stride, |
| 130 | + padding=self.padding, |
| 131 | + dilation=self.dilation, |
| 132 | + groups=self.groups, |
| 133 | + ) |
| 134 | + return x |
| 135 | + |
| 136 | + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: |
| 137 | + return self.forward_native(x) |
| 138 | + |
| 139 | + |
| 140 | +@CustomOp.register("conv3d") |
| 141 | +class Conv3dLayer(ConvLayerBase): |
| 142 | + """Conv layer with Conv3d.""" |
| 143 | + |
| 144 | + num_dim = 3 |
| 145 | + |
| 146 | + def forward_native(self, x: torch.Tensor) -> torch.Tensor: |
| 147 | + """Expected input shape: (batch_size, in_channels, time, height, width)""" |
| 148 | + assert x.dim() == 5 |
| 149 | + if self.can_linearize: |
| 150 | + B, C, T, H, W = x.shape |
| 151 | + K1, K2, K3 = self.kernel_size |
| 152 | + T, H, W = T // K1, H // K2, W // K3 |
| 153 | + x = x.view(-1, self.in_channels * math.prod(self.kernel_size)) |
| 154 | + x = F.linear(x, self.weight, self.bias) |
| 155 | + x = x.view(B, self.out_channels, T, H, W) |
| 156 | + else: |
| 157 | + x = F.conv3d( |
| 158 | + x, |
| 159 | + self.weight, |
| 160 | + self.bias, |
| 161 | + stride=self.stride, |
| 162 | + padding=self.padding, |
| 163 | + dilation=self.dilation, |
| 164 | + groups=self.groups, |
| 165 | + ) |
| 166 | + return x |
| 167 | + |
| 168 | + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: |
| 169 | + return self.forward_native(x) |
| 170 | + |
79 | 171 |
|
80 | 172 | class CausalConv2dLayer(Conv2dLayer): |
81 | 173 | """ |
|
0 commit comments