Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion optimum/habana/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,7 +986,10 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args, use_reentrant: Optio
inputs["flash_attention_causal_mask"] = True
if self.model.config is not None:
if self.model.config.model_type in ["llama", "qwen2", "mistral", "starcoder2"]:
inputs["lazy_mode"] = args.use_lazy_mode
forward_method = getattr(self.model, "forward")
signature = inspect.signature(forward_method)
if "lazy_mode" in signature.parameters:
inputs["lazy_mode"] = args.use_lazy_mode
# TODO: keep syncs for fast DDP?
with self.accelerator.accumulate(model):
tr_loss_step = self.training_step(model, inputs)
Expand Down