Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 20 additions & 30 deletions torchtitan/experiments/simple_fsdp/simple_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from collections.abc import Sequence
from contextlib import contextmanager
from dataclasses import dataclass
from typing import List, Optional
from typing import cast

import torch
import torch.nn as nn
Expand All @@ -19,7 +19,7 @@
Replicate,
Shard,
)
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor._dtensor_spec import DTensorSpec
from torch.distributed.tensor._redistribute import redistribute_local_tensor
from torch.distributed.tensor.placement_types import _StridedShard, Placement
Expand All @@ -45,8 +45,8 @@ def disable_active_parametrization():

@dataclass(frozen=True)
class MixedPrecisionPolicy:
param_dtype: Optional[torch.dtype] = None
reduce_dtype: Optional[torch.dtype] = None
param_dtype: torch.dtype | None = None
reduce_dtype: torch.dtype | None = None


class _ScaledPartial(Partial):
Expand Down Expand Up @@ -95,19 +95,7 @@ def _distribute_dtensor(
"""
inner_spec = tensor._spec
outer_mesh, inner_mesh = device_mesh, inner_spec.mesh
outer_global_mesh = _mesh_resources.get_root_mesh(outer_mesh)
inner_global_mesh = _mesh_resources.get_root_mesh(inner_mesh)
if outer_global_mesh != inner_global_mesh or (
outer_global_mesh is None or inner_global_mesh is None
):
raise AssertionError(
"Cannot distribute tensor across two meshes without the same root mesh: \n"
f"outer global mesh: {outer_global_mesh}\ninner global mesh: {inner_global_mesh}"
)
assert outer_mesh.mesh_dim_names is not None
assert inner_mesh.mesh_dim_names is not None
submesh_names = outer_mesh.mesh_dim_names + inner_mesh.mesh_dim_names
spanned_mesh = outer_global_mesh[submesh_names]
spanned_mesh = DeviceMesh._concatenate((outer_mesh, inner_mesh))

if len(dp_placements) == 1:
assert dp_placements[0].is_replicate() or dp_placements[0].is_shard()
Expand Down Expand Up @@ -173,8 +161,8 @@ def _distribute_dtensor(


def _register_parametrization(
module: nn.Module, param_names: List[str], parametrization: nn.Module
):
module: nn.Module, param_names: list[str], parametrization: nn.Module
) -> None:
"""
It works with state_dict without incurring parametrization calls because
state_dict accesses parameters directly from self._parameters, not from getters
Expand Down Expand Up @@ -242,16 +230,14 @@ def __init__(
self.param_dtype = mp_policy.param_dtype
self.reduce_dtype = mp_policy.reduce_dtype

def replicate_compute(self, x):
def replicate_compute(self, x: DTensor) -> torch.Tensor:
# data parallel runtime replicate parameters and do local compute
# the gradients are partial tensors that needs to perform reduction
# (i.e. DDP: allreduce, FSDP: reduce_scatter, HSDP: mix of both)
# support FSDP/DDP/HSDP + EP + TP (assuming TP shards the inner-most dim)
non_dp_mesh_dims = x._spec.mesh.ndim - self.device_mesh.ndim
assert non_dp_mesh_dims <= 2, "Only DP + EP/TP/EP+TP is supported"
if non_dp_mesh_dims > 0:
# TODO: remove tp_mesh as an input arg to data_parallel API and use x._spec.mesh["tp"]
# after DeviceMesh supports slicing a non-root mesh
dp_mesh = self.device_mesh
# re-wrap 2D DTensor to 1D DTensor on dp_mesh for efficient FSDP all-gather
sharded_local_tensor = x.to_local()
Expand Down Expand Up @@ -295,7 +281,7 @@ def replicate_compute(self, x):

return output

def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
global _active_parametrization
# This should never be set to true during forward, only outside for model
# inspection / debugging / initialization
Expand All @@ -305,25 +291,29 @@ def forward(self, x):
if not _active_parametrization:
return x

assert isinstance(x, DTensor)
if self.regional_ac and self.mode in ("fully_shard", "hybrid_shard"):
# apply checkpointing to implement reshard_after_forward
output = checkpoint(
self.replicate_compute, x, use_reentrant=False, context_fn=fsdp_policy
self.replicate_compute,
cast(DTensor, x),
use_reentrant=False,
context_fn=fsdp_policy,
)
else:
output = self.replicate_compute(x)
output = self.replicate_compute(cast(DTensor, x))

return output


def data_parallel(
model,
device_mesh,
mode="replicate",
model: nn.Module,
device_mesh: DeviceMesh,
mode: str = "replicate",
ac_mode: str = "none",
mp_policy: Optional[MixedPrecisionPolicy] = None,
mp_policy: MixedPrecisionPolicy | None = None,
shard_dim: int = 0,
reduction_divide_factor: Optional[float] = None,
reduction_divide_factor: float | None = None,
):
if mode == "replicate":
param_sharding = (Replicate(),)
Expand Down
Loading