diff --git a/vllm/model_executor/layers/conv.py b/vllm/model_executor/layers/conv.py new file mode 100644 index 000000000000..e6f2d2990c24 --- /dev/null +++ b/vllm/model_executor/layers/conv.py @@ -0,0 +1,236 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Conv Layer Class.""" + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from vllm.model_executor.custom_op import CustomOp +from vllm.utils.torch_utils import is_torch_equal + + +class ConvLayerBase(CustomOp): + """Conv layer base class.""" + + num_dim: int + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int | tuple[int, ...], + stride: int | tuple[int, ...] = 1, + padding: int | tuple[int, ...] = 0, + dilation: int | tuple[int, ...] = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", + *, + params_dtype: torch.dtype | None = None, + ) -> None: + super().__init__() + + if params_dtype is None: + params_dtype = torch.get_default_dtype() + + kernel_size = ( + (kernel_size,) * self.num_dim + if isinstance(kernel_size, int) + else kernel_size + ) + stride = (stride,) * self.num_dim if isinstance(stride, int) else stride + padding = (padding,) * self.num_dim if isinstance(padding, int) else padding + dilation = (dilation,) * self.num_dim if isinstance(dilation, int) else dilation + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + self.padding_mode = padding_mode + + self.enable_linear = ( + (self.kernel_size == self.stride) + and not any(self.padding) + and self.groups == 1 + ) + self.input_size = in_channels * math.prod(self.kernel_size) + + self.weight = nn.Parameter( + torch.empty( + out_channels, + in_channels // groups, + *kernel_size, + dtype=params_dtype, + ), + ) + + if bias: + self.bias = nn.Parameter(torch.empty(self.out_channels, dtype=params_dtype)) + else: + self.register_parameter("bias", None) + + def extra_repr(self) -> str: + s = f"in_channels={self.in_channels}, " + s += f"out_channels={self.out_channels}, " + s += f"kernel_size={self.kernel_size}, " + s += f"stride={self.stride}, " + s += f"padding={self.padding}, " + s += f"bias={self.bias is not None}" + return s + + +@CustomOp.register("conv2d") +class Conv2dLayer(ConvLayerBase): + """Conv layer with Conv2d.""" + + num_dim = 2 + + def _forward_mulmat(self, x: torch.Tensor) -> torch.Tensor: + assert x.dim() == 4 + B, C, H, W = x.shape + K1, K2 = self.kernel_size + H, W = H // K1, W // K2 + x = x.unfold(2, K1, K1).unfold(3, K2, K2) + x = x.permute(0, 2, 3, 1, 4, 5).reshape(-1, self.input_size) + x = F.linear( + x, + self.weight.view(self.out_channels, self.input_size), + self.bias, + ) + x = x.view(B, H, W, self.out_channels).permute(0, 3, 1, 2) + return x + + def _forward_conv(self, x: torch.Tensor) -> torch.Tensor: + assert x.dim() == 4 + x = F.conv2d( + x, + self.weight, + self.bias, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + groups=self.groups, + ) + return x + + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + """Expected input shape: (batch_size, in_channels, height, width)""" + assert x.dim() == 4 + if self.enable_linear: + return self._forward_mulmat(x) + else: + return self._forward_conv(x) + + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + # By default, we use CUDNN's convolution ops with optimization. + return self._forward_conv(x) + + +class CausalConv2dLayer(Conv2dLayer): + """ + A causal version of nn.Conv2d where each location in the 2D matrix would + have no access to locations on its right or down + All arguments are the same as nn.Conv2d except padding which should be + set as None + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int, + padding: int = 0, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", + *, + params_dtype: torch.dtype | None = None, + ) -> None: + if padding is not None: + raise ValueError( + "Argument padding should be set to None for CausalConv2dLayer." + ) + self._left_padding: int = kernel_size - 1 + self._right_padding: int = stride - 1 + padding = 0 + + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + padding_mode, + params_dtype=params_dtype, + ) + + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + x = F.pad(x, pad=(self._left_padding, self._right_padding, 0, 0)) + x = super().forward(x) + return x + + +@CustomOp.register("conv3d") +class Conv3dLayer(ConvLayerBase): + """Conv layer with Conv3d.""" + + num_dim = 3 + + def _forward_mulmat(self, x: torch.Tensor) -> torch.Tensor: + assert x.dim() == 5 + B, C, T, H, W = x.shape + K1, K2, K3 = self.kernel_size + T, H, W = T // K1, H // K2, W // K3 + x = x.unfold(2, K1, K1).unfold(3, K2, K2).unfold(4, K3, K3) + x = x.permute(0, 2, 3, 4, 1, 5, 6, 7).reshape(-1, self.input_size) + x = F.linear( + x, + self.weight.view(self.out_channels, self.input_size), + self.bias, + ) + x = x.view(B, T, H, W, self.out_channels).permute(0, 4, 1, 2, 3) + return x + + def _forward_conv(self, x: torch.Tensor) -> torch.Tensor: + assert x.dim() == 5 + x = F.conv3d( + x, + self.weight, + self.bias, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + groups=self.groups, + ) + return x + + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + """Expected input shape: (batch_size, in_channels, time, height, width)""" + if self.enable_linear: + return self._forward_mulmat(x) + else: + return self._forward_conv(x) + + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + # PyTorch2.9.0 disabled CUDNN's Conv3D, which caused a + # significant performance regression. + # See: https://github.com/vllm-project/vllm/issues/27406 + # and https://github.com/pytorch/pytorch/issues/166122 + # By default, we use CUDNN's convolution ops with optimization. + if self.enable_linear and is_torch_equal("2.9.0"): + return self._forward_mulmat(x) + return self._forward_conv(x) diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index 50f476dfd185..5d611deb942d 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -20,6 +20,7 @@ from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.conv import Conv2dLayer from vllm.model_executor.layers.linear import ( ColumnParallelLinear, QKVParallelLinear, @@ -315,7 +316,7 @@ def __init__(self, config: CLIPVisionConfig): self.class_embedding = nn.Parameter(torch.randn(self.embed_dim)) - self.patch_embedding = nn.Conv2d( + self.patch_embedding = Conv2dLayer( in_channels=config.num_channels, out_channels=self.embed_dim, kernel_size=self.patch_size, diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index b2d4fe0c0139..6953b805653b 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -56,12 +56,12 @@ from vllm.distributed import get_tensor_model_parallel_world_size, parallel_state from vllm.distributed import utils as dist_utils from vllm.logger import init_logger +from vllm.model_executor.layers.conv import Conv3dLayer from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( ColumnParallelLinear, MergedColumnParallelLinear, QKVParallelLinear, - ReplicatedLinear, RowParallelLinear, ) from vllm.model_executor.layers.quantization import QuantizationConfig @@ -103,7 +103,6 @@ maybe_prefix, ) from .vision import ( - conv3d_to_linear_weight, get_vit_attn_backend, run_dp_sharded_mrope_vision_model, ) @@ -486,15 +485,18 @@ def __init__( self.hidden_size = hidden_size kernel_size = (temporal_patch_size, patch_size, patch_size) - self.proj = ReplicatedLinear( - in_channels * math.prod(kernel_size), + self.proj = Conv3dLayer( + in_channels, hidden_size, + kernel_size=kernel_size, + stride=kernel_size, bias=True, - return_bias=False, ) def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.proj(x) + L, C = x.shape + x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size) + x = self.proj(x).view(L, self.hidden_size) return x @@ -893,9 +895,6 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loaded_params: set[str] = set() for name, loaded_weight in weights: - if name.endswith("patch_embed.proj.weight"): - loaded_weight = conv3d_to_linear_weight(loaded_weight) - for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 23591480b160..7617929e93ac 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -26,7 +26,6 @@ # limitations under the License. """Inference-only Qwen2.5-VL model compatible with HuggingFace weights.""" -import math from collections.abc import Callable, Iterable, Mapping, Sequence from functools import lru_cache, partial from typing import Annotated, Any, Literal, TypeAlias @@ -56,12 +55,12 @@ from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.activation import get_act_and_mul_fn +from vllm.model_executor.layers.conv import Conv3dLayer from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( ColumnParallelLinear, MergedColumnParallelLinear, QKVParallelLinear, - ReplicatedLinear, RowParallelLinear, ) from vllm.model_executor.layers.quantization import QuantizationConfig @@ -110,7 +109,6 @@ maybe_prefix, ) from .vision import ( - conv3d_to_linear_weight, get_vit_attn_backend, run_dp_sharded_mrope_vision_model, ) @@ -525,15 +523,18 @@ def __init__( self.hidden_size = hidden_size kernel_size = (temporal_patch_size, patch_size, patch_size) - self.proj = ReplicatedLinear( - in_channels * math.prod(kernel_size), + self.proj = Conv3dLayer( + in_channels, hidden_size, + kernel_size=kernel_size, + stride=kernel_size, bias=False, - return_bias=False, ) def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.proj(x) + L, C = x.shape + x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size) + x = self.proj(x).view(L, self.hidden_size) return x @@ -957,9 +958,6 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loaded_params: set[str] = set() for name, loaded_weight in weights: - if name.endswith("patch_embed.proj.weight"): - loaded_weight = conv3d_to_linear_weight(loaded_weight) - for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 13b54bbe1748..5d21e249fc4c 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -25,7 +25,6 @@ # limitations under the License. """Inference-only Qwen2-VL model compatible with HuggingFace weights.""" -import math from collections.abc import Callable, Iterable, Mapping, Sequence from functools import partial from typing import Annotated, Any, Literal, TypeAlias @@ -54,9 +53,9 @@ from vllm.distributed import utils as dist_utils from vllm.logger import init_logger from vllm.model_executor.layers.activation import QuickGELU +from vllm.model_executor.layers.conv import Conv3dLayer from vllm.model_executor.layers.linear import ( ColumnParallelLinear, - ReplicatedLinear, RowParallelLinear, ) from vllm.model_executor.layers.quantization import QuantizationConfig @@ -107,7 +106,6 @@ maybe_prefix, ) from .vision import ( - conv3d_to_linear_weight, get_vit_attn_backend, run_dp_sharded_mrope_vision_model, ) @@ -566,15 +564,18 @@ def __init__( self.embed_dim = embed_dim kernel_size = (temporal_patch_size, patch_size, patch_size) - self.proj = ReplicatedLinear( - in_channels * math.prod(kernel_size), + self.proj = Conv3dLayer( + in_channels, embed_dim, + kernel_size=kernel_size, + stride=kernel_size, bias=False, - return_bias=False, ) def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.proj(x) + L, C = x.shape + x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size) + x = self.proj(x).view(L, self.embed_dim) return x @@ -844,9 +845,6 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loaded_params: set[str] = set() for name, loaded_weight in weights: - if name.endswith("patch_embed.proj.weight"): - loaded_weight = conv3d_to_linear_weight(loaded_weight) - for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue diff --git a/vllm/model_executor/models/qwen3_omni_moe_thinker.py b/vllm/model_executor/models/qwen3_omni_moe_thinker.py index 5df2372a842c..40b80ce2387c 100755 --- a/vllm/model_executor/models/qwen3_omni_moe_thinker.py +++ b/vllm/model_executor/models/qwen3_omni_moe_thinker.py @@ -22,7 +22,6 @@ # limitations under the License. """Inference-only Qwen3-Omni-Moe model (thinker part).""" -import math from collections.abc import Callable, Iterable, Mapping, Sequence from functools import partial from typing import Any @@ -54,9 +53,9 @@ from vllm.distributed import get_pp_group from vllm.logger import init_logger from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY +from vllm.model_executor.layers.conv import Conv3dLayer from vllm.model_executor.layers.linear import ( ColumnParallelLinear, - ReplicatedLinear, RowParallelLinear, ) from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -102,7 +101,6 @@ maybe_prefix, ) from .vision import ( - conv3d_to_linear_weight, get_llm_pos_ids_for_vision, get_vit_attn_backend, ) @@ -138,16 +136,18 @@ def __init__( self.hidden_size = hidden_size kernel_size = (temporal_patch_size, patch_size, patch_size) - self.proj = ReplicatedLinear( - in_channels * math.prod(kernel_size), + self.proj = Conv3dLayer( + in_channels, hidden_size, + kernel_size=kernel_size, + stride=kernel_size, bias=True, - return_bias=False, ) def forward(self, x: torch.Tensor) -> torch.Tensor: L, C = x.shape - x = self.proj(x) + x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size) + x = self.proj(x).view(L, self.hidden_size) return x @@ -566,9 +566,6 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loaded_params: set[str] = set() for name, loaded_weight in weights: - if name.endswith("patch_embed.proj.weight"): - loaded_weight = conv3d_to_linear_weight(loaded_weight) - for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 5f5bde1dd72d..faeb9f81d961 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -24,7 +24,6 @@ # limitations under the License. """Inference-only Qwen3VL model compatible with HuggingFace weights.""" -import math from collections.abc import Callable, Iterable, Mapping, Sequence from functools import partial from itertools import islice @@ -57,9 +56,9 @@ from vllm.distributed import get_pp_group from vllm.logger import init_logger from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY +from vllm.model_executor.layers.conv import Conv3dLayer from vllm.model_executor.layers.linear import ( ColumnParallelLinear, - ReplicatedLinear, RowParallelLinear, ) from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -114,7 +113,6 @@ maybe_prefix, ) from .vision import ( - conv3d_to_linear_weight, get_vit_attn_backend, run_dp_sharded_mrope_vision_model, ) @@ -139,15 +137,18 @@ def __init__( self.hidden_size = hidden_size kernel_size = (temporal_patch_size, patch_size, patch_size) - self.proj = ReplicatedLinear( - in_channels * math.prod(kernel_size), + self.proj = Conv3dLayer( + in_channels, hidden_size, + kernel_size=kernel_size, + stride=kernel_size, bias=True, - return_bias=False, ) def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.proj(x) + L, C = x.shape + x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size) + x = self.proj(x).view(L, self.hidden_size) return x @@ -579,9 +580,6 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loaded_params: set[str] = set() for name, loaded_weight in weights: - if name.endswith("patch_embed.proj.weight"): - loaded_weight = conv3d_to_linear_weight(loaded_weight) - for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue diff --git a/vllm/model_executor/models/vision.py b/vllm/model_executor/models/vision.py index 0e814e5c86ad..e5d70eb7bc2f 100644 --- a/vllm/model_executor/models/vision.py +++ b/vllm/model_executor/models/vision.py @@ -550,19 +550,3 @@ def get_llm_pos_ids_for_vision( llm_pos_ids_list.append(_llm_pos_ids + start_idx) llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1) return llm_pos_ids - - -# Due to a performance regression with Conv3D in PyTorch2.9, we reshape -# Conv3D weights to Linear weights for better performance. -# See: https://github.com/vllm-project/vllm/issues/27406 -# and https://github.com/pytorch/pytorch/issues/166122 -# FIXME(Isotr0py): Revert the PR introduces this workaround -# (https://github.com/vllm-project/vllm/pull/27418), -# once the performance issue is resolved in PyTorch. -def conv3d_to_linear_weight(conv3d_weight: torch.Tensor) -> torch.Tensor: - """ - Reshape Conv3D weight to Linear weight. Only work when kernel_size==stride. - """ - out_channels, in_channels, kt, kh, kw = conv3d_weight.shape - linear_weight = conv3d_weight.reshape(out_channels, in_channels * kt * kh * kw) - return linear_weight