Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
236 changes: 236 additions & 0 deletions vllm/model_executor/layers/conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Conv Layer Class."""

import math

import torch
import torch.nn as nn
import torch.nn.functional as F

from vllm.model_executor.custom_op import CustomOp
from vllm.utils.torch_utils import is_torch_equal


class ConvLayerBase(CustomOp):
"""Conv layer base class."""

num_dim: int

def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int | tuple[int, ...],
stride: int | tuple[int, ...] = 1,
padding: int | tuple[int, ...] = 0,
dilation: int | tuple[int, ...] = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros",
*,
params_dtype: torch.dtype | None = None,
) -> None:
super().__init__()

if params_dtype is None:
params_dtype = torch.get_default_dtype()

kernel_size = (
(kernel_size,) * self.num_dim
if isinstance(kernel_size, int)
else kernel_size
)
stride = (stride,) * self.num_dim if isinstance(stride, int) else stride
padding = (padding,) * self.num_dim if isinstance(padding, int) else padding
dilation = (dilation,) * self.num_dim if isinstance(dilation, int) else dilation

self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
self.padding_mode = padding_mode

self.enable_linear = (
(self.kernel_size == self.stride)
and not any(self.padding)
and self.groups == 1
)
self.input_size = in_channels * math.prod(self.kernel_size)

self.weight = nn.Parameter(
torch.empty(
out_channels,
in_channels // groups,
*kernel_size,
dtype=params_dtype,
),
)

if bias:
self.bias = nn.Parameter(torch.empty(self.out_channels, dtype=params_dtype))
else:
self.register_parameter("bias", None)

def extra_repr(self) -> str:
s = f"in_channels={self.in_channels}, "
s += f"out_channels={self.out_channels}, "
s += f"kernel_size={self.kernel_size}, "
s += f"stride={self.stride}, "
s += f"padding={self.padding}, "
s += f"bias={self.bias is not None}"
return s


@CustomOp.register("conv2d")
class Conv2dLayer(ConvLayerBase):
"""Conv layer with Conv2d."""

num_dim = 2

def _forward_mulmat(self, x: torch.Tensor) -> torch.Tensor:
assert x.dim() == 4
B, C, H, W = x.shape
K1, K2 = self.kernel_size
H, W = H // K1, W // K2
x = x.unfold(2, K1, K1).unfold(3, K2, K2)
x = x.permute(0, 2, 3, 1, 4, 5).reshape(-1, self.input_size)
x = F.linear(
x,
self.weight.view(self.out_channels, self.input_size),
self.bias,
)
x = x.view(B, H, W, self.out_channels).permute(0, 3, 1, 2)
return x

def _forward_conv(self, x: torch.Tensor) -> torch.Tensor:
assert x.dim() == 4
x = F.conv2d(
x,
self.weight,
self.bias,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
groups=self.groups,
)
return x

def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""Expected input shape: (batch_size, in_channels, height, width)"""
assert x.dim() == 4
if self.enable_linear:
return self._forward_mulmat(x)
else:
return self._forward_conv(x)

def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
# By default, we use CUDNN's convolution ops with optimization.
return self._forward_conv(x)


class CausalConv2dLayer(Conv2dLayer):
"""
A causal version of nn.Conv2d where each location in the 2D matrix would
have no access to locations on its right or down
All arguments are the same as nn.Conv2d except padding which should be
set as None
"""

def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int,
padding: int = 0,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros",
*,
params_dtype: torch.dtype | None = None,
) -> None:
if padding is not None:
raise ValueError(
"Argument padding should be set to None for CausalConv2dLayer."
)
self._left_padding: int = kernel_size - 1
self._right_padding: int = stride - 1
padding = 0

super().__init__(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
padding_mode,
params_dtype=params_dtype,
)

def forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
x = F.pad(x, pad=(self._left_padding, self._right_padding, 0, 0))
x = super().forward(x)
return x


@CustomOp.register("conv3d")
class Conv3dLayer(ConvLayerBase):
"""Conv layer with Conv3d."""

num_dim = 3

def _forward_mulmat(self, x: torch.Tensor) -> torch.Tensor:
assert x.dim() == 5
B, C, T, H, W = x.shape
K1, K2, K3 = self.kernel_size
T, H, W = T // K1, H // K2, W // K3
x = x.unfold(2, K1, K1).unfold(3, K2, K2).unfold(4, K3, K3)
x = x.permute(0, 2, 3, 4, 1, 5, 6, 7).reshape(-1, self.input_size)
x = F.linear(
x,
self.weight.view(self.out_channels, self.input_size),
self.bias,
)
x = x.view(B, T, H, W, self.out_channels).permute(0, 4, 1, 2, 3)
return x

def _forward_conv(self, x: torch.Tensor) -> torch.Tensor:
assert x.dim() == 5
x = F.conv3d(
x,
self.weight,
self.bias,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
groups=self.groups,
)
return x

def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""Expected input shape: (batch_size, in_channels, time, height, width)"""
if self.enable_linear:
return self._forward_mulmat(x)
else:
return self._forward_conv(x)

def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
# PyTorch2.9.0 disabled CUDNN's Conv3D, which caused a
# significant performance regression.
# See: https://github.com/vllm-project/vllm/issues/27406
# and https://github.com/pytorch/pytorch/issues/166122
# By default, we use CUDNN's convolution ops with optimization.
if self.enable_linear and is_torch_equal("2.9.0"):
return self._forward_mulmat(x)
return self._forward_conv(x)
3 changes: 2 additions & 1 deletion vllm/model_executor/models/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
Expand Down Expand Up @@ -315,7 +316,7 @@ def __init__(self, config: CLIPVisionConfig):

self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))

self.patch_embedding = nn.Conv2d(
self.patch_embedding = Conv2dLayer(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
Expand Down
17 changes: 8 additions & 9 deletions vllm/model_executor/models/glm4_1v.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,12 @@
from vllm.distributed import get_tensor_model_parallel_world_size, parallel_state
from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger
from vllm.model_executor.layers.conv import Conv3dLayer
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
Expand Down Expand Up @@ -103,7 +103,6 @@
maybe_prefix,
)
from .vision import (
conv3d_to_linear_weight,
get_vit_attn_backend,
run_dp_sharded_mrope_vision_model,
)
Expand Down Expand Up @@ -486,15 +485,18 @@ def __init__(
self.hidden_size = hidden_size

kernel_size = (temporal_patch_size, patch_size, patch_size)
self.proj = ReplicatedLinear(
in_channels * math.prod(kernel_size),
self.proj = Conv3dLayer(
in_channels,
hidden_size,
kernel_size=kernel_size,
stride=kernel_size,
bias=True,
return_bias=False,
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
L, C = x.shape
x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
x = self.proj(x).view(L, self.hidden_size)
return x


Expand Down Expand Up @@ -893,9 +895,6 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loaded_params: set[str] = set()

for name, loaded_weight in weights:
if name.endswith("patch_embed.proj.weight"):
loaded_weight = conv3d_to_linear_weight(loaded_weight)

for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
Expand Down
18 changes: 8 additions & 10 deletions vllm/model_executor/models/qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
# limitations under the License.
"""Inference-only Qwen2.5-VL model compatible with HuggingFace weights."""

import math
from collections.abc import Callable, Iterable, Mapping, Sequence
from functools import lru_cache, partial
from typing import Annotated, Any, Literal, TypeAlias
Expand Down Expand Up @@ -56,12 +55,12 @@
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import get_act_and_mul_fn
from vllm.model_executor.layers.conv import Conv3dLayer
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
Expand Down Expand Up @@ -110,7 +109,6 @@
maybe_prefix,
)
from .vision import (
conv3d_to_linear_weight,
get_vit_attn_backend,
run_dp_sharded_mrope_vision_model,
)
Expand Down Expand Up @@ -525,15 +523,18 @@ def __init__(
self.hidden_size = hidden_size

kernel_size = (temporal_patch_size, patch_size, patch_size)
self.proj = ReplicatedLinear(
in_channels * math.prod(kernel_size),
self.proj = Conv3dLayer(
in_channels,
hidden_size,
kernel_size=kernel_size,
stride=kernel_size,
bias=False,
return_bias=False,
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
L, C = x.shape
x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
x = self.proj(x).view(L, self.hidden_size)
return x


Expand Down Expand Up @@ -957,9 +958,6 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loaded_params: set[str] = set()

for name, loaded_weight in weights:
if name.endswith("patch_embed.proj.weight"):
loaded_weight = conv3d_to_linear_weight(loaded_weight)

for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
Expand Down
Loading