Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions vllm/model_executor/layers/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""Conv Layer Class."""

import math
from typing import Literal

import torch
import torch.nn as nn
Expand All @@ -23,11 +24,11 @@ def __init__(
out_channels: int,
kernel_size: int | tuple[int, ...],
stride: int | tuple[int, ...] = 1,
padding: int | tuple[int, ...] = 0,
padding: int | tuple[int, ...] | Literal["same", "valid"] = 0,
dilation: int | tuple[int, ...] = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros",
padding_mode: Literal["zeros", "reflect", "replicate", "circular"] = "zeros",
*,
params_dtype: torch.dtype | None = None,
) -> None:
Expand All @@ -36,6 +37,22 @@ def __init__(
if params_dtype is None:
params_dtype = torch.get_default_dtype()

valid_padding_strings = {"same", "valid"}
if isinstance(padding, str) and padding not in valid_padding_strings:
raise ValueError(
f"Invalid padding string '{padding}'. "
f"Expected one of {valid_padding_strings}."
)

if padding == "same":
padding = (
kernel_size // 2
if isinstance(kernel_size, int)
else tuple(k // 2 for k in kernel_size)
)
elif padding == "valid":
padding = 0

kernel_size = (
(kernel_size,) * self.num_dim
if isinstance(kernel_size, int)
Expand All @@ -45,6 +62,9 @@ def __init__(
padding = (padding,) * self.num_dim if isinstance(padding, int) else padding
dilation = (dilation,) * self.num_dim if isinstance(dilation, int) else dilation

if padding == "same" and any(s != 1 for s in stride):
raise ValueError("padding='same' is not supported for strided convolutions")
Comment on lines +65 to +66
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The addition of this validation check is crucial for correctness. padding='same' behavior is not well-defined for strided convolutions in all frameworks, and explicitly disallowing it prevents potential silent miscalculations or unexpected output dimensions. This improves the robustness of the Conv2dLayer.


self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/models/aimv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.utils import divide
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
Expand Down Expand Up @@ -58,7 +59,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
class AIMv2PatchEmbed(nn.Module):
def __init__(self, config: AIMv2Config):
super().__init__()
self.proj = nn.Conv2d(
self.proj = Conv2dLayer(
config.num_channels,
config.hidden_size,
kernel_size=(config.patch_size, config.patch_size),
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/models/blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from vllm.attention.layer import MultiHeadAttention
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
Expand Down Expand Up @@ -47,7 +48,7 @@ def __init__(self, config: BlipVisionConfig | Blip2VisionConfig):

self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim))

self.patch_embedding = nn.Conv2d(
self.patch_embedding = Conv2dLayer(
in_channels=3,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
Expand Down
29 changes: 14 additions & 15 deletions vllm/model_executor/models/chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
Expand Down Expand Up @@ -549,7 +550,7 @@ def forward(self, hidden_state: torch.Tensor):
class ChameleonVQVAEEncoderConvDownsample(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
self.conv = nn.Conv2d(
self.conv = Conv2dLayer(
in_channels, in_channels, kernel_size=3, stride=2, padding=0
)

Expand Down Expand Up @@ -577,23 +578,23 @@ def __init__(
self.norm1 = torch.nn.GroupNorm(
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
)
self.conv1 = torch.nn.Conv2d(
self.conv1 = Conv2dLayer(
in_channels, out_channels, kernel_size=3, stride=1, padding=1
)
self.norm2 = torch.nn.GroupNorm(
num_groups=32, num_channels=out_channels, eps=1e-6, affine=True
)
self.dropout = torch.nn.Dropout(config.dropout)
self.conv2 = torch.nn.Conv2d(
self.conv2 = Conv2dLayer(
out_channels, out_channels, kernel_size=3, stride=1, padding=1
)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(
self.conv_shortcut = Conv2dLayer(
in_channels, out_channels, kernel_size=3, stride=1, padding=1
)
else:
self.nin_shortcut = torch.nn.Conv2d(
self.nin_shortcut = Conv2dLayer(
in_channels, out_channels, kernel_size=1, stride=1, padding=0
)

Expand Down Expand Up @@ -626,16 +627,16 @@ def __init__(self, in_channels: int):
self.norm = torch.nn.GroupNorm(
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
)
self.q = torch.nn.Conv2d(
self.q = Conv2dLayer(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.k = torch.nn.Conv2d(
self.k = Conv2dLayer(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.v = torch.nn.Conv2d(
self.v = Conv2dLayer(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.proj_out = torch.nn.Conv2d(
self.proj_out = Conv2dLayer(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)

Expand Down Expand Up @@ -681,7 +682,7 @@ def __init__(self, config: ChameleonVQVAEConfig):
latent_channels = config.latent_channels
channel_multiplier = config.channel_multiplier

self.conv_in = torch.nn.Conv2d(
self.conv_in = Conv2dLayer(
in_channels, base_channels, kernel_size=3, stride=1, padding=1
)

Expand Down Expand Up @@ -738,7 +739,7 @@ def __init__(self, config: ChameleonVQVAEConfig):
self.norm_out = torch.nn.GroupNorm(
num_groups=32, num_channels=block_in, eps=1e-6, affine=True
)
self.conv_out = torch.nn.Conv2d(
self.conv_out = Conv2dLayer(
block_in,
2 * latent_channels if double_latent else latent_channels,
kernel_size=3,
Expand Down Expand Up @@ -779,10 +780,8 @@ def __init__(self, config: ChameleonVQVAEConfig):
super().__init__()
self.encoder = ChameleonVQVAEEncoder(config)
self.quantize = ChameleonVQVAEVectorQuantizer(config)
self.quant_conv = torch.nn.Conv2d(config.latent_channels, config.embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(
config.embed_dim, config.latent_channels, 1
)
self.quant_conv = Conv2dLayer(config.latent_channels, config.embed_dim, 1)
self.post_quant_conv = Conv2dLayer(config.embed_dim, config.latent_channels, 1)
self.eval() # Chameleon's VQ model is frozen

def encode(
Expand Down
13 changes: 8 additions & 5 deletions vllm/model_executor/models/deepencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from transformers import CLIPVisionConfig

from vllm.attention.layer import MultiHeadAttention
from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader

Expand Down Expand Up @@ -133,14 +134,14 @@ def __init__(
self.blocks.append(block)

self.neck = nn.Sequential(
nn.Conv2d(
Conv2dLayer(
embed_dim,
out_chans,
kernel_size=1,
bias=False,
),
LayerNorm2d(out_chans),
nn.Conv2d(
Conv2dLayer(
out_chans,
out_chans,
kernel_size=3,
Expand All @@ -150,8 +151,10 @@ def __init__(
LayerNorm2d(out_chans),
)

self.net_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False)
self.net_3 = nn.Conv2d(
self.net_2 = Conv2dLayer(
256, 512, kernel_size=3, stride=2, padding=1, bias=False
)
self.net_3 = Conv2dLayer(
512, 1024, kernel_size=3, stride=2, padding=1, bias=False
)

Expand Down Expand Up @@ -500,7 +503,7 @@ def __init__(
"""
super().__init__()

self.proj = nn.Conv2d(
self.proj = Conv2dLayer(
in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
)

Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/models/dots_ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
Expand Down Expand Up @@ -471,7 +472,7 @@ def __init__(self, config):
self.temporal_patch_size = config.temporal_patch_size
self.embed_dim = config.embed_dim
self.config = config
self.proj = nn.Conv2d(
self.proj = Conv2dLayer(
config.num_channels,
config.embed_dim,
kernel_size=(config.patch_size, config.patch_size),
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/glm4_1v.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
from vllm.distributed import get_tensor_model_parallel_world_size, parallel_state
from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger
from vllm.model_executor.layers.conv import Conv3dLayer
from vllm.model_executor.layers.conv import Conv2dLayer, Conv3dLayer
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
Expand Down Expand Up @@ -734,7 +734,7 @@ def __init__(
self.post_conv_layernorm = RMSNorm(
vision_config.hidden_size, eps=vision_config.rms_norm_eps
)
self.downsample = nn.Conv2d(
self.downsample = Conv2dLayer(
in_channels=vision_config.hidden_size,
out_channels=vision_config.out_hidden_size,
kernel_size=vision_config.spatial_merge_size,
Expand Down
5 changes: 3 additions & 2 deletions vllm/model_executor/models/glm4v.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
Expand Down Expand Up @@ -78,7 +79,7 @@ class GLMVImagePixelInputs(TensorSchema):
class EVA2CLIPPatchEmbedding(nn.Module):
def __init__(self, config):
super().__init__()
self.proj = nn.Conv2d(
self.proj = Conv2dLayer(
config.in_channels,
config.hidden_size,
kernel_size=config.patch_size,
Expand Down Expand Up @@ -333,7 +334,7 @@ def __init__(
quant_config=quant_config,
prefix=f"{prefix}.linear_proj",
)
self.conv = nn.Conv2d(
self.conv = Conv2dLayer(
in_channels=vision_config.hidden_size,
out_channels=config.hidden_size,
kernel_size=2,
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/models/idefics2_vision_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from vllm.attention.layer import MultiHeadAttention
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
Expand Down Expand Up @@ -60,7 +61,7 @@ def __init__(self, config: Idefics2VisionConfig):
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.patch_embedding = nn.Conv2d(
self.patch_embedding = Conv2dLayer(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
Comment on lines +64 to 67

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P0 Badge Idefics2 vision embeddings now call Conv2dLayer with padding string

The new Conv2dLayer wrapper forwards padding directly to F.conv2d and does not implement the "valid" shortcut that nn.Conv2d provided. Using the string here will cause a runtime failure when forward runs. Replace with the correct numeric padding (0) or add conversion logic.

Useful? React with 👍 / 👎.

Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/models/intern_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
tensor_model_parallel_all_gather,
)
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
Expand Down Expand Up @@ -51,7 +52,7 @@ def __init__(self, config: PretrainedConfig):

self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim))

self.patch_embedding = nn.Conv2d(
self.patch_embedding = Conv2dLayer(
in_channels=3,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/models/interns1_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from vllm.attention.layer import MultiHeadAttention
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
from vllm.model_executor.layers.quantization import QuantizationConfig
Expand Down Expand Up @@ -43,7 +44,7 @@ def __init__(self, config):
self.num_patches = num_patches
self.patch_shape = patch_shape

self.projection = nn.Conv2d(
self.projection = Conv2dLayer(
num_channels, hidden_size, kernel_size=patch_size, stride=patch_size
)

Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/models/keye.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
Expand Down Expand Up @@ -204,7 +205,7 @@ def __init__(self, config: PretrainedConfig):
self.image_size = config.image_size
self.patch_size = config.patch_size

self.patch_embedding = nn.Conv2d(
self.patch_embedding = Conv2dLayer(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
Comment on lines +208 to 211

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P0 Badge Keye vision embeddings pass unsupported padding="valid"

Conv2dLayer’s constructor only handles integer/tuple padding values. Passing the string "valid" like nn.Conv2d previously allowed will lead to an exception in forward when the convolution executes. Replace the string with the equivalent numeric padding.

Useful? React with 👍 / 👎.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you update the type annotation to account for this?

Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/models/midashenglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
Expand Down Expand Up @@ -120,7 +121,7 @@ def __init__(
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten

self.proj = nn.Conv2d(
self.proj = Conv2dLayer(
in_chans,
embed_dim,
kernel_size=self.patch_size,
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/models/moonvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import is_flash_attn_2_available

from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.models.utils import maybe_prefix
from vllm.transformers_utils.configs.moonvit import MoonViTConfig
Expand Down Expand Up @@ -244,7 +245,7 @@ def __init__(
)
self.patch_size = patch_size

self.proj = nn.Conv2d(
self.proj = Conv2dLayer(
in_dim, out_dim, kernel_size=patch_size, stride=patch_size
)

Expand Down
Loading