Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions sdv/datasets/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,3 +338,113 @@ def get_available_demos(modality):
continue

return pd.DataFrame(tables_info)


def _find_text_key(contents, dataset_prefix, filename):
"""Find a text file key (README.txt or SOURCE.txt).

Performs a case-insensitive search for ``filename`` directly under ``dataset_prefix``.

Args:
contents (list[dict]):
List of objects from S3.
dataset_prefix (str):
Prefix like 'single_table/dataset/'.
filename (str):
The filename to look for (e.g., 'README.txt').

Returns:
str or None:
The key if found, otherwise ``None``.
"""
expected_lower = f'{dataset_prefix}{filename}'.lower()
for entry in contents:
key = entry.get('Key') or ''
if key.lower() == expected_lower:
return key

return None


def _get_text_file_content(modality, dataset_name, filename, output_filepath=None):
"""Fetch text file content under the dataset prefix.

Args:
modality (str):
The modality of the dataset: ``'single_table'``, ``'multi_table'``, ``'sequential'``.
dataset_name (str):
The name of the dataset.
filename (str):
The filename to fetch (``'README.txt'`` or ``'SOURCE.txt'``).
output_filepath (str or None):
If provided, save the file contents at this path.

Returns:
str or None:
The decoded text contents if the file exists, otherwise ``None``.
"""
_validate_modalities(modality)
dataset_prefix = f'{modality}/{dataset_name}/'
contents = _list_objects(dataset_prefix)

key = _find_text_key(contents, dataset_prefix, filename)
if not key:
LOGGER.info(f'No {filename} found for dataset {dataset_name}.')
return None

try:
raw = _get_data_from_bucket(key)
except Exception:
LOGGER.info(f'Error fetching {filename} for dataset {dataset_name}.')
return None

text = raw.decode('utf-8', errors='replace')
if output_filepath:
try:
parent = os.path.dirname(str(output_filepath))
if parent:
os.makedirs(parent, exist_ok=True)
with open(output_filepath, 'w', encoding='utf-8') as f:
f.write(text)

except Exception:
LOGGER.info(f'Error saving {filename} for dataset {dataset_name}.')
pass

return text


def get_source(modality, dataset_name, output_filepath=None):
"""Get dataset source/citation text.

Args:
modality (str):
The modality of the dataset: ``'single_table'``, ``'multi_table'``, ``'sequential'``.
dataset_name (str):
The name of the dataset to get the source information for.
output_filepath (str or None):
Optional path where to save the file.

Returns:
str or None:
The contents of the source file if it exists; otherwise ``None``.
"""
return _get_text_file_content(modality, dataset_name, 'SOURCE.txt', output_filepath)


def get_readme(modality, dataset_name, output_filepath=None):
"""Get dataset README text.

Args:
modality (str):
The modality of the dataset: ``'single_table'``, ``'multi_table'``, ``'sequential'``.
dataset_name (str):
The name of the dataset to get the README for.
output_filepath (str or None):
Optional path where to save the file.

Returns:
str or None:
The contents of the README file if it exists; otherwise ``None``.
"""
return _get_text_file_content(modality, dataset_name, 'README.txt', output_filepath)
29 changes: 28 additions & 1 deletion tests/integration/datasets/test_demo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pandas as pd

from sdv.datasets.demo import get_available_demos
from sdv.datasets.demo import get_available_demos, get_readme, get_source


def test_get_available_demos_single_table():
Expand Down Expand Up @@ -85,3 +85,30 @@ def test_get_available_demos_multi_table():
'num_tables': [2, 2],
})
pd.testing.assert_frame_equal(tables_info[['dataset_name', 'size_MB', 'num_tables']], expected)


def test_get_readme_and_source_single_table_dataset1(tmp_path):
"""Test it returns the README and SOURCE for a single table dataset."""
# Run
readme = get_readme('single_table', 'dataset1')
source = get_source('single_table', 'dataset1')

# Assert
assert isinstance(readme, str) and 'sample dataset' in readme.lower()
assert isinstance(source, str) and source.strip() == 'unknown'

readme_out = tmp_path / 'r.txt'
source_out = tmp_path / 's.txt'
readme2 = get_readme('single_table', 'dataset1', str(readme_out))
source2 = get_source('single_table', 'dataset1', str(source_out))
assert readme2 == readme
assert source2 == source
assert readme_out.read_text(encoding='utf-8').strip() == readme.strip()
assert source_out.read_text(encoding='utf-8').strip() == source.strip()


def test_get_readme_missing_returns_none():
"""Test it returns None when the README/SOURCE is missing."""
# Run and Assert
assert get_readme('single_table', 'dataset2') is None
assert get_source('single_table', 'dataset2') is None
190 changes: 190 additions & 0 deletions tests/unit/datasets/test_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,15 @@
from sdv.datasets.demo import (
_download,
_find_data_zip_key,
_find_text_key,
_get_data_from_bucket,
_get_first_v1_metadata_bytes,
_get_text_file_content,
_iter_metainfo_yaml_entries,
download_demo,
get_available_demos,
get_readme,
get_source,
)
from sdv.errors import DemoResourceNotFoundError

Expand Down Expand Up @@ -579,3 +583,189 @@ def test_download_demo_no_v1_metadata_raises(mock_list, mock_get):
# Run and Assert
with pytest.raises(DemoResourceNotFoundError, match='METADATA_SPEC_VERSION'):
download_demo('single_table', 'word')


def test__find_text_key_returns_none_when_missing():
"""Test it returns None when the key is missing."""
# Setup
contents = [
{'Key': 'single_table/dataset/metadata.json'},
{'Key': 'single_table/dataset/data.zip'},
]
dataset_prefix = 'single_table/dataset/'

# Run
key = _find_text_key(contents, dataset_prefix, 'README.txt')

# Assert
assert key is None


def test__find_text_key_ignores_nested_paths():
"""Test it ignores files in nested folders under the dataset prefix."""
# Setup
contents = [
{'Key': 'single_table/dataset1/bad_folder/SOURCE.txt'},
]
dataset_prefix = 'single_table/dataset1/'

# Run
key = _find_text_key(contents, dataset_prefix, 'SOURCE.txt')

# Assert
assert key is None


@patch('sdv.datasets.demo._get_data_from_bucket')
@patch('sdv.datasets.demo._list_objects')
def test__get_text_file_content_happy_path(mock_list, mock_get, tmpdir):
"""Test it gets the text file content when it exists."""
# Setup
mock_list.return_value = [
{'Key': 'single_table/dataset1/README.txt'},
]
mock_get.return_value = 'Hello README'.encode()

# Run
text = _get_text_file_content('single_table', 'dataset1', 'README.txt')

# Assert
assert text == 'Hello README'


@patch('sdv.datasets.demo._list_objects')
def test__get_text_file_content_missing_key_returns_none(mock_list):
"""Test it returns None when the key is missing."""
# Setup
mock_list.return_value = [
{'Key': 'single_table/dataset1/metadata.json'},
]

# Run
text = _get_text_file_content('single_table', 'dataset1', 'README.txt')

# Assert
assert text is None


@patch('sdv.datasets.demo._list_objects')
def test__get_text_file_content_logs_when_missing_key(mock_list, caplog):
"""It logs an info when the key is missing under the dataset prefix."""
# Setup
mock_list.return_value = [
{'Key': 'single_table/dataset1/metadata.json'},
]

# Run
caplog.set_level(logging.INFO, logger='sdv.datasets.demo')
text = _get_text_file_content('single_table', 'dataset1', 'README.txt')

# Assert
assert text is None
assert 'No README.txt found for dataset dataset1.' in caplog.text


@patch('sdv.datasets.demo._get_data_from_bucket')
@patch('sdv.datasets.demo._list_objects')
def test__get_text_file_content_fetch_error_returns_none(mock_list, mock_get):
"""Test it returns None when the fetch error occurs."""
# Setup
mock_list.return_value = [
{'Key': 'single_table/dataset1/SOURCE.txt'},
]
mock_get.side_effect = Exception('boom')

# Run
text = _get_text_file_content('single_table', 'dataset1', 'SOURCE.txt')

# Assert
assert text is None


@patch('sdv.datasets.demo._get_data_from_bucket')
@patch('sdv.datasets.demo._list_objects')
def test__get_text_file_content_logs_on_fetch_error(mock_list, mock_get, caplog):
"""It logs an info when fetching the key raises an error."""
# Setup
mock_list.return_value = [
{'Key': 'single_table/dataset1/SOURCE.txt'},
]
mock_get.side_effect = Exception('boom')

# Run
caplog.set_level(logging.INFO, logger='sdv.datasets.demo')
text = _get_text_file_content('single_table', 'dataset1', 'SOURCE.txt')

# Assert
assert text is None
assert 'Error fetching SOURCE.txt for dataset dataset1.' in caplog.text


@patch('sdv.datasets.demo._get_data_from_bucket')
@patch('sdv.datasets.demo._list_objects')
def test__get_text_file_content_writes_file_when_output_filepath_given(
mock_list, mock_get, tmp_path
):
"""Test it writes the file when the output filepath is given."""
# Setup
mock_list.return_value = [
{'Key': 'single_table/dataset1/README.txt'},
]
mock_get.return_value = 'Write me'.encode()
out = tmp_path / 'subdir' / 'readme.txt'

# Run
text = _get_text_file_content('single_table', 'dataset1', 'README.txt', str(out))

# Assert
assert text == 'Write me'
with open(out, 'r', encoding='utf-8') as f:
assert f.read() == 'Write me'


@patch('sdv.datasets.demo._get_data_from_bucket')
@patch('sdv.datasets.demo._list_objects')
def test__get_text_file_content_logs_on_save_error(
mock_list, mock_get, tmp_path, caplog, monkeypatch
):
"""It logs an info when saving to disk fails."""
# Setup
mock_list.return_value = [
{'Key': 'single_table/dataset1/README.txt'},
]
mock_get.return_value = 'Write me'.encode()
out = tmp_path / 'subdir' / 'readme.txt'

def _fail_open(*args, **kwargs):
raise OSError('fail-open')

monkeypatch.setattr('builtins.open', _fail_open)

# Run
caplog.set_level(logging.INFO, logger='sdv.datasets.demo')
text = _get_text_file_content('single_table', 'dataset1', 'README.txt', str(out))

# Assert
assert text == 'Write me'
assert 'Error saving README.txt for dataset dataset1.' in caplog.text


def test_get_readme_and_get_source_call_wrapper(monkeypatch):
"""Test it calls the wrapper function when the output filepath is given."""
# Setup
calls = []

def fake(modality, dataset_name, filename, output_filepath=None):
calls.append((modality, dataset_name, filename, output_filepath))
return 'X'

monkeypatch.setattr('sdv.datasets.demo._get_text_file_content', fake)

# Run
r = get_readme('single_table', 'dataset1', '/tmp/readme')
s = get_source('single_table', 'dataset1', '/tmp/source')

# Assert
assert r == 'X' and s == 'X'
assert calls[0] == ('single_table', 'dataset1', 'README.txt', '/tmp/readme')
assert calls[1] == ('single_table', 'dataset1', 'SOURCE.txt', '/tmp/source')
Loading