Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/paddle/distributed/auto_parallel/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def set_field_default_config(category, field, default_value):
set_field_default_config(BASE, "use_cache", True)
set_field_default_config(BASE, "return_numpy", True)
set_field_default_config(BASE, "all_ranks", False)
set_field_default_config(BASE, "split_data", False)
set_field_default_config(BASE, "split_data", True)
set_field_default_config(BASE, "seed", None)
set_field_default_config(BASE, "reinit", False) # Only for debug

Expand Down
19 changes: 10 additions & 9 deletions python/paddle/distributed/auto_parallel/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class Engine:

import paddle
import paddle.vision.transforms as T
import paddle.distributed.auto_parallel as auto
from paddle.distributed.fleet import auto
from paddle.vision.datasets import MNIST

transform = T.Compose([
Expand Down Expand Up @@ -540,7 +540,7 @@ def fit(self,

import paddle
import paddle.vision.transforms as T
import paddle.distributed.auto_parallel as auto
from paddle.distributed.fleet import auto
from paddle.vision.datasets import MNIST

transform = T.Compose([
Expand Down Expand Up @@ -663,7 +663,7 @@ def evaluate(self,

import paddle
import paddle.vision.transforms as T
import paddle.distributed.auto_parallel as auto
from paddle.distributed.fleet import auto
from paddle.vision.datasets import MNIST

transform = T.Compose([
Expand Down Expand Up @@ -771,7 +771,7 @@ def predict(self,

import paddle
import paddle.vision.transforms as T
import paddle.distributed.auto_parallel as auto
from paddle.distributed.fleet import auto
from paddle.vision.datasets import MNIST

transform = T.Compose([
Expand Down Expand Up @@ -978,9 +978,10 @@ def _set_recompute_ckpts(self):

# extract ckpts by specific model
if isinstance(self._model, paddle.nn.Layer):
if hasattr(
self._model, "gpt"
) and self._model.__class__.__name__ == 'GPTForPretraining':
if hasattr(self._model,
"gpt") and self._model.__class__.__name__ in [
'GPTForPretraining', 'GPTForPretrainingAuto'
]:
exact_ckpts = self._model.gpt.checkpoints
else:
exact_ckpts = recompute.checkpoints
Expand Down Expand Up @@ -1041,7 +1042,7 @@ def save(self, path, training=True):
.. code-block:: python
import paddle
import paddle.vision.transforms as T
import paddle.distributed.auto_parallel as auto
from paddle.distributed.fleet import auto
from paddle.vision.datasets import MNIST

transform = T.Compose([
Expand Down Expand Up @@ -1107,7 +1108,7 @@ def load(self, path, strict=True, load_optimizer=True):
.. code-block:: python
import paddle
import paddle.vision.transforms as T
import paddle.distributed.auto_parallel as auto
from paddle.distributed.fleet import auto
from paddle.vision.datasets import MNIST

transform = T.Compose([
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/distributed/auto_parallel/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def shard_tensor(x, process_mesh=None, shard_spec=None):
.. code-block:: python

import paddle
import paddle.distributed.auto_parallel as auto
from paddle.distributed.fleet import auto

mesh = auto.ProcessMesh([[0, 1], [2, 3]], dim_names=["x", "y"])
x = paddle.ones([4, 6])
Expand Down Expand Up @@ -129,7 +129,7 @@ def shard_op(op, process_mesh=None, in_shard_specs=None, out_shard_specs=None):
.. code-block:: python

import paddle
import paddle.distributed.auto_parallel as auto
from paddle.distributed.fleet import auto

x = paddle.ones([4, 6])
y = paddle.zeros([4, 6])
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/distributed/auto_parallel/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import numpy as np

import paddle
import paddle.distributed.auto_parallel as auto
from paddle.distributed.fleet import auto
from .cost_model import estimate_cost
from .dist_op import DistributedOperator
from .process_group import _g_process_group_map
Expand Down
11 changes: 6 additions & 5 deletions python/paddle/distributed/auto_parallel/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,11 @@ def to_dict(self):
return result_dict

def __repr__(self):
return yaml.dump(self.to_dict(),
default_flow_style=False,
sort_keys=True,
indent=4)
result_dict = self.to_dict()
string = "{"
for k, v in result_dict.items():
string += "\"%s\":\"%s\"," % (k, v)
return string + "}"

def __deepcopy__(self, memo):
cls = self.__class__
Expand Down Expand Up @@ -130,7 +131,7 @@ class Strategy(BaseConfig):
.. code-block:: python

import paddle
import paddle.distributed.auto_parallel as auto
from paddle.distributed.fleet import auto

strategy = auto.Strategy()
sharding = strategy.sharding
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/distributed/auto_parallel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def convert_to_dims_mapping(shard_spec, process_mesh):
for shard in shard_spec:
if shard is None:
dims_mapping.append(-1)
elif process_mesh.topology[process_mesh.dim_names.index(shard)] == 1:
dims_mapping.append(-1)
else:
dims_mapping.append(process_mesh.dim_names.index(shard))
return dims_mapping
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/distributed/fleet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,5 @@
shrink = fleet.shrink
get_hybrid_communicate_group = fleet.get_hybrid_communicate_group
distributed_scaler = distributed_scaler

from .. import auto_parallel as auto
42 changes: 18 additions & 24 deletions python/paddle/distributed/passes/auto_parallel_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from functools import reduce
from collections import OrderedDict, defaultdict
from collections import OrderedDict
import numpy as np

import paddle
Expand All @@ -22,12 +22,15 @@
from .pass_base import PassBase, register_pass
from paddle.distributed.fleet.meta_optimizers.common import is_backward_op, is_optimizer_op
from paddle.distributed.auto_parallel.process_group import new_process_group
from paddle.distributed.auto_parallel.operators.common import is_parameter_related
from paddle.distributed.auto_parallel.operators.common import is_parameter_related, is_data_parallel_reduce_op
from paddle.distributed.auto_parallel.utils import _get_comm_group, naive_set_dist_op_attr_for_program_by_mesh_and_mapping, set_var_dist_attr

OpRole = core.op_proto_and_checker_maker.OpRole
OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
_skip_ops = ['create_py_reader', 'create_double_buffer_reader', 'read']
_skip_ops = [
'create_py_reader', 'create_double_buffer_reader', 'read', 'slice', 'split',
'assign', "send_v2"
]
# update here to support new optimizers
_supported_optimizer_type = [
"adam", "adamax", "adamw", "decayed_adagrad", "momentum", "dgc_momentum",
Expand Down Expand Up @@ -393,15 +396,16 @@ def _shard_gradient_synchronization(self, main_block):

dp_ring_ids = [group.id for group in self.dp_groups]
for idx, op in reversed(list(enumerate(main_block.ops))):
if _is_param_grad_allreduce_op(op, main_block, dp_ring_ids):
if is_data_parallel_reduce_op(op):
input_name = op.input_arg_names[0]
base_name = _get_base_name_from_grad_name(input_name)
sharding_info = self.varname_to_sharding_info[base_name]
_insert_reduce_op(main_block, idx, input_name,
sharding_info.group.id,
sharding_info.get_var_rank(base_name),
self._dist_context)
if not self.partial_sharding:
if not self.partial_sharding or not sharding_info.is_in_local_shard(
base_name):
main_block._remove_op(idx + 1, sync=False)
else:
op._set_attr("ring_id", self.outer_dp_group.id)
Expand Down Expand Up @@ -439,7 +443,10 @@ def _shard_parameter(self, main_block, startup_block):
continue

for input_name in op.desc.input_arg_names():
if op.type == "cast":
# NOTE hack for embedding op when AMP 02-3
# paddle amp force embedding (lookup table) to be run on fp32
if _is_param_fp16_cast_op(main_block, op,
sharding_info.param_names):
continue
if input_name not in need_broadcast_vars:
continue
Expand Down Expand Up @@ -646,24 +653,6 @@ def _get_base_name_from_grad_name(grad_name):
return base_name


def _is_param_grad_allreduce_op(op, block, dp_ring_ids):

if not is_backward_op(op):
return False
if op.type != "c_allreduce_sum":
return False
if op.attr('ring_id') not in dp_ring_ids:
return False

output_name = op.output_arg_names[0]
base_name = _get_base_name_from_grad_name(output_name)

if not block.has_var(base_name):
return False

return block.var(base_name).is_parameter


def _is_param_grad_sum_op(op, block):

if not is_backward_op(op):
Expand Down Expand Up @@ -756,9 +745,14 @@ def get_var_rank(self, varname):
return self.param_to_rank[varname]
return -1

# determine fp32 and fp16 (cast) param
def is_in_local_shard(self, param_name):
return self.get_var_rank(param_name) == self.local_rank

# NOTE the follwo logic is designed for supporting AMP O1 when
# the param would be cast to fp16 before used for caculation.
# and sharding should only broadcast the casted fp16 param
# instead of the origin fp32 version param.
def get_broadcast_vars_and_param_usage(self, block):
broadcast_vars = set([])
fp16_params = set([])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import numpy as np
import paddle

import paddle.distributed.auto_parallel as auto
from paddle.distributed.fleet import auto
from paddle.fluid.dygraph.parallel import ParallelEnv
from get_gpt_model import generate_model, create_data_holder, FakeDataset

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from paddle.fluid import layers
from paddle.io import IterableDataset, DataLoader
from paddle.distributed import fleet
import paddle.distributed.auto_parallel as auto
from paddle.distributed.fleet import auto

paddle.enable_static()
_global_parallel_strategy = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import numpy as np

import paddle.distributed.auto_parallel as auto
from paddle.distributed.fleet import auto
from auto_parallel_relaunch_model import mlp_pretrain_forward
from auto_parallel_relaunch_model import batch_generator_creator

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import numpy as np
import paddle

import paddle.distributed.auto_parallel as auto
from paddle.distributed.fleet import auto
from paddle.fluid.dygraph.parallel import ParallelEnv
from get_gpt_model import generate_model, create_data_holder, FakeDataset

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from paddle.fluid import layers
from paddle.io import Dataset, IterableDataset, DataLoader

import paddle.distributed.auto_parallel as auto
from paddle.distributed.fleet import auto
from paddle.optimizer.lr import CosineAnnealingDecay
from paddle.fluid.dataloader.collate import default_collate_fn

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from paddle.fluid import layers
from paddle.io import Dataset, DataLoader

import paddle.distributed.auto_parallel as auto
from paddle.distributed.fleet import auto

paddle.enable_static()
batch_size = 2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import random

import paddle
import paddle.distributed.auto_parallel as auto
from paddle.distributed.fleet import auto

sys.path.append("..")
import auto_parallel_gpt_model as modeling
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import numpy as np
import paddle

import paddle.distributed.auto_parallel as auto
from paddle.distributed.fleet import auto
from paddle.fluid.dygraph.parallel import ParallelEnv
from get_gpt_model import generate_model, create_data_holder, FakeDataset

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import paddle
import unittest
import numpy as np
import paddle.distributed.auto_parallel as auto
from paddle.distributed.fleet import auto
from paddle.incubate.autograd import Hessian

np.random.seed(1234)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from paddle.io import Dataset, IterableDataset, DataLoader
from paddle.static import InputSpec

import paddle.distributed.auto_parallel as auto
from paddle.distributed.fleet import auto
from paddle.optimizer.lr import CosineAnnealingDecay
from paddle.fluid.dataloader.collate import default_collate_fn

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from paddle.fluid import layers
from paddle.io import Dataset, IterableDataset, DataLoader

import paddle.distributed.auto_parallel as auto
from paddle.distributed.fleet import auto
from engine_api_dp import MyDataset

paddle.enable_static()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import numpy as np
import paddle

import paddle.distributed.auto_parallel as auto
from paddle.distributed.fleet import auto
from paddle.fluid.dygraph.parallel import ParallelEnv
from get_gpt_model import generate_model, create_data_holder, FakeDataset

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import numpy as np
import paddle

import paddle.distributed.auto_parallel as auto
from paddle.distributed.fleet import auto
from paddle.fluid.dygraph.parallel import ParallelEnv
from get_gpt_model import generate_model, create_data_holder, FakeDataset

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import paddle.static as static
import paddle.nn.functional as F
import paddle.utils as utils
import paddle.distributed.auto_parallel as auto
from paddle.distributed.fleet import auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed import fleet
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import paddle.nn.functional as F

from paddle.distributed import fleet
import paddle.distributed.auto_parallel as auto
from paddle.distributed.fleet import auto
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import unittest
import paddle
import paddle.distributed.auto_parallel as auto
from paddle.distributed.fleet import auto

from paddle.fluid import program_guard
from paddle.fluid.backward import append_backward
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import unittest
import paddle
import paddle.distributed.auto_parallel as auto
from paddle.distributed.fleet import auto

from paddle.fluid import program_guard
from paddle.fluid.backward import append_backward
Expand Down
Loading