From fc574039e080de70a85097a3466730704b8b9c5a Mon Sep 17 00:00:00 2001 From: Ata Fatahi Date: Wed, 18 Jun 2025 12:13:23 -0700 Subject: [PATCH 1/4] purge VerlEngine Signed-off-by: Ata Fatahi --- .../offline_batch_inference_torchrun.py | 81 ---- .../srt/entrypoints/http_server_engine.py | 97 ---- python/sglang/srt/entrypoints/verl_engine.py | 179 -------- test/srt/test_verl_engine_server.py | 415 ------------------ 4 files changed, 772 deletions(-) delete mode 100644 examples/runtime/engine/offline_batch_inference_torchrun.py delete mode 100644 python/sglang/srt/entrypoints/verl_engine.py delete mode 100644 test/srt/test_verl_engine_server.py diff --git a/examples/runtime/engine/offline_batch_inference_torchrun.py b/examples/runtime/engine/offline_batch_inference_torchrun.py deleted file mode 100644 index d2185da0985a..000000000000 --- a/examples/runtime/engine/offline_batch_inference_torchrun.py +++ /dev/null @@ -1,81 +0,0 @@ -import datetime -import os -import sys - -from torch.distributed.device_mesh import init_device_mesh - -from sglang.srt.entrypoints.verl_engine import VerlEngine - - -def run(): - """ - Example command: - ``` - torchrun --nproc_per_node=8 offline_batch_inference_torchrun.py - ``` - """ - - local_rank = int(os.environ["LOCAL_RANK"]) - rank = int(os.environ["RANK"]) - world_size = int(os.environ["WORLD_SIZE"]) - - def _log(text): - t = datetime.datetime.now().strftime("%H:%M:%S") - print(f"[{t}] [rank={rank}] {text}") - - _log( - f'start {local_rank=} {rank=} {world_size=} {sys.argv=} {os.environ.get("CUDA_VISIBLE_DEVICES")}' - ) - - tp_size = 4 - dp_size = 2 - assert world_size == tp_size * dp_size - - device_mesh_kwargs = dict( - mesh_shape=(tp_size, dp_size, 1), mesh_dim_names=["tp", "dp", "pp"] - ) - device_mesh_cpu = init_device_mesh("cpu", **device_mesh_kwargs) - _log(f"{device_mesh_cpu=}") - - tp_rank = device_mesh_cpu.get_local_rank("tp") - dp_rank = device_mesh_cpu.get_local_rank("dp") - _log(f"{tp_rank=} {tp_size=} ; {dp_rank=} {dp_size=}") - - model_name, mem_fraction_static = "meta-llama/Llama-3.2-1B-Instruct", 0.1 - # model_name, mem_fraction_static = "meta-llama/Llama-3.1-70B-Instruct", 0.9 # test large models - # model_name, mem_fraction_static = "deepseek-ai/DeepSeek-V2-Lite", 0.8 - - for k in ["TORCHELASTIC_USE_AGENT_STORE"]: - if k in os.environ: - del os.environ[k] - - fragment = VerlEngine( - model_path=model_name, - mem_fraction_static=mem_fraction_static, - device_mesh_cpu=device_mesh_cpu["tp"], - base_gpu_id=dp_rank, - gpu_id_step=dp_size, - port=30000, - # for DeepSeek-V2-Lite + DP Attention - # enable_dp_attention=True, port=30000 + dp_rank * 100, - ) - _log(f"{fragment=}") - - prompt_all = [ - ["1+1=2, 1+2=3, 1+3=4, 1+4=", "9-1=8, 8-1=7, 7-1="], - ["2*1=2, 2*2=4, 2*3=", "8/2=4, 6/2="], - ] - prompt = prompt_all[dp_rank] - - output = fragment.generate( - prompt=prompt, - sampling_params=dict(max_new_tokens=16, temperature=0.0), - ) - _log(f"{prompt=} {output=}") - - fragment.shutdown() - _log(f"End script") - - -if __name__ == "__main__": - run() diff --git a/python/sglang/srt/entrypoints/http_server_engine.py b/python/sglang/srt/entrypoints/http_server_engine.py index ace569e56a4d..2600645b28ca 100644 --- a/python/sglang/srt/entrypoints/http_server_engine.py +++ b/python/sglang/srt/entrypoints/http_server_engine.py @@ -46,100 +46,3 @@ def launch_server_process(server_args: ServerArgs) -> multiprocessing.Process: p.terminate() raise TimeoutError("Server failed to start within the timeout period.") - - -class HttpServerEngineAdapter(EngineBase): - """ - You can use this class to launch a server from a VerlEngine instance. - We recommend using this class only you need to use http server. - Otherwise, you can use Engine directly. - """ - - def __init__(self, **kwargs): - self.server_args = ServerArgs(**kwargs) - print( - f"Launch HttpServerEngineAdapter at: {self.server_args.host}:{self.server_args.port}" - ) - self.process = launch_server_process(self.server_args) - - def _make_request(self, endpoint: str, payload: Optional[dict] = None): - """Make a POST request to the specified endpoint with the given payload. - - Args: - endpoint: The API endpoint to call - payload: The JSON payload to send (default: empty dict) - - Returns: - The JSON response from the server - """ - url = f"http://{self.server_args.host}:{self.server_args.port}/{endpoint}" - response = requests.post(url, json=payload or {}) - response.raise_for_status() - return response.json() - - def update_weights_from_tensor( - self, - named_tensors: List[Tuple[str, torch.Tensor]], - load_format: Optional[str] = None, - flush_cache: bool = False, - ): - """ - Update model weights from tensor data. The HTTP server will only post meta data, and the real weights will be copied directly from GPUs. - - Note: The model should be on GPUs rather than CPU for this functionality to work properly. - If you encounter issues, ensure your model is loaded on GPU devices rather than CPU. - """ - - return self._make_request( - "update_weights_from_tensor", - { - "serialized_named_tensors": [ - MultiprocessingSerializer.serialize(named_tensors, output_str=True) - for _ in range(self.server_args.tp_size) - ], - "load_format": load_format, - "flush_cache": flush_cache, - }, - ) - - def shutdown(self): - kill_process_tree(self.process.pid) - - def generate( - self, - prompt=None, - sampling_params=None, - input_ids=None, - image_data=None, - return_logprob=False, - logprob_start_len=None, - top_logprobs_num=None, - token_ids_logprob=None, - lora_path=None, - custom_logit_processor=None, - ): - payload = { - "text": prompt, - "sampling_params": sampling_params, - "input_ids": input_ids, - "image_data": image_data, - "return_logprob": return_logprob, - "logprob_start_len": logprob_start_len, - "top_logprobs_num": top_logprobs_num, - "token_ids_logprob": token_ids_logprob, - "lora_path": lora_path, - "custom_logit_processor": custom_logit_processor, - } - # Filter out None values - payload = {k: v for k, v in payload.items() if v is not None} - - return self._make_request("generate", payload) - - def release_memory_occupation(self): - return self._make_request("release_memory_occupation") - - def resume_memory_occupation(self): - return self._make_request("resume_memory_occupation") - - def flush_cache(self): - return self._make_request("flush_cache") diff --git a/python/sglang/srt/entrypoints/verl_engine.py b/python/sglang/srt/entrypoints/verl_engine.py deleted file mode 100644 index ab1ce8e165a9..000000000000 --- a/python/sglang/srt/entrypoints/verl_engine.py +++ /dev/null @@ -1,179 +0,0 @@ -# Copyright 2023-2024 SGLang Team -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -import os -from typing import Dict, Iterable, List, Literal, Optional, Tuple, Union - -import torch -import torch.distributed as dist -from PIL.Image import Image -from torch.distributed.tensor import DeviceMesh, DTensor - -from sglang.srt.entrypoints.engine import Engine -from sglang.srt.entrypoints.http_server_engine import HttpServerEngineAdapter -from sglang.srt.model_executor.model_runner import LocalSerializedTensor -from sglang.srt.patch_torch import monkey_patch_torch_reductions -from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj - - -class VerlEngine: - def __init__( - self, - device_mesh_cpu: DeviceMesh, - nnodes: int = 1, - backend: Literal["engine", "server"] = "engine", - **kwargs, - ): - monkey_patch_torch_reductions() - self._device_mesh_cpu = device_mesh_cpu - self._tp_rank = device_mesh_cpu.get_local_rank() - self._rank = device_mesh_cpu.get_rank() - self._tp_size = device_mesh_cpu.size() - tp_size_per_node = self._tp_size // nnodes - node_rank = self._tp_rank // tp_size_per_node - first_rank_in_node = self._tp_rank % tp_size_per_node == 0 - - # Common engine keyword arguments - engine_kwargs = dict( - **kwargs, tp_size=self._tp_size, node_rank=node_rank, nnodes=nnodes - ) - - if backend == "engine": - if first_rank_in_node: - os.environ["SGLANG_BLOCK_NONZERO_RANK_CHILDREN"] = "0" - self._engine = Engine(**engine_kwargs) - else: - self._engine = None - - elif backend == "server": - if self._tp_rank == 0: - self._engine = HttpServerEngineAdapter(**engine_kwargs) - else: - self._engine = None - else: - raise ValueError(f"Unsupported backend: {backend}") - - dist.barrier(group=self._device_mesh_cpu.get_group()) - - def generate( - self, - # The input prompt. It can be a single prompt or a batch of prompts. - prompt: Optional[Union[List[str], str]] = None, - sampling_params: Optional[Union[List[Dict], Dict]] = None, - # The token ids for text; one can either specify text or input_ids. - input_ids: Optional[Union[List[List[int]], List[int]]] = None, - # The image input. It can be an image instance, file name, URL, or base64 encoded string. - # Can be formatted as: - # - Single image for a single request - # - List of images (one per request in a batch) - # - List of lists of images (multiple images per request) - # See also python/sglang/srt/utils.py:load_image for more details. - image_data: Optional[ - Union[ - List[List[Union[Image, str]]], - List[Union[Image, str]], - Union[Image, str], - ] - ] = None, - return_logprob: Optional[Union[List[bool], bool]] = False, - logprob_start_len: Optional[Union[List[int], int]] = None, - top_logprobs_num: Optional[Union[List[int], int]] = None, - token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None, - lora_path: Optional[List[Optional[str]]] = None, - custom_logit_processor: Optional[Union[List[str], str]] = None, - ) -> Dict: - """ - The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`. - Please refer to `GenerateReqInput` for the documentation. - """ - if self._tp_rank == 0: - output = self._engine.generate( - prompt=prompt, - sampling_params=sampling_params, - input_ids=input_ids, - image_data=image_data, - return_logprob=return_logprob, - logprob_start_len=logprob_start_len, - top_logprobs_num=top_logprobs_num, - token_ids_logprob=token_ids_logprob, - lora_path=lora_path, - custom_logit_processor=custom_logit_processor, - ) - else: - output = None - - # Most naive implementation, can extract tensor and send via gloo if too slow - [output] = broadcast_pyobj( - data=[output], - rank=self._rank, - dist_group=self._device_mesh_cpu.get_group(), - src=self._device_mesh_cpu.mesh[0].item(), - force_cpu_device=False, - ) - - return output - - def update_weights_from_tensor( - self, - named_tensors: Iterable[Tuple[str, torch.Tensor]], - load_format: Optional[str] = None, - ): - # Most naive implementation, can optimize a lot if it is bottleneck - for tensor_index, (name, tensor) in enumerate(named_tensors): - serialized_tensor = MultiprocessingSerializer.serialize( - _preprocess_tensor_for_update_weights(tensor) - ) - - if self._tp_rank == 0: - gathered_serialized_tensors = [None for _ in range(self._tp_size)] - else: - gathered_serialized_tensors = None - dist.gather_object( - obj=serialized_tensor, - object_gather_list=gathered_serialized_tensors, - dst=self._device_mesh_cpu.mesh.tolist()[0], - group=self._device_mesh_cpu.get_group(), - ) - - if self._tp_rank == 0: - self._engine.update_weights_from_tensor( - named_tensors=[ - ( - name, - LocalSerializedTensor(values=gathered_serialized_tensors), - ) - ], - load_format=load_format, - flush_cache=False, - ) - - if self._tp_rank == 0: - self._engine.flush_cache() - - def release_memory_occupation(self): - if self._tp_rank == 0: - self._engine.release_memory_occupation() - - def resume_memory_occupation(self): - if self._tp_rank == 0: - self._engine.resume_memory_occupation() - - def shutdown(self): - if self._engine is not None: - self._engine.shutdown() - - -def _preprocess_tensor_for_update_weights(tensor: torch.Tensor): - if isinstance(tensor, DTensor): - return tensor.full_tensor() - return tensor diff --git a/test/srt/test_verl_engine_server.py b/test/srt/test_verl_engine_server.py deleted file mode 100644 index 6b7cbd0bf6fd..000000000000 --- a/test/srt/test_verl_engine_server.py +++ /dev/null @@ -1,415 +0,0 @@ -import multiprocessing -import multiprocessing as mp -import os -import random -import time -import traceback -import unittest -from multiprocessing import Process - -import requests -import torch -from openai import OpenAI -from torch.distributed.device_mesh import init_device_mesh -from torch.distributed.fsdp import CPUOffload -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp import MixedPrecision -from torch.distributed.fsdp.api import ( - ShardedStateDictConfig, - ShardingStrategy, - StateDictType, -) -from transformers import AutoModelForCausalLM - -from sglang.srt.entrypoints.verl_engine import VerlEngine -from sglang.srt.hf_transformers_utils import get_tokenizer -from sglang.srt.server_args import ServerArgs -from sglang.srt.utils import is_port_available -from sglang.test.runners import ( - HFRunner, - SRTRunner, - check_close_model_outputs, - get_dtype_str, -) -from sglang.test.test_utils import CustomTestCase, find_available_port, is_in_ci - -_MAX_NEW_TOKENS = 8 -_PROMPTS = ["1+1=2, 1+2=3, 1+3=4, 1+4=5, 1+5=", "1*1=1, 1*2=2, 1*3=3, 1*4=4, 1*5="] -_TORCH_DTYPE = torch.float16 - -# Set to false to temporarily debug issues unrelated to weight update -_ENABLE_UPDATE_WEIGHTS = True - -CI_MODELS = [ - dict(model_path="meta-llama/Llama-3.1-8B-Instruct"), - # Fail to run gemma-2-2b after transformers==4.48.3 -> 4.50.0 - # dict(model_path="google/gemma-2-2b"), -] -ALL_OTHER_MODELS = [ - dict(model_path="meta-llama/Llama-3.2-1B-Instruct", tp_size=1), - dict(model_path="Qwen/Qwen2-1.5B"), - # dict( - # model_path="Qwen/Qwen2.5-14B-Instruct", - # mem_fraction_static=0.4, - # tp_size=8, - # tight_memory=True, - # decode_tolerance=1.3, - # ), # test_generation_models.py same config (qwen + tp=8) gives 1.22 decode error - dict(model_path="HuggingFaceTB/SmolLM-135M-Instruct", tp_size=3), - # dict(model_path="allenai/OLMo-1B-0724-hf"), - # dict( - # model_path="THUDM/glm-4-9b-chat", - # mem_fraction_static=0.1, - # tp_size=8, - # tight_memory=True, - # ), - # dict(model_path="allenai/OLMo-2-1124-7B-Instruct"), - # dict( - # model_path="ibm-granite/granite-3.0-2b-instruct", - # prefill_tolerance=0.22, - # decode_tolerance=0.22, - # ), -] - -# This port is used for HTTP API communication with the VerlEngine server -# It handles client requests for text generation, weight updates, and memory management -# This port must be available and not used by other processes -PORT = find_available_port(2345) - -# Master port is used for PyTorch's distributed communication setup -# It enables tensor-parallel processes to communicate with each other -# Default is 23456, but we find an available port dynamically in assert_fragment_e2e_execution -# This port is critical for torch.distributed.init_process_group to function properly -# Each test needs a unique master_port to avoid conflicts between parallel test executions -# master_port = find_available_port(23456) # This is set in assert_fragment_e2e_execution method - - -class TestVerlEngine(CustomTestCase): - @classmethod - def setUpClass(cls): - multiprocessing.set_start_method("spawn") - - def assert_fragment_e2e_execution( - self, - index: int, - model_path: str, - mem_fraction_static: float = 0.4, - tp_size: int = 2, - tight_memory: bool = False, - prefill_tolerance: float = 0.1, - decode_tolerance: float = 0.1, - ): - """ - Tests VerlEngine with tensor parallelism across multiple processes. - - Spawns tp_size processes to test distributed execution, including: - - Model inference via direct API and HTTP server - - Weight updating functionality - - Memory management (release/resume) - - The test validates output correctness against a reference implementation - within specified tolerance bounds. - - Parameters: - ----------- - index: int - Test index for logging - model_path: str - HuggingFace model identifier - mem_fraction_static: float - Memory fraction for static tensors - tp_size: int - Number of tensor parallel processes - tight_memory: bool - Enable memory optimization - prefill_tolerance: float - Max error for prefill computation - decode_tolerance: float - Max error for decoding computation - """ - - master_port = find_available_port(23456) - - print(f"assert_fragment_e2e_execution START {index=} {model_path=}") - - processes = [] - output_reader, output_writer = mp.Pipe(duplex=False) - for tp_rank in range(tp_size): - p = Process( - target=_run_subprocess, - kwargs=dict( - tp_rank=tp_rank, - tp_size=tp_size, - master_port=master_port, - output_writer=output_writer, - model_path=model_path, - mem_fraction_static=mem_fraction_static, - tight_memory=tight_memory, - prefill_tolerance=prefill_tolerance, - decode_tolerance=decode_tolerance, - ), - ) - p.start() - processes.append(p) - - for _ in range(tp_size): - self.assertTrue( - output_reader.recv(), - f"Subprocess has error, please see logs above. ({index=} {model_path=})", - ) - - for p in processes: - p.join() - - def test_models(self): - """ - Orchestrates end-to-end testing across configured model sets. - - In CI environments: Randomly selects one model for faster testing. - In development: Tests all configured models for comprehensive validation. - - Each model configuration specifies model path, memory settings, - tensor-parallel size, and error tolerance bounds. - """ - test_models = ALL_OTHER_MODELS - if is_in_ci(): - # Randomly select one model in CI for faster testing - test_models = [random.choice(ALL_OTHER_MODELS)] - # Test all models in development environment - print(f"Development environment: Testing all {len(ALL_OTHER_MODELS)} models") - for index, model_info in enumerate(test_models): - self.assert_fragment_e2e_execution(index=index, **model_info) - - -def _run_subprocess( - tp_rank: int, - tp_size: int, - master_port: int, - output_writer, - model_path: str, - mem_fraction_static: float, - tight_memory: bool, - prefill_tolerance: float, - decode_tolerance: float, -): - """ - Executes a single tensor-parallel process for testing VerlEngine. - - Performs the core test operations: - 1. Initializes distributed environment - 2. Loads HuggingFace model for reference - 3. Tests VerlEngine API (generation, memory management, weight updates) - 4. Tests OpenAI-compatible endpoints on rank 0 - - Reports success/failure via output_writer pipe. - - Parameters: - tp_rank: int - Process rank in tensor parallel group - tp_size: int - Total processes in tensor parallel group - master_port: int - Port for distributed communication - output_writer - Pipe for result communication - model_path: str - HuggingFace model identifier - mem_fraction_static: float - Static memory allocation fraction - tight_memory: bool - Memory optimization flag - prefill_tolerance: float - Acceptable prefill error - decode_tolerance: float - Acceptable decode error - """ - try: - print(f"subprocess[{tp_rank=}] Start {os.environ.get('CUDA_VISIBLE_DEVICES')=}") - - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = str(master_port) - torch.distributed.init_process_group(rank=tp_rank, world_size=tp_size) - torch.cuda.set_device(tp_rank) - - mesh_kwargs = dict(mesh_shape=(tp_size, 1), mesh_dim_names=["tp", "pp"]) - inference_device_mesh_device = init_device_mesh("cuda", **mesh_kwargs) - inference_device_mesh_cpu = init_device_mesh("cpu", **mesh_kwargs) - # Print basic information about this subprocess including: - # - Current tensor-parallel rank - # - Device mesh configuration for both CUDA and CPU - # - This subprocess's role in testing tensor-parallel execution - # - How it contributes to the distributed model testing - print( - f"subprocess[{tp_rank=}] initialized for VerlEngine testing - " - f"Role: Shard {tp_rank+1}/{tp_size} of tensor-parallel model execution | " - f"Device meshes: CUDA={inference_device_mesh_device}, CPU={inference_device_mesh_cpu}" - ) - - # hf model is used for comparison - hf_model = AutoModelForCausalLM.from_pretrained( - model_path, torch_dtype=_TORCH_DTYPE, trust_remote_code=True - ).cuda() - hf_tokenizer = get_tokenizer(model_path, trust_remote_code=True) - - hf_outputs = HFRunner.forward_generation_raw( - base_model=hf_model, - prompts=_PROMPTS, - max_new_tokens=_MAX_NEW_TOKENS, - tokenizer=hf_tokenizer, - lora_paths=None, - torch_dtype=_TORCH_DTYPE, - output_str_only=False, - ) - - if _ENABLE_UPDATE_WEIGHTS: - if tight_memory: - # If tight_memory is True, we need to move the model to CPU to save memory - hf_model.cpu() - torch.cuda.empty_cache() - - # test update weights - print(f"subprocess[{tp_rank=}] get_fsdp_state_dict", flush=True) - fsdp_state_dict = _get_fsdp_state_dict(hf_model=hf_model, tp_size=tp_size) - - engine = VerlEngine( - model_path=model_path, - load_format="dummy" if _ENABLE_UPDATE_WEIGHTS else "auto", - mem_fraction_static=mem_fraction_static, - random_seed=42, - trust_remote_code=True, - dtype=get_dtype_str(_TORCH_DTYPE), - device_mesh_cpu=inference_device_mesh_cpu["tp"], - backend="server", - enable_memory_saver=True, - port=PORT, - ) - # test direct generate API with multiple different requests - print( - f"subprocess[{tp_rank=}] testing direct generate API with multiple requests" - ) - - # Request 1: Basic generation with temperature - print(f"subprocess[{tp_rank=}] test request 1: Basic generation") - direct_response = engine.generate( - prompt="Hello, world!", - sampling_params={"temperature": 0.7, "max_new_tokens": 20}, - ) - print(f"Response 1: {direct_response}") - - # Request 2: Zero temperature (greedy) generation - print(f"subprocess[{tp_rank=}] test request 2: Greedy generation") - direct_response = engine.generate( - prompt="Complete this sequence: 1, 2, 3,", - sampling_params={"temperature": 0.0, "max_new_tokens": 10}, - ) - print(f"Response 2: {direct_response}") - - # Request 3: Batch generation - print(f"subprocess[{tp_rank=}] test request 3: Batch generation") - batch_response = engine.generate( - prompt=["Translate 'hello' to French:", "Translate 'goodbye' to Spanish:"], - sampling_params={"temperature": 0.8, "max_new_tokens": 15}, - ) - print(f"Response 3: {batch_response}") - - # test memory occupation APIs - print(f"subprocess[{tp_rank=}] testing memory occupation APIs") - engine.release_memory_occupation() - print("Memory released") - # time.sleep(1) - engine.resume_memory_occupation() - print("Memory resumed") - - # openai API test for reference - torch.distributed.barrier() - if tp_rank == 0: - client = OpenAI(api_key="None", base_url=f"http://localhost:{PORT}/v1") - print(client.models.list().data[0].id) - - # Multiple HTTP API requests - print("Testing HTTP API with multiple requests") - - # Request 1 - url = f"http://localhost:{PORT}/generate" - data = {"text": "1*1=1, 1*2=2, 1*3=3, 1*4=4, 1*5="} - response = requests.post(url, json=data) - print(f"HTTP Response 1: {response.json()}") - - # Request 2 - data = { - "text": "The capital of France is", - "sampling_params": {"temperature": 0.2}, - } - response = requests.post(url, json=data) - print(f"HTTP Response 2: {response.json()}") - - # Request 3 - data = { - "text": "List three colors:", - "sampling_params": {"top_p": 0.95, "max_new_tokens": 25}, - } - response = requests.post(url, json=data) - print(f"HTTP Response 3: {response.json()}") - - if _ENABLE_UPDATE_WEIGHTS: - print(f"subprocess[{tp_rank=}] call update_weights_from_tensor", flush=True) - - engine.update_weights_from_tensor( - [(k, v) for k, v in fsdp_state_dict.items()] - ) - - # Final generation test after weight update - print(f"subprocess[{tp_rank=}] testing generation after weight update") - direct_response = engine.generate( - prompt="After weight update: Hello, world!", - sampling_params={"temperature": 0.7, "max_new_tokens": 20}, - ) - print(f"Post-update response: {direct_response}") - - execution_ok = True - - except Exception as e: - print(f"subprocess[{tp_rank=}] has error: {e}", flush=True) - traceback.print_exc() - execution_ok = False - - output_writer.send(execution_ok) - output_writer.close() - - engine.shutdown() - print(f"subprocess[{tp_rank=}] end", flush=True) - - -# Adapted from https://github.com/volcengine/verl/blob/main/tests/rollout/run_fsdp_vllm.py -def _get_fsdp_state_dict(hf_model, tp_size: int): - """ - Creates a sharded state dictionary for weight update testing. - - Wraps the HuggingFace model with FSDP (FullyShardedDataParallel), - configures precision settings, and returns a sharded state dict - for testing VerlEngine's weight update capabilities. - - Parameters: - hf_model - HuggingFace model to wrap - tp_size: int - Number of tensor-parallel shards - - Returns: - dict - Sharded state dict for update_weights_from_tensor - """ - device_mesh = init_device_mesh( - "cuda", mesh_shape=(tp_size,), mesh_dim_names=["fsdp"] - ) - - mixed_precision = MixedPrecision( - param_dtype=torch.bfloat16, - reduce_dtype=torch.float32, - buffer_dtype=torch.float32, - ) - fsdp_model = FSDP( - hf_model, - use_orig_params=True, - auto_wrap_policy=None, - device_id=torch.cuda.current_device(), - sharding_strategy=ShardingStrategy.FULL_SHARD, - mixed_precision=mixed_precision, - cpu_offload=CPUOffload(offload_params=False), - sync_module_states=False, - device_mesh=device_mesh, - ) - print(f"{fsdp_model=}") - - FSDP.set_state_dict_type( - fsdp_model, - state_dict_type=StateDictType.SHARDED_STATE_DICT, - state_dict_config=ShardedStateDictConfig(), - ) - - return fsdp_model.state_dict() - - -if __name__ == "__main__": - unittest.main() From b30bd55c08912f4191bc04e0312c5ab1d318ebc7 Mon Sep 17 00:00:00 2001 From: Ata Fatahi Date: Wed, 18 Jun 2025 13:13:14 -0700 Subject: [PATCH 2/4] exclude verl tests from suite Signed-off-by: Ata Fatahi --- test/srt/run_suite.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 42e52de4ba97..327d084d9913 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -144,7 +144,6 @@ class TestFile: TestFile("test_moe_ep.py", 181), TestFile("test_patch_torch.py", 19), TestFile("test_update_weights_from_distributed.py", 103), - TestFile("test_verl_engine_2_gpu.py", 64), TestFile("test_release_memory_occupation.py", 44), ], "per-commit-2-gpu-amd": [ @@ -157,7 +156,6 @@ class TestFile: "per-commit-4-gpu": [ TestFile("test_local_attn.py", 250), TestFile("test_pp_single_node.py", 150), - TestFile("test_verl_engine_4_gpu.py", 64), ], "per-commit-4-gpu-amd": [ TestFile("test_pp_single_node.py", 150), From 4e97971b618b9d4f9d3929f107442ae550c4cba4 Mon Sep 17 00:00:00 2001 From: Ata Fatahi Date: Thu, 19 Jun 2025 22:15:23 -0700 Subject: [PATCH 3/4] add back HttpServerEngineAdapter Signed-off-by: Ata Fatahi --- .../srt/entrypoints/http_server_engine.py | 94 +++++++++++++++++++ 1 file changed, 94 insertions(+) diff --git a/python/sglang/srt/entrypoints/http_server_engine.py b/python/sglang/srt/entrypoints/http_server_engine.py index 2600645b28ca..abbd3fc3d000 100644 --- a/python/sglang/srt/entrypoints/http_server_engine.py +++ b/python/sglang/srt/entrypoints/http_server_engine.py @@ -46,3 +46,97 @@ def launch_server_process(server_args: ServerArgs) -> multiprocessing.Process: p.terminate() raise TimeoutError("Server failed to start within the timeout period.") + + +class HttpServerEngineAdapter(EngineBase): + """ + You can use this class to launch a server from a VerlEngine instance. + We recommend using this class only you need to use http server. + Otherwise, you can use Engine directly. + """ + + def __init__(self, **kwargs): + self.server_args = ServerArgs(**kwargs) + print( + f"Launch HttpServerEngineAdapter at: {self.server_args.host}:{self.server_args.port}" + ) + self.process = launch_server_process(self.server_args) + + def _make_request(self, endpoint: str, payload: Optional[dict] = None): + """Make a POST request to the specified endpoint with the given payload. + Args: + endpoint: The API endpoint to call + payload: The JSON payload to send (default: empty dict) + Returns: + The JSON response from the server + """ + url = f"http://{self.server_args.host}:{self.server_args.port}/{endpoint}" + response = requests.post(url, json=payload or {}) + response.raise_for_status() + return response.json() + + def update_weights_from_tensor( + self, + named_tensors: List[Tuple[str, torch.Tensor]], + load_format: Optional[str] = None, + flush_cache: bool = False, + ): + """ + Update model weights from tensor data. The HTTP server will only post meta data, and the real weights will be copied directly from GPUs. + Note: The model should be on GPUs rather than CPU for this functionality to work properly. + If you encounter issues, ensure your model is loaded on GPU devices rather than CPU. + """ + + return self._make_request( + "update_weights_from_tensor", + { + "serialized_named_tensors": [ + MultiprocessingSerializer.serialize(named_tensors, output_str=True) + for _ in range(self.server_args.tp_size) + ], + "load_format": load_format, + "flush_cache": flush_cache, + }, + ) + + def shutdown(self): + kill_process_tree(self.process.pid) + + def generate( + self, + prompt=None, + sampling_params=None, + input_ids=None, + image_data=None, + return_logprob=False, + logprob_start_len=None, + top_logprobs_num=None, + token_ids_logprob=None, + lora_path=None, + custom_logit_processor=None, + ): + payload = { + "text": prompt, + "sampling_params": sampling_params, + "input_ids": input_ids, + "image_data": image_data, + "return_logprob": return_logprob, + "logprob_start_len": logprob_start_len, + "top_logprobs_num": top_logprobs_num, + "token_ids_logprob": token_ids_logprob, + "lora_path": lora_path, + "custom_logit_processor": custom_logit_processor, + } + # Filter out None values + payload = {k: v for k, v in payload.items() if v is not None} + + return self._make_request("generate", payload) + + def release_memory_occupation(self): + return self._make_request("release_memory_occupation") + + def resume_memory_occupation(self): + return self._make_request("resume_memory_occupation") + + def flush_cache(self): + return self._make_request("flush_cache") \ No newline at end of file From c203d350a605b95ad5e855ab652939cfe14fb051 Mon Sep 17 00:00:00 2001 From: Ata Fatahi Date: Thu, 19 Jun 2025 22:16:02 -0700 Subject: [PATCH 4/4] add empty new line Signed-off-by: Ata Fatahi --- python/sglang/srt/entrypoints/http_server_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/entrypoints/http_server_engine.py b/python/sglang/srt/entrypoints/http_server_engine.py index abbd3fc3d000..b2edf1abe61b 100644 --- a/python/sglang/srt/entrypoints/http_server_engine.py +++ b/python/sglang/srt/entrypoints/http_server_engine.py @@ -139,4 +139,4 @@ def resume_memory_occupation(self): return self._make_request("resume_memory_occupation") def flush_cache(self): - return self._make_request("flush_cache") \ No newline at end of file + return self._make_request("flush_cache")