Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 21 additions & 10 deletions tests/test_dataset_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,31 +105,33 @@ class DatasetTester:
def __init__(self, parent):
self.parent = parent if parent is not None else TestCase()

def load_builder_class(self, dataset_name, is_local=False):
def load_builder_class(self, dataset_name, is_local=False, data_files=None):
# Download/copy dataset script
if is_local is True:
dataset_module = dataset_module_factory(os.path.join("datasets", dataset_name))
else:
dataset_module = dataset_module_factory(dataset_name, download_config=DownloadConfig(force_download=True))
dataset_module = dataset_module_factory(
dataset_name, download_config=DownloadConfig(force_download=True), data_files=data_files
)
# Get dataset builder class
builder_cls = import_main_class(dataset_module.module_path)
return builder_cls

def load_all_configs(self, dataset_name, is_local=False) -> List[Optional[BuilderConfig]]:
def load_all_configs(self, dataset_name, is_local=False, data_files=None) -> List[Optional[BuilderConfig]]:
# get builder class
builder_cls = self.load_builder_class(dataset_name, is_local=is_local)
builder_cls = self.load_builder_class(dataset_name, is_local=is_local, data_files=data_files)
builder = builder_cls

if len(builder.BUILDER_CONFIGS) == 0:
return [None]
return builder.BUILDER_CONFIGS

def check_load_dataset(self, dataset_name, configs, is_local=False, use_local_dummy_data=False):
def check_load_dataset(self, dataset_name, configs, is_local=False, use_local_dummy_data=False, data_files=None):
for config in configs:
with tempfile.TemporaryDirectory() as processed_temp_dir, tempfile.TemporaryDirectory() as raw_temp_dir:

# create config and dataset
dataset_builder_cls = self.load_builder_class(dataset_name, is_local=is_local)
dataset_builder_cls = self.load_builder_class(dataset_name, is_local=is_local, data_files=data_files)
config_name = config.name if config is not None else None
dataset_builder = dataset_builder_cls(config_name=config_name, cache_dir=processed_temp_dir)

Expand Down Expand Up @@ -303,20 +305,29 @@ def setUp(self):
self.dataset_tester = DatasetTester(self)

def test_load_dataset_offline(self, dataset_name):
# pass existing dummy data_files to avoid slow inferring over root directory
# overwritten afterwards with extracted dummy data
dummy_data_files = f"datasets/{dataset_name}/dummy/0.0.0/dummy_data.zip"
for offline_simulation_mode in list(OfflineSimulationMode):
with offline(offline_simulation_mode):
configs = self.dataset_tester.load_all_configs(dataset_name)[:1]
self.dataset_tester.check_load_dataset(dataset_name, configs, use_local_dummy_data=True)
configs = self.dataset_tester.load_all_configs(dataset_name, data_files=dummy_data_files)[:1]
self.dataset_tester.check_load_dataset(
dataset_name, configs, use_local_dummy_data=True, data_files=dummy_data_files
)

def test_builder_class(self, dataset_name):
builder_cls = self.dataset_tester.load_builder_class(dataset_name)
# pass existing dummy data_files to avoid slow inferring over root directory; not used afterwards
dummy_data_files = f"datasets/{dataset_name}/dummy/0.0.0/dummy_data.zip"
builder_cls = self.dataset_tester.load_builder_class(dataset_name, data_files=dummy_data_files)
config_name = builder_cls.BUILDER_CONFIGS[0].name if builder_cls.BUILDER_CONFIGS else None
with tempfile.TemporaryDirectory() as tmp_cache_dir:
builder = builder_cls(config_name=config_name, cache_dir=tmp_cache_dir)
self.assertIsInstance(builder, DatasetBuilder)

def test_builder_configs(self, dataset_name):
builder_configs = self.dataset_tester.load_all_configs(dataset_name)
# pass existing dummy data_files to avoid slow inferring over root directory; not used afterwards
dummy_data_files = f"datasets/{dataset_name}/dummy/0.0.0/dummy_data.zip"
builder_configs = self.dataset_tester.load_all_configs(dataset_name, data_files=dummy_data_files)
self.assertTrue(len(builder_configs) > 0)

if builder_configs[0] is not None:
Expand Down