-
Notifications
You must be signed in to change notification settings - Fork 5.1k
Support for Phi-1.5 & Phi-2 models #7862
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
23 commits
Select commit
Hold shift + click to select a range
5f4e6e8
Added phi 1.5 model support
ppraneth 635b0ad
Added microsoft/phi-1_5 to test_generation_models.py
ppraneth fb60fa2
Modified phi.py
ppraneth 33dfcde
Updated code for Phi-1.5
ppraneth adc05a5
removed comments and unused import statements for Phi-1.5(phi.py)
ppraneth dce59cb
Merge branch 'main' into support-phi-1.5
ppraneth 3586ecc
removed compute_logits function from phi.py
ppraneth 1d888b8
Merge branch 'main' into support-phi-1.5
ppraneth a7623ed
Modified phi.py to fix errors and added gelu_new activation function …
ppraneth d48f0cb
Modified load_weights function in phi.py
ppraneth f672db1
Merge branch 'main' into support-phi-1.5
ppraneth ccf669f
Merge branch 'main' into support-phi-1.5
lifuhuang 876bcb4
Applied suggested changes to phi.py
ppraneth b6d8263
Applied suggested changes to phi.py removed pylint comment
ppraneth 23d7c45
Updated generative_models.md
ppraneth 8a9bc15
Updated generative_models.md
ppraneth 59e0da9
Merge branch 'main' into support-phi-1.5
ppraneth 376323d
Merge branch 'main' into support-phi-1.5
ppraneth 35e8118
Merge branch 'main' into support-phi-1.5
ppraneth 8e871e5
Merge branch 'main' into support-phi-1.5
ppraneth 435f3a0
Merge branch 'main' into support-phi-1.5
ppraneth bcc9570
Merge branch 'main' into support-phi-1.5
lifuhuang 66ac35e
Merge branch 'main' into support-phi-1.5
lifuhuang File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,321 @@ | ||
| # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/phi.py | ||
| from typing import Iterable, Optional, Union | ||
|
|
||
| import torch | ||
| from torch import nn | ||
| from transformers import PhiConfig | ||
|
|
||
| from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size | ||
| from sglang.srt.layers.activation import get_act_fn | ||
| from sglang.srt.layers.linear import ( | ||
| ColumnParallelLinear, | ||
| QKVParallelLinear, | ||
| RowParallelLinear, | ||
| ) | ||
| from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput | ||
| from sglang.srt.layers.quantization.base_config import QuantizationConfig | ||
| from sglang.srt.layers.radix_attention import RadixAttention | ||
| from sglang.srt.layers.rotary_embedding import get_rope | ||
| from sglang.srt.layers.vocab_parallel_embedding import ( | ||
| ParallelLMHead, | ||
| VocabParallelEmbedding, | ||
| ) | ||
| from sglang.srt.model_executor.forward_batch_info import ForwardBatch | ||
| from sglang.srt.model_loader.weight_utils import default_weight_loader | ||
| from sglang.srt.utils import add_prefix, make_layers | ||
|
|
||
|
|
||
| class PhiAttention(nn.Module): | ||
|
|
||
| def __init__( | ||
| self, | ||
| config: PhiConfig, | ||
| quant_config: Optional[QuantizationConfig] = None, | ||
| prefix: str = "", | ||
| layer_id: int = 0, | ||
| ): | ||
| super().__init__() | ||
| self.total_num_heads = config.num_attention_heads | ||
| self.hidden_size = config.hidden_size | ||
| self.head_size = self.hidden_size // self.total_num_heads | ||
|
|
||
| tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() | ||
| assert self.total_num_heads % tensor_model_parallel_world_size == 0 | ||
| self.num_heads = self.total_num_heads // tensor_model_parallel_world_size | ||
|
|
||
| self.qkv_proj = QKVParallelLinear( | ||
| self.hidden_size, | ||
| self.head_size, | ||
| self.total_num_heads, | ||
| bias=True, | ||
| quant_config=quant_config, | ||
| ) | ||
| self.dense = RowParallelLinear( | ||
| self.hidden_size, | ||
| self.hidden_size, | ||
| quant_config=quant_config, | ||
| ) | ||
|
|
||
| scaling = self.head_size**-0.5 | ||
| rotary_dim = int( | ||
| config.partial_rotary_factor | ||
| * (config.hidden_size // config.num_attention_heads) | ||
| ) | ||
| assert rotary_dim % 2 == 0 | ||
|
|
||
| rope_theta = getattr(config, "rope_theta", 10000.0) | ||
| max_position_embeddings = getattr(config, "max_position_embeddings", 2048) | ||
| self.rotary_emb = get_rope( | ||
| self.head_size, | ||
| rotary_dim=rotary_dim, | ||
| max_position=max_position_embeddings, | ||
| base=rope_theta, | ||
| ) | ||
| self.attn = RadixAttention( | ||
| self.num_heads, | ||
| self.head_size, | ||
| scaling, | ||
| num_kv_heads=self.num_heads, | ||
| layer_id=layer_id, | ||
| quant_config=quant_config, | ||
| prefix=add_prefix("attn", prefix), | ||
| ) | ||
|
|
||
| def forward( | ||
| self, | ||
| position_ids: torch.Tensor, | ||
| forward_batch: ForwardBatch, | ||
| hidden_states: torch.Tensor, | ||
| ) -> torch.Tensor: | ||
| qkv, _ = self.qkv_proj(hidden_states) | ||
| q, k, v = qkv.chunk(chunks=3, dim=-1) | ||
| q, k = self.rotary_emb(position_ids, q, k) | ||
| attn_output = self.attn(q, k, v, forward_batch=forward_batch) | ||
| output, _ = self.dense(attn_output) | ||
| return output | ||
|
|
||
|
|
||
| class PhiMLP(nn.Module): | ||
|
|
||
| def __init__( | ||
| self, config: PhiConfig, quant_config: Optional[QuantizationConfig] = None | ||
| ): | ||
| super().__init__() | ||
|
|
||
| n_inner = getattr(config, "n_inner", None) | ||
| n_inner = n_inner if n_inner is not None else 4 * config.hidden_size | ||
|
|
||
| self.fc1 = ColumnParallelLinear( | ||
| config.hidden_size, | ||
| n_inner, | ||
| quant_config=quant_config, | ||
| ) | ||
| self.fc2 = RowParallelLinear( | ||
| n_inner, | ||
| config.hidden_size, | ||
| quant_config=quant_config, | ||
| ) | ||
| self.act = get_act_fn(config.hidden_act) | ||
|
|
||
| def forward(self, hidden_states): | ||
| hidden_states, _ = self.fc1(hidden_states) | ||
| hidden_states = self.act(hidden_states) | ||
| hidden_states, _ = self.fc2(hidden_states) | ||
| return hidden_states | ||
|
|
||
|
|
||
| class PhiLayer(nn.Module): | ||
|
|
||
| def __init__( | ||
| self, | ||
| config: PhiConfig, | ||
| quant_config: Optional[QuantizationConfig] = None, | ||
| prefix: str = "", | ||
| idx: int = 0, | ||
| ): | ||
| super().__init__() | ||
| self.input_layernorm = nn.LayerNorm( | ||
| config.hidden_size, eps=config.layer_norm_eps | ||
| ) | ||
| self.self_attn = PhiAttention( | ||
| config, | ||
| quant_config, | ||
| prefix=add_prefix("self_attn", prefix), | ||
| layer_id=idx, | ||
| ) | ||
| self.mlp = PhiMLP(config, quant_config) | ||
|
|
||
| def forward( | ||
| self, | ||
| position_ids: torch.Tensor, | ||
| forward_batch: ForwardBatch, | ||
| hidden_states: torch.Tensor, | ||
| ) -> torch.Tensor: | ||
| residual = hidden_states | ||
| hidden_states = self.input_layernorm(hidden_states) | ||
| attn_outputs = self.self_attn( | ||
| position_ids=position_ids, | ||
| hidden_states=hidden_states, | ||
| forward_batch=forward_batch, | ||
| ) | ||
| feed_forward_hidden_states = self.mlp(hidden_states) | ||
| hidden_states = attn_outputs + feed_forward_hidden_states + residual | ||
| return hidden_states | ||
|
|
||
|
|
||
| class PhiModel(nn.Module): | ||
|
|
||
| def __init__( | ||
| self, | ||
| config: PhiConfig, | ||
| quant_config: Optional[QuantizationConfig] = None, | ||
| prefix: str = "", | ||
| ): | ||
| super().__init__() | ||
| self.config = config | ||
| self.embed_tokens = VocabParallelEmbedding( | ||
| config.vocab_size, config.hidden_size | ||
| ) | ||
|
|
||
| pp_group = get_pp_group() | ||
| pp_size = pp_group.world_size | ||
| pp_rank = pp_group.rank | ||
|
|
||
| self.start_layer = pp_rank * config.num_hidden_layers // pp_size | ||
| self.end_layer = (pp_rank + 1) * config.num_hidden_layers // pp_size | ||
|
|
||
| self.layers = make_layers( | ||
| config.num_hidden_layers, | ||
| lambda idx, prefix: PhiLayer( | ||
| config, quant_config=quant_config, prefix=prefix, idx=idx | ||
| ), | ||
| prefix=add_prefix("layers", prefix), | ||
| ) | ||
|
|
||
| self.final_layernorm = nn.LayerNorm( | ||
| config.hidden_size, eps=config.layer_norm_eps | ||
| ) | ||
|
|
||
| def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: | ||
| return self.embed_tokens(input_ids) | ||
|
|
||
| def forward( | ||
| self, | ||
| input_ids: torch.Tensor, | ||
| forward_batch: ForwardBatch, | ||
| positions: torch.Tensor, | ||
| inputs_embeds: Optional[torch.Tensor] = None, | ||
| ) -> torch.Tensor: | ||
| if inputs_embeds is not None: | ||
| hidden_states = inputs_embeds | ||
| else: | ||
| hidden_states = self.get_input_embeddings(input_ids) | ||
| for i in range(self.start_layer, self.end_layer): | ||
| layer = self.layers[i] | ||
|
|
||
| hidden_states = layer( | ||
| position_ids=positions, | ||
| forward_batch=forward_batch, | ||
| hidden_states=hidden_states, | ||
| ) | ||
| hidden_states = self.final_layernorm(hidden_states) | ||
| return hidden_states | ||
|
|
||
|
|
||
| class PhiForCausalLM(nn.Module): | ||
| packed_modules_mapping = { | ||
| "qkv_proj": [ | ||
| "q_proj", | ||
| "k_proj", | ||
| "v_proj", | ||
| ] | ||
| } | ||
|
|
||
| def __init__( | ||
| self, | ||
| config: PhiConfig, | ||
| quant_config: Optional[QuantizationConfig] = None, | ||
| prefix: str = "", | ||
| ): | ||
| super().__init__() | ||
| self.config = config | ||
| self.quant_config = quant_config | ||
| self.model = PhiModel( | ||
| config=config, | ||
| quant_config=quant_config, | ||
| prefix=add_prefix("model", prefix), | ||
| ) | ||
|
|
||
| self.lm_head = ParallelLMHead( | ||
| config.vocab_size, | ||
| config.hidden_size, | ||
| bias=True, | ||
| quant_config=quant_config, | ||
| ) | ||
| self.logits_processor = LogitsProcessor(config) | ||
|
|
||
| def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: | ||
| return self.model.get_input_embeddings(input_ids) | ||
|
|
||
| def forward( | ||
| self, | ||
| input_ids: torch.Tensor, | ||
| positions: torch.Tensor, | ||
| forward_batch: ForwardBatch, | ||
| inputs_embeds: Optional[torch.Tensor] = None, | ||
| ) -> LogitsProcessorOutput: | ||
|
|
||
| hidden_states = self.model( | ||
| input_ids=input_ids, | ||
| forward_batch=forward_batch, | ||
| positions=positions, | ||
| inputs_embeds=inputs_embeds, | ||
| ) | ||
|
|
||
| return self.logits_processor( | ||
| input_ids, hidden_states, self.lm_head, forward_batch | ||
| ) | ||
|
|
||
| def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): | ||
| params_dict = dict(self.named_parameters()) | ||
| weights = dict(weights) | ||
| loaded_keys = set() | ||
|
|
||
| for name, param in params_dict.items(): | ||
| if name in loaded_keys: | ||
| continue | ||
|
|
||
| # Handle packed weights | ||
| is_packed = False | ||
| for packed_name, src_names in self.packed_modules_mapping.items(): | ||
| if packed_name not in name: | ||
| continue | ||
|
|
||
| weight_loader = getattr(param, "weight_loader", default_weight_loader) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| for src_name in src_names: | ||
| full_src_name = name.replace(packed_name, src_name) | ||
| if full_src_name in weights: | ||
| loaded_weight = weights[full_src_name] | ||
| # The shard_id for QKVParallelLinear is 'q', 'k', 'v'. | ||
| shard_id = src_name.split("_")[0] | ||
| weight_loader(param, loaded_weight, shard_id) | ||
| loaded_keys.add(full_src_name) | ||
|
|
||
| loaded_keys.add(name) | ||
| is_packed = True | ||
| break | ||
| if is_packed: | ||
| continue | ||
|
|
||
| # Handle non-packed weights | ||
| if name not in weights: | ||
| # Redundant with the check in the loop, but good for safety | ||
| continue | ||
|
|
||
| loaded_weight = weights[name] | ||
| weight_loader = getattr(param, "weight_loader", default_weight_loader) | ||
| weight_loader(param, loaded_weight) | ||
| loaded_keys.add(name) | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| EntryClass = PhiForCausalLM | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.