From 8aa66d9b2e8c1c581d058fdc4f8206f5edce87eb Mon Sep 17 00:00:00 2001 From: Nathan Azrak Date: Tue, 3 Jun 2025 10:07:30 +0000 Subject: [PATCH 01/29] Add loss parallel to ParallelDims, and train context manager Signed-off-by: Nathan Azrak --- torchtune/training/__init__.py | 2 ++ torchtune/training/_distributed.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/torchtune/training/__init__.py b/torchtune/training/__init__.py index 1ddd76f57f..a7a26c1970 100644 --- a/torchtune/training/__init__.py +++ b/torchtune/training/__init__.py @@ -15,6 +15,7 @@ get_distributed_backend, get_full_optimizer_state_dict, get_shard_conditions, + get_train_context, get_world_size_and_rank, init_distributed, is_distributed, @@ -147,4 +148,5 @@ "disable_dropout", "DATALOADER_KEY", "get_context_parallel_manager", + "get_train_context", ] diff --git a/torchtune/training/_distributed.py b/torchtune/training/_distributed.py index 4c12d45150..41c077cfae 100644 --- a/torchtune/training/_distributed.py +++ b/torchtune/training/_distributed.py @@ -64,6 +64,7 @@ class ParallelDims: tp: int cp: int world_size: int + enable_loss_parallel: bool def __post_init__(self): self._validate() @@ -152,6 +153,10 @@ def dp_shard_enabled(self): def tp_enabled(self): return self.tp > 1 + @property + def loss_parallel_enabled(self): + return self.enable_loss_parallel and self.tp > 1 + @cached_property def non_data_parallel_size(self): # update below as more parallelism options are implemented @@ -874,3 +879,26 @@ def context(model_inputs: list[torch.Tensor]): yield return context + + +def get_train_context( + enable_loss_parallel: bool, enable_compiled_autograd: bool +) -> Generator[None, None, None]: + @contextlib.contextmanager + def context(cp_context: Generator[None, None, None] | None = None): + with contextlib.ExitStack() as stack: + if enable_loss_parallel: + stack.enter_context(torch.distributed.tensor.parallel.loss_parallel()) + + if enable_compiled_autograd: + stack.enter_context( + torch._dynamo.utils.maybe_enable_compiled_autograd(True) + ) + + # because we create a noop ctx manager, this is never None in actual recipes + # leave condition so this can be used separately + if cp_context is not None: + stack.enter_context(cp_context) + yield + + return context From f9315ddea8be3f2648b8cb80e8d87097c01efdd4 Mon Sep 17 00:00:00 2001 From: Nathan Azrak Date: Tue, 3 Jun 2025 10:08:13 +0000 Subject: [PATCH 02/29] Add loss parallel support to LinearCrossEntropyLoss Signed-off-by: Nathan Azrak --- torchtune/modules/loss/cross_entropy_loss.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/torchtune/modules/loss/cross_entropy_loss.py b/torchtune/modules/loss/cross_entropy_loss.py index 193ee181c5..7df9bdac0d 100644 --- a/torchtune/modules/loss/cross_entropy_loss.py +++ b/torchtune/modules/loss/cross_entropy_loss.py @@ -34,6 +34,7 @@ def __init__( self, num_output_chunks: int = 8, ignore_index: int = -100, + loss_parallel: bool = False, ): super().__init__() """ @@ -44,6 +45,7 @@ def __init__( self.linear_projection = None self.num_output_chunks = num_output_chunks self.ignore_index = ignore_index + self.loss_parallel = loss_parallel def apply_compile_strategy(self, *args, **kwargs): """Applies compile only to the compute_cross_entropy function. @@ -92,6 +94,7 @@ def compute_cross_entropy( hidden_chunk = DTensor.from_local( local_hidden_chunk, mesh, placements ) # [num_valid, embed_dim] + else: hidden_chunk = hidden_chunk[mask_chunk] # [num_valid, embed_dim] @@ -99,15 +102,19 @@ def compute_cross_entropy( if self.linear_projection is None: raise AttributeError("forward called before update_model") logits = self.linear_projection(hidden_chunk) # [num_valid, vocab_size] - if isinstance(logits, DTensor): + if isinstance(logits, DTensor) and not self.loss_parallel: logits = logits.full_tensor() - return F.cross_entropy( + loss = F.cross_entropy( logits.float(), target_chunk, reduction="sum", ignore_index=self.ignore_index, ) + # the all-reduce later complains if a DTensor is returned + if isinstance(loss, DTensor): + loss = loss.full_tensor() + return loss def forward( self, From fd7872c5737aed061e1b207ecc1084d3d7d795ee Mon Sep 17 00:00:00 2001 From: Nathan Azrak Date: Tue, 3 Jun 2025 10:12:26 +0000 Subject: [PATCH 03/29] ungate fp8, clean up llama3 parallelism, add loss parallel to TP plans Signed-off-by: Nathan Azrak --- torchtune/models/llama3/__init__.py | 3 +- torchtune/models/llama3/_parallelism.py | 65 +++++++++++++++---------- 2 files changed, 41 insertions(+), 27 deletions(-) diff --git a/torchtune/models/llama3/__init__.py b/torchtune/models/llama3/__init__.py index 5cf4e6b616..4091765ab7 100644 --- a/torchtune/models/llama3/__init__.py +++ b/torchtune/models/llama3/__init__.py @@ -15,7 +15,7 @@ qlora_llama3_70b, qlora_llama3_8b, ) -from ._parallelism import base_llama_tp_plan +from ._parallelism import base_llama_tp_plan, fp8_llama_tp_plan from ._tokenizer import Llama3Tokenizer __all__ = [ @@ -30,4 +30,5 @@ "qlora_llama3_8b", "qlora_llama3_70b", "base_llama_tp_plan", + "fp8_llama_tp_plan", ] diff --git a/torchtune/models/llama3/_parallelism.py b/torchtune/models/llama3/_parallelism.py index 8a2360c39f..73f8995d2b 100644 --- a/torchtune/models/llama3/_parallelism.py +++ b/torchtune/models/llama3/_parallelism.py @@ -26,6 +26,7 @@ def _get_base_llama_tp_training_plan( layerwise_colwise_parallel_cls: type[ParallelStyle] = ColwiseParallel, layerwise_rowwise_parallel_cls: type[ParallelStyle] = RowwiseParallel, layerwise_prepare_module_input_cls: type[ParallelStyle] = PrepareModuleInput, + loss_parallel: bool = False, ) -> dict[str, ParallelStyle]: """ Define the Tensor Parallel plan for Llama3 model, which will also be shared with 3.1, 3.2, and 3.3 models. @@ -35,7 +36,11 @@ def _get_base_llama_tp_training_plan( input_layouts=Replicate(), output_layouts=Shard(1) ), "norm": SequenceParallel(), - "output": ColwiseParallel(input_layouts=Shard(1), output_layouts=Replicate()), + "output": ColwiseParallel( + input_layouts=Shard(1), + output_layouts=Shard(-1) if loss_parallel else Replicate(), + use_local_output=not loss_parallel, + ), "layers.*.attn": layerwise_prepare_module_input_cls( input_layouts=(Shard(1), Shard(1)), desired_input_layouts=(Replicate(), Replicate()), @@ -58,51 +63,59 @@ def _get_base_llama_tp_training_plan( } -BASE_LLAMA_TP_TRAINING_PLAN = _get_base_llama_tp_training_plan() - -FP8_LLAMA_TP_TRAINING_PLAN = _get_base_llama_tp_training_plan( - layerwise_colwise_parallel_cls=Float8ColwiseParallel, - layerwise_rowwise_parallel_cls=Float8RowwiseParallel, - layerwise_prepare_module_input_cls=PrepareFloat8ModuleInput, -) - -BASE_LLAMA_TP_INFERENCE_PLAN = { - "tok_embeddings": RowwiseParallel(input_layouts=Replicate()), - "output": ColwiseParallel(output_layouts=Replicate()), - "layers.*.attn.q_proj": ColwiseParallel(), - "layers.*.attn.k_proj": ColwiseParallel(), - "layers.*.attn.v_proj": ColwiseParallel(), - "layers.*.attn.output_proj": RowwiseParallel(), - "layers.*.mlp.w1": ColwiseParallel(), - "layers.*.mlp.w2": RowwiseParallel(), - "layers.*.mlp.w3": ColwiseParallel(), -} +def _get_base_llama_tp_inference_plan(): + return { + "tok_embeddings": RowwiseParallel(input_layouts=Replicate()), + "output": ColwiseParallel(output_layouts=Replicate()), + "layers.*.attn.q_proj": ColwiseParallel(), + "layers.*.attn.k_proj": ColwiseParallel(), + "layers.*.attn.v_proj": ColwiseParallel(), + "layers.*.attn.output_proj": RowwiseParallel(), + "layers.*.mlp.w1": ColwiseParallel(), + "layers.*.mlp.w2": RowwiseParallel(), + "layers.*.mlp.w3": ColwiseParallel(), + } def base_llama_tp_plan( - model: nn.Module, inference: bool = False + model: nn.Module, *, inference: bool = False, loss_parallel: bool = False ) -> dict[str, ParallelStyle]: """ Helper function to get the base tensor parallel plan for Llama3 model, which will also be shared with 3.1, 3.2, and 3.3 models Args: model (nn.Module): Model to generate plan for (no-op) - inference (bool): Whether running inference or not. + inference (bool): Whether running inference or not + loss_parallel (bool): Whether to use loss parallelism after the output layer Returns: dict[str, Any]: The tensor parallel plan for Llama3 model. """ - return BASE_LLAMA_TP_INFERENCE_PLAN if inference else BASE_LLAMA_TP_TRAINING_PLAN + return ( + _get_base_llama_tp_inference_plan() + if inference + else _get_base_llama_tp_training_plan(loss_parallel=loss_parallel) + ) -# TODO: expose this once tested -def _fp8_llama_tp_plan() -> dict[str, ParallelStyle]: +def fp8_llama_tp_plan( + model: nn.Module, *, loss_parallel: bool = False +) -> dict[str, ParallelStyle]: """ Return the tensor parallel plan for Llama3 model that uses float8 for all-gather for both rowwise and colwise computation, currently only compatible with float8 fine-tuning with "tensorwise" scaling. This tensor parallel plan is shared between 3.1, 3.2, and 3.3 models. + Args: + model (nn.Module): Model to generate plan for (no-op) + loss_parallel (bool): Whether to use loss parallelism after the output layer + Returns: dict[str, Any]: The float8-enabled tensor parallel plan for Llama3 model. """ - return FP8_LLAMA_TP_TRAINING_PLAN + return _get_base_llama_tp_training_plan( + layerwise_colwise_parallel_cls=Float8ColwiseParallel, + layerwise_rowwise_parallel_cls=Float8RowwiseParallel, + layerwise_prepare_module_input_cls=PrepareFloat8ModuleInput, + loss_parallel=loss_parallel, + ) From f0a40faf2d796287fda6986403f07c21150200a0 Mon Sep 17 00:00:00 2001 From: Nathan Azrak Date: Tue, 3 Jun 2025 10:13:13 +0000 Subject: [PATCH 04/29] Add loss parallel support to full finetune recipe Signed-off-by: Nathan Azrak --- recipes/full_finetune_distributed.py | 38 +++++++++++++++++----------- 1 file changed, 23 insertions(+), 15 deletions(-) diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index d33d2dad31..b5f9fc18ed 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -161,6 +161,7 @@ def __init__(self, cfg: DictConfig) -> None: self.cp_degree = cfg.get("context_parallel_dim", 1) data_shard = cfg.get("data_parallel_shard_dim", -1) # -1 means to infer data_replicate = cfg.get("data_parallel_replicate_dim", 1) + enable_loss_parallel = cfg.get("enable_loss_parallel", True) # Set up n-d device mesh self.parallel_dims = training.ParallelDims( @@ -169,6 +170,7 @@ def __init__(self, cfg: DictConfig) -> None: tp=self.tp_degree, cp=self.cp_degree, world_size=self.world_size, + enable_loss_parallel=enable_loss_parallel, ) self.world_mesh = self.parallel_dims.build_mesh(device_type=device_type) if self.parallel_dims.dp_enabled: @@ -339,6 +341,7 @@ def setup(self, cfg: DictConfig) -> None: self._compile_loss = compile_bool self._compile_optimizer_step = compile_bool self._compile_scale_grads = compile_bool + self._compile_autograd = False # TODO: work out why this isn't working. We leave it parameterisable for debug purposes. if isinstance(compile, DictConfig): self._compile_model = compile.get("model", True) self._compile_loss = compile.get("loss", True) @@ -418,6 +421,9 @@ def setup(self, cfg: DictConfig) -> None: # initialize loss self._loss_fn = config.instantiate(cfg.loss) + if hasattr(self._loss_fn, "loss_parallel"): + self._loss_fn.loss_parallel = self.parallel_dims.loss_parallel_enabled + if isinstance(self._loss_fn, SFTLoss): self._loss_fn.set_model_output(self._model) @@ -603,11 +609,6 @@ def _setup_model( raise RuntimeError( "Float8 fine-tuning requires PyTorch 2.8.0.dev20250318 or later." ) - if self.tp_plan is not None: - raise ValueError( - "FP8 training does not support tensor parallelism yet. " - "This will be enabled in the near future." - ) if self.cp_degree > 1: raise ValueError( "Context Parallel for fp8 training is not currently supported" @@ -626,6 +627,7 @@ def _setup_model( self.tp_plan = config.instantiate( self.tp_plan, model=model, + loss_parallel=self.parallel_dims.loss_parallel_enabled, ) parallelize_module( model, @@ -674,13 +676,6 @@ def _setup_model( dp_mesh=self.world_mesh[dp_mesh_dim_names], ) - # Define context manager for context parallelism - self.context_parallel_manager = training.get_context_parallel_manager( - enabled=self.cp_degree > 1, - world_mesh=self.world_mesh, - model=model, - ) - with training.set_default_dtype(self._dtype), self._device: for m in model.modules(): # RoPE is not covered in state dict @@ -701,6 +696,17 @@ def _setup_model( self.activations_handling_ctx = training.get_act_offloading_ctx_manager( model, enable_activation_offloading, activation_offloading_use_streams ) + # context parallel + self.optional_context_parallel_manager = training.get_context_parallel_manager( + enabled=self.cp_degree > 1, + world_mesh=self.world_mesh, + model=model, + ) + # remaining context managers for fwd/bwd + self.train_context = training.get_train_context( + enable_loss_parallel=self.parallel_dims.loss_parallel_enabled, + enable_compiled_autograd=self._compile_autograd, + ) # Ensure no params and buffers are on meta device training.validate_no_params_on_meta_device(model) @@ -934,9 +940,11 @@ def train(self) -> None: ).sum() num_tokens += current_num_tokens - # Loss is normalized by default so we multiply by the number of tokens - # This way we can normalize by the total number of tokens if we're accumulating gradients - with self.context_parallel_manager(list(batch.values())): + with self.train_context( + self.optional_context_parallel_manager(list(batch.values())) + ): + # Loss is normalized by default so we multiply by the number of tokens + # This way we can normalize by the total number of tokens if we're accumulating gradients current_loss = self._loss_step(batch) * current_num_tokens running_loss += current_loss # For optimizer in backward, we need to normalize before calling backward From 6e93e834616c349f8a5a4ffa6bab3081170a3eb7 Mon Sep 17 00:00:00 2001 From: Nathan Azrak Date: Tue, 3 Jun 2025 10:29:39 +0000 Subject: [PATCH 05/29] allow enabling autograd compile even though it doesn't work Signed-off-by: Nathan Azrak --- recipes/full_finetune_distributed.py | 1 + 1 file changed, 1 insertion(+) diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index b5f9fc18ed..2344ab8168 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -347,6 +347,7 @@ def setup(self, cfg: DictConfig) -> None: self._compile_loss = compile.get("loss", True) self._compile_optimizer_step = compile.get("optimizer_step", False) self._compile_scale_grads = compile.get("scale_grads", True) + self._compile_autograd = compile.get("autograd", False) if self._compile_model: # Capture scalar outputs is required to compile MoE torch._dynamo.config.capture_scalar_outputs = True From 566142a54e7ed264d9e73279c732c6cb8a7840de Mon Sep 17 00:00:00 2001 From: Nathan Azrak Date: Tue, 3 Jun 2025 18:56:59 +0000 Subject: [PATCH 06/29] correct import paths in unit tests Signed-off-by: Nathan Azrak --- tests/torchtune/training/test_quantization.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/tests/torchtune/training/test_quantization.py b/tests/torchtune/training/test_quantization.py index 6581dca99c..35ad54ce61 100644 --- a/tests/torchtune/training/test_quantization.py +++ b/tests/torchtune/training/test_quantization.py @@ -10,8 +10,7 @@ from torchao.float8.float8_linear import Float8Linear -from torchtune.models.llama3 import base_llama_tp_plan -from torchtune.models.llama3._parallelism import _fp8_llama_tp_plan +from torchtune.models.llama3 import base_llama_tp_plan, fp8_llama_tp_plan from torchtune.training.quantization import ( _validate_float8_tp_plan, convert_to_float8_training, @@ -54,12 +53,12 @@ def _test_validate_float8_tp_plan(self): """ _validate_float8_tp_plan(base_llama_tp_plan()) _validate_float8_tp_plan(base_llama_tp_plan(), "anything") - _validate_float8_tp_plan(_fp8_llama_tp_plan()) - _validate_float8_tp_plan(_fp8_llama_tp_plan(), "tensorwise") + _validate_float8_tp_plan(fp8_llama_tp_plan()) + _validate_float8_tp_plan(fp8_llama_tp_plan(), "tensorwise") with pytest.raises(ValueError): - _validate_float8_tp_plan(_fp8_llama_tp_plan(), "rowwise") + _validate_float8_tp_plan(fp8_llama_tp_plan(), "rowwise") with pytest.raises(ValueError): - _validate_float8_tp_plan(_fp8_llama_tp_plan(), "rowwise_with_gw_hp") + _validate_float8_tp_plan(fp8_llama_tp_plan(), "rowwise_with_gw_hp") def test_is_fp8_tensorwise_scaling(self): """ From defc48f8d172a173c30f5307d0a6f774b6c8137f Mon Sep 17 00:00:00 2001 From: Nathan Azrak Date: Wed, 11 Jun 2025 03:35:38 +0000 Subject: [PATCH 07/29] remove autograd compile for now --- recipes/full_finetune_distributed.py | 4 +--- torchtune/training/_distributed.py | 9 +-------- 2 files changed, 2 insertions(+), 11 deletions(-) diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 2344ab8168..389af9ea09 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -341,13 +341,11 @@ def setup(self, cfg: DictConfig) -> None: self._compile_loss = compile_bool self._compile_optimizer_step = compile_bool self._compile_scale_grads = compile_bool - self._compile_autograd = False # TODO: work out why this isn't working. We leave it parameterisable for debug purposes. if isinstance(compile, DictConfig): self._compile_model = compile.get("model", True) self._compile_loss = compile.get("loss", True) self._compile_optimizer_step = compile.get("optimizer_step", False) self._compile_scale_grads = compile.get("scale_grads", True) - self._compile_autograd = compile.get("autograd", False) if self._compile_model: # Capture scalar outputs is required to compile MoE torch._dynamo.config.capture_scalar_outputs = True @@ -706,7 +704,6 @@ def _setup_model( # remaining context managers for fwd/bwd self.train_context = training.get_train_context( enable_loss_parallel=self.parallel_dims.loss_parallel_enabled, - enable_compiled_autograd=self._compile_autograd, ) # Ensure no params and buffers are on meta device @@ -1024,6 +1021,7 @@ def train(self) -> None: num_tokens / self.parallel_dims.non_data_parallel_size ) / (time_per_step * self.world_size), + "num_tokens": num_tokens, } if self._log_peak_memory_stats: log_dict.update( diff --git a/torchtune/training/_distributed.py b/torchtune/training/_distributed.py index 5c046b0bc6..a31be3eb04 100644 --- a/torchtune/training/_distributed.py +++ b/torchtune/training/_distributed.py @@ -886,20 +886,13 @@ def context(model_inputs: list[torch.Tensor]): return context -def get_train_context( - enable_loss_parallel: bool, enable_compiled_autograd: bool -) -> Generator[None, None, None]: +def get_train_context(enable_loss_parallel: bool) -> Generator[None, None, None]: @contextlib.contextmanager def context(cp_context: Generator[None, None, None] | None = None): with contextlib.ExitStack() as stack: if enable_loss_parallel: stack.enter_context(torch.distributed.tensor.parallel.loss_parallel()) - if enable_compiled_autograd: - stack.enter_context( - torch._dynamo.utils.maybe_enable_compiled_autograd(True) - ) - # because we create a noop ctx manager, this is never None in actual recipes # leave condition so this can be used separately if cp_context is not None: From 6eb0d8519c932da336944118739b92d0840e5964 Mon Sep 17 00:00:00 2001 From: Nathan Azrak Date: Wed, 11 Jun 2025 05:30:07 +0000 Subject: [PATCH 08/29] Remove unnecessary `full_tensor` in LinearCrossEntropyLoss --- torchtune/modules/loss/cross_entropy_loss.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/torchtune/modules/loss/cross_entropy_loss.py b/torchtune/modules/loss/cross_entropy_loss.py index 7df9bdac0d..ebebe0f718 100644 --- a/torchtune/modules/loss/cross_entropy_loss.py +++ b/torchtune/modules/loss/cross_entropy_loss.py @@ -102,8 +102,6 @@ def compute_cross_entropy( if self.linear_projection is None: raise AttributeError("forward called before update_model") logits = self.linear_projection(hidden_chunk) # [num_valid, vocab_size] - if isinstance(logits, DTensor) and not self.loss_parallel: - logits = logits.full_tensor() loss = F.cross_entropy( logits.float(), From ae4e751f6d8d62e985005bcf41dc7c37a4c6ae7e Mon Sep 17 00:00:00 2001 From: Nathan Azrak Date: Wed, 11 Jun 2025 05:38:08 +0000 Subject: [PATCH 09/29] remove layerwise prefix in llama3 tp plans --- torchtune/models/llama3/_parallelism.py | 32 ++++++++++++------------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/torchtune/models/llama3/_parallelism.py b/torchtune/models/llama3/_parallelism.py index 73f8995d2b..d35a57fb54 100644 --- a/torchtune/models/llama3/_parallelism.py +++ b/torchtune/models/llama3/_parallelism.py @@ -23,9 +23,9 @@ def _get_base_llama_tp_training_plan( - layerwise_colwise_parallel_cls: type[ParallelStyle] = ColwiseParallel, - layerwise_rowwise_parallel_cls: type[ParallelStyle] = RowwiseParallel, - layerwise_prepare_module_input_cls: type[ParallelStyle] = PrepareModuleInput, + colwise_parallel_cls: type[ParallelStyle] = ColwiseParallel, + rowwise_parallel_cls: type[ParallelStyle] = RowwiseParallel, + prepare_module_input_cls: type[ParallelStyle] = PrepareModuleInput, loss_parallel: bool = False, ) -> dict[str, ParallelStyle]: """ @@ -41,25 +41,23 @@ def _get_base_llama_tp_training_plan( output_layouts=Shard(-1) if loss_parallel else Replicate(), use_local_output=not loss_parallel, ), - "layers.*.attn": layerwise_prepare_module_input_cls( + "layers.*.attn": prepare_module_input_cls( input_layouts=(Shard(1), Shard(1)), desired_input_layouts=(Replicate(), Replicate()), ), - "layers.*.mlp": layerwise_prepare_module_input_cls( + "layers.*.mlp": prepare_module_input_cls( input_layouts=(Shard(1),), desired_input_layouts=(Replicate(),), ), "layers.*.sa_norm": SequenceParallel(), "layers.*.mlp_norm": SequenceParallel(), - "layers.*.attn.q_proj": layerwise_colwise_parallel_cls(), - "layers.*.attn.k_proj": layerwise_colwise_parallel_cls(), - "layers.*.attn.v_proj": layerwise_colwise_parallel_cls(), - "layers.*.attn.output_proj": layerwise_rowwise_parallel_cls( - output_layouts=Shard(1) - ), - "layers.*.mlp.w1": layerwise_colwise_parallel_cls(), - "layers.*.mlp.w2": layerwise_rowwise_parallel_cls(output_layouts=Shard(1)), - "layers.*.mlp.w3": layerwise_colwise_parallel_cls(), + "layers.*.attn.q_proj": colwise_parallel_cls(), + "layers.*.attn.k_proj": colwise_parallel_cls(), + "layers.*.attn.v_proj": colwise_parallel_cls(), + "layers.*.attn.output_proj": rowwise_parallel_cls(output_layouts=Shard(1)), + "layers.*.mlp.w1": colwise_parallel_cls(), + "layers.*.mlp.w2": rowwise_parallel_cls(output_layouts=Shard(1)), + "layers.*.mlp.w3": colwise_parallel_cls(), } @@ -114,8 +112,8 @@ def fp8_llama_tp_plan( dict[str, Any]: The float8-enabled tensor parallel plan for Llama3 model. """ return _get_base_llama_tp_training_plan( - layerwise_colwise_parallel_cls=Float8ColwiseParallel, - layerwise_rowwise_parallel_cls=Float8RowwiseParallel, - layerwise_prepare_module_input_cls=PrepareFloat8ModuleInput, + colwise_parallel_cls=Float8ColwiseParallel, + rowwise_parallel_cls=Float8RowwiseParallel, + prepare_module_input_cls=PrepareFloat8ModuleInput, loss_parallel=loss_parallel, ) From c4abec6fd7a425928e81a06e79d8d9bfa0209f91 Mon Sep 17 00:00:00 2001 From: Nathan Azrak Date: Wed, 11 Jun 2025 05:57:11 +0000 Subject: [PATCH 10/29] Refactor tp plans to support fp8 training via arg --- recipes/full_finetune_distributed.py | 1 + tests/torchtune/training/test_quantization.py | 13 +++-- torchtune/models/llama3/__init__.py | 3 +- torchtune/models/llama3/_parallelism.py | 55 ++++++++++++------- torchtune/models/llama4/_parallelism.py | 11 +++- 5 files changed, 55 insertions(+), 28 deletions(-) diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 389af9ea09..ce7f181e5d 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -627,6 +627,7 @@ def _setup_model( self.tp_plan, model=model, loss_parallel=self.parallel_dims.loss_parallel_enabled, + enable_fp8_training=self._enable_fp8_training, ) parallelize_module( model, diff --git a/tests/torchtune/training/test_quantization.py b/tests/torchtune/training/test_quantization.py index 35ad54ce61..e7bd943427 100644 --- a/tests/torchtune/training/test_quantization.py +++ b/tests/torchtune/training/test_quantization.py @@ -10,7 +10,8 @@ from torchao.float8.float8_linear import Float8Linear -from torchtune.models.llama3 import base_llama_tp_plan, fp8_llama_tp_plan +from torchtune.models.llama3 import base_llama_tp_plan +from torchtune.models.llama3._parallelism import _get_fp8_llama_tp_training_plan from torchtune.training.quantization import ( _validate_float8_tp_plan, convert_to_float8_training, @@ -53,12 +54,14 @@ def _test_validate_float8_tp_plan(self): """ _validate_float8_tp_plan(base_llama_tp_plan()) _validate_float8_tp_plan(base_llama_tp_plan(), "anything") - _validate_float8_tp_plan(fp8_llama_tp_plan()) - _validate_float8_tp_plan(fp8_llama_tp_plan(), "tensorwise") + _validate_float8_tp_plan(_get_fp8_llama_tp_training_plan()) + _validate_float8_tp_plan(_get_fp8_llama_tp_training_plan(), "tensorwise") with pytest.raises(ValueError): - _validate_float8_tp_plan(fp8_llama_tp_plan(), "rowwise") + _validate_float8_tp_plan(_get_fp8_llama_tp_training_plan(), "rowwise") with pytest.raises(ValueError): - _validate_float8_tp_plan(fp8_llama_tp_plan(), "rowwise_with_gw_hp") + _validate_float8_tp_plan( + _get_fp8_llama_tp_training_plan(), "rowwise_with_gw_hp" + ) def test_is_fp8_tensorwise_scaling(self): """ diff --git a/torchtune/models/llama3/__init__.py b/torchtune/models/llama3/__init__.py index 4091765ab7..5cf4e6b616 100644 --- a/torchtune/models/llama3/__init__.py +++ b/torchtune/models/llama3/__init__.py @@ -15,7 +15,7 @@ qlora_llama3_70b, qlora_llama3_8b, ) -from ._parallelism import base_llama_tp_plan, fp8_llama_tp_plan +from ._parallelism import base_llama_tp_plan from ._tokenizer import Llama3Tokenizer __all__ = [ @@ -30,5 +30,4 @@ "qlora_llama3_8b", "qlora_llama3_70b", "base_llama_tp_plan", - "fp8_llama_tp_plan", ] diff --git a/torchtune/models/llama3/_parallelism.py b/torchtune/models/llama3/_parallelism.py index d35a57fb54..d9954a3558 100644 --- a/torchtune/models/llama3/_parallelism.py +++ b/torchtune/models/llama3/_parallelism.py @@ -75,45 +75,60 @@ def _get_base_llama_tp_inference_plan(): } -def base_llama_tp_plan( - model: nn.Module, *, inference: bool = False, loss_parallel: bool = False +def _get_fp8_llama_tp_training_plan( + model: nn.Module, *, loss_parallel: bool = False ) -> dict[str, ParallelStyle]: """ - Helper function to get the base tensor parallel plan for Llama3 model, which will also be shared with 3.1, 3.2, and 3.3 models + Return the tensor parallel plan for Llama3 model that uses float8 for all-gather for both + rowwise and colwise computation, currently only compatible with float8 fine-tuning with + "tensorwise" scaling. This tensor parallel plan is shared between 3.1, 3.2, and 3.3 models. Args: model (nn.Module): Model to generate plan for (no-op) - inference (bool): Whether running inference or not loss_parallel (bool): Whether to use loss parallelism after the output layer Returns: - dict[str, Any]: The tensor parallel plan for Llama3 model. + dict[str, Any]: The float8-enabled tensor parallel plan for Llama3 model. """ - return ( - _get_base_llama_tp_inference_plan() - if inference - else _get_base_llama_tp_training_plan(loss_parallel=loss_parallel) + return _get_base_llama_tp_training_plan( + colwise_parallel_cls=Float8ColwiseParallel, + rowwise_parallel_cls=Float8RowwiseParallel, + prepare_module_input_cls=PrepareFloat8ModuleInput, + loss_parallel=loss_parallel, ) -def fp8_llama_tp_plan( - model: nn.Module, *, loss_parallel: bool = False +def base_llama_tp_plan( + model: nn.Module, + *, + inference: bool = False, + loss_parallel: bool = False, + enable_fp8_training: bool = False, ) -> dict[str, ParallelStyle]: """ - Return the tensor parallel plan for Llama3 model that uses float8 for all-gather for both - rowwise and colwise computation, currently only compatible with float8 fine-tuning with - "tensorwise" scaling. This tensor parallel plan is shared between 3.1, 3.2, and 3.3 models. + Helper function to get the base tensor parallel plan for Llama3 model, which will also be shared with 3.1, 3.2, and 3.3 models Args: model (nn.Module): Model to generate plan for (no-op) + inference (bool): Whether running inference or not loss_parallel (bool): Whether to use loss parallelism after the output layer + enable_fp8_training (bool): Whether to enable float8 training. Returns: - dict[str, Any]: The float8-enabled tensor parallel plan for Llama3 model. + dict[str, Any]: The tensor parallel plan for Llama3 model. + + Raises: + ValueError: if FP8 training is enabled with inference. """ - return _get_base_llama_tp_training_plan( - colwise_parallel_cls=Float8ColwiseParallel, - rowwise_parallel_cls=Float8RowwiseParallel, - prepare_module_input_cls=PrepareFloat8ModuleInput, - loss_parallel=loss_parallel, + if enable_fp8_training: + if inference: + raise ValueError( + "FP8 training is not compatible with inference with LLaMA-3" + ) + return _get_fp8_llama_tp_training_plan(model, loss_parallel=loss_parallel) + + return ( + _get_base_llama_tp_inference_plan() + if inference + else _get_base_llama_tp_training_plan(loss_parallel=loss_parallel) ) diff --git a/torchtune/models/llama4/_parallelism.py b/torchtune/models/llama4/_parallelism.py index 8cc5433609..37bead9c4f 100644 --- a/torchtune/models/llama4/_parallelism.py +++ b/torchtune/models/llama4/_parallelism.py @@ -161,7 +161,10 @@ def decoder_only_tp_inference_plan(model: nn.Module) -> dict[str, ParallelStyle] def decoder_only_tp_plan( - model: nn.Module, inference: bool = False + model: nn.Module, + *, + inference: bool = False, + enable_fp8_training: bool = False, ) -> dict[str, ParallelStyle]: """ Helper function to get the base tensor parallel plan for Llama3 model, which will also be shared with 3.1, 3.2, and 3.3 models @@ -169,10 +172,16 @@ def decoder_only_tp_plan( Args: model (nn.Module): Model to generate plan for (no-op) inference (bool): Whether running inference or not. + enable_fp8_training (bool): Whether to enable float8 training. Currently not supported for Llama4. Returns: dict[str, Any]: The tensor parallel plan for Llama3 model. + + Raises: + ValueError: if FP8 training is enabled. """ + if enable_fp8_training: + raise ValueError("FP8 training is not supported for Llama4") return ( decoder_only_tp_inference_plan(model) if inference From f037fab6662524837beae72af8eece4071381a45 Mon Sep 17 00:00:00 2001 From: Nathan Azrak Date: Wed, 11 Jun 2025 10:50:28 +0000 Subject: [PATCH 11/29] Refactor loss parallel into custom loss classes --- recipes/full_finetune_distributed.py | 29 ++++++++--- torchtune/models/llama3/_parallelism.py | 16 ++---- torchtune/modules/loss/cross_entropy_loss.py | 30 ++++++++++-- torchtune/modules/loss/loss_types.py | 51 ++++++++++++++++++++ 4 files changed, 102 insertions(+), 24 deletions(-) diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index ce7f181e5d..4b59ea0ca1 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -357,6 +357,24 @@ def setup(self, cfg: DictConfig) -> None: self._grad_scaler, backend=self._compile_backend ) + # initialize loss + self._loss_fn = config.instantiate(cfg.loss) + if isinstance(self._loss_fn, SFTLoss): + self._loss_fn.enable_loss_parallel = ( + self.parallel_dims.loss_parallel_enabled + ) + + # Whether to use the ctx manager. If the loss fn has the property, use that. Otherwise, assume it is supported. + # Useful if, for example, user opts to use the basic CrossEntropyLoss() instead of an SFTLoss subclass. + self.use_loss_parallel_ctx_manager = ( + self.parallel_dims.loss_parallel_enabled + and getattr( + self._loss_fn, + "use_loss_parallel_ctx_manager", + True, + ) + ) + self._model = self._setup_model( cfg_model=cfg.model, enable_activation_checkpointing=self._enable_activation_checkpointing, @@ -418,11 +436,6 @@ def setup(self, cfg: DictConfig) -> None: # Update the recipe state from the checkpoint state dict. self._update_recipe_state(checkpoint_dict) - # initialize loss - self._loss_fn = config.instantiate(cfg.loss) - if hasattr(self._loss_fn, "loss_parallel"): - self._loss_fn.loss_parallel = self.parallel_dims.loss_parallel_enabled - if isinstance(self._loss_fn, SFTLoss): self._loss_fn.set_model_output(self._model) @@ -626,9 +639,11 @@ def _setup_model( self.tp_plan = config.instantiate( self.tp_plan, model=model, - loss_parallel=self.parallel_dims.loss_parallel_enabled, enable_fp8_training=self._enable_fp8_training, ) + if isinstance(self._loss_fn, SFTLoss): + self.tp_plan = self._loss_fn.patch_tp_plan(self.tp_plan) + parallelize_module( model, self.world_mesh["tp"], @@ -704,7 +719,7 @@ def _setup_model( ) # remaining context managers for fwd/bwd self.train_context = training.get_train_context( - enable_loss_parallel=self.parallel_dims.loss_parallel_enabled, + enable_loss_parallel=self.use_loss_parallel_ctx_manager, ) # Ensure no params and buffers are on meta device diff --git a/torchtune/models/llama3/_parallelism.py b/torchtune/models/llama3/_parallelism.py index d9954a3558..1c69406e79 100644 --- a/torchtune/models/llama3/_parallelism.py +++ b/torchtune/models/llama3/_parallelism.py @@ -26,7 +26,6 @@ def _get_base_llama_tp_training_plan( colwise_parallel_cls: type[ParallelStyle] = ColwiseParallel, rowwise_parallel_cls: type[ParallelStyle] = RowwiseParallel, prepare_module_input_cls: type[ParallelStyle] = PrepareModuleInput, - loss_parallel: bool = False, ) -> dict[str, ParallelStyle]: """ Define the Tensor Parallel plan for Llama3 model, which will also be shared with 3.1, 3.2, and 3.3 models. @@ -38,8 +37,7 @@ def _get_base_llama_tp_training_plan( "norm": SequenceParallel(), "output": ColwiseParallel( input_layouts=Shard(1), - output_layouts=Shard(-1) if loss_parallel else Replicate(), - use_local_output=not loss_parallel, + output_layouts=Replicate(), ), "layers.*.attn": prepare_module_input_cls( input_layouts=(Shard(1), Shard(1)), @@ -75,9 +73,7 @@ def _get_base_llama_tp_inference_plan(): } -def _get_fp8_llama_tp_training_plan( - model: nn.Module, *, loss_parallel: bool = False -) -> dict[str, ParallelStyle]: +def _get_fp8_llama_tp_training_plan(model: nn.Module) -> dict[str, ParallelStyle]: """ Return the tensor parallel plan for Llama3 model that uses float8 for all-gather for both rowwise and colwise computation, currently only compatible with float8 fine-tuning with @@ -85,7 +81,6 @@ def _get_fp8_llama_tp_training_plan( Args: model (nn.Module): Model to generate plan for (no-op) - loss_parallel (bool): Whether to use loss parallelism after the output layer Returns: dict[str, Any]: The float8-enabled tensor parallel plan for Llama3 model. @@ -94,7 +89,6 @@ def _get_fp8_llama_tp_training_plan( colwise_parallel_cls=Float8ColwiseParallel, rowwise_parallel_cls=Float8RowwiseParallel, prepare_module_input_cls=PrepareFloat8ModuleInput, - loss_parallel=loss_parallel, ) @@ -102,7 +96,6 @@ def base_llama_tp_plan( model: nn.Module, *, inference: bool = False, - loss_parallel: bool = False, enable_fp8_training: bool = False, ) -> dict[str, ParallelStyle]: """ @@ -111,7 +104,6 @@ def base_llama_tp_plan( Args: model (nn.Module): Model to generate plan for (no-op) inference (bool): Whether running inference or not - loss_parallel (bool): Whether to use loss parallelism after the output layer enable_fp8_training (bool): Whether to enable float8 training. Returns: @@ -125,10 +117,10 @@ def base_llama_tp_plan( raise ValueError( "FP8 training is not compatible with inference with LLaMA-3" ) - return _get_fp8_llama_tp_training_plan(model, loss_parallel=loss_parallel) + return _get_fp8_llama_tp_training_plan(model) return ( _get_base_llama_tp_inference_plan() if inference - else _get_base_llama_tp_training_plan(loss_parallel=loss_parallel) + else _get_base_llama_tp_training_plan() ) diff --git a/torchtune/modules/loss/cross_entropy_loss.py b/torchtune/modules/loss/cross_entropy_loss.py index ebebe0f718..a7e1c8fed0 100644 --- a/torchtune/modules/loss/cross_entropy_loss.py +++ b/torchtune/modules/loss/cross_entropy_loss.py @@ -7,7 +7,8 @@ import torch import torch.nn.functional as F from torch import nn -from torch.distributed.tensor import DTensor +from torch.distributed.tensor import DTensor, Shard +from torch.distributed.tensor.parallel import ColwiseParallel from torchtune.modules.loss.loss_types import SFTLoss from torchtune.utils import get_logger @@ -15,7 +16,7 @@ log = get_logger() -class LinearCrossEntropyLoss(nn.Module, SFTLoss): +class LinearCrossEntropyLoss(SFTLoss, nn.Module): """Memory efficient Cross-entropy loss that incrementally computes loss for chunks of tokens by masking ignored tokens, calculating logits and then applying cross-entropy loss. Combines the linear projection with the cross-entropy calculation for further memory savings. @@ -34,9 +35,9 @@ def __init__( self, num_output_chunks: int = 8, ignore_index: int = -100, - loss_parallel: bool = False, + enable_loss_parallel: bool = False, ): - super().__init__() + super().__init__(enable_loss_parallel=enable_loss_parallel) """ Args: num_output_chunks (int): Number of chunks to split the output tensor into. Default is 8. @@ -45,7 +46,6 @@ def __init__( self.linear_projection = None self.num_output_chunks = num_output_chunks self.ignore_index = ignore_index - self.loss_parallel = loss_parallel def apply_compile_strategy(self, *args, **kwargs): """Applies compile only to the compute_cross_entropy function. @@ -62,6 +62,26 @@ def set_model_output(self, model: nn.Module) -> None: model.skip_output_layer = True self.linear_projection = model.output + def patch_tp_plan(self, tp_plan) -> bool: + if self.loss_parallel_enabled: + if "output" not in tp_plan: + raise KeyError("`tp_plan` requires `output` key") + + tp_plan["output"] = ColwiseParallel( + input_layouts=Shard(1), + output_layouts=Shard(-1), + use_local_output=False, + ) + return tp_plan + + @property + def supports_loss_parallel(self) -> bool: + return True + + @property + def loss_parallel_requires_ctx_manager(self) -> bool: + return True + def compute_cross_entropy( self, hidden_chunk: torch.Tensor, diff --git a/torchtune/modules/loss/loss_types.py b/torchtune/modules/loss/loss_types.py index 7d1edf9082..9005be239e 100644 --- a/torchtune/modules/loss/loss_types.py +++ b/torchtune/modules/loss/loss_types.py @@ -14,6 +14,14 @@ class SFTLoss(ABC): """Interface for loss functions in torchtune used in sft recipes.""" + # makes subclasses with multiple inheritance including nn.Module play nicely + # https://github.com/pytorch/pytorch/pull/91819 + call_super_init = True + + def __init__(self, *, enable_loss_parallel: bool = False): + super().__init__() + self.enable_loss_parallel = enable_loss_parallel + def apply_compile_strategy(self, *args, **kwargs): """Compile the loss function for inference.""" self.forward = torch.compile(self.forward, *args, **kwargs) @@ -36,6 +44,49 @@ def forward(self, outputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: """ pass + @property + @abstractmethod + def supports_loss_parallel(self) -> bool: + """ + Whether the loss function supports loss parallel. + Set to loss if loss parallelism isn't tested with your loss class. + """ + pass + + @property + @abstractmethod + def loss_parallel_requires_ctx_manager(self) -> bool: + """ + Whether to use the loss parallel context manager for loss parallelism. Can be + used if the function relies on the standard cross_entropy() or CrossEntropyLoss. + Set to false if loss parallelism isn't tested with your loss class, or your loss + parallelism doesn't require the context manager.. + """ + pass + + def patch_tp_plan(self, tp_plan) -> bool: + """Whether the loss function supports loss parallel. Defaults to a noop.""" + return tp_plan + + @property + def loss_parallel_enabled(self) -> bool: + """ + The `enable_loss_parallel` flag is a directive from the user. + This property also validates that it is supported, or raises an error. + """ + if self.enable_loss_parallel and not self.supports_loss_parallel: + raise ValueError( + f"Loss function is enabled, but {self.__class__.__name__} does not support loss parallelism" + ) + return self.enable_loss_parallel + + @property + def use_loss_parallel_ctx_manager(self) -> bool: + """ + Whether to enable the loss parallelism ctx manager for this instance of the class. + """ + return self.loss_parallel_enabled and self.loss_parallel_requires_ctx_manager + class RLLoss(ABC): """Interface for loss functions in torchtune used in RL recipes.""" From ed6111329759a29279c4ca069df67c5e48d48461 Mon Sep 17 00:00:00 2001 From: Nathan Azrak Date: Wed, 11 Jun 2025 14:39:43 +0000 Subject: [PATCH 12/29] Fix Replicate caused by `tensor_split` in Linear CE, disallow masking for now. --- torchtune/modules/loss/cross_entropy_loss.py | 49 +++++++++++++------- 1 file changed, 31 insertions(+), 18 deletions(-) diff --git a/torchtune/modules/loss/cross_entropy_loss.py b/torchtune/modules/loss/cross_entropy_loss.py index a7e1c8fed0..1e9ce6b177 100644 --- a/torchtune/modules/loss/cross_entropy_loss.py +++ b/torchtune/modules/loss/cross_entropy_loss.py @@ -99,23 +99,17 @@ def compute_cross_entropy( Raises: AttributeError: if called before update_model """ - # Select hidden states and targets where mask is True - mask_chunk = target_chunk != self.ignore_index - if mask_chunk.sum() == 0: - # Unmask 1 token to allow loss to sync with all data parallel workers - mask_chunk[0] = True - - target_chunk = target_chunk[mask_chunk] # [num_valid] if isinstance(hidden_chunk, DTensor): - # DTensor doesn't support masks so we have to mask locally - mesh = hidden_chunk.device_mesh - placements = hidden_chunk.placements - local_hidden_chunk = hidden_chunk.to_local()[mask_chunk] - hidden_chunk = DTensor.from_local( - local_hidden_chunk, mesh, placements - ) # [num_valid, embed_dim] - + # TODO: work out if masking can still be done after avoiding the Replicate issue + pass else: + # Select hidden states and targets where mask is True + mask_chunk = target_chunk != self.ignore_index + if mask_chunk.sum() == 0: + # Unmask 1 token to allow loss to sync with all data parallel workers + mask_chunk[0] = True + + target_chunk = target_chunk[mask_chunk] # [num_valid] hidden_chunk = hidden_chunk[mask_chunk] # [num_valid, embed_dim] # [num_valid, embed_dim] @ [embed_dim, vocab_size] @@ -151,9 +145,28 @@ def forward( mask = targets != self.ignore_index total_elements = mask.sum() - # Chunk along sequence dimension - hidden_chunks = outputs.tensor_split(self.num_output_chunks, dim=1) - target_chunks = targets.tensor_split(self.num_output_chunks, dim=1) + if isinstance(outputs, DTensor): + original_placements = outputs.placements + original_mesh = outputs.device_mesh + + # resharding to a different dim stops the sharding dim from decaying to Replicate during tensor_split + outputs = outputs.redistribute( + device_mesh=original_mesh, placements=[Shard(-1)] * original_mesh.ndim + ) + # perform the splitting on a different dimension + hidden_chunks = outputs.tensor_split(self.num_output_chunks, dim=1) + target_chunks = targets.tensor_split(self.num_output_chunks, dim=1) + hidden_chunks = [ + h.flatten(0, 1).redistribute( + device_mesh=original_mesh, placements=original_placements + ) # this last redistribute is to remain consistent with the TP plan, which may cause overhead + for h in hidden_chunks + ] + target_chunks = [t.flatten(0, 1) for t in target_chunks] + + else: + hidden_chunks = outputs.tensor_split(self.num_output_chunks, dim=1) + target_chunks = targets.tensor_split(self.num_output_chunks, dim=1) # Compute cross-entropy loss for the chunks total_loss = 0.0 From 587529233f84b26d1a81bb1e6373acd16f6d09ac Mon Sep 17 00:00:00 2001 From: Nathan Azrak Date: Thu, 12 Jun 2025 00:23:48 +0000 Subject: [PATCH 13/29] re-introduce masking in tensor parallel linear CE loss. Signed-off-by: Nathan Azrak --- torchtune/modules/loss/cross_entropy_loss.py | 52 +++++++++++--------- 1 file changed, 28 insertions(+), 24 deletions(-) diff --git a/torchtune/modules/loss/cross_entropy_loss.py b/torchtune/modules/loss/cross_entropy_loss.py index 1e9ce6b177..bbe8cafc7d 100644 --- a/torchtune/modules/loss/cross_entropy_loss.py +++ b/torchtune/modules/loss/cross_entropy_loss.py @@ -7,7 +7,8 @@ import torch import torch.nn.functional as F from torch import nn -from torch.distributed.tensor import DTensor, Shard +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor, Placement, Shard from torch.distributed.tensor.parallel import ColwiseParallel from torchtune.modules.loss.loss_types import SFTLoss @@ -86,12 +87,17 @@ def compute_cross_entropy( self, hidden_chunk: torch.Tensor, target_chunk: torch.Tensor, + *, + original_mesh: DeviceMesh | None = None, + original_placements: list[Placement] | None = None, ) -> torch.Tensor: """Computes cross-entropy by masking tokens, calculating logits and then applying cross-entropy loss. Args: hidden_chunk (torch.Tensor): [batch_size, chunk_size, embed_dim] target_chunk (torch.Tensor): [batch_size, chunk_size] + original_mesh (DeviceMesh | None): Device mesh of the original tensor if distributed + original_placements (list[Placement] | None): Placements of the original tensor if distributed Returns: torch.Tensor: Sum of cross-entropy loss for non-ignored tokens in the chunk @@ -99,17 +105,21 @@ def compute_cross_entropy( Raises: AttributeError: if called before update_model """ + # target_chunk = target_chunk.flatten(0, 1) + # hidden_chunk = hidden_chunk.flatten(0, 1) + mask_chunk = target_chunk != self.ignore_index + if mask_chunk.sum() == 0: + # Unmask 1 token to allow loss to sync with all data parallel workers + mask_chunk[0] = True + target_chunk = target_chunk[mask_chunk] + if isinstance(hidden_chunk, DTensor): - # TODO: work out if masking can still be done after avoiding the Replicate issue - pass - else: - # Select hidden states and targets where mask is True - mask_chunk = target_chunk != self.ignore_index - if mask_chunk.sum() == 0: - # Unmask 1 token to allow loss to sync with all data parallel workers - mask_chunk[0] = True + local_hidden_chunk = hidden_chunk.to_local()[mask_chunk] + hidden_chunk = DTensor.from_local( + local_hidden_chunk, original_mesh, original_placements + ) # [num_valid, embed_dim] - target_chunk = target_chunk[mask_chunk] # [num_valid] + else: hidden_chunk = hidden_chunk[mask_chunk] # [num_valid, embed_dim] # [num_valid, embed_dim] @ [embed_dim, vocab_size] @@ -145,28 +155,20 @@ def forward( mask = targets != self.ignore_index total_elements = mask.sum() + original_mesh = None + original_placements = None if isinstance(outputs, DTensor): original_placements = outputs.placements original_mesh = outputs.device_mesh - # resharding to a different dim stops the sharding dim from decaying to Replicate during tensor_split + # resharding on the feature dim stops the sharding dim (currently sequence) from + # decaying to Replicate during tensor_split outputs = outputs.redistribute( device_mesh=original_mesh, placements=[Shard(-1)] * original_mesh.ndim ) - # perform the splitting on a different dimension - hidden_chunks = outputs.tensor_split(self.num_output_chunks, dim=1) - target_chunks = targets.tensor_split(self.num_output_chunks, dim=1) - hidden_chunks = [ - h.flatten(0, 1).redistribute( - device_mesh=original_mesh, placements=original_placements - ) # this last redistribute is to remain consistent with the TP plan, which may cause overhead - for h in hidden_chunks - ] - target_chunks = [t.flatten(0, 1) for t in target_chunks] - else: - hidden_chunks = outputs.tensor_split(self.num_output_chunks, dim=1) - target_chunks = targets.tensor_split(self.num_output_chunks, dim=1) + hidden_chunks = outputs.tensor_split(self.num_output_chunks, dim=1) + target_chunks = targets.tensor_split(self.num_output_chunks, dim=1) # Compute cross-entropy loss for the chunks total_loss = 0.0 @@ -174,6 +176,8 @@ def forward( total_loss += self.compute_cross_entropy( hidden_chunks[idx], target_chunks[idx], + original_mesh=original_mesh, + original_placements=original_placements, ) if total_elements == 0: From bba42d5aa691d58e5f30441c027c5bf50fe2047f Mon Sep 17 00:00:00 2001 From: Nathan Azrak Date: Thu, 12 Jun 2025 00:24:02 +0000 Subject: [PATCH 14/29] revert accidental num_tokens metric in recipe. Signed-off-by: Nathan Azrak --- recipes/full_finetune_distributed.py | 1 - 1 file changed, 1 deletion(-) diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 4b59ea0ca1..629904e39f 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -1037,7 +1037,6 @@ def train(self) -> None: num_tokens / self.parallel_dims.non_data_parallel_size ) / (time_per_step * self.world_size), - "num_tokens": num_tokens, } if self._log_peak_memory_stats: log_dict.update( From 87d5fd6dc37492860a5523be9c22b43da40f6032 Mon Sep 17 00:00:00 2001 From: Nathan Azrak Date: Thu, 12 Jun 2025 00:47:16 +0000 Subject: [PATCH 15/29] clean up linear CE loss Signed-off-by: Nathan Azrak --- torchtune/modules/loss/cross_entropy_loss.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/torchtune/modules/loss/cross_entropy_loss.py b/torchtune/modules/loss/cross_entropy_loss.py index bbe8cafc7d..4ca3aea901 100644 --- a/torchtune/modules/loss/cross_entropy_loss.py +++ b/torchtune/modules/loss/cross_entropy_loss.py @@ -105,20 +105,17 @@ def compute_cross_entropy( Raises: AttributeError: if called before update_model """ - # target_chunk = target_chunk.flatten(0, 1) - # hidden_chunk = hidden_chunk.flatten(0, 1) mask_chunk = target_chunk != self.ignore_index if mask_chunk.sum() == 0: # Unmask 1 token to allow loss to sync with all data parallel workers mask_chunk[0] = True - target_chunk = target_chunk[mask_chunk] + target_chunk = target_chunk[mask_chunk] # [num_valid] if isinstance(hidden_chunk, DTensor): local_hidden_chunk = hidden_chunk.to_local()[mask_chunk] hidden_chunk = DTensor.from_local( local_hidden_chunk, original_mesh, original_placements ) # [num_valid, embed_dim] - else: hidden_chunk = hidden_chunk[mask_chunk] # [num_valid, embed_dim] From f5efb0f5872ee35003013f358071289f53814234 Mon Sep 17 00:00:00 2001 From: Nathan Azrak Date: Sun, 15 Jun 2025 00:09:47 +0000 Subject: [PATCH 16/29] explicitly flatten target and hidden chunks Signed-off-by: Nathan Azrak --- torchtune/modules/loss/cross_entropy_loss.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torchtune/modules/loss/cross_entropy_loss.py b/torchtune/modules/loss/cross_entropy_loss.py index 4ca3aea901..da7b8e4b70 100644 --- a/torchtune/modules/loss/cross_entropy_loss.py +++ b/torchtune/modules/loss/cross_entropy_loss.py @@ -105,6 +105,10 @@ def compute_cross_entropy( Raises: AttributeError: if called before update_model """ + # this explicit flattening ensures same tensor dimension, regardless of if mask is all true + target_chunk = target_chunk.reshape(-1) + hidden_chunk = hidden_chunk.reshape(-1, hidden_chunk.shape[-1]) + mask_chunk = target_chunk != self.ignore_index if mask_chunk.sum() == 0: # Unmask 1 token to allow loss to sync with all data parallel workers From 40df86da0e5314016739996d218ffefeeb4ad0a2 Mon Sep 17 00:00:00 2001 From: Nathan Azrak Date: Mon, 16 Jun 2025 10:00:03 +0000 Subject: [PATCH 17/29] refactor to allow compile in non-parallel case --- recipes/full_finetune_distributed.py | 7 +- torchtune/modules/loss/cross_entropy_loss.py | 121 ++++++++++--------- 2 files changed, 67 insertions(+), 61 deletions(-) diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 629904e39f..8a68398ad4 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -364,14 +364,15 @@ def setup(self, cfg: DictConfig) -> None: self.parallel_dims.loss_parallel_enabled ) - # Whether to use the ctx manager. If the loss fn has the property, use that. Otherwise, assume it is supported. - # Useful if, for example, user opts to use the basic CrossEntropyLoss() instead of an SFTLoss subclass. + # Whether to use the ctx manager. If the loss fn has the property, use that. Otherwise, assume it is not supported. + # Currently our TP plans assume replicating on the output of `output` so without a custom loss class, there is + # no memory benefit to the ctx manager self.use_loss_parallel_ctx_manager = ( self.parallel_dims.loss_parallel_enabled and getattr( self._loss_fn, "use_loss_parallel_ctx_manager", - True, + False, ) ) diff --git a/torchtune/modules/loss/cross_entropy_loss.py b/torchtune/modules/loss/cross_entropy_loss.py index da7b8e4b70..803f53a7b6 100644 --- a/torchtune/modules/loss/cross_entropy_loss.py +++ b/torchtune/modules/loss/cross_entropy_loss.py @@ -7,8 +7,7 @@ import torch import torch.nn.functional as F from torch import nn -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor import DTensor, Placement, Shard +from torch.distributed.tensor import DTensor, Shard from torch.distributed.tensor.parallel import ColwiseParallel from torchtune.modules.loss.loss_types import SFTLoss @@ -51,11 +50,14 @@ def __init__( def apply_compile_strategy(self, *args, **kwargs): """Applies compile only to the compute_cross_entropy function. If compiling CE + chunking operation together, memory requirement is higher.""" - log.warning("Skipping compile loss, as it is not supported at this time") - # TODO fix compile and re-enable - # self.compute_cross_entropy = torch.compile( - # self.compute_cross_entropy, *args, **kwargs - # ) + if not self.loss_parallel_enabled: + self.compute_cross_entropy = torch.compile( + self.compute_cross_entropy, *args, **kwargs + ) + else: + log.warning( + "Skipping compile loss, as it is not supported with loss parallel enabled." + ) return self def set_model_output(self, model: nn.Module) -> None: @@ -87,17 +89,12 @@ def compute_cross_entropy( self, hidden_chunk: torch.Tensor, target_chunk: torch.Tensor, - *, - original_mesh: DeviceMesh | None = None, - original_placements: list[Placement] | None = None, ) -> torch.Tensor: """Computes cross-entropy by masking tokens, calculating logits and then applying cross-entropy loss. Args: hidden_chunk (torch.Tensor): [batch_size, chunk_size, embed_dim] target_chunk (torch.Tensor): [batch_size, chunk_size] - original_mesh (DeviceMesh | None): Device mesh of the original tensor if distributed - original_placements (list[Placement] | None): Placements of the original tensor if distributed Returns: torch.Tensor: Sum of cross-entropy loss for non-ignored tokens in the chunk @@ -105,24 +102,6 @@ def compute_cross_entropy( Raises: AttributeError: if called before update_model """ - # this explicit flattening ensures same tensor dimension, regardless of if mask is all true - target_chunk = target_chunk.reshape(-1) - hidden_chunk = hidden_chunk.reshape(-1, hidden_chunk.shape[-1]) - - mask_chunk = target_chunk != self.ignore_index - if mask_chunk.sum() == 0: - # Unmask 1 token to allow loss to sync with all data parallel workers - mask_chunk[0] = True - - target_chunk = target_chunk[mask_chunk] # [num_valid] - if isinstance(hidden_chunk, DTensor): - local_hidden_chunk = hidden_chunk.to_local()[mask_chunk] - hidden_chunk = DTensor.from_local( - local_hidden_chunk, original_mesh, original_placements - ) # [num_valid, embed_dim] - else: - hidden_chunk = hidden_chunk[mask_chunk] # [num_valid, embed_dim] - # [num_valid, embed_dim] @ [embed_dim, vocab_size] if self.linear_projection is None: raise AttributeError("forward called before update_model") @@ -134,11 +113,38 @@ def compute_cross_entropy( reduction="sum", ignore_index=self.ignore_index, ) - # the all-reduce later complains if a DTensor is returned - if isinstance(loss, DTensor): - loss = loss.full_tensor() return loss + def mask_inputs( + self, hidden: torch.Tensor, target: torch.Tensor, indices: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Args: + hidden (torch.Tensor): Hidden state of the model, pre projection. Shape ``[bsz, seq_len, emb_dim]`` + target (torch.Tensor): Labels for the model. Shape ``[bsz, seq_len]`` + indices (torch.Tensor): Indices of the valid tokens. Shape ``[num_valid]`` + + Returns: + tuple[torch.Tensor, torch.Tensor]: returns a tuple of + - The indexed hidden states + - The indexed targets + """ + # slicing requires both tensors to be same type + # since hidden is sharded on the feature dim, slicing on seq dim is possible + if isinstance(hidden, DTensor): + device_mesh = hidden.device_mesh + hidden = hidden.to_local().index_select(0, indices) + hidden = DTensor.from_local( + hidden, + device_mesh=device_mesh, + placements=[Shard(-1)] * device_mesh.ndim, + ) + else: + hidden = hidden.index_select(0, indices) + + target = target.index_select(0, indices) + return hidden, target + def forward( self, outputs: torch.Tensor, @@ -152,37 +158,36 @@ def forward( Returns: torch.Tensor: loss tensor """ - # Total number of non-ignored tokens across the entire batch - mask = targets != self.ignore_index - total_elements = mask.sum() + total_valid_tokens = torch.where(targets != self.ignore_index)[0].numel() + if total_valid_tokens == 0: + return torch.tensor(0.0, device=targets.device) - original_mesh = None - original_placements = None + # this redistribute allows tensor spitting without replication if isinstance(outputs, DTensor): - original_placements = outputs.placements - original_mesh = outputs.device_mesh - - # resharding on the feature dim stops the sharding dim (currently sequence) from - # decaying to Replicate during tensor_split outputs = outputs.redistribute( - device_mesh=original_mesh, placements=[Shard(-1)] * original_mesh.ndim + device_mesh=outputs.device_mesh, + placements=[Shard(-1)] * outputs.device_mesh.ndim, ) - hidden_chunks = outputs.tensor_split(self.num_output_chunks, dim=1) target_chunks = targets.tensor_split(self.num_output_chunks, dim=1) - # Compute cross-entropy loss for the chunks - total_loss = 0.0 - for idx in range(len(hidden_chunks)): - total_loss += self.compute_cross_entropy( - hidden_chunks[idx], - target_chunks[idx], - original_mesh=original_mesh, - original_placements=original_placements, + total_loss = torch.tensor(0.0, device=targets.device) + for hidden_chunk, target_chunk in zip(hidden_chunks, target_chunks): + # forcefully reshaping ensures same dim tensor, even if mask is all True + target_chunk = target_chunk.reshape(-1) + hidden_chunk = hidden_chunk.reshape(-1, hidden_chunk.shape[-1]) + + # non-ignored indices + indices = torch.where(target_chunk != self.ignore_index)[0] + + hidden_chunk, target_chunk = self.mask_inputs( + hidden_chunk, target_chunk, indices ) - if total_elements == 0: - # must return after calling compute_cross_entropy to not hang during data parallel training - return total_loss - else: - return total_loss / total_elements + loss = self.compute_cross_entropy(hidden_chunk, target_chunk) + # without this backprop throws `'Tensor' object has no attribute '_local_tensor'` + if isinstance(loss, DTensor): + loss = loss.full_tensor() + total_loss += loss + + return total_loss / total_valid_tokens From c0df2b21879a203a7d176e86b4474c3162dd7abe Mon Sep 17 00:00:00 2001 From: Nathan Azrak Date: Mon, 16 Jun 2025 10:27:12 +0000 Subject: [PATCH 18/29] comment clarity --- recipes/full_finetune_distributed.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 8a68398ad4..b6a782d986 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -365,8 +365,8 @@ def setup(self, cfg: DictConfig) -> None: ) # Whether to use the ctx manager. If the loss fn has the property, use that. Otherwise, assume it is not supported. - # Currently our TP plans assume replicating on the output of `output` so without a custom loss class, there is - # no memory benefit to the ctx manager + # Currently our TP plans assume replicating on the output of `output` so without a custom loss class patching the + # TP plan, there is no memory benefit to the ctx manager self.use_loss_parallel_ctx_manager = ( self.parallel_dims.loss_parallel_enabled and getattr( From def5ed44dce235a5e7b054a780eda2882c60649c Mon Sep 17 00:00:00 2001 From: Nathan Azrak Date: Mon, 16 Jun 2025 16:33:58 +0000 Subject: [PATCH 19/29] reshape + mask before chunking, allow toggling masking in loss --- torchtune/modules/loss/cross_entropy_loss.py | 29 +++++++++++--------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/torchtune/modules/loss/cross_entropy_loss.py b/torchtune/modules/loss/cross_entropy_loss.py index 803f53a7b6..8cc3dbf136 100644 --- a/torchtune/modules/loss/cross_entropy_loss.py +++ b/torchtune/modules/loss/cross_entropy_loss.py @@ -36,21 +36,26 @@ def __init__( num_output_chunks: int = 8, ignore_index: int = -100, enable_loss_parallel: bool = False, + mask_ignored_tokens: bool = True, ): super().__init__(enable_loss_parallel=enable_loss_parallel) """ Args: num_output_chunks (int): Number of chunks to split the output tensor into. Default is 8. ignore_index (int): Index to ignore in the target tensor. Default is -100. + enable_loss_parallel (bool): Whether to enable loss parallel. Default is False. + mask_ignored_tokens (bool): Whether to mask out ignored tokens during loss computation. Default is True. """ self.linear_projection = None self.num_output_chunks = num_output_chunks self.ignore_index = ignore_index + self.mask_ignored_tokens = mask_ignored_tokens def apply_compile_strategy(self, *args, **kwargs): """Applies compile only to the compute_cross_entropy function. If compiling CE + chunking operation together, memory requirement is higher.""" - if not self.loss_parallel_enabled: + # we might be able to compile in TP case if masking is disabled? + if not self.loss_parallel_enabled or not self.mask_ignored: self.compute_cross_entropy = torch.compile( self.compute_cross_entropy, *args, **kwargs ) @@ -168,22 +173,20 @@ def forward( device_mesh=outputs.device_mesh, placements=[Shard(-1)] * outputs.device_mesh.ndim, ) - hidden_chunks = outputs.tensor_split(self.num_output_chunks, dim=1) - target_chunks = targets.tensor_split(self.num_output_chunks, dim=1) - total_loss = torch.tensor(0.0, device=targets.device) - for hidden_chunk, target_chunk in zip(hidden_chunks, target_chunks): - # forcefully reshaping ensures same dim tensor, even if mask is all True - target_chunk = target_chunk.reshape(-1) - hidden_chunk = hidden_chunk.reshape(-1, hidden_chunk.shape[-1]) + # forcefully reshaping ensures same dim tensor, even if mask is all True + targets = targets.reshape(-1) + outputs = outputs.reshape(-1, outputs.shape[-1]) - # non-ignored indices - indices = torch.where(target_chunk != self.ignore_index)[0] + if self.mask_ignored: + indices = torch.where(targets != self.ignore_index)[0] + outputs, targets = self.mask_inputs(outputs, targets, indices) - hidden_chunk, target_chunk = self.mask_inputs( - hidden_chunk, target_chunk, indices - ) + hidden_chunks = outputs.tensor_split(self.num_output_chunks, dim=0) + target_chunks = targets.tensor_split(self.num_output_chunks, dim=0) + total_loss = torch.tensor(0.0, device=targets.device) + for hidden_chunk, target_chunk in zip(hidden_chunks, target_chunks): loss = self.compute_cross_entropy(hidden_chunk, target_chunk) # without this backprop throws `'Tensor' object has no attribute '_local_tensor'` if isinstance(loss, DTensor): From ef601911675258570a3b3f0ca59acad2ff08ebf9 Mon Sep 17 00:00:00 2001 From: Nathan Azrak Date: Mon, 16 Jun 2025 16:40:28 +0000 Subject: [PATCH 20/29] clean up comments --- torchtune/modules/loss/cross_entropy_loss.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/torchtune/modules/loss/cross_entropy_loss.py b/torchtune/modules/loss/cross_entropy_loss.py index 8cc3dbf136..ba05aab6c6 100644 --- a/torchtune/modules/loss/cross_entropy_loss.py +++ b/torchtune/modules/loss/cross_entropy_loss.py @@ -134,8 +134,6 @@ def mask_inputs( - The indexed hidden states - The indexed targets """ - # slicing requires both tensors to be same type - # since hidden is sharded on the feature dim, slicing on seq dim is possible if isinstance(hidden, DTensor): device_mesh = hidden.device_mesh hidden = hidden.to_local().index_select(0, indices) @@ -174,7 +172,6 @@ def forward( placements=[Shard(-1)] * outputs.device_mesh.ndim, ) - # forcefully reshaping ensures same dim tensor, even if mask is all True targets = targets.reshape(-1) outputs = outputs.reshape(-1, outputs.shape[-1]) From 65c1d3a5f37180e4d85249305e6271e49d60dc72 Mon Sep 17 00:00:00 2001 From: Nathan Azrak Date: Tue, 17 Jun 2025 00:40:50 +0000 Subject: [PATCH 21/29] Clean up docstrings --- torchtune/modules/loss/cross_entropy_loss.py | 6 +++--- torchtune/modules/loss/loss_types.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/torchtune/modules/loss/cross_entropy_loss.py b/torchtune/modules/loss/cross_entropy_loss.py index ba05aab6c6..00402ecfe7 100644 --- a/torchtune/modules/loss/cross_entropy_loss.py +++ b/torchtune/modules/loss/cross_entropy_loss.py @@ -54,7 +54,7 @@ def __init__( def apply_compile_strategy(self, *args, **kwargs): """Applies compile only to the compute_cross_entropy function. If compiling CE + chunking operation together, memory requirement is higher.""" - # we might be able to compile in TP case if masking is disabled? + # compiling with loss parallelism appears to work without masking if not self.loss_parallel_enabled or not self.mask_ignored: self.compute_cross_entropy = torch.compile( self.compute_cross_entropy, *args, **kwargs @@ -125,8 +125,8 @@ def mask_inputs( ) -> tuple[torch.Tensor, torch.Tensor]: """ Args: - hidden (torch.Tensor): Hidden state of the model, pre projection. Shape ``[bsz, seq_len, emb_dim]`` - target (torch.Tensor): Labels for the model. Shape ``[bsz, seq_len]`` + hidden (torch.Tensor): Hidden state of the model, pre projection. Shape ``[bsz*seq_len, emb_dim]`` + target (torch.Tensor): Labels for the model. Shape ``[bsz*seq_len]`` indices (torch.Tensor): Indices of the valid tokens. Shape ``[num_valid]`` Returns: diff --git a/torchtune/modules/loss/loss_types.py b/torchtune/modules/loss/loss_types.py index 9005be239e..a6e99ba8ea 100644 --- a/torchtune/modules/loss/loss_types.py +++ b/torchtune/modules/loss/loss_types.py @@ -49,7 +49,7 @@ def forward(self, outputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: def supports_loss_parallel(self) -> bool: """ Whether the loss function supports loss parallel. - Set to loss if loss parallelism isn't tested with your loss class. + Set to false if loss parallelism isn't tested with your loss class. """ pass @@ -60,7 +60,7 @@ def loss_parallel_requires_ctx_manager(self) -> bool: Whether to use the loss parallel context manager for loss parallelism. Can be used if the function relies on the standard cross_entropy() or CrossEntropyLoss. Set to false if loss parallelism isn't tested with your loss class, or your loss - parallelism doesn't require the context manager.. + parallelism doesn't require the context manager. """ pass From 8294215b4540435e9379950cf22b281c3c341b47 Mon Sep 17 00:00:00 2001 From: Nathan Azrak Date: Tue, 17 Jun 2025 02:39:00 +0000 Subject: [PATCH 22/29] simplify SFTLoss contract --- recipes/full_finetune_distributed.py | 18 +++------ torchtune/modules/loss/cross_entropy_loss.py | 33 +++++++---------- torchtune/modules/loss/loss_types.py | 39 ++------------------ torchtune/training/_distributed.py | 5 --- 4 files changed, 23 insertions(+), 72 deletions(-) diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index b6a782d986..1ed40c4758 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -161,7 +161,6 @@ def __init__(self, cfg: DictConfig) -> None: self.cp_degree = cfg.get("context_parallel_dim", 1) data_shard = cfg.get("data_parallel_shard_dim", -1) # -1 means to infer data_replicate = cfg.get("data_parallel_replicate_dim", 1) - enable_loss_parallel = cfg.get("enable_loss_parallel", True) # Set up n-d device mesh self.parallel_dims = training.ParallelDims( @@ -170,7 +169,6 @@ def __init__(self, cfg: DictConfig) -> None: tp=self.tp_degree, cp=self.cp_degree, world_size=self.world_size, - enable_loss_parallel=enable_loss_parallel, ) self.world_mesh = self.parallel_dims.build_mesh(device_type=device_type) if self.parallel_dims.dp_enabled: @@ -359,21 +357,14 @@ def setup(self, cfg: DictConfig) -> None: # initialize loss self._loss_fn = config.instantiate(cfg.loss) - if isinstance(self._loss_fn, SFTLoss): - self._loss_fn.enable_loss_parallel = ( - self.parallel_dims.loss_parallel_enabled - ) # Whether to use the ctx manager. If the loss fn has the property, use that. Otherwise, assume it is not supported. # Currently our TP plans assume replicating on the output of `output` so without a custom loss class patching the # TP plan, there is no memory benefit to the ctx manager - self.use_loss_parallel_ctx_manager = ( - self.parallel_dims.loss_parallel_enabled - and getattr( - self._loss_fn, - "use_loss_parallel_ctx_manager", - False, - ) + self.use_loss_parallel_ctx_manager = self.parallel_dims.tp_enabled and getattr( + self._loss_fn, + "tp_requires_loss_parallel_ctx_manager", + False, ) self._model = self._setup_model( @@ -643,6 +634,7 @@ def _setup_model( enable_fp8_training=self._enable_fp8_training, ) if isinstance(self._loss_fn, SFTLoss): + self._loss_fn.tp_enabled = True self.tp_plan = self._loss_fn.patch_tp_plan(self.tp_plan) parallelize_module( diff --git a/torchtune/modules/loss/cross_entropy_loss.py b/torchtune/modules/loss/cross_entropy_loss.py index 00402ecfe7..2d510ab8ba 100644 --- a/torchtune/modules/loss/cross_entropy_loss.py +++ b/torchtune/modules/loss/cross_entropy_loss.py @@ -35,27 +35,27 @@ def __init__( self, num_output_chunks: int = 8, ignore_index: int = -100, - enable_loss_parallel: bool = False, + tp_enabled: bool = False, mask_ignored_tokens: bool = True, ): - super().__init__(enable_loss_parallel=enable_loss_parallel) + super().__init__(tp_enabled=tp_enabled) """ Args: num_output_chunks (int): Number of chunks to split the output tensor into. Default is 8. ignore_index (int): Index to ignore in the target tensor. Default is -100. - enable_loss_parallel (bool): Whether to enable loss parallel. Default is False. mask_ignored_tokens (bool): Whether to mask out ignored tokens during loss computation. Default is True. """ self.linear_projection = None self.num_output_chunks = num_output_chunks self.ignore_index = ignore_index self.mask_ignored_tokens = mask_ignored_tokens + self.tp_enabled = tp_enabled def apply_compile_strategy(self, *args, **kwargs): """Applies compile only to the compute_cross_entropy function. If compiling CE + chunking operation together, memory requirement is higher.""" # compiling with loss parallelism appears to work without masking - if not self.loss_parallel_enabled or not self.mask_ignored: + if not self.tp_enabled or not self.mask_ignored_tokens: self.compute_cross_entropy = torch.compile( self.compute_cross_entropy, *args, **kwargs ) @@ -71,23 +71,18 @@ def set_model_output(self, model: nn.Module) -> None: self.linear_projection = model.output def patch_tp_plan(self, tp_plan) -> bool: - if self.loss_parallel_enabled: - if "output" not in tp_plan: - raise KeyError("`tp_plan` requires `output` key") - - tp_plan["output"] = ColwiseParallel( - input_layouts=Shard(1), - output_layouts=Shard(-1), - use_local_output=False, - ) - return tp_plan + if "output" not in tp_plan: + raise KeyError("`tp_plan` requires `output` key") - @property - def supports_loss_parallel(self) -> bool: - return True + tp_plan["output"] = ColwiseParallel( + input_layouts=Shard(1), + output_layouts=Shard(-1), + use_local_output=False, + ) + return tp_plan @property - def loss_parallel_requires_ctx_manager(self) -> bool: + def tp_requires_loss_parallel_ctx_manager(self) -> bool: return True def compute_cross_entropy( @@ -175,7 +170,7 @@ def forward( targets = targets.reshape(-1) outputs = outputs.reshape(-1, outputs.shape[-1]) - if self.mask_ignored: + if self.mask_ignored_tokens: indices = torch.where(targets != self.ignore_index)[0] outputs, targets = self.mask_inputs(outputs, targets, indices) diff --git a/torchtune/modules/loss/loss_types.py b/torchtune/modules/loss/loss_types.py index a6e99ba8ea..7b5eb2c284 100644 --- a/torchtune/modules/loss/loss_types.py +++ b/torchtune/modules/loss/loss_types.py @@ -18,9 +18,9 @@ class SFTLoss(ABC): # https://github.com/pytorch/pytorch/pull/91819 call_super_init = True - def __init__(self, *, enable_loss_parallel: bool = False): + def __init__(self, *, tp_enabled: bool = False): super().__init__() - self.enable_loss_parallel = enable_loss_parallel + self.tp_enabled = tp_enabled def apply_compile_strategy(self, *args, **kwargs): """Compile the loss function for inference.""" @@ -45,48 +45,17 @@ def forward(self, outputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: pass @property - @abstractmethod - def supports_loss_parallel(self) -> bool: - """ - Whether the loss function supports loss parallel. - Set to false if loss parallelism isn't tested with your loss class. - """ - pass - - @property - @abstractmethod - def loss_parallel_requires_ctx_manager(self) -> bool: + def tp_requires_loss_parallel_ctx_manager(self) -> bool: """ Whether to use the loss parallel context manager for loss parallelism. Can be used if the function relies on the standard cross_entropy() or CrossEntropyLoss. - Set to false if loss parallelism isn't tested with your loss class, or your loss - parallelism doesn't require the context manager. """ - pass + return False def patch_tp_plan(self, tp_plan) -> bool: """Whether the loss function supports loss parallel. Defaults to a noop.""" return tp_plan - @property - def loss_parallel_enabled(self) -> bool: - """ - The `enable_loss_parallel` flag is a directive from the user. - This property also validates that it is supported, or raises an error. - """ - if self.enable_loss_parallel and not self.supports_loss_parallel: - raise ValueError( - f"Loss function is enabled, but {self.__class__.__name__} does not support loss parallelism" - ) - return self.enable_loss_parallel - - @property - def use_loss_parallel_ctx_manager(self) -> bool: - """ - Whether to enable the loss parallelism ctx manager for this instance of the class. - """ - return self.loss_parallel_enabled and self.loss_parallel_requires_ctx_manager - class RLLoss(ABC): """Interface for loss functions in torchtune used in RL recipes.""" diff --git a/torchtune/training/_distributed.py b/torchtune/training/_distributed.py index a31be3eb04..ccded67a2e 100644 --- a/torchtune/training/_distributed.py +++ b/torchtune/training/_distributed.py @@ -65,7 +65,6 @@ class ParallelDims: tp: int cp: int world_size: int - enable_loss_parallel: bool def __post_init__(self): self._validate() @@ -154,10 +153,6 @@ def dp_shard_enabled(self): def tp_enabled(self): return self.tp > 1 - @property - def loss_parallel_enabled(self): - return self.enable_loss_parallel and self.tp > 1 - @cached_property def non_data_parallel_size(self): # update below as more parallelism options are implemented From f3227eb53f082c3001a7e5cf3370d293109e7357 Mon Sep 17 00:00:00 2001 From: Nathan Azrak Date: Tue, 17 Jun 2025 02:42:26 +0000 Subject: [PATCH 23/29] remove comment --- recipes/full_finetune_distributed.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 1ed40c4758..70443d0aa9 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -357,10 +357,6 @@ def setup(self, cfg: DictConfig) -> None: # initialize loss self._loss_fn = config.instantiate(cfg.loss) - - # Whether to use the ctx manager. If the loss fn has the property, use that. Otherwise, assume it is not supported. - # Currently our TP plans assume replicating on the output of `output` so without a custom loss class patching the - # TP plan, there is no memory benefit to the ctx manager self.use_loss_parallel_ctx_manager = self.parallel_dims.tp_enabled and getattr( self._loss_fn, "tp_requires_loss_parallel_ctx_manager", From f1e33302b1f82191fcf84449908482413b630991 Mon Sep 17 00:00:00 2001 From: Nathan Azrak Date: Tue, 17 Jun 2025 23:00:23 +0000 Subject: [PATCH 24/29] cleanup --- torchtune/modules/loss/cross_entropy_loss.py | 15 +++++++-------- torchtune/modules/loss/loss_types.py | 2 +- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/torchtune/modules/loss/cross_entropy_loss.py b/torchtune/modules/loss/cross_entropy_loss.py index 2d510ab8ba..4c5cf208b1 100644 --- a/torchtune/modules/loss/cross_entropy_loss.py +++ b/torchtune/modules/loss/cross_entropy_loss.py @@ -54,14 +54,13 @@ def __init__( def apply_compile_strategy(self, *args, **kwargs): """Applies compile only to the compute_cross_entropy function. If compiling CE + chunking operation together, memory requirement is higher.""" - # compiling with loss parallelism appears to work without masking - if not self.tp_enabled or not self.mask_ignored_tokens: - self.compute_cross_entropy = torch.compile( - self.compute_cross_entropy, *args, **kwargs + if self.tp_enabled and self.mask_ignored_tokens: + log.warning( + "Skipping compile loss, as it is not supported with both masking and tensor parallelism enabled." ) else: - log.warning( - "Skipping compile loss, as it is not supported with loss parallel enabled." + self.compute_cross_entropy = torch.compile( + self.compute_cross_entropy, *args, **kwargs ) return self @@ -129,6 +128,7 @@ def mask_inputs( - The indexed hidden states - The indexed targets """ + indices = torch.where(target != self.ignore_index)[0] if isinstance(hidden, DTensor): device_mesh = hidden.device_mesh hidden = hidden.to_local().index_select(0, indices) @@ -171,8 +171,7 @@ def forward( outputs = outputs.reshape(-1, outputs.shape[-1]) if self.mask_ignored_tokens: - indices = torch.where(targets != self.ignore_index)[0] - outputs, targets = self.mask_inputs(outputs, targets, indices) + outputs, targets = self.mask_inputs(outputs, targets) hidden_chunks = outputs.tensor_split(self.num_output_chunks, dim=0) target_chunks = targets.tensor_split(self.num_output_chunks, dim=0) diff --git a/torchtune/modules/loss/loss_types.py b/torchtune/modules/loss/loss_types.py index 7b5eb2c284..022489253b 100644 --- a/torchtune/modules/loss/loss_types.py +++ b/torchtune/modules/loss/loss_types.py @@ -15,7 +15,7 @@ class SFTLoss(ABC): """Interface for loss functions in torchtune used in sft recipes.""" # makes subclasses with multiple inheritance including nn.Module play nicely - # https://github.com/pytorch/pytorch/pull/91819 + # https://docs.pytorch.org/docs/stable/distributed.tensor.parallel.html#torch.distributed.tensor.parallel.loss_parallel call_super_init = True def __init__(self, *, tp_enabled: bool = False): From 7ceec03d569ccd7c0bccaf2edc63dbb2795f12c0 Mon Sep 17 00:00:00 2001 From: Nathan Azrak Date: Tue, 17 Jun 2025 23:03:06 +0000 Subject: [PATCH 25/29] fix docstrings in loss_types --- torchtune/modules/loss/loss_types.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchtune/modules/loss/loss_types.py b/torchtune/modules/loss/loss_types.py index 022489253b..ce4944840d 100644 --- a/torchtune/modules/loss/loss_types.py +++ b/torchtune/modules/loss/loss_types.py @@ -15,7 +15,7 @@ class SFTLoss(ABC): """Interface for loss functions in torchtune used in sft recipes.""" # makes subclasses with multiple inheritance including nn.Module play nicely - # https://docs.pytorch.org/docs/stable/distributed.tensor.parallel.html#torch.distributed.tensor.parallel.loss_parallel + # https://github.com/pytorch/pytorch/pull/91819 call_super_init = True def __init__(self, *, tp_enabled: bool = False): @@ -47,8 +47,8 @@ def forward(self, outputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: @property def tp_requires_loss_parallel_ctx_manager(self) -> bool: """ - Whether to use the loss parallel context manager for loss parallelism. Can be - used if the function relies on the standard cross_entropy() or CrossEntropyLoss. + Whether to use the loss parallel context manager for loss parallelism. + https://docs.pytorch.org/docs/stable/distributed.tensor.parallel.html#torch.distributed.tensor.parallel.loss_parallel """ return False From bb58300e2ed5b670ae2430270c89145da4797487 Mon Sep 17 00:00:00 2001 From: Nathan Azrak Date: Tue, 17 Jun 2025 23:19:18 +0000 Subject: [PATCH 26/29] remove unused `indices` arg --- torchtune/modules/loss/cross_entropy_loss.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchtune/modules/loss/cross_entropy_loss.py b/torchtune/modules/loss/cross_entropy_loss.py index 4c5cf208b1..b17621e1a8 100644 --- a/torchtune/modules/loss/cross_entropy_loss.py +++ b/torchtune/modules/loss/cross_entropy_loss.py @@ -115,13 +115,14 @@ def compute_cross_entropy( return loss def mask_inputs( - self, hidden: torch.Tensor, target: torch.Tensor, indices: torch.Tensor + self, + hidden: torch.Tensor, + target: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: """ Args: hidden (torch.Tensor): Hidden state of the model, pre projection. Shape ``[bsz*seq_len, emb_dim]`` target (torch.Tensor): Labels for the model. Shape ``[bsz*seq_len]`` - indices (torch.Tensor): Indices of the valid tokens. Shape ``[num_valid]`` Returns: tuple[torch.Tensor, torch.Tensor]: returns a tuple of From b89260af9c5a07548a181d8819a048212f519f85 Mon Sep 17 00:00:00 2001 From: Nathan Azrak Date: Thu, 19 Jun 2025 00:38:57 +0000 Subject: [PATCH 27/29] make python3.9 happy Signed-off-by: Nathan Azrak --- torchtune/training/_distributed.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchtune/training/_distributed.py b/torchtune/training/_distributed.py index ccded67a2e..47217275bb 100644 --- a/torchtune/training/_distributed.py +++ b/torchtune/training/_distributed.py @@ -12,7 +12,7 @@ from dataclasses import dataclass from functools import cached_property from itertools import chain -from typing import Any, Callable, cast, Generator, Optional, Union +from typing import Any, Callable, cast, Generator, Optional import torch import torch.distributed as dist @@ -785,7 +785,7 @@ def _get_sdpa_context() -> ( """ @contextlib.contextmanager - def context(cp_context: Union[Generator[None, None, None], None] = None): + def context(cp_context: Optional[Generator[None, None, None]] = None): with contextlib.ExitStack() as stack: if cp_context is not None: stack.enter_context( @@ -883,7 +883,7 @@ def context(model_inputs: list[torch.Tensor]): def get_train_context(enable_loss_parallel: bool) -> Generator[None, None, None]: @contextlib.contextmanager - def context(cp_context: Generator[None, None, None] | None = None): + def context(cp_context: Optional[Generator[None, None, None]] = None): with contextlib.ExitStack() as stack: if enable_loss_parallel: stack.enter_context(torch.distributed.tensor.parallel.loss_parallel()) From f4060ac8c2b2f1ff988961c2b8a4a2b448611369 Mon Sep 17 00:00:00 2001 From: Nathan Azrak Date: Thu, 19 Jun 2025 00:39:30 +0000 Subject: [PATCH 28/29] Fix typehints, remove `__init__` from SFTLoss ABC Signed-off-by: Nathan Azrak --- torchtune/modules/loss/cross_entropy_loss.py | 4 ++-- torchtune/modules/loss/loss_types.py | 10 +--------- 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/torchtune/modules/loss/cross_entropy_loss.py b/torchtune/modules/loss/cross_entropy_loss.py index b17621e1a8..18a1ad0065 100644 --- a/torchtune/modules/loss/cross_entropy_loss.py +++ b/torchtune/modules/loss/cross_entropy_loss.py @@ -38,7 +38,7 @@ def __init__( tp_enabled: bool = False, mask_ignored_tokens: bool = True, ): - super().__init__(tp_enabled=tp_enabled) + super().__init__() """ Args: num_output_chunks (int): Number of chunks to split the output tensor into. Default is 8. @@ -69,7 +69,7 @@ def set_model_output(self, model: nn.Module) -> None: model.skip_output_layer = True self.linear_projection = model.output - def patch_tp_plan(self, tp_plan) -> bool: + def patch_tp_plan(self, tp_plan) -> dict: if "output" not in tp_plan: raise KeyError("`tp_plan` requires `output` key") diff --git a/torchtune/modules/loss/loss_types.py b/torchtune/modules/loss/loss_types.py index ce4944840d..4c8d942995 100644 --- a/torchtune/modules/loss/loss_types.py +++ b/torchtune/modules/loss/loss_types.py @@ -14,14 +14,6 @@ class SFTLoss(ABC): """Interface for loss functions in torchtune used in sft recipes.""" - # makes subclasses with multiple inheritance including nn.Module play nicely - # https://github.com/pytorch/pytorch/pull/91819 - call_super_init = True - - def __init__(self, *, tp_enabled: bool = False): - super().__init__() - self.tp_enabled = tp_enabled - def apply_compile_strategy(self, *args, **kwargs): """Compile the loss function for inference.""" self.forward = torch.compile(self.forward, *args, **kwargs) @@ -52,7 +44,7 @@ def tp_requires_loss_parallel_ctx_manager(self) -> bool: """ return False - def patch_tp_plan(self, tp_plan) -> bool: + def patch_tp_plan(self, tp_plan) -> dict: """Whether the loss function supports loss parallel. Defaults to a noop.""" return tp_plan From 2a5714c6a51861a6a2610f2d64c8d6f860eaced1 Mon Sep 17 00:00:00 2001 From: Nathan Azrak Date: Thu, 19 Jun 2025 00:43:19 +0000 Subject: [PATCH 29/29] rename context parallel ctx manager Signed-off-by: Nathan Azrak --- recipes/full_finetune_distributed.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 70443d0aa9..10c64559d4 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -701,7 +701,7 @@ def _setup_model( model, enable_activation_offloading, activation_offloading_use_streams ) # context parallel - self.optional_context_parallel_manager = training.get_context_parallel_manager( + self.context_parallel_manager = training.get_context_parallel_manager( enabled=self.cp_degree > 1, world_mesh=self.world_mesh, model=model, @@ -944,7 +944,7 @@ def train(self) -> None: num_tokens += current_num_tokens with self.train_context( - self.optional_context_parallel_manager(list(batch.values())) + self.context_parallel_manager(list(batch.values())) ): # Loss is normalized by default so we multiply by the number of tokens # This way we can normalize by the total number of tokens if we're accumulating gradients