diff --git a/python/paddle/io/dataloader/dataset.py b/python/paddle/io/dataloader/dataset.py index 666c9afe7bab68..267ef23b4fc8ec 100755 --- a/python/paddle/io/dataloader/dataset.py +++ b/python/paddle/io/dataloader/dataset.py @@ -13,6 +13,8 @@ # limitations under the License. import bisect +import math +import warnings from typing import Iterable import paddle @@ -487,7 +489,7 @@ def random_split(dataset, lengths, generator=None): Args: dataset (Dataset): Dataset to be split - lengths (sequence): lengths of splits to be produced + lengths (sequence): lengths or fractions of splits to be produced generator (Generator, optional): Generator used for the random permutation. Default is None then the DefaultGenerator is used in manual_seed(). Returns: @@ -522,6 +524,28 @@ def random_split(dataset, lengths, generator=None): 5 3 6 8 """ + if math.isclose(sum(lengths), 1) and sum(lengths) <= 1: + subset_lengths = [] + for i, frac in enumerate(lengths): + if frac < 0 or frac > 1: + raise ValueError( + f"Fraction at index {i} is not between 0 and 1" + ) + n_items_in_split = int(math.floor(len(dataset) * frac)) + subset_lengths.append(n_items_in_split) + remainder = len(dataset) - sum(subset_lengths) + + for i in range(remainder): + idx_to_add_at = i % len(subset_lengths) + subset_lengths[idx_to_add_at] += 1 + lengths = subset_lengths + for i, length in enumerate(lengths): + if length == 0: + warnings.warn( + f"Length of split at index {i} is 0. " + f"This might result in an empty dataset." + ) + # Cannot verify that dataset is Sized if sum(lengths) != len(dataset): # type: ignore raise ValueError( diff --git a/python/paddle/io/dataloader/sampler.py b/python/paddle/io/dataloader/sampler.py index 9fdfe15c64122d..eef54ae3c380e4 100644 --- a/python/paddle/io/dataloader/sampler.py +++ b/python/paddle/io/dataloader/sampler.py @@ -160,8 +160,7 @@ class RandomSampler(Sampler): object which implemented :code:`__len__` to get indices as the range of :code:`dataset` length. Default None. replacement(bool, optional): If False, sample the whole dataset, If True, set :attr:`num_samples` for how many samples to draw. Default False. - num_samples(int, optional): set sample number to draw if :attr:`replacement` - is True, then it will take samples according to the number you set. Default None, disabled. + num_samples(int, optional): set sample number to draw. Default None, which is set to the length of `data_source`. generator(Generator, optional): specify a generator to sample the :code:`data_source`. Default None, disabled. Returns: @@ -212,9 +211,10 @@ def __init__( f"replacement={self.replacement}" ) - if self._num_samples is not None and not replacement: + if not self.replacement and self.num_samples > len(self.data_source): raise ValueError( - "num_samples should not be specified while replacement is False" + "num_samples should be smaller than or equal to length of data_source when replacement is False, " + f"but got num_samples: {self.num_samples} > data_source: {len(self.data_source)}" ) if not isinstance(self.num_samples, int) or self.num_samples <= 0: @@ -246,7 +246,7 @@ def __iter__(self): yield index else: for index in np.random.choice( - np.arange(n), n, replace=False + np.arange(n), self.num_samples, replace=False ).tolist(): yield index diff --git a/python/paddle/nn/layer/layers.py b/python/paddle/nn/layer/layers.py index 829494083d9d49..877d3eb1da5914 100644 --- a/python/paddle/nn/layer/layers.py +++ b/python/paddle/nn/layer/layers.py @@ -1349,10 +1349,14 @@ def named_buffers(self, prefix='', include_sublayers=True): name = layer_prefix + ('.' if layer_prefix else '') + key yield name, buffer - def clear_gradients(self): + def clear_gradients(self, set_to_zero=True): """ Clear the gradients of all parameters for this layer. + Args: + set_to_zero (bool, optional): Whether to set the trainable parameters' + gradients to zero or None. Default is True. + Returns: None @@ -1375,7 +1379,7 @@ def clear_gradients(self): """ for p in self.parameters(): if p.trainable: - p.clear_gradient() + p.clear_gradient(set_to_zero) def _build_once(self, *args, **kwargs): pass diff --git a/test/legacy_test/test_batch_sampler.py b/test/legacy_test/test_batch_sampler.py index 750a916b3b29a2..9440d9b5777fc1 100644 --- a/test/legacy_test/test_batch_sampler.py +++ b/test/legacy_test/test_batch_sampler.py @@ -87,6 +87,16 @@ def test_with_num_samples(self): rets.append(i) assert i >= 0 and i < 100 + def test_with_num_samples_and_without_replacement(self): + dataset = RandomDataset(100, 10) + sampler = RandomSampler(dataset, num_samples=80, replacement=False) + assert len(sampler) == 80 + + rets = [] + for i in iter(sampler): + rets.append(i) + assert i >= 0 and i < 100 + def test_with_generator(self): dataset = RandomDataset(100, 10) generator = iter(range(0, 60)) @@ -111,6 +121,10 @@ def test_with_generator_num_samples(self): rets.append(i) assert tuple(sorted(rets)) == tuple(range(0, 50)) + def test_with_num_samples_error(self): + dataset = RandomDataset(100, 10) + self.assertRaises(ValueError, RandomSampler, dataset, False, 120) + class TestSubsetRandomSampler(unittest.TestCase): def test_main(self): diff --git a/test/legacy_test/test_dataloader_dataset.py b/test/legacy_test/test_dataloader_dataset.py index 0d28b558d1acb8..b6e5cfe204d290 100644 --- a/test/legacy_test/test_dataloader_dataset.py +++ b/test/legacy_test/test_dataloader_dataset.py @@ -89,5 +89,34 @@ def test_multi_process(self): break +class TestRandomSplitApi(unittest.TestCase): + def test_main(self): + paddle.seed(1) + + dataset1, dataset2, dataset3 = paddle.io.random_split( + range(5), [0.3, 0.0, 0.7] + ) + + self.assertTrue(len(dataset1) == 2) + self.assertTrue(len(dataset2) == 0) + self.assertTrue(len(dataset3) == 3) + + elements_list = list(range(5)) + + for _, val in enumerate(dataset1): + elements_list.remove(val) + + for _, val in enumerate(dataset3): + elements_list.remove(val) + + self.assertTrue(len(elements_list) == 0) + + def test_errors(self): + paddle.seed(1) + self.assertRaises( + ValueError, paddle.io.random_split, range(5), [-0.2, 1.2] + ) + + if __name__ == '__main__': unittest.main() diff --git a/test/legacy_test/test_layers.py b/test/legacy_test/test_layers.py index 8529245d1fe2db..b2e3691eac705a 100644 --- a/test/legacy_test/test_layers.py +++ b/test/legacy_test/test_layers.py @@ -2465,6 +2465,28 @@ def test_support_tuple(self): self.assertTrue(model._linear.weight.dtype == paddle.float32) +class TestLayerClearGradientSetToZero(unittest.TestCase): + def test_layer_clear_gradient_set_to_zero_true(self): + with base.dygraph.guard(): + net = MyLayer() + inputs = paddle.randn([10, 1]) + outputs = net(inputs) + outputs.backward() + net.clear_gradients() + self.assertTrue( + net._linear.weight.grad.numpy() == np.array([[0.0]]) + ) + + def test_layer_clear_gradient_set_to_zero_false(self): + with base.dygraph.guard(): + net = MyLayer() + inputs = paddle.randn([10, 1]) + outputs = net(inputs) + outputs.backward() + net.clear_gradients(set_to_zero=False) + self.assertTrue(net._linear.weight.grad is None) + + if __name__ == '__main__': paddle.enable_static() unittest.main()