diff --git a/.gitignore b/.gitignore index 5e057e7caf..d939038aae 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,4 @@ out wandb *.model *.json +*.watchman diff --git a/torchtrain/meta_init.py b/torchtrain/meta_init.py new file mode 100644 index 0000000000..3eafb01221 --- /dev/null +++ b/torchtrain/meta_init.py @@ -0,0 +1,48 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +import torch +from torch import nn +from torch.distributed.fsdp._common_utils import _is_fsdp_flattened + +from contextlib import contextmanager + + +@contextmanager +def meta_model_init(): + """init model on meta device""" + saved_register_parameter = nn.Module.register_parameter + saved_register_buffer = nn.Module.register_buffer + + def register_meta_param(module, name, param): + saved_register_parameter(module, name, param) + if param is not None: + param_cls = type(module._parameters[name]) + kwargs = module._parameters[name].__dict__ + module._parameters[name] = param_cls( + module._parameters[name].to(torch.device("meta")), **kwargs + ) + + def register_meta_buffer(module, name, buffer): + saved_register_buffer(module, name, buffer) + if buffer is not None: + module._buffers[name] = module._buffers[name].to(torch.device("meta")) + + try: + nn.Module.register_parameter = register_meta_param + nn.Module.register_buffer = register_meta_buffer + yield + finally: + nn.Module.register_parameter = saved_register_parameter + nn.Module.register_buffer = saved_register_buffer + + +@torch.no_grad() +def meta_to_real_init_fn(module: nn.Module): + for submodule in module.modules(): + for param_name, param in submodule.named_parameters(recurse=False): + if not _is_fsdp_flattened(param) and param.is_meta: + materialized_param = nn.Parameter( + torch.randn_like(param, device=torch.device("cuda")) + ) + setattr(submodule, param_name, materialized_param) diff --git a/torchtrain/models/llama/model.py b/torchtrain/models/llama/model.py index a3fbeb4bad..feec26aa40 100644 --- a/torchtrain/models/llama/model.py +++ b/torchtrain/models/llama/model.py @@ -24,6 +24,8 @@ class ModelArgs: max_batch_size: int = 32 max_seq_len: int = 32768 + use_meta_init: Optional[bool] = False # controlled via global settings + class RMSNorm(torch.nn.Module): def __init__(self, dim: int, eps: float = 1e-6): diff --git a/torchtrain/parallelisms/parallelize_llama.py b/torchtrain/parallelisms/parallelize_llama.py index 46bca9ff49..dcebcf2e3e 100644 --- a/torchtrain/parallelisms/parallelize_llama.py +++ b/torchtrain/parallelisms/parallelize_llama.py @@ -19,18 +19,18 @@ from torch.distributed.fsdp.wrap import enable_wrap, wrap from torchtrain.logging_utils import rank0_log - +from torchtrain.meta_init import meta_to_real_init_fn # Uses PTD FSDP AC wrapper def checkpoint_wrapper(module, config): return ptd_checkpoint_wrapper(module, checkpoint_impl=CheckpointImpl.NO_REENTRANT, preserve_rng_state=False) -def parallelize_llama(model, args): +def parallelize_llama(model, args, use_meta_init=False): """ Apply parallelisms to the model, including PTD parallelisms, and AC. - NOTE: the model passed in preferrablably shoule be a meta device model, + NOTE: the model passed in preferrablably should be a meta device model, otherwise the model needs to be small enough on GPU or can fit into CPU. # TODO: apply SP """ @@ -51,6 +51,8 @@ def parallelize_llama(model, args): dp_mesh = world_mesh # apply PTD parallelisms + meta_init_fn = meta_to_real_init_fn if use_meta_init else None + fsdp_config = { "mixed_precision": MixedPrecision( param_dtype=torch.bfloat16, @@ -62,20 +64,30 @@ def parallelize_llama(model, args): # When torch.compile is active, it requires us to set use_orig_params=True "use_orig_params": True, "device_mesh": dp_mesh, + "param_init_fn": meta_init_fn, } with enable_wrap(wrapper_cls=FSDP, **fsdp_config): + + using_meta_init = fsdp_config["param_init_fn"] + for layer_id, transformer_block in enumerate(model.layers): # apply AC to each layer # before wrapping with FSDP, we need to make sure the layer is on GPU - transformer_block = transformer_block.cuda() + # unless using meta init: + + if not using_meta_init: + transformer_block = transformer_block.cuda() + transformer_block = checkpoint_wrapper(transformer_block, args) # Wraps each layer with FSDP model.layers[layer_id]= wrap(transformer_block) - # wrap the rest layers with FSDP - model = wrap(model.cuda()) + # wrap the remaining layers with FSDP + if not using_meta_init: + model.cuda() + model = wrap(model) rank0_log(f"Applied parallelisms to the model...") diff --git a/train.py b/train.py index 073e5b1755..9e8a8cbea7 100644 --- a/train.py +++ b/train.py @@ -19,6 +19,7 @@ ) from torchtrain.models import models_config, model_name_to_cls, model_name_to_tokenizer from torchtrain.parallelisms import models_parallelize_fns +from torchtrain.meta_init import meta_model_init @dataclass @@ -57,17 +58,29 @@ def main(args): ) # build model - # TODO: add meta initialization + model_cls = model_name_to_cls[model_name] model_config = models_config[model_name][args.model_conf] + model_config.vocab_size = tokenizer.n_words - model = model_cls.from_model_args(model_config) + # meta initialization + _use_meta_init = args.meta_init # todo - add to toml + model_config.use_meta_init = _use_meta_init # append this to model config + + if _use_meta_init: + with meta_model_init(): + model = model_cls.from_model_args(model_config) + else: + model = model_cls.from_model_args(model_config) # apply PTD parallelisms + AC - model = models_parallelize_fns[model_name](model, args) + model = models_parallelize_fns[model_name]( + model, args, use_meta_init=_use_meta_init + ) # build optimizer after apply parallelisms to the model + # TODO: add scheduler if needed optimizer = build_optimizer(model, args) @@ -85,6 +98,8 @@ def main(args): # train loop model.train() + # use fsdp + with maybe_run_profiler() as torch_profiler: while train_state.step < args.steps or args.steps == -1: train_state.step += 1 @@ -160,5 +175,11 @@ def main(args): "--compile", action="store_true", help="Whether to compile the model." ) + parser.add_argument( + "--meta_init", + action="store_true", + help="Whether to use meta init for the model.", + ) + args = parser.parse_args() main(args)