Skip to content

Commit 16ebca7

Browse files
committed
use floyd algorithm to find meta optimizer max path, test=develop
1 parent 6cdf2c9 commit 16ebca7

File tree

3 files changed

+162
-16
lines changed

3 files changed

+162
-16
lines changed

python/paddle/distributed/fleet/base/strategy_compiler.py

Lines changed: 87 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,24 +13,95 @@
1313
# limitations under the License.
1414

1515

16-
def maximum_path_len_algo(optimizer_list):
17-
max_idx = 0
18-
max_len = 0
19-
candidates = []
20-
for idx, opt in enumerate(optimizer_list):
21-
local_buffer = [opt]
22-
for opt_inner in optimizer_list:
16+
def create_graph(optimizer_list):
17+
nsize = len(optimizer_list)
18+
19+
edge = [[0] * nsize for _ in range(nsize)] # adjacency matrix
20+
indegree = [0] * nsize
21+
for i, opt in enumerate(optimizer_list):
22+
for j, opt_inner in enumerate(optimizer_list):
2323
if opt._can_update(opt_inner):
24-
local_buffer.append(opt_inner)
25-
if len(local_buffer) > max_len:
26-
max_idx = idx
27-
max_len = len(local_buffer)
28-
candidates.append(local_buffer)
29-
if len(candidates) == 0:
24+
edge[i][j] = 1 # weight
25+
indegree[j] += 1
26+
27+
return edge, indegree
28+
29+
30+
def topo_sort(edge, indegree):
31+
nsize = len(indegree)
32+
33+
topo = [-1] * nsize
34+
for i in range(nsize):
35+
j = 0
36+
while j < nsize and indegree[j] != 0:
37+
j += 1
38+
assert j < nsize, 'The combination of meta optimizers contains ring'
39+
40+
topo[i] = j
41+
indegree[j] = -1
42+
for k in range(nsize):
43+
if edge[j][k] != 0:
44+
indegree[k] -= 1
45+
46+
return topo
47+
48+
49+
def floyd(edge):
50+
nsize = len(edge)
51+
max_len = -1
52+
max_edge = [-1, -1]
53+
54+
max_path = [[[] for _ in range(nsize)] for _ in range(nsize)]
55+
for i in range(nsize):
56+
for j in range(nsize):
57+
if edge[i][j] > 0:
58+
max_path[i][j] = [j]
59+
60+
if edge[i][j] > max_len:
61+
max_len = edge[i][j]
62+
max_edge = [i, j]
63+
64+
# use floyd algorithm to find max_path
65+
for k in range(nsize):
66+
for i in range(nsize):
67+
for j in range(nsize):
68+
# if a-->b-->c, but a-/->c, can only apply a-->b or b-->c,
69+
# however if a-->b-->c, and a-->c, can apply a->b->c
70+
if edge[i][j] == 0:
71+
continue
72+
73+
if edge[i][k] == 0 or edge[k][j] == 0:
74+
continue
75+
76+
if edge[i][j] < edge[i][k] + edge[k][j]:
77+
edge[i][j] = edge[i][k] + edge[k][j]
78+
max_path[i][j] = max_path[i][k] + max_path[k][j]
79+
80+
max_len = edge[i][j]
81+
max_edge = [i, j]
82+
83+
if max_len == -1:
84+
return [0]
85+
86+
return [max_edge[0]] + max_path[max_edge[0]][max_edge[1]]
87+
88+
89+
def maximum_path_len_algo(optimizer_list):
90+
if len(optimizer_list) == 0:
3091
return None
31-
for idx, opt in enumerate(candidates[max_idx][:-1]):
32-
opt._update_inner_optimizer(candidates[max_idx][idx + 1])
33-
return candidates[max_idx]
92+
93+
edge, indegree = create_graph(optimizer_list)
94+
topo_sort(edge, indegree)
95+
max_path = floyd(edge)
96+
97+
candidate = []
98+
for idx in max_path:
99+
candidate.append(optimizer_list[idx])
100+
101+
for idx, opt in enumerate(candidate[:-1]):
102+
opt._update_inner_optimizer(candidate[idx + 1])
103+
104+
return candidate
34105

35106

36107
class StrategyCompilerBase(object):

python/paddle/fluid/tests/unittests/test_fleet_amp_meta_optimizer.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,51 @@ def test_amp_recompute_optimizer(self):
103103
# recompute
104104
self.assertIn('subprog', ''.join(outs))
105105

106+
def test_amp_recompute_lars_optimizer(self):
107+
""" test amp + recompute """
108+
train_prog, startup_prog = fluid.Program(), fluid.Program()
109+
avg_cost, strategy = self.net(train_prog, startup_prog)
110+
self.set_strategy(strategy, 'amp')
111+
self.set_strategy(strategy, 'recompute')
112+
self.set_strategy(strategy, 'lars')
113+
self.optimizer(avg_cost, strategy, train_prog, startup_prog)
114+
115+
strategy = fleet._final_strategy()
116+
117+
ops = [op.type for op in avg_cost.block.ops]
118+
outs = [
119+
op.output('Out')[0] for op in avg_cost.block.ops if op.type == 'mul'
120+
]
121+
self.assertIn('cast', ops)
122+
self.assertIn('check_finite_and_unscale', ops)
123+
124+
# recompute
125+
self.assertIn('subprog', ''.join(outs))
126+
127+
# lars
128+
self.assertIn('lars_momentum', ops)
129+
130+
def test_amp_recompute_lamb_optimizer(self):
131+
train_prog, startup_prog = fluid.Program(), fluid.Program()
132+
avg_cost, strategy = self.net(train_prog, startup_prog)
133+
self.set_strategy(strategy, 'amp')
134+
self.set_strategy(strategy, 'recompute')
135+
self.set_strategy(strategy, 'lamb')
136+
self.optimizer(avg_cost, strategy, train_prog, startup_prog, 'adam')
137+
138+
ops = [op.type for op in avg_cost.block.ops]
139+
outs = [
140+
op.output('Out')[0] for op in avg_cost.block.ops if op.type == 'mul'
141+
]
142+
self.assertIn('cast', ops)
143+
self.assertIn('check_finite_and_unscale', ops)
144+
145+
# recompute
146+
self.assertIn('subprog', ''.join(outs))
147+
148+
# lamb
149+
self.assertIn('lamb', ops)
150+
106151

107152
if __name__ == "__main__":
108153
unittest.main()

python/paddle/fluid/tests/unittests/test_fleet_dgc_meta_optimizer.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,36 @@ def test_dgc_recompute_optimizer(self):
128128
# recompute
129129
self.assertIn('subprog', ''.join(outs))
130130

131+
def test_amp_recompute_lars_dgc_not_apply_optimizer(self):
132+
""" test amp + recompute + lars + dgc,
133+
amp -/-> dgc, max_path is amp-->recompute-->lars
134+
"""
135+
train_prog, startup_prog = fluid.Program(), fluid.Program()
136+
avg_cost, strategy = self.net(train_prog, startup_prog)
137+
self.set_strategy(strategy, 'dgc')
138+
self.set_strategy(strategy, 'amp')
139+
self.set_strategy(strategy, 'recompute')
140+
self.set_strategy(strategy, 'lars')
141+
self.optimizer(avg_cost, strategy, train_prog, startup_prog)
142+
143+
strategy = fleet._final_strategy()
144+
145+
ops = [op.type for op in avg_cost.block.ops]
146+
outs = [
147+
op.output('Out')[0] for op in avg_cost.block.ops if op.type == 'mul'
148+
]
149+
self.assertIn('cast', ops)
150+
self.assertIn('check_finite_and_unscale', ops)
151+
152+
# recompute
153+
self.assertIn('subprog', ''.join(outs))
154+
155+
# lars
156+
self.assertIn('lars_momentum', ops)
157+
158+
# dgc not apply
159+
self.assertFalse(strategy.dgc)
160+
131161

132162
if __name__ == "__main__":
133163
unittest.main()

0 commit comments

Comments
 (0)