|
25 | 25 | import torch |
26 | 26 | from torch import nn |
27 | 27 | from transformers import PersimmonConfig |
28 | | -from transformers.activations import ReLUSquaredActivation |
29 | 28 |
|
30 | 29 | from vllm.attention import Attention, AttentionMetadata |
31 | 30 | from vllm.config import CacheConfig |
32 | 31 | from vllm.distributed import get_tensor_model_parallel_world_size |
| 32 | +from vllm.model_executor.layers.activation import get_act_fn |
33 | 33 | from vllm.model_executor.layers.linear import (ColumnParallelLinear, |
34 | 34 | QKVParallelLinear, |
35 | 35 | RowParallelLinear) |
@@ -57,7 +57,7 @@ def __init__(self, |
57 | 57 | self.dense_4h_to_h = RowParallelLinear(config.intermediate_size, |
58 | 58 | config.hidden_size, |
59 | 59 | quant_config=quant_config) |
60 | | - self.act = ReLUSquaredActivation() |
| 60 | + self.act = get_act_fn(config.hidden_act, quant_config) |
61 | 61 |
|
62 | 62 | def forward(self, hidden_states) -> torch.Tensor: |
63 | 63 | hidden_states, _ = self.dense_h_to_4h(hidden_states) |
@@ -96,7 +96,7 @@ def __init__(self, |
96 | 96 | quant_config=quant_config, |
97 | 97 | ) |
98 | 98 | self.dense = RowParallelLinear( |
99 | | - self.num_heads * self.head_dim, |
| 99 | + self.total_num_heads * self.head_dim, |
100 | 100 | self.hidden_size, |
101 | 101 | bias=True, |
102 | 102 | quant_config=quant_config, |
@@ -213,10 +213,10 @@ def __init__(self, |
213 | 213 | cache_config: Optional[CacheConfig] = None, |
214 | 214 | quant_config: Optional[QuantizationConfig] = None): |
215 | 215 | super().__init__() |
216 | | - self.vocab_size = config.text_config.vocab_size |
| 216 | + self.vocab_size = config.vocab_size |
217 | 217 |
|
218 | | - self.embed_tokens = VocabParallelEmbedding( |
219 | | - config.text_config.vocab_size, config.hidden_size) |
| 218 | + self.embed_tokens = VocabParallelEmbedding(config.vocab_size, |
| 219 | + config.hidden_size) |
220 | 220 | self.layers = nn.ModuleList([ |
221 | 221 | PersimmonDecoderLayer(config, |
222 | 222 | cache_config=cache_config, |
@@ -252,19 +252,19 @@ def forward( |
252 | 252 | class PersimmonForCausalLM(nn.Module): |
253 | 253 |
|
254 | 254 | def __init__(self, |
255 | | - config, |
| 255 | + config: PersimmonConfig, |
256 | 256 | cache_config: Optional[CacheConfig] = None, |
257 | 257 | quant_config: Optional[QuantizationConfig] = None): |
258 | 258 | super().__init__() |
259 | 259 | self.config = config |
260 | | - self.vocab_size = config.text_config.vocab_size |
| 260 | + self.vocab_size = config.vocab_size |
261 | 261 | self.model = PersimmonModel(config, |
262 | 262 | cache_config=cache_config, |
263 | 263 | quant_config=quant_config) |
264 | | - self.lm_head = ParallelLMHead(config.text_config.vocab_size, |
| 264 | + self.lm_head = ParallelLMHead(config.vocab_size, |
265 | 265 | config.hidden_size, |
266 | 266 | bias=False) |
267 | | - self.logits_processor = LogitsProcessor(config.text_config.vocab_size) |
| 267 | + self.logits_processor = LogitsProcessor(config.vocab_size) |
268 | 268 | self.sampler = Sampler() |
269 | 269 |
|
270 | 270 | def forward( |
|
0 commit comments