Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
337 changes: 123 additions & 214 deletions python/paddle/fluid/dygraph/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,121 @@ def trainer_endpoints(self):
Env = ParallelEnv


def _build_default_parallel_strategy():
strategy = ParallelStrategy()
strategy.nranks = ParallelEnv().nranks
strategy.local_rank = ParallelEnv().local_rank
strategy.trainer_endpoints = ParallelEnv().trainer_endpoints
strategy.current_endpoint = ParallelEnv().current_endpoint
return strategy


def _coalesce_tensors(var_groups):
from ..layers import nn
coalesced_grads_and_grad_vars = []
for group_id, grad_vars in var_groups.items():
flattened_vars = []
g_var_shapes = []
for g_var in grad_vars:
g_var_shapes.append(g_var.shape)
flattened_vars.append(
nn.reshape(
x=g_var, shape=[np.prod(g_var.shape)], inplace=True))
coalesced_grad = nn.concat(flattened_vars)
coalesced_grads_and_grad_vars.append(
[coalesced_grad, grad_vars, g_var_shapes])
return coalesced_grads_and_grad_vars


@framework.dygraph_only
def _reshape_inplace(x, shape):
x_shape = framework._varbase_creator(dtype=x.dtype)
framework._dygraph_tracer().trace_op(
type="reshape2",
inputs={'X': x},
outputs={'Out': x,
'XShape': x_shape},
attrs={'shape': shape})


@framework.dygraph_only
def _split_tensors(coalesced_grads_and_grad_vars):
for coalesced_grad, origin_grad_vars, grad_shapes in coalesced_grads_and_grad_vars:
grad_var_len = [np.prod(g_shape) for g_shape in grad_shapes]
framework._dygraph_tracer().trace_op(
type='split',
inputs={'X': coalesced_grad},
outputs={'Out': origin_grad_vars},
attrs={'sections': grad_var_len,
'axis': 0})
for g_var, g_shape in zip(origin_grad_vars, grad_shapes):
_reshape_inplace(x=g_var, shape=g_shape)
assert g_var.shape == g_shape


def scale_loss(loss):
if not ParallelEnv().world_size > 1:
return loss

loss_scale = to_variable(
np.array([ParallelEnv().world_size]).astype("float32"))
loss_scale.stop_gradient = True
scaled_loss = loss / loss_scale
return scaled_loss


@no_grad
def apply_collective_grads(parameters):
if not ParallelEnv().world_size > 1:
return

grad_var_set = set()
grad_vars = []
sparse_grad_vars = []
strategy = _build_default_parallel_strategy()
for param in parameters:
# NOTE(zcd): The grad_ivar maybe no generated.
if param.trainable and (param._grad_ivar() is not None):
g_var = param._grad_ivar()
if g_var._is_sparse():
sparse_grad_vars.append(g_var)
continue
grad_vars.append(g_var)
assert g_var not in grad_var_set
grad_var_set.add(g_var)

if sparse_grad_vars:
sparse_grad_vars.sort(key=lambda x: x.name)
for grad_var in sparse_grad_vars:
grad_var._allreduce(strategy)

# FIXME(zcd): the type of the var should be LoDTensor, i.e
# the gradients should be dense, otherwise, the following
# logic should be updated.
# 128 MB as a group
mega_bytes = 128 * 1024 * 1024
group_idx = 0
memory_counter = 0
grad_var_groups = OrderedDict()
dtype = grad_vars[0].dtype
for g_var in grad_vars:
# NOTE: the dtype of the same group should be the same.
bytes = np.prod(g_var.shape) * core.size_of_dtype(g_var.dtype)
if memory_counter < mega_bytes and dtype == g_var.dtype:
memory_counter += bytes
else:
memory_counter = bytes
group_idx += 1
grad_var_groups.setdefault(group_idx, []).append(g_var)

coalesced_grads_and_vars = _coalesce_tensors(grad_var_groups)

for coalesced_grad, _, _ in coalesced_grads_and_vars:
coalesced_grad._allreduce(strategy)

_split_tensors(coalesced_grads_and_vars)


class DataParallel(layers.Layer):
"""
Run the dygraph module with data parallelism.
Expand Down Expand Up @@ -325,232 +440,26 @@ def __init__(self, layers, strategy=None):
if strategy is not None:
self._strategy = strategy
else:
self._strategy = ParallelStrategy()
self._strategy.nranks = ParallelEnv().nranks
self._strategy.local_rank = ParallelEnv().local_rank
self._strategy.trainer_endpoints = ParallelEnv().trainer_endpoints
self._strategy.current_endpoint = ParallelEnv().current_endpoint
self._strategy = _build_default_parallel_strategy()

def forward(self, *inputs, **kwargs):
return self._layers(*inputs, **kwargs)

@deprecated(
since="2.0.0", reason="This method does not need to be called anymore.")
def scale_loss(self, loss):
"""
Scale the loss. In data parallel mode, the loss should be scale with
the number of trainers. If not in data parallel mode, return the loss
directly.

Args:
loss(Variable): The loss of the current Model.

Returns:
Variable: the scaled loss.

Examples:
.. code-block:: python

import paddle
import paddle.nn as nn
import paddle.optimizer as opt
import paddle.distributed as dist

class LinearNet(nn.Layer):
def __init__(self):
super(LinearNet, self).__init__()
self._linear1 = nn.Linear(10, 10)
self._linear2 = nn.Linear(10, 1)

def forward(self, x):
return self._linear2(self._linear1(x))

def train():
# 1. enable dynamic mode
paddle.disable_static()

# 2. initialize parallel environment
dist.init_parallel_env()

# 3. create data parallel layer & optimizer
layer = LinearNet()
dp_layer = paddle.DataParallel(layer)

loss_fn = nn.MSELoss()
adam = opt.Adam(
learning_rate=0.001, parameters=dp_layer.parameters())

# 4. run layer
inputs = paddle.randn([10, 10], 'float32')
outputs = dp_layer(inputs)
labels = paddle.randn([10, 1], 'float32')
loss = loss_fn(outputs, labels)

loss = dp_layer.scale_loss(loss)
loss.backward()
dp_layer.apply_collective_grads()

adam.step()
adam.clear_grad()

if __name__ == '__main__':
# 1. start by ``paddle.distributed.spawn`` (default)
dist.spawn(train, nprocs=2)
# 2. start by ``paddle.distributed.launch``
# train()
Deprecated method, now ``scale_loss`` is an empty method.
"""
if not self._is_data_parallel_mode():
return loss

loss_scale = to_variable(
np.array([self._strategy.nranks]).astype("float32"))
loss_scale.stop_gradient = True
loss = loss / loss_scale
return loss

def _coalesce_tensors(self, var_groups):
from ..layers import nn
coalesced_grads_and_grad_vars = []
for group_id, grad_vars in var_groups.items():
flattened_vars = []
g_var_shapes = []
for g_var in grad_vars:
g_var_shapes.append(g_var.shape)
flattened_vars.append(
nn.reshape(
x=g_var, shape=[np.prod(g_var.shape)], inplace=True))
coalesced_grad = nn.concat(flattened_vars)
coalesced_grads_and_grad_vars.append(
[coalesced_grad, grad_vars, g_var_shapes])
return coalesced_grads_and_grad_vars

def _reshape_inplace(self, x, shape):
x_shape = self._helper.create_variable_for_type_inference(dtype=x.dtype)
self._helper.append_op(
type="reshape2",
inputs={'X': x},
attrs={'shape': shape},
outputs={'Out': x,
'XShape': x_shape})

def _split_tensors(self, coalesced_grads_and_grad_vars):
from ..layers import nn
for coalesced_grad, origin_grad_vars, grad_shapes in coalesced_grads_and_grad_vars:
grad_var_len = [np.prod(g_shape) for g_shape in grad_shapes]
self._helper.main_program.current_block().append_op(
type='split',
inputs={'X': coalesced_grad},
outputs={'Out': origin_grad_vars},
attrs={'sections': grad_var_len,
'axis': 0})
for g_var, g_shape in zip(origin_grad_vars, grad_shapes):
self._reshape_inplace(x=g_var, shape=g_shape)
assert g_var.shape == g_shape

@no_grad
@deprecated(
since="2.0.0", reason="This method does not need to be called anymore.")
def apply_collective_grads(self):
"""
AllReduce the Parameters' gradient.

Examples:
.. code-block:: python

import paddle
import paddle.nn as nn
import paddle.optimizer as opt
import paddle.distributed as dist

class LinearNet(nn.Layer):
def __init__(self):
super(LinearNet, self).__init__()
self._linear1 = nn.Linear(10, 10)
self._linear2 = nn.Linear(10, 1)

def forward(self, x):
return self._linear2(self._linear1(x))

def train():
# 1. enable dynamic mode
paddle.disable_static()

# 2. initialize parallel environment
dist.init_parallel_env()

# 3. create data parallel layer & optimizer
layer = LinearNet()
dp_layer = paddle.DataParallel(layer)

loss_fn = nn.MSELoss()
adam = opt.Adam(
learning_rate=0.001, parameters=dp_layer.parameters())

# 4. run layer
inputs = paddle.randn([10, 10], 'float32')
outputs = dp_layer(inputs)
labels = paddle.randn([10, 1], 'float32')
loss = loss_fn(outputs, labels)

loss = dp_layer.scale_loss(loss)
loss.backward()
dp_layer.apply_collective_grads()

adam.step()
adam.clear_grad()

if __name__ == '__main__':
# 1. start by ``paddle.distributed.spawn`` (default)
dist.spawn(train, nprocs=2)
# 2. start by ``paddle.distributed.launch``
# train()
Deprecated method, now ``apply_collective_grads`` is an empty method.
"""
if not self._is_data_parallel_mode():
return

grad_var_set = set()
grad_vars = []
sparse_grad_vars = []
for param in self._layers.parameters():
# NOTE(zcd): The grad_ivar maybe no generated.
if param.trainable and (param._grad_ivar() is not None):
g_var = param._grad_ivar()
if g_var._is_sparse():
sparse_grad_vars.append(g_var)
continue
grad_vars.append(g_var)
assert g_var not in grad_var_set
grad_var_set.add(g_var)

if sparse_grad_vars:
sparse_grad_vars.sort(key=lambda x: x.name)
for grad_var in sparse_grad_vars:
grad_var._allreduce(self._strategy)

# FIXME(zcd): the type of the var should be LoDTensor, i.e
# the gradients should be dense, otherwise, the following
# logic should be updated.
# 128 MB as a group
mega_bytes = 128 * 1024 * 1024
group_idx = 0
memory_counter = 0
grad_var_groups = OrderedDict()
dtype = grad_vars[0].dtype
for g_var in grad_vars:
# Note: the dtype of the same group should be the same.
bytes = np.prod(g_var.shape) * core.size_of_dtype(g_var.dtype)
if memory_counter < mega_bytes and dtype == g_var.dtype:
memory_counter += bytes
else:
memory_counter = bytes
group_idx += 1
grad_var_groups.setdefault(group_idx, []).append(g_var)

coalesced_grads_and_vars = self._coalesce_tensors(grad_var_groups)

for coalesced_grad, _, _ in coalesced_grads_and_vars:
coalesced_grad._allreduce(self._strategy)

self._split_tensors(coalesced_grads_and_vars)

def _is_data_parallel_mode(self):
return self._strategy.nranks > 1
return

def state_dict(self,
destination=None,
Expand Down
12 changes: 10 additions & 2 deletions python/paddle/fluid/dygraph/varbase_patch_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,15 @@
# limitations under the License.

import inspect
import numpy as np

import paddle
from .. import framework
from .. import core
from ..framework import Variable, Parameter, ParamBase
from .base import switch_to_static_graph
import numpy as np
from .math_op_patch import monkey_patch_math_varbase
from .parallel import scale_loss


def monkey_patch_varbase():
Expand Down Expand Up @@ -165,7 +168,12 @@ def backward(self, retain_graph=False):

"""
if framework.in_dygraph_mode():
self._run_backward(framework._dygraph_tracer(), retain_graph)
if paddle.distributed.get_world_size() > 1:
scaled_loss = scale_loss(self)
scaled_loss._run_backward(framework._dygraph_tracer(),
retain_graph)
else:
self._run_backward(framework._dygraph_tracer(), retain_graph)
else:
raise ValueError(
"Variable.backward() is only available in DyGraph mode")
Expand Down
Loading