Skip to content

Commit 50edf52

Browse files
authored
Merge pull request #862 from backyes/fix_data_sources
refine data_sources.py and PyDataProvider2.py to make more readable
2 parents f0449b8 + 7b08a98 commit 50edf52

File tree

2 files changed

+14
-18
lines changed

2 files changed

+14
-18
lines changed

python/paddle/trainer/PyDataProvider2.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,7 @@ def integer_value_sub_sequence(dim):
107107
return integer_value(dim, seq_type=SequenceType.SUB_SEQUENCE)
108108

109109

110-
def integer_sequence(dim):
111-
return index_slot(dim, seq_type=SequenceType.SEQUENCE)
110+
integer_sequence = integer_value_sequence
112111

113112

114113
class SingleSlotWrapper(object):

python/paddle/trainer_config_helpers/data_sources.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -78,21 +78,6 @@ def define_py_data_source(file_list,
7878
if not isinstance(args, basestring) and args is not None:
7979
args = pickle.dumps(args, 0)
8080

81-
if data_cls is None:
82-
83-
def py_data2(files, load_data_module, load_data_object, load_data_args,
84-
**kwargs):
85-
data = DataBase()
86-
data.type = 'py2'
87-
data.files = files
88-
data.load_data_module = load_data_module
89-
data.load_data_object = load_data_object
90-
data.load_data_args = load_data_args
91-
data.async_load_data = True
92-
return data
93-
94-
data_cls = py_data2
95-
9681
cls(
9782
data_cls(
9883
files=file_list,
@@ -207,10 +192,22 @@ def define_py_data_sources2(train_list, test_list, module, obj, args=None):
207192
:return: None
208193
:rtype: None
209194
"""
195+
196+
def py_data2(files, load_data_module, load_data_object, load_data_args,
197+
**kwargs):
198+
data = DataBase()
199+
data.type = 'py2'
200+
data.files = files
201+
data.load_data_module = load_data_module
202+
data.load_data_object = load_data_object
203+
data.load_data_args = load_data_args
204+
data.async_load_data = True
205+
return data
206+
210207
define_py_data_sources(
211208
train_list=train_list,
212209
test_list=test_list,
213210
module=module,
214211
obj=obj,
215212
args=args,
216-
data_cls=None)
213+
data_cls=py_data2)

0 commit comments

Comments
 (0)