Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 78 additions & 43 deletions vllm/model_executor/models/hunyuan_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -889,7 +889,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
return loaded_params


class HunYuanV1Base(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts):
class HunyuanV1ModelBase(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
Expand All @@ -902,6 +902,50 @@ class HunYuanV1Base(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts):
],
}

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
model_output = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return model_output

def compute_logits(
self,
hidden_states: torch.Tensor,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states)
return logits

def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
return IntermediateTensors({
"hidden_states":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
"residual":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
})

def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights)


class HunYuanMoEV1Base(HunyuanV1ModelBase, MixtureOfExperts):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()

Expand Down Expand Up @@ -989,54 +1033,45 @@ def update_physical_experts_metadata(
moe.n_redundant_experts = self.num_redundant_experts
moe.experts.update_expert_map()

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
model_output = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return model_output
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
return self.model.get_expert_mapping()

def compute_logits(
self,
hidden_states: torch.Tensor,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states)
return logits

def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
return IntermediateTensors({
"hidden_states":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
"residual":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
})
class HunYuanDenseV1Base(HunyuanV1ModelBase):

def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()

def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
return self.model.get_expert_mapping()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config

self.model = HunYuanModel(vllm_config=vllm_config, prefix="model")
if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
if config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight

logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size,
logit_scale)
else:
self.lm_head = PPMissingLayer()


class HunYuanDenseV1ForCausalLM(HunYuanV1Base):
class HunYuanDenseV1ForCausalLM(HunYuanDenseV1Base):
pass


class HunYuanMoEV1ForCausalLM(HunYuanV1Base):
pass
class HunYuanMoEV1ForCausalLM(HunYuanMoEV1Base):
pass