diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 051f2b75e48b..281e50129a8c 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -402,7 +402,7 @@ def model_specific_adjustment(self): else: server_args.attention_backend = "triton" logger.info( - f"Attention backend not set. Use {server_args.attention_backend} backend by default." + f"Attention backend not explicitly specified. Use {server_args.attention_backend} backend by default." ) elif self.use_mla_backend: if server_args.device != "cpu": @@ -454,7 +454,7 @@ def model_specific_adjustment(self): if not self.is_multimodal_chunked_prefill_supported: server_args.chunked_prefill_size = -1 logger.info( - f"Automatically turn of --chunked-prefill-size as it is not supported for " + f"Automatically turn off --chunked-prefill-size as it is not supported for " f"{self.model_config.hf_config.model_type}" ) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index ce9710985df2..2fdf235053db 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -325,8 +325,52 @@ def __post_init__(self): # Multimodal models need more memory for the image processor model_config = ModelConfig.from_server_args(self) - if model_config.is_multimodal: - self.mem_fraction_static *= 0.90 + + vision_config = getattr(model_config.hf_config, "vision_config", None) + + if model_config.is_multimodal and vision_config: + # roughly reduce the mem_fraction_static base on params of Vit + original_server_arg_mem_fraction = self.mem_fraction_static + # a base mem_fraction_static factor for regular Vit + base_mem_fraction_reduction_ratio = 0.95 + + vit_num_layers = getattr(vision_config, "num_hidden_layers", 24) + vit_hidden_size = getattr(vision_config, "hidden_size", 1024) + + # baseline ViT params (ViT-L/14) + baseline_vit_layers = 24 + baseline_vit_hidden_size = 1024 + + # weight params count + current_complexity_score = vit_num_layers * (vit_hidden_size**2) + baseline_complexity_score = baseline_vit_layers * ( + baseline_vit_hidden_size**2 + ) + complexity_ratio = ( + current_complexity_score / baseline_complexity_score + if baseline_complexity_score > 0 + else 1.0 + ) + + # every time the complexity grows 100%, adjust final factor for 10% + sensitivity_scale = 0.1 + dynamic_adjustment_factor = 1.0 - sensitivity_scale * ( + complexity_ratio - 1.0 + ) + dynamic_adjustment_factor = max( + 0.8, min(1.05, dynamic_adjustment_factor) + ) + + final_overall_factor = ( + base_mem_fraction_reduction_ratio * dynamic_adjustment_factor + ) + self.mem_fraction_static = ( + original_server_arg_mem_fraction * final_overall_factor + ) + logger.warning( + f"Multimodal model: Dynamically adjusted --mem-fraction-static " + f"from: {original_server_arg_mem_fraction:.3f} to: {self.mem_fraction_static:.3f}." + ) # Set chunked prefill size, which depends on the gpu memory capacity if self.chunked_prefill_size is None: diff --git a/test/srt/test_vision_openai_server_a.py b/test/srt/test_vision_openai_server_a.py index a4a2e770dbf7..f96e0dc7a9d7 100644 --- a/test/srt/test_vision_openai_server_a.py +++ b/test/srt/test_vision_openai_server_a.py @@ -30,7 +30,7 @@ def setUpClass(cls): api_key=cls.api_key, other_args=[ "--mem-fraction-static", - "0.4", + "0.35", ], ) cls.base_url += "/v1" @@ -49,7 +49,7 @@ def setUpClass(cls): api_key=cls.api_key, other_args=[ "--mem-fraction-static", - "0.4", + "0.35", ], ) cls.base_url += "/v1" @@ -69,7 +69,7 @@ def setUpClass(cls): other_args=[ "--context-length", "300", - "--mem-fraction-static=0.80", + "--mem-fraction-static=0.75", ], ) cls.base_url += "/v1" @@ -141,7 +141,7 @@ def setUpClass(cls): other_args=[ "--trust-remote-code", "--mem-fraction-static", - "0.4", + "0.35", ], ) cls.base_url += "/v1" @@ -175,7 +175,7 @@ def setUpClass(cls): other_args=[ "--trust-remote-code", "--mem-fraction-static", - "0.7", + "0.65", ], ) cls.base_url += "/v1" diff --git a/test/srt/test_vision_openai_server_b.py b/test/srt/test_vision_openai_server_b.py index b1ca951dfed9..b676e48c2ec6 100644 --- a/test/srt/test_vision_openai_server_b.py +++ b/test/srt/test_vision_openai_server_b.py @@ -22,7 +22,7 @@ def setUpClass(cls): other_args=[ "--trust-remote-code", "--mem-fraction-static", - "0.73", + "0.70", ], ) cls.base_url += "/v1" @@ -44,7 +44,7 @@ def setUpClass(cls): other_args=[ "--trust-remote-code", "--mem-fraction-static", - "0.8", + "0.75", ], ) cls.base_url += "/v1" @@ -88,7 +88,7 @@ def setUpClass(cls): other_args=[ "--trust-remote-code", "--mem-fraction-static", - "0.4", + "0.35", ], ) cls.base_url += "/v1" @@ -197,7 +197,7 @@ def setUpClass(cls): other_args=[ "--trust-remote-code", "--mem-fraction-static", - "0.75", + "0.70", "--disable-radix-cache", "--max-loras-per-batch", "1",