1515import os
1616import subprocess
1717import sys
18+ import tempfile
1819import unittest
1920
2021import numpy
2324from paddle import base
2425from paddle .distributed import fleet
2526from paddle .distributed .fleet .base import role_maker
27+ from paddle .distributed .utils .launch_utils import find_free_ports
2628
2729paddle .enable_static ()
2830
2931
3032class TestCommunicatorHalfAsyncEnd2End (unittest .TestCase ):
3133 def net (self ):
3234 x = paddle .static .data (name = 'x' , shape = [- 1 , 13 ], dtype = 'float32' )
33- y_predict = paddle .static .nn .fc (x , size = 1 , activation = None )
34- y = paddle .static .data (name = 'y' , shape = [- 1 , 1 ], dtype = 'float32' )
35+ x1 = paddle .static .data (
36+ name = 'x1' , shape = [- 1 , 1 ], dtype = 'int64' , lod_level = 1
37+ )
3538
39+ emb = paddle .static .nn .embedding (
40+ input = x1 ,
41+ size = [10000 , 10 ],
42+ param_attr = base .ParamAttr (
43+ name = "embedding" ,
44+ initializer = paddle .nn .initializer .Constant (value = 0.01 ),
45+ ),
46+ is_sparse = True ,
47+ )
48+
49+ pool = paddle .static .nn .sequence_lod .sequence_pool (
50+ input = emb .squeeze (- 2 ), pool_type = "sum"
51+ )
52+ z = paddle .concat ([x , pool ], axis = 1 )
53+
54+ y_predict = paddle .static .nn .fc (x = z , size = 1 )
55+ y = paddle .static .data (name = 'y' , shape = [- 1 , 1 ], dtype = 'float32' )
3656 cost = paddle .nn .functional .square_error_cost (input = y_predict , label = y )
3757 avg_cost = paddle .mean (cost )
38- return avg_cost , x , y
58+ return avg_cost , x , x1 , y
3959
4060 def fake_reader (self ):
4161 def reader ():
4262 for i in range (10000 ):
4363 x = numpy .random .random ((1 , 13 )).astype ('float32' )
64+ z = numpy .random .randint (0 , 9999 , (1 , 1 )).astype ('int64' )
4465 y = numpy .random .randint (0 , 2 , (1 , 1 )).astype ('int64' )
45- yield x , y
66+ yield x , z , y
4667
4768 return reader
4869
4970 def run_pserver (self , role , strategy ):
5071 fleet .init (role )
51- avg_cost , x , y = self .net ()
72+ avg_cost , x , z , y = self .net ()
5273 optimizer = paddle .optimizer .SGD (0.01 )
5374 optimizer = fleet .distributed_optimizer (optimizer , strategy )
5475 optimizer .minimize (avg_cost )
@@ -61,102 +82,79 @@ def run_trainer(self, role, strategy):
6182 exe = base .Executor (place )
6283
6384 fleet .init (role )
64- avg_cost , x , y = self .net ()
85+ avg_cost , x , z , y = self .net ()
6586 optimizer = paddle .optimizer .SGD (0.01 )
6687 optimizer = fleet .distributed_optimizer (optimizer , strategy )
6788 optimizer .minimize (avg_cost )
6889
69- exe .run (paddle . static .default_startup_program ())
90+ exe .run (base .default_startup_program ())
7091 fleet .init_worker ()
7192
7293 train_reader = paddle .batch (self .fake_reader (), batch_size = 24 )
73- feeder = base .DataFeeder (place = place , feed_list = [x , y ])
94+ feeder = base .DataFeeder (place = place , feed_list = [x , z , y ])
7495
7596 for batch_id , data in enumerate (train_reader ()):
7697 exe .run (
77- paddle . static .default_main_program (),
98+ base .default_main_program (),
7899 feed = feeder .feed (data ),
79100 fetch_list = [],
80101 )
81102
82103 fleet .stop_worker ()
83104
84105 def run_ut (self ):
85- strategy = paddle .distributed .fleet .DistributedStrategy ()
86- strategy .a_sync = True
87-
88106 training_role = os .getenv ("TRAINING_ROLE" , "TRAINER" )
89107
90- role = role_maker .UserDefinedRoleMaker (
91- current_id = 0 ,
92- role = role_maker .Role .WORKER
93- if training_role == "TRAINER"
94- else role_maker .Role .SERVER ,
95- worker_num = 1 ,
96- server_endpoints = ["127.0.0.1:6002" ],
97- )
108+ os .environ ["PADDLE_PSERVER_NUMS" ] = "1"
109+ os .environ ["PADDLE_TRAINERS_NUM" ] = "1"
110+ os .environ ["PADDLE_TRAINER_ID" ] = "0"
111+ os .environ ["PADDLE_TRAINERS_NUM" ] = "1"
112+ os .environ ["POD_IP" ] = "127.0.0.1"
113+
114+ role = role_maker .PaddleCloudRoleMaker ()
115+
116+ strategy = paddle .distributed .fleet .DistributedStrategy ()
117+ strategy .a_sync = True
98118
99119 if training_role == "TRAINER" :
100120 self .run_trainer (role , strategy )
101121 else :
102122 self .run_pserver (role , strategy )
103123
104124 def test_communicator (self ):
105- run_server_cmd = """
125+ temp_dir = tempfile .TemporaryDirectory ()
126+ pipe_name = os .path .join (temp_dir .name , 'mypipe' )
127+ try :
128+ os .mkfifo (pipe_name )
129+ except OSError as oe :
130+ print (f"Failed to create pipe: { oe } " )
106131
107- import sys
108- import os
132+ port = find_free_ports (1 ).pop ()
109133
110- import time
111- import threading
112- import subprocess
113- import unittest
114- import numpy
115-
116- from test_communicator_half_async import TestCommunicatorHalfAsyncEnd2End
117-
118- import paddle
119- import paddle.base as base
120- import paddle.distributed.fleet as fleet
121- import paddle.distributed.fleet.base.role_maker as role_maker
122-
123- paddle.enable_static()
124-
125- class RunServer(TestCommunicatorHalfAsyncEnd2End):
126- def runTest(self):
127- pass
128-
129- os.environ["http_proxy"] = ""
130- os.environ["https_proxy"] = ""
131- os.environ["TRAINING_ROLE"] = "PSERVER"
132- half_run_server = RunServer()
133- half_run_server.run_ut()
134- """
135-
136- server_file = "run_server_for_communicator_haflaysnc.py"
137- with open (server_file , "w" ) as wb :
138- wb .write (run_server_cmd )
139134 os .environ ["TRAINING_ROLE" ] = "PSERVER"
140- _python = sys .executable
135+ os .environ ["PADDLE_PORT" ] = str (port )
136+ os .environ ["PADDLE_PSERVERS_IP_PORT_LIST" ] = f"127.0.0.1:{ port } "
137+ os .environ ["PIPE_FILE" ] = pipe_name
141138
139+ _python = sys .executable
140+ server_file = "run_server_for_communicator_half_async.py"
142141 ps_cmd = f"{ _python } { server_file } "
142+
143143 ps_proc = subprocess .Popen (
144144 ps_cmd .strip ().split (" " ),
145145 stdout = subprocess .PIPE ,
146146 stderr = subprocess .PIPE ,
147147 )
148148
149- os .environ ["http_proxy" ] = ""
150- os .environ ["https_proxy" ] = ""
149+ with open (pipe_name , 'r' ) as pipe :
150+ start_command = pipe .read ()
151+
151152 os .environ ["TRAINING_ROLE" ] = "TRAINER"
152- os .environ ["FLAGS_communicator_send_queue_size" ] = "1"
153- os .environ ["FLAGS_communicator_max_merge_var_num" ] = "1"
154153
155154 self .run_ut ()
156155 ps_proc .kill ()
157-
158- if os .path .exists (server_file ):
159- os .remove (server_file )
156+ ps_proc .wait ()
157+ outs , errs = ps_proc .communicate ()
160158
161159
162160if __name__ == '__main__' :
0 commit comments