Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 11 additions & 12 deletions python/sglang/srt/entrypoints/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,8 +353,7 @@ async def generate_from_file_request(file: UploadFile, request: Request):
obj = GenerateReqInput(
input_embeds=input_embeds,
sampling_params={
"repetition_penalty": 1.2,
"temperature": 0.2,
"temperature": 0.0,
"max_new_tokens": 512,
},
)
Expand Down Expand Up @@ -393,16 +392,6 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
return _create_error_response(e)


@app.api_route(
"/v1/rerank", methods=["POST", "PUT"], dependencies=[Depends(validate_json_request)]
)
async def v1_rerank_request(request: V1RerankReqInput, raw_request: Request):
"""Endpoint for reranking documents based on query relevance."""
return await raw_request.app.state.openai_serving_rerank.handle_request(
request, raw_request
)


@app.api_route("/flush_cache", methods=["GET", "POST"])
async def flush_cache():
"""Flush the radix cache."""
Expand Down Expand Up @@ -841,6 +830,16 @@ async def v1_score_request(request: ScoringRequest, raw_request: Request):
)


@app.api_route(
"/v1/rerank", methods=["POST", "PUT"], dependencies=[Depends(validate_json_request)]
)
async def v1_rerank_request(request: V1RerankReqInput, raw_request: Request):
"""Endpoint for reranking documents based on query relevance."""
return await raw_request.app.state.openai_serving_rerank.handle_request(
request, raw_request
)


def _create_error_response(e):
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
Expand Down
8 changes: 4 additions & 4 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,16 @@
from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union

from sglang.srt.managers.schedule_batch import BaseFinishReason
from sglang.srt.multimodal.mm_utils import has_valid_data
from sglang.srt.sampling.sampling_params import SamplingParams

# handle serialization of Image for pydantic
# Handle serialization of Image for pydantic
if TYPE_CHECKING:
from PIL.Image import Image
else:
Image = Any

from sglang.srt.managers.schedule_batch import BaseFinishReason
from sglang.srt.sampling.sampling_params import SamplingParams


@dataclass
class SessionParams:
Expand Down Expand Up @@ -182,6 +181,7 @@ def _handle_parallel_sampling(self):
# Determine parallel sample count
if self.sampling_params is None:
self.parallel_sample_num = 1
return
elif isinstance(self.sampling_params, dict):
self.parallel_sample_num = self.sampling_params.get("n", 1)
else: # isinstance(self.sampling_params, list):
Expand Down
1 change: 0 additions & 1 deletion python/sglang/srt/managers/multimodal_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def get_dummy_processor():
return DummyMultimodalProcessor()


@lru_cache()
def import_processors():
package_name = "sglang.srt.multimodal.processors"
package = importlib.import_module(package_name)
Expand Down
42 changes: 22 additions & 20 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,46 +180,48 @@ class Modality(Enum):
@dataclasses.dataclass
class MultimodalDataItem:
"""
A single multimodal data, from a single image/video/audio or others
A single multimodal data, from a single image/video/audio or others.

We put the common fields first and the model-specific fields last.
"""

modality: Modality

hash: int = None
pad_value: int = None

aspect_ratio_id: Optional[List[torch.Tensor]] = None
aspect_ratio_mask: Optional[List[torch.Tensor]] = None

image_sizes: Tuple[int, int] = None
image_offsets: Optional[list] = None

# the real data, pixel_values or audio_features
# data: Union[List[torch.Tensor], List[np.ndarray]]
pixel_values: Union[torch.Tensor, np.ndarray] = None
audio_features: Union[torch.Tensor, np.ndarray] = None
audio_feature_lens: Optional[List[torch.Tensor]] = None
audio_offsets: Optional[List[Tuple[int, int]]] = None
precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None

# For qwen-vl
image_grid_thw: Union[torch.Tensor, np.ndarray] = None
video_grid_thws: Union[torch.Tensor, np.ndarray] = None
second_per_grid_ts: Optional[List[torch.Tensor]] = None

# For deepseek-vl
image_emb_mask: Optional[torch.Tensor] = None
image_spatial_crop: Optional[torch.Tensor] = None
second_per_grid_ts: Optional[List[torch.Tensor]] = None

# For minicpmv
# [num_images, (n, w, h)]
tgt_size: Tuple[int, int] = None

# kimi-vl related
image_grid_hws: Optional[List[torch.Tensor]] = None
# For mllama
aspect_ratio_id: Optional[List[torch.Tensor]] = None
aspect_ratio_mask: Optional[List[torch.Tensor]] = None

audio_features: Union[torch.Tensor, np.ndarray] = None
audio_feature_lens: Optional[List[torch.Tensor]] = None
audio_offsets: Optional[List[Tuple[int, int]]] = None
# For kimi-vl
image_grid_hws: Optional[List[torch.Tensor]] = None

# gemma3n related
# For gemma3n
input_features: Optional[torch.Tensor] = None
input_features_mask: Optional[torch.Tensor] = None

precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None

@staticmethod
def is_empty_list(l):
if l is None:
Expand Down Expand Up @@ -339,10 +341,6 @@ class MultimodalInputs:
image_pad_len: Optional[list] = None
num_image_tokens: Optional[int] = None

# QWen2-VL related
mrope_positions: Optional[torch.Tensor] = None
mrope_position_delta: Optional[torch.Tensor] = None

# image
im_token_id: Optional[int] = None
im_start_id: Optional[int] = None
Expand All @@ -358,6 +356,10 @@ class MultimodalInputs:
audio_start_id: Optional[int] = None
audio_end_id: Optional[int] = None

# QWen2-VL related
mrope_positions: Optional[torch.Tensor] = None
mrope_position_delta: Optional[torch.Tensor] = None

@staticmethod
def from_dict(obj: dict):
ret = MultimodalInputs(
Expand Down
113 changes: 49 additions & 64 deletions python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,9 @@ class ReqState:

# For streaming output
last_output_offset: int = 0

# For incremental state update.
# TODO(lianmin): do not initialize some lists if not needed.
text: str = ""
output_ids: List[int] = dataclasses.field(default_factory=list)
input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
Expand Down Expand Up @@ -199,7 +201,6 @@ def __init__(
self.model_path = server_args.model_path
self.served_model_name = server_args.served_model_name
self.model_config = ModelConfig.from_server_args(server_args)

self.is_generation = self.model_config.is_generation
self.is_image_gen = self.model_config.is_image_gen
self.context_len = self.model_config.context_len
Expand Down Expand Up @@ -251,19 +252,36 @@ def __init__(
self.dump_requests_threshold = 1000
self.dump_request_list: List[Tuple] = []
self.log_request_metadata = self.get_log_request_metadata()
self.asyncio_tasks = set()
self.session_futures = {} # session_id -> asyncio event
self.max_req_input_len = None

# The event to notify the weight sync is finished.
self.model_update_lock = RWLock()
self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = (
None
)
self.asyncio_tasks = set()

# For session info
self.session_futures = {} # session_id -> asyncio event
# For pd disaggregtion
self.disaggregation_mode = DisaggregationMode(
self.server_args.disaggregation_mode
)
self.transfer_backend = TransferBackend(
self.server_args.disaggregation_transfer_backend
)
# Start kv boostrap server on prefill
if self.disaggregation_mode == DisaggregationMode.PREFILL:
# only start bootstrap server on prefill tm
kv_bootstrap_server_class = get_kv_class(
self.transfer_backend, KVClassType.BOOTSTRAP_SERVER
)
self.bootstrap_server = kv_bootstrap_server_class(
self.server_args.disaggregation_bootstrap_port
)

# Set after scheduler is initialized
self.max_req_input_len = None
# For load balancing
self.current_load = 0
self.current_load_lock = asyncio.Lock()

# Metrics
if self.enable_metrics:
Expand Down Expand Up @@ -393,56 +411,38 @@ def __init__(
]
)

# For pd disaggregtion
self.disaggregation_mode = DisaggregationMode(
self.server_args.disaggregation_mode
)
self.transfer_backend = TransferBackend(
self.server_args.disaggregation_transfer_backend
)
# Start kv boostrap server on prefill
if self.disaggregation_mode == DisaggregationMode.PREFILL:
# only start bootstrap server on prefill tm
kv_bootstrap_server_class = get_kv_class(
self.transfer_backend, KVClassType.BOOTSTRAP_SERVER
)
self.bootstrap_server = kv_bootstrap_server_class(
self.server_args.disaggregation_bootstrap_port
)

self.current_load = 0
self.current_load_lock = asyncio.Lock()

async def generate_request(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
request: Optional[fastapi.Request] = None,
):
created_time = time.time()

self.auto_create_handle_loop()
obj.normalize_batch_and_arguments()

if isinstance(obj, EmbeddingReqInput) and self.is_generation:
raise ValueError(
"This model does not appear to be an embedding model by default. "
"Please add `--is-embedding` when launching the server or try another model."
)

obj.normalize_batch_and_arguments()
if (
obj.return_hidden_states
and not self.server_args.enable_return_hidden_states
):
raise ValueError(
"The server is not configured to return the hidden states. "
"Please set `--enable-return-hidden-states` to enable this feature."
)

if isinstance(obj, GenerateReqInput):
return_hidden_states = obj.return_hidden_states
has_return_hidden_states = return_hidden_states == True or (
isinstance(return_hidden_states, list) and any(return_hidden_states)
if (
obj.custom_logit_processor
and not self.server_args.enable_custom_logit_processor
):
raise ValueError(
"The server is not configured to enable custom logit processor. "
"Please set `--enable-custom-logits-processor` to enable this feature."
)
if (
not self.server_args.enable_return_hidden_states
and has_return_hidden_states
):
raise ValueError(
"return_hidden_states=True requires the server to be started "
"with --enable-return-hidden-states (ServerArgs.enable_return_hidden_states)."
)

if self.log_requests:
max_length, skip_names, _ = self.log_request_metadata
Expand All @@ -451,8 +451,7 @@ async def generate_request(
)

async with self.model_update_lock.reader_lock:
is_single = obj.is_single
if is_single:
if obj.is_single:
tokenized_obj = await self._tokenize_one_request(obj)
state = self._send_one_request(obj, tokenized_obj, created_time)
async for response in self._wait_one_response(obj, state, request):
Expand Down Expand Up @@ -558,24 +557,6 @@ def _create_tokenized_object(
token_type_ids: Optional[List[int]] = None,
) -> Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]:
"""Create a tokenized request object from common parameters."""

if self.is_generation:
return_logprob = obj.return_logprob
logprob_start_len = obj.logprob_start_len
top_logprobs_num = obj.top_logprobs_num
token_ids_logprob = obj.token_ids_logprob
session_params = (
SessionParams(**obj.session_params) if obj.session_params else None
)
if (
obj.custom_logit_processor
and not self.server_args.enable_custom_logit_processor
):
raise ValueError(
"The server is not configured to enable custom logit processor. "
"Please set `--enable-custom-logits-processor` to enable this feature."
)

# Parse sampling parameters
# Note: if there are preferred sampling params, we use them if they are not
# explicitly passed in sampling_params
Expand All @@ -589,16 +570,20 @@ def _create_tokenized_object(

# Build return object
if isinstance(obj, GenerateReqInput):
session_params = (
SessionParams(**obj.session_params) if obj.session_params else None
)

tokenized_obj = TokenizedGenerateReqInput(
obj.rid,
input_text,
input_ids,
image_inputs,
sampling_params,
return_logprob,
logprob_start_len,
top_logprobs_num,
token_ids_logprob,
obj.return_logprob,
obj.logprob_start_len,
obj.top_logprobs_num,
obj.token_ids_logprob,
obj.stream,
bootstrap_host=obj.bootstrap_host,
bootstrap_port=obj.bootstrap_port,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def __init__(self, hf_config, server_args, _processor):
self._processor = _processor
self.arch = hf_config.architectures[0]
self.server_args = server_args

# FIXME: not accurate, model and image specific
self.NUM_TOKEN_PER_FRAME = 330

Expand Down
Loading