Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions python/sglang/srt/layers/attention/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch.nn.functional as F
from einops import rearrange

from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
from sglang.srt.utils import is_cuda, print_info_once

_is_cuda = is_cuda()
Expand Down Expand Up @@ -365,19 +366,20 @@ def __init__(
**kwargs,
):
super().__init__()
world_size = parallel_state.get_tensor_model_parallel_world_size()
self.tp_size = world_size
self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
attn_tp_rank = get_attention_tp_rank()
attn_tp_size = get_attention_tp_size()
self.tp_size = attn_tp_size
self.tp_rank = attn_tp_rank
self.dropout = dropout
self.head_size = embed_dim // num_heads
self.hidden_size_per_attention_head = dist_utils.divide(
projection_size, num_heads
)
self.num_attention_heads_per_partition = dist_utils.divide(
num_dummy_heads + num_heads, world_size
num_dummy_heads + num_heads, self.tp_size
)
self.num_attention_kv_heads_per_partition = dist_utils.divide(
num_dummy_heads + num_heads, world_size
num_dummy_heads + num_heads, self.tp_size
)

self.q_size = self.num_attention_heads_per_partition * self.head_size
Expand Down Expand Up @@ -427,6 +429,8 @@ def __init__(
total_num_kv_heads=num_dummy_heads + num_heads,
bias=qkv_bias,
quant_config=quant_config,
tp_rank=self.tp_rank,
tp_size=self.tp_size,
prefix=add_prefix("qkv_proj", prefix),
)
else:
Expand All @@ -435,13 +439,17 @@ def __init__(
output_size=3 * self.dummy_dim,
bias=qkv_bias,
quant_config=quant_config,
tp_rank=self.tp_rank,
tp_size=self.tp_size,
prefix=add_prefix("qkv_proj", prefix),
)
self.proj = RowParallelLinear(
input_size=self.dummy_dim,
output_size=embed_dim,
bias=proj_bias,
quant_config=quant_config,
tp_rank=self.tp_rank,
tp_size=self.tp_size,
prefix=add_prefix("proj", prefix),
)

Expand Down
9 changes: 9 additions & 0 deletions python/sglang/srt/models/step3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,11 +531,18 @@ def __init__(
prefix: str = "",
) -> None:
super().__init__()
# Since this is a dense model,
# the MLP component likewise adopts a DP-MLP approach modeled after DP Attention.
# This choice may not represent the optimal solution and remains open to further deliberation.
attn_tp_rank = get_attention_tp_rank()
attn_tp_size = get_attention_tp_size()
self.fc1 = ColumnParallelLinear(
dim,
intermediate_size,
bias=bias,
quant_config=quant_config,
tp_rank=attn_tp_rank,
tp_size=attn_tp_size,
prefix=add_prefix("gate_proj", prefix),
)
self.act = ACT2FN[hidden_act] # quick_gelu
Expand All @@ -544,6 +551,8 @@ def __init__(
dim,
bias=bias,
quant_config=quant_config,
tp_rank=attn_tp_rank,
tp_size=attn_tp_size,
prefix=add_prefix("down_proj", prefix),
)

Expand Down
4 changes: 3 additions & 1 deletion python/sglang/srt/multimodal/processors/step3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from PIL import Image
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from transformers import BatchFeature, TensorType
from transformers import BatchFeature, ProcessorMixin, TensorType

from sglang.srt.models.step3_vl import Step3VLForConditionalGeneration
from sglang.srt.multimodal.processors.base_processor import (
Expand Down Expand Up @@ -276,6 +276,8 @@ def __init__(
super().__init__()

self.config = config
if isinstance(tokenizer, ProcessorMixin):
tokenizer = tokenizer.tokenizer
self.tokenizer = tokenizer

self.image_size = 728
Expand Down
Loading