Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 18 additions & 19 deletions vllm/model_executor/models/qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,22 +218,22 @@ def __init__(self,
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False):
disable_tp: bool = False):
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
input_size=in_features,
output_sizes=[hidden_features] * 2, # [gate_proj, up_proj]
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
disable_tp=use_data_parallel)
disable_tp=disable_tp)

self.down_proj = RowParallelLinear(hidden_features,
in_features,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
disable_tp=use_data_parallel)
disable_tp=disable_tp)
self.act_fn = act_fn

def forward(self, x: torch.Tensor):
Expand Down Expand Up @@ -271,13 +271,13 @@ def __init__(
projection_size: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
disable_tp: bool = False,
attn_backend: _Backend = _Backend.TORCH_SDPA,
use_upstream_fa: bool = False,
) -> None:
super().__init__()
# Per attention head and per partition values.
self.tp_size = (1 if use_data_parallel else
self.tp_size = (1 if disable_tp else
parallel_state.get_tensor_model_parallel_world_size())
self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
self.hidden_size_per_attention_head = dist_utils.divide(
Expand All @@ -293,13 +293,13 @@ def __init__(
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.qkv",
disable_tp=use_data_parallel)
disable_tp=disable_tp)

self.proj = RowParallelLinear(input_size=projection_size,
output_size=embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.proj",
disable_tp=use_data_parallel)
disable_tp=disable_tp)
self.attn_backend = attn_backend
self.use_upstream_fa = use_upstream_fa
self.is_flash_attn_backend = self.attn_backend in {
Expand Down Expand Up @@ -425,7 +425,7 @@ def __init__(
norm_layer: Optional[Callable[[int], nn.Module]] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
disable_tp: bool = False,
attn_backend: _Backend = _Backend.TORCH_SDPA,
use_upstream_fa: bool = False,
) -> None:
Expand All @@ -434,22 +434,21 @@ def __init__(
norm_layer = partial(nn.LayerNorm, eps=1e-6)
self.norm1 = norm_layer(dim)
self.norm2 = norm_layer(dim)
self.attn = Qwen2_5_VisionAttention(
embed_dim=dim,
num_heads=num_heads,
projection_size=dim,
quant_config=quant_config,
prefix=f"{prefix}.attn",
use_data_parallel=use_data_parallel,
attn_backend=attn_backend,
use_upstream_fa=use_upstream_fa)
self.attn = Qwen2_5_VisionAttention(embed_dim=dim,
num_heads=num_heads,
projection_size=dim,
quant_config=quant_config,
prefix=f"{prefix}.attn",
disable_tp=disable_tp,
attn_backend=attn_backend,
use_upstream_fa=use_upstream_fa)
self.mlp = Qwen2_5_VisionMLP(dim,
mlp_hidden_dim,
act_fn=act_fn,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel)
disable_tp=disable_tp)

def forward(
self,
Expand Down Expand Up @@ -640,7 +639,7 @@ def __init__(
norm_layer=norm_layer,
quant_config=quant_config,
prefix=f"{prefix}.blocks.{layer_idx}",
use_data_parallel=use_data_parallel,
disable_tp=use_data_parallel,
attn_backend=self.attn_backend,
use_upstream_fa=use_upstream_fa) for layer_idx in range(depth)
])
Expand Down
45 changes: 22 additions & 23 deletions vllm/model_executor/models/qwen3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,22 +128,22 @@ def __init__(self,
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False):
disable_tp: bool = False):
super().__init__()
self.linear_fc1 = ColumnParallelLinear(in_features,
hidden_features,
bias=bias,
quant_config=quant_config,
return_bias=False,
prefix=f"{prefix}.linear_fc1",
disable_tp=use_data_parallel)
disable_tp=disable_tp)
self.linear_fc2 = RowParallelLinear(hidden_features,
in_features,
bias=bias,
quant_config=quant_config,
return_bias=False,
prefix=f"{prefix}.linear_fc2",
disable_tp=use_data_parallel)
disable_tp=disable_tp)
self.act_fn = act_fn

def forward(self, x: torch.Tensor):
Expand All @@ -162,7 +162,7 @@ def __init__(
norm_layer: Optional[Callable[[int], nn.Module]] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
disable_tp: bool = False,
attn_backend: _Backend = _Backend.TORCH_SDPA,
use_upstream_fa: bool = False,
) -> None:
Expand All @@ -171,22 +171,21 @@ def __init__(
norm_layer = partial(nn.LayerNorm, eps=1e-6)
self.norm1 = norm_layer(dim)
self.norm2 = norm_layer(dim)
self.attn = Qwen2_5_VisionAttention(
embed_dim=dim,
num_heads=num_heads,
projection_size=dim,
quant_config=quant_config,
prefix=f"{prefix}.attn",
use_data_parallel=use_data_parallel,
attn_backend=attn_backend,
use_upstream_fa=use_upstream_fa)
self.attn = Qwen2_5_VisionAttention(embed_dim=dim,
num_heads=num_heads,
projection_size=dim,
quant_config=quant_config,
prefix=f"{prefix}.attn",
disable_tp=disable_tp,
attn_backend=attn_backend,
use_upstream_fa=use_upstream_fa)
self.mlp = Qwen3_VisionMLP(dim,
mlp_hidden_dim,
act_fn=act_fn,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel)
disable_tp=disable_tp)

def forward(
self,
Expand Down Expand Up @@ -217,7 +216,7 @@ def __init__(
use_postshuffle_norm: bool = False,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
disable_tp: bool = False,
) -> None:
super().__init__()
self.hidden_size = context_dim * (spatial_merge_size**2)
Expand All @@ -234,14 +233,14 @@ def __init__(
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.linear_fc1",
disable_tp=use_data_parallel)
disable_tp=disable_tp)
self.act_fn = nn.GELU()
self.linear_fc2 = RowParallelLinear(self.hidden_size,
d_model,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.linear_fc2",
disable_tp=use_data_parallel)
disable_tp=disable_tp)

def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.use_postshuffle_norm:
Expand All @@ -263,7 +262,7 @@ def __init__(
norm_eps: float = 1e-6,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
disable_tp: bool = False,
) -> None:
super().__init__()
self.hidden_size = vision_config.hidden_size
Expand All @@ -274,7 +273,7 @@ def __init__(
self.spatial_merge_unit = self.spatial_merge_size**2
self.temporal_patch_size = vision_config.temporal_patch_size
self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes
self.use_data_parallel = use_data_parallel
self.disable_tp = disable_tp
self.num_grid_per_side = int(self.num_position_embeddings**0.5)

# NOTE: This is used for creating empty tensor for all_gather for
Expand Down Expand Up @@ -303,7 +302,7 @@ def __init__(
spatial_merge_size=self.spatial_merge_size,
quant_config=quant_config,
prefix=f"{prefix}.merger",
use_data_parallel=use_data_parallel,
disable_tp=disable_tp,
)

self.deepstack_merger_list = nn.ModuleList([
Expand All @@ -315,7 +314,7 @@ def __init__(
norm_layer=norm_layer,
quant_config=quant_config,
prefix=f"{prefix}.deepstack_merger_list.{layer_idx}",
use_data_parallel=use_data_parallel)
disable_tp=disable_tp)
for layer_idx in range(len(self.deepstack_visual_indexes))
])

Expand Down Expand Up @@ -344,7 +343,7 @@ def __init__(
norm_layer=norm_layer,
quant_config=quant_config,
prefix=f"{prefix}.blocks.{layer_idx}",
use_data_parallel=use_data_parallel,
disable_tp=disable_tp,
attn_backend=self.attn_backend,
use_upstream_fa=use_upstream_fa)
for layer_idx in range(vision_config.depth)
Expand Down Expand Up @@ -1134,7 +1133,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"):
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel,
disable_tp=self.use_data_parallel,
)

self.language_model = Qwen3LLMForCausalLM(vllm_config=vllm_config,
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/models/qwen3_vl_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,10 +314,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config: Qwen3VLMoeConfig = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
parallel_config = vllm_config.parallel_config

self.config = config
self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
self.sequence_parallel = parallel_config.use_sequence_parallel_moe

if not multimodal_config.get_limit_per_prompt("image") and \
not multimodal_config.get_limit_per_prompt("video"):
Expand All @@ -328,7 +330,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel,
disable_tp=self.use_data_parallel or self.sequence_parallel,
)

self.language_model = Qwen3MoeLLMForCausalLM(vllm_config=vllm_config,
Expand Down