Skip to content

Commit b27e801

Browse files
sguggerMagnus Pierrau
authored andcommitted
Migrate torchdynamo to torch.compile (huggingface#20634)
* Migrate torchdynamo to torch.compile * Add docstring and generic option * Properly use the function... * Reorg args
1 parent af927c7 commit b27e801

File tree

5 files changed

+70
-15
lines changed

5 files changed

+70
-15
lines changed

docs/source/en/perf_train_gpu_one.mdx

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -718,11 +718,11 @@ For some applications, such as pretraining large language models, applying all t
718718
719719
Another use case for training on many GPUs is if the model does not fit on a single GPU with all the mentioned tricks. There are still more methods we can apply although life starts to get a bit more complicated. This usually involves some form of pipeline or tensor parallelism where the model itself is distributed across several GPUs. One can also make use of DeepSpeed which implements some of these parallelism strategies along with some more optimization to reduce the memory footprint such as partitioning the optimizer states. You can read more about this in the ["Multi-GPU training" section](perf_train_gpu_many).
720720
721-
## Inference with torchdynamo
721+
## Using torch.compile
722722
723-
TorchDynamo is a new tracer that uses Python’s frame evaluation API to automatically create FX traces from existing PyTorch programs. After capturing the FX graph, different backends can be deployed to lower the graph to an optimized engine. You can choose one option below for performance boost.
723+
PyTorch 2.0 introduces a new compile function, you can learn more about it [in their documentation](https://pytorch.org/get-started/pytorch-2.0/). It uses Python’s frame evaluation API to automatically create a graph from existing PyTorch programs. After capturing the graph, different backends can be deployed to lower the graph to an optimized engine. You can choose one option below for performance boost.
724724
725-
TorchDynamo has a growing list of backends, which can be found in [backends.py](https://github.com/pytorch/pytorch/blob/master/torch/_dynamo/optimizations/backends.py)
725+
`torch.compile` has a growing list of backends, which can be found in [backends.py](https://github.com/pytorch/pytorch/blob/master/torch/_dynamo/optimizations/backends.py)
726726
or `torchdynamo.list_backends()` each of which with its optional dependencies.
727727
728728
Some of the most commonly used backends are

src/transformers/trainer.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,8 @@
144144
is_ipex_available,
145145
is_sagemaker_dp_enabled,
146146
is_sagemaker_mp_enabled,
147+
is_torch_compile_available,
147148
is_torch_tpu_available,
148-
is_torchdynamo_available,
149149
logging,
150150
)
151151
from .utils.generic import ContextManagers
@@ -642,9 +642,9 @@ def __init__(
642642
# very last
643643
self._memory_tracker.stop_and_update_metrics()
644644

645-
# torchdynamo
646-
if args.torchdynamo is not None and not is_torchdynamo_available():
647-
raise RuntimeError("Using torchdynamo requires a nighly install of PyTorch.")
645+
# torch.compile
646+
if args.torch_compile and not is_torch_compile_available():
647+
raise RuntimeError("Using torch.compile requires a nighly install of PyTorch.")
648648

649649
def add_callback(self, callback):
650650
"""
@@ -1321,10 +1321,9 @@ def ipex_optimize_model(self, model, training=False, dtype=torch.float32):
13211321
return model
13221322

13231323
def _wrap_model(self, model, training=True, dataloader=None):
1324-
if self.args.torchdynamo is not None:
1325-
import torch._dynamo as dynamo
1324+
if self.args.torch_compile:
1325+
model = torch.compile(model, backend=self.args.torch_compile_backend, mode=self.args.torch_compile_mode)
13261326

1327-
model = dynamo.optimize(self.args.torchdynamo)(model)
13281327
if self.args.use_ipex:
13291328
dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32
13301329
model = self.ipex_optimize_model(model, training, dtype=dtype)

src/transformers/training_args.py

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@
7373
trainer_log_levels = dict(**log_levels, passive=-1)
7474

7575

76-
DYNAMO_BACKENDS = [
76+
TORCH_COMPILE_BACKENDS = [
7777
"eager",
7878
"aot_eager",
7979
"inductor",
@@ -514,6 +514,21 @@ class TrainingArguments:
514514
information.
515515
use_mps_device (`bool`, *optional*, defaults to `False`):
516516
Whether to use Apple Silicon chip based `mps` device.
517+
torch_compile (`bool`, *optional*, defaults to `False`):
518+
Whether or not to compile the model using PyTorch 2.0
519+
[`torch.compile`](https://pytorch.org/get-started/pytorch-2.0/) (requires a nighlty install of PyTorch).
520+
521+
If set, the backend will default to `"inductor"` (can be customized with `torch_compile_backend`) and the
522+
mode will default to `"default"` (can be customized with `torch_compile_mode`).
523+
torch_compile_backend (`str`, *optional*):
524+
The backend to use in `torch.compile`. If set to any value, `torch_compile` will be set to `True`.
525+
526+
Possible choices are `"eager"`, `"aot_eager"`, `"inductor"`, `"nvfuser"`, `"aot_nvfuser"`,
527+
`"aot_cudagraphs"`, `"ofi"`, `"fx2trt"`, `"onnxrt"` and `"ipex"`.
528+
torch_compile_mode (`str`, *optional*):
529+
The mode to use in `torch.compile`. If set to any value, `torch_compile` will be set to `True`.
530+
531+
Possible choices are `"default"`, `"reduce-overhead"` and `"max-autotune"`.
517532
"""
518533

519534
framework = "pt"
@@ -983,8 +998,8 @@ class TrainingArguments:
983998
torchdynamo: Optional[str] = field(
984999
default=None,
9851000
metadata={
986-
"help": "Sets up the backend compiler for TorchDynamo.",
987-
"choices": DYNAMO_BACKENDS,
1001+
"help": "This argument is deprecated, use `--torch_compile_backend` instead.",
1002+
"choices": TORCH_COMPILE_BACKENDS,
9881003
},
9891004
)
9901005
ray_scope: Optional[str] = field(
@@ -1006,6 +1021,23 @@ class TrainingArguments:
10061021
"help": "Overrides the default timeout for distributed training (value should be given in seconds)."
10071022
},
10081023
)
1024+
torch_compile: bool = field(
1025+
default=False, metadata={"help": "If set to `True`, the model will be wrapped in `torch.compile`."}
1026+
)
1027+
torch_compile_backend: Optional[str] = field(
1028+
default=None,
1029+
metadata={
1030+
"help": "Which backend to use with `torch.compile`, passing one will trigger a model compilation.",
1031+
"choices": TORCH_COMPILE_BACKENDS,
1032+
},
1033+
)
1034+
torch_compile_mode: Optional[str] = field(
1035+
default=None,
1036+
metadata={
1037+
"help": "Which mode to use with `torch.compile`, passing one will trigger a model compilation.",
1038+
"choices": ["default", "reduce-overhead", "max-autotune"],
1039+
},
1040+
)
10091041

10101042
def __post_init__(self):
10111043
# Handle --use_env option in torch.distributed.launch (local_rank not passed as an arg then).
@@ -1148,10 +1180,24 @@ def __post_init__(self):
11481180
" (`--bf16_full_eval`) can only be used on CUDA or CPU devices."
11491181
)
11501182

1151-
if self.framework == "pt" and is_torch_available() and self.torchdynamo is not None:
1183+
if self.torchdynamo is not None:
1184+
warnings.warn(
1185+
"`torchdynamo` is deprecated and will be removed in version 5 of 🤗 Transformers. Use"
1186+
" `torch_compile_backend` instead",
1187+
FutureWarning,
1188+
)
1189+
self.torch_compile_backend = self.torchdynamo
1190+
if (self.torch_compile_mode is not None or self.torch_compile_backend is not None) and not self.torch_compile:
1191+
self.torch_compile = True
1192+
if self.torch_compile and self.torch_compile_backend is None:
1193+
self.torch_compile_backend = "inductor"
1194+
if self.framework == "pt" and is_torch_available() and self.torch_compile:
11521195
if is_torch_tf32_available():
11531196
if self.tf32 is None and not self.fp16 or self.bf16:
1154-
logger.info("Setting TF32 in CUDA backends to speedup torchdynamo.")
1197+
logger.info(
1198+
"Setting TF32 in CUDA backends to speedup torch compile, you won't see any improvement"
1199+
" otherwise."
1200+
)
11551201
torch.backends.cuda.matmul.allow_tf32 = True
11561202
else:
11571203
logger.warning(

src/transformers/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@
148148
is_torch_bf16_available,
149149
is_torch_bf16_cpu_available,
150150
is_torch_bf16_gpu_available,
151+
is_torch_compile_available,
151152
is_torch_cuda_available,
152153
is_torch_fx_available,
153154
is_torch_fx_proxy,

src/transformers/utils/import_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,15 @@ def is_torchdynamo_available():
455455
return False
456456

457457

458+
def is_torch_compile_available():
459+
if not is_torch_available():
460+
return False
461+
462+
import torch
463+
464+
return hasattr(torch, "compile")
465+
466+
458467
def is_torch_tensorrt_fx_available():
459468
if importlib.util.find_spec("torch_tensorrt") is None:
460469
return False

0 commit comments

Comments
 (0)