Skip to content

Commit 0ea192b

Browse files
hmellorxuebwang-amd
authored andcommitted
Enable RMSNorm substitution for Transformers backend (vllm-project#26353)
Signed-off-by: Harry Mellor <[email protected]> Signed-off-by: xuebwang-amd <[email protected]>
1 parent 5e257c4 commit 0ea192b

File tree

1 file changed

+28
-15
lines changed

1 file changed

+28
-15
lines changed

vllm/model_executor/models/transformers.py

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
from vllm.distributed import get_pp_group, get_tp_group
4444
from vllm.distributed.utils import get_pp_indices
4545
from vllm.logger import init_logger
46-
from vllm.model_executor.layers.layernorm import RMSNorm
46+
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
4747
from vllm.model_executor.layers.linear import (
4848
ColumnParallelLinear,
4949
ReplicatedLinear,
@@ -194,15 +194,29 @@ def replace_rms_norm_class(rms_norm: nn.Module, hidden_size: int) -> RMSNorm:
194194
- `var_hidden_size` is only ever used for Intern vision encoder in vLLM
195195
and Transformers doesn't appear to have the same concept.
196196
"""
197-
kwargs = {
198-
"hidden_size": hidden_size,
199-
"eps": getattr_iter(rms_norm, ("eps", "variance_epsilon"), 1e-6),
200-
"has_weight": getattr(rms_norm, "with_scale", True),
201-
}
202-
if (weight := getattr(rms_norm, "weight", None)) is not None:
203-
# If weight is a Parameter, get its data tensor
204-
weight = getattr(weight, "data", weight)
205-
kwargs["dtype"] = weight.dtype
197+
eps = getattr_iter(rms_norm, ("eps", "variance_epsilon"), 1e-6)
198+
kwargs = {"hidden_size": hidden_size, "eps": eps}
199+
# Update hidden size if weight is available
200+
weight_meta = getattr(rms_norm, "weight", None)
201+
if weight_meta is not None:
202+
kwargs["hidden_size"] = weight_meta.size(0)
203+
# Check if weight is all zeros, which indicates GemmaRMSNorm
204+
# We must create a new instance because rms_norm is on meta
205+
try:
206+
with torch.device("cpu"):
207+
weight_test = getattr(rms_norm.__class__(1), "weight", None)
208+
except Exception:
209+
logger.warning(
210+
"Failed to determine if RMSNorm weight is centered on zero or one. "
211+
"Defaulting to one."
212+
)
213+
weight_test = None
214+
if weight_test is not None and torch.all(weight_test == 0):
215+
return GemmaRMSNorm(**kwargs)
216+
# Otherwise assume it's a regular RMSNorm
217+
kwargs["has_weight"] = getattr(rms_norm, "with_scale", True)
218+
if weight_meta is not None:
219+
kwargs["dtype"] = weight_meta.dtype
206220
else:
207221
# No weight, fall back to weightless RMSNorm
208222
kwargs["has_weight"] = False
@@ -645,11 +659,10 @@ def _recursive_replace(module: nn.Module, prefix: str):
645659
new_module = replace_linear_class(
646660
child_module, style, self.quant_config, prefix=qual_name
647661
)
648-
# TODO(hmellor): Enable RMSNorm replacement once we have a way
649-
# to choose RMSNorm vs GemmaRMSNorm
650-
# elif child_module.__class__.__name__.endswith("RMSNorm"):
651-
# new_module = replace_rms_norm_class(
652-
# child_module, self.config.hidden_size)
662+
elif child_module.__class__.__name__.endswith("RMSNorm"):
663+
new_module = replace_rms_norm_class(
664+
child_module, self.text_config.hidden_size
665+
)
653666
else:
654667
_recursive_replace(child_module, prefix=qual_name)
655668

0 commit comments

Comments
 (0)