Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 27 additions & 33 deletions vllm/model_executor/models/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear,
Expand All @@ -40,21 +41,6 @@
KVCache = Tuple[torch.Tensor, torch.Tensor]


class GemmaRMSNorm(nn.Module):

def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.zeros(dim))

def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * (1 + self.weight)


class GemmaMLP(nn.Module):

def __init__(
Expand Down Expand Up @@ -185,36 +171,38 @@ def __init__(
intermediate_size=config.intermediate_size,
linear_method=linear_method,
)
self.input_layernorm = GemmaRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)

def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
)
hidden_states = residual + hidden_states

# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states

return hidden_states
return hidden_states, residual


class GemmaModel(nn.Module):
Expand All @@ -235,7 +223,7 @@ def __init__(
GemmaDecoderLayer(config, linear_method)
for _ in range(config.num_hidden_layers)
])
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

def forward(
self,
Expand All @@ -246,17 +234,19 @@ def forward(
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
# Normalize the embedding by sqrt(hidden_size)
hidden_states = hidden_states * (self.config.hidden_size**0.5)
hidden_states *= self.config.hidden_size**0.5

residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states = layer(
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i],
input_metadata,
residual,
)
hidden_states = self.norm(hidden_states)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states


Expand Down Expand Up @@ -321,6 +311,10 @@ def load_weights(self,
# Skip loading extra layer for lora models.
if "lm_head" in name:
continue
# GemmaRMSNorm is different from Llama's in that it multiplies
# (1 + weight) to the output, instead of just weight.
if "norm.weight" in name:
loaded_weight += 1.0
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
Expand All @@ -329,5 +323,5 @@ def load_weights(self,
unloaded_params = params_dict.keys() - loaded_params
if unloaded_params:
raise RuntimeError(
f"Some weights are not initialized from checkpoints: {unloaded_params}"
)
"Some weights are not initialized from checkpoints: "
f"{unloaded_params}")