@@ -255,6 +255,71 @@ def _dynamic_adjust_before_train(self, thread_num):
255255 def _dynamic_adjust_after_train (self ):
256256 pass
257257
258+ def _check_use_var_with_data_generator (self , var_list , data_generator_class ,
259+ test_file ):
260+ """
261+ Var consistency insepection of use_var_list and data_generator data.
262+
263+ Examples:
264+ .. code-block:: python
265+
266+ # required: skiptest
267+ import paddle
268+ from dataset_generator import CTRDataset
269+ dataset = paddle.distributed.fleet.DatasetBase()
270+ generator_class = CTRDataset()
271+ dataset._check_use_var_with_data_generator([data, label], generator_class, "data/part-00000")
272+
273+ Args:
274+ var_list(list): variable list
275+ data_generator_class(class): data_generator class
276+ test_file(str): local test file path
277+ """
278+
279+ f = open (test_file , "r" )
280+ var_len = len (var_list )
281+
282+ while True :
283+ line = f .readline ()
284+ if line :
285+ line_iter = data_generator_class .generate_sample (line )
286+ for user_parsed_line in line_iter ():
287+ data_gen_len = len (user_parsed_line )
288+ if var_len != data_gen_len :
289+ raise ValueError (
290+ "var length mismatch error: var_list = %s vs data_generator = %s"
291+ % (var_len , data_gen_len ))
292+
293+ for i , ele in enumerate (user_parsed_line ):
294+ if len (ele [1 ]) == 0 :
295+ raise ValueError (
296+ "var length error: var %s's length in data_generator is 0"
297+ % ele [0 ])
298+
299+ if var_list [
300+ i ].dtype == core .VarDesc .VarType .FP32 and not all (
301+ isinstance (ele , float ) for ele in ele [1 ]):
302+ raise TypeError (
303+ "var dtype mismatch error: var name = %s, var type in var_list = %s, while var in data_generator contains non-float value, which is %s \n "
304+ "Please check if order of var_list and data_generator are aligned. \n "
305+ "Please check if var's type in data_generator is correct."
306+ % (ele [0 ], "float" , ele [1 ]))
307+
308+ if (var_list [i ].dtype == core .VarDesc .VarType .INT64 or
309+ var_list [i ].dtype == core .VarDesc .VarType .INT32
310+ ) and not all (
311+ isinstance (ele , int ) for ele in ele [1 ]):
312+ raise TypeError (
313+ "var dtype mismatch error: var name = %s, var type in var_list = %s, while var in data_generator contains non-int value, which is %s \n "
314+ "Please check if order of var_list and data_generator are aligned. \n "
315+ "Please check if var's type in data_generator is correct."
316+ % (ele [0 ], "int" , ele [1 ]))
317+
318+ else :
319+ break
320+
321+ f .close ()
322+
258323
259324class InMemoryDataset (DatasetBase ):
260325 """
0 commit comments