Skip to content

Commit 47c09f4

Browse files
authored
Update the download_demo and get_available_demos functions (#2669)
1 parent 1705ced commit 47c09f4

File tree

6 files changed

+895
-489
lines changed

6 files changed

+895
-489
lines changed

sdv/datasets/demo.py

Lines changed: 212 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,23 @@
44
import json
55
import logging
66
import os
7-
import warnings
87
from collections import defaultdict
98
from pathlib import Path
109
from zipfile import ZipFile
1110

1211
import boto3
1312
import numpy as np
1413
import pandas as pd
14+
import yaml
1515
from botocore import UNSIGNED
1616
from botocore.client import Config
17-
from botocore.exceptions import ClientError
1817

18+
from sdv.errors import DemoResourceNotFoundError
1919
from sdv.metadata.metadata import Metadata
2020

2121
LOGGER = logging.getLogger(__name__)
22-
BUCKET = 'sdv-demo-datasets'
23-
BUCKET_URL = 'https://sdv-demo-datasets.s3.amazonaws.com'
22+
BUCKET = 'sdv-datasets-public'
23+
BUCKET_URL = f'https://{BUCKET}.s3.amazonaws.com'
2424
SIGNATURE_VERSION = UNSIGNED
2525
METADATA_FILENAME = 'metadata.json'
2626

@@ -39,41 +39,154 @@ def _validate_output_folder(output_folder_name):
3939
)
4040

4141

42+
def _create_s3_client():
43+
"""Create and return an S3 client with unsigned requests."""
44+
return boto3.client('s3', config=Config(signature_version=SIGNATURE_VERSION))
45+
46+
4247
def _get_data_from_bucket(object_key):
43-
session = boto3.Session()
44-
s3 = session.client('s3', config=Config(signature_version=SIGNATURE_VERSION))
48+
s3 = _create_s3_client()
4549
response = s3.get_object(Bucket=BUCKET, Key=object_key)
4650
return response['Body'].read()
4751

4852

49-
def _download(modality, dataset_name):
50-
dataset_url = f'{BUCKET_URL}/{modality.upper()}/{dataset_name}.zip'
51-
object_key = f'{modality.upper()}/{dataset_name}.zip'
52-
LOGGER.info(f'Downloading dataset {dataset_name} from {dataset_url}')
53-
try:
54-
file_content = _get_data_from_bucket(object_key)
55-
except ClientError:
56-
raise ValueError(
57-
f"Invalid dataset name '{dataset_name}'. "
58-
'Make sure you have the correct modality for the dataset name or '
59-
"use 'get_available_demos' to get a list of demo datasets."
53+
def _list_objects(prefix):
54+
"""List all objects under a given prefix using pagination.
55+
56+
Args:
57+
prefix (str):
58+
The S3 prefix to list.
59+
60+
Returns:
61+
list[dict]:
62+
A list of object summaries.
63+
"""
64+
client = _create_s3_client()
65+
contents = []
66+
paginator = client.get_paginator('list_objects_v2')
67+
for resp in paginator.paginate(Bucket=BUCKET, Prefix=prefix):
68+
contents.extend(resp.get('Contents', []))
69+
70+
if not contents:
71+
raise DemoResourceNotFoundError(f"No objects found under '{prefix}' in bucket '{BUCKET}'.")
72+
73+
return contents
74+
75+
76+
def _search_contents_keys(contents, match_fn):
77+
"""Return list of keys from ``contents`` that satisfy ``match_fn``.
78+
79+
Args:
80+
contents (list[dict]):
81+
S3 list_objects-like contents entries.
82+
match_fn (callable):
83+
Function that receives a key (str) and returns True if it matches.
84+
85+
Returns:
86+
list[str]:
87+
Keys in their original order that matched the predicate.
88+
"""
89+
matches = []
90+
for entry in contents or []:
91+
key = entry.get('Key', '')
92+
try:
93+
if match_fn(key):
94+
matches.append(key)
95+
except Exception:
96+
continue
97+
98+
return matches
99+
100+
101+
def _find_data_zip_key(contents, dataset_prefix):
102+
"""Find the 'data.zip' object key under dataset prefix, case-insensitive.
103+
104+
Args:
105+
contents (list[dict]):
106+
List of objects from S3.
107+
dataset_prefix (str):
108+
Prefix like 'single_table/dataset/'.
109+
110+
Returns:
111+
str:
112+
The key to the data zip if found.
113+
"""
114+
prefix_lower = dataset_prefix.lower()
115+
116+
def is_data_zip(key):
117+
return key.lower() == f'{prefix_lower}data.zip'
118+
119+
matches = _search_contents_keys(contents, is_data_zip)
120+
if matches:
121+
return matches[0]
122+
123+
raise DemoResourceNotFoundError("Could not find 'data.zip' for the requested dataset.")
124+
125+
126+
def _get_first_v1_metadata_bytes(contents, dataset_prefix):
127+
"""Find and return bytes of the first V1 metadata JSON under `dataset_prefix`.
128+
129+
Scans S3 listing `contents` and, for any JSON file directly under the dataset prefix,
130+
downloads and returns its bytes if it contains METADATA_SPEC_VERSION == 'V1'.
131+
132+
Returns:
133+
bytes:
134+
The bytes of the first V1 metadata JSON.
135+
"""
136+
prefix_lower = dataset_prefix.lower()
137+
138+
def is_direct_json_under_prefix(key):
139+
key_lower = key.lower()
140+
return (
141+
key_lower.startswith(prefix_lower)
142+
and key_lower.endswith('.json')
143+
and 'metadata' in key_lower
144+
and key_lower.count('/') == prefix_lower.count('/')
60145
)
61146

62-
return io.BytesIO(file_content)
147+
candidate_keys = _search_contents_keys(contents, is_direct_json_under_prefix)
148+
149+
for key in candidate_keys:
150+
try:
151+
raw = _get_data_from_bucket(key)
152+
metadict = json.loads(raw)
153+
if isinstance(metadict, dict) and metadict.get('METADATA_SPEC_VERSION') == 'V1':
154+
return raw
155+
156+
except Exception:
157+
continue
158+
159+
raise DemoResourceNotFoundError(
160+
'Could not find a valid metadata JSON with METADATA_SPEC_VERSION "V1".'
161+
)
162+
163+
164+
def _download(modality, dataset_name):
165+
"""Download dataset resources from a bucket.
166+
167+
Returns:
168+
tuple:
169+
(BytesIO(zip_bytes), metadata_bytes)
170+
"""
171+
dataset_prefix = f'{modality}/{dataset_name}/'
172+
LOGGER.info(
173+
f"Downloading dataset '{dataset_name}' for modality '{modality}' from "
174+
f'{BUCKET_URL}/{dataset_prefix}'
175+
)
176+
contents = _list_objects(dataset_prefix)
177+
178+
zip_key = _find_data_zip_key(contents, dataset_prefix)
179+
zip_bytes = _get_data_from_bucket(zip_key)
180+
metadata_bytes = _get_first_v1_metadata_bytes(contents, dataset_prefix)
181+
182+
return io.BytesIO(zip_bytes), metadata_bytes
63183

64184

65185
def _extract_data(bytes_io, output_folder_name):
66186
with ZipFile(bytes_io) as zf:
67187
if output_folder_name:
68188
os.makedirs(output_folder_name, exist_ok=True)
69189
zf.extractall(output_folder_name)
70-
metadata_v0_filepath = os.path.join(output_folder_name, 'metadata_v0.json')
71-
if os.path.isfile(metadata_v0_filepath):
72-
os.remove(metadata_v0_filepath)
73-
os.rename(
74-
os.path.join(output_folder_name, 'metadata_v1.json'),
75-
os.path.join(output_folder_name, METADATA_FILENAME),
76-
)
77190

78191
else:
79192
in_memory_directory = {}
@@ -104,32 +217,14 @@ def _get_data(modality, output_folder_name, in_memory_directory):
104217
return data
105218

106219

107-
def _get_metadata(output_folder_name, in_memory_directory, dataset_name):
108-
metadata = Metadata()
109-
if output_folder_name:
110-
metadata_path = os.path.join(output_folder_name, METADATA_FILENAME)
111-
metadata = metadata.load_from_json(metadata_path, dataset_name)
112-
113-
else:
114-
metadata_path = 'metadata_v2.json'
115-
if metadata_path not in in_memory_directory:
116-
warnings.warn(f'Metadata for {dataset_name} is missing updated version v2.')
117-
metadata_path = 'metadata_v1.json'
118-
119-
metadict = json.loads(in_memory_directory[metadata_path])
120-
metadata = metadata.load_from_dict(metadict, dataset_name)
121-
122-
return metadata
123-
124-
125220
def download_demo(modality, dataset_name, output_folder_name=None):
126221
"""Download a demo dataset.
127222
128223
Args:
129224
modality (str):
130225
The modality of the dataset: ``'single_table'``, ``'multi_table'``, ``'sequential'``.
131226
dataset_name (str):
132-
Name of the dataset to be downloaded from the sdv-datasets S3 bucket.
227+
Name of the dataset to be downloaded from the sdv-datasets-public S3 bucket.
133228
output_folder_name (str or None):
134229
The name of the local folder where the metadata and data should be stored.
135230
If ``None`` the data is not saved locally and is loaded as a Python object.
@@ -149,14 +244,41 @@ def download_demo(modality, dataset_name, output_folder_name=None):
149244
"""
150245
_validate_modalities(modality)
151246
_validate_output_folder(output_folder_name)
152-
bytes_io = _download(modality, dataset_name)
153-
in_memory_directory = _extract_data(bytes_io, output_folder_name)
247+
data_io, metadata_bytes = _download(modality, dataset_name)
248+
in_memory_directory = _extract_data(data_io, output_folder_name)
154249
data = _get_data(modality, output_folder_name, in_memory_directory)
155-
metadata = _get_metadata(output_folder_name, in_memory_directory, dataset_name)
250+
251+
try:
252+
metadict = json.loads(metadata_bytes)
253+
metadata = Metadata().load_from_dict(metadict, dataset_name)
254+
except Exception as e:
255+
raise DemoResourceNotFoundError('Failed to parse metadata JSON for the dataset.') from e
156256

157257
return data, metadata
158258

159259

260+
def _iter_metainfo_yaml_entries(contents, modality):
261+
"""Yield (dataset_name, yaml_key) for metainfo.yaml files under a modality.
262+
263+
This matches keys like '<modality>/<dataset>/metainfo.yaml'.
264+
"""
265+
modality_lower = (modality or '').lower()
266+
267+
def is_metainfo_yaml(key):
268+
parts = key.split('/')
269+
if len(parts) != 3:
270+
return False
271+
if parts[0].lower() != modality_lower:
272+
return False
273+
if parts[-1].lower() != 'metainfo.yaml':
274+
return False
275+
return bool(parts[1])
276+
277+
for key in _search_contents_keys(contents, is_metainfo_yaml):
278+
dataset_name = key.split('/')[1]
279+
yield dataset_name, key
280+
281+
160282
def get_available_demos(modality):
161283
"""Get demo datasets available for a ``modality``.
162284
@@ -170,23 +292,49 @@ def get_available_demos(modality):
170292
``dataset_name``: The name of the dataset.
171293
``size_MB``: The unzipped folder size in MB.
172294
``num_tables``: The number of tables in the dataset.
173-
174-
Raises:
175-
Error:
176-
* If ``modality`` is not ``'single_table'``, ``'multi_table'`` or ``'sequential'``.
177295
"""
178296
_validate_modalities(modality)
179-
client = boto3.client('s3', config=Config(signature_version=SIGNATURE_VERSION))
297+
contents = _list_objects(f'{modality}/')
180298
tables_info = defaultdict(list)
181-
for item in client.list_objects(Bucket=BUCKET)['Contents']:
182-
dataset_modality, dataset = item['Key'].split('/', 1)
183-
if dataset_modality == modality.upper():
184-
tables_info['dataset_name'].append(dataset.replace('.zip', ''))
185-
headers = client.head_object(Bucket=BUCKET, Key=item['Key'])['Metadata']
186-
size_mb = headers.get('size-mb', np.nan)
187-
tables_info['size_MB'].append(round(float(size_mb), 2))
188-
tables_info['num_tables'].append(headers.get('num-tables', np.nan))
189-
190-
df = pd.DataFrame(tables_info)
191-
df['num_tables'] = pd.to_numeric(df['num_tables'])
192-
return df
299+
for dataset_name, yaml_key in _iter_metainfo_yaml_entries(contents, modality):
300+
try:
301+
raw = _get_data_from_bucket(yaml_key)
302+
info = yaml.safe_load(raw) or {}
303+
name = info.get('dataset-name') or dataset_name
304+
305+
size_mb_val = info.get('dataset-size-mb')
306+
try:
307+
size_mb = float(size_mb_val) if size_mb_val is not None else np.nan
308+
except (ValueError, TypeError):
309+
LOGGER.info(
310+
f'Invalid dataset-size-mb {size_mb_val} for dataset {name}; defaulting to NaN.'
311+
)
312+
size_mb = np.nan
313+
314+
num_tables_val = info.get('num-tables', np.nan)
315+
if isinstance(num_tables_val, str):
316+
try:
317+
num_tables_val = float(num_tables_val)
318+
except (ValueError, TypeError):
319+
LOGGER.info(
320+
f'Could not cast num_tables_val {num_tables_val} to float for '
321+
f'dataset {name}; defaulting to NaN.'
322+
)
323+
num_tables_val = np.nan
324+
325+
try:
326+
num_tables = int(num_tables_val) if not pd.isna(num_tables_val) else np.nan
327+
except (ValueError, TypeError):
328+
LOGGER.info(
329+
f'Invalid num-tables {num_tables_val} for dataset {name} when parsing as int.'
330+
)
331+
num_tables = np.nan
332+
333+
tables_info['dataset_name'].append(name)
334+
tables_info['size_MB'].append(size_mb)
335+
tables_info['num_tables'].append(num_tables)
336+
337+
except Exception:
338+
continue
339+
340+
return pd.DataFrame(tables_info)

sdv/errors.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,11 @@ class RefitWarning(UserWarning):
9191
Warning to be raised if a change to a synthesizer requires the synthesizer
9292
to be refit for the change to be applied.
9393
"""
94+
95+
96+
class DemoResourceNotFoundError(Exception):
97+
"""Raised when a demo dataset or one of its resources cannot be found.
98+
99+
This error is intended for missing demo assets such as the dataset archive,
100+
metadata, license, README, or other auxiliary files in the demo bucket.
101+
"""

0 commit comments

Comments
 (0)