From 630b3b809b2296e7ae4df1faef1453b8ffe15af4 Mon Sep 17 00:00:00 2001 From: roy Date: Mon, 26 Feb 2024 22:01:36 +0800 Subject: [PATCH 1/3] fix --- vllm/model_executor/models/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 411814f2f5d0..d16c3c8cb57e 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -43,6 +43,7 @@ "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), "RWForCausalLM": ("falcon", "FalconForCausalLM"), "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"), + "StableLMForCausalLM": ("stablelm", "StablelmForCausalLM"), } # Models not supported by ROCm. From 402471b94388bd24ef4ec93af4db73ae253e7eaa Mon Sep 17 00:00:00 2001 From: roy Date: Mon, 26 Feb 2024 22:20:36 +0800 Subject: [PATCH 2/3] fix model --- vllm/model_executor/models/stablelm.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index 95e5ad8ede63..44c57e5a6d4f 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -94,7 +94,9 @@ def __init__(self, 1, self.total_num_key_value_heads // tp_size) self.head_dim = self.hidden_size // self.total_num_heads self.max_position_embeddings = config.max_position_embeddings - self.rotary_ndims = int(self.head_dim * self.config.rope_pct) + rope_pct = getattr(config, "rope_pct", + getattr(config, "partial_rotary_factor", 1)) + self.rotary_ndims = int(self.head_dim * rope_pct) self.scaling = self.head_dim**-0.5 self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_key_value_heads * self.head_dim @@ -114,7 +116,6 @@ def __init__(self, self.hidden_size, bias=False, linear_method=linear_method) - self.rotary_ndims = int(self.head_dim * self.config.rope_pct) self.rotary_emb = get_rope( self.head_dim, rotary_dim=self.rotary_ndims, @@ -152,10 +153,11 @@ def __init__( super().__init__() self.self_attn = StablelmAttention(config) self.mlp = StablelmMLP(config, linear_method) - self.input_layernorm = nn.LayerNorm(config.hidden_size, - eps=config.norm_eps) + norm_eps = getattr(config, "norm_eps", + getattr(config, "layer_norm_eps", 1e-05)) + self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, - eps=config.norm_eps) + eps=norm_eps) def forward( self, @@ -199,7 +201,9 @@ def __init__(self, StablelmDecoderLayer(config, linear_method) for _ in range(config.num_hidden_layers) ]) - self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps) + norm_eps = getattr(config, "norm_eps", + getattr(config, "layer_norm_eps", 1e-05)) + self.norm = nn.LayerNorm(config.hidden_size, eps=norm_eps) def forward( self, From 5e380e9012fbd93268d4064be7afdb2f989cf717 Mon Sep 17 00:00:00 2001 From: roy Date: Mon, 26 Feb 2024 22:21:30 +0800 Subject: [PATCH 3/3] fix type --- vllm/model_executor/models/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index d16c3c8cb57e..40b375bb6fbe 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -43,7 +43,7 @@ "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), "RWForCausalLM": ("falcon", "FalconForCausalLM"), "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"), - "StableLMForCausalLM": ("stablelm", "StablelmForCausalLM"), + "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"), } # Models not supported by ROCm.