Skip to content

use repeat_interleave for input_features in kd_trainer#121

Open
zhuxiaoxuhit wants to merge 1 commit intowenet-e2e:mainfrom
zhuxiaoxuhit:fix/input-features-repeat-interleave-in-kd-trainer
Open

use repeat_interleave for input_features in kd_trainer#121
zhuxiaoxuhit wants to merge 1 commit intowenet-e2e:mainfrom
zhuxiaoxuhit:fix/input-features-repeat-interleave-in-kd-trainer

Conversation

@zhuxiaoxuhit
Copy link
Copy Markdown
Contributor

Same issue as #119 but in KnowledgeDistillationTrainer._prepare_logprob_inputs.

input_features uses .repeat(num_generations, 1, 1) while all other tensors in the same function use .repeat_interleave(num_generations, dim=0). These two produce different element orderings when batch size > 1:

  • repeat_interleave: [s0, s0, s0, s1, s1, s1, ...] which matches how generated_ids are laid out
  • repeat: [s0, s1, s0, s1, s0, s1, ...] which does not

So when batch size > 1, input_features ends up paired with the wrong completions during log probability computation, making both the student and teacher logprob calculations incorrect and corrupting the KD training signal.

@yuekaizhang
Copy link
Copy Markdown
Contributor

@zhuxiaoxuhit Thanks.

@yuekaizhang
Copy link
Copy Markdown
Contributor

@robin1001 Would you mind helping merge this one also? Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants