Skip to content

Commit 4233a61

Browse files
authored
[Enhance] Combined dataset supports custom sampling ratio (#2562)
1 parent cbbea68 commit 4233a61

File tree

4 files changed

+55
-1
lines changed

4 files changed

+55
-1
lines changed

docs/en/user_guides/mixed_datasets.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,11 @@ dataset = dict(
8484
# The pipeline includes typical transforms, such as loading the
8585
# image and data augmentation
8686
pipeline=train_pipeline,
87+
# The sample_ratio_factor controls the sampling ratio of
88+
# each dataset in the combined dataset. The length of sample_ratio_factor
89+
# should match the number of datasets. Each factor indicates the sampling
90+
# ratio of the corresponding dataset relative to its original length.
91+
sample_ratio_factor=[1.0, 0.5]
8792
)
8893
```
8994

docs/zh_cn/user_guides/mixed_datasets.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,9 @@ dataset = dict(
8484
# `train_pipeline` 包含了常用的数据预处理,
8585
# 比如图片读取、数据增广等
8686
pipeline=train_pipeline,
87+
# sample_ratio_factor 参数是用来调节每个子数据集
88+
# 在组合数据集中的样本数量比例的
89+
sample_ratio_factor=[1.0, 0.5]
8790
)
8891
```
8992

mmpose/datasets/dataset_wrappers.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22

33
from copy import deepcopy
4-
from typing import Any, Callable, List, Tuple, Union
4+
from typing import Any, Callable, List, Optional, Tuple, Union
55

6+
import numpy as np
67
from mmengine.dataset import BaseDataset
78
from mmengine.registry import build_from_cfg
89

@@ -18,21 +19,37 @@ class CombinedDataset(BaseDataset):
1819
metainfo (dict): The meta information of combined dataset.
1920
datasets (list): The configs of datasets to be combined.
2021
pipeline (list, optional): Processing pipeline. Defaults to [].
22+
sample_ratio_factor (list, optional): A list of sampling ratio
23+
factors for each dataset. Defaults to None
2124
"""
2225

2326
def __init__(self,
2427
metainfo: dict,
2528
datasets: list,
2629
pipeline: List[Union[dict, Callable]] = [],
30+
sample_ratio_factor: Optional[List[float]] = None,
2731
**kwargs):
2832

2933
self.datasets = []
34+
self.resample = sample_ratio_factor is not None
3035

3136
for cfg in datasets:
3237
dataset = build_from_cfg(cfg, DATASETS)
3338
self.datasets.append(dataset)
3439

3540
self._lens = [len(dataset) for dataset in self.datasets]
41+
if self.resample:
42+
assert len(sample_ratio_factor) == len(datasets), f'the length ' \
43+
f'of `sample_ratio_factor` {len(sample_ratio_factor)} does ' \
44+
f'not match the length of `datasets` {len(datasets)}'
45+
assert min(sample_ratio_factor) >= 0.0, 'the ratio values in ' \
46+
'`sample_ratio_factor` should not be negative.'
47+
self._lens_ori = self._lens
48+
self._lens = [
49+
round(l * sample_ratio_factor[i])
50+
for i, l in enumerate(self._lens_ori)
51+
]
52+
3653
self._len = sum(self._lens)
3754

3855
super(CombinedDataset, self).__init__(pipeline=pipeline, **kwargs)
@@ -71,6 +88,12 @@ def _get_subset_index(self, index: int) -> Tuple[int, int]:
7188
while index >= self._lens[subset_index]:
7289
index -= self._lens[subset_index]
7390
subset_index += 1
91+
92+
if self.resample:
93+
gap = (self._lens_ori[subset_index] -
94+
1e-4) / self._lens[subset_index]
95+
index = round(gap * index + np.random.rand() * gap - 0.5)
96+
7497
return subset_index, index
7598

7699
def prepare_data(self, idx: int) -> Any:

tests/test_datasets/test_datasets/test_dataset_wrappers/test_combined_dataset.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,29 @@ def test_get_subset_index(self):
8181
self.assertEqual(subset_idx, 0)
8282
self.assertEqual(sample_idx, lens[0] - 1)
8383

84+
# combiend dataset with resampling ratio
85+
dataset = self.build_combined_dataset(sample_ratio_factor=[1, 0.3])
86+
self.assertEqual(
87+
len(dataset),
88+
len(dataset.datasets[0]) + round(0.3 * len(dataset.datasets[1])))
89+
lens = dataset._lens
90+
91+
index = lens[0]
92+
subset_idx, sample_idx = dataset._get_subset_index(index)
93+
self.assertEqual(subset_idx, 1)
94+
self.assertIn(sample_idx, (0, 1, 2))
95+
96+
index = -lens[1] - 1
97+
subset_idx, sample_idx = dataset._get_subset_index(index)
98+
self.assertEqual(subset_idx, 0)
99+
self.assertEqual(sample_idx, lens[0] - 1)
100+
101+
with self.assertRaises(AssertionError):
102+
_ = self.build_combined_dataset(sample_ratio_factor=[1, 0.3, 0.1])
103+
104+
with self.assertRaises(AssertionError):
105+
_ = self.build_combined_dataset(sample_ratio_factor=[1, -0.3])
106+
84107
def test_prepare_data(self):
85108
dataset = self.build_combined_dataset()
86109
lens = dataset._lens

0 commit comments

Comments
 (0)