@@ -238,7 +238,7 @@ def sharding_specs(self):
238238def shard_tensor (
239239 data : Tensor | TensorLike | NestedNumericSequence ,
240240 mesh : ProcessMesh ,
241- placements : list [Placement ],
241+ placements : Sequence [Placement ],
242242 dtype : DTypeLike | None = None ,
243243 place : PlaceLike | None = None ,
244244 stop_gradient : bool | None = None ,
@@ -780,15 +780,15 @@ def dtensor_to_local(dist_tensor, mesh, placements):
780780def dtensor_from_fn (
781781 fn : Callable [..., Tensor ],
782782 mesh : ProcessMesh ,
783- placements : list [Placement ],
783+ placements : Sequence [Placement ],
784784 * args : Any ,
785785 ** kwargs : Any ,
786786) -> Tensor :
787787 """
788788 Construct a Distributed Tensor from a function of arguments.
789789
790790 Args:
791- fn (callable): A callable function that takes arguments of Distributed Tensor and returns tensor .
791+ fn (callable): A callable function that creates and returns a tensor, such as paddle.ones, paddle.zeros, etc .
792792 mesh(paddle.distributed.ProcessMesh): The `ProcessMesh` object describes the Cartesian topology of the used processes.
793793 placements(list[paddle.distributed.Placement]): the placements describe how to place the tensor on ProcessMesh, it can
794794 be Shard, Replicate and Partial.
@@ -818,7 +818,7 @@ def dtensor_from_fn(
818818
819819
820820def reshard (
821- dist_tensor : Tensor , mesh : ProcessMesh , placements : list [Placement ]
821+ dist_tensor : Tensor , mesh : ProcessMesh , placements : Sequence [Placement ]
822822) -> Tensor :
823823 """
824824 Reshard a distributed ``paddle.Tensor`` with given distributed attributes.
0 commit comments