Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
5f4e6e8
Added phi 1.5 model support
ppraneth Jul 8, 2025
635b0ad
Added microsoft/phi-1_5 to test_generation_models.py
ppraneth Jul 8, 2025
fb60fa2
Modified phi.py
ppraneth Jul 8, 2025
33dfcde
Updated code for Phi-1.5
ppraneth Jul 8, 2025
adc05a5
removed comments and unused import statements for Phi-1.5(phi.py)
ppraneth Jul 8, 2025
dce59cb
Merge branch 'main' into support-phi-1.5
ppraneth Jul 8, 2025
3586ecc
removed compute_logits function from phi.py
ppraneth Jul 9, 2025
1d888b8
Merge branch 'main' into support-phi-1.5
ppraneth Jul 9, 2025
a7623ed
Modified phi.py to fix errors and added gelu_new activation function …
ppraneth Jul 9, 2025
d48f0cb
Modified load_weights function in phi.py
ppraneth Jul 10, 2025
f672db1
Merge branch 'main' into support-phi-1.5
ppraneth Jul 10, 2025
ccf669f
Merge branch 'main' into support-phi-1.5
lifuhuang Jul 10, 2025
876bcb4
Applied suggested changes to phi.py
ppraneth Jul 11, 2025
b6d8263
Applied suggested changes to phi.py removed pylint comment
ppraneth Jul 11, 2025
23d7c45
Updated generative_models.md
ppraneth Jul 11, 2025
8a9bc15
Updated generative_models.md
ppraneth Jul 11, 2025
59e0da9
Merge branch 'main' into support-phi-1.5
ppraneth Jul 11, 2025
376323d
Merge branch 'main' into support-phi-1.5
ppraneth Jul 12, 2025
35e8118
Merge branch 'main' into support-phi-1.5
ppraneth Jul 12, 2025
8e871e5
Merge branch 'main' into support-phi-1.5
ppraneth Jul 12, 2025
435f3a0
Merge branch 'main' into support-phi-1.5
ppraneth Jul 13, 2025
bcc9570
Merge branch 'main' into support-phi-1.5
lifuhuang Jul 14, 2025
66ac35e
Merge branch 'main' into support-phi-1.5
lifuhuang Jul 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/supported_models/generative_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ in the GitHub search bar.
| **Llama** (2, 3.x, 4 series) | `meta-llama/Llama-4-Scout-17B-16E-Instruct` | Meta’s open LLM series, spanning 7B to 400B parameters (Llama 2, 3, and new Llama 4) with well-recognized performance. [SGLang provides Llama-4 model-specific optimizations](https://docs.sglang.ai/references/llama4) |
| **Mistral** (Mixtral, NeMo, Small3) | `mistralai/Mistral-7B-Instruct-v0.2` | Open 7B LLM by Mistral AI with strong performance; extended into MoE (“Mixtral”) and NeMo Megatron variants for larger scale. |
| **Gemma** (v1, v2, v3) | `google/gemma-3-1b-it` | Google’s family of efficient multilingual models (1B–27B); Gemma 3 offers a 128K context window, and its larger (4B+) variants support vision input. |
| **Phi** (Phi-3, Phi-4 series) | `microsoft/Phi-4-multimodal-instruct` | Microsoft’s Phi family of small models (1.3B–5.6B); Phi-4-mini is a high-accuracy text model and Phi-4-multimodal (5.6B) processes text, images, and speech in one compact model. |
| **Phi** (Phi-1.5, Phi-2, Phi-3, Phi-4 series) | `microsoft/Phi-4-multimodal-instruct` | Microsoft’s Phi family of small models (1.3B–5.6B); Phi-4-mini is a high-accuracy text model and Phi-4-multimodal (5.6B) processes text, images, and speech in one compact model. |
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Thanks for adding Phi-1.5 support! To make the documentation clearer for users, could you please update the example model and description for the Phi model family? The current example and description are specific to Phi-4. Since this PR adds support for Phi-1.5, it would be great to reflect that.

Suggested change
| **Phi** (Phi-1.5, Phi-2, Phi-3, Phi-4 series) | `microsoft/Phi-4-multimodal-instruct` | Microsoft’s Phi family of small models (1.3B–5.6B); Phi-4-mini is a high-accuracy text model and Phi-4-multimodal (5.6B) processes text, images, and speech in one compact model. |
| **Phi** (Phi-1.5, Phi-2, Phi-3, Phi-4 series) | `microsoft/phi-1_5` | Microsoft’s family of small language models (1.3B–5.6B), known for strong performance at small sizes. Newer variants like Phi-4 also support multimodal inputs. |

| **MiniCPM** (v3, 4B) | `openbmb/MiniCPM3-4B` | OpenBMB’s series of compact LLMs for edge devices; MiniCPM 3 (4B) achieves GPT-3.5-level results in text tasks. |
| **OLMoE** (Open MoE) | `allenai/OLMoE-1B-7B-0924` | Allen AI’s open Mixture-of-Experts model (7B total, 1B active parameters) delivering state-of-the-art results with sparse expert activation. |
| **StableLM** (3B, 7B) | `stabilityai/stablelm-tuned-alpha-7b` | StabilityAI’s early open-source LLM (3B & 7B) for general text generation; a demonstration model with basic instruction-following ability. |
Expand Down
292 changes: 292 additions & 0 deletions python/sglang/srt/models/phi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,292 @@
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/phi.py
import math
from typing import Iterable, Optional, Tuple, Union

import torch
from torch import nn
from transformers import PhiConfig

from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import get_act_fn
from sglang.srt.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix, make_layers


class PhiAttention(nn.Module):

def __init__(
self,
config: PhiConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.total_num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.total_num_heads

tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
assert self.total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = self.total_num_heads // tensor_model_parallel_world_size

# pylint: disable=C0103
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ppraneth , do we need this?

self.qkv_proj = QKVParallelLinear(
self.hidden_size,
self.head_size,
self.total_num_heads,
bias=True,
quant_config=quant_config,
)
self.dense = RowParallelLinear(
self.hidden_size,
self.hidden_size,
quant_config=quant_config,
)

scaling = self.head_size**-0.5
rotary_dim = int(
config.partial_rotary_factor
* (config.hidden_size // config.num_attention_heads)
)
assert rotary_dim % 2 == 0

# pylint: disable=C0301
# Refer to:
# https://huggingface.co/microsoft/phi-1_5/blob/d212a789620c380ff32ca1d1ee9943a777360987/modeling_phi.py#L518
rope_theta = getattr(config, "rope_theta", 10000.0)
max_position_embeddings = getattr(config, "max_position_embeddings", 2048)
self.rotary_emb = get_rope(
self.head_size,
rotary_dim=rotary_dim,
max_position=max_position_embeddings,
base=rope_theta,
)
self.attn = RadixAttention(
self.num_heads,
self.head_size,
scaling,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
)

def forward(
self,
position_ids: torch.Tensor,
forward_batch: ForwardBatch,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k = self.rotary_emb(position_ids, q, k)
attn_output = self.attn(q, k, v, forward_batch=forward_batch)
output, _ = self.dense(attn_output)
return output


class PhiMLP(nn.Module):

def __init__(
self, config: PhiConfig, quant_config: Optional[QuantizationConfig] = None
):
super().__init__()

n_inner = getattr(config, "n_inner", None)
n_inner = n_inner if n_inner is not None else 4 * config.hidden_size

self.fc1 = ColumnParallelLinear(
config.hidden_size,
n_inner,
quant_config=quant_config,
)
self.fc2 = RowParallelLinear(
n_inner,
config.hidden_size,
quant_config=quant_config,
)
self.act = get_act_fn(config.hidden_act)

def forward(self, hidden_states):
hidden_states, _ = self.fc1(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states, _ = self.fc2(hidden_states)
return hidden_states


class PhiLayer(nn.Module):

def __init__(
self,
config: PhiConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.input_layernorm = nn.LayerNorm(
config.hidden_size, eps=config.layer_norm_eps
)
self.self_attn = PhiAttention(
config, quant_config, prefix=add_prefix("self_attn", prefix)
)
self.mlp = PhiMLP(config, quant_config)

def forward(
self,
position_ids: torch.Tensor,
forward_batch: ForwardBatch,
hidden_states: torch.Tensor,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
attn_outputs = self.self_attn(
position_ids=position_ids,
hidden_states=hidden_states,
forward_batch=forward_batch,
)
feed_forward_hidden_states = self.mlp(hidden_states)
hidden_states = attn_outputs + feed_forward_hidden_states + residual
return hidden_states


class PhiModel(nn.Module):

def __init__(
self,
config: PhiConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.hidden_size
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: PhiLayer(config, prefix=prefix),
prefix=add_prefix("layers", prefix),
)
self.final_layernorm = nn.LayerNorm(
config.hidden_size, eps=config.layer_norm_eps
)

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)

def forward(
self,
input_ids: torch.Tensor,
forward_batch: ForwardBatch,
positions: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor]:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The forward method of PhiModel is declared to return Union[torch.Tensor], but it always returns a torch.Tensor. Using the more specific type hint torch.Tensor improves clarity and type-checking accuracy.

Suggested change
) -> Union[torch.Tensor]:
) -> torch.Tensor:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ppraneth , can you apply this suggested change? Thanks!

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The type hint Union[torch.Tensor] is redundant and can be simplified to torch.Tensor. The Union import on line 2 can then be removed.

Suggested change
) -> Union[torch.Tensor]:
) -> torch.Tensor:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
# hidden_states = layer(positions, hidden_states, forward_batch=forward_batch)
hidden_states = layer(
position_ids=positions,
forward_batch=forward_batch,
hidden_states=hidden_states,
)
hidden_states = self.final_layernorm(hidden_states)
return hidden_states


# Pending
class PhiForCausalLM(nn.Module):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
]
}

def __init__(
self,
config: PhiConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = PhiModel(
config=config,
quant_config=quant_config,
prefix=add_prefix("model", prefix),
)

self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
bias=True,
quant_config=quant_config,
)
self.logits_processor = LogitsProcessor(config.vocab_size)

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
inputs_embeds: Optional[torch.Tensor] = None,
) -> LogitsProcessorOutput:
# hidden_states = self.model(input_ids, positions, forward_batch, inputs_embeds) changed it to below
hidden_states = self.model(
input_ids=input_ids,
forward_batch=forward_batch,
positions=positions,
inputs_embeds=inputs_embeds,
)

return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
)

def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(
self.lm_head, hidden_states, sampling_metadata, self.lm_head.bias
)
return logits

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if name.endswith(".bias") and name not in params_dict:
continue
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue

param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider adding a check at the end of this method to warn about any uninitialized parameters. This can help debug issues with missing weights.


EntryClass = PhiForCausalLM
1 change: 1 addition & 0 deletions test/srt/models/test_generation_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class ModelCase:
"THUDM/glm-4-9b-chat", tp_size=2, trust_remote_code=True, skip_long_prompt=True
),
ModelCase("openai-community/gpt2"),
ModelCase("microsoft/phi-1_5", trust_remote_code=True),
ModelCase("microsoft/Phi-3-small-8k-instruct", trust_remote_code=True),
ModelCase("allenai/OLMo-2-1124-7B-Instruct", skip_long_prompt=True),
ModelCase("ibm-granite/granite-3.0-2b-instruct", skip_long_prompt=True),
Expand Down