Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 4 additions & 27 deletions src/megatron/hub/training/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
FullyParallelLoadStrategyWrapper,
FullyParallelSaveStrategyWrapper,
)
from megatron.core.fp8_utils import is_float8tensor
from megatron.core.num_microbatches_calculator import update_num_microbatches
from megatron.core.rerun_state_machine import get_rerun_state_machine

Expand Down Expand Up @@ -1016,9 +1015,6 @@ def _load_checkpoint_from_path(
rerun_state=gen_sd_rerun_state,
)

# When "--fp8-param-gather" is disabled, this function doesn't modify anything.
fix_fp8_params_lose_precision_when_loading_dist_ckpt(load_kwargs["sharded_state_dict"])

# For PEFT, check if resuming from a checkpoint saved during training, which contains only the PEFT adapter states
# This situation occurs when:
# 1. The PEFT config is set
Expand Down Expand Up @@ -1114,8 +1110,10 @@ def _load_checkpoint_from_path(
raise e
else:
if (cfg.model.fp16 or cfg.model.bf16) and optimizer is not None:
optimizer.reload_model_params()

if cfg.checkpoint.load_main_params_from_ckpt:
optimizer.reload_model_params(state_dict=state_dict)
else:
optimizer.reload_model_params()
# rerun state
try:
if "rerun_state_machine" in state_dict:
Expand Down Expand Up @@ -1250,27 +1248,6 @@ def init_checkpointing_context(checkpoint_config: CheckpointConfig) -> dict[str,
return checkpointing_context


def fix_fp8_params_lose_precision_when_loading_dist_ckpt(state_dict: dict[str, Any]) -> None:
"""Workaround for FP8 parameters losing precision during distributed checkpoint loading.

When loading a distributed checkpoint, FP8 tensors within the model's state_dict
can sometimes lose precision. This function iterates through the model state
dictionary entries (keys starting with "model") and converts any ShardedTensors
containing FP8 data back to a higher precision format (via `.from_float8()`)
and moves them to the CPU before they are loaded into the model.

Args:
state_dict: The state dictionary loaded from the checkpoint, potentially
containing FP8 tensors within model states. This dictionary
is modified in-place.
"""
for key in state_dict.keys():
if key.startswith("model"):
for _, sharded_tensor in state_dict[key].items():
if is_float8tensor(sharded_tensor.data):
sharded_tensor.data = sharded_tensor.data.from_float8().cpu()


def apply_peft_adapter_filter_to_state_dict(state_dict: dict[str, Any], peft_config: PEFT) -> dict[str, Any]:
"""Filter state dict to contain only PEFT adapter parameters in model sections.

Expand Down
10 changes: 10 additions & 0 deletions src/megatron/hub/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,12 @@ class CheckpointConfig:
load_optim: bool = True
"""Do not load optimizer when loading checkpoint."""

load_main_params_from_ckpt: bool = False
"""Load main parameters from checkpoint. When loading a model from a checkpoint without loading
the optimizer, the model parameters are updated but for fp16 optimizer with main parameters,
the main parameters need to also be updated.
"""

load_rng: bool = True
"""Do not load rng state when loading checkpoint."""

Expand Down Expand Up @@ -446,6 +452,10 @@ class CheckpointConfig:
replication_factor: int = 2
"""Number of machines storing the replica of a given rank's data."""

def __post_init__(self) -> None:
if self.load_main_params_from_ckpt:
assert not self.load_optim, "load_main_params_from_ckpt must be used with load_optim=False"


@dataclass(kw_only=True)
class LoggerConfig:
Expand Down
38 changes: 38 additions & 0 deletions tests/unit_tests/training/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,19 @@ def test_nvrx_straggler_config(self):
cfg.report_time_interval = -100.0
cfg.__post_init__()

def test_checkpoint_config(self):
self._check_post_init_idempotency(create_test_checkpoint_config)

# Test rerun of post-init with valid and invalid changes
cfg = create_test_checkpoint_config(ckpt_format="torch_dist")
cfg.save = "/tmp/test_checkpoint_config"
cfg.__post_init__()

with pytest.raises(AssertionError, match="load_main_params_from_ckpt must be used with load_optim=False"):
cfg.load_main_params_from_ckpt = True
cfg.load_optim = True
cfg.__post_init__()

def test_rerun_validate_config_container(self):
import copy
from dataclasses import fields
Expand Down Expand Up @@ -822,3 +835,28 @@ def check_container_state_matches(cfg1, cfg2):
full_cfg.validate()
finally:
restore_get_world_size_safe(og_ws, cfg_mod)


class TestCheckpointConfig:
"""Tests for CheckpointConfig class."""

@pytest.mark.parametrize(
"load_main_params_from_ckpt, load_optim, expect_assertion_error",
[
(True, False, False), # Valid combination
(True, True, True), # Invalid combination - should raise error
(False, False, False), # Valid combination
(False, True, False), # Valid combination
],
)
def test_load_main_params_from_ckpt_validation_parametrized(
self, load_main_params_from_ckpt, load_optim, expect_assertion_error
):
"""Parametrized test for load_main_params_from_ckpt validation."""
if expect_assertion_error:
with pytest.raises(AssertionError, match="load_main_params_from_ckpt must be used with load_optim=False"):
create_test_checkpoint_config(
load_main_params_from_ckpt=load_main_params_from_ckpt, load_optim=load_optim
)
else:
create_test_checkpoint_config(load_main_params_from_ckpt=load_main_params_from_ckpt, load_optim=load_optim)
3 changes: 0 additions & 3 deletions tests/unit_tests/training/test_peft_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,6 @@ def test_load_checkpoint_peft_resume_detection(
patch("megatron.hub.training.checkpointing.update_num_microbatches"),
patch("megatron.hub.training.checkpointing.get_checkpoint_version") as mock_get_version,
patch("megatron.hub.training.checkpointing.set_checkpoint_version"),
patch("megatron.hub.training.checkpointing.fix_fp8_params_lose_precision_when_loading_dist_ckpt"),
patch("megatron.hub.training.checkpointing.restore_sharded_modelopt_state"),
patch("torch.distributed.barrier"),
patch("megatron.hub.training.checkpointing.print_rank_0"),
Expand Down Expand Up @@ -583,7 +582,6 @@ def test_load_checkpoint_non_peft_regular_loading(self, mock_checkpoint_exists,
patch("megatron.hub.training.checkpointing.update_num_microbatches"),
patch("megatron.hub.training.checkpointing.get_checkpoint_version") as mock_get_version,
patch("megatron.hub.training.checkpointing.set_checkpoint_version"),
patch("megatron.hub.training.checkpointing.fix_fp8_params_lose_precision_when_loading_dist_ckpt"),
patch("megatron.hub.training.checkpointing.restore_sharded_modelopt_state"),
patch("torch.distributed.barrier"),
patch("megatron.hub.training.checkpointing.print_rank_0"),
Expand Down Expand Up @@ -700,7 +698,6 @@ def test_load_checkpoint_peft_resume_multi_model(
patch("megatron.hub.training.checkpointing.update_num_microbatches"),
patch("megatron.hub.training.checkpointing.get_checkpoint_version") as mock_get_version,
patch("megatron.hub.training.checkpointing.set_checkpoint_version"),
patch("megatron.hub.training.checkpointing.fix_fp8_params_lose_precision_when_loading_dist_ckpt"),
patch("megatron.hub.training.checkpointing.restore_sharded_modelopt_state"),
patch("megatron.core.mpu.set_virtual_pipeline_model_parallel_rank"),
patch("torch.distributed.barrier"),
Expand Down
Loading