Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions examples/summarization/run_summarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,12 @@ def main():
token=model_args.token,
)

if training_args.do_train and training_args.use_compiled_autograd:
from habana_frameworks.torch.dynamo.compile_backend.experimental import enable_compiled_autograd

enable_compiled_autograd()
torch._C._set_autograd_fallback_mode("nothing")

# Log on each process the small summary:
mixed_precision = training_args.bf16 or gaudi_config.use_torch_autocast
logger.warning(
Expand Down
4 changes: 3 additions & 1 deletion optimum/habana/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def __init__(
step_scheduler_with_optimizer: bool = True,
kwargs_handlers: list[KwargsHandler] | None = None,
dynamo_backend: GaudiDynamoBackend | str | None = None,
dynamic: bool | None = None,
distribution_strategy: str = None,
force_autocast: bool = False,
):
Expand Down Expand Up @@ -310,6 +311,7 @@ def __init__(
FutureWarning,
)
self.step_scheduler_with_optimizer = step_scheduler_with_optimizer
self.dynamic = dynamic

# Mixed precision attributes
self.scaler = None
Expand Down Expand Up @@ -776,7 +778,7 @@ def _prepare_deepspeed(self, *args):
if self.state.dynamo_plugin.backend == GaudiDynamoBackend.HPU_BACKEND and not is_compiled_module(
kwargs["model"]
):
engine.compile()
engine.compile(compile_kwargs={"dynamic": self.dynamic})
if optimizer is not None:
optimizer = DeepSpeedOptimizerWrapper(optimizer)
if scheduler is not None:
Expand Down
1 change: 1 addition & 0 deletions optimum/habana/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2409,6 +2409,7 @@ def create_accelerator_and_postprocess(self):
"deepspeed_plugin": self.args.deepspeed_plugin,
"gradient_accumulation_plugin": gradient_accumulation_plugin,
"distribution_strategy": self.args.distribution_strategy,
"dynamic": self.args.compile_dynamic,
}
if is_accelerate_available("0.28.0"):
args["dataloader_config"] = dataloader_config
Expand Down
14 changes: 14 additions & 0 deletions optimum/habana/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ class GaudiTrainingArguments(TrainingArguments):
Whether to use HPU graphs for performing inference. It will speed up latency but may not be compatible with some operations.
use_hpu_graphs_for_training (`bool`, *optional*, defaults to `False`):
Whether to use HPU graphs for performing inference. It will speed up training but may not be compatible with some operations.
use_compiled_autograd (`bool`, *optional*, defaults to `False`):
Whether to use compiled autograd for training. Currently only for summarization models.
compile_dynamic (`bool|None`, *optional*, defaults to `None`):
Set value of 'dynamic' parameter for torch.compile.
disable_tensor_cache_hpu_graphs (`bool`, *optional*, defaults to `False`):
Whether to disable tensor cache when using hpu graphs. If True, tensors won't be cached in hpu graph and memory can be saved.
max_hpu_graphs (`int`, *optional*):
Expand Down Expand Up @@ -156,6 +160,16 @@ class GaudiTrainingArguments(TrainingArguments):
},
)

use_compiled_autograd: Optional[bool] = field(
default=False,
metadata={"help": ("Whether to use compiled autograd for training. Currently only for summarization models.")},
)

compile_dynamic: Optional[bool | None] = field(
default=None,
metadata={"help": ("Set value of 'dynamic' parameter for torch.compile.")},
)

disable_tensor_cache_hpu_graphs: Optional[bool] = field(
default=False,
metadata={"help": "Whether to use a tensor cache for hpu graphs."},
Expand Down