Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 20 additions & 22 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4659,21 +4659,16 @@ def _fix_state_dict_keys_on_save(self, state_dict):

@classmethod
def _load_pretrained_model(
cls,
model: "PreTrainedModel",
state_dict: Optional[dict],
checkpoint_files: Optional[list[str]],
pretrained_model_name_or_path: Optional[str],
ignore_mismatched_sizes: bool = False,
sharded_metadata: Optional[dict] = None,
device_map: Optional[dict] = None,
disk_offload_folder: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
hf_quantizer: Optional[HfQuantizer] = None,
device_mesh: Optional["torch.distributed.device_mesh.DeviceMesh"] = None,
key_mapping: Optional[dict[str, str]] = None,
weights_only: bool = True,
model,
state_dict,
loaded_keys,
resolved_archive_file,
pretrained_model_name_or_path,
ignore_mismatched_sizes=False,
sharded_metadata=None,
_fast_init=True,
):

# TODO: we should only be calling hf_quantizer.skip_placement or something like that
is_quantized = hf_quantizer is not None
is_hqq_or_quark = is_quantized and hf_quantizer.quantization_config.quant_method in {
Expand Down Expand Up @@ -5141,14 +5136,17 @@ def set_is_initialized_for_modules(module):
if is_deepspeed_zero3_enabled() and not is_quantized:
import deepspeed

# keep_vars=True as we need the original tensors, so that the "_is_hf_initialized" is present on them
not_initialized_parameters = list(
{v for v in self.state_dict(keep_vars=True).values() if not getattr(v, "_is_hf_initialized", False)}
)
with deepspeed.zero.GatheredParameters(not_initialized_parameters, modifier_rank=0):
self.initialize_weights()
else:
self.initialize_weights()
# keep_vars=True as we need the original tensors, so that the "_is_hf_initialized" is present on them
not_initialized_parameters = list(
{v for v in self.state_dict(keep_vars=True).values() if not getattr(v, "_is_hf_initialized", False)}
)
with deepspeed.zero.GatheredParameters(not_initialized_parameters, modifier_rank=0):
self.initialize_weights()
else:
# Skip reinitialization for quantized (int8) models
if not is_quantized:
self.initialize_weights()


def _adjust_missing_and_unexpected_keys(
self, missing_keys: list[str], unexpected_keys: list[str], loading_task_model_from_base_state_dict: bool
Expand Down
10 changes: 10 additions & 0 deletions src/transformers/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,13 @@
IMAGENET_STANDARD_STD = [0.5, 0.5, 0.5]
OPENAI_CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073]
OPENAI_CLIP_STD = [0.26862954, 0.26130258, 0.27577711]
__all__ = [
"IMAGENET_DEFAULT_MEAN",
"IMAGENET_DEFAULT_STD",
"IMAGENET_STANDARD_MEAN",
"IMAGENET_STANDARD_STD",
"OPENAI_CLIP_MEAN",
"OPENAI_CLIP_STD",
"SAFE_WEIGHTS_INDEX_NAME",
]

113 changes: 21 additions & 92 deletions src/transformers/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,16 @@
"critical": logging.CRITICAL,
}

_default_log_level = logging.WARNING
# Changed variable name from _default_log_level to DEFAULT_LOG_LEVEL
DEFAULT_LOG_LEVEL = logging.WARNING

_tqdm_active = not hf_hub_utils.are_progress_bars_disabled()


def _get_default_logging_level():
"""
If TRANSFORMERS_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is
not - fall back to `_default_log_level`
not - fall back to DEFAULT_LOG_LEVEL
"""
env_level_str = os.getenv("TRANSFORMERS_VERBOSITY", None)
if env_level_str:
Expand All @@ -65,7 +66,7 @@ def _get_default_logging_level():
f"Unknown option TRANSFORMERS_VERBOSITY={env_level_str}, "
f"has to be one of: {', '.join(log_levels.keys())}"
)
return _default_log_level
return DEFAULT_LOG_LEVEL


def _get_library_name() -> str:
Expand Down Expand Up @@ -160,152 +161,94 @@ def get_logger(name: str | None = None) -> logging.Logger:
def get_verbosity() -> int:
"""
Return the current level for the 🤗 Transformers's root logger as an int.

Returns:
`int`: The logging level.

<Tip>

🤗 Transformers has following logging levels:

- 50: `transformers.logging.CRITICAL` or `transformers.logging.FATAL`
- 40: `transformers.logging.ERROR`
- 30: `transformers.logging.WARNING` or `transformers.logging.WARN`
- 20: `transformers.logging.INFO`
- 10: `transformers.logging.DEBUG`

</Tip>"""

"""
_configure_library_root_logger()
return _get_library_root_logger().getEffectiveLevel()


def set_verbosity(verbosity: int) -> None:
"""
Set the verbosity level for the 🤗 Transformers's root logger.

Args:
verbosity (`int`):
Logging level, e.g., one of:

- `transformers.logging.CRITICAL` or `transformers.logging.FATAL`
- `transformers.logging.ERROR`
- `transformers.logging.WARNING` or `transformers.logging.WARN`
- `transformers.logging.INFO`
- `transformers.logging.DEBUG`
"""

"""Set the verbosity level for the 🤗 Transformers's root logger."""
_configure_library_root_logger()
_get_library_root_logger().setLevel(verbosity)


def set_verbosity_info():
"""Set the verbosity to the `INFO` level."""
"""Set the verbosity to the INFO level."""
return set_verbosity(INFO)


def set_verbosity_warning():
"""Set the verbosity to the `WARNING` level."""
"""Set the verbosity to the WARNING level."""
return set_verbosity(WARNING)


def set_verbosity_debug():
"""Set the verbosity to the `DEBUG` level."""
"""Set the verbosity to the DEBUG level."""
return set_verbosity(DEBUG)


def set_verbosity_error():
"""Set the verbosity to the `ERROR` level."""
"""Set the verbosity to the ERROR level."""
return set_verbosity(ERROR)


def disable_default_handler() -> None:
"""Disable the default handler of the HuggingFace Transformers's root logger."""

_configure_library_root_logger()

assert _default_handler is not None
_get_library_root_logger().removeHandler(_default_handler)


def enable_default_handler() -> None:
"""Enable the default handler of the HuggingFace Transformers's root logger."""

_configure_library_root_logger()

assert _default_handler is not None
_get_library_root_logger().addHandler(_default_handler)


def add_handler(handler: logging.Handler) -> None:
"""adds a handler to the HuggingFace Transformers's root logger."""

"""Adds a handler to the HuggingFace Transformers's root logger."""
_configure_library_root_logger()

assert handler is not None
_get_library_root_logger().addHandler(handler)


def remove_handler(handler: logging.Handler) -> None:
"""removes given handler from the HuggingFace Transformers's root logger."""

"""Removes given handler from the HuggingFace Transformers's root logger."""
_configure_library_root_logger()

assert handler is not None and handler not in _get_library_root_logger().handlers
_get_library_root_logger().removeHandler(handler)


def disable_propagation() -> None:
"""
Disable propagation of the library log outputs. Note that log propagation is disabled by default.
"""

"""Disable propagation of the library log outputs."""
_configure_library_root_logger()
_get_library_root_logger().propagate = False


def enable_propagation() -> None:
"""
Enable propagation of the library log outputs. Please disable the HuggingFace Transformers's default handler to
prevent double logging if the root logger has been configured.
"""

"""Enable propagation of the library log outputs."""
_configure_library_root_logger()
_get_library_root_logger().propagate = True


def enable_explicit_format() -> None:
"""
Enable explicit formatting for every HuggingFace Transformers's logger. The explicit formatter is as follows:
```
[LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE
```
All handlers currently bound to the root logger are affected by this method.
"""
"""Enable explicit formatting for every HuggingFace Transformers's logger."""
handlers = _get_library_root_logger().handlers

for handler in handlers:
formatter = logging.Formatter("[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s")
handler.setFormatter(formatter)


def reset_format() -> None:
"""
Resets the formatting for HuggingFace Transformers's loggers.

All handlers currently bound to the root logger are affected by this method.
"""
"""Resets the formatting for HuggingFace Transformers's loggers."""
handlers = _get_library_root_logger().handlers

for handler in handlers:
handler.setFormatter(None)


def warning_advice(self, *args, **kwargs):
"""
This method is identical to `logger.warning()`, but if env var TRANSFORMERS_NO_ADVISORY_WARNINGS=1 is set, this
warning will not be printed
"""
"""Suppress advisory warnings when TRANSFORMERS_NO_ADVISORY_WARNINGS=1."""
no_advisory_warnings = os.getenv("TRANSFORMERS_NO_ADVISORY_WARNINGS")
if no_advisory_warnings:
return
Expand All @@ -317,13 +260,7 @@ def warning_advice(self, *args, **kwargs):

@functools.lru_cache(None)
def warning_once(self, *args, **kwargs):
"""
This method is identical to `logger.warning()`, but will emit the warning with the same message only once

Note: The cache is for the function arguments, so 2 different callers using the same arguments will hit the cache.
The assumption here is that all warning messages are unique across the code. If they aren't then need to switch to
another type of cache that includes the caller frame information in the hashing function.
"""
"""Like logger.warning(), but emits the same message only once."""
self.warning(*args, **kwargs)


Expand All @@ -332,13 +269,7 @@ def warning_once(self, *args, **kwargs):

@functools.lru_cache(None)
def info_once(self, *args, **kwargs):
"""
This method is identical to `logger.info()`, but will emit the info with the same message only once

Note: The cache is for the function arguments, so 2 different callers using the same arguments will hit the cache.
The assumption here is that all warning messages are unique across the code. If they aren't then need to switch to
another type of cache that includes the caller frame information in the hashing function.
"""
"""Like logger.info(), but emits the same message only once."""
self.info(*args, **kwargs)


Expand All @@ -348,16 +279,14 @@ def info_once(self, *args, **kwargs):
class EmptyTqdm:
"""Dummy tqdm which doesn't do anything."""

def __init__(self, *args, **kwargs): # pylint: disable=unused-argument
def __init__(self, *args, **kwargs):
self._iterator = args[0] if args else None

def __iter__(self):
return iter(self._iterator)

def __getattr__(self, _):
"""Return empty function."""

def empty_fn(*args, **kwargs): # pylint: disable=unused-argument
def empty_fn(*args, **kwargs):
return

return empty_fn
Expand Down