From 4e64df55399e7f71b47e8b7caef57be0fdbeef39 Mon Sep 17 00:00:00 2001 From: Ananth Subramaniam Date: Wed, 26 Nov 2025 14:26:08 -0800 Subject: [PATCH] re-use higher-level config override util in tutorials Signed-off-by: Ananth Subramaniam --- .../recipes/llama/02_pretrain_with_yaml.py | 34 ++++--------------- .../recipes/llama/03_finetune_with_yaml.py | 34 ++++--------------- 2 files changed, 12 insertions(+), 56 deletions(-) diff --git a/tutorials/recipes/llama/02_pretrain_with_yaml.py b/tutorials/recipes/llama/02_pretrain_with_yaml.py index 999922324..b0ac1a6cd 100644 --- a/tutorials/recipes/llama/02_pretrain_with_yaml.py +++ b/tutorials/recipes/llama/02_pretrain_with_yaml.py @@ -48,21 +48,14 @@ import argparse import logging -import sys from pathlib import Path from typing import Tuple -from omegaconf import OmegaConf - from megatron.bridge.recipes.llama import llama32_1b_pretrain_config from megatron.bridge.training.config import ConfigContainer from megatron.bridge.training.gpt_step import forward_step from megatron.bridge.training.pretrain import pretrain -from megatron.bridge.training.utils.omegaconf_utils import ( - apply_overrides, - create_omegaconf_dict_config, - parse_hydra_overrides, -) +from megatron.bridge.training.utils.omegaconf_utils import process_config_with_overrides logger = logging.getLogger(__name__) @@ -98,26 +91,11 @@ def main() -> None: # Load base configuration from recipe config: ConfigContainer = llama32_1b_pretrain_config() - # Convert to OmegaConf for merging - omega_conf, excluded_fields = create_omegaconf_dict_config(config) - - # Apply YAML overrides if provided - if args.config_file: - config_file_path = Path(args.config_file) - if not config_file_path.exists(): - logger.error(f"Config file not found: {config_file_path}") - sys.exit(1) - - yaml_conf = OmegaConf.load(config_file_path) - omega_conf = OmegaConf.merge(omega_conf, yaml_conf) - - # Apply command-line overrides - if cli_overrides: - omega_conf = parse_hydra_overrides(omega_conf, cli_overrides) - - # Convert back to ConfigContainer - final_config_dict = OmegaConf.to_container(omega_conf, resolve=True) - apply_overrides(config, final_config_dict, excluded_fields) + config = process_config_with_overrides( + config, + config_filepath=args.config_file, + cli_overrides=cli_overrides or None, + ) # Start pretraining pretrain(config=config, forward_step_func=forward_step) diff --git a/tutorials/recipes/llama/03_finetune_with_yaml.py b/tutorials/recipes/llama/03_finetune_with_yaml.py index ead33540f..4f2baee1b 100644 --- a/tutorials/recipes/llama/03_finetune_with_yaml.py +++ b/tutorials/recipes/llama/03_finetune_with_yaml.py @@ -54,21 +54,14 @@ import argparse import logging -import sys from pathlib import Path from typing import Tuple -from omegaconf import OmegaConf - from megatron.bridge.recipes.llama import llama32_1b_finetune_config from megatron.bridge.training.config import ConfigContainer from megatron.bridge.training.finetune import finetune from megatron.bridge.training.gpt_step import forward_step -from megatron.bridge.training.utils.omegaconf_utils import ( - apply_overrides, - create_omegaconf_dict_config, - parse_hydra_overrides, -) +from megatron.bridge.training.utils.omegaconf_utils import process_config_with_overrides logger = logging.getLogger(__name__) @@ -112,26 +105,11 @@ def main() -> None: peft_method = None if args.peft == "none" else args.peft config: ConfigContainer = llama32_1b_finetune_config(peft=peft_method) - # Convert to OmegaConf for merging - omega_conf, excluded_fields = create_omegaconf_dict_config(config) - - # Apply YAML overrides if provided - if args.config_file: - config_file_path = Path(args.config_file) - if not config_file_path.exists(): - logger.error(f"Config file not found: {config_file_path}") - sys.exit(1) - - yaml_conf = OmegaConf.load(config_file_path) - omega_conf = OmegaConf.merge(omega_conf, yaml_conf) - - # Apply command-line overrides - if cli_overrides: - omega_conf = parse_hydra_overrides(omega_conf, cli_overrides) - - # Convert back to ConfigContainer - final_config_dict = OmegaConf.to_container(omega_conf, resolve=True) - apply_overrides(config, final_config_dict, excluded_fields) + config = process_config_with_overrides( + config, + config_filepath=args.config_file, + cli_overrides=cli_overrides or None, + ) # Start finetuning finetune(config=config, forward_step_func=forward_step)