|
43 | 43 | from vllm.distributed import get_pp_group, get_tp_group |
44 | 44 | from vllm.distributed.utils import get_pp_indices |
45 | 45 | from vllm.logger import init_logger |
46 | | -from vllm.model_executor.layers.layernorm import RMSNorm |
| 46 | +from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm |
47 | 47 | from vllm.model_executor.layers.linear import ( |
48 | 48 | ColumnParallelLinear, |
49 | 49 | ReplicatedLinear, |
@@ -194,15 +194,29 @@ def replace_rms_norm_class(rms_norm: nn.Module, hidden_size: int) -> RMSNorm: |
194 | 194 | - `var_hidden_size` is only ever used for Intern vision encoder in vLLM |
195 | 195 | and Transformers doesn't appear to have the same concept. |
196 | 196 | """ |
197 | | - kwargs = { |
198 | | - "hidden_size": hidden_size, |
199 | | - "eps": getattr_iter(rms_norm, ("eps", "variance_epsilon"), 1e-6), |
200 | | - "has_weight": getattr(rms_norm, "with_scale", True), |
201 | | - } |
202 | | - if (weight := getattr(rms_norm, "weight", None)) is not None: |
203 | | - # If weight is a Parameter, get its data tensor |
204 | | - weight = getattr(weight, "data", weight) |
205 | | - kwargs["dtype"] = weight.dtype |
| 197 | + eps = getattr_iter(rms_norm, ("eps", "variance_epsilon"), 1e-6) |
| 198 | + kwargs = {"hidden_size": hidden_size, "eps": eps} |
| 199 | + # Update hidden size if weight is available |
| 200 | + weight_meta = getattr(rms_norm, "weight", None) |
| 201 | + if weight_meta is not None: |
| 202 | + kwargs["hidden_size"] = weight_meta.size(0) |
| 203 | + # Check if weight is all zeros, which indicates GemmaRMSNorm |
| 204 | + # We must create a new instance because rms_norm is on meta |
| 205 | + try: |
| 206 | + with torch.device("cpu"): |
| 207 | + weight_test = getattr(rms_norm.__class__(1), "weight", None) |
| 208 | + except Exception: |
| 209 | + logger.warning( |
| 210 | + "Failed to determine if RMSNorm weight is centered on zero or one. " |
| 211 | + "Defaulting to one." |
| 212 | + ) |
| 213 | + weight_test = None |
| 214 | + if weight_test is not None and torch.all(weight_test == 0): |
| 215 | + return GemmaRMSNorm(**kwargs) |
| 216 | + # Otherwise assume it's a regular RMSNorm |
| 217 | + kwargs["has_weight"] = getattr(rms_norm, "with_scale", True) |
| 218 | + if weight_meta is not None: |
| 219 | + kwargs["dtype"] = weight_meta.dtype |
206 | 220 | else: |
207 | 221 | # No weight, fall back to weightless RMSNorm |
208 | 222 | kwargs["has_weight"] = False |
@@ -645,11 +659,10 @@ def _recursive_replace(module: nn.Module, prefix: str): |
645 | 659 | new_module = replace_linear_class( |
646 | 660 | child_module, style, self.quant_config, prefix=qual_name |
647 | 661 | ) |
648 | | - # TODO(hmellor): Enable RMSNorm replacement once we have a way |
649 | | - # to choose RMSNorm vs GemmaRMSNorm |
650 | | - # elif child_module.__class__.__name__.endswith("RMSNorm"): |
651 | | - # new_module = replace_rms_norm_class( |
652 | | - # child_module, self.config.hidden_size) |
| 662 | + elif child_module.__class__.__name__.endswith("RMSNorm"): |
| 663 | + new_module = replace_rms_norm_class( |
| 664 | + child_module, self.text_config.hidden_size |
| 665 | + ) |
653 | 666 | else: |
654 | 667 | _recursive_replace(child_module, prefix=qual_name) |
655 | 668 |
|
|
0 commit comments