Skip to content

Commit 209075a

Browse files
authored
[CPU-PSLIB] Add consistency insepection of use_var_list and data_generator data, test=develop (#34463)
1 parent 8967a66 commit 209075a

File tree

2 files changed

+471
-0
lines changed

2 files changed

+471
-0
lines changed

python/paddle/distributed/fleet/dataset/dataset.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,71 @@ def _dynamic_adjust_before_train(self, thread_num):
255255
def _dynamic_adjust_after_train(self):
256256
pass
257257

258+
def _check_use_var_with_data_generator(self, var_list, data_generator_class,
259+
test_file):
260+
"""
261+
Var consistency insepection of use_var_list and data_generator data.
262+
263+
Examples:
264+
.. code-block:: python
265+
266+
# required: skiptest
267+
import paddle
268+
from dataset_generator import CTRDataset
269+
dataset = paddle.distributed.fleet.DatasetBase()
270+
generator_class = CTRDataset()
271+
dataset._check_use_var_with_data_generator([data, label], generator_class, "data/part-00000")
272+
273+
Args:
274+
var_list(list): variable list
275+
data_generator_class(class): data_generator class
276+
test_file(str): local test file path
277+
"""
278+
279+
f = open(test_file, "r")
280+
var_len = len(var_list)
281+
282+
while True:
283+
line = f.readline()
284+
if line:
285+
line_iter = data_generator_class.generate_sample(line)
286+
for user_parsed_line in line_iter():
287+
data_gen_len = len(user_parsed_line)
288+
if var_len != data_gen_len:
289+
raise ValueError(
290+
"var length mismatch error: var_list = %s vs data_generator = %s"
291+
% (var_len, data_gen_len))
292+
293+
for i, ele in enumerate(user_parsed_line):
294+
if len(ele[1]) == 0:
295+
raise ValueError(
296+
"var length error: var %s's length in data_generator is 0"
297+
% ele[0])
298+
299+
if var_list[
300+
i].dtype == core.VarDesc.VarType.FP32 and not all(
301+
isinstance(ele, float) for ele in ele[1]):
302+
raise TypeError(
303+
"var dtype mismatch error: var name = %s, var type in var_list = %s, while var in data_generator contains non-float value, which is %s \n"
304+
"Please check if order of var_list and data_generator are aligned. \n"
305+
"Please check if var's type in data_generator is correct."
306+
% (ele[0], "float", ele[1]))
307+
308+
if (var_list[i].dtype == core.VarDesc.VarType.INT64 or
309+
var_list[i].dtype == core.VarDesc.VarType.INT32
310+
) and not all(
311+
isinstance(ele, int) for ele in ele[1]):
312+
raise TypeError(
313+
"var dtype mismatch error: var name = %s, var type in var_list = %s, while var in data_generator contains non-int value, which is %s \n"
314+
"Please check if order of var_list and data_generator are aligned. \n"
315+
"Please check if var's type in data_generator is correct."
316+
% (ele[0], "int", ele[1]))
317+
318+
else:
319+
break
320+
321+
f.close()
322+
258323

259324
class InMemoryDataset(DatasetBase):
260325
"""

0 commit comments

Comments
 (0)