Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions paddle/fluid/pybind/eager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -244,9 +244,18 @@ void InitDistTensorWithTensor(TensorObject* self,
std::make_shared<DistTensor>(tensor, process_mesh, placements));
VLOG(4) << "Same place, do ShareDataWith for DistTensor.";
} else {
std::shared_ptr<phi::DenseTensor> tensor =
std::static_pointer_cast<phi::DenseTensor>(
src.copy_to(place, true).impl());
std::shared_ptr<phi::DenseTensor> tensor;
if (src.initialized()) {
tensor = std::static_pointer_cast<phi::DenseTensor>(
src.copy_to(place, true).impl());
} else {
// lazy init branch. The src tensor is on undefined place.
PADDLE_ENFORCE(
src.place().GetType() == phi::AllocationType::UNDEFINED,
phi::errors::InvalidArgument("Only undefined place is support for "
"uninitialized input tensor."));
tensor = std::static_pointer_cast<phi::DenseTensor>(src.impl());
}
self->tensor.set_impl(
std::make_shared<DistTensor>(tensor, process_mesh, placements));
VLOG(4) << "Different place, do TensorCopy for DistTensor.";
Expand Down
38 changes: 34 additions & 4 deletions python/paddle/distributed/auto_parallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from paddle.framework import core

from .placement_type import check_placements_equal, get_shard_spec
from .random import determinate_rng, rng_state

# There are the auto parallel API of the unified version of dynamic and static mode.
# Some APIs have the same name with the previous APIs implementation, which are
Expand Down Expand Up @@ -171,19 +172,48 @@ def shard_tensor(
# `paddle.to_tensor` supports both dynamic and static mode
if stop_gradient is None:
stop_gradient = getattr(data, "stop_gradient", True)
tensor = paddle.to_tensor(
data, dtype=dtype, place=place, stop_gradient=stop_gradient
)
if isinstance(data, EagerParamBase) and not data._is_initialized():
assert (
data._init_func is not None
), "Get an uninitialized param with an unregistered init_func."
tensor = data
else:
tensor = paddle.to_tensor(
data, dtype=dtype, place=place, stop_gradient=stop_gradient
)

if paddle.in_dynamic_mode():
# here the dist tensor is deep copy constructed
if isinstance(data, EagerParamBase):
return EagerParamBase.from_tensor(

def lazy_init_hook(param, origin_hook):
# lazy init hook with randomness controlling
def _init_func(var, block):
# get the unique rng name
rng_name = determinate_rng(
dist.get_rank(),
process_mesh=param.process_mesh,
placements=param.placements,
)
# real call the init function
with rng_state(rng_name):
origin_hook(var, block)

return _init_func

dist_param = EagerParamBase.from_tensor(
tensor,
process_mesh=mesh,
placements=placements,
**tensor.__dict__,
)
if tensor._init_func is not None:
origin_init_func = tensor._init_func
dist_param.set_init_func(
lazy_init_hook(dist_param, origin_init_func)
)

return dist_param
else:
return paddle.Tensor(
tensor, process_mesh=mesh, placements=placements, place=place
Expand Down
40 changes: 37 additions & 3 deletions python/paddle/distributed/auto_parallel/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import logging

import paddle
Expand All @@ -22,6 +23,7 @@
_logger = get_logger(logging.INFO)

_rng_name_to_seed = {}
_rng_name_to_states = {}
_inited_rng_name_to_seed = {}
_enable_random_control = False
_basic_seed = 42
Expand Down Expand Up @@ -71,7 +73,16 @@ def parallel_manual_seed(seed, name=""):
_basic_name = name


def determinate_rng(rank, dims_mapping, process_mesh):
def determinate_rng(
rank, dims_mapping=None, process_mesh=None, placements=None
):
assert process_mesh is not None, "Must provide process mesh"
assert (
dims_mapping is not None or placements is not None
), "Must provide one of dims mapping or placements."
assert not (
dims_mapping is not None and placements is not None
), "Cannot provide dims mapping and placements at same time."
# TODO(JZ-LIANG) Support Mesh with any high rank
# use a string to unique integer hashing algorithm for seed computation.
# instead of using offsets to coodinate seed across devices.
Expand All @@ -98,7 +109,9 @@ def determinate_rng(rank, dims_mapping, process_mesh):
seed_ += _mesh_offset * (unique_id + 1)

for i in range(len(process_mesh.shape)):
if i not in dims_mapping:
if (dims_mapping is not None and i not in dims_mapping) or (
placements is not None and not placements[i].is_shard()
):
relative_idx = -1
else:
relative_idx = _get_idx_in_axis(
Expand All @@ -112,6 +125,7 @@ def determinate_rng(rank, dims_mapping, process_mesh):
seed_ += _dim_offsets[i] * (relative_idx + 1)

global _rng_name_to_seed
global _rng_name_to_states
if sharding_expr in _rng_name_to_seed:
assert _rng_name_to_seed[sharding_expr] == seed_
else:
Expand All @@ -121,10 +135,30 @@ def determinate_rng(rank, dims_mapping, process_mesh):
seed_, sharding_expr, _rng_name_to_seed
)
_rng_name_to_seed[sharding_expr] = seed_

if paddle.in_dynamic_mode():
# for dygraph, just init the seed when meeting a new seed
orig_rng_state = paddle.get_rng_state()
paddle.seed(seed_)
_rng_name_to_states[sharding_expr] = paddle.get_rng_state()
paddle.set_rng_state(orig_rng_state)
return sharding_expr


@contextlib.contextmanager
def rng_state(name):
global _rng_name_to_states
assert (
name in _rng_name_to_states
), f"The rng state name {name} haven't been init. "
orig_rng_state = paddle.get_rng_state()
paddle.set_rng_state(_rng_name_to_states[name])
try:
yield
finally:
_rng_name_to_states[name] = paddle.get_rng_state()
paddle.set_rng_state(orig_rng_state)


def init_auto_parallel_rng():
if not is_enable_auto_rand_ctrl():
return
Expand Down
3 changes: 3 additions & 0 deletions python/paddle/nn/initializer/Bilinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ def forward(self, var, block=None):
Returns:
The initialization op
"""
assert not (
isinstance(var, framework.EagerParamBase) and var.is_dist()
), "Currently, Bilinear initializer not support lazy init for dist param."
block = self._check_block(block)

if not isinstance(var, (framework.Variable, pir.core.ParameterMeta)):
Expand Down
3 changes: 3 additions & 0 deletions python/paddle/nn/initializer/assign.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ def forward(self, var, block=None):
Returns:
The initialization op
"""
assert not (
isinstance(var, framework.EagerParamBase) and var.is_dist()
), "Currently, assign initializer not support lazy init for dist param."
block = self._check_block(block)

assert isinstance(
Expand Down
3 changes: 3 additions & 0 deletions python/paddle/nn/initializer/dirac.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ def __call__(self, var, block=None):
Returns:
The most critical OP(scatter) in this initializer, which contains 7~8 ops in total.
"""
assert not (
isinstance(var, framework.EagerParamBase) and var.is_dist()
), "Currently, dirac initializer not support lazy init for dist param."
block = self._check_block(block)
assert isinstance(var, (framework.Variable, pir.core.ParameterMeta))
assert isinstance(block, (framework.Block, pir.Block))
Expand Down
12 changes: 10 additions & 2 deletions python/paddle/nn/initializer/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@

import numpy as np

from ...base.framework import default_main_program, in_dygraph_mode
from ...base.framework import (
EagerParamBase,
default_main_program,
in_dygraph_mode,
)
from .lazy_init import lazy_init_helper

__all__ = []
Expand Down Expand Up @@ -86,7 +90,11 @@ def _compute_fans(self, var):
Returns:
tuple of two integers (fan_in, fan_out).
"""
shape = var.shape
shape = (
var._local_shape
if (isinstance(var, EagerParamBase) and var.is_dist())
else var.shape
)
if not shape or len(shape) == 0:
fan_in = fan_out = 1
elif len(shape) == 1:
Expand Down
3 changes: 3 additions & 0 deletions python/paddle/nn/initializer/kaiming.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ def forward(self, var, block=None):
Returns:
The initialization op.
"""
assert not (
isinstance(var, framework.EagerParamBase) and var.is_dist()
), "Currently, kaiming initializer not support lazy init for dist param."
block = self._check_block(block)
assert isinstance(
var, (framework.Variable, paddle.pir.core.ParameterMeta)
Expand Down
3 changes: 3 additions & 0 deletions python/paddle/nn/initializer/normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ def forward(self, var, block=None):
Returns:
The initialization op.
"""
assert not (
isinstance(var, framework.EagerParamBase) and var.is_dist()
), "Currently, normal initializer not support lazy init for dist param."
block = self._check_block(block)

assert isinstance(block, (framework.Block, pir.Block))
Expand Down
3 changes: 3 additions & 0 deletions python/paddle/nn/initializer/orthogonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ def __call__(self, var, block=None):
Returns:
The last initialization op, it contain 8 ops in orthogonal initializer.
"""
assert not (
isinstance(var, framework.EagerParamBase) and var.is_dist()
), "Currently, orthogonal initializer not support lazy init for dist param."
block = self._check_block(block)
assert isinstance(var, (framework.Variable, pir.core.ParameterMeta))
assert isinstance(block, (framework.Block, pir.Block))
Expand Down
3 changes: 3 additions & 0 deletions python/paddle/nn/initializer/uniform.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ def forward(self, var, block=None):
Returns:
The initialization op
"""
assert not (
isinstance(var, framework.EagerParamBase) and var.is_dist()
), "Currently, uniform initializer not support lazy init for dist param."
block = self._check_block(block)

assert isinstance(block, (framework.Block, pir.Block))
Expand Down
17 changes: 12 additions & 5 deletions python/paddle/nn/initializer/xavier.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,9 @@ def forward(self, var, block=None):
name=unique_name.generate(
".".join(['xavier_init', var.name, 'tmp'])
),
shape=var.shape,
shape=var._local_shape
if (isinstance(var, framework.EagerParamBase) and var.is_dist())
else var.shape,
dtype=out_dtype,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
Expand Down Expand Up @@ -151,10 +153,15 @@ def forward(self, var, block=None):
if var.dtype == core.VarDesc.VarType.FP16 or (
var.dtype == core.VarDesc.VarType.BF16 and not self._uniform
):
var_tmp = _C_ops.cast(out_var, var.dtype)
var_tmp._share_underline_tensor_to(var)
else:
out_var._share_underline_tensor_to(var)
out_var = _C_ops.cast(out_var, var.dtype)
if isinstance(var, framework.EagerParamBase) and var.is_dist():
# lazy init for dist tensor
out_var = (
paddle.distributed.auto_parallel.api.dtensor_from_local(
out_var, var.process_mesh, var.placements
)
)
out_var._share_underline_tensor_to(var)
return None
elif in_pir_mode():
if self._uniform:
Expand Down
4 changes: 4 additions & 0 deletions test/auto_parallel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,10 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
test_semi_auto_parallel_single_strategy)
set_tests_properties(test_semi_auto_parallel_single_strategy
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 400)
py_test_modules(test_semi_auto_parallel_lazy_init MODULES
test_semi_auto_parallel_lazy_init)
set_tests_properties(test_semi_auto_parallel_lazy_init
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 120)
py_test_modules(test_semi_auto_parallel_in_framework MODULES
test_semi_auto_parallel_in_framework)
set_tests_properties(test_semi_auto_parallel_in_framework
Expand Down
62 changes: 62 additions & 0 deletions test/auto_parallel/semi_auto_parallel_lazy_init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import paddle
import paddle.distributed as dist
from paddle import LazyGuard


class TestSemiAutoParallelLazyInit:
def __init__(self):
self._backend = os.getenv("backend")
self._seed = eval(os.getenv("seed"))
self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"])

def test_replicate(self):
paddle.distributed.auto_parallel.parallel_manual_seed(self._seed)
with LazyGuard():
linear = paddle.nn.Linear(10, 10)
linear.weight = dist.shard_tensor(
linear.weight, self._mesh, [dist.Replicate()]
)
linear.bias = dist.shard_tensor(
linear.bias, self._mesh, [dist.Replicate()]
)
for param in linear.parameters():
assert not param._is_initialized()
param.initialize()
assert param._is_initialized()

local_weight_md5 = linear.weight._local_value()._md5sum()
mesh0 = dist.ProcessMesh([0], dim_names=["x"])
mesh1 = dist.ProcessMesh([1], dim_names=["x"])
tmp = paddle.distributed.auto_parallel.api.dtensor_from_local(
linear.weight._local_value(),
mesh0 if dist.get_rank() == 0 else mesh1,
[dist.Replicate()],
)
tmp = dist.reshard(
tmp, mesh1 if dist.get_rank() == 0 else mesh0, [dist.Replicate()]
)
tmp_md5 = tmp._local_value()._md5sum()
assert local_weight_md5 == tmp_md5

def run_test_case(self):
self.test_replicate()


if __name__ == '__main__':
TestSemiAutoParallelLazyInit().run_test_case()
Loading