Skip to content

Commit 2d9f1c7

Browse files
committed
update conv2d conv3d impl
Signed-off-by: Isotr0py <[email protected]>
1 parent a627804 commit 2d9f1c7

File tree

2 files changed

+134
-45
lines changed

2 files changed

+134
-45
lines changed

vllm/model_executor/layers/multi_modal/conv.py

Lines changed: 125 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2,80 +2,172 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
"""Conv Layer Class."""
44

5+
import math
6+
57
import torch
68
import torch.nn as nn
79
import torch.nn.functional as F
810

911
from vllm.model_executor.custom_op import CustomOp
1012
from vllm.model_executor.layers.linear import ReplicatedLinear
1113
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
14+
from vllm.model_executor.utils import set_weight_attrs
1215

1316

14-
@CustomOp.register("conv")
1517
class ConvLayerBase(CustomOp):
1618
"""Conv layer base class."""
1719

18-
def __init__(
19-
self,
20-
) -> None:
21-
super().__init__()
22-
23-
24-
class Conv2dLayer(ConvLayerBase):
25-
"""Conv layer with Conv2d."""
20+
num_dim: int
2621

2722
def __init__(
2823
self,
2924
in_channels: int,
3025
out_channels: int,
31-
kernel_size: int | tuple,
32-
stride: int | tuple | None,
33-
padding: int | tuple | str | None,
34-
dilation: int | tuple | None,
35-
groups: int | None,
36-
bias: bool | None,
37-
padding_mode: str | None,
26+
kernel_size: int | tuple[int, ...],
27+
stride: int | tuple[int, ...] = 1,
28+
padding: int | tuple[int, ...] = 0,
29+
dilation: int | tuple[int, ...] = 1,
30+
groups: int = 1,
31+
bias: bool = True,
32+
padding_mode: str = "zeros",
33+
*,
34+
params_dtype: torch.dtype | None = None,
3835
) -> None:
3936
super().__init__()
4037

38+
if params_dtype is None:
39+
params_dtype = torch.get_default_dtype()
40+
41+
kernel_size = (
42+
(kernel_size,) * self.num_dim
43+
if isinstance(kernel_size, int)
44+
else kernel_size
45+
)
46+
stride = (stride,) * self.num_dim if isinstance(stride, int) else stride
47+
padding = (padding,) * self.num_dim if isinstance(padding, int) else padding
48+
dilation = (dilation,) * self.num_dim if isinstance(dilation, int) else dilation
49+
4150
self.in_channels = in_channels
4251
self.out_channels = out_channels
4352
self.kernel_size = kernel_size
4453
self.stride = stride
4554
self.padding = padding
4655
self.dilation = dilation
4756
self.groups = groups
48-
self.bias = bias
4957
self.padding_mode = padding_mode
5058

51-
self.proj = nn.Conv2d(
52-
in_channels=in_channels,
53-
out_channels=out_channels,
54-
kernel_size=kernel_size,
55-
stride=stride,
56-
padding=padding,
57-
dilation=dilation,
58-
groups=groups,
59-
bias=bias,
60-
padding_mode=padding_mode,
59+
self.can_linearize = (
60+
(self.kernel_size == self.stride)
61+
and not any(self.padding)
62+
and self.groups == 1
6163
)
6264

63-
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
64-
x = self.proj(x)
65-
return x
65+
if self.can_linearize:
66+
self.weight = nn.Parameter(
67+
torch.empty(
68+
out_channels,
69+
in_channels * math.prod(self.kernel_size),
70+
dtype=params_dtype,
71+
),
72+
)
73+
else:
74+
self.weight = nn.Parameter(
75+
torch.empty(
76+
out_channels,
77+
in_channels // groups,
78+
*kernel_size,
79+
dtype=params_dtype,
80+
),
81+
)
82+
set_weight_attrs(self.weight, {"weight_loader": self.weight_loader})
6683

67-
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
68-
return self.forward_native(x)
84+
if bias:
85+
self.bias = nn.Parameter(torch.empty(self.out_channels, dtype=params_dtype))
86+
set_weight_attrs(
87+
self.bias,
88+
{
89+
"weight_loader": self.weight_loader,
90+
},
91+
)
92+
else:
93+
self.register_parameter("bias", None)
6994

7095
def extra_repr(self) -> str:
7196
s = f"in_channels={self.in_channels}, "
7297
s += f"out_channels={self.out_channels}, "
7398
s += f"kernel_size={self.kernel_size}, "
7499
s += f"stride={self.stride}, "
75100
s += f"padding={self.padding}, "
76-
s += f"bias={self.bias}, "
101+
s += f"bias={self.bias is not None}"
77102
return s
78103

104+
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
105+
param.data.copy_(loaded_weight.view(param.shape))
106+
107+
108+
@CustomOp.register("conv2d")
109+
class Conv2dLayer(ConvLayerBase):
110+
"""Conv layer with Conv2d."""
111+
112+
num_dim = 2
113+
114+
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
115+
"""Expected input shape: (batch_size, in_channels, height, width)"""
116+
assert x.dim() == 4
117+
if self.can_linearize:
118+
B, C, H, W = x.shape
119+
K1, K2 = self.kernel_size
120+
H, W = H // K1, W // K2
121+
x = x.view(-1, self.in_channels * math.prod(self.kernel_size))
122+
x = F.linear(x, self.weight, self.bias)
123+
x = x.view(B, self.out_channels, H, W)
124+
else:
125+
x = F.conv2d(
126+
x,
127+
self.weight,
128+
self.bias,
129+
stride=self.stride,
130+
padding=self.padding,
131+
dilation=self.dilation,
132+
groups=self.groups,
133+
)
134+
return x
135+
136+
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
137+
return self.forward_native(x)
138+
139+
140+
@CustomOp.register("conv3d")
141+
class Conv3dLayer(ConvLayerBase):
142+
"""Conv layer with Conv3d."""
143+
144+
num_dim = 3
145+
146+
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
147+
"""Expected input shape: (batch_size, in_channels, time, height, width)"""
148+
assert x.dim() == 5
149+
if self.can_linearize:
150+
B, C, T, H, W = x.shape
151+
K1, K2, K3 = self.kernel_size
152+
T, H, W = T // K1, H // K2, W // K3
153+
x = x.view(-1, self.in_channels * math.prod(self.kernel_size))
154+
x = F.linear(x, self.weight, self.bias)
155+
x = x.view(B, self.out_channels, T, H, W)
156+
else:
157+
x = F.conv3d(
158+
x,
159+
self.weight,
160+
self.bias,
161+
stride=self.stride,
162+
padding=self.padding,
163+
dilation=self.dilation,
164+
groups=self.groups,
165+
)
166+
return x
167+
168+
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
169+
return self.forward_native(x)
170+
79171

80172
class CausalConv2dLayer(Conv2dLayer):
81173
"""

vllm/model_executor/models/qwen2_5_vl.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
# limitations under the License.
2727
"""Inference-only Qwen2.5-VL model compatible with HuggingFace weights."""
2828

29-
import math
3029
from collections.abc import Callable, Iterable, Mapping, Sequence
3130
from functools import lru_cache, partial
3231
from typing import Annotated, Any, Literal, TypeAlias
@@ -63,7 +62,7 @@
6362
QKVParallelLinear,
6463
RowParallelLinear,
6564
)
66-
from vllm.model_executor.layers.multi_modal import get_conv_layer
65+
from vllm.model_executor.layers.multi_modal.conv import Conv3dLayer
6766
from vllm.model_executor.layers.quantization import QuantizationConfig
6867
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
6968
from vllm.model_executor.models.module_mapping import MultiModelKeys
@@ -110,7 +109,6 @@
110109
maybe_prefix,
111110
)
112111
from .vision import (
113-
conv3d_to_linear_weight,
114112
get_vit_attn_backend,
115113
run_dp_sharded_mrope_vision_model,
116114
)
@@ -555,16 +553,18 @@ def __init__(
555553
self.hidden_size = hidden_size
556554

557555
kernel_size = (temporal_patch_size, patch_size, patch_size)
558-
self.proj = get_conv_layer(
559-
input_size=in_channels * math.prod(kernel_size),
560-
output_size=hidden_size,
556+
self.proj = Conv3dLayer(
557+
in_channels,
558+
hidden_size,
559+
kernel_size=kernel_size,
560+
stride=kernel_size,
561561
bias=False,
562-
return_bias=False,
563-
conv_type="linear",
564562
)
565563

566564
def forward(self, x: torch.Tensor) -> torch.Tensor:
567-
x = self.proj(x)
565+
L, C = x.shape
566+
x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
567+
x = self.proj(x).view(L, self.hidden_size)
568568
return x
569569

570570

@@ -988,9 +988,6 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
988988
loaded_params: set[str] = set()
989989

990990
for name, loaded_weight in weights:
991-
if name.endswith("patch_embed.proj.weight"):
992-
loaded_weight = conv3d_to_linear_weight(loaded_weight)
993-
994991
for param_name, weight_name, shard_id in stacked_params_mapping:
995992
if weight_name not in name:
996993
continue

0 commit comments

Comments
 (0)