diff --git a/torchtitan/experiments/simple_fsdp/backend.py b/torchtitan/experiments/simple_fsdp/backend.py index 55fb099afb..36abe4ad0b 100644 --- a/torchtitan/experiments/simple_fsdp/backend.py +++ b/torchtitan/experiments/simple_fsdp/backend.py @@ -21,12 +21,14 @@ def get_compile_backend(backend_name: str) -> Union[str, callable]: # Perform auto optimization in aten fx-level and execute code in aot_eager backend # The autobucketing logic is here: https://github.com/pytorch/pytorch/pull/163960 from torch._dynamo.backends.common import aot_autograd as aot_autograd_backend + + from torch._inductor.config import aten_distributed_optimizations as dist_opts from torch._inductor.fx_passes.overlap_scheduling import ( schedule_overlap_bucketing, ) - torch._inductor.config.test_configs.aten_fx_overlap_preserving_bucketing = True - torch._inductor.config.test_configs.aten_fx_overlap_insert_overlap_deps = False + dist_opts.collective_bucketing = True + dist_opts.insert_overlap_deps = False torch._inductor.config.allow_buffer_reuse = False def aten_autobucketing_reordering_pass( diff --git a/torchtitan/experiments/simple_fsdp/simple_fsdp.py b/torchtitan/experiments/simple_fsdp/simple_fsdp.py index 9ca74601e9..674e9b44c9 100644 --- a/torchtitan/experiments/simple_fsdp/simple_fsdp.py +++ b/torchtitan/experiments/simple_fsdp/simple_fsdp.py @@ -19,7 +19,7 @@ Replicate, Shard, ) -from torch.distributed.device_mesh import _mesh_resources, DeviceMesh +from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor._dtensor_spec import DTensorSpec from torch.distributed.tensor._redistribute import redistribute_local_tensor from torch.distributed.tensor.placement_types import _StridedShard, Placement @@ -95,19 +95,7 @@ def _distribute_dtensor( """ inner_spec = tensor._spec outer_mesh, inner_mesh = device_mesh, inner_spec.mesh - outer_global_mesh = _mesh_resources.get_root_mesh(outer_mesh) - inner_global_mesh = _mesh_resources.get_root_mesh(inner_mesh) - if outer_global_mesh != inner_global_mesh or ( - outer_global_mesh is None or inner_global_mesh is None - ): - raise AssertionError( - "Cannot distribute tensor across two meshes without the same root mesh: \n" - f"outer global mesh: {outer_global_mesh}\ninner global mesh: {inner_global_mesh}" - ) - assert outer_mesh.mesh_dim_names is not None - assert inner_mesh.mesh_dim_names is not None - submesh_names = outer_mesh.mesh_dim_names + inner_mesh.mesh_dim_names - spanned_mesh = outer_global_mesh[submesh_names] + spanned_mesh = DeviceMesh._concatenate([outer_mesh, inner_mesh]) if len(dp_placements) == 1: assert dp_placements[0].is_replicate() or dp_placements[0].is_shard()