Skip to content

Commit 366e87d

Browse files
authored
【Hackathon 6th No. 27】为 paddle.io.RandomSampler/random_split /Layer.clear_gradients 进行功能增强 -part (#62966)
* update layers sampler and dataset * update test * update test * revise randomsampler num_samples and test * update error check * update test * update docs * update docs
1 parent 4ade81d commit 366e87d

6 files changed

Lines changed: 101 additions & 8 deletions

File tree

python/paddle/io/dataloader/dataset.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# limitations under the License.
1414

1515
import bisect
16+
import math
17+
import warnings
1618
from typing import Iterable
1719

1820
import paddle
@@ -487,7 +489,7 @@ def random_split(dataset, lengths, generator=None):
487489
488490
Args:
489491
dataset (Dataset): Dataset to be split
490-
lengths (sequence): lengths of splits to be produced
492+
lengths (sequence): lengths or fractions of splits to be produced
491493
generator (Generator, optional): Generator used for the random permutation. Default is None then the DefaultGenerator is used in manual_seed().
492494
493495
Returns:
@@ -522,6 +524,28 @@ def random_split(dataset, lengths, generator=None):
522524
5 3
523525
6 8
524526
"""
527+
if math.isclose(sum(lengths), 1) and sum(lengths) <= 1:
528+
subset_lengths = []
529+
for i, frac in enumerate(lengths):
530+
if frac < 0 or frac > 1:
531+
raise ValueError(
532+
f"Fraction at index {i} is not between 0 and 1"
533+
)
534+
n_items_in_split = int(math.floor(len(dataset) * frac))
535+
subset_lengths.append(n_items_in_split)
536+
remainder = len(dataset) - sum(subset_lengths)
537+
538+
for i in range(remainder):
539+
idx_to_add_at = i % len(subset_lengths)
540+
subset_lengths[idx_to_add_at] += 1
541+
lengths = subset_lengths
542+
for i, length in enumerate(lengths):
543+
if length == 0:
544+
warnings.warn(
545+
f"Length of split at index {i} is 0. "
546+
f"This might result in an empty dataset."
547+
)
548+
525549
# Cannot verify that dataset is Sized
526550
if sum(lengths) != len(dataset): # type: ignore
527551
raise ValueError(

python/paddle/io/dataloader/sampler.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,7 @@ class RandomSampler(Sampler):
160160
object which implemented :code:`__len__` to get indices as the range of :code:`dataset` length. Default None.
161161
replacement(bool, optional): If False, sample the whole dataset, If True,
162162
set :attr:`num_samples` for how many samples to draw. Default False.
163-
num_samples(int, optional): set sample number to draw if :attr:`replacement`
164-
is True, then it will take samples according to the number you set. Default None, disabled.
163+
num_samples(int, optional): set sample number to draw. Default None, which is set to the length of `data_source`.
165164
generator(Generator, optional): specify a generator to sample the :code:`data_source`. Default None, disabled.
166165
167166
Returns:
@@ -212,9 +211,10 @@ def __init__(
212211
f"replacement={self.replacement}"
213212
)
214213

215-
if self._num_samples is not None and not replacement:
214+
if not self.replacement and self.num_samples > len(self.data_source):
216215
raise ValueError(
217-
"num_samples should not be specified while replacement is False"
216+
"num_samples should be smaller than or equal to length of data_source when replacement is False, "
217+
f"but got num_samples: {self.num_samples} > data_source: {len(self.data_source)}"
218218
)
219219

220220
if not isinstance(self.num_samples, int) or self.num_samples <= 0:
@@ -246,7 +246,7 @@ def __iter__(self):
246246
yield index
247247
else:
248248
for index in np.random.choice(
249-
np.arange(n), n, replace=False
249+
np.arange(n), self.num_samples, replace=False
250250
).tolist():
251251
yield index
252252

python/paddle/nn/layer/layers.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1349,10 +1349,14 @@ def named_buffers(self, prefix='', include_sublayers=True):
13491349
name = layer_prefix + ('.' if layer_prefix else '') + key
13501350
yield name, buffer
13511351

1352-
def clear_gradients(self):
1352+
def clear_gradients(self, set_to_zero=True):
13531353
"""
13541354
Clear the gradients of all parameters for this layer.
13551355
1356+
Args:
1357+
set_to_zero (bool, optional): Whether to set the trainable parameters'
1358+
gradients to zero or None. Default is True.
1359+
13561360
Returns:
13571361
None
13581362
@@ -1375,7 +1379,7 @@ def clear_gradients(self):
13751379
"""
13761380
for p in self.parameters():
13771381
if p.trainable:
1378-
p.clear_gradient()
1382+
p.clear_gradient(set_to_zero)
13791383

13801384
def _build_once(self, *args, **kwargs):
13811385
pass

test/legacy_test/test_batch_sampler.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,16 @@ def test_with_num_samples(self):
8787
rets.append(i)
8888
assert i >= 0 and i < 100
8989

90+
def test_with_num_samples_and_without_replacement(self):
91+
dataset = RandomDataset(100, 10)
92+
sampler = RandomSampler(dataset, num_samples=80, replacement=False)
93+
assert len(sampler) == 80
94+
95+
rets = []
96+
for i in iter(sampler):
97+
rets.append(i)
98+
assert i >= 0 and i < 100
99+
90100
def test_with_generator(self):
91101
dataset = RandomDataset(100, 10)
92102
generator = iter(range(0, 60))
@@ -111,6 +121,10 @@ def test_with_generator_num_samples(self):
111121
rets.append(i)
112122
assert tuple(sorted(rets)) == tuple(range(0, 50))
113123

124+
def test_with_num_samples_error(self):
125+
dataset = RandomDataset(100, 10)
126+
self.assertRaises(ValueError, RandomSampler, dataset, False, 120)
127+
114128

115129
class TestSubsetRandomSampler(unittest.TestCase):
116130
def test_main(self):

test/legacy_test/test_dataloader_dataset.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,5 +89,34 @@ def test_multi_process(self):
8989
break
9090

9191

92+
class TestRandomSplitApi(unittest.TestCase):
93+
def test_main(self):
94+
paddle.seed(1)
95+
96+
dataset1, dataset2, dataset3 = paddle.io.random_split(
97+
range(5), [0.3, 0.0, 0.7]
98+
)
99+
100+
self.assertTrue(len(dataset1) == 2)
101+
self.assertTrue(len(dataset2) == 0)
102+
self.assertTrue(len(dataset3) == 3)
103+
104+
elements_list = list(range(5))
105+
106+
for _, val in enumerate(dataset1):
107+
elements_list.remove(val)
108+
109+
for _, val in enumerate(dataset3):
110+
elements_list.remove(val)
111+
112+
self.assertTrue(len(elements_list) == 0)
113+
114+
def test_errors(self):
115+
paddle.seed(1)
116+
self.assertRaises(
117+
ValueError, paddle.io.random_split, range(5), [-0.2, 1.2]
118+
)
119+
120+
92121
if __name__ == '__main__':
93122
unittest.main()

test/legacy_test/test_layers.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2465,6 +2465,28 @@ def test_support_tuple(self):
24652465
self.assertTrue(model._linear.weight.dtype == paddle.float32)
24662466

24672467

2468+
class TestLayerClearGradientSetToZero(unittest.TestCase):
2469+
def test_layer_clear_gradient_set_to_zero_true(self):
2470+
with base.dygraph.guard():
2471+
net = MyLayer()
2472+
inputs = paddle.randn([10, 1])
2473+
outputs = net(inputs)
2474+
outputs.backward()
2475+
net.clear_gradients()
2476+
self.assertTrue(
2477+
net._linear.weight.grad.numpy() == np.array([[0.0]])
2478+
)
2479+
2480+
def test_layer_clear_gradient_set_to_zero_false(self):
2481+
with base.dygraph.guard():
2482+
net = MyLayer()
2483+
inputs = paddle.randn([10, 1])
2484+
outputs = net(inputs)
2485+
outputs.backward()
2486+
net.clear_gradients(set_to_zero=False)
2487+
self.assertTrue(net._linear.weight.grad is None)
2488+
2489+
24682490
if __name__ == '__main__':
24692491
paddle.enable_static()
24702492
unittest.main()

0 commit comments

Comments
 (0)