Skip to content
35 changes: 25 additions & 10 deletions python/sglang/srt/hf_transformers_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
"""Utilities for Huggingface Transformers."""

import contextlib
import logging
import os
import warnings
from pathlib import Path
Expand Down Expand Up @@ -45,7 +44,7 @@
)
from sglang.srt.configs.internvl import InternVLChatConfig
from sglang.srt.connector import create_remote_connector
from sglang.srt.utils import is_remote_url, lru_cache_frozenset
from sglang.srt.utils import is_remote_url, logger, lru_cache_frozenset

_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
ChatGLMConfig.model_type: ChatGLMConfig,
Expand Down Expand Up @@ -317,15 +316,31 @@ def get_processor(

if config.model_type not in {"llava", "clip"}:
kwargs["use_fast"] = use_fast
try:
processor = AutoProcessor.from_pretrained(
tokenizer_name,
*args,
trust_remote_code=trust_remote_code,
revision=revision,
**kwargs,
)

processor = AutoProcessor.from_pretrained(
tokenizer_name,
*args,
trust_remote_code=trust_remote_code,
revision=revision,
**kwargs,
)

except ValueError as e:
error_message = str(e)
if "does not have a slow version" in error_message:
logger.info(
f"Processor {tokenizer_name} does not have a slow version. Automatically use fast version"
)
kwargs["use_fast"] = True
processor = AutoProcessor.from_pretrained(
tokenizer_name,
*args,
trust_remote_code=trust_remote_code,
revision=revision,
**kwargs,
)
else:
raise e
tokenizer = get_tokenizer_from_processor(processor)

attach_additional_stop_token_ids(tokenizer)
Expand Down
37 changes: 27 additions & 10 deletions python/sglang/srt/layers/attention/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import functools
import math
from functools import lru_cache, partial
from typing import Any, Optional, Tuple, Union
from typing import Any, Callable, Optional, Tuple, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -308,6 +308,7 @@ def forward(
cu_seqlens = cu_seqlens.to(dtype=torch.int32).to(q.device)
seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
max_seqlen = seq_lens.max().item()

output = flash_attn_varlen_func(
q,
k,
Expand Down Expand Up @@ -358,6 +359,9 @@ def __init__(
qkv_bias: bool = True,
qk_normalization: bool = False,
layer_norm_eps: float = 1e-06,
customized_position_embedding_applier: Callable[
[torch.Tensor, torch.Tensor, Any, Any], Tuple[torch.Tensor, torch.Tensor]
] = None,
**kwargs,
):
super().__init__()
Expand Down Expand Up @@ -392,6 +396,7 @@ def __init__(
self.dummy_dim, eps=layer_norm_eps, var_hidden_size=embed_dim
)

# priority: server_args > passed qkv_backend > sdpa
if global_server_args_dict["mm_attention_backend"] is None:
if qkv_backend is None:
qkv_backend = "sdpa"
Expand All @@ -401,6 +406,9 @@ def __init__(

print_info_once(f"Using {qkv_backend} as multimodal attention backend.")

self.customized_position_embedding_applier = (
customized_position_embedding_applier
)
self.qkv_backend = QKV_BACKEND_IMPL[qkv_backend](
head_dim=self.head_size,
num_heads=self.num_attention_heads_per_partition,
Expand Down Expand Up @@ -473,13 +481,13 @@ def forward(
if x.dim() == 2:
x = x.unsqueeze(0)
assert x.dim() == 3, x.shape
bsz, s, _ = x.shape
x_shape = x.shape
bsz, s, _ = x_shape
head = self.num_attention_heads_per_partition
kv_head = self.num_attention_kv_heads_per_partition
if self.use_qkv_parallel:
# [b, s, embed_dim] --> [b, s, embed_dim]
qkv, _ = self.qkv_proj(x)

q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)

# [b, s, embed_dim] --> [b * s, head, head_size]
Expand Down Expand Up @@ -508,16 +516,25 @@ def forward(
]

if position_embeddings is not None:
cos, sin = position_embeddings
original_shape = q.shape
# [total_tokens, head, head_size]
q = q.view(-1, head, self.head_size)
k = k.view(-1, head, self.head_size)

q, k = apply_rotary_pos_emb(q, k, cos, sin)
if self.customized_position_embedding_applier is not None:
q, k = self.customized_position_embedding_applier(
q, k, position_embeddings, x_shape
)
q = q.view(original_shape)
k = k.view(original_shape)
else:
cos, sin = position_embeddings

# [total_tokens, head, head_size]
q = q.view(-1, head, self.head_size)
k = k.view(-1, head, self.head_size)

q, k = apply_rotary_pos_emb(q, k, cos, sin)

q = q.view(original_shape)
k = k.view(original_shape)
q = q.view(original_shape)
k = k.view(original_shape)

if q.dim() == 4:
# [b, s, head, head_size] --> [b * s, head, head_size]
Expand Down
31 changes: 23 additions & 8 deletions python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@
BatchMultimodalOut,
BatchStrOut,
BatchTokenIDOut,
BlockReqType,
CloseSessionReqInput,
ConfigureLoggingReq,
EmbeddingReqInput,
Expand Down Expand Up @@ -202,13 +201,29 @@ def __init__(

if self.model_config.is_multimodal:
import_processors()
_processor = get_processor(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
revision=server_args.revision,
use_fast=not server_args.disable_fast_image_processor,
)
try:
_processor = get_processor(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
revision=server_args.revision,
use_fast=not server_args.disable_fast_image_processor,
)
except ValueError as e:
error_message = str(e)
if "does not have a slow version" in error_message:
logger.info(
f"Processor {server_args.tokenizer_path} does not have a slow version. Automatically use fast version"
)
_processor = get_processor(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
revision=server_args.revision,
use_fast=True,
)
else:
raise e
transport_mode = _determine_tensor_transport_mode(self.server_args)

# We want to parallelize the image pre-processing so we create an executor for it
Expand Down
13 changes: 11 additions & 2 deletions python/sglang/srt/models/llama4.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,13 +241,22 @@ def __init__(
if self.use_qk_norm
else None
)

qkv_quant_config = quant_config
o_quant_config = quant_config
if quant_config and hasattr(quant_config, "ignore") and quant_config.ignore:
if add_prefix("q_proj", prefix) in quant_config.ignore:
qkv_quant_config = None
if add_prefix("o_proj", prefix) in quant_config.ignore:
o_quant_config = None

self.qkv_proj = QKVParallelLinear(
hidden_size=hidden_size,
head_size=self.head_dim,
total_num_heads=self.total_num_heads,
total_num_kv_heads=self.total_num_kv_heads,
bias=bias,
quant_config=quant_config,
quant_config=qkv_quant_config,
prefix=add_prefix("qkv_proj", prefix),
tp_rank=attn_tp_rank,
tp_size=attn_tp_size,
Expand All @@ -257,7 +266,7 @@ def __init__(
input_size=self.total_num_heads * self.head_dim,
output_size=hidden_size,
bias=bias_o_proj,
quant_config=quant_config,
quant_config=o_quant_config,
prefix=add_prefix("o_proj", prefix),
tp_rank=attn_tp_rank,
tp_size=attn_tp_size,
Expand Down
Loading
Loading