Skip to content

Commit dd90bd1

Browse files
esmeetujimpang
authored andcommitted
Refactor llama family models (vllm-project#2637)
1 parent 008c0ae commit dd90bd1

File tree

17 files changed

+236
-2720
lines changed

17 files changed

+236
-2720
lines changed

vllm/model_executor/layers/layernorm.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,31 @@
77
from vllm._C import ops
88

99

10+
class LayerNorm(nn.LayerNorm):
11+
12+
def __init__(
13+
self,
14+
hidden_size: int,
15+
eps: float = 1e-6,
16+
) -> None:
17+
super().__init__(hidden_size, eps=eps)
18+
19+
def forward(
20+
self,
21+
x: torch.Tensor,
22+
residual: Optional[torch.Tensor] = None,
23+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
24+
"""normalization."""
25+
if residual is not None:
26+
x = x + residual
27+
residual = x
28+
x = super().forward(x)
29+
if residual is None:
30+
return x
31+
else:
32+
return x, residual
33+
34+
1035
class RMSNorm(nn.Module):
1136
"""Root mean square normalization.
1237

vllm/model_executor/models/__init__.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010

1111
# Architecture -> (module, class).
1212
_MODELS = {
13-
"AquilaModel": ("aquila", "AquilaForCausalLM"),
14-
"AquilaForCausalLM": ("aquila", "AquilaForCausalLM"), # AquilaChat2
13+
"AquilaModel": ("llama", "LlamaForCausalLM"),
14+
"AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2
1515
"BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b
1616
"BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b
1717
"BloomForCausalLM": ("bloom", "BloomForCausalLM"),
@@ -24,12 +24,12 @@
2424
"GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
2525
"GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
2626
"GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
27-
"InternLMForCausalLM": ("internlm", "InternLMForCausalLM"),
27+
"InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
2828
"InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
2929
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
3030
# For decapoda-research/llama-*
3131
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
32-
"MistralForCausalLM": ("mistral", "MistralForCausalLM"),
32+
"MistralForCausalLM": ("llama", "LlamaForCausalLM"),
3333
"MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
3434
"QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
3535
# transformers's mpt class has lower case
@@ -41,7 +41,6 @@
4141
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
4242
"RWForCausalLM": ("falcon", "FalconForCausalLM"),
4343
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
44-
"YiForCausalLM": ("yi", "YiForCausalLM")
4544
}
4645

4746
# Models not supported by ROCm.

0 commit comments

Comments
 (0)