[LLM] Add MTP for Deepseekv3#9876
Conversation
|
Thanks for your contribution! |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## develop #9876 +/- ##
===========================================
- Coverage 51.34% 51.28% -0.07%
===========================================
Files 745 745
Lines 118590 118778 +188
===========================================
+ Hits 60886 60910 +24
- Misses 57704 57868 +164 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
| dtype=paddle.get_default_dtype(), | ||
| default_initializer=nn.initializer.Constant(0.0), | ||
| ) | ||
| self.e_score_correction_bias.is_distributed = True |
| k_pe = GatherOp.apply(k_pe) | ||
| k_pe = k_pe.reshape([-1, q_len, 1, self.qk_rope_head_dim]).expand( | ||
| [-1, q_len, self.num_heads, self.qk_rope_head_dim] | ||
| ) |
There was a problem hiding this comment.
sp 还没修复,部分代码在这里保留了
| key_states[:, :, :, : self.qk_nope_head_dim] = k_nope | ||
| key_states[:, :, :, self.qk_nope_head_dim :] = k_pe | ||
| query_states = paddle.concat([q_nope, q_pe], axis=-1) | ||
| key_states = paddle.concat([k_nope, k_pe], axis=-1) |
There was a problem hiding this comment.
参考自动并行实现,结果是一致的
| if self.sequence_parallel: | ||
| inputs_embeds = inputs_embeds.reshape([-1, inputs_embeds.shape[-1]]) | ||
| inputs_embeds = ScatterOp.apply(inputs_embeds) | ||
| return return_args(inputs_embeds, attention_mask, attn_mask_startend_row_indices, position_ids) |
There was a problem hiding this comment.
可能有潜在 显存问题,现在是 input_emb mtp_emb 一路pp,向后发送吗?
There was a problem hiding this comment.
好像确实存在显存问题,每一层都会有mtp_emb,这部分我再想一下
There was a problem hiding this comment.
这里暂时没有更好的方式来计算,后续可以尝试直接输入完整的input_embed向后传,应该只多占用一个hidden_state
|
|
||
| if self.config.num_nextn_predict_layers > 0: | ||
| hidden_states_list = paddle.split(hidden_states, self.config.num_nextn_predict_layers + 1) | ||
| inputs_embeds_mtp = hidden_states_list[-self.config.num_nextn_predict_layers :] |
There was a problem hiding this comment.
很低效的,应该最后一层share weight的话,之前去最后一层取。最后一层看是不是可以根据label idx取
|
|
||
| inputs_embeds_cur_depth = paddle.concat( | ||
| [inputs_embeds_ori[:, (nextn + 1) :, :], inputs_embeds_extra[:, : (nextn + 1), :]], axis=1 | ||
| ) |
There was a problem hiding this comment.
计算或显存占用,当前方式多占用显存为[batch_size, n, hidden_size], 其中n为MTP层数,这部分显存占用也还能接受。
| hidden_states = hidden_states.reshape([-1, seq_length, hidden_states.shape[-1]]) | ||
|
|
||
| inputs_embeds_cur_depth = paddle.concat( | ||
| [inputs_embeds_ori[:, (nextn + 1) :, :], inputs_embeds_extra[:, : (nextn + 1), :]], axis=1 |
There was a problem hiding this comment.
输入是【1,2,3,4,5】,embedding的是【1,2,3,4,5】,其中【1,2,3,4】和decoder架构一致forward,MTP layer处理【2,3,4,5】,此处的concat是用于拼接【2,3,4】和【5】
| ) | ||
| return BaseModelOutputWithPast( | ||
| last_hidden_state=hidden_states, | ||
| past_key_values=next_cache, |
Before submitting
testsfolder. If there are codecov issues, please add tests cases first.PR types
New features
PR changes
Models
Description