Skip to content

Commit e23c08a

Browse files
esmeetujimpang
authored andcommitted
Fix stablelm (vllm-project#3038)
1 parent d55c43c commit e23c08a

File tree

2 files changed

+11
-6
lines changed

2 files changed

+11
-6
lines changed

vllm/model_executor/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
4444
"RWForCausalLM": ("falcon", "FalconForCausalLM"),
4545
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
46+
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
4647
}
4748

4849
# Models not supported by ROCm.

vllm/model_executor/models/stablelm.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,9 @@ def __init__(self,
9494
1, self.total_num_key_value_heads // tp_size)
9595
self.head_dim = self.hidden_size // self.total_num_heads
9696
self.max_position_embeddings = config.max_position_embeddings
97-
self.rotary_ndims = int(self.head_dim * self.config.rope_pct)
97+
rope_pct = getattr(config, "rope_pct",
98+
getattr(config, "partial_rotary_factor", 1))
99+
self.rotary_ndims = int(self.head_dim * rope_pct)
98100
self.scaling = self.head_dim**-0.5
99101
self.q_size = self.num_heads * self.head_dim
100102
self.kv_size = self.num_key_value_heads * self.head_dim
@@ -114,7 +116,6 @@ def __init__(self,
114116
self.hidden_size,
115117
bias=False,
116118
linear_method=linear_method)
117-
self.rotary_ndims = int(self.head_dim * self.config.rope_pct)
118119
self.rotary_emb = get_rope(
119120
self.head_dim,
120121
rotary_dim=self.rotary_ndims,
@@ -152,10 +153,11 @@ def __init__(
152153
super().__init__()
153154
self.self_attn = StablelmAttention(config)
154155
self.mlp = StablelmMLP(config, linear_method)
155-
self.input_layernorm = nn.LayerNorm(config.hidden_size,
156-
eps=config.norm_eps)
156+
norm_eps = getattr(config, "norm_eps",
157+
getattr(config, "layer_norm_eps", 1e-05))
158+
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
157159
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
158-
eps=config.norm_eps)
160+
eps=norm_eps)
159161

160162
def forward(
161163
self,
@@ -199,7 +201,9 @@ def __init__(self,
199201
StablelmDecoderLayer(config, linear_method)
200202
for _ in range(config.num_hidden_layers)
201203
])
202-
self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps)
204+
norm_eps = getattr(config, "norm_eps",
205+
getattr(config, "layer_norm_eps", 1e-05))
206+
self.norm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
203207

204208
def forward(
205209
self,

0 commit comments

Comments
 (0)