diff --git a/torchtitan/experiments/simple_fsdp/backend.py b/torchtitan/experiments/simple_fsdp/backend.py index 55fb099afb..36abe4ad0b 100644 --- a/torchtitan/experiments/simple_fsdp/backend.py +++ b/torchtitan/experiments/simple_fsdp/backend.py @@ -21,12 +21,14 @@ def get_compile_backend(backend_name: str) -> Union[str, callable]: # Perform auto optimization in aten fx-level and execute code in aot_eager backend # The autobucketing logic is here: https://github.com/pytorch/pytorch/pull/163960 from torch._dynamo.backends.common import aot_autograd as aot_autograd_backend + + from torch._inductor.config import aten_distributed_optimizations as dist_opts from torch._inductor.fx_passes.overlap_scheduling import ( schedule_overlap_bucketing, ) - torch._inductor.config.test_configs.aten_fx_overlap_preserving_bucketing = True - torch._inductor.config.test_configs.aten_fx_overlap_insert_overlap_deps = False + dist_opts.collective_bucketing = True + dist_opts.insert_overlap_deps = False torch._inductor.config.allow_buffer_reuse = False def aten_autobucketing_reordering_pass( diff --git a/torchtitan/experiments/simple_fsdp/simple_fsdp.py b/torchtitan/experiments/simple_fsdp/simple_fsdp.py index 9ca74601e9..a7cef3ca9a 100644 --- a/torchtitan/experiments/simple_fsdp/simple_fsdp.py +++ b/torchtitan/experiments/simple_fsdp/simple_fsdp.py @@ -7,7 +7,6 @@ from collections.abc import Sequence from contextlib import contextmanager from dataclasses import dataclass -from typing import List, Optional import torch import torch.nn as nn @@ -19,7 +18,7 @@ Replicate, Shard, ) -from torch.distributed.device_mesh import _mesh_resources, DeviceMesh +from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor._dtensor_spec import DTensorSpec from torch.distributed.tensor._redistribute import redistribute_local_tensor from torch.distributed.tensor.placement_types import _StridedShard, Placement @@ -45,8 +44,8 @@ def disable_active_parametrization(): @dataclass(frozen=True) class MixedPrecisionPolicy: - param_dtype: Optional[torch.dtype] = None - reduce_dtype: Optional[torch.dtype] = None + param_dtype: torch.dtype | None = None + reduce_dtype: torch.dtype | None = None class _ScaledPartial(Partial): @@ -95,19 +94,7 @@ def _distribute_dtensor( """ inner_spec = tensor._spec outer_mesh, inner_mesh = device_mesh, inner_spec.mesh - outer_global_mesh = _mesh_resources.get_root_mesh(outer_mesh) - inner_global_mesh = _mesh_resources.get_root_mesh(inner_mesh) - if outer_global_mesh != inner_global_mesh or ( - outer_global_mesh is None or inner_global_mesh is None - ): - raise AssertionError( - "Cannot distribute tensor across two meshes without the same root mesh: \n" - f"outer global mesh: {outer_global_mesh}\ninner global mesh: {inner_global_mesh}" - ) - assert outer_mesh.mesh_dim_names is not None - assert inner_mesh.mesh_dim_names is not None - submesh_names = outer_mesh.mesh_dim_names + inner_mesh.mesh_dim_names - spanned_mesh = outer_global_mesh[submesh_names] + spanned_mesh = DeviceMesh._concatenate([outer_mesh, inner_mesh]) if len(dp_placements) == 1: assert dp_placements[0].is_replicate() or dp_placements[0].is_shard() @@ -173,8 +160,8 @@ def _distribute_dtensor( def _register_parametrization( - module: nn.Module, param_names: List[str], parametrization: nn.Module -): + module: nn.Module, param_names: list[str], parametrization: nn.Module +) -> None: """ It works with state_dict without incurring parametrization calls because state_dict accesses parameters directly from self._parameters, not from getters @@ -242,7 +229,7 @@ def __init__( self.param_dtype = mp_policy.param_dtype self.reduce_dtype = mp_policy.reduce_dtype - def replicate_compute(self, x): + def replicate_compute(self, x: DTensor) -> torch.Tensor: # data parallel runtime replicate parameters and do local compute # the gradients are partial tensors that needs to perform reduction # (i.e. DDP: allreduce, FSDP: reduce_scatter, HSDP: mix of both) @@ -250,8 +237,6 @@ def replicate_compute(self, x): non_dp_mesh_dims = x._spec.mesh.ndim - self.device_mesh.ndim assert non_dp_mesh_dims <= 2, "Only DP + EP/TP/EP+TP is supported" if non_dp_mesh_dims > 0: - # TODO: remove tp_mesh as an input arg to data_parallel API and use x._spec.mesh["tp"] - # after DeviceMesh supports slicing a non-root mesh dp_mesh = self.device_mesh # re-wrap 2D DTensor to 1D DTensor on dp_mesh for efficient FSDP all-gather sharded_local_tensor = x.to_local() @@ -295,7 +280,7 @@ def replicate_compute(self, x): return output - def forward(self, x): + def forward(self, x: DTensor) -> torch.Tensor: global _active_parametrization # This should never be set to true during forward, only outside for model # inspection / debugging / initialization @@ -308,7 +293,10 @@ def forward(self, x): if self.regional_ac and self.mode in ("fully_shard", "hybrid_shard"): # apply checkpointing to implement reshard_after_forward output = checkpoint( - self.replicate_compute, x, use_reentrant=False, context_fn=fsdp_policy + self.replicate_compute, + x, + use_reentrant=False, + context_fn=fsdp_policy, ) else: output = self.replicate_compute(x) @@ -317,13 +305,13 @@ def forward(self, x): def data_parallel( - model, - device_mesh, - mode="replicate", + model: nn.Module, + device_mesh: DeviceMesh, + mode: str = "replicate", ac_mode: str = "none", - mp_policy: Optional[MixedPrecisionPolicy] = None, + mp_policy: MixedPrecisionPolicy | None = None, shard_dim: int = 0, - reduction_divide_factor: Optional[float] = None, + reduction_divide_factor: float | None = None, ): if mode == "replicate": param_sharding = (Replicate(),)