|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | +"""Conv Layer Class.""" |
| 4 | + |
| 5 | +import math |
| 6 | +from collections.abc import Callable |
| 7 | + |
| 8 | +import torch |
| 9 | +import torch.nn.functional as F |
| 10 | +import torch.nn.parameter as Parameter |
| 11 | + |
| 12 | +from vllm.model_executor.custom_op import CustomOp |
| 13 | +from vllm.model_executor.utils import set_weight_attrs |
| 14 | + |
| 15 | + |
| 16 | +@CustomOp.register("conv2d") |
| 17 | +class Conv2dLayer(CustomOp): |
| 18 | + """Conv2D layer.""" |
| 19 | + |
| 20 | + def __init__( |
| 21 | + self, |
| 22 | + in_channels: int, |
| 23 | + out_channels: int, |
| 24 | + kernel_size: int | tuple, |
| 25 | + stride: int = 1, |
| 26 | + padding: int = 0, |
| 27 | + dilation: int = 1, |
| 28 | + groups: int = 1, |
| 29 | + enable_bias: bool = True, |
| 30 | + padding_mode: str = "zeros", |
| 31 | + ) -> None: |
| 32 | + super().__init__() |
| 33 | + |
| 34 | + if isinstance(kernel_size, int): |
| 35 | + kernel_size = (kernel_size, kernel_size) |
| 36 | + |
| 37 | + self.in_channels = in_channels |
| 38 | + self.out_channels = out_channels |
| 39 | + self.kernel_size = kernel_size |
| 40 | + self.stride = stride |
| 41 | + self.padding: int = padding |
| 42 | + self.dilation = dilation |
| 43 | + self.groups = groups |
| 44 | + self.enable_bias = enable_bias |
| 45 | + self.padding_mode = padding_mode |
| 46 | + |
| 47 | + self.enable_linear: bool = False |
| 48 | + if _enable_linear(kernel_size[0], stride, padding): |
| 49 | + self.enable_linear = True |
| 50 | + |
| 51 | + _create_conv_weights( |
| 52 | + self, |
| 53 | + in_channels=self.in_channels, |
| 54 | + out_channels=self.out_channels, |
| 55 | + kernel_size=self.kernel_size, |
| 56 | + groups=self.groups, |
| 57 | + enable_bias=self.enable_bias, |
| 58 | + enable_linear=self.enable_linear, |
| 59 | + weight_loader=self.weight_loader, |
| 60 | + ) |
| 61 | + |
| 62 | + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): |
| 63 | + if self.enable_linear: |
| 64 | + loaded_weight = _convert_conv_to_linear_weight(loaded_weight) |
| 65 | + param.data.copy_(loaded_weight) |
| 66 | + |
| 67 | + def forward_native(self, x: torch.Tensor) -> torch.Tensor: |
| 68 | + if self.enable_linear: |
| 69 | + x = F.linear(x, self.weight, self.bias) |
| 70 | + else: |
| 71 | + x = F.conv2d( |
| 72 | + x, |
| 73 | + self.weight, |
| 74 | + bias=self.bias, |
| 75 | + stride=self.stride, |
| 76 | + padding=self.padding, |
| 77 | + dilation=self.dilation, |
| 78 | + groups=self.groups, |
| 79 | + ) |
| 80 | + return x |
| 81 | + |
| 82 | + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: |
| 83 | + return self.forward_native(x) |
| 84 | + |
| 85 | + def extra_repr(self) -> str: |
| 86 | + s = f"in_channels={self.in_channels}, " |
| 87 | + s += f"out_channels={self.out_channels}, " |
| 88 | + s += f"kernel_size={self.kernel_size}, " |
| 89 | + s += f"stride={self.stride}, " |
| 90 | + s += f"padding={self.padding}, " |
| 91 | + s += f"bias={self.bias}, " |
| 92 | + return s |
| 93 | + |
| 94 | + |
| 95 | +@CustomOp.register("causal_conv2d") |
| 96 | +class CausalConv2dLayer(CustomOp): |
| 97 | + """ |
| 98 | + A causal version of nn.Conv2d where each location in the 2D matrix would |
| 99 | + have no access to locations on its right or down |
| 100 | + All arguments are the same as nn.Conv2d except padding which should be |
| 101 | + set as None |
| 102 | + """ |
| 103 | + |
| 104 | + def __init__( |
| 105 | + self, |
| 106 | + in_channels: int, |
| 107 | + out_channels: int, |
| 108 | + kernel_size: int | tuple, |
| 109 | + stride: int = 1, |
| 110 | + padding: int = 0, |
| 111 | + dilation: int = 1, |
| 112 | + groups: int = 1, |
| 113 | + enable_bias: bool = True, |
| 114 | + padding_mode: str = "zeros", |
| 115 | + ) -> None: |
| 116 | + super().__init__() |
| 117 | + |
| 118 | + if isinstance(kernel_size, int): |
| 119 | + kernel_size = (kernel_size, kernel_size) |
| 120 | + |
| 121 | + self.in_channels = in_channels |
| 122 | + self.out_channels = out_channels |
| 123 | + self.kernel_size = kernel_size |
| 124 | + self.stride = stride |
| 125 | + self.dilation = dilation |
| 126 | + self.groups = groups |
| 127 | + self.enable_bias = enable_bias |
| 128 | + self.padding_mode = padding_mode |
| 129 | + |
| 130 | + if padding is not None: |
| 131 | + raise ValueError( |
| 132 | + "Argument padding should be set to None for CausalConv2dLayer." |
| 133 | + ) |
| 134 | + self._left_padding: int = kernel_size[0] - 1 |
| 135 | + self._right_padding: int = stride - 1 |
| 136 | + self.padding: int = 0 |
| 137 | + |
| 138 | + self.enable_linear: bool = False |
| 139 | + if _enable_linear(kernel_size[0], stride, padding): |
| 140 | + self.enable_linear = True |
| 141 | + |
| 142 | + _create_conv_weights( |
| 143 | + self, |
| 144 | + in_channels=self.in_channels, |
| 145 | + out_channels=self.out_channels, |
| 146 | + kernel_size=self.kernel_size, |
| 147 | + groups=self.groups, |
| 148 | + enable_bias=self.enable_bias, |
| 149 | + enable_linear=self.enable_linear, |
| 150 | + weight_loader=self.weight_loader, |
| 151 | + ) |
| 152 | + |
| 153 | + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): |
| 154 | + if self.enable_linear: |
| 155 | + loaded_weight = _convert_conv_to_linear_weight(loaded_weight) |
| 156 | + param.data.copy_(loaded_weight) |
| 157 | + |
| 158 | + def forward_native(self, x: torch.Tensor) -> torch.Tensor: |
| 159 | + if self.enable_linear: |
| 160 | + x = F.linear(x, self.weight, self.bias) |
| 161 | + else: |
| 162 | + x = F.pad(x, pad=(self._left_padding, self._right_padding, 0, 0)) |
| 163 | + x = F.conv2d( |
| 164 | + x, |
| 165 | + self.weight, |
| 166 | + bias=self.bias, |
| 167 | + stride=self.stride, |
| 168 | + padding=self.padding, |
| 169 | + dilation=self.dilation, |
| 170 | + groups=self.groups, |
| 171 | + ) |
| 172 | + return x |
| 173 | + |
| 174 | + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: |
| 175 | + return self.forward_native(x) |
| 176 | + |
| 177 | + def extra_repr(self) -> str: |
| 178 | + s = f"in_channels={self.in_channels}, " |
| 179 | + s += f"out_channels={self.out_channels}, " |
| 180 | + s += f"kernel_size={self.kernel_size}, " |
| 181 | + s += f"stride={self.stride}, " |
| 182 | + s += f"padding={self.padding}, " |
| 183 | + s += f"bias={self.bias}, " |
| 184 | + return s |
| 185 | + |
| 186 | + |
| 187 | +@CustomOp.register("conv3d") |
| 188 | +class Conv3dLayer(CustomOp): |
| 189 | + """Conv3D layer with linear weight.""" |
| 190 | + |
| 191 | + def __init__( |
| 192 | + self, |
| 193 | + in_channels: int, |
| 194 | + out_channels: int, |
| 195 | + kernel_size: int | tuple, |
| 196 | + stride: int = 1, |
| 197 | + padding: int = 0, |
| 198 | + dilation: int = 1, |
| 199 | + groups: int = 1, |
| 200 | + enable_bias: bool = True, |
| 201 | + padding_mode: str = "zeros", |
| 202 | + ) -> None: |
| 203 | + super().__init__() |
| 204 | + |
| 205 | + assert isinstance(kernel_size, tuple) and len(kernel_size) == 3 |
| 206 | + |
| 207 | + self.in_channels = in_channels |
| 208 | + self.out_channels = out_channels |
| 209 | + self.kernel_size = kernel_size |
| 210 | + self.stride = stride |
| 211 | + self.padding: int = padding |
| 212 | + self.dilation = dilation |
| 213 | + self.groups = groups |
| 214 | + self.padding_mode = padding_mode |
| 215 | + self.use_linear = False |
| 216 | + |
| 217 | + self.enable_linear: bool = False |
| 218 | + if _enable_linear(kernel_size[0], stride, padding): |
| 219 | + self.enable_linear = True |
| 220 | + |
| 221 | + _create_conv_weights( |
| 222 | + self, |
| 223 | + in_channels=self.in_channels, |
| 224 | + out_channels=self.out_channels, |
| 225 | + kernel_size=self.kernel_size, |
| 226 | + groups=self.groups, |
| 227 | + enable_bias=self.enable_bias, |
| 228 | + enable_linear=self.enable_linear, |
| 229 | + weight_loader=self.weight_loader, |
| 230 | + ) |
| 231 | + |
| 232 | + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): |
| 233 | + if self.enable_linear: |
| 234 | + loaded_weight = _convert_conv_to_linear_weight(loaded_weight) |
| 235 | + param.data.copy_(loaded_weight) |
| 236 | + |
| 237 | + def forward_native(self, x: torch.Tensor) -> torch.Tensor: |
| 238 | + if self.enable_linear: |
| 239 | + x = F.linear(x, self.weight, self.bias) |
| 240 | + else: |
| 241 | + x = F.conv2d( |
| 242 | + x, |
| 243 | + self.weight, |
| 244 | + bias=self.bias, |
| 245 | + stride=self.stride, |
| 246 | + padding=self.padding, |
| 247 | + dilation=self.dilation, |
| 248 | + groups=self.groups, |
| 249 | + ) |
| 250 | + return x |
| 251 | + |
| 252 | + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: |
| 253 | + return self.forward_native(x) |
| 254 | + |
| 255 | + def extra_repr(self) -> str: |
| 256 | + s = f"in_channels={self.in_channels}, " |
| 257 | + s += f"out_channels={self.out_channels}, " |
| 258 | + s += f"kernel_size={self.kernel_size}, " |
| 259 | + s += f"stride={self.stride}, " |
| 260 | + s += f"padding={self.padding}, " |
| 261 | + s += f"bias={self.bias}, " |
| 262 | + return s |
| 263 | + |
| 264 | + |
| 265 | +def _enable_linear( |
| 266 | + kernel_size: int, |
| 267 | + stride: int, |
| 268 | + padding: int, |
| 269 | +) -> bool: |
| 270 | + assert isinstance(kernel_size, int) and isinstance(stride, int) |
| 271 | + return kernel_size == stride and padding == 0 |
| 272 | + |
| 273 | + |
| 274 | +def _create_conv_weights( |
| 275 | + layer: torch.nn.Module, |
| 276 | + in_channels: int, |
| 277 | + out_channels: int, |
| 278 | + kernel_size: tuple, |
| 279 | + groups: int, |
| 280 | + enable_bias: bool, |
| 281 | + enable_linear: bool, |
| 282 | + weight_loader: Callable, |
| 283 | +) -> None: |
| 284 | + if enable_linear: |
| 285 | + # Use linear computation for better performance. |
| 286 | + weight = Parameter( |
| 287 | + torch.empty((out_channels, in_channels * math.prod(kernel_size))) |
| 288 | + ) |
| 289 | + else: |
| 290 | + # Use normal Conv2D computation. |
| 291 | + weight = Parameter( |
| 292 | + torch.empty((out_channels, in_channels // groups, *kernel_size)) |
| 293 | + ) |
| 294 | + layer.register_parameter("weight", weight) |
| 295 | + set_weight_attrs(weight, {"weight_loader": weight_loader}) |
| 296 | + |
| 297 | + if enable_bias: |
| 298 | + bias = Parameter(torch.empty(out_channels)) |
| 299 | + layer.register_parameter("bias", bias) |
| 300 | + set_weight_attrs(bias, {"weight_loader": weight_loader}) |
| 301 | + else: |
| 302 | + layer.register_parameter("bias", None) |
| 303 | + |
| 304 | + |
| 305 | +# Due to a performance regression with Conv3D in PyTorch2.9, we reshape |
| 306 | +# Conv3D weights to Linear weights for better performance. |
| 307 | +# See: https://github.com/vllm-project/vllm/issues/27406 |
| 308 | +# and https://github.com/pytorch/pytorch/issues/166122 |
| 309 | +# FIXME(Isotr0py): Revert the PR introduces this workaround |
| 310 | +# (https://github.com/vllm-project/vllm/pull/27418), |
| 311 | +# once the performance issue is resolved in PyTorch. |
| 312 | +def _convert_conv_to_linear_weight(conv_weight: torch.Tensor) -> torch.Tensor: |
| 313 | + """ |
| 314 | + Reshape Conv2D or Conv3D weight to Linear weight. |
| 315 | + Only work when kernel_size==stride. |
| 316 | + """ |
| 317 | + out_channels = conv_weight.shape[0] |
| 318 | + linear_weight = conv_weight.reshape(out_channels, -1) |
| 319 | + return linear_weight |
0 commit comments