From 2df5181ff25cf1a62e1f02f5f1f4ab1b1ebc8f06 Mon Sep 17 00:00:00 2001 From: Mick Date: Thu, 5 Jun 2025 10:50:39 +0800 Subject: [PATCH 01/11] vlm: tune mem-fraction-static --- .../sglang/srt/model_executor/model_runner.py | 65 ++++++++++++++++++- test/srt/test_vision_openai_server_a.py | 12 ---- 2 files changed, 63 insertions(+), 14 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 051f2b75e48b..5a8dd00b2aae 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -402,7 +402,7 @@ def model_specific_adjustment(self): else: server_args.attention_backend = "triton" logger.info( - f"Attention backend not set. Use {server_args.attention_backend} backend by default." + f"Attention backend not explicitly specified. Use {server_args.attention_backend} backend by default." ) elif self.use_mla_backend: if server_args.device != "cpu": @@ -454,10 +454,71 @@ def model_specific_adjustment(self): if not self.is_multimodal_chunked_prefill_supported: server_args.chunked_prefill_size = -1 logger.info( - f"Automatically turn of --chunked-prefill-size as it is not supported for " + f"Automatically turn off --chunked-prefill-size as it is not supported for " f"{self.model_config.hf_config.model_type}" ) + # roughly reduce the mem_fraction_static base on params of Vit + original_server_arg_mem_fraction = self.mem_fraction_static + # a base mem_fraction_static factor for regular Vit + base_mem_fraction_reduction_ratio = 0.95 + if ( + hasattr(self.model_config.hf_config, "vision_config") + and self.model_config.hf_config.vision_config is not None + ): + vision_config = self.model_config.hf_config.vision_config + + vit_num_layers = getattr(vision_config, "num_hidden_layers", 24) + vit_hidden_size = getattr(vision_config, "hidden_size", 1024) + + # baseline ViT params (ViT-L/14) + baseline_vit_layers = 24 + baseline_vit_hidden_size = 1024 + + # weight params count + current_complexity_score = vit_num_layers * (vit_hidden_size**2) + baseline_complexity_score = baseline_vit_layers * ( + baseline_vit_hidden_size**2 + ) + complexity_ratio = ( + current_complexity_score / baseline_complexity_score + if baseline_complexity_score > 0 + else 1.0 + ) + + # everytime the complexity grows 100%, adjust final factor for 3% + sensitivity_scale = 0.02 + dynamic_adjustment_factor = 1.0 - sensitivity_scale * ( + complexity_ratio - 1.0 + ) + print(f"{dynamic_adjustment_factor=}") + dynamic_adjustment_factor = max( + 0.8, min(1.05, dynamic_adjustment_factor) + ) + + final_overall_factor = ( + base_mem_fraction_reduction_ratio * dynamic_adjustment_factor + ) + self.mem_fraction_static = ( + original_server_arg_mem_fraction * final_overall_factor + ) + + print(f"{vit_hidden_size=}") + print(f"{vit_num_layers=}") + + logger.info( + f"Multimodal model: Dynamically adjusted --mem-fraction-static " + f"from: {original_server_arg_mem_fraction:.3f} to: {self.mem_fraction_static:.3f}." + ) + else: + self.mem_fraction_static = ( + original_server_arg_mem_fraction * base_mem_fraction_reduction_ratio + ) + logger.info( + f"Multimodal model: No detailed vision_config, fixed adjusted --mem-fraction-static " + f"from: {original_server_arg_mem_fraction:.3f} to: {self.mem_fraction_static:.3f}." + ) + if not self.use_mla_backend: server_args.disable_chunked_prefix_cache = True elif self.page_size > 1: diff --git a/test/srt/test_vision_openai_server_a.py b/test/srt/test_vision_openai_server_a.py index a4a2e770dbf7..4355b65be124 100644 --- a/test/srt/test_vision_openai_server_a.py +++ b/test/srt/test_vision_openai_server_a.py @@ -28,10 +28,6 @@ def setUpClass(cls): cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, api_key=cls.api_key, - other_args=[ - "--mem-fraction-static", - "0.4", - ], ) cls.base_url += "/v1" @@ -47,10 +43,6 @@ def setUpClass(cls): cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, api_key=cls.api_key, - other_args=[ - "--mem-fraction-static", - "0.4", - ], ) cls.base_url += "/v1" @@ -140,8 +132,6 @@ def setUpClass(cls): timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, other_args=[ "--trust-remote-code", - "--mem-fraction-static", - "0.4", ], ) cls.base_url += "/v1" @@ -174,8 +164,6 @@ def setUpClass(cls): timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, other_args=[ "--trust-remote-code", - "--mem-fraction-static", - "0.7", ], ) cls.base_url += "/v1" From e8afad042e656ca9fbaf532bd48956970df8ac35 Mon Sep 17 00:00:00 2001 From: Mick Date: Fri, 6 Jun 2025 09:06:31 +0800 Subject: [PATCH 02/11] cleanup --- python/sglang/srt/model_executor/model_runner.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 5a8dd00b2aae..070728da6c74 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -486,12 +486,11 @@ def model_specific_adjustment(self): else 1.0 ) - # everytime the complexity grows 100%, adjust final factor for 3% + # every time the complexity grows 100%, adjust final factor for 2% sensitivity_scale = 0.02 dynamic_adjustment_factor = 1.0 - sensitivity_scale * ( complexity_ratio - 1.0 ) - print(f"{dynamic_adjustment_factor=}") dynamic_adjustment_factor = max( 0.8, min(1.05, dynamic_adjustment_factor) ) @@ -503,9 +502,6 @@ def model_specific_adjustment(self): original_server_arg_mem_fraction * final_overall_factor ) - print(f"{vit_hidden_size=}") - print(f"{vit_num_layers=}") - logger.info( f"Multimodal model: Dynamically adjusted --mem-fraction-static " f"from: {original_server_arg_mem_fraction:.3f} to: {self.mem_fraction_static:.3f}." From 8feed3425119fce00b6cb26798698319da4a987b Mon Sep 17 00:00:00 2001 From: Mick Date: Fri, 6 Jun 2025 10:44:18 +0800 Subject: [PATCH 03/11] remove mem-fraction-static args in test --- test/srt/test_vision_openai_server_b.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/test/srt/test_vision_openai_server_b.py b/test/srt/test_vision_openai_server_b.py index b1ca951dfed9..3eb10562f3c1 100644 --- a/test/srt/test_vision_openai_server_b.py +++ b/test/srt/test_vision_openai_server_b.py @@ -21,8 +21,6 @@ def setUpClass(cls): timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, other_args=[ "--trust-remote-code", - "--mem-fraction-static", - "0.73", ], ) cls.base_url += "/v1" @@ -43,8 +41,6 @@ def setUpClass(cls): timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, other_args=[ "--trust-remote-code", - "--mem-fraction-static", - "0.8", ], ) cls.base_url += "/v1" @@ -87,8 +83,6 @@ def setUpClass(cls): timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, other_args=[ "--trust-remote-code", - "--mem-fraction-static", - "0.4", ], ) cls.base_url += "/v1" @@ -115,8 +109,6 @@ def test_single_image_chat_completion(self): # other_args=[ # "--chat-template", # "llama-4", -# "--mem-fraction-static", -# "0.8", # "--tp-size=8", # "--context-length=8192", # ], @@ -139,8 +131,6 @@ def setUpClass(cls): timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, other_args=[ "--trust-remote-code", - "--mem-fraction-static", - "0.70", "--enable-multimodal", ], ) @@ -196,8 +186,6 @@ def setUpClass(cls): timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, other_args=[ "--trust-remote-code", - "--mem-fraction-static", - "0.75", "--disable-radix-cache", "--max-loras-per-batch", "1", From 7557eac83e2ae573157fbe3c750c932de057a0e9 Mon Sep 17 00:00:00 2001 From: Mick Date: Sat, 7 Jun 2025 21:15:50 +0800 Subject: [PATCH 04/11] rollback some mem-fraction-static of ci --- test/srt/test_vision_openai_server_a.py | 2 ++ test/srt/test_vision_openai_server_b.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/test/srt/test_vision_openai_server_a.py b/test/srt/test_vision_openai_server_a.py index 4355b65be124..909d372cb638 100644 --- a/test/srt/test_vision_openai_server_a.py +++ b/test/srt/test_vision_openai_server_a.py @@ -164,6 +164,8 @@ def setUpClass(cls): timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, other_args=[ "--trust-remote-code", + "--mem-fraction-static", + "0.7", ], ) cls.base_url += "/v1" diff --git a/test/srt/test_vision_openai_server_b.py b/test/srt/test_vision_openai_server_b.py index 3eb10562f3c1..ffd595c441da 100644 --- a/test/srt/test_vision_openai_server_b.py +++ b/test/srt/test_vision_openai_server_b.py @@ -131,6 +131,8 @@ def setUpClass(cls): timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, other_args=[ "--trust-remote-code", + "--mem-fraction-static", + "0.75", "--enable-multimodal", ], ) From 4d28cc41d0c93bccd9a3c9d99611c9e57192c8ee Mon Sep 17 00:00:00 2001 From: Mick Date: Thu, 12 Jun 2025 09:04:53 +0800 Subject: [PATCH 05/11] update --- test/srt/test_vision_openai_server_a.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/srt/test_vision_openai_server_a.py b/test/srt/test_vision_openai_server_a.py index 909d372cb638..3ed4e4f423cd 100644 --- a/test/srt/test_vision_openai_server_a.py +++ b/test/srt/test_vision_openai_server_a.py @@ -28,6 +28,10 @@ def setUpClass(cls): cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, api_key=cls.api_key, + other_args=[ + "--mem-fraction-static", + "0.5", + ], ) cls.base_url += "/v1" @@ -132,6 +136,8 @@ def setUpClass(cls): timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, other_args=[ "--trust-remote-code", + "--mem-fraction-static", + "0.7", ], ) cls.base_url += "/v1" From 3f94e8ff37a3c792b9543f36ebf686865d9b07d0 Mon Sep 17 00:00:00 2001 From: Mick Date: Thu, 12 Jun 2025 10:12:00 +0800 Subject: [PATCH 06/11] update --- python/sglang/srt/model_executor/model_runner.py | 5 +++-- test/srt/test_vision_openai_server_a.py | 4 ++++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 070728da6c74..1fd5c3f275e7 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -485,9 +485,10 @@ def model_specific_adjustment(self): if baseline_complexity_score > 0 else 1.0 ) + print(f"{complexity_ratio=}") - # every time the complexity grows 100%, adjust final factor for 2% - sensitivity_scale = 0.02 + # every time the complexity grows 100%, adjust final factor for 3% + sensitivity_scale = 0.03 dynamic_adjustment_factor = 1.0 - sensitivity_scale * ( complexity_ratio - 1.0 ) diff --git a/test/srt/test_vision_openai_server_a.py b/test/srt/test_vision_openai_server_a.py index 3ed4e4f423cd..731495329016 100644 --- a/test/srt/test_vision_openai_server_a.py +++ b/test/srt/test_vision_openai_server_a.py @@ -47,6 +47,10 @@ def setUpClass(cls): cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, api_key=cls.api_key, + other_args=[ + "--mem-fraction-static", + "0.5", + ], ) cls.base_url += "/v1" From bd5f9de1a2b45080e5a7383fa07bbc6ff031497f Mon Sep 17 00:00:00 2001 From: Mick Date: Thu, 12 Jun 2025 12:48:42 +0800 Subject: [PATCH 07/11] adjust sensitivity_scale --- python/sglang/srt/model_executor/model_runner.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 1fd5c3f275e7..d3fa0ac58cc6 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -485,10 +485,9 @@ def model_specific_adjustment(self): if baseline_complexity_score > 0 else 1.0 ) - print(f"{complexity_ratio=}") - # every time the complexity grows 100%, adjust final factor for 3% - sensitivity_scale = 0.03 + # every time the complexity grows 100%, adjust final factor for 10% + sensitivity_scale = 0.1 dynamic_adjustment_factor = 1.0 - sensitivity_scale * ( complexity_ratio - 1.0 ) From f9a86f02ab72f787925d255812c1048d453a25f3 Mon Sep 17 00:00:00 2001 From: Mick Date: Wed, 2 Jul 2025 09:24:04 +0800 Subject: [PATCH 08/11] update --- .../sglang/srt/model_executor/model_runner.py | 57 -------- python/sglang/srt/server_args.py | 138 ++++++++++++------ 2 files changed, 91 insertions(+), 104 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index d3fa0ac58cc6..281e50129a8c 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -458,63 +458,6 @@ def model_specific_adjustment(self): f"{self.model_config.hf_config.model_type}" ) - # roughly reduce the mem_fraction_static base on params of Vit - original_server_arg_mem_fraction = self.mem_fraction_static - # a base mem_fraction_static factor for regular Vit - base_mem_fraction_reduction_ratio = 0.95 - if ( - hasattr(self.model_config.hf_config, "vision_config") - and self.model_config.hf_config.vision_config is not None - ): - vision_config = self.model_config.hf_config.vision_config - - vit_num_layers = getattr(vision_config, "num_hidden_layers", 24) - vit_hidden_size = getattr(vision_config, "hidden_size", 1024) - - # baseline ViT params (ViT-L/14) - baseline_vit_layers = 24 - baseline_vit_hidden_size = 1024 - - # weight params count - current_complexity_score = vit_num_layers * (vit_hidden_size**2) - baseline_complexity_score = baseline_vit_layers * ( - baseline_vit_hidden_size**2 - ) - complexity_ratio = ( - current_complexity_score / baseline_complexity_score - if baseline_complexity_score > 0 - else 1.0 - ) - - # every time the complexity grows 100%, adjust final factor for 10% - sensitivity_scale = 0.1 - dynamic_adjustment_factor = 1.0 - sensitivity_scale * ( - complexity_ratio - 1.0 - ) - dynamic_adjustment_factor = max( - 0.8, min(1.05, dynamic_adjustment_factor) - ) - - final_overall_factor = ( - base_mem_fraction_reduction_ratio * dynamic_adjustment_factor - ) - self.mem_fraction_static = ( - original_server_arg_mem_fraction * final_overall_factor - ) - - logger.info( - f"Multimodal model: Dynamically adjusted --mem-fraction-static " - f"from: {original_server_arg_mem_fraction:.3f} to: {self.mem_fraction_static:.3f}." - ) - else: - self.mem_fraction_static = ( - original_server_arg_mem_fraction * base_mem_fraction_reduction_ratio - ) - logger.info( - f"Multimodal model: No detailed vision_config, fixed adjusted --mem-fraction-static " - f"from: {original_server_arg_mem_fraction:.3f} to: {self.mem_fraction_static:.3f}." - ) - if not self.use_mla_backend: server_args.disable_chunked_prefix_cache = True elif self.page_size > 1: diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index ce9710985df2..1fd5a77fbcdd 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -325,8 +325,52 @@ def __post_init__(self): # Multimodal models need more memory for the image processor model_config = ModelConfig.from_server_args(self) - if model_config.is_multimodal: - self.mem_fraction_static *= 0.90 + + vision_config = getattr(model_config.hf_config, "vision_config", None) + + if model_config.is_multimodal and vision_config: + # roughly reduce the mem_fraction_static base on params of Vit + original_server_arg_mem_fraction = self.mem_fraction_static + # a base mem_fraction_static factor for regular Vit + base_mem_fraction_reduction_ratio = 0.95 + + vit_num_layers = getattr(vision_config, "num_hidden_layers", 24) + vit_hidden_size = getattr(vision_config, "hidden_size", 1024) + + # baseline ViT params (ViT-L/14) + baseline_vit_layers = 24 + baseline_vit_hidden_size = 1024 + + # weight params count + current_complexity_score = vit_num_layers * (vit_hidden_size ** 2) + baseline_complexity_score = baseline_vit_layers * ( + baseline_vit_hidden_size ** 2 + ) + complexity_ratio = ( + current_complexity_score / baseline_complexity_score + if baseline_complexity_score > 0 + else 1.0 + ) + + # every time the complexity grows 100%, adjust final factor for 10% + sensitivity_scale = 0.1 + dynamic_adjustment_factor = 1.0 - sensitivity_scale * ( + complexity_ratio - 1.0 + ) + dynamic_adjustment_factor = max( + 0.8, min(1.05, dynamic_adjustment_factor) + ) + + final_overall_factor = ( + base_mem_fraction_reduction_ratio * dynamic_adjustment_factor + ) + self.mem_fraction_static = ( + original_server_arg_mem_fraction * final_overall_factor + ) + logger.warning( + f"Multimodal model: Dynamically adjusted --mem-fraction-static " + f"from: {original_server_arg_mem_fraction:.3f} to: {self.mem_fraction_static:.3f}." + ) # Set chunked prefill size, which depends on the gpu memory capacity if self.chunked_prefill_size is None: @@ -608,8 +652,8 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.tokenizer_mode, choices=["auto", "slow"], help="Tokenizer mode. 'auto' will use the fast " - "tokenizer if available, and 'slow' will " - "always use the slow tokenizer.", + "tokenizer if available, and 'slow' will " + "always use the slow tokenizer.", ) parser.add_argument( "--skip-tokenizer-init", @@ -638,27 +682,27 @@ def add_cli_args(parser: argparse.ArgumentParser): "remote", ], help="The format of the model weights to load. " - '"auto" will try to load the weights in the safetensors format ' - "and fall back to the pytorch bin format if safetensors format " - "is not available. " - '"pt" will load the weights in the pytorch bin format. ' - '"safetensors" will load the weights in the safetensors format. ' - '"npcache" will load the weights in pytorch format and store ' - "a numpy cache to speed up the loading. " - '"dummy" will initialize the weights with random values, ' - "which is mainly for profiling." - '"gguf" will load the weights in the gguf format. ' - '"bitsandbytes" will load the weights using bitsandbytes ' - "quantization." - '"layered" loads weights layer by layer so that one can quantize a ' - "layer before loading another to make the peak memory envelope " - "smaller.", + '"auto" will try to load the weights in the safetensors format ' + "and fall back to the pytorch bin format if safetensors format " + "is not available. " + '"pt" will load the weights in the pytorch bin format. ' + '"safetensors" will load the weights in the safetensors format. ' + '"npcache" will load the weights in pytorch format and store ' + "a numpy cache to speed up the loading. " + '"dummy" will initialize the weights with random values, ' + "which is mainly for profiling." + '"gguf" will load the weights in the gguf format. ' + '"bitsandbytes" will load the weights using bitsandbytes ' + "quantization." + '"layered" loads weights layer by layer so that one can quantize a ' + "layer before loading another to make the peak memory envelope " + "smaller.", ) parser.add_argument( "--model-loader-extra-config", type=str, help="Extra config for model loader. " - "This will be passed to the model loader corresponding to the chosen load_format.", + "This will be passed to the model loader corresponding to the chosen load_format.", default=ServerArgs.model_loader_extra_config, ) parser.add_argument( @@ -672,13 +716,13 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.dtype, choices=["auto", "half", "float16", "bfloat16", "float", "float32"], help="Data type for model weights and activations.\n\n" - '* "auto" will use FP16 precision for FP32 and FP16 models, and ' - "BF16 precision for BF16 models.\n" - '* "half" for FP16. Recommended for AWQ quantization.\n' - '* "float16" is the same as "half".\n' - '* "bfloat16" for a balance between precision and range.\n' - '* "float" is shorthand for FP32 precision.\n' - '* "float32" for FP32 precision.', + '* "auto" will use FP16 precision for FP32 and FP16 models, and ' + "BF16 precision for BF16 models.\n" + '* "half" for FP16. Recommended for AWQ quantization.\n' + '* "float16" is the same as "half".\n' + '* "bfloat16" for a balance between precision and range.\n' + '* "float" is shorthand for FP32 precision.\n' + '* "float32" for FP32 precision.', ) parser.add_argument( "--kv-cache-dtype", @@ -714,9 +758,9 @@ def add_cli_args(parser: argparse.ArgumentParser): type=nullable_str, default=None, help="Path to the JSON file containing the KV cache " - "scaling factors. This should generally be supplied, when " - "KV cache dtype is FP8. Otherwise, KV cache scaling factors " - "default to 1.0, which may cause accuracy issues. ", + "scaling factors. This should generally be supplied, when " + "KV cache dtype is FP8. Otherwise, KV cache scaling factors " + "default to 1.0, which may cause accuracy issues. ", ) parser.add_argument( "--context-length", @@ -764,20 +808,20 @@ def add_cli_args(parser: argparse.ArgumentParser): type=str, default=None, help="The specific model version to use. It can be a branch " - "name, a tag name, or a commit id. If unspecified, will use " - "the default version.", + "name, a tag name, or a commit id. If unspecified, will use " + "the default version.", ) parser.add_argument( "--impl", type=str, default=ServerArgs.impl, help="Which implementation of the model to use.\n\n" - '* "auto" will try to use the SGLang implementation if it exists ' - "and fall back to the Transformers implementation if no SGLang " - "implementation is available.\n" - '* "sglang" will use the SGLang model implementation.\n' - '* "transformers" will use the Transformers model ' - "implementation.\n", + '* "auto" will try to use the SGLang implementation if it exists ' + "and fall back to the Transformers implementation if no SGLang " + "implementation is available.\n" + '* "sglang" will use the SGLang model implementation.\n' + '* "transformers" will use the Transformers model ' + "implementation.\n", ) # Memory and scheduling @@ -798,7 +842,7 @@ def add_cli_args(parser: argparse.ArgumentParser): type=int, default=ServerArgs.max_total_tokens, help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. " - "This option is typically used for development and debugging purposes.", + "This option is typically used for development and debugging purposes.", ) parser.add_argument( "--chunked-prefill-size", @@ -1469,7 +1513,7 @@ def add_cli_args(parser: argparse.ArgumentParser): "--triton-attention-reduce-in-fp32", action="store_true", help="Cast the intermediate attention results to fp32 to avoid possible crashes related to fp16." - "This only affects Triton attention kernels.", + "This only affects Triton attention kernels.", ) parser.add_argument( "--triton-attention-num-kv-splits", @@ -1482,8 +1526,8 @@ def add_cli_args(parser: argparse.ArgumentParser): type=int, default=ServerArgs.num_continuous_decode_steps, help="Run multiple continuous decoding steps to reduce scheduling overhead. " - "This can potentially increase throughput but may also increase time-to-first-token latency. " - "The default value is 1, meaning only run one decoding step at a time.", + "This can potentially increase throughput but may also increase time-to-first-token latency. " + "The default value is 1, meaning only run one decoding step at a time.", ) parser.add_argument( "--delete-ckpt-after-loading", @@ -1559,7 +1603,7 @@ def add_cli_args(parser: argparse.ArgumentParser): type=str, required=False, help="Specify custom warmup functions (csv) to run before server starts eg. --warmups=warmup_name1,warmup_name2 " - "will run the functions `warmup_name1` and `warmup_name2` specified in warmup.py before the server starts listening for requests", + "will run the functions `warmup_name1` and `warmup_name2` specified in warmup.py before the server starts listening for requests", ) # Debug tensor dumps @@ -1631,8 +1675,8 @@ def add_cli_args(parser: argparse.ArgumentParser): type=str, default=ServerArgs.disaggregation_ib_device, help="The InfiniBand devices for disaggregation transfer, accepts single device (e.g., --disaggregation-ib-device mlx5_0) " - "or multiple comma-separated devices (e.g., --disaggregation-ib-device mlx5_0,mlx5_1). " - "Default is None, which triggers automatic device detection when mooncake backend is enabled.", + "or multiple comma-separated devices (e.g., --disaggregation-ib-device mlx5_0,mlx5_1). " + "Default is None, which triggers automatic device detection when mooncake backend is enabled.", ) parser.add_argument( "--num-reserved-decode-tokens", @@ -1676,8 +1720,8 @@ def url(self): def check_server_args(self): assert ( - self.tp_size * self.pp_size - ) % self.nnodes == 0, "tp_size must be divisible by number of nodes" + self.tp_size * self.pp_size + ) % self.nnodes == 0, "tp_size must be divisible by number of nodes" # FIXME pp constraints if self.pp_size > 1: From 1506ad0136bb28d6903e03f295be5b18b5fa759a Mon Sep 17 00:00:00 2001 From: Mick Date: Wed, 2 Jul 2025 09:51:33 +0800 Subject: [PATCH 09/11] revert test --- test/srt/test_vision_openai_server_a.py | 6 +++--- test/srt/test_vision_openai_server_b.py | 10 ++++++++++ 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/test/srt/test_vision_openai_server_a.py b/test/srt/test_vision_openai_server_a.py index 731495329016..a4a2e770dbf7 100644 --- a/test/srt/test_vision_openai_server_a.py +++ b/test/srt/test_vision_openai_server_a.py @@ -30,7 +30,7 @@ def setUpClass(cls): api_key=cls.api_key, other_args=[ "--mem-fraction-static", - "0.5", + "0.4", ], ) cls.base_url += "/v1" @@ -49,7 +49,7 @@ def setUpClass(cls): api_key=cls.api_key, other_args=[ "--mem-fraction-static", - "0.5", + "0.4", ], ) cls.base_url += "/v1" @@ -141,7 +141,7 @@ def setUpClass(cls): other_args=[ "--trust-remote-code", "--mem-fraction-static", - "0.7", + "0.4", ], ) cls.base_url += "/v1" diff --git a/test/srt/test_vision_openai_server_b.py b/test/srt/test_vision_openai_server_b.py index ffd595c441da..0e42defa5edb 100644 --- a/test/srt/test_vision_openai_server_b.py +++ b/test/srt/test_vision_openai_server_b.py @@ -21,6 +21,8 @@ def setUpClass(cls): timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, other_args=[ "--trust-remote-code", + "--mem-fraction-static", + "0.73", ], ) cls.base_url += "/v1" @@ -41,6 +43,8 @@ def setUpClass(cls): timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, other_args=[ "--trust-remote-code", + "--mem-fraction-static", + "0.8", ], ) cls.base_url += "/v1" @@ -83,6 +87,8 @@ def setUpClass(cls): timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, other_args=[ "--trust-remote-code", + "--mem-fraction-static", + "0.4", ], ) cls.base_url += "/v1" @@ -109,6 +115,8 @@ def test_single_image_chat_completion(self): # other_args=[ # "--chat-template", # "llama-4", +# "--mem-fraction-static", +# "0.8", # "--tp-size=8", # "--context-length=8192", # ], @@ -188,6 +196,8 @@ def setUpClass(cls): timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, other_args=[ "--trust-remote-code", + "--mem-fraction-static", + "0.75", "--disable-radix-cache", "--max-loras-per-batch", "1", From 1ee4eef8c3b3995e9813732951e7bfd4801d809f Mon Sep 17 00:00:00 2001 From: Mick Date: Wed, 2 Jul 2025 11:14:09 +0800 Subject: [PATCH 10/11] update test --- test/srt/test_vision_openai_server_a.py | 10 +++++----- test/srt/test_vision_openai_server_b.py | 10 +++++----- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/test/srt/test_vision_openai_server_a.py b/test/srt/test_vision_openai_server_a.py index a4a2e770dbf7..f96e0dc7a9d7 100644 --- a/test/srt/test_vision_openai_server_a.py +++ b/test/srt/test_vision_openai_server_a.py @@ -30,7 +30,7 @@ def setUpClass(cls): api_key=cls.api_key, other_args=[ "--mem-fraction-static", - "0.4", + "0.35", ], ) cls.base_url += "/v1" @@ -49,7 +49,7 @@ def setUpClass(cls): api_key=cls.api_key, other_args=[ "--mem-fraction-static", - "0.4", + "0.35", ], ) cls.base_url += "/v1" @@ -69,7 +69,7 @@ def setUpClass(cls): other_args=[ "--context-length", "300", - "--mem-fraction-static=0.80", + "--mem-fraction-static=0.75", ], ) cls.base_url += "/v1" @@ -141,7 +141,7 @@ def setUpClass(cls): other_args=[ "--trust-remote-code", "--mem-fraction-static", - "0.4", + "0.35", ], ) cls.base_url += "/v1" @@ -175,7 +175,7 @@ def setUpClass(cls): other_args=[ "--trust-remote-code", "--mem-fraction-static", - "0.7", + "0.65", ], ) cls.base_url += "/v1" diff --git a/test/srt/test_vision_openai_server_b.py b/test/srt/test_vision_openai_server_b.py index 0e42defa5edb..b676e48c2ec6 100644 --- a/test/srt/test_vision_openai_server_b.py +++ b/test/srt/test_vision_openai_server_b.py @@ -22,7 +22,7 @@ def setUpClass(cls): other_args=[ "--trust-remote-code", "--mem-fraction-static", - "0.73", + "0.70", ], ) cls.base_url += "/v1" @@ -44,7 +44,7 @@ def setUpClass(cls): other_args=[ "--trust-remote-code", "--mem-fraction-static", - "0.8", + "0.75", ], ) cls.base_url += "/v1" @@ -88,7 +88,7 @@ def setUpClass(cls): other_args=[ "--trust-remote-code", "--mem-fraction-static", - "0.4", + "0.35", ], ) cls.base_url += "/v1" @@ -140,7 +140,7 @@ def setUpClass(cls): other_args=[ "--trust-remote-code", "--mem-fraction-static", - "0.75", + "0.70", "--enable-multimodal", ], ) @@ -197,7 +197,7 @@ def setUpClass(cls): other_args=[ "--trust-remote-code", "--mem-fraction-static", - "0.75", + "0.70", "--disable-radix-cache", "--max-loras-per-batch", "1", From 5ccac38e3457407c4a0203cf08101b2414fcdfbd Mon Sep 17 00:00:00 2001 From: Mick Date: Sun, 6 Jul 2025 09:38:51 +0800 Subject: [PATCH 11/11] upd --- python/sglang/srt/server_args.py | 94 ++++++++++++++++---------------- 1 file changed, 47 insertions(+), 47 deletions(-) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 1fd5a77fbcdd..2fdf235053db 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -342,9 +342,9 @@ def __post_init__(self): baseline_vit_hidden_size = 1024 # weight params count - current_complexity_score = vit_num_layers * (vit_hidden_size ** 2) + current_complexity_score = vit_num_layers * (vit_hidden_size**2) baseline_complexity_score = baseline_vit_layers * ( - baseline_vit_hidden_size ** 2 + baseline_vit_hidden_size**2 ) complexity_ratio = ( current_complexity_score / baseline_complexity_score @@ -652,8 +652,8 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.tokenizer_mode, choices=["auto", "slow"], help="Tokenizer mode. 'auto' will use the fast " - "tokenizer if available, and 'slow' will " - "always use the slow tokenizer.", + "tokenizer if available, and 'slow' will " + "always use the slow tokenizer.", ) parser.add_argument( "--skip-tokenizer-init", @@ -682,27 +682,27 @@ def add_cli_args(parser: argparse.ArgumentParser): "remote", ], help="The format of the model weights to load. " - '"auto" will try to load the weights in the safetensors format ' - "and fall back to the pytorch bin format if safetensors format " - "is not available. " - '"pt" will load the weights in the pytorch bin format. ' - '"safetensors" will load the weights in the safetensors format. ' - '"npcache" will load the weights in pytorch format and store ' - "a numpy cache to speed up the loading. " - '"dummy" will initialize the weights with random values, ' - "which is mainly for profiling." - '"gguf" will load the weights in the gguf format. ' - '"bitsandbytes" will load the weights using bitsandbytes ' - "quantization." - '"layered" loads weights layer by layer so that one can quantize a ' - "layer before loading another to make the peak memory envelope " - "smaller.", + '"auto" will try to load the weights in the safetensors format ' + "and fall back to the pytorch bin format if safetensors format " + "is not available. " + '"pt" will load the weights in the pytorch bin format. ' + '"safetensors" will load the weights in the safetensors format. ' + '"npcache" will load the weights in pytorch format and store ' + "a numpy cache to speed up the loading. " + '"dummy" will initialize the weights with random values, ' + "which is mainly for profiling." + '"gguf" will load the weights in the gguf format. ' + '"bitsandbytes" will load the weights using bitsandbytes ' + "quantization." + '"layered" loads weights layer by layer so that one can quantize a ' + "layer before loading another to make the peak memory envelope " + "smaller.", ) parser.add_argument( "--model-loader-extra-config", type=str, help="Extra config for model loader. " - "This will be passed to the model loader corresponding to the chosen load_format.", + "This will be passed to the model loader corresponding to the chosen load_format.", default=ServerArgs.model_loader_extra_config, ) parser.add_argument( @@ -716,13 +716,13 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.dtype, choices=["auto", "half", "float16", "bfloat16", "float", "float32"], help="Data type for model weights and activations.\n\n" - '* "auto" will use FP16 precision for FP32 and FP16 models, and ' - "BF16 precision for BF16 models.\n" - '* "half" for FP16. Recommended for AWQ quantization.\n' - '* "float16" is the same as "half".\n' - '* "bfloat16" for a balance between precision and range.\n' - '* "float" is shorthand for FP32 precision.\n' - '* "float32" for FP32 precision.', + '* "auto" will use FP16 precision for FP32 and FP16 models, and ' + "BF16 precision for BF16 models.\n" + '* "half" for FP16. Recommended for AWQ quantization.\n' + '* "float16" is the same as "half".\n' + '* "bfloat16" for a balance between precision and range.\n' + '* "float" is shorthand for FP32 precision.\n' + '* "float32" for FP32 precision.', ) parser.add_argument( "--kv-cache-dtype", @@ -758,9 +758,9 @@ def add_cli_args(parser: argparse.ArgumentParser): type=nullable_str, default=None, help="Path to the JSON file containing the KV cache " - "scaling factors. This should generally be supplied, when " - "KV cache dtype is FP8. Otherwise, KV cache scaling factors " - "default to 1.0, which may cause accuracy issues. ", + "scaling factors. This should generally be supplied, when " + "KV cache dtype is FP8. Otherwise, KV cache scaling factors " + "default to 1.0, which may cause accuracy issues. ", ) parser.add_argument( "--context-length", @@ -808,20 +808,20 @@ def add_cli_args(parser: argparse.ArgumentParser): type=str, default=None, help="The specific model version to use. It can be a branch " - "name, a tag name, or a commit id. If unspecified, will use " - "the default version.", + "name, a tag name, or a commit id. If unspecified, will use " + "the default version.", ) parser.add_argument( "--impl", type=str, default=ServerArgs.impl, help="Which implementation of the model to use.\n\n" - '* "auto" will try to use the SGLang implementation if it exists ' - "and fall back to the Transformers implementation if no SGLang " - "implementation is available.\n" - '* "sglang" will use the SGLang model implementation.\n' - '* "transformers" will use the Transformers model ' - "implementation.\n", + '* "auto" will try to use the SGLang implementation if it exists ' + "and fall back to the Transformers implementation if no SGLang " + "implementation is available.\n" + '* "sglang" will use the SGLang model implementation.\n' + '* "transformers" will use the Transformers model ' + "implementation.\n", ) # Memory and scheduling @@ -842,7 +842,7 @@ def add_cli_args(parser: argparse.ArgumentParser): type=int, default=ServerArgs.max_total_tokens, help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. " - "This option is typically used for development and debugging purposes.", + "This option is typically used for development and debugging purposes.", ) parser.add_argument( "--chunked-prefill-size", @@ -1513,7 +1513,7 @@ def add_cli_args(parser: argparse.ArgumentParser): "--triton-attention-reduce-in-fp32", action="store_true", help="Cast the intermediate attention results to fp32 to avoid possible crashes related to fp16." - "This only affects Triton attention kernels.", + "This only affects Triton attention kernels.", ) parser.add_argument( "--triton-attention-num-kv-splits", @@ -1526,8 +1526,8 @@ def add_cli_args(parser: argparse.ArgumentParser): type=int, default=ServerArgs.num_continuous_decode_steps, help="Run multiple continuous decoding steps to reduce scheduling overhead. " - "This can potentially increase throughput but may also increase time-to-first-token latency. " - "The default value is 1, meaning only run one decoding step at a time.", + "This can potentially increase throughput but may also increase time-to-first-token latency. " + "The default value is 1, meaning only run one decoding step at a time.", ) parser.add_argument( "--delete-ckpt-after-loading", @@ -1603,7 +1603,7 @@ def add_cli_args(parser: argparse.ArgumentParser): type=str, required=False, help="Specify custom warmup functions (csv) to run before server starts eg. --warmups=warmup_name1,warmup_name2 " - "will run the functions `warmup_name1` and `warmup_name2` specified in warmup.py before the server starts listening for requests", + "will run the functions `warmup_name1` and `warmup_name2` specified in warmup.py before the server starts listening for requests", ) # Debug tensor dumps @@ -1675,8 +1675,8 @@ def add_cli_args(parser: argparse.ArgumentParser): type=str, default=ServerArgs.disaggregation_ib_device, help="The InfiniBand devices for disaggregation transfer, accepts single device (e.g., --disaggregation-ib-device mlx5_0) " - "or multiple comma-separated devices (e.g., --disaggregation-ib-device mlx5_0,mlx5_1). " - "Default is None, which triggers automatic device detection when mooncake backend is enabled.", + "or multiple comma-separated devices (e.g., --disaggregation-ib-device mlx5_0,mlx5_1). " + "Default is None, which triggers automatic device detection when mooncake backend is enabled.", ) parser.add_argument( "--num-reserved-decode-tokens", @@ -1720,8 +1720,8 @@ def url(self): def check_server_args(self): assert ( - self.tp_size * self.pp_size - ) % self.nnodes == 0, "tp_size must be divisible by number of nodes" + self.tp_size * self.pp_size + ) % self.nnodes == 0, "tp_size must be divisible by number of nodes" # FIXME pp constraints if self.pp_size > 1: