Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions fastdeploy/model_executor/layers/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,14 +105,14 @@ def __init__(
self.tp_rank = self.fd_config.parallel_config.tensor_parallel_rank
self.tp_group = self.fd_config.parallel_config.tp_group
is_input_norm = prefix.endswith(".input_layernorm")
is_last_norm = prefix.endswith(".norm")
self.is_last_norm = prefix.endswith(".norm")
self.split_x = (
self.fd_config.parallel_config.use_sequence_parallel_moe
and self.layer_id == self.fd_config.model_config.moe_layer_start_index
and is_input_norm
)
self.allgather_out = self.fd_config.parallel_config.use_sequence_parallel_moe and (
(self.layer_id > self.fd_config.model_config.moe_layer_start_index and is_input_norm) or is_last_norm
(self.layer_id > self.fd_config.model_config.moe_layer_start_index and is_input_norm)
)

self.init_weight()
Expand Down
7 changes: 6 additions & 1 deletion fastdeploy/model_executor/models/deepseek_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,12 @@ def forward(
position_ids,
mask_encoder_batch,
)
out = self.norm(hidden_states, residual, forward_meta=forward_meta)[0]
hidden_states = hidden_states + residual

if self.norm.is_last_norm and self.norm.fd_config.parallel_config.use_sequence_parallel_moe:
hidden_states = self.norm.allgather(hidden_states, forward_meta.ids_remove_padding.shape[0])

out = self.norm(hidden_states, forward_meta=forward_meta)[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all_gather可以放在norm后吧 ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

对,这个合理,done~

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

测了qwen 没啥问题


return out

Expand Down
7 changes: 6 additions & 1 deletion fastdeploy/model_executor/models/ernie4_5_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,12 @@ def forward(
for i in range(self.num_layers):
hidden_states, residual = self.layers[i](forward_meta, hidden_states, residual)

out = self.norm(hidden_states, residual, forward_meta=forward_meta)[0]
hidden_states = hidden_states + residual

if self.norm.is_last_norm and self.norm.fd_config.parallel_config.use_sequence_parallel_moe:
hidden_states = self.norm.allgather(hidden_states, forward_meta.ids_remove_padding.shape[0])

out = self.norm(hidden_states, forward_meta=forward_meta)[0]

if current_platform.is_iluvatar() and forward_meta.attn_backend.mixed:
out = forward_meta.attn_backend.reverse_transpose(out)
Expand Down
7 changes: 6 additions & 1 deletion fastdeploy/model_executor/models/ernie4_5_mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,12 @@ def forward(
for i in range(self.num_layers):
hidden_states, residual = self.mtp_block[i](forward_meta, hidden_states, residual)

hidden_states = self.norm(hidden_states, residual)[0]
hidden_states = hidden_states + residual

if self.norm.is_last_norm and self.norm.fd_config.parallel_config.use_sequence_parallel_moe:
hidden_states = self.norm.allgather(hidden_states, forward_meta.ids_remove_padding.shape[0])

hidden_states = self.norm(hidden_states, forward_meta=forward_meta)[0]

return hidden_states

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,12 @@ def forward(
vl_moe_meta,
)

out = self.norm(hidden_states, residual, forward_meta=forward_meta)[0]
hidden_states = hidden_states + residual

if self.norm.is_last_norm and self.norm.fd_config.parallel_config.use_sequence_parallel_moe:
hidden_states = self.norm.allgather(hidden_states, forward_meta.ids_remove_padding.shape[0])

out = self.norm(hidden_states, forward_meta=forward_meta)[0]

return out

Expand Down
7 changes: 6 additions & 1 deletion fastdeploy/model_executor/models/glm4_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,12 @@ def forward(
for i in range(self.num_layers):
hidden_states, residual = self.layers[i](forward_meta, hidden_states, residual)

out = self.norm(hidden_states, residual, forward_meta=forward_meta)[0]
hidden_states = hidden_states + residual

if self.norm.is_last_norm and self.norm.fd_config.parallel_config.use_sequence_parallel_moe:
hidden_states = self.norm.allgather(hidden_states, forward_meta.ids_remove_padding.shape[0])

out = self.norm(hidden_states, forward_meta=forward_meta)[0]

return out

Expand Down
10 changes: 8 additions & 2 deletions fastdeploy/model_executor/models/gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,14 @@ def forward(self, ids_remove_padding: paddle.Tensor, forward_meta: ForwardMeta):
for i in range(self.num_layers):
hidden_states, residual = self.layers[i](forward_meta, hidden_states, residual)

hidden_states = self.norm(hidden_states, residual)[0]
return hidden_states
hidden_states = hidden_states + residual

if self.norm.is_last_norm and self.norm.fd_config.parallel_config.use_sequence_parallel_moe:
hidden_states = self.norm.allgather(hidden_states, forward_meta.ids_remove_padding.shape[0])

out = self.norm(hidden_states, forward_meta=forward_meta)[0]

return out


@ModelRegistry.register_model_class(
Expand Down
7 changes: 6 additions & 1 deletion fastdeploy/model_executor/models/qwen3moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,12 @@ def forward(
for i in range(self.num_layers):
hidden_states, residual = self.layers[i](forward_meta, hidden_states, residual)

out = self.norm(hidden_states, residual, forward_meta=forward_meta)[0]
hidden_states = hidden_states + residual

if self.norm.is_last_norm and self.norm.fd_config.parallel_config.use_sequence_parallel_moe:
hidden_states = self.norm.allgather(hidden_states, forward_meta.ids_remove_padding.shape[0])

out = self.norm(hidden_states, forward_meta=forward_meta)[0]

return out

Expand Down
2 changes: 1 addition & 1 deletion fastdeploy/spec_decode/mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,7 +876,7 @@ def _propose(self, step_use_cudagraph: bool = False):
)

# 4. Compute logits, Sample
logits = self.model.compute_logits(hidden_states)
logits = self.model.compute_logits(hidden_states, forward_meta=self.forward_meta)
if self.enable_logprob and substep == 0:
first_token_logits = self.model.compute_logits(self.model_inputs["first_token_hidden_states"])

Expand Down
Loading