@@ -417,6 +417,76 @@ def partition_tensor(self, param, empty_param, param_type, param_casting_dtype,
417417 parameter = DTensor .from_local (parameter , device_mesh , [Shard (- 1 )], run_check = False )
418418 return nn .Parameter (parameter )
419419
420+ class ReplicateParallel (TensorParallelLayer ):
421+ """
422+ Replicate a nn.Module.
423+ Users can compose it together with other parallel styles like RowwiseParallel to achieve a fully distributed model.
424+ Fully distributed model is needed for gradient clipping.
425+
426+ Keyword Args:
427+ input_layouts (Placement, optional):
428+ The DTensor layout of input tensor for the nn.Module, this is used to annotate the input tensor to
429+ become a DTensor. If not specified, we assume the input tensor to be replicated.
430+ output_layouts (Placement, optional):
431+ The DTensor layout of the output for the nn.Module, this is used to ensure the output of the nn.Module
432+ with the user desired layout. If not specified, we assume the output tensor to be replicated.
433+ use_local_output (bool, optional):
434+ Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module output, default: True.
435+ Returns:
436+ A :class:`ParallelStyle` object that represents replication of nn.Module.
437+
438+ Example::
439+ >>> # xdoctest: +SKIP(failing)
440+ >>> from torch.distributed.tensor.parallel import parallelize_module, ReplicateParallel
441+ >>> from torch.distributed.device_mesh import init_device_mesh
442+ >>> ...
443+ >>> m = Model(...) # m is a nn.Module that contains a "w1" nn.Linear submodule
444+ >>> tp_mesh = init_device_mesh("cuda", (8,))
445+ >>>
446+ >>> # By default, the input and output of the "w1" Linear will be converted to Replicated DTensor
447+ >>>
448+ >>> replicated_mod = parallelize_module(m, tp_mesh, {"w1": ReplicateParallel()})
449+ >>> ...
450+
451+ """
452+
453+
454+ def __init__ (
455+ self ,
456+ * ,
457+ input_layouts : Optional [Placement ] = None ,
458+ output_layouts : Optional [Placement ] = None ,
459+ use_local_output : bool = True ,
460+ use_dtensor = True ,
461+ ):
462+
463+ super ().__init__ ()
464+ self .input_layouts = (input_layouts or Replicate (),)
465+ self .output_layouts = (output_layouts or Replicate (),)
466+ self .desired_input_layouts = (Replicate (),)
467+ self .use_local_output = use_local_output
468+ self .use_dtensor = use_dtensor
469+
470+ @staticmethod
471+ def _prepare_input_fn (input_layouts , desired_input_layouts , mod , inputs , device_mesh ):
472+ # since nn.Linear and nn.Embedding have single input
473+ # we may extend support to other modules since its replicate.
474+ input_tensor = inputs [0 ]
475+ if isinstance (input_tensor , torch .distributed ._functional_collectives .AsyncCollectiveTensor ):
476+ input_tensor = input_tensor .trigger_wait ()
477+ if not isinstance (input_tensor , DTensor ):
478+ input_tensor = DTensor .from_local (input_tensor , device_mesh , input_layouts , run_check = False )
479+
480+ if input_layouts != desired_input_layouts :
481+ input_tensor = input_tensor .redistribute (placements = desired_input_layouts , async_op = True )
482+ return input_tensor
483+
484+
485+ @staticmethod
486+ def _prepare_output_fn (output_layouts , use_local_output , mod , outputs , device_mesh ):
487+ if outputs .placements != output_layouts :
488+ outputs = outputs .redistribute (placements = output_layouts , async_op = True )
489+ return outputs .to_local () if use_local_output else outputs
420490
421491SUPPORTED_TP_STYLES = {
422492 "colwise" ,
@@ -428,6 +498,8 @@ def partition_tensor(self, param, empty_param, param_type, param_casting_dtype,
428498 "local" ,
429499 "gather" ,
430500 "local_packed_rowwise" ,
501+ "replicate" ,
502+ "replicate_output_dtensor"
431503}
432504
433505
@@ -459,6 +531,10 @@ def translate_to_torch_parallel_style(style: str):
459531 return GatherParallel ()
460532 elif style == "local_packed_rowwise" :
461533 return PackedRowwiseParallel (use_dtensor = False )
534+ elif style == "replicate" :
535+ return ReplicateParallel ()
536+ elif style == "replicate_output_dtensor" :
537+ return ReplicateParallel (use_local_output = False )
462538 else :
463539 raise ValueError (f"Unsupported parallel style value: { style } " )
464540
0 commit comments