Skip to content

Commit be61e9e

Browse files
Merge pull request #16597 from guru4elephant/refine_dataset
refine dataset API
2 parents fb1ae72 + 2c5839f commit be61e9e

File tree

1 file changed

+13
-5
lines changed

1 file changed

+13
-5
lines changed

python/paddle/fluid/dataset.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from paddle.fluid.proto import data_feed_pb2
1616
from google.protobuf import text_format
1717
from . import core
18-
__all__ = ['DatasetFactory']
18+
__all__ = ['DatasetFactory', 'InMemoryDataset', 'QueueDataset']
1919

2020

2121
class DatasetFactory(object):
@@ -38,6 +38,10 @@ def create_dataset(self, datafeed_class="QueueDataset"):
3838
"""
3939
Create "QueueDataset" or "InMemoryDataset",
4040
the default is "QueueDataset".
41+
42+
Examples:
43+
import paddle.fluid as fluid
44+
dataset = fluid.DatasetFactory().create_dataset()
4145
"""
4246
try:
4347
dataset = globals()[datafeed_class]()
@@ -177,7 +181,8 @@ def desc(self):
177181
class InMemoryDataset(DatasetBase):
178182
"""
179183
InMemoryDataset, it will load data into memory
180-
and shuffle data before training
184+
and shuffle data before training.
185+
This class should be created by DatasetFactory
181186
182187
Example:
183188
dataset = paddle.fluid.DatasetFactory.create_dataset("InMemoryDataset")
@@ -259,7 +264,8 @@ class QueueDataset(DatasetBase):
259264

260265
def __init__(self):
261266
"""
262-
Init
267+
Initialize QueueDataset
268+
This class should be created by DatasetFactory
263269
"""
264270
super(QueueDataset, self).__init__()
265271
self.proto_desc.name = "MultiSlotDataFeed"
@@ -268,15 +274,17 @@ def local_shuffle(self):
268274
"""
269275
Local shuffle
270276
271-
QueueDataset does not support local shuffle
277+
Local shuffle is not supported in QueueDataset
278+
NotImplementedError will be raised
272279
"""
273280
raise NotImplementedError(
274281
"QueueDataset does not support local shuffle, "
275282
"please use InMemoryDataset for local_shuffle")
276283

277284
def global_shuffle(self, fleet=None):
278285
"""
279-
Global shuffle
286+
Global shuffle is not supported in QueueDataset
287+
NotImplementedError will be raised
280288
"""
281289
raise NotImplementedError(
282290
"QueueDataset does not support global shuffle, "

0 commit comments

Comments
 (0)