Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 53 additions & 16 deletions sdv/datasets/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,11 +200,12 @@ def _extract_data(bytes_io, output_folder_name):
def _get_data(modality, output_folder_name, in_memory_directory):
data = {}
if output_folder_name:
for filename in os.listdir(output_folder_name):
if filename.endswith('.csv'):
table_name = Path(filename).stem
data_path = os.path.join(output_folder_name, filename)
data[table_name] = pd.read_csv(data_path)
for root, _dirs, files in os.walk(output_folder_name):
for filename in files:
if filename.endswith('.csv'):
table_name = Path(filename).stem
data_path = os.path.join(root, filename)
data[table_name] = pd.read_csv(data_path)

else:
for filename, file_ in in_memory_directory.items():
Expand All @@ -223,6 +224,45 @@ def _get_data(modality, output_folder_name, in_memory_directory):
return data


def _get_metadata(metadata_bytes, dataset_name, output_folder_name=None):
"""Parse metadata bytes and optionally persist to ``output_folder_name``.

Args:
metadata_bytes (bytes):
Raw bytes of the metadata JSON file.
dataset_name (str):
The dataset name used when loading into ``Metadata``.
output_folder_name (str or None):
Optional folder path where to write ``metadata.json``.

Returns:
Metadata:
Parsed metadata object.
"""
try:
metadict = json.loads(metadata_bytes)
metadata = Metadata().load_from_dict(metadict, dataset_name)
except Exception as e:
raise DemoResourceNotFoundError('Failed to parse metadata JSON for the dataset.') from e

if output_folder_name:
try:
metadata_path = os.path.join(str(output_folder_name), METADATA_FILENAME)
with open(metadata_path, 'wb') as f:
f.write(metadata_bytes)

except Exception:
warnings.warn(
(
f'Error saving {METADATA_FILENAME} for dataset {dataset_name} into '
f'{output_folder_name}.',
),
DemoResourceNotFoundWarning,
)

return metadata


def download_demo(modality, dataset_name, output_folder_name=None):
"""Download a demo dataset.

Expand Down Expand Up @@ -250,15 +290,11 @@ def download_demo(modality, dataset_name, output_folder_name=None):
"""
_validate_modalities(modality)
_validate_output_folder(output_folder_name)

data_io, metadata_bytes = _download(modality, dataset_name)
in_memory_directory = _extract_data(data_io, output_folder_name)
data = _get_data(modality, output_folder_name, in_memory_directory)

try:
metadict = json.loads(metadata_bytes)
metadata = Metadata().load_from_dict(metadict, dataset_name)
except Exception as e:
raise DemoResourceNotFoundError('Failed to parse metadata JSON for the dataset.') from e
metadata = _get_metadata(metadata_bytes, dataset_name, output_folder_name)

return data, metadata

Expand Down Expand Up @@ -306,14 +342,14 @@ def get_available_demos(modality):
try:
raw = _get_data_from_bucket(yaml_key)
info = yaml.safe_load(raw) or {}
name = info.get('dataset-name') or dataset_name

size_mb_val = info.get('dataset-size-mb')
try:
size_mb = float(size_mb_val) if size_mb_val is not None else np.nan
except (ValueError, TypeError):
LOGGER.info(
f'Invalid dataset-size-mb {size_mb_val} for dataset {name}; defaulting to NaN.'
f'Invalid dataset-size-mb {size_mb_val} for dataset '
f'{dataset_name}; defaulting to NaN.'
)
size_mb = np.nan

Expand All @@ -324,19 +360,20 @@ def get_available_demos(modality):
except (ValueError, TypeError):
LOGGER.info(
f'Could not cast num_tables_val {num_tables_val} to float for '
f'dataset {name}; defaulting to NaN.'
f'dataset {dataset_name}; defaulting to NaN.'
)
num_tables_val = np.nan

try:
num_tables = int(num_tables_val) if not pd.isna(num_tables_val) else np.nan
except (ValueError, TypeError):
LOGGER.info(
f'Invalid num-tables {num_tables_val} for dataset {name} when parsing as int.'
f'Invalid num-tables {num_tables_val} for '
f'dataset {dataset_name} when parsing as int.'
)
num_tables = np.nan

tables_info['dataset_name'].append(name)
tables_info['dataset_name'].append(dataset_name)
tables_info['size_MB'].append(size_mb)
tables_info['num_tables'].append(num_tables)

Expand Down
117 changes: 117 additions & 0 deletions tests/unit/datasets/test_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
_find_text_key,
_get_data_from_bucket,
_get_first_v1_metadata_bytes,
_get_metadata,
_get_text_file_content,
_iter_metainfo_yaml_entries,
download_demo,
Expand Down Expand Up @@ -515,6 +516,31 @@ def side_effect(key):
assert row['size_MB'] == 1.1


@patch('sdv.datasets.demo._get_data_from_bucket')
@patch('sdv.datasets.demo._list_objects')
def test_get_available_demos_ignores_yaml_dataset_name_mismatch(mock_list, mock_get):
"""When YAML dataset-name mismatches folder, use folder name from S3 path."""
# Setup
mock_list.return_value = [
{'Key': 'single_table/folder_name/metainfo.yaml'},
]

# YAML uses a different name; should be ignored for dataset_name field
def side_effect(key):
return b'dataset-name: DIFFERENT\nnum-tables: 3\ndataset-size-mb: 2.5\n'

mock_get.side_effect = side_effect

# Run
df = get_available_demos('single_table')

# Assert
assert set(df['dataset_name']) == {'folder_name'}
row = df[df['dataset_name'] == 'folder_name'].iloc[0]
assert row['num_tables'] == 3
assert row['size_MB'] == 2.5


@patch('sdv.datasets.demo._get_data_from_bucket')
@patch('sdv.datasets.demo._list_objects')
def test_download_demo_success_single_table(mock_list, mock_get):
Expand Down Expand Up @@ -585,6 +611,97 @@ def test_download_demo_no_v1_metadata_raises(mock_list, mock_get):
download_demo('single_table', 'word')


@patch('builtins.open', side_effect=OSError('fail-open'))
def test__get_metadata_warns_on_save_error(_mock_open, tmp_path):
"""_get_metadata should emit a warning if writing metadata.json fails."""
# Setup
meta = {
'METADATA_SPEC_VERSION': 'V1',
'relationships': [],
'tables': {
't': {
'columns': {
'a': {'sdtype': 'numerical'},
}
}
},
}
meta_bytes = json.dumps(meta).encode()
out_dir = tmp_path / 'out'
out_dir.mkdir(parents=True, exist_ok=True)

# Run and Assert
warn_msg = 'Error saving metadata.json'
with pytest.warns(DemoResourceNotFoundWarning, match=warn_msg):
md = _get_metadata(meta_bytes, 'dataset1', str(out_dir))

assert md.to_dict() == meta


def test__get_metadata_raises_on_invalid_json():
"""_get_metadata should raise a helpful error when JSON is invalid."""
# Run / Assert
err = 'Failed to parse metadata JSON for the dataset.'
with pytest.raises(DemoResourceNotFoundError, match=re.escape(err)):
_get_metadata(b'not-json', 'dataset1')


@patch('sdv.datasets.demo._get_data_from_bucket')
@patch('sdv.datasets.demo._list_objects')
def test_download_demo_writes_metadata_and_discovers_nested_csv(mock_list, mock_get, tmp_path):
"""When output folder is set, it writes metadata.json and finds nested CSVs."""
# Setup
mock_list.return_value = [
{'Key': 'single_table/nested/data.zip'},
{'Key': 'single_table/nested/metadata.json'},
]

df = pd.DataFrame({'a': [1, 2], 'b': ['x', 'y']})
buf = io.BytesIO()
with zipfile.ZipFile(buf, mode='w', compression=zipfile.ZIP_DEFLATED) as zf:
zf.writestr('level1/level2/my_table.csv', df.to_csv(index=False))
zip_bytes = buf.getvalue()

meta_dict = {
'METADATA_SPEC_VERSION': 'V1',
'tables': {
'my_table': {
'columns': {
'a': {'sdtype': 'numerical', 'computer_representation': 'Int64'},
'b': {'sdtype': 'categorical'},
}
}
},
'relationships': [],
}
meta_bytes = json.dumps(meta_dict).encode()

def side_effect(key):
if key.endswith('data.zip'):
return zip_bytes
if key.endswith('metadata.json'):
return meta_bytes
raise KeyError(key)

mock_get.side_effect = side_effect

out = tmp_path / 'outdir'

# Run
data, metadata = download_demo('single_table', 'nested', out)

# Assert
pd.testing.assert_frame_equal(data, df)
assert metadata.to_dict() == meta_dict

meta_path = out / 'metadata.json'
assert meta_path.is_file()

with open(meta_path, 'rb') as f:
on_disk = f.read()
assert on_disk == meta_bytes


def test__find_text_key_returns_none_when_missing():
"""Test it returns None when the key is missing."""
# Setup
Expand Down
Loading