|
18 | 18 | # limitations under the License. |
19 | 19 |
|
20 | 20 | from functools import partial |
21 | | -from typing import Callable, Iterable, Optional, Set, Tuple, Union |
| 21 | +from typing import Iterable, Optional, Set, Tuple, Union |
22 | 22 |
|
23 | 23 | import torch |
24 | 24 | import torch.nn as nn |
25 | 25 | import torch.nn.functional as F |
26 | | -import torch_npu |
27 | 26 | from einops import rearrange |
28 | 27 | from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( |
29 | 28 | Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig) |
|
35 | 34 | from vllm.model_executor.layers.quantization import QuantizationConfig |
36 | 35 | from vllm.model_executor.model_loader.weight_utils import default_weight_loader |
37 | 36 | from vllm.model_executor.models.qwen2_5_vl import ( |
38 | | - Qwen2_5_VisionAttention, Qwen2_5_VisionBlock, Qwen2_5_VisionPatchEmbed, |
| 37 | + Qwen2_5_VisionBlock, Qwen2_5_VisionPatchEmbed, |
39 | 38 | Qwen2_5_VisionRotaryEmbedding, Qwen2_5_VisionTransformer, |
40 | 39 | Qwen2_5_VLDummyInputsBuilder, Qwen2_5_VLForConditionalGeneration, |
41 | 40 | Qwen2_5_VLMultiModalProcessor, Qwen2_5_VLProcessingInfo) |
42 | 41 | from vllm.model_executor.models.utils import maybe_prefix |
43 | 42 | from vllm.multimodal import MULTIMODAL_REGISTRY |
44 | 43 |
|
45 | 44 | from vllm_ascend.ascend_forward_context import set_ascend_forward_context |
46 | | -from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, is_enable_nz, |
47 | | - vllm_version_is) |
| 45 | +from vllm_ascend.utils import vllm_version_is |
48 | 46 |
|
49 | 47 | if not vllm_version_is("0.11.0"): |
50 | 48 | from vllm.model_executor.models.vision import conv3d_to_linear_weight |
|
53 | 51 | MAX_PAD_SIZE = 128 # max_size to pad weight |
54 | 52 |
|
55 | 53 |
|
56 | | -class AscendQwen2_5_VisionAttention(Qwen2_5_VisionAttention): |
57 | | - |
58 | | - def __init__( |
59 | | - self, |
60 | | - embed_dim: int, |
61 | | - num_heads: int, |
62 | | - projection_size: int, |
63 | | - quant_config: Optional[QuantizationConfig] = None, |
64 | | - prefix: str = "", |
65 | | - ) -> None: |
66 | | - super().__init__( |
67 | | - embed_dim, |
68 | | - num_heads, |
69 | | - projection_size, |
70 | | - quant_config, |
71 | | - prefix, |
72 | | - ) |
73 | | - self.embed_dim = embed_dim |
74 | | - self.hidden_size_per_attention_head = dist_utils.divide( |
75 | | - projection_size, num_heads) |
76 | | - self.origin_hidden_size_per_attention_head = self.hidden_size_per_attention_head |
77 | | - if self.hidden_size_per_attention_head > MIN_PAD_SIZE and self.hidden_size_per_attention_head < MAX_PAD_SIZE: |
78 | | - self.hidden_size_per_attention_head = MAX_PAD_SIZE |
79 | | - |
80 | | - def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: |
81 | | - # [s, b, 3 * head * head_dim] |
82 | | - seq_len, bs, _ = qkv.shape |
83 | | - |
84 | | - # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim] |
85 | | - q, k, v = qkv.chunk(3, dim=2) |
86 | | - |
87 | | - # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim] |
88 | | - new_shape = (seq_len, bs, self.num_attention_heads_per_partition, |
89 | | - self.hidden_size_per_attention_head) |
90 | | - q, k, v = (x.view(*new_shape) for x in (q, k, v)) |
91 | | - return q, k, v |
| 54 | +class AscendQwen2_5_VisionBlock(Qwen2_5_VisionBlock): |
92 | 55 |
|
93 | 56 | def forward( |
94 | 57 | self, |
95 | 58 | x: torch.Tensor, |
96 | 59 | cu_seqlens: torch.Tensor, |
97 | | - cos: torch.Tensor, |
98 | | - sin: torch.Tensor, |
| 60 | + rotary_pos_emb_cos: torch.Tensor, |
| 61 | + rotary_pos_emb_sin: torch.Tensor, |
99 | 62 | ) -> torch.Tensor: |
100 | | - # [s, b, c] --> [s, b, head * 3 * head_dim] |
101 | | - x, _ = self.qkv(x) |
102 | | - |
103 | | - # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim] |
104 | | - q, k, v = self.split_qkv(x) |
105 | | - batch_size = q.shape[1] |
106 | | - |
107 | | - q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() |
108 | | - for x in (q, k, v)) |
109 | | - q = torch_npu.npu_rotary_mul(q, cos, sin) |
110 | | - k = torch_npu.npu_rotary_mul(k, cos, sin) |
111 | | - |
112 | | - q, k, v = [ |
113 | | - rearrange(x, "b s h d -> (b s) h d").contiguous() |
114 | | - for x in (q, k, v) |
115 | | - ] |
116 | | - |
117 | | - context_layer = torch.empty_like(q) |
118 | | - |
119 | | - # operator requires pta version >= 2.5.1 |
120 | | - torch_npu._npu_flash_attention_unpad( |
121 | | - query=q, |
122 | | - key=k, |
123 | | - value=v, |
124 | | - seq_len=cu_seqlens, |
125 | | - scale_value=self.origin_hidden_size_per_attention_head**-0.5, |
126 | | - num_heads=self.num_attention_heads_per_partition, |
127 | | - num_kv_heads=self.num_attention_heads_per_partition, |
128 | | - out=context_layer) |
129 | | - |
130 | | - context_layer = rearrange(context_layer, |
131 | | - "(b s) h d -> s b (h d)", |
132 | | - b=batch_size).contiguous() |
133 | | - |
134 | | - output, _ = self.proj(context_layer) |
135 | | - return output |
136 | | - |
137 | | - |
138 | | -class AscendQwen2_5_VisionBlock(Qwen2_5_VisionBlock): |
139 | | - |
140 | | - def __init__( |
141 | | - self, |
142 | | - dim: int, |
143 | | - num_heads: int, |
144 | | - mlp_hidden_dim: int, |
145 | | - act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, |
146 | | - norm_layer: Optional[Callable[[int], nn.Module]] = None, |
147 | | - quant_config: Optional[QuantizationConfig] = None, |
148 | | - prefix: str = "", |
149 | | - ) -> None: |
150 | | - super().__init__(dim, num_heads, mlp_hidden_dim, act_fn, norm_layer, |
151 | | - quant_config, prefix) |
152 | | - self.attn = AscendQwen2_5_VisionAttention(embed_dim=dim, |
153 | | - num_heads=num_heads, |
154 | | - projection_size=dim, |
155 | | - quant_config=quant_config, |
156 | | - prefix=f"{prefix}.attn") |
157 | | - |
158 | | - def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor, |
159 | | - cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: |
160 | | - x = x + self.attn( |
161 | | - self.norm1(x), cu_seqlens=cu_seqlens, cos=cos, sin=sin) |
162 | | - |
| 63 | + x = x + self.attn(self.norm1(x), |
| 64 | + cu_seqlens=cu_seqlens, |
| 65 | + rotary_pos_emb_cos=rotary_pos_emb_cos, |
| 66 | + rotary_pos_emb_sin=rotary_pos_emb_sin) |
163 | 67 | x = x + self.mlp(self.norm2(x)) |
164 | 68 | return x |
165 | 69 |
|
@@ -255,98 +159,6 @@ def cal_cos_sin(self, rotary_pos_emb): |
255 | 159 | self.hidden_size_per_attention_head) |
256 | 160 | return cos_new, sin_new |
257 | 161 |
|
258 | | - def pad_qkv_bias(self, bias): |
259 | | - first_half = bias.reshape( |
260 | | - -1, 3, self.origin_hidden_size_per_attention_head |
261 | | - )[:, :, :self.half_origin_hidden_size_per_attention_head] |
262 | | - second_half = bias.reshape( |
263 | | - -1, 3, self.origin_hidden_size_per_attention_head |
264 | | - )[:, :, self.half_origin_hidden_size_per_attention_head:] |
265 | | - first_half_padded = torch.nn.functional.pad( |
266 | | - first_half, (0, self.half_pad_hidden_size_per_attention_head)) |
267 | | - second_half_padded = torch.nn.functional.pad( |
268 | | - second_half, (0, self.half_pad_hidden_size_per_attention_head)) |
269 | | - bias_padded = torch.cat([first_half_padded, second_half_padded], dim=2) |
270 | | - bias_final = bias_padded.reshape(-1) |
271 | | - return bias_final |
272 | | - |
273 | | - def pad_qkv_weight(self, data): |
274 | | - qkv_weight_first_half = data.reshape( |
275 | | - -1, 3, self.origin_hidden_size_per_attention_head, self.hidden_size |
276 | | - )[:, :, :self.half_origin_hidden_size_per_attention_head, :] |
277 | | - qkv_weight_second_half = data.reshape( |
278 | | - -1, 3, self.origin_hidden_size_per_attention_head, self.hidden_size |
279 | | - )[:, :, self.half_origin_hidden_size_per_attention_head:, :] |
280 | | - |
281 | | - qkv_weight_first_half_padded = torch.nn.functional.pad( |
282 | | - qkv_weight_first_half, |
283 | | - (0, 0, 0, self.half_pad_hidden_size_per_attention_head)) |
284 | | - qkv_weight_second_half_padded = torch.nn.functional.pad( |
285 | | - qkv_weight_second_half, |
286 | | - (0, 0, 0, self.half_pad_hidden_size_per_attention_head)) |
287 | | - qkv_weight_padded = torch.cat( |
288 | | - [qkv_weight_first_half_padded, qkv_weight_second_half_padded], |
289 | | - dim=2) |
290 | | - qkv_weight_final = qkv_weight_padded.reshape(-1, self.hidden_size) |
291 | | - |
292 | | - if is_enable_nz(): |
293 | | - qkv_weight_final_copy = torch.empty_like(qkv_weight_final).copy_( |
294 | | - qkv_weight_final) |
295 | | - qkv_weight_final_copy = torch_npu.npu_format_cast( |
296 | | - qkv_weight_final_copy, ACL_FORMAT_FRACTAL_ND) |
297 | | - return qkv_weight_final_copy |
298 | | - |
299 | | - return qkv_weight_final |
300 | | - |
301 | | - def pad_proj_weight(self, data): |
302 | | - out_weight = torch.nn.functional.pad( |
303 | | - data.reshape(self.hidden_size, -1, |
304 | | - self.half_origin_hidden_size_per_attention_head), |
305 | | - (0, self.half_pad_hidden_size_per_attention_head, 0, 0)).reshape( |
306 | | - self.hidden_size, -1) |
307 | | - |
308 | | - if is_enable_nz(): |
309 | | - out_weight_copy = torch.empty_like(out_weight).copy_(out_weight) |
310 | | - out_weight_copy = torch_npu.npu_format_cast( |
311 | | - out_weight_copy, ACL_FORMAT_FRACTAL_ND) |
312 | | - return out_weight_copy |
313 | | - |
314 | | - return out_weight |
315 | | - |
316 | | - def pad_qkv_weight_scale_offset(self, data): |
317 | | - reshaped_data = data.reshape( |
318 | | - -1, 3, self.origin_hidden_size_per_attention_head, 1) |
319 | | - data1 = reshaped_data[:, :, :self. |
320 | | - half_origin_hidden_size_per_attention_head, :] |
321 | | - data2 = reshaped_data[:, :, self. |
322 | | - half_origin_hidden_size_per_attention_head:, :] |
323 | | - data1_paded = torch.nn.functional.pad( |
324 | | - data1, (0, 0, 0, self.half_pad_hidden_size_per_attention_head, 0, |
325 | | - 0, 0, 0)) |
326 | | - data2_paded = torch.nn.functional.pad( |
327 | | - data2, (0, 0, 0, self.half_pad_hidden_size_per_attention_head, 0, |
328 | | - 0, 0, 0)) |
329 | | - res = torch.cat([data1_paded, data2_paded], dim=2) |
330 | | - res = res.reshape(-1, 1) |
331 | | - return res |
332 | | - |
333 | | - def pad_qkv_deq_scale_quant_bias(self, data): |
334 | | - reshaped_data = data.reshape( |
335 | | - -1, 3, self.origin_hidden_size_per_attention_head) |
336 | | - data1 = reshaped_data[:, :, :self. |
337 | | - half_origin_hidden_size_per_attention_head] |
338 | | - data2 = reshaped_data[:, :, |
339 | | - self.half_origin_hidden_size_per_attention_head:] |
340 | | - |
341 | | - data1_paded = torch.nn.functional.pad( |
342 | | - data1, (0, self.half_pad_hidden_size_per_attention_head)) |
343 | | - data2_paded = torch.nn.functional.pad( |
344 | | - data2, (0, self.half_pad_hidden_size_per_attention_head)) |
345 | | - |
346 | | - res = torch.cat([data1_paded, data2_paded], dim=2) |
347 | | - res = res.reshape(-1) |
348 | | - return res |
349 | | - |
350 | 162 | def load_weights(self, weights: Iterable[Tuple[str, |
351 | 163 | torch.Tensor]]) -> Set[str]: |
352 | 164 | stacked_params_mapping: list[tuple[str, str, Union[str, int]]] = [ |
@@ -377,24 +189,6 @@ def load_weights(self, weights: Iterable[Tuple[str, |
377 | 189 | weight_loader = getattr(param, "weight_loader", |
378 | 190 | default_weight_loader) |
379 | 191 | weight_loader(param, loaded_weight) |
380 | | - if ("attn.proj.weight_scale" in name or |
381 | | - "attn.proj.weight_offset" in name) and self.enable_pad: |
382 | | - continue |
383 | | - elif ("attn.proj.deq_scale" in name |
384 | | - or "attn.proj.quant_bias" in name) and self.enable_pad: |
385 | | - continue |
386 | | - elif ("attn.qkv.weight_scale" in name |
387 | | - or "attn.qkv.weight_offset" in name) and self.enable_pad: |
388 | | - param.data = self.pad_qkv_weight_scale_offset(param.data) |
389 | | - elif ("attn.qkv.deq_scale" in name |
390 | | - or "attn.qkv.quant_bias" in name) and self.enable_pad: |
391 | | - param.data = self.pad_qkv_deq_scale_quant_bias(param.data) |
392 | | - elif ("attn.proj.weight" in name) and self.enable_pad: |
393 | | - param.data = self.pad_proj_weight(param.data) |
394 | | - elif ("attn.qkv.weight" in name) and self.enable_pad: |
395 | | - param.data = self.pad_qkv_weight(param.data) |
396 | | - elif ("attn.qkv.bias" in name) and self.enable_pad: |
397 | | - param.data = self.pad_qkv_bias(param.data) |
398 | 192 | loaded_params.add(name) |
399 | 193 | return loaded_params |
400 | 194 |
|
@@ -501,7 +295,10 @@ def forward( |
501 | 295 | cu_seqlens_now = cu_seqlens |
502 | 296 | else: |
503 | 297 | cu_seqlens_now = cu_window_seqlens |
504 | | - x = blk(x, cu_seqlens=cu_seqlens_now, cos=cos, sin=sin) |
| 298 | + x = blk(x, |
| 299 | + cu_seqlens=cu_seqlens_now, |
| 300 | + rotary_pos_emb_cos=cos, |
| 301 | + rotary_pos_emb_sin=sin) |
505 | 302 |
|
506 | 303 | # adapter |
507 | 304 | x = self.merger(x) |
|
0 commit comments