From dee601f21d6b123198f29f6d67a7576245a9beeb Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Tue, 7 Oct 2025 15:33:43 +0200 Subject: [PATCH 1/5] Enable `RMSNorm` substitution for Transformers backend This change should enable quant fusions which depend on the `RMSNorm` op being present Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/model_executor/models/transformers.py | 36 +++++++++++++--------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 47e829861284..a9c6566ebc74 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -43,7 +43,7 @@ from vllm.distributed import get_pp_group, get_tp_group from vllm.distributed.utils import get_pp_indices from vllm.logger import init_logger -from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm from vllm.model_executor.layers.linear import ( ColumnParallelLinear, ReplicatedLinear, @@ -194,15 +194,22 @@ def replace_rms_norm_class(rms_norm: nn.Module, hidden_size: int) -> RMSNorm: - `var_hidden_size` is only ever used for Intern vision encoder in vLLM and Transformers doesn't appear to have the same concept. """ - kwargs = { - "hidden_size": hidden_size, - "eps": getattr_iter(rms_norm, ("eps", "variance_epsilon"), 1e-6), - "has_weight": getattr(rms_norm, "with_scale", True), - } - if (weight := getattr(rms_norm, "weight", None)) is not None: - # If weight is a Parameter, get its data tensor - weight = getattr(weight, "data", weight) - kwargs["dtype"] = weight.dtype + eps = getattr_iter(rms_norm, ("eps", "variance_epsilon"), 1e-6) + kwargs = {"hidden_size": hidden_size, "eps": eps} + # Update hidden size if weight is available + weight_meta = getattr(rms_norm, "weight", None) + if weight_meta is not None: + kwargs["hidden_size"] = weight_meta.size(0) + # Check if weight is all zeros, which indicates GemmaRMSNorm + # We must create a new instance because rms_norm is on meta + with torch.device("cpu"): + weight_test = getattr(rms_norm.__class__(1), "weight", None) + if weight_test is not None and torch.all(weight_test == 0): + return GemmaRMSNorm(**kwargs) + # Otherwise assume it's a regular RMSNorm + kwargs["has_weight"] = getattr(rms_norm, "with_scale", True) + if weight_meta is not None: + kwargs["dtype"] = weight_meta.dtype else: # No weight, fall back to weightless RMSNorm kwargs["has_weight"] = False @@ -645,11 +652,10 @@ def _recursive_replace(module: nn.Module, prefix: str): new_module = replace_linear_class( child_module, style, self.quant_config, prefix=qual_name ) - # TODO(hmellor): Enable RMSNorm replacement once we have a way - # to choose RMSNorm vs GemmaRMSNorm - # elif child_module.__class__.__name__.endswith("RMSNorm"): - # new_module = replace_rms_norm_class( - # child_module, self.config.hidden_size) + elif child_module.__class__.__name__.endswith("RMSNorm"): + new_module = replace_rms_norm_class( + child_module, self.config.hidden_size + ) else: _recursive_replace(child_module, prefix=qual_name) From 423e4eb97e8dc72ac3095a1d5159fac8bc93f26b Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Tue, 7 Oct 2025 15:48:18 +0200 Subject: [PATCH 2/5] Error handling for `rms_norm.__class__(1)` Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/model_executor/models/transformers.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index a9c6566ebc74..b434b19646da 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -203,7 +203,14 @@ def replace_rms_norm_class(rms_norm: nn.Module, hidden_size: int) -> RMSNorm: # Check if weight is all zeros, which indicates GemmaRMSNorm # We must create a new instance because rms_norm is on meta with torch.device("cpu"): - weight_test = getattr(rms_norm.__class__(1), "weight", None) + try: + weight_test = getattr(rms_norm.__class__(1), "weight", None) + except Exception: + logger.warning( + "Failed to determine if RMSNorm weight is centered on zero or one. " + "Defaulting to one." + ) + weight_test = None if weight_test is not None and torch.all(weight_test == 0): return GemmaRMSNorm(**kwargs) # Otherwise assume it's a regular RMSNorm From 8f945a1ed5a490cecfaae714a2bdb1ff1527c20c Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Tue, 7 Oct 2025 15:51:53 +0200 Subject: [PATCH 3/5] Decrease indentation Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/model_executor/models/transformers.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index b434b19646da..48291c5022e2 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -202,15 +202,15 @@ def replace_rms_norm_class(rms_norm: nn.Module, hidden_size: int) -> RMSNorm: kwargs["hidden_size"] = weight_meta.size(0) # Check if weight is all zeros, which indicates GemmaRMSNorm # We must create a new instance because rms_norm is on meta - with torch.device("cpu"): - try: + try: + with torch.device("cpu"): weight_test = getattr(rms_norm.__class__(1), "weight", None) - except Exception: - logger.warning( - "Failed to determine if RMSNorm weight is centered on zero or one. " - "Defaulting to one." - ) - weight_test = None + except Exception: + logger.warning( + "Failed to determine if RMSNorm weight is centered on zero or one. " + "Defaulting to one." + ) + weight_test = None if weight_test is not None and torch.all(weight_test == 0): return GemmaRMSNorm(**kwargs) # Otherwise assume it's a regular RMSNorm From c8eba5027f2ab7ceeb780437d3b144d2eb2ad230 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Wed, 8 Oct 2025 15:50:00 +0200 Subject: [PATCH 4/5] Get hidden_size from text config Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/model_executor/models/transformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 48291c5022e2..786c4b026f15 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -661,7 +661,7 @@ def _recursive_replace(module: nn.Module, prefix: str): ) elif child_module.__class__.__name__.endswith("RMSNorm"): new_module = replace_rms_norm_class( - child_module, self.config.hidden_size + child_module, self.config.text_config.hidden_size ) else: _recursive_replace(child_module, prefix=qual_name) From 2568fa53388ec009e2b3b9ef1157677403c8b905 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Wed, 8 Oct 2025 18:15:13 +0200 Subject: [PATCH 5/5] Typo... Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/model_executor/models/transformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 786c4b026f15..1cfe401b243c 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -661,7 +661,7 @@ def _recursive_replace(module: nn.Module, prefix: str): ) elif child_module.__class__.__name__.endswith("RMSNorm"): new_module = replace_rms_norm_class( - child_module, self.config.text_config.hidden_size + child_module, self.text_config.hidden_size ) else: _recursive_replace(child_module, prefix=qual_name)