Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/internal/trainer_utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ Most of those are only useful if you are studying the code of the Trainer in the

[[autodoc]] torch_distributed_zero_first

[[autodoc]] load_pretrained_model_only_on_rank0

## Callbacks internals

[[autodoc]] trainer_callback.CallbackHandler
Expand Down
7 changes: 5 additions & 2 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3059,7 +3059,10 @@
_import_structure["sagemaker"] = []
_import_structure["time_series_utils"] = []
_import_structure["trainer"] = ["Trainer"]
_import_structure["trainer_pt_utils"] = ["torch_distributed_zero_first"]
_import_structure["trainer_pt_utils"] = [
"load_pretrained_model_only_on_rank0",
"torch_distributed_zero_first",
]
_import_structure["trainer_seq2seq"] = ["Seq2SeqTrainer"]

# TensorFlow-backed objects
Expand Down Expand Up @@ -6598,7 +6601,7 @@

# Trainer
from .trainer import Trainer
from .trainer_pt_utils import torch_distributed_zero_first
from .trainer_pt_utils import load_pretrained_model_only_on_rank0, torch_distributed_zero_first
from .trainer_seq2seq import Seq2SeqTrainer

# TensorFlow
Expand Down
13 changes: 13 additions & 0 deletions src/transformers/trainer_pt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1125,3 +1125,16 @@ def smp_nested_concat(tensor):
# It doesn't seem possible to check here if `tensor` is a StepOutput because StepOutput lives in `smp.step`
# which is also the name of the decorator so Python is confused.
return tensor.concat().detach().cpu()


def load_pretrained_model_only_on_rank0(model_cls, config_cls, model_name_or_path):
from accelerate.state import PartialState

state = PartialState()
if state.is_main_process:
model = model_cls.from_pretrained(model_name_or_path, return_dict=True)
else:
with torch.device("meta"):
config = config_cls.from_pretrained(model_name_or_path)
model = model_cls.from_config(config)
return model
4 changes: 4 additions & 0 deletions src/transformers/utils/dummy_pt_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -8493,6 +8493,10 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


def load_pretrained_model_only_on_rank0(*args, **kwargs):
requires_backends(load_pretrained_model_only_on_rank0, ["torch"])


def torch_distributed_zero_first(*args, **kwargs):
requires_backends(torch_distributed_zero_first, ["torch"])

Expand Down