diff --git a/benchmark/score/bench_score.py b/benchmark/score/bench_score.py new file mode 100644 index 000000000000..7a3571191618 --- /dev/null +++ b/benchmark/score/bench_score.py @@ -0,0 +1,558 @@ +""" +SGLang Scoring Benchmark Script + +This script benchmarks SGLang's scoring API performance using HTTP requests. + +Current Features: +- HTTP-only implementation (open source compatible) +- Uses /v1/score API endpoint directly +- Single item scoring with batching support +- Configurable RPS, duration, and batch sizes +- Progress tracking and detailed metrics +- Poisson and constant request distributions + +Usage: +- Update configuration variables at the top of the file +- Ensure SGLang server is running on the configured HTTP_URL +- Run: python bench_score.py +- Each request will contain ITEM_COUNT_VALUES items for batch scoring + +Note: This script has been updated to remove all GRPC dependencies and +focuses solely on HTTP-based scoring using the /v1/score endpoint. +""" + +import random +import asyncio +import os +import json +import aiohttp +from statistics import mean +from transformers import AutoTokenizer +import numpy as np +from tqdm import tqdm + +import concurrent.futures # For parallel prompt generation + +############################################################################### +# CONFIG +############################################################################### +# Server Configuration +SERVER_TYPE = "HTTP" # Fixed to HTTP for open source + +# HTTP Configuration +HTTP_URL = "http://localhost:30000/v1/score" # Use score API directly + +# Score API Config +# ITEM_COUNT_VALUES determines number of items per score request (batch size) +SCORE_QUERY_TOKENS = 120 +SCORE_ITEM_TOKENS = 180 +SCORE_MODEL_PATH = ("/shared/public/sharing/job-rank/kbehdin/" + "f389cde308efd4dbb8d9-2025-06-06-18-31-30/best_model/" + "epoch=0-step=498-HF") +SCORE_LABEL_TOKEN_IDS = [9454, 2753] # Yes/No token IDs + +# Array of RPS values to test +RPS_VALUES = [70] +# Array of duration values to test +DURATION_SECS_VALUES = [120] # Duration values in seconds +# Array of item count values to test +ITEM_COUNT_VALUES = [10] # Number of items per request +# Number of unique requests to generate (will be reused) +NUM_UNIQUE_REQUESTS = 100 +DISTRIBUTION = "POISSON" # Options: "CONSTANT", "POISSON" + +# Profiling Configuration +PROFILE = False # Enable profiling with START_PROFILE/STOP_PROFILE prompts +# Directory for profiler output +SGLANG_TORCH_PROFILER_DIR = "/shared/user/sglang-oss-trace/remove-decode" +if PROFILE: + os.environ["SGLANG_TORCH_PROFILER_DIR"] = SGLANG_TORCH_PROFILER_DIR + +# Special token to replicate for precise token counting +SPECIAL_REPLICATED_TOKEN = "<|im_start|>" + + +############################################################################### +# REQUEST GENERATION (in parallel) +############################################################################### +def prepare_all_requests_parallel(num_requests, item_count): + """ + Generates unique requests in parallel, then reuses them to create the + full request list. Returns a list of str prompts for HTTP. + """ + # Load tokenizer once here to verify special token and get precise counts + print("Loading tokenizer...") + tokenizer = AutoTokenizer.from_pretrained(SCORE_MODEL_PATH) + + # Verify that our special token produces exactly 1 token + special_token_count = len(tokenizer.encode( + SPECIAL_REPLICATED_TOKEN, add_special_tokens=False)) + print(f"Special token '{SPECIAL_REPLICATED_TOKEN}' produces " + f"{special_token_count} token(s)") + + def generate_text_with_token_count(num_toks): + """Generate text with precise token count using replicated token.""" + if special_token_count == 1: + # Simple case: token maps to exactly 1 token + return SPECIAL_REPLICATED_TOKEN * num_toks + else: + print(f"Special token '{SPECIAL_REPLICATED_TOKEN}' produces more than 1 token!!!") + # Handle case where special token produces multiple tokens + # Repeat the token enough times to get at least num_toks tokens + repetitions = ((num_toks + special_token_count - 1) // + special_token_count) + text = SPECIAL_REPLICATED_TOKEN * repetitions + + # Verify we got the expected token count (approximately) + actual_tokens = len(tokenizer.encode( + text, add_special_tokens=False)) + if actual_tokens < num_toks: + print(f"Warning: Generated {actual_tokens} tokens, " + f"expected {num_toks}") + + return text + + def build_request(index): + """Build a single request using the shared tokenizer.""" + try: + # Generate query and items for score API + query = generate_text_with_token_count(SCORE_QUERY_TOKENS) + items = [ + generate_text_with_token_count(SCORE_ITEM_TOKENS) + for _ in range(item_count)] + + # Return as dict for score API format + score_data = { + "query": query, + "items": items, + "label_token_ids": SCORE_LABEL_TOKEN_IDS, + "model": SCORE_MODEL_PATH + } + return (index, score_data) + + except Exception as e: + print(f"Error building request {index}: {e}") + return (index, None) + + # Generate only the unique requests + unique_requests = [None] * NUM_UNIQUE_REQUESTS + + # Use ThreadPoolExecutor instead of ProcessPoolExecutor to avoid + # tokenizer loading issues across processes + max_workers = min(8, os.cpu_count() or 1) # Limit to 8 threads max + + with concurrent.futures.ThreadPoolExecutor( + max_workers=max_workers) as executor: + futures = [] + for i in tqdm(range(NUM_UNIQUE_REQUESTS), + desc="Submitting prompt generation tasks"): + future = executor.submit(build_request, i) + futures.append(future) + + # Collect results as they complete + for f in tqdm(concurrent.futures.as_completed(futures), + desc="Building unique requests", + total=NUM_UNIQUE_REQUESTS): + try: + index, req_data = f.result() + if req_data is not None: + unique_requests[index] = req_data + else: + print(f"Failed to build request {index}") + except Exception as e: + print(f"Error processing request result: {e}") + + # Check if we have any valid requests + valid_requests = [req for req in unique_requests if req is not None] + if not valid_requests: + raise RuntimeError("Failed to generate any valid requests") + + print(f"Successfully generated {len(valid_requests)} out of " + f"{NUM_UNIQUE_REQUESTS} unique requests") + + # Create the full request list by cycling through unique requests + print(f"Reusing {len(valid_requests)} unique requests to create " + f"{num_requests} total requests...") + all_requests = [] + for i in tqdm(range(num_requests), desc="Reusing requests"): + unique_index = i % len(valid_requests) + all_requests.append(valid_requests[unique_index]) + + print("All prompts/requests prepared.\n") + return all_requests + + +############################################################################### +# PROFILING HELPERS +############################################################################### +async def send_profile_request(profile_text, item_count, session=None): + """Send a profile request and wait for completion.""" + try: + if session: + print(f"Sending {profile_text} request via HTTP...") + + # Determine the correct endpoint + base_url = HTTP_URL.rsplit('/', 2)[0] # Remove /v1/score + if profile_text == "START_PROFILE": + endpoint_url = f"{base_url}/start_profile" + elif profile_text == "STOP_PROFILE": + endpoint_url = f"{base_url}/stop_profile" + else: + print(f"Unknown profile request: {profile_text}") + return + + headers = {"Content-Type": "application/json"} + + async with session.post(endpoint_url, headers=headers) as resp: + resp_text = await resp.text() + if resp.status == 200: + print(f"{profile_text} request completed") + else: + print(f"{profile_text} request failed with status " + f"{resp.status}: {resp_text}") + else: + print(f"Cannot send {profile_text} request - missing session") + + except Exception as e: + print(f"Error sending {profile_text} request: {e}") + + +############################################################################### +# HTTP CALLS +############################################################################### +def build_http_request_json(score_data): + """Build HTTP request JSON for /v1/score endpoint. + + Score API format: + { + "query": "Generated query text with SCORE_QUERY_TOKENS tokens", + "items": ["item1", "item2", ...], # Items to score with SCORE_ITEM_TOKENS each + "label_token_ids": [token_id1, token_id2], # Target token IDs + "model": "/path/to/model" + } + + Args: + score_data: A dict containing query, items, label_token_ids, and model + """ + # score_data is already in the correct format from build_request + return json.dumps(score_data) + + +async def make_http_call(session, score_data, request_id, results_queue): + """HTTP call to /v1/score endpoint.""" + try: + start_time = asyncio.get_event_loop().time() + + request_json = build_http_request_json(score_data) + headers = {"Content-Type": "application/json"} + + async with session.post(HTTP_URL, data=request_json, + headers=headers) as resp: + resp_text = await resp.text() + + if resp.status != 200: + print(f"[HTTP] Request {request_id} failed with status " + f"{resp.status}: {resp_text}") + completion_time = asyncio.get_event_loop().time() + await results_queue.put((request_id, 0, False, + completion_time)) + return + + # Parse score API response + try: + response_data = json.loads(resp_text) + # Score API returns scores for each item + # For now, just verify we got a valid response + if "scores" in response_data or "logprobs" in response_data: + success = True + else: + print(f"[HTTP] Request {request_id} missing expected fields in response") + success = False + except json.JSONDecodeError: + print(f"[HTTP] Request {request_id} failed to parse JSON response") + success = False + + completion_time = asyncio.get_event_loop().time() + elapsed_time = (completion_time - start_time) * 1000 + await results_queue.put((request_id, elapsed_time, success, + completion_time)) + + except Exception as e: + print(f"[HTTP] Error for request {request_id}: {e}") + completion_time = asyncio.get_event_loop().time() + await results_queue.put((request_id, 0, False, completion_time)) + + +############################################################################### +# RESULTS +############################################################################### +async def process_results(results_queue, num_requests, send_duration, + total_duration, rps, duration_secs, item_count, + test_start_time): + """Processes results and groups them by minute intervals. + Returns a list of dictionaries, one for each minute.""" + all_results = [] + + # Collect all results + for _ in range(num_requests): + result = await results_queue.get() + request_id, elapsed_time, success, completion_time = result + all_results.append({ + 'request_id': request_id, + 'elapsed_time': elapsed_time, + 'success': success, + 'completion_time': completion_time + }) + + # Group results by minute intervals + minute_results = [] + num_minutes = int(duration_secs // 60) + ( + 1 if duration_secs % 60 > 0 else 0) + + for minute in range(num_minutes): + minute_start = test_start_time + (minute * 60) + minute_end = test_start_time + ((minute + 1) * 60) + + # Filter results that completed in this minute + minute_data = [r for r in all_results + if minute_start <= r['completion_time'] < minute_end] + + response_times = [r['elapsed_time'] for r in minute_data + if r['success']] + successful_requests = len([r for r in minute_data + if r['success']]) + failed_requests = len([r for r in minute_data + if not r['success']]) + + avg_response_time = mean(response_times) if response_times else 0 + + # Calculate percentiles using numpy + if response_times: + p50 = np.percentile(response_times, 50) + p90 = np.percentile(response_times, 90) + p99 = np.percentile(response_times, 99) + else: + p50 = p90 = p99 = 0 + + minute_result = { + 'test_duration_secs': duration_secs, + 'minute_interval': minute + 1, + 'target_rps': rps, + 'item_count': item_count, + 'server_type': SERVER_TYPE, + 'distribution': DISTRIBUTION, + 'unique_requests': NUM_UNIQUE_REQUESTS, + 'total_requests': len(minute_data), + 'successful_requests': successful_requests, + 'failed_requests': failed_requests, + 'send_duration_secs': send_duration, + 'total_duration_secs': total_duration, + 'avg_response_time_ms': avg_response_time, + 'p50_response_time_ms': p50, + 'p90_response_time_ms': p90, + 'p99_response_time_ms': p99 + } + + minute_results.append(minute_result) + + print(f"\nMinute {minute + 1} Summary for RPS {rps}, " + f"Duration {duration_secs}s, Item Count {item_count}:") + print(f" Requests completed in minute: {len(minute_data)}") + print(f" Successful requests: {successful_requests}") + print(f" Failed requests: {failed_requests}") + print(f" Average response time: {avg_response_time:.2f} ms") + print(f" P50 response time: {p50:.2f} ms") + print(f" P90 response time: {p90:.2f} ms") + print(f" P99 response time: {p99:.2f} ms") + + # Also print overall summary + all_response_times = [r['elapsed_time'] for r in all_results + if r['success']] + total_successful = len([r for r in all_results if r['success']]) + total_failed = len([r for r in all_results if not r['success']]) + + overall_avg = mean(all_response_times) if all_response_times else 0 + if all_response_times: + overall_p50 = np.percentile(all_response_times, 50) + overall_p90 = np.percentile(all_response_times, 90) + overall_p99 = np.percentile(all_response_times, 99) + else: + overall_p50 = overall_p90 = overall_p99 = 0 + + print(f"\nOverall Summary for RPS {rps}, Duration {duration_secs}s, " + f"Item Count {item_count}:") + print(f" Test duration: {duration_secs} seconds") + print(f" Server type: {SERVER_TYPE}") + print(f" HTTP mode: SINGLE_ITEM_SCORING") + print(f" Target RPS: {rps}") + print(f" Item count: {item_count}") + print(f" Distribution: {DISTRIBUTION}") + print(f" Unique requests generated: {NUM_UNIQUE_REQUESTS}") + print(f" Total requests sent: {num_requests}") + print(f" Successful requests: {total_successful}") + print(f" Failed requests: {total_failed}") + print(f" Time to send all requests: {send_duration:.2f} seconds") + print(f" Time for all requests to complete: {total_duration:.2f} seconds") + print(f" Average response time: {overall_avg:.2f} ms") + print(f" P50 response time: {overall_p50:.2f} ms") + print(f" P90 response time: {overall_p90:.2f} ms") + print(f" P99 response time: {overall_p99:.2f} ms\n") + + return minute_results + + +############################################################################### +# MAIN +############################################################################### +async def run_benchmark(rps, duration_secs, item_count): + """Run a single benchmark with the given RPS value.""" + num_requests = int(rps * duration_secs) + print(f"Starting benchmark with RPS={rps}, Duration={duration_secs}s, " + f"Item Count={item_count}, num_requests={num_requests}") + print(f"Server Type: {SERVER_TYPE}") + print(f"HTTP Mode: SINGLE_ITEM_SCORING") + print(f"Profiling Enabled: {PROFILE}") + + # Build requests in parallel (unmeasured) + all_requests = prepare_all_requests_parallel(num_requests, item_count) + + results_queue = asyncio.Queue() + tasks = [] + + # Track timing for sending requests + send_start_time = asyncio.get_event_loop().time() + + # HTTP implementation (open source only supports HTTP with /v1/score API) + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=300)) as session: + + # Send START_PROFILE if profiling is enabled + if PROFILE: + await send_profile_request("START_PROFILE", item_count, + session=session) + + # Add progress bar for sending requests + with tqdm(total=len(all_requests), + desc=f"Sending HTTP score requests at {rps} RPS", + unit="req") as pbar: + for i, score_data in enumerate(all_requests): + request_id = i + 1 + tasks.append( + asyncio.create_task( + make_http_call(session, score_data, request_id, + results_queue) + ) + ) + + # Update progress bar + pbar.update(1) + + # Throttle based on distribution + if i < len(all_requests) - 1: + if DISTRIBUTION == "CONSTANT": + interval = 1 / rps + await asyncio.sleep(interval) + elif DISTRIBUTION == "POISSON": + # For Poisson process, inter-arrival times follow + # exponential distribution + interval = random.expovariate(rps) + await asyncio.sleep(interval) + else: + raise ValueError( + f"Unknown distribution: {DISTRIBUTION}. " + f"Use 'CONSTANT' or 'POISSON'.") + + send_end_time = asyncio.get_event_loop().time() + send_duration = send_end_time - send_start_time + + # Wait for all requests to complete with progress tracking + print(f"Waiting for {len(tasks)} HTTP score requests to complete...") + with tqdm(total=len(tasks), desc="Completing HTTP score requests", + unit="req") as completion_pbar: + completed_tasks = [] + for task in asyncio.as_completed(tasks): + await task + completed_tasks.append(task) + completion_pbar.update(1) + + # Send STOP_PROFILE if profiling is enabled + if PROFILE: + await send_profile_request("STOP_PROFILE", item_count, + session=session) + + completion_end_time = asyncio.get_event_loop().time() + total_duration = completion_end_time - send_start_time + + return await process_results( + results_queue, num_requests, send_duration, total_duration, + rps, duration_secs, item_count, send_start_time) + + +async def main(): + """Main function that runs benchmarks for all RPS values.""" + total_combinations = (len(DURATION_SECS_VALUES) * len(RPS_VALUES) * + len(ITEM_COUNT_VALUES)) + print(f"Running benchmarks for {len(DURATION_SECS_VALUES)} duration " + f"values, {len(RPS_VALUES)} RPS values, and " + f"{len(ITEM_COUNT_VALUES)} item count values = " + f"{total_combinations} total combinations") + print(f"Server Type: {SERVER_TYPE}") + print(f"HTTP Mode: SINGLE_ITEM_SCORING") + print(f"Score API URL: {HTTP_URL}") + print(f"Query tokens per request: {SCORE_QUERY_TOKENS}") + print(f"Item tokens per item: {SCORE_ITEM_TOKENS}") + print(f"Items per request (batch size): {ITEM_COUNT_VALUES}") + print(f"Profiling Enabled: {PROFILE}") + print(f"Duration values: {DURATION_SECS_VALUES}") + print(f"RPS values: {RPS_VALUES}") + print(f"Item count values: {ITEM_COUNT_VALUES}") + print("="*80) + + all_results = [] + + for duration_secs in DURATION_SECS_VALUES: + for rps in RPS_VALUES: + for item_count in ITEM_COUNT_VALUES: + result = await run_benchmark(rps, duration_secs, item_count) + all_results.extend(result) # Extend with minute results + + # Print CSV header and results + print("\n" + "="*80) + print("FINAL CSV RESULTS:") + print("="*80) + + # CSV Header + headers = [ + "test_duration_secs", "minute_interval", "target_rps", "item_count", + "server_type", "distribution", "unique_requests", + "total_requests", "successful_requests", "failed_requests", + "send_duration_secs", "total_duration_secs", "avg_response_time_ms", + "p50_response_time_ms", "p90_response_time_ms", "p99_response_time_ms" + ] + print(",".join(headers)) + + # CSV Data + for result in all_results: + row = [ + result['test_duration_secs'], + result['minute_interval'], + result['target_rps'], + result['item_count'], + result['server_type'], + result['distribution'], + result['unique_requests'], + result['total_requests'], + result['successful_requests'], + result['failed_requests'], + f"{result['send_duration_secs']:.2f}", + f"{result['total_duration_secs']:.2f}", + f"{result['avg_response_time_ms']:.2f}", + f"{result['p50_response_time_ms']:.2f}", + f"{result['p90_response_time_ms']:.2f}", + f"{result['p99_response_time_ms']:.2f}" + ] + print(",".join(map(str, row))) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index c8d325f9ee28..ac385b4c8298 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -476,6 +476,97 @@ def __getitem__(self, i): ) +@dataclass +class ScoreReqInput: + """ + A request input for scoring/ranking tasks. + + This class is designed for prefill-only operations that compute logprobs for specific tokens + without generating new tokens. It's similar to EmbeddingReqInput but focused on scoring. + """ + + # Core input - at least one must be provided + text: Optional[Union[List[str], str]] = None + input_ids: Optional[Union[List[List[int]], List[int]]] = None + + # Request management + rid: Optional[Union[List[str], str]] = None + + # Token IDs to score (required for scoring operations) + # This is always a List[int] that applies to all prompts in a batch + token_ids_logprob: Optional[List[int]] = None + + # Whether to log metrics for this request + log_metrics: bool = True + + def normalize_batch_and_arguments(self): + """Normalize the batch size and arguments for scoring requests.""" + # At least one of text, input_ids, or input_embeds should be provided + if ( + self.text is None and self.input_ids is None + ): + raise ValueError( + "At least one of text, input_ids should be provided" + ) + + # text and input_ids cannot be provided at the same time + if self.text is not None and self.input_ids is not None: + raise ValueError("text and input_ids cannot be provided at the same time") + + # Determine batch size + self._determine_batch_size() + + # Fill in default arguments + if self.is_single: + self._normalize_single_inputs() + else: + self._normalize_batch_inputs() + + def _determine_batch_size(self): + """Determine if this is a single example or a batch and the batch size.""" + self.batch_size = 0 + self.is_single = True + + if self.text is not None: + if isinstance(self.text, str): + self.batch_size = 1 + self.is_single = True + else: + self.batch_size = len(self.text) + self.is_single = False + elif self.input_ids is not None: + if len(self.input_ids) == 0: + raise ValueError("input_ids cannot be empty.") + if isinstance(self.input_ids[0], int): + self.batch_size = 1 + self.is_single = True + else: + self.batch_size = len(self.input_ids) + self.is_single = False + + def _normalize_single_inputs(self): + """Normalize inputs for a single scoring example.""" + if self.rid is None: + self.rid = uuid.uuid4().hex + + def _normalize_batch_inputs(self): + """Normalize inputs for a batch of scoring examples.""" + if self.rid is None: + self.rid = [uuid.uuid4().hex for _ in range(self.batch_size)] + else: + assert isinstance(self.rid, list), "The rid should be a list." + + def __getitem__(self, i): + """Get a single scoring request from a batch.""" + return ScoreReqInput( + text=self.text[i] if self.text is not None else None, + input_ids=self.input_ids[i] if self.input_ids is not None else None, + token_ids_logprob=self.token_ids_logprob, # Same token_ids_logprob for all batch items + rid=self.rid[i], + log_metrics=self.log_metrics, + ) + + @dataclass class TokenizedGenerateReqInput: # The request id @@ -524,6 +615,26 @@ class TokenizedGenerateReqInput: data_parallel_rank: Optional[int] = None +@dataclass +class TokenizedScoreReqInput: + """ + A specialized tokenized request input for scoring/ranking tasks. + + This class is designed specifically for scoring operations and contains only + the fields necessary for scoring tasks. + """ + # The request id + rid: str + # The input text + input_text: str + # The input token ids + input_ids: List[int] + # Token IDs to score (required for scoring operations) + token_ids_logprob: List[int] + # Whether to log metrics + log_metrics: bool + + @dataclass class EmbeddingReqInput: # The input prompt. It can be a single prompt or a batch of prompts. diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index e698bf85b768..d6b3a0b7a56c 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -434,6 +434,7 @@ def __init__( bootstrap_room: Optional[int] = None, data_parallel_rank: Optional[int] = None, vocab_size: Optional[int] = None, + is_scoring_request: bool = False, ): # Input and output info self.rid = rid @@ -602,6 +603,9 @@ def __init__( # For data parallel rank routing self.data_parallel_rank: Optional[int] = data_parallel_rank + # Whether this is a scoring request (prefill-only) for the decoder model + self.is_scoring_request: bool = is_scoring_request + # the start index of the sent kv cache # We want to send it chunk by chunk for chunked prefill. # After every chunk forward, we do the following: @@ -896,6 +900,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): # Whether to return hidden states return_hidden_states: bool = False + # Whether this batch contains scoring requests (prefill-only) + is_scoring_batch: bool = False + # hicache pointer for synchronizing data loading from CPU to GPU hicache_consumer_index: int = 0 @@ -936,6 +943,7 @@ def init_new( spec_algorithm=spec_algorithm, enable_custom_logit_processor=enable_custom_logit_processor, return_hidden_states=any(req.return_hidden_states for req in reqs), + is_scoring_batch=any(req.is_scoring_request for req in reqs), chunked_req=chunked_req, ) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index b6cf72d4e553..97014457e901 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -91,6 +91,7 @@ SlowDownReqOutput, TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, + TokenizedScoreReqInput, UnloadLoRAAdapterReqInput, UnloadLoRAAdapterReqOutput, UpdateWeightFromDiskReqInput, @@ -184,6 +185,12 @@ class EmbeddingBatchResult: bid: int +@dataclass +class ScoreBatchResult: + logits_output: Optional[LogitsProcessorOutput] + bid: int + + class Scheduler( SchedulerOutputProcessorMixin, SchedulerUpdateWeightsMixin, @@ -491,6 +498,7 @@ def __init__( self._request_dispatcher = TypeBasedDispatcher( [ (TokenizedGenerateReqInput, self.handle_generate_request), + (TokenizedScoreReqInput, self.handle_score_request), # Use dedicated scoring handler (TokenizedEmbeddingReqInput, self.handle_embedding_request), (FlushCacheReqInput, self.flush_cache_wrapped), (AbortReq, self.abort_request), @@ -1182,6 +1190,57 @@ def handle_generate_request( else: self._add_request_to_queue(req) + def handle_score_request( + self, + recv_req: TokenizedScoreReqInput, + ): + """ + Handle scoring requests separately from generation requests. + + Scoring requests are prefill-only operations that compute logprobs for specific tokens + without generating new tokens. This method creates a specialized request object + optimized for scoring tasks. + """ + # Create a specialized Req object for scoring + # For scoring requests, we need to create a minimal SamplingParams object + # since scoring doesn't use sampling but the Req constructor requires it + from sglang.srt.sampling.sampling_params import SamplingParams + dummy_sampling_params = SamplingParams(max_new_tokens=0) + + req = Req( + rid=recv_req.rid, + origin_input_text=recv_req.input_text, + origin_input_ids=recv_req.input_ids, + sampling_params=dummy_sampling_params, + return_logprob=True, # Scoring requests always return logprobs + top_logprobs_num=0, # Scoring requests don't need top logprobs + token_ids_logprob=recv_req.token_ids_logprob, + stream=False, # Scoring requests don't stream + return_hidden_states=False, # No hidden states for scoring + eos_token_ids=self.model_config.hf_eos_token_id, + vocab_size=self.model_config.vocab_size, + is_scoring_request=True, + ) + + # Always set logprob_start_len to last token for scoring requests + req.logprob_start_len = len(recv_req.input_ids) - 1 + req.tokenizer = self.tokenizer + + # Validate prompt length + error_msg = validate_input_length( + req, + self.max_req_input_len, + self.server_args.allow_auto_truncate, + ) + if error_msg: + req.set_finish_with_abort(error_msg) + self._add_request_to_queue(req) + return + + # For scoring, we don't need grammar processing or other generation-specific features + # Just add the request to the queue for processing + self._add_request_to_queue(req) + def _add_request_to_queue(self, req: Req): req.queue_time_start = time.perf_counter() if self.disaggregation_mode == DisaggregationMode.PREFILL: @@ -1590,6 +1649,10 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: self.running_batch.filter_batch() if not self.running_batch.is_empty(): self.running_batch.prepare_for_decode() + # Only prepare for decode if this is not a scoring batch + # Scoring batches are prefill-only and should not allocate decode tokens + if not self.running_batch.is_scoring_batch: + self.running_batch.prepare_for_decode() new_batch.mix_with_running(self.running_batch) new_batch.decoding_reqs = self.running_batch.reqs self.running_batch = ScheduleBatch( @@ -1636,13 +1699,15 @@ def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]: if batch.batch_size() < initial_bs: batch.batch_is_full = False - # Update batch tensors - batch.prepare_for_decode() + # Only prepare for decode if this is not a scoring batch + # Scoring batches are prefill-only and should not allocate decode tokens + if not batch.is_scoring_batch: + batch.prepare_for_decode() return batch def run_batch( self, batch: ScheduleBatch - ) -> Union[GenerationBatchResult, EmbeddingBatchResult]: + ) -> Union[GenerationBatchResult, EmbeddingBatchResult, ScoreBatchResult]: """Run a batch.""" self.forward_ct += 1 @@ -1653,7 +1718,10 @@ def run_batch( time.sleep(self.forward_sleep_time) # Run forward - if self.is_generation: + if batch.is_scoring_batch: + ret = self._run_scoring_batch(batch) + + elif self.is_generation: if self.spec_algorithm.is_none(): model_worker_batch = batch.get_model_worker_batch() @@ -1721,13 +1789,45 @@ def run_batch( ) return ret + def _run_scoring_batch(self, batch: ScheduleBatch) -> ScoreBatchResult: + """Handle scoring requests (prefill-only) - separate from generation path.""" + model_worker_batch = batch.get_model_worker_batch() + + # update the consumer index of hicache to the running batch + self.tp_worker.set_hicache_consumer( + model_worker_batch.hicache_consumer_index + ) + + if self.pp_group.is_last_rank: + logits_output, _, can_run_cuda_graph = ( + self.tp_worker.forward_batch_generation(model_worker_batch) + ) + else: + _, _, can_run_cuda_graph = ( + self.tp_worker.forward_batch_generation(model_worker_batch) + ) + bid = model_worker_batch.bid + + # Mark scoring requests as finished after forward pass, but only if not chunked + for req in batch.reqs: + if req.is_chunked <= 0: + req.check_finished() + + return ScoreBatchResult( + logits_output=logits_output if self.pp_group.is_last_rank else None, + bid=bid, + ) + def process_batch_result( self, batch: ScheduleBatch, - result: Union[GenerationBatchResult, EmbeddingBatchResult], + result: Union[GenerationBatchResult, EmbeddingBatchResult, ScoreBatchResult], launch_done: Optional[threading.Event] = None, ): - if batch.forward_mode.is_decode(): + if isinstance(result, ScoreBatchResult): + # Handle scoring results (prefill-only) + self.process_batch_result_score(batch, result, launch_done) + elif batch.forward_mode.is_decode(): self.process_batch_result_decode(batch, result, launch_done) elif batch.forward_mode.is_extend(): self.process_batch_result_prefill(batch, result, launch_done) @@ -2350,7 +2450,7 @@ def is_health_check_generate_req(recv_req): def is_work_request(recv_req): - return isinstance(recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)) + return isinstance(recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput, TokenizedScoreReqInput)) def run_scheduler_process( diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index 635121920479..9f8d80f72dc4 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -8,12 +8,17 @@ from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.io_struct import AbortReq, BatchEmbeddingOut, BatchTokenIDOut -from sglang.srt.managers.schedule_batch import BaseFinishReason, Req, ScheduleBatch +from sglang.srt.managers.schedule_batch import ( + BaseFinishReason, + Req, + ScheduleBatch, +) if TYPE_CHECKING: from sglang.srt.managers.scheduler import ( EmbeddingBatchResult, GenerationBatchResult, + ScoreBatchResult, ScheduleBatch, Scheduler, ) @@ -295,6 +300,120 @@ def process_batch_result_decode( ): self.log_decode_stats(can_run_cuda_graph, running_batch=batch) + def process_batch_result_score( + self: Scheduler, + batch: ScheduleBatch, + result: ScoreBatchResult, + launch_done: Optional[threading.Event] = None, + ): + """ + Process scoring batch results. + + Scoring requests are prefill-only operations that compute logprobs for specific tokens + without generating new tokens. They are marked as finished in run_batch. + """ + logits_output = result.logits_output + + if self.enable_overlap: + # Resolve the last batch result for overlap scheduling + logits_output, _, _ = self.tp_worker.resolve_last_batch_result(launch_done) + else: + # Move logprobs to CPU if needed + if batch.return_logprob and logits_output is not None: + if logits_output.input_token_logprobs is not None: + logits_output.input_token_logprobs = tuple( + logits_output.input_token_logprobs.tolist() + ) + + # Free tokens from the memory pool (same as decode processing) + self.token_to_kv_pool_allocator.free_group_begin() + + # Process logprobs for scoring requests + if batch.return_logprob and logits_output is not None: + for i, req in enumerate(batch.reqs): + if req.is_scoring_request: + # For scoring requests, we use next_token logprobs since we want + # to know the probability of specific tokens at the next position + if (logits_output.next_token_token_ids_logprobs_val is not None and + logits_output.next_token_token_ids_logprobs_idx is not None): + + # Initialize all the logprob fields for scoring request + if req.input_token_logprobs_val is None: + req.input_token_logprobs_val = [] + if req.input_token_logprobs_idx is None: + req.input_token_logprobs_idx = [] + if req.input_top_logprobs_val is None: + req.input_top_logprobs_val = [] + if req.input_top_logprobs_idx is None: + req.input_top_logprobs_idx = [] + if req.output_token_logprobs_val is None: + req.output_token_logprobs_val = [] + if req.output_token_logprobs_idx is None: + req.output_token_logprobs_idx = [] + if req.output_top_logprobs_val is None: + req.output_top_logprobs_val = [] + if req.output_top_logprobs_idx is None: + req.output_top_logprobs_idx = [] + if req.input_token_ids_logprobs_val is None: + req.input_token_ids_logprobs_val = [] + if req.input_token_ids_logprobs_idx is None: + req.input_token_ids_logprobs_idx = [] + if req.output_token_ids_logprobs_val is None: + req.output_token_ids_logprobs_val = [] + if req.output_token_ids_logprobs_idx is None: + req.output_token_ids_logprobs_idx = [] + + # Store the next token logprobs as output logprobs for scoring + req.output_token_ids_logprobs_val.append( + logits_output.next_token_token_ids_logprobs_val[i] + ) + req.output_token_ids_logprobs_idx.append( + logits_output.next_token_token_ids_logprobs_idx[i] + ) + + # Handle cache cleanup for scoring requests (similar to prefill logic) + skip_stream_req = None + for i, req in enumerate(batch.reqs): + if req.is_scoring_request: + if req.is_retracted: + continue + + # Handle chunked scoring requests + if req.is_chunked <= 0: + # For scoring requests, check_finished() was already called in _run_scoring_batch() + # So we just need to cache the requests based on their current state + if req.finished(): + # Handle the "one extra delayed token" pattern for overlap scheduling + # Similar to decode processor, but only for non-ChunkCache types + from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache + is_chunk_cache = isinstance(self.tree_cache, (ChunkCache, SWAChunkCache)) + + if self.enable_overlap and not is_chunk_cache: + if self.page_size == 1: + self.token_to_kv_pool_allocator.free(batch.out_cache_loc[i : i + 1]) + else: + # Only free when the extra token is in a new page + if ( + len(req.origin_input_ids) + len(req.output_ids) - 1 + ) % self.page_size == 0: + self.token_to_kv_pool_allocator.free( + batch.out_cache_loc[i : i + 1] + ) + + self.tree_cache.cache_finished_req(req) + else: + self.tree_cache.cache_unfinished_req(req) + else: + # Scoring request is still being chunked, don't stream yet + req.is_chunked -= 1 + skip_stream_req = req + + # Stream the results back to the tokenizer manager + self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req) + + # Complete memory cleanup (same as decode processing) + self.token_to_kv_pool_allocator.free_group_end() + def add_input_logprob_return_values( self: Scheduler, i: int, diff --git a/python/sglang/srt/managers/session_controller.py b/python/sglang/srt/managers/session_controller.py index 34ee663ca036..67ede40e18b4 100644 --- a/python/sglang/srt/managers/session_controller.py +++ b/python/sglang/srt/managers/session_controller.py @@ -14,7 +14,7 @@ import uuid from typing import Dict, Optional -from sglang.srt.managers.io_struct import TokenizedGenerateReqInput +from sglang.srt.managers.io_struct import TokenizedGenerateReqInput, TokenizedScoreReqInput from sglang.srt.managers.schedule_batch import Req @@ -144,6 +144,7 @@ def create_req(self, req: TokenizedGenerateReqInput, tokenizer): return_logprob=req.return_logprob, top_logprobs_num=req.top_logprobs_num, token_ids_logprob=req.token_ids_logprob, + is_scoring_request=isinstance(req, TokenizedScoreReqInput), ) if last_req is not None: new_req.multimodal_inputs = last_req.multimodal_inputs diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 9250c6866eef..5aff99f52340 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -98,6 +98,7 @@ ReleaseMemoryOccupationReqOutput, ResumeMemoryOccupationReqInput, ResumeMemoryOccupationReqOutput, + ScoreReqInput, SessionParams, SetInternalStateReq, SetInternalStateReqOutput, @@ -105,6 +106,7 @@ SlowDownReqOutput, TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, + TokenizedScoreReqInput, UnloadLoRAAdapterReqInput, UnloadLoRAAdapterReqOutput, UpdateWeightFromDiskReqInput, @@ -140,7 +142,7 @@ class ReqState: out_list: List[Dict[Any, Any]] finished: bool event: asyncio.Event - obj: Union[GenerateReqInput, EmbeddingReqInput] + obj: Union[GenerateReqInput, EmbeddingReqInput, ScoreReqInput] # For metrics created_time: float @@ -451,7 +453,7 @@ def __init__( async def generate_request( self, - obj: Union[GenerateReqInput, EmbeddingReqInput], + obj: Union[GenerateReqInput, EmbeddingReqInput, ScoreReqInput], request: Optional[fastapi.Request] = None, ): created_time = time.time() @@ -481,7 +483,7 @@ async def generate_request( async def _tokenize_one_request( self, - obj: Union[GenerateReqInput, EmbeddingReqInput], + obj: Union[GenerateReqInput, EmbeddingReqInput, ScoreReqInput], ): """Tokenize one request.""" # Tokenize @@ -491,7 +493,7 @@ async def _tokenize_one_request( is_cross_encoder_request = ( isinstance(obj, EmbeddingReqInput) and obj.is_cross_encoder_request ) - if obj.input_embeds is not None: + if hasattr(obj, 'input_embeds') and obj.input_embeds is not None: if not self.server_args.disable_radix_cache: raise ValueError( "input_embeds is provided while disable_radix_cache is False. " @@ -546,7 +548,7 @@ async def _tokenize_one_request( ) def _validate_one_request( - self, obj: Union[GenerateReqInput, EmbeddingReqInput], input_ids: List[int] + self, obj: Union[GenerateReqInput, EmbeddingReqInput, ScoreReqInput], input_ids: List[int] ) -> None: """Validates that the input token count and the requested token count doesn't exceed the model's context length.""" @@ -558,6 +560,13 @@ def _validate_one_request( f"model's context length ({self.context_len} tokens)." ) + if isinstance(obj, ScoreReqInput): + if obj.token_ids_logprob is None: + raise ValueError( + "token_ids_logprob is required for scoring requests." + ) + return + if isinstance(obj, EmbeddingReqInput) and self.is_generation: raise ValueError( "This model does not appear to be an embedding model by default. " @@ -608,27 +617,39 @@ def _validate_input_ids_in_vocab( def _create_tokenized_object( self, - obj: Union[GenerateReqInput, EmbeddingReqInput], + obj: Union[GenerateReqInput, EmbeddingReqInput, ScoreReqInput], input_text: str, input_ids: List[int], input_embeds: Optional[Union[List[float], None]] = None, mm_inputs: Optional[Dict] = None, token_type_ids: Optional[List[int]] = None, - ) -> Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]: + ) -> Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput, TokenizedScoreReqInput]: """Create a tokenized request object from common parameters.""" - # Parse sampling parameters - # Note: if there are preferred sampling params, we use them if they are not - # explicitly passed in sampling_params - if self.preferred_sampling_params: - sampling_kwargs = {**self.preferred_sampling_params, **obj.sampling_params} + # Parse sampling parameters - skip for ScoreReqInput + if isinstance(obj, ScoreReqInput): + # ScoreReqInput doesn't use sampling parameters + sampling_params = SamplingParams(max_new_tokens=0) else: - sampling_kwargs = obj.sampling_params - sampling_params = SamplingParams(**sampling_kwargs) - sampling_params.normalize(self.tokenizer) - sampling_params.verify(self.model_config.vocab_size) + # Note: if there are preferred sampling params, we use them if they are not + # explicitly passed in sampling_params + if self.preferred_sampling_params: + sampling_kwargs = {**self.preferred_sampling_params, **obj.sampling_params} + else: + sampling_kwargs = obj.sampling_params + sampling_params = SamplingParams(**sampling_kwargs) + sampling_params.normalize(self.tokenizer) + sampling_params.verify(self.model_config.vocab_size) # Build return object - if isinstance(obj, GenerateReqInput): + if isinstance(obj, ScoreReqInput): + tokenized_obj = TokenizedScoreReqInput( + rid=obj.rid, + input_text=input_text, + input_ids=input_ids, + token_ids_logprob=obj.token_ids_logprob or [], + log_metrics=obj.log_metrics, + ) + elif isinstance(obj, GenerateReqInput): session_params = ( SessionParams(**obj.session_params) if obj.session_params else None ) @@ -667,8 +688,8 @@ def _create_tokenized_object( return tokenized_obj async def _batch_tokenize_and_process( - self, batch_size: int, obj: Union[GenerateReqInput, EmbeddingReqInput] - ) -> List[Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]]: + self, batch_size: int, obj: Union[GenerateReqInput, EmbeddingReqInput, ScoreReqInput] + ) -> List[Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput, TokenizedScoreReqInput]]: """Handle batch tokenization for text inputs only.""" logger.debug(f"Starting batch tokenization for {batch_size} text requests") @@ -683,7 +704,7 @@ async def _batch_tokenize_and_process( # Process all requests tokenized_objs = [] for i, req in enumerate(requests): - self._validate_token_len(obj[i], input_ids_list[i]) + self._validate_one_request(obj[i], input_ids_list[i]) tokenized_objs.append( self._create_tokenized_object( req, req.text, input_ids_list[i], None, None @@ -693,11 +714,12 @@ async def _batch_tokenize_and_process( return tokenized_objs def _validate_batch_tokenization_constraints( - self, batch_size: int, obj: Union[GenerateReqInput, EmbeddingReqInput] + self, batch_size: int, obj: Union[GenerateReqInput, EmbeddingReqInput, ScoreReqInput] ) -> None: """Validate constraints for batch tokenization processing.""" for i in range(batch_size): - if self.is_generation and obj[i].contains_mm_input(): + # Skip multimodal validation for ScoreReqInput since they don't support multimodal + if self.is_generation and not isinstance(obj[i], ScoreReqInput) and obj[i].contains_mm_input(): raise ValueError( "For multimodal input processing do not set `enable_tokenizer_batch_encode`." ) @@ -705,15 +727,15 @@ def _validate_batch_tokenization_constraints( raise ValueError( "Batch tokenization is not needed for pre-tokenized input_ids. Do not set `enable_tokenizer_batch_encode`." ) - if obj[i].input_embeds is not None: + if hasattr(obj[i], 'input_embeds') and obj[i].input_embeds is not None: raise ValueError( "Batch tokenization is not needed for input_embeds. Do not set `enable_tokenizer_batch_encode`." ) def _send_one_request( self, - obj: Union[GenerateReqInput, EmbeddingReqInput], - tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput], + obj: Union[GenerateReqInput, EmbeddingReqInput, ScoreReqInput], + tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput, TokenizedScoreReqInput], created_time: Optional[float] = None, ): self.send_to_scheduler.send_pyobj(tokenized_obj) @@ -723,7 +745,7 @@ def _send_one_request( async def _wait_one_response( self, - obj: Union[GenerateReqInput, EmbeddingReqInput], + obj: Union[GenerateReqInput, EmbeddingReqInput, ScoreReqInput], state: ReqState, request: Optional[fastapi.Request] = None, ): @@ -797,7 +819,7 @@ async def _wait_one_response( async def _handle_batch_request( self, - obj: Union[GenerateReqInput, EmbeddingReqInput], + obj: Union[GenerateReqInput, EmbeddingReqInput, ScoreReqInput], request: Optional[fastapi.Request] = None, created_time: Optional[float] = None, ): @@ -1289,7 +1311,7 @@ def configure_logging(self, obj: ConfigureLoggingReq): logging.info(f"Config logging: {obj=}") self.log_request_metadata = self.get_log_request_metadata() - def create_abort_task(self, obj: GenerateReqInput): + def create_abort_task(self, obj: Union[GenerateReqInput, ScoreReqInput]): # Abort the request if the client is disconnected. async def abort_request(): await asyncio.sleep(2) @@ -1483,13 +1505,20 @@ def _handle_batch_output( "prompt_tokens": recv_obj.prompt_tokens[i], } - if getattr(state.obj, "return_logprob", False): + # Call convert_logprob_style for regular requests with return_logprob=True + # and for scoring requests (which have token_ids_logprob) + if (getattr(state.obj, "return_logprob", False) or + hasattr(state.obj, "token_ids_logprob")): + + # Use token_ids_logprob for both regular and scoring requests + token_ids_logprob = state.obj.token_ids_logprob + self.convert_logprob_style( meta_info, state, - state.obj.top_logprobs_num, - state.obj.token_ids_logprob, - state.obj.return_text_in_logprobs + getattr(state.obj, "top_logprobs_num", 0), + token_ids_logprob, + getattr(state.obj, "return_text_in_logprobs", False) and not self.server_args.skip_tokenizer_init, recv_obj, i, @@ -1616,6 +1645,7 @@ def convert_logprob_style( ) if token_ids_logprob is not None: + if len(recv_obj.input_token_ids_logprobs_val) > 0: state.input_token_ids_logprobs_val.extend( recv_obj.input_token_ids_logprobs_val[recv_obj_index] @@ -1818,6 +1848,8 @@ async def score_request( """ See Engine.score() for more details. """ + + if label_token_ids is None: raise ValueError("label_token_ids must be provided") @@ -1840,13 +1872,12 @@ async def score_request( prompts = [f"{item}{query}" for item in items_list] else: prompts = [f"{query}{item}" for item in items_list] - batch_request = GenerateReqInput( + + batch_request = ScoreReqInput( text=prompts, - return_logprob=True, token_ids_logprob=label_token_ids, - stream=False, - sampling_params={"max_new_tokens": 1}, ) + elif ( isinstance(query, list) and isinstance(items, list) @@ -1858,13 +1889,12 @@ async def score_request( input_ids_list = [item + query for item in items] else: input_ids_list = [query + item for item in items] - batch_request = GenerateReqInput( + + batch_request = ScoreReqInput( input_ids=input_ids_list, - return_logprob=True, token_ids_logprob=label_token_ids, - stream=False, - sampling_params={"max_new_tokens": 1}, ) + else: raise ValueError( "Invalid combination of query/items types for score_request." @@ -1876,11 +1906,17 @@ async def score_request( for result in results: # Get logprobs for each token logprobs = {} - for logprob, token_id, _ in result["meta_info"].get( - "output_token_ids_logprobs", [] - )[0]: - if token_id in label_token_ids: - logprobs[token_id] = logprob + + # For scoring requests, we read from output_token_ids_logprobs since we want + # the logprobs for specific tokens at the next position (not input tokens) + output_logprobs = result["meta_info"].get("output_token_ids_logprobs", []) + + + + if output_logprobs and output_logprobs[0] is not None: + for logprob, token_id, _ in output_logprobs[0]: + if token_id in label_token_ids: + logprobs[token_id] = logprob # Get scores in order of label_token_ids score_list = [