Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/recipes/test_knowledge_distillation_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _get_test_config_overrides(self, dtype_str: str = "fp32", epochs: int = 2):

def _fetch_expected_loss_values(self, model_type):
loss_values_map = {
"llama3": [11.0651, 11.0577, 11.0540, 11.7671],
"llama3": [11.7898, 11.7825, 11.7788, 11.7671],
}
return loss_values_map[model_type]

Expand Down
12 changes: 10 additions & 2 deletions torchtune/modules/loss/kd_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def forward(
student_logits: torch.Tensor,
teacher_logits: torch.Tensor,
labels: torch.Tensor,
normalize: bool = True,
) -> torch.Tensor:
"""
Args:
Expand All @@ -39,6 +40,7 @@ def forward(
(batch_size*num_tokens, vocab_size).
labels (torch.Tensor): Ground truth labels of shape
(batch_size, vocab_size).
normalize (bool): Whether to normalize the loss by the number of unmasked elements.

Returns:
torch.Tensor: KL divergence loss of shape (1,).
Expand All @@ -50,6 +52,8 @@ def forward(
prod_probs = torch.masked_fill(teacher_prob * student_logprob, inf_mask, 0)
x = torch.sum(prod_probs, dim=-1).view(-1)
mask = (labels != self.ignore_index).int()
if not normalize:
return -torch.sum(x * mask.view(-1), dim=0)
if torch.sum(mask.view(-1), dim=0) == 0:
return torch.tensor(0.0, device=x.device)
return -torch.sum(x * mask.view(-1), dim=0) / torch.sum(mask.view(-1), dim=0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW we have masked_mean which we use in other parts of the codebase for this https://github.com/pytorch/torchtune/blob/17ba37d281548e5d60215f741888066717ad5b3e/torchtune/rlhf/rewards.py#L101 but not hugely fussed about whether it's used here.

Expand Down Expand Up @@ -118,15 +122,19 @@ def forward(
student_logits_chunk.reshape(-1, student_logits_chunk.size(-1))
for student_logits_chunk in student_logits
]
mask = (labels != self.ignore_index).int()
# chunk and reshape labels (bsz, num_tokens, vocab) -> [(bsz*num_tokens/num_chunks, vocab)]
labels = [
target_chunk.reshape(-1)
for target_chunk in labels.chunk(self.num_output_chunks, dim=1)
]

total_fkl_loss = 0.0
for student_chunk, teacher_chunk, label_chunk in zip(
student_logits, teacher_logits, labels
):
total_fkl_loss += self.fkl_loss(student_chunk, teacher_chunk, label_chunk)
total_fkl_loss += self.fkl_loss(
student_chunk, teacher_chunk, label_chunk, normalize=False
)

return total_fkl_loss / self.num_output_chunks
return total_fkl_loss / torch.sum(mask.view(-1), dim=0)
Loading