Skip to content

Commit 60111dd

Browse files
committed
Extract conv layer as CustomOp
Signed-off-by: shen-shanshan <[email protected]>
1 parent ac0bb2c commit 60111dd

File tree

8 files changed

+418
-75
lines changed

8 files changed

+418
-75
lines changed
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""Multi-Modal Layers."""
4+
5+
from .conv import CausalConv2dLayer, Conv2dLayer, Conv3dLayer
6+
7+
8+
def get_conv_layer(
9+
in_channels: int,
10+
out_channels: int,
11+
kernel_size: int | tuple,
12+
stride: int,
13+
padding: int,
14+
dilation: int,
15+
groups: int,
16+
enable_bias: bool,
17+
padding_mode: str,
18+
conv_type: str,
19+
):
20+
assert in_channels and out_channels and kernel_size
21+
22+
if conv_type == "conv2d":
23+
return Conv2dLayer(
24+
in_channels=in_channels,
25+
out_channels=out_channels,
26+
kernel_size=kernel_size,
27+
stride=stride,
28+
padding=padding,
29+
dilation=dilation,
30+
groups=groups,
31+
enable_bias=enable_bias,
32+
padding_mode=padding_mode,
33+
)
34+
elif conv_type == "causal_conv2d":
35+
return CausalConv2dLayer(
36+
in_channels=in_channels,
37+
out_channels=out_channels,
38+
kernel_size=kernel_size,
39+
stride=stride,
40+
padding=padding,
41+
dilation=dilation,
42+
groups=groups,
43+
enable_bias=enable_bias,
44+
padding_mode=padding_mode,
45+
)
46+
elif conv_type == "conv3d":
47+
return Conv3dLayer(
48+
in_channels=in_channels,
49+
out_channels=out_channels,
50+
kernel_size=kernel_size,
51+
stride=stride,
52+
padding=padding,
53+
dilation=dilation,
54+
groups=groups,
55+
enable_bias=enable_bias,
56+
padding_mode=padding_mode,
57+
)
58+
else:
59+
raise ValueError(f"Unknown conv layer type {conv_type}.")
Lines changed: 319 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,319 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""Conv Layer Class."""
4+
5+
import math
6+
from collections.abc import Callable
7+
8+
import torch
9+
import torch.nn.functional as F
10+
import torch.nn.parameter as Parameter
11+
12+
from vllm.model_executor.custom_op import CustomOp
13+
from vllm.model_executor.utils import set_weight_attrs
14+
15+
16+
@CustomOp.register("conv2d")
17+
class Conv2dLayer(CustomOp):
18+
"""Conv2D layer."""
19+
20+
def __init__(
21+
self,
22+
in_channels: int,
23+
out_channels: int,
24+
kernel_size: int | tuple,
25+
stride: int = 1,
26+
padding: int = 0,
27+
dilation: int = 1,
28+
groups: int = 1,
29+
enable_bias: bool = True,
30+
padding_mode: str = "zeros",
31+
) -> None:
32+
super().__init__()
33+
34+
if isinstance(kernel_size, int):
35+
kernel_size = (kernel_size, kernel_size)
36+
37+
self.in_channels = in_channels
38+
self.out_channels = out_channels
39+
self.kernel_size = kernel_size
40+
self.stride = stride
41+
self.padding: int = padding
42+
self.dilation = dilation
43+
self.groups = groups
44+
self.enable_bias = enable_bias
45+
self.padding_mode = padding_mode
46+
47+
self.enable_linear: bool = False
48+
if _enable_linear(kernel_size[0], stride, padding):
49+
self.enable_linear = True
50+
51+
_create_conv_weights(
52+
self,
53+
in_channels=self.in_channels,
54+
out_channels=self.out_channels,
55+
kernel_size=self.kernel_size,
56+
groups=self.groups,
57+
enable_bias=self.enable_bias,
58+
enable_linear=self.enable_linear,
59+
weight_loader=self.weight_loader,
60+
)
61+
62+
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
63+
if self.enable_linear:
64+
loaded_weight = _convert_conv_to_linear_weight(loaded_weight)
65+
param.data.copy_(loaded_weight)
66+
67+
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
68+
if self.enable_linear:
69+
x = F.linear(x, self.weight, self.bias)
70+
else:
71+
x = F.conv2d(
72+
x,
73+
self.weight,
74+
bias=self.bias,
75+
stride=self.stride,
76+
padding=self.padding,
77+
dilation=self.dilation,
78+
groups=self.groups,
79+
)
80+
return x
81+
82+
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
83+
return self.forward_native(x)
84+
85+
def extra_repr(self) -> str:
86+
s = f"in_channels={self.in_channels}, "
87+
s += f"out_channels={self.out_channels}, "
88+
s += f"kernel_size={self.kernel_size}, "
89+
s += f"stride={self.stride}, "
90+
s += f"padding={self.padding}, "
91+
s += f"bias={self.bias}, "
92+
return s
93+
94+
95+
@CustomOp.register("causal_conv2d")
96+
class CausalConv2dLayer(CustomOp):
97+
"""
98+
A causal version of nn.Conv2d where each location in the 2D matrix would
99+
have no access to locations on its right or down
100+
All arguments are the same as nn.Conv2d except padding which should be
101+
set as None
102+
"""
103+
104+
def __init__(
105+
self,
106+
in_channels: int,
107+
out_channels: int,
108+
kernel_size: int | tuple,
109+
stride: int = 1,
110+
padding: int = 0,
111+
dilation: int = 1,
112+
groups: int = 1,
113+
enable_bias: bool = True,
114+
padding_mode: str = "zeros",
115+
) -> None:
116+
super().__init__()
117+
118+
if isinstance(kernel_size, int):
119+
kernel_size = (kernel_size, kernel_size)
120+
121+
self.in_channels = in_channels
122+
self.out_channels = out_channels
123+
self.kernel_size = kernel_size
124+
self.stride = stride
125+
self.dilation = dilation
126+
self.groups = groups
127+
self.enable_bias = enable_bias
128+
self.padding_mode = padding_mode
129+
130+
if padding is not None:
131+
raise ValueError(
132+
"Argument padding should be set to None for CausalConv2dLayer."
133+
)
134+
self._left_padding: int = kernel_size[0] - 1
135+
self._right_padding: int = stride - 1
136+
self.padding: int = 0
137+
138+
self.enable_linear: bool = False
139+
if _enable_linear(kernel_size[0], stride, padding):
140+
self.enable_linear = True
141+
142+
_create_conv_weights(
143+
self,
144+
in_channels=self.in_channels,
145+
out_channels=self.out_channels,
146+
kernel_size=self.kernel_size,
147+
groups=self.groups,
148+
enable_bias=self.enable_bias,
149+
enable_linear=self.enable_linear,
150+
weight_loader=self.weight_loader,
151+
)
152+
153+
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
154+
if self.enable_linear:
155+
loaded_weight = _convert_conv_to_linear_weight(loaded_weight)
156+
param.data.copy_(loaded_weight)
157+
158+
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
159+
if self.enable_linear:
160+
x = F.linear(x, self.weight, self.bias)
161+
else:
162+
x = F.pad(x, pad=(self._left_padding, self._right_padding, 0, 0))
163+
x = F.conv2d(
164+
x,
165+
self.weight,
166+
bias=self.bias,
167+
stride=self.stride,
168+
padding=self.padding,
169+
dilation=self.dilation,
170+
groups=self.groups,
171+
)
172+
return x
173+
174+
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
175+
return self.forward_native(x)
176+
177+
def extra_repr(self) -> str:
178+
s = f"in_channels={self.in_channels}, "
179+
s += f"out_channels={self.out_channels}, "
180+
s += f"kernel_size={self.kernel_size}, "
181+
s += f"stride={self.stride}, "
182+
s += f"padding={self.padding}, "
183+
s += f"bias={self.bias}, "
184+
return s
185+
186+
187+
@CustomOp.register("conv3d")
188+
class Conv3dLayer(CustomOp):
189+
"""Conv3D layer with linear weight."""
190+
191+
def __init__(
192+
self,
193+
in_channels: int,
194+
out_channels: int,
195+
kernel_size: int | tuple,
196+
stride: int = 1,
197+
padding: int = 0,
198+
dilation: int = 1,
199+
groups: int = 1,
200+
enable_bias: bool = True,
201+
padding_mode: str = "zeros",
202+
) -> None:
203+
super().__init__()
204+
205+
assert isinstance(kernel_size, tuple) and len(kernel_size) == 3
206+
207+
self.in_channels = in_channels
208+
self.out_channels = out_channels
209+
self.kernel_size = kernel_size
210+
self.stride = stride
211+
self.padding: int = padding
212+
self.dilation = dilation
213+
self.groups = groups
214+
self.padding_mode = padding_mode
215+
self.use_linear = False
216+
217+
self.enable_linear: bool = False
218+
if _enable_linear(kernel_size[0], stride, padding):
219+
self.enable_linear = True
220+
221+
_create_conv_weights(
222+
self,
223+
in_channels=self.in_channels,
224+
out_channels=self.out_channels,
225+
kernel_size=self.kernel_size,
226+
groups=self.groups,
227+
enable_bias=self.enable_bias,
228+
enable_linear=self.enable_linear,
229+
weight_loader=self.weight_loader,
230+
)
231+
232+
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
233+
if self.enable_linear:
234+
loaded_weight = _convert_conv_to_linear_weight(loaded_weight)
235+
param.data.copy_(loaded_weight)
236+
237+
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
238+
if self.enable_linear:
239+
x = F.linear(x, self.weight, self.bias)
240+
else:
241+
x = F.conv2d(
242+
x,
243+
self.weight,
244+
bias=self.bias,
245+
stride=self.stride,
246+
padding=self.padding,
247+
dilation=self.dilation,
248+
groups=self.groups,
249+
)
250+
return x
251+
252+
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
253+
return self.forward_native(x)
254+
255+
def extra_repr(self) -> str:
256+
s = f"in_channels={self.in_channels}, "
257+
s += f"out_channels={self.out_channels}, "
258+
s += f"kernel_size={self.kernel_size}, "
259+
s += f"stride={self.stride}, "
260+
s += f"padding={self.padding}, "
261+
s += f"bias={self.bias}, "
262+
return s
263+
264+
265+
def _enable_linear(
266+
kernel_size: int,
267+
stride: int,
268+
padding: int,
269+
) -> bool:
270+
assert isinstance(kernel_size, int) and isinstance(stride, int)
271+
return kernel_size == stride and padding == 0
272+
273+
274+
def _create_conv_weights(
275+
layer: torch.nn.Module,
276+
in_channels: int,
277+
out_channels: int,
278+
kernel_size: tuple,
279+
groups: int,
280+
enable_bias: bool,
281+
enable_linear: bool,
282+
weight_loader: Callable,
283+
) -> None:
284+
if enable_linear:
285+
# Use linear computation for better performance.
286+
weight = Parameter(
287+
torch.empty((out_channels, in_channels * math.prod(kernel_size)))
288+
)
289+
else:
290+
# Use normal Conv2D computation.
291+
weight = Parameter(
292+
torch.empty((out_channels, in_channels // groups, *kernel_size))
293+
)
294+
layer.register_parameter("weight", weight)
295+
set_weight_attrs(weight, {"weight_loader": weight_loader})
296+
297+
if enable_bias:
298+
bias = Parameter(torch.empty(out_channels))
299+
layer.register_parameter("bias", bias)
300+
set_weight_attrs(bias, {"weight_loader": weight_loader})
301+
else:
302+
layer.register_parameter("bias", None)
303+
304+
305+
# Due to a performance regression with Conv3D in PyTorch2.9, we reshape
306+
# Conv3D weights to Linear weights for better performance.
307+
# See: https://github.com/vllm-project/vllm/issues/27406
308+
# and https://github.com/pytorch/pytorch/issues/166122
309+
# FIXME(Isotr0py): Revert the PR introduces this workaround
310+
# (https://github.com/vllm-project/vllm/pull/27418),
311+
# once the performance issue is resolved in PyTorch.
312+
def _convert_conv_to_linear_weight(conv_weight: torch.Tensor) -> torch.Tensor:
313+
"""
314+
Reshape Conv2D or Conv3D weight to Linear weight.
315+
Only work when kernel_size==stride.
316+
"""
317+
out_channels = conv_weight.shape[0]
318+
linear_weight = conv_weight.reshape(out_channels, -1)
319+
return linear_weight

0 commit comments

Comments
 (0)