Skip to content

Commit d483b8c

Browse files
authored
Add solutions to PyLayer which is unsupported in DataParallel (#35401)
* Add solutions to PyLayer which is unsupported in DataParallel * modify note format for parallel.py * modify docs of dataparallel * add docs of dp with pylayer * modify docs format * modify example format * change example of dp with pylayer * add unittest for dp with pylayer * modify ut * merge latest codes * update * modify for CI-Coverage * modify text-indent
1 parent 4bc0853 commit d483b8c

File tree

6 files changed

+218
-5
lines changed

6 files changed

+218
-5
lines changed

paddle/fluid/imperative/reducer.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,21 @@ void Reducer::PrepareDeps(const std::unordered_set<GradOpNode *> &init_nodes) {
451451
PADDLE_ENFORCE_NOT_NULL(
452452
grad_pending_node,
453453
platform::errors::NotFound("Grad pending node should not be null"));
454+
// py_layer is not supported in DataParallel
455+
auto begin = grad_pending_node->begin();
456+
auto end = grad_pending_node->end();
457+
for (auto op_base = begin; op_base != end; op_base++) {
458+
PADDLE_ENFORCE_EQ(
459+
op_base->Type() != "py_layer", true,
460+
platform::errors::PreconditionNotMet(
461+
"Note: Currently PyLayer is not supported in DataParallel. For "
462+
"using PyLayer in a DataParallel model, you can skip gradient "
463+
"synchronization among multiple cards by 'no_sync', and "
464+
"manually implement 'all_reduce' before model optimization. "
465+
"There is an example showing specific implemetation processing "
466+
"in offical docs: https://www.paddlepaddle.org.cn/documentation"
467+
"/docs/api/paddle/DataParallel_cn.html"));
468+
}
454469
++node_deps_[grad_pending_node.get()];
455470
if (visited.count(grad_pending_node.get()) == 0) {
456471
visited.insert(grad_pending_node.get());

python/paddle/distributed/fleet/utils/hybrid_parallel_util.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from paddle import framework
2020
import paddle
2121
from paddle.fluid import core
22+
import paddle.distributed as dist
2223
from paddle.fluid.dygraph.parallel import _split_tensors, sync_params_buffers, build_groups
2324
from collections import OrderedDict
2425
from .log_util import logger
@@ -44,8 +45,9 @@ def _apply_collective_grads(parameters, comm_group):
4445

4546
for coalesced_grad, _, _ in coalesced_grads_and_vars:
4647
# need to div nranks
47-
div_factor = paddle.to_tensor(
48-
comm_group.nranks, dtype=coalesced_grad.dtype)
48+
nranks = dist.get_world_size(
49+
) if comm_group is None else comm_group.nranks
50+
div_factor = paddle.to_tensor(nranks, dtype=coalesced_grad.dtype)
4951
paddle.fluid.framework._dygraph_tracer().trace_op(
5052
type="elementwise_div",
5153
inputs={'X': coalesced_grad,
@@ -115,7 +117,7 @@ def broadcast_dp_parameters(model, hcg):
115117

116118

117119
def fused_allreduce_gradients(parameter_list, hcg):
118-
data_parallel_group = hcg.get_data_parallel_group()
120+
data_parallel_group = None if hcg is None else hcg.get_data_parallel_group()
119121
logger.debug("dp start fuse allreduce gradients")
120122
with framework.no_grad():
121123
_apply_collective_grads(parameter_list, data_parallel_group)

python/paddle/fluid/dygraph/parallel.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,8 +426,10 @@ class DataParallel(layers.Layer):
426426
Layer: The data paralleled module.
427427
428428
Examples:
429+
429430
.. code-block:: python
430-
431+
:name: dp-example
432+
431433
# required: distributed
432434
import paddle
433435
import paddle.nn as nn
@@ -471,6 +473,72 @@ def train():
471473
dist.spawn(train, nprocs=2)
472474
# 2. start by ``paddle.distributed.launch``
473475
# train()
476+
477+
478+
.. note::
479+
``PyLayer`` is not supported in DataParallel. To solve problems of this kind,
480+
it's recommended to skip gradient synchronization among multiple cards by 'no_sync',
481+
and manually implement 'all_reduce' before model optimization. There is an example
482+
showing specific implemetation processing.
483+
484+
Examples:
485+
486+
.. code-block:: python
487+
:name: dp-pylayer-example
488+
489+
# required: distributed
490+
import numpy
491+
import paddle
492+
import paddle.distributed as dist
493+
from paddle.autograd import PyLayer
494+
from paddle.distributed.fleet.utils.hybrid_parallel_util import fused_allreduce_gradients
495+
496+
class cus_tanh(PyLayer):
497+
@staticmethod
498+
def forward(ctx, x):
499+
y = paddle.tanh(x)
500+
ctx.save_for_backward(y)
501+
return y
502+
503+
@staticmethod
504+
def backward(ctx, dy):
505+
y, = ctx.saved_tensor()
506+
grad = dy * (1 - paddle.square(y))
507+
return grad
508+
509+
class SimpleNet(paddle.nn.Layer):
510+
def __init__(self):
511+
super(SimpleNet, self).__init__()
512+
self.linear = paddle.nn.Linear(2, 2)
513+
514+
def forward(self, inputs):
515+
inputs = cus_tanh.apply(inputs)
516+
return self.linear(inputs)
517+
518+
if __name__ == '__main__':
519+
dist.init_parallel_env()
520+
521+
model = SimpleNet()
522+
model = paddle.DataParallel(model)
523+
opt = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters())
524+
525+
for step in range(10):
526+
x_data = numpy.random.randn(2, 2).astype(numpy.float32)
527+
x = paddle.to_tensor(x_data)
528+
x.stop_gradient = False
529+
530+
# step 1 : skip gradient synchronization by 'no_sync'
531+
with model.no_sync():
532+
y_pred = model(x)
533+
loss = y_pred.mean()
534+
loss.backward()
535+
536+
# step 2 : fuse + allreduce manually before optimization
537+
fused_allreduce_gradients(list(model.parameters()), None)
538+
539+
opt.step()
540+
opt.clear_grad()
541+
474542
"""
475543

476544
def __init__(self,
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import division
16+
from __future__ import print_function
17+
18+
import unittest
19+
20+
import paddle
21+
import numpy as np
22+
import paddle.distributed as dist
23+
from paddle.fluid.dygraph.nn import Linear
24+
from paddle.autograd import PyLayer
25+
from paddle.distributed.fleet.utils.hybrid_parallel_util import fused_allreduce_gradients
26+
27+
batch = 5
28+
in_dim = 20
29+
out_dim = 10
30+
31+
32+
class cus_tanh(PyLayer):
33+
@staticmethod
34+
def forward(ctx, x):
35+
y = paddle.tanh(x)
36+
ctx.save_for_backward(y)
37+
return y
38+
39+
@staticmethod
40+
def backward(ctx, dy):
41+
y, = ctx.saved_tensor()
42+
grad = dy * (1 - paddle.square(y))
43+
return grad
44+
45+
46+
class SimpleNet(paddle.nn.Layer):
47+
def __init__(self, train_id, model_id):
48+
super(SimpleNet, self).__init__()
49+
self.w = self.create_parameter(shape=[in_dim, batch], dtype="float32")
50+
self.linear = paddle.nn.Linear(in_dim, out_dim)
51+
self.tanh = paddle.tanh
52+
53+
self.trainer_id = train_id
54+
self.model_id = model_id
55+
56+
def forward(self, inputs):
57+
if self.model_id == 0:
58+
inputs = cus_tanh.apply(inputs)
59+
else:
60+
inputs = self.tanh(inputs)
61+
62+
inputs = paddle.matmul(self.w, inputs)
63+
return self.linear(inputs)
64+
65+
66+
class TestDistTraning(unittest.TestCase):
67+
def test_multiple_gpus(self):
68+
self.trainer_id = dist.get_rank()
69+
dist.init_parallel_env()
70+
71+
model_a = SimpleNet(self.trainer_id, 0)
72+
model_b = SimpleNet(self.trainer_id, 1)
73+
74+
state_dict = model_a.state_dict()
75+
model_b.set_state_dict(state_dict)
76+
77+
model_a = paddle.DataParallel(model_a)
78+
model_b = paddle.DataParallel(model_b)
79+
80+
for step in range(10):
81+
x_data = np.random.randn(batch, in_dim).astype(np.float32)
82+
x = paddle.to_tensor(x_data)
83+
x.stop_gradient = False
84+
85+
with model_a.no_sync():
86+
y_pred_a = model_a(x)
87+
loss_a = y_pred_a.mean()
88+
loss_a.backward()
89+
fused_allreduce_gradients(list(model_a.parameters()), None)
90+
91+
y_pred_b = model_b(x)
92+
loss_b = y_pred_b.mean()
93+
loss_b.backward()
94+
95+
self.check_gradient(model_a.parameters())
96+
self.check_gradient(model_b.parameters())
97+
98+
self.check_acc(model_a._layers.w.grad, model_b._layers.w.grad)
99+
100+
model_a.clear_gradients()
101+
model_b.clear_gradients()
102+
103+
def check_acc(self, grad, acc_grad):
104+
grad = grad.numpy() if grad is not None else None
105+
acc_grad = acc_grad.numpy() if acc_grad is not None else None
106+
return np.testing.assert_allclose(grad, acc_grad, rtol=1e-6)
107+
108+
def broadcast_param(self, param, root):
109+
paddle.distributed.broadcast(param, root)
110+
return param
111+
112+
def check_gradient(self, params):
113+
other_param = []
114+
for param in params:
115+
if param.trainable and (param._grad_ivar() is not None):
116+
grad = param._grad_ivar()
117+
other_grad = self.broadcast_param(grad.clone(), root=1)
118+
if self.trainer_id == 0:
119+
np.testing.assert_allclose(other_grad.numpy(), grad.numpy())
120+
121+
122+
if __name__ == '__main__':
123+
unittest.main()

python/paddle/fluid/tests/unittests/test_parallel_dygraph_dataparallel.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,5 +130,10 @@ def test_multiple_gpus_dynamic(self):
130130
self.run_mnist_2gpu('parallel_dygraph_gradient_check.py')
131131

132132

133+
class TestDataParallelWithPyLayer(TestMultipleGpus):
134+
def test_parallel_dygraph_dataparallel_with_pylayer(self):
135+
self.run_mnist_2gpu('parallel_dygraph_dataparallel_with_pylayer.py')
136+
137+
133138
if __name__ == "__main__":
134139
unittest.main()

python/paddle/fluid/tests/unittests/test_parallel_dygraph_no_sync_gradient_check.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from test_parallel_dygraph_dataparallel import TestMultipleGpus
2121

2222

23-
class TestModelParallelLayer(TestMultipleGpus):
23+
class TestDataParallelLayer(TestMultipleGpus):
2424
def test_parallel_dygraph_dataparallel_no_sync(self):
2525
self.run_mnist_2gpu('parallel_dygraph_no_sync_gradient_check.py')
2626

0 commit comments

Comments
 (0)