Skip to content

Commit 5e0f199

Browse files
youth123sneaxiy
andauthored
Fix raw optim (#36176)
* fix raw optim * pre-commit test file Co-authored-by: sneaxiy <[email protected]>
1 parent 8af939f commit 5e0f199

File tree

3 files changed

+161
-0
lines changed

3 files changed

+161
-0
lines changed

python/paddle/distributed/fleet/meta_optimizers/raw_program_optimizer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,8 @@ def __get_ouputs_name_to_idx(self, first_backward_idx, block):
460460
if is_optimizer_op(op):
461461
break
462462
for name in op.output_arg_names:
463+
if name == core.kEmptyVarName():
464+
continue
463465
var = block.var(name)
464466
if not outputs_name_to_idx.get(var):
465467
# if the grad only be generated by one op

python/paddle/fluid/tests/unittests/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ list(APPEND DIST_TEST_OPS test_parallel_dygraph_transformer)
2121
list(APPEND DIST_TEST_OPS test_fleet_pipeline_meta_optimizer)
2222
list(APPEND DIST_TEST_OPS test_fleet_pipeline_meta_optimizer_with_recompute)
2323
list(APPEND DIST_TEST_OPS test_fleet_raw_program_meta_optimizer)
24+
list(APPEND DIST_TEST_OPS test_rnn_dp)
2425
list(APPEND DIST_TEST_OPS test_fleet_graph_execution_meta_optimizer)
2526
list(APPEND DIST_TEST_OPS test_gen_nccl_id_op)
2627
list(APPEND DIST_TEST_OPS test_parallel_dygraph_unused_variables)
@@ -66,6 +67,7 @@ list(APPEND MIXED_DIST_TEST_OPS test_fleet_recompute_meta_optimizer)
6667
list(APPEND MIXED_DIST_TEST_OPS test_fleet_pipeline_meta_optimizer)
6768
list(APPEND MIXED_DIST_TEST_OPS test_fleet_pipeline_meta_optimizer_with_recompute)
6869
list(APPEND MIXED_DIST_TEST_OPS test_fleet_raw_program_meta_optimizer)
70+
list(APPEND MIXED_DIST_TEST_OPS test_rnn_dp)
6971
list(APPEND MIXED_DIST_TEST_OPS test_fleet_amp_meta_optimizer)
7072
list(APPEND MIXED_DIST_TEST_OPS test_fleet_amp_init)
7173
list(APPEND MIXED_DIST_TEST_OPS test_fleet_gradient_merge_meta_optimizer)
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
import paddle
17+
import os
18+
19+
import numpy as np
20+
import paddle
21+
import paddle.static as static
22+
import paddle.distributed.fleet as fleet
23+
import paddle.nn as nn
24+
import paddle.nn.functional as F
25+
26+
paddle.enable_static()
27+
28+
29+
class RNNEncoder(nn.Layer):
30+
def __init__(self,
31+
input_size,
32+
hidden_size,
33+
num_layers=1,
34+
direction="forward",
35+
dropout=0.0,
36+
pooling_type=None,
37+
**kwargs):
38+
super().__init__()
39+
self._input_size = input_size
40+
self._hidden_size = hidden_size
41+
self._direction = direction
42+
self._pooling_type = pooling_type
43+
44+
self.rnn_layer = nn.SimpleRNN(
45+
input_size=input_size,
46+
hidden_size=hidden_size,
47+
num_layers=num_layers,
48+
direction=direction,
49+
dropout=dropout,
50+
**kwargs)
51+
52+
def get_input_dim(self):
53+
return self._input_size
54+
55+
def get_output_dim(self):
56+
if self._direction == "bidirect":
57+
return self._hidden_size * 2
58+
else:
59+
return self._hidden_size
60+
61+
def forward(self, inputs, sequence_length):
62+
encoded_text, last_hidden = self.rnn_layer(
63+
inputs, sequence_length=sequence_length)
64+
output = paddle.max(encoded_text, axis=1)
65+
return output
66+
67+
68+
class RNNModel(nn.Layer):
69+
def __init__(self,
70+
vocab_size,
71+
num_classes,
72+
emb_dim=128,
73+
padding_idx=0,
74+
rnn_hidden_size=198,
75+
direction='forward',
76+
rnn_layers=1,
77+
dropout_rate=0.0,
78+
pooling_type=None,
79+
fc_hidden_size=96):
80+
super().__init__()
81+
self.embedder = nn.Embedding(
82+
num_embeddings=vocab_size,
83+
embedding_dim=emb_dim,
84+
padding_idx=padding_idx)
85+
self.rnn_encoder = RNNEncoder(
86+
emb_dim,
87+
rnn_hidden_size,
88+
num_layers=rnn_layers,
89+
direction=direction,
90+
dropout=dropout_rate,
91+
pooling_type=pooling_type)
92+
self.fc = nn.Linear(self.rnn_encoder.get_output_dim(), fc_hidden_size)
93+
self.output_layer = nn.Linear(fc_hidden_size, num_classes)
94+
95+
def forward(self, text, seq_len):
96+
embedded_text = self.embedder(text)
97+
text_repr = self.rnn_encoder(embedded_text, sequence_length=seq_len)
98+
fc_out = paddle.tanh(self.fc(text_repr))
99+
logits = self.output_layer(fc_out)
100+
return logits
101+
102+
103+
def rnn_pretrain_forward(train_program, start_program, topo=None):
104+
with static.program_guard(train_program,
105+
start_program), paddle.utils.unique_name.guard():
106+
batch_size = 1
107+
tokens = static.data(
108+
name="tokens", shape=[batch_size, -1], dtype="int64")
109+
seq_len = static.data(name="ids", shape=[batch_size], dtype="int64")
110+
labels = static.data(name="labels", shape=[batch_size], dtype="int64")
111+
data_holders = [tokens, seq_len, labels]
112+
vocab_size = 10
113+
num_classes = 2
114+
pad_token_id = 0
115+
model = RNNModel(
116+
vocab_size,
117+
num_classes,
118+
direction='forward',
119+
padding_idx=pad_token_id,
120+
pooling_type='max')
121+
122+
optimizer = paddle.optimizer.Adam(
123+
parameters=model.parameters(), learning_rate=0.001)
124+
criterion = paddle.nn.CrossEntropyLoss()
125+
preds = model(tokens, seq_len)
126+
loss = criterion(preds, labels)
127+
128+
return train_program, start_program, loss, optimizer, data_holders
129+
130+
131+
class TestFleetMetaOptimizer(unittest.TestCase):
132+
def setUp(self):
133+
os.environ["PADDLE_TRAINER_ID"] = "1"
134+
os.environ[
135+
"PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36001,127.0.0.1:36002"
136+
137+
def test_rnn_raw_optimizer(self):
138+
import paddle.distributed.fleet as fleet
139+
import paddle.distributed.fleet.base.role_maker as role_maker
140+
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
141+
fleet.init(role)
142+
train_program = static.Program()
143+
start_program = static.Program()
144+
train_program, start_program, loss, optimizer, data_holders = \
145+
rnn_pretrain_forward(train_program, start_program)
146+
with paddle.static.program_guard(
147+
train_program, start_program), paddle.utils.unique_name.guard():
148+
strategy = fleet.DistributedStrategy()
149+
strategy.without_graph_optimization = True
150+
strategy.fuse_all_reduce_ops = True
151+
fleet.init(is_collective=True, strategy=strategy)
152+
optimizer = fleet.distributed_optimizer(optimizer)
153+
optimizer.minimize(loss)
154+
155+
156+
if __name__ == "__main__":
157+
unittest.main()

0 commit comments

Comments
 (0)