Skip to content

Commit 3a9acfc

Browse files
committed
add get_forward_context_manager interface
Signed-off-by: shen-shanshan <[email protected]>
1 parent a5b476a commit 3a9acfc

File tree

2 files changed

+7
-44
lines changed

2 files changed

+7
-44
lines changed

vllm_ascend/patch/worker/patch_qwen2_5_vl.py

Lines changed: 1 addition & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -22,55 +22,14 @@
2222
import torch.nn.functional as F
2323
import torch_npu
2424
from einops import rearrange
25-
from vllm.model_executor.models.qwen2_5_vl import (
26-
Qwen2_5_VisionAttention, Qwen2_5_VLForConditionalGeneration)
25+
from vllm.model_executor.models.qwen2_5_vl import Qwen2_5_VisionAttention
2726

2827
import vllm_ascend.envs as envs_ascend
29-
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
3028

3129
MIN_PAD_SIZE = 64 # min_size to pad weight
3230
MAX_PAD_SIZE = 128 # max_size to pad weight
3331

3432

35-
class AscendQwen2_5_VLForConditionalGeneration(nn.Module):
36-
37-
def _process_image_input(self, image_input) -> tuple[torch.Tensor, ...]:
38-
39-
grid_thw = image_input["image_grid_thw"]
40-
assert grid_thw.ndim == 2
41-
42-
if image_input["type"] == "image_embeds":
43-
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
44-
else:
45-
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
46-
with set_ascend_forward_context(None, self.vllm_config):
47-
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
48-
49-
# Split concatenated embeddings for each image item.
50-
merge_size = self.visual.spatial_merge_size
51-
sizes = grid_thw.prod(-1) // merge_size // merge_size
52-
return image_embeds.split(sizes.tolist())
53-
54-
def _process_video_input(self, video_input) -> tuple[torch.Tensor, ...]:
55-
56-
grid_thw = video_input["video_grid_thw"]
57-
assert grid_thw.ndim == 2
58-
59-
if video_input["type"] == "video_embeds":
60-
video_embeds = video_input["video_embeds"].type(self.visual.dtype)
61-
else:
62-
pixel_values_videos = video_input["pixel_values_videos"].type(
63-
self.visual.dtype)
64-
with set_ascend_forward_context(None, self.vllm_config):
65-
video_embeds = self.visual(pixel_values_videos,
66-
grid_thw=grid_thw)
67-
68-
# Split concatenated embeddings for each video item.
69-
merge_size = self.visual.spatial_merge_size
70-
sizes = grid_thw.prod(-1) // merge_size // merge_size
71-
return video_embeds.split(sizes.tolist())
72-
73-
7433
@contextmanager
7534
def _padding_manager(
7635
q: torch.Tensor,
@@ -189,5 +148,3 @@ def forward(
189148

190149

191150
Qwen2_5_VisionAttention.forward = AscendQwen2_5_VisionAttention.forward
192-
Qwen2_5_VLForConditionalGeneration._process_image_input = AscendQwen2_5_VLForConditionalGeneration._process_image_input
193-
Qwen2_5_VLForConditionalGeneration._process_video_input = AscendQwen2_5_VLForConditionalGeneration._process_video_input

vllm_ascend/platform.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,3 +413,9 @@ def support_hybrid_kv_cache(cls) -> bool:
413413
@classmethod
414414
def support_static_graph_mode(cls) -> bool:
415415
return True
416+
417+
@classmethod
418+
def get_forward_context_manager(cls):
419+
from vllm_ascend.ascend_forward_context import \
420+
set_ascend_forward_context
421+
return set_ascend_forward_context

0 commit comments

Comments
 (0)