Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion python/paddle/io/dataloader/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.

import bisect
import math
import warnings
from typing import Iterable

import paddle
Expand Down Expand Up @@ -487,7 +489,7 @@ def random_split(dataset, lengths, generator=None):

Args:
dataset (Dataset): Dataset to be split
lengths (sequence): lengths of splits to be produced
lengths (sequence): lengths or fractions of splits to be produced
generator (Generator, optional): Generator used for the random permutation. Default is None then the DefaultGenerator is used in manual_seed().

Returns:
Expand Down Expand Up @@ -522,6 +524,28 @@ def random_split(dataset, lengths, generator=None):
5 3
6 8
"""
if math.isclose(sum(lengths), 1) and sum(lengths) <= 1:
subset_lengths = []
for i, frac in enumerate(lengths):
if frac < 0 or frac > 1:
raise ValueError(
f"Fraction at index {i} is not between 0 and 1"
)
n_items_in_split = int(math.floor(len(dataset) * frac))
subset_lengths.append(n_items_in_split)
remainder = len(dataset) - sum(subset_lengths)

for i in range(remainder):
idx_to_add_at = i % len(subset_lengths)
subset_lengths[idx_to_add_at] += 1
lengths = subset_lengths
for i, length in enumerate(lengths):
if length == 0:
warnings.warn(
f"Length of split at index {i} is 0. "
f"This might result in an empty dataset."
)

# Cannot verify that dataset is Sized
if sum(lengths) != len(dataset): # type: ignore
raise ValueError(
Expand Down
10 changes: 5 additions & 5 deletions python/paddle/io/dataloader/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,7 @@ class RandomSampler(Sampler):
object which implemented :code:`__len__` to get indices as the range of :code:`dataset` length. Default None.
replacement(bool, optional): If False, sample the whole dataset, If True,
set :attr:`num_samples` for how many samples to draw. Default False.
num_samples(int, optional): set sample number to draw if :attr:`replacement`
is True, then it will take samples according to the number you set. Default None, disabled.
num_samples(int, optional): set sample number to draw. Default None, which is set to the length of `data_source`.
generator(Generator, optional): specify a generator to sample the :code:`data_source`. Default None, disabled.

Returns:
Expand Down Expand Up @@ -212,9 +211,10 @@ def __init__(
f"replacement={self.replacement}"
)

if self._num_samples is not None and not replacement:
if not self.replacement and self.num_samples > len(self.data_source):
raise ValueError(
"num_samples should not be specified while replacement is False"
"num_samples should be smaller than or equal to length of data_source when replacement is False, "
f"but got num_samples: {self.num_samples} > data_source: {len(self.data_source)}"
)

if not isinstance(self.num_samples, int) or self.num_samples <= 0:
Expand Down Expand Up @@ -246,7 +246,7 @@ def __iter__(self):
yield index
else:
for index in np.random.choice(
np.arange(n), n, replace=False
np.arange(n), self.num_samples, replace=False
).tolist():
yield index

Expand Down
8 changes: 6 additions & 2 deletions python/paddle/nn/layer/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1349,10 +1349,14 @@ def named_buffers(self, prefix='', include_sublayers=True):
name = layer_prefix + ('.' if layer_prefix else '') + key
yield name, buffer

def clear_gradients(self):
def clear_gradients(self, set_to_zero=True):
"""
Clear the gradients of all parameters for this layer.

Args:
set_to_zero (bool, optional): Whether to set the trainable parameters'
gradients to zero or None. Default is True.

Returns:
None

Expand All @@ -1375,7 +1379,7 @@ def clear_gradients(self):
"""
for p in self.parameters():
if p.trainable:
p.clear_gradient()
p.clear_gradient(set_to_zero)

def _build_once(self, *args, **kwargs):
pass
Expand Down
14 changes: 14 additions & 0 deletions test/legacy_test/test_batch_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,16 @@ def test_with_num_samples(self):
rets.append(i)
assert i >= 0 and i < 100

def test_with_num_samples_and_without_replacement(self):
dataset = RandomDataset(100, 10)
sampler = RandomSampler(dataset, num_samples=80, replacement=False)
assert len(sampler) == 80

rets = []
for i in iter(sampler):
rets.append(i)
assert i >= 0 and i < 100

def test_with_generator(self):
dataset = RandomDataset(100, 10)
generator = iter(range(0, 60))
Expand All @@ -111,6 +121,10 @@ def test_with_generator_num_samples(self):
rets.append(i)
assert tuple(sorted(rets)) == tuple(range(0, 50))

def test_with_num_samples_error(self):
dataset = RandomDataset(100, 10)
self.assertRaises(ValueError, RandomSampler, dataset, False, 120)


class TestSubsetRandomSampler(unittest.TestCase):
def test_main(self):
Expand Down
29 changes: 29 additions & 0 deletions test/legacy_test/test_dataloader_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,5 +89,34 @@ def test_multi_process(self):
break


class TestRandomSplitApi(unittest.TestCase):
def test_main(self):
paddle.seed(1)

dataset1, dataset2, dataset3 = paddle.io.random_split(
range(5), [0.3, 0.0, 0.7]
)

self.assertTrue(len(dataset1) == 2)
self.assertTrue(len(dataset2) == 0)
self.assertTrue(len(dataset3) == 3)

elements_list = list(range(5))

for _, val in enumerate(dataset1):
elements_list.remove(val)

for _, val in enumerate(dataset3):
elements_list.remove(val)

self.assertTrue(len(elements_list) == 0)

def test_errors(self):
paddle.seed(1)
self.assertRaises(
ValueError, paddle.io.random_split, range(5), [-0.2, 1.2]
)


if __name__ == '__main__':
unittest.main()
22 changes: 22 additions & 0 deletions test/legacy_test/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2465,6 +2465,28 @@ def test_support_tuple(self):
self.assertTrue(model._linear.weight.dtype == paddle.float32)


class TestLayerClearGradientSetToZero(unittest.TestCase):
def test_layer_clear_gradient_set_to_zero_true(self):
with base.dygraph.guard():
net = MyLayer()
inputs = paddle.randn([10, 1])
outputs = net(inputs)
outputs.backward()
net.clear_gradients()
self.assertTrue(
net._linear.weight.grad.numpy() == np.array([[0.0]])
)

def test_layer_clear_gradient_set_to_zero_false(self):
with base.dygraph.guard():
net = MyLayer()
inputs = paddle.randn([10, 1])
outputs = net(inputs)
outputs.backward()
net.clear_gradients(set_to_zero=False)
self.assertTrue(net._linear.weight.grad is None)


if __name__ == '__main__':
paddle.enable_static()
unittest.main()