Skip to content

Commit 0cafcf8

Browse files
committed
Rename helper loss
1 parent 4ee2e7d commit 0cafcf8

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

benchmarks/dpo_ewc/continual_dpo_EWC_trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def move_to_device(batch: Any) -> Any:
167167
batch = move_to_device(batch)
168168

169169
model.zero_grad(set_to_none=True)
170-
loss = self.compute_dpo_loss(model, batch)
170+
loss = self.compute_dpo_loss_for_fisher(model, batch)
171171
loss.backward()
172172

173173
for name, param in model.named_parameters():
@@ -183,7 +183,7 @@ def move_to_device(batch: Any) -> Any:
183183
fisher[name] /= sample_count
184184
return fisher
185185

186-
def compute_dpo_loss(
186+
def compute_dpo_loss_for_fisher(
187187
self,
188188
model: Union[PreTrainedModel, nn.Module],
189189
batch: dict[str, Union[torch.Tensor, Any]],

0 commit comments

Comments
 (0)