Skip to content
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
8aa66d9
Add loss parallel to ParallelDims, and train context manager
nazrak-atlassian Jun 3, 2025
f9315dd
Add loss parallel support to LinearCrossEntropyLoss
nazrak-atlassian Jun 3, 2025
fd7872c
ungate fp8, clean up llama3 parallelism, add loss parallel to TP plans
nazrak-atlassian Jun 3, 2025
f0a40fa
Add loss parallel support to full finetune recipe
nazrak-atlassian Jun 3, 2025
6e93e83
allow enabling autograd compile even though it doesn't work
nazrak-atlassian Jun 3, 2025
566142a
correct import paths in unit tests
nazrak-atlassian Jun 3, 2025
b3ae7c4
Merge branch 'pytorch:main' into enable_loss_parallel
nathan-az Jun 10, 2025
defc48f
remove autograd compile for now
nathan-az Jun 11, 2025
6eb0d85
Remove unnecessary `full_tensor` in LinearCrossEntropyLoss
nathan-az Jun 11, 2025
ae4e751
remove layerwise prefix in llama3 tp plans
nathan-az Jun 11, 2025
c4abec6
Refactor tp plans to support fp8 training via arg
nathan-az Jun 11, 2025
f037fab
Refactor loss parallel into custom loss classes
nathan-az Jun 11, 2025
ed61113
Fix Replicate caused by `tensor_split` in Linear CE, disallow masking…
nathan-az Jun 11, 2025
5875292
re-introduce masking in tensor parallel linear CE loss.
nathan-az Jun 12, 2025
bba42d5
revert accidental num_tokens metric in recipe.
nathan-az Jun 12, 2025
87d5fd6
clean up linear CE loss
nathan-az Jun 12, 2025
f5efb0f
explicitly flatten target and hidden chunks
nathan-az Jun 15, 2025
40df86d
refactor to allow compile in non-parallel case
nathan-az Jun 16, 2025
c0df2b2
comment clarity
nathan-az Jun 16, 2025
def5ed4
reshape + mask before chunking, allow toggling masking in loss
nathan-az Jun 16, 2025
ef60191
clean up comments
nathan-az Jun 16, 2025
65c1d3a
Clean up docstrings
nathan-az Jun 17, 2025
8294215
simplify SFTLoss contract
nazrak-atlassian Jun 17, 2025
f3227eb
remove comment
nathan-az Jun 17, 2025
f1e3330
cleanup
nathan-az Jun 17, 2025
7ceec03
fix docstrings in loss_types
nathan-az Jun 17, 2025
bb58300
remove unused `indices` arg
nathan-az Jun 17, 2025
b89260a
make python3.9 happy
nathan-az Jun 19, 2025
f4060ac
Fix typehints, remove `__init__` from SFTLoss ABC
nathan-az Jun 19, 2025
2a5714c
rename context parallel ctx manager
nathan-az Jun 19, 2025
e275c88
Merge branch 'pytorch:main' into enable_loss_parallel
nathan-az Jun 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 28 additions & 17 deletions recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,14 @@ def setup(self, cfg: DictConfig) -> None:
self._grad_scaler, backend=self._compile_backend
)

# initialize loss
self._loss_fn = config.instantiate(cfg.loss)
self.use_loss_parallel_ctx_manager = self.parallel_dims.tp_enabled and getattr(
self._loss_fn,
"tp_requires_loss_parallel_ctx_manager",
False,
)

self._model = self._setup_model(
cfg_model=cfg.model,
enable_activation_checkpointing=self._enable_activation_checkpointing,
Expand Down Expand Up @@ -416,8 +424,6 @@ def setup(self, cfg: DictConfig) -> None:
# Update the recipe state from the checkpoint state dict.
self._update_recipe_state(checkpoint_dict)

# initialize loss
self._loss_fn = config.instantiate(cfg.loss)
if isinstance(self._loss_fn, SFTLoss):
self._loss_fn.set_model_output(self._model)

Expand Down Expand Up @@ -603,11 +609,6 @@ def _setup_model(
raise RuntimeError(
"Float8 fine-tuning requires PyTorch 2.8.0.dev20250318 or later."
)
if self.tp_plan is not None:
raise ValueError(
"FP8 training does not support tensor parallelism yet. "
"This will be enabled in the near future."
)
if self.cp_degree > 1:
raise ValueError(
"Context Parallel for fp8 training is not currently supported"
Expand All @@ -626,7 +627,12 @@ def _setup_model(
self.tp_plan = config.instantiate(
self.tp_plan,
model=model,
enable_fp8_training=self._enable_fp8_training,
)
if isinstance(self._loss_fn, SFTLoss):
self._loss_fn.tp_enabled = True
self.tp_plan = self._loss_fn.patch_tp_plan(self.tp_plan)

parallelize_module(
model,
self.world_mesh["tp"],
Expand Down Expand Up @@ -674,13 +680,6 @@ def _setup_model(
dp_mesh=self.world_mesh[dp_mesh_dim_names],
)

# Define context manager for context parallelism
self.context_parallel_manager = training.get_context_parallel_manager(
enabled=self.cp_degree > 1,
world_mesh=self.world_mesh,
model=model,
)

with training.set_default_dtype(self._dtype), self._device:
for m in model.modules():
# RoPE is not covered in state dict
Expand All @@ -701,6 +700,16 @@ def _setup_model(
self.activations_handling_ctx = training.get_act_offloading_ctx_manager(
model, enable_activation_offloading, activation_offloading_use_streams
)
# context parallel
self.optional_context_parallel_manager = training.get_context_parallel_manager(
enabled=self.cp_degree > 1,
world_mesh=self.world_mesh,
model=model,
)
# remaining context managers for fwd/bwd
self.train_context = training.get_train_context(
enable_loss_parallel=self.use_loss_parallel_ctx_manager,
)

# Ensure no params and buffers are on meta device
training.validate_no_params_on_meta_device(model)
Expand Down Expand Up @@ -934,9 +943,11 @@ def train(self) -> None:
).sum()
num_tokens += current_num_tokens

# Loss is normalized by default so we multiply by the number of tokens
# This way we can normalize by the total number of tokens if we're accumulating gradients
with self.context_parallel_manager(list(batch.values())):
with self.train_context(
self.optional_context_parallel_manager(list(batch.values()))
):
# Loss is normalized by default so we multiply by the number of tokens
# This way we can normalize by the total number of tokens if we're accumulating gradients
current_loss = self._loss_step(batch) * current_num_tokens
running_loss += current_loss
# For optimizer in backward, we need to normalize before calling backward
Expand Down
12 changes: 7 additions & 5 deletions tests/torchtune/training/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torchao.float8.float8_linear import Float8Linear

from torchtune.models.llama3 import base_llama_tp_plan
from torchtune.models.llama3._parallelism import _fp8_llama_tp_plan
from torchtune.models.llama3._parallelism import _get_fp8_llama_tp_training_plan
from torchtune.training.quantization import (
_validate_float8_tp_plan,
convert_to_float8_training,
Expand Down Expand Up @@ -54,12 +54,14 @@ def _test_validate_float8_tp_plan(self):
"""
_validate_float8_tp_plan(base_llama_tp_plan())
_validate_float8_tp_plan(base_llama_tp_plan(), "anything")
_validate_float8_tp_plan(_fp8_llama_tp_plan())
_validate_float8_tp_plan(_fp8_llama_tp_plan(), "tensorwise")
_validate_float8_tp_plan(_get_fp8_llama_tp_training_plan())
_validate_float8_tp_plan(_get_fp8_llama_tp_training_plan(), "tensorwise")
with pytest.raises(ValueError):
_validate_float8_tp_plan(_fp8_llama_tp_plan(), "rowwise")
_validate_float8_tp_plan(_get_fp8_llama_tp_training_plan(), "rowwise")
with pytest.raises(ValueError):
_validate_float8_tp_plan(_fp8_llama_tp_plan(), "rowwise_with_gw_hp")
_validate_float8_tp_plan(
_get_fp8_llama_tp_training_plan(), "rowwise_with_gw_hp"
)

def test_is_fp8_tensorwise_scaling(self):
"""
Expand Down
112 changes: 65 additions & 47 deletions torchtune/models/llama3/_parallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@


def _get_base_llama_tp_training_plan(
layerwise_colwise_parallel_cls: type[ParallelStyle] = ColwiseParallel,
layerwise_rowwise_parallel_cls: type[ParallelStyle] = RowwiseParallel,
layerwise_prepare_module_input_cls: type[ParallelStyle] = PrepareModuleInput,
colwise_parallel_cls: type[ParallelStyle] = ColwiseParallel,
rowwise_parallel_cls: type[ParallelStyle] = RowwiseParallel,
prepare_module_input_cls: type[ParallelStyle] = PrepareModuleInput,
) -> dict[str, ParallelStyle]:
"""
Define the Tensor Parallel plan for Llama3 model, which will also be shared with 3.1, 3.2, and 3.3 models.
Expand All @@ -35,74 +35,92 @@ def _get_base_llama_tp_training_plan(
input_layouts=Replicate(), output_layouts=Shard(1)
),
"norm": SequenceParallel(),
"output": ColwiseParallel(input_layouts=Shard(1), output_layouts=Replicate()),
"layers.*.attn": layerwise_prepare_module_input_cls(
"output": ColwiseParallel(
input_layouts=Shard(1),
output_layouts=Replicate(),
),
"layers.*.attn": prepare_module_input_cls(
input_layouts=(Shard(1), Shard(1)),
desired_input_layouts=(Replicate(), Replicate()),
),
"layers.*.mlp": layerwise_prepare_module_input_cls(
"layers.*.mlp": prepare_module_input_cls(
input_layouts=(Shard(1),),
desired_input_layouts=(Replicate(),),
),
"layers.*.sa_norm": SequenceParallel(),
"layers.*.mlp_norm": SequenceParallel(),
"layers.*.attn.q_proj": layerwise_colwise_parallel_cls(),
"layers.*.attn.k_proj": layerwise_colwise_parallel_cls(),
"layers.*.attn.v_proj": layerwise_colwise_parallel_cls(),
"layers.*.attn.output_proj": layerwise_rowwise_parallel_cls(
output_layouts=Shard(1)
),
"layers.*.mlp.w1": layerwise_colwise_parallel_cls(),
"layers.*.mlp.w2": layerwise_rowwise_parallel_cls(output_layouts=Shard(1)),
"layers.*.mlp.w3": layerwise_colwise_parallel_cls(),
"layers.*.attn.q_proj": colwise_parallel_cls(),
"layers.*.attn.k_proj": colwise_parallel_cls(),
"layers.*.attn.v_proj": colwise_parallel_cls(),
"layers.*.attn.output_proj": rowwise_parallel_cls(output_layouts=Shard(1)),
"layers.*.mlp.w1": colwise_parallel_cls(),
"layers.*.mlp.w2": rowwise_parallel_cls(output_layouts=Shard(1)),
"layers.*.mlp.w3": colwise_parallel_cls(),
}


BASE_LLAMA_TP_TRAINING_PLAN = _get_base_llama_tp_training_plan()
def _get_base_llama_tp_inference_plan():
return {
"tok_embeddings": RowwiseParallel(input_layouts=Replicate()),
"output": ColwiseParallel(output_layouts=Replicate()),
"layers.*.attn.q_proj": ColwiseParallel(),
"layers.*.attn.k_proj": ColwiseParallel(),
"layers.*.attn.v_proj": ColwiseParallel(),
"layers.*.attn.output_proj": RowwiseParallel(),
"layers.*.mlp.w1": ColwiseParallel(),
"layers.*.mlp.w2": RowwiseParallel(),
"layers.*.mlp.w3": ColwiseParallel(),
}

FP8_LLAMA_TP_TRAINING_PLAN = _get_base_llama_tp_training_plan(
layerwise_colwise_parallel_cls=Float8ColwiseParallel,
layerwise_rowwise_parallel_cls=Float8RowwiseParallel,
layerwise_prepare_module_input_cls=PrepareFloat8ModuleInput,
)

BASE_LLAMA_TP_INFERENCE_PLAN = {
"tok_embeddings": RowwiseParallel(input_layouts=Replicate()),
"output": ColwiseParallel(output_layouts=Replicate()),
"layers.*.attn.q_proj": ColwiseParallel(),
"layers.*.attn.k_proj": ColwiseParallel(),
"layers.*.attn.v_proj": ColwiseParallel(),
"layers.*.attn.output_proj": RowwiseParallel(),
"layers.*.mlp.w1": ColwiseParallel(),
"layers.*.mlp.w2": RowwiseParallel(),
"layers.*.mlp.w3": ColwiseParallel(),
}
def _get_fp8_llama_tp_training_plan(model: nn.Module) -> dict[str, ParallelStyle]:
"""
Return the tensor parallel plan for Llama3 model that uses float8 for all-gather for both
rowwise and colwise computation, currently only compatible with float8 fine-tuning with
"tensorwise" scaling. This tensor parallel plan is shared between 3.1, 3.2, and 3.3 models.

Args:
model (nn.Module): Model to generate plan for (no-op)

Returns:
dict[str, Any]: The float8-enabled tensor parallel plan for Llama3 model.
"""
return _get_base_llama_tp_training_plan(
colwise_parallel_cls=Float8ColwiseParallel,
rowwise_parallel_cls=Float8RowwiseParallel,
prepare_module_input_cls=PrepareFloat8ModuleInput,
)


def base_llama_tp_plan(
model: nn.Module, inference: bool = False
model: nn.Module,
*,
inference: bool = False,
enable_fp8_training: bool = False,
) -> dict[str, ParallelStyle]:
"""
Helper function to get the base tensor parallel plan for Llama3 model, which will also be shared with 3.1, 3.2, and 3.3 models

Args:
model (nn.Module): Model to generate plan for (no-op)
inference (bool): Whether running inference or not.
inference (bool): Whether running inference or not
enable_fp8_training (bool): Whether to enable float8 training.

Returns:
dict[str, Any]: The tensor parallel plan for Llama3 model.
"""
return BASE_LLAMA_TP_INFERENCE_PLAN if inference else BASE_LLAMA_TP_TRAINING_PLAN


# TODO: expose this once tested
def _fp8_llama_tp_plan() -> dict[str, ParallelStyle]:
"""
Return the tensor parallel plan for Llama3 model that uses float8 for all-gather for both
rowwise and colwise computation, currently only compatible with float8 fine-tuning with
"tensorwise" scaling. This tensor parallel plan is shared between 3.1, 3.2, and 3.3 models.

Returns:
dict[str, Any]: The float8-enabled tensor parallel plan for Llama3 model.
Raises:
ValueError: if FP8 training is enabled with inference.
"""
return FP8_LLAMA_TP_TRAINING_PLAN
if enable_fp8_training:
if inference:
raise ValueError(
"FP8 training is not compatible with inference with LLaMA-3"
)
return _get_fp8_llama_tp_training_plan(model)

return (
_get_base_llama_tp_inference_plan()
if inference
else _get_base_llama_tp_training_plan()
)
11 changes: 10 additions & 1 deletion torchtune/models/llama4/_parallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,18 +161,27 @@ def decoder_only_tp_inference_plan(model: nn.Module) -> dict[str, ParallelStyle]


def decoder_only_tp_plan(
model: nn.Module, inference: bool = False
model: nn.Module,
*,
inference: bool = False,
enable_fp8_training: bool = False,
) -> dict[str, ParallelStyle]:
"""
Helper function to get the base tensor parallel plan for Llama3 model, which will also be shared with 3.1, 3.2, and 3.3 models

Args:
model (nn.Module): Model to generate plan for (no-op)
inference (bool): Whether running inference or not.
enable_fp8_training (bool): Whether to enable float8 training. Currently not supported for Llama4.

Returns:
dict[str, Any]: The tensor parallel plan for Llama3 model.

Raises:
ValueError: if FP8 training is enabled.
"""
if enable_fp8_training:
raise ValueError("FP8 training is not supported for Llama4")
return (
decoder_only_tp_inference_plan(model)
if inference
Expand Down
Loading
Loading