44import json
55import logging
66import os
7- import warnings
87from collections import defaultdict
98from pathlib import Path
109from zipfile import ZipFile
1110
1211import boto3
1312import numpy as np
1413import pandas as pd
14+ import yaml
1515from botocore import UNSIGNED
1616from botocore .client import Config
17- from botocore .exceptions import ClientError
1817
18+ from sdv .errors import DemoResourceNotFoundError
1919from sdv .metadata .metadata import Metadata
2020
2121LOGGER = logging .getLogger (__name__ )
22- BUCKET = 'sdv-demo- datasets'
23- BUCKET_URL = 'https://sdv-demo-datasets .s3.amazonaws.com'
22+ BUCKET = 'sdv-datasets-public '
23+ BUCKET_URL = f 'https://{ BUCKET } .s3.amazonaws.com'
2424SIGNATURE_VERSION = UNSIGNED
2525METADATA_FILENAME = 'metadata.json'
2626
@@ -39,41 +39,154 @@ def _validate_output_folder(output_folder_name):
3939 )
4040
4141
42+ def _create_s3_client ():
43+ """Create and return an S3 client with unsigned requests."""
44+ return boto3 .client ('s3' , config = Config (signature_version = SIGNATURE_VERSION ))
45+
46+
4247def _get_data_from_bucket (object_key ):
43- session = boto3 .Session ()
44- s3 = session .client ('s3' , config = Config (signature_version = SIGNATURE_VERSION ))
48+ s3 = _create_s3_client ()
4549 response = s3 .get_object (Bucket = BUCKET , Key = object_key )
4650 return response ['Body' ].read ()
4751
4852
49- def _download (modality , dataset_name ):
50- dataset_url = f'{ BUCKET_URL } /{ modality .upper ()} /{ dataset_name } .zip'
51- object_key = f'{ modality .upper ()} /{ dataset_name } .zip'
52- LOGGER .info (f'Downloading dataset { dataset_name } from { dataset_url } ' )
53- try :
54- file_content = _get_data_from_bucket (object_key )
55- except ClientError :
56- raise ValueError (
57- f"Invalid dataset name '{ dataset_name } '. "
58- 'Make sure you have the correct modality for the dataset name or '
59- "use 'get_available_demos' to get a list of demo datasets."
53+ def _list_objects (prefix ):
54+ """List all objects under a given prefix using pagination.
55+
56+ Args:
57+ prefix (str):
58+ The S3 prefix to list.
59+
60+ Returns:
61+ list[dict]:
62+ A list of object summaries.
63+ """
64+ client = _create_s3_client ()
65+ contents = []
66+ paginator = client .get_paginator ('list_objects_v2' )
67+ for resp in paginator .paginate (Bucket = BUCKET , Prefix = prefix ):
68+ contents .extend (resp .get ('Contents' , []))
69+
70+ if not contents :
71+ raise DemoResourceNotFoundError (f"No objects found under '{ prefix } ' in bucket '{ BUCKET } '." )
72+
73+ return contents
74+
75+
76+ def _search_contents_keys (contents , match_fn ):
77+ """Return list of keys from ``contents`` that satisfy ``match_fn``.
78+
79+ Args:
80+ contents (list[dict]):
81+ S3 list_objects-like contents entries.
82+ match_fn (callable):
83+ Function that receives a key (str) and returns True if it matches.
84+
85+ Returns:
86+ list[str]:
87+ Keys in their original order that matched the predicate.
88+ """
89+ matches = []
90+ for entry in contents or []:
91+ key = entry .get ('Key' , '' )
92+ try :
93+ if match_fn (key ):
94+ matches .append (key )
95+ except Exception :
96+ continue
97+
98+ return matches
99+
100+
101+ def _find_data_zip_key (contents , dataset_prefix ):
102+ """Find the 'data.zip' object key under dataset prefix, case-insensitive.
103+
104+ Args:
105+ contents (list[dict]):
106+ List of objects from S3.
107+ dataset_prefix (str):
108+ Prefix like 'single_table/dataset/'.
109+
110+ Returns:
111+ str:
112+ The key to the data zip if found.
113+ """
114+ prefix_lower = dataset_prefix .lower ()
115+
116+ def is_data_zip (key ):
117+ return key .lower () == f'{ prefix_lower } data.zip'
118+
119+ matches = _search_contents_keys (contents , is_data_zip )
120+ if matches :
121+ return matches [0 ]
122+
123+ raise DemoResourceNotFoundError ("Could not find 'data.zip' for the requested dataset." )
124+
125+
126+ def _get_first_v1_metadata_bytes (contents , dataset_prefix ):
127+ """Find and return bytes of the first V1 metadata JSON under `dataset_prefix`.
128+
129+ Scans S3 listing `contents` and, for any JSON file directly under the dataset prefix,
130+ downloads and returns its bytes if it contains METADATA_SPEC_VERSION == 'V1'.
131+
132+ Returns:
133+ bytes:
134+ The bytes of the first V1 metadata JSON.
135+ """
136+ prefix_lower = dataset_prefix .lower ()
137+
138+ def is_direct_json_under_prefix (key ):
139+ key_lower = key .lower ()
140+ return (
141+ key_lower .startswith (prefix_lower )
142+ and key_lower .endswith ('.json' )
143+ and 'metadata' in key_lower
144+ and key_lower .count ('/' ) == prefix_lower .count ('/' )
60145 )
61146
62- return io .BytesIO (file_content )
147+ candidate_keys = _search_contents_keys (contents , is_direct_json_under_prefix )
148+
149+ for key in candidate_keys :
150+ try :
151+ raw = _get_data_from_bucket (key )
152+ metadict = json .loads (raw )
153+ if isinstance (metadict , dict ) and metadict .get ('METADATA_SPEC_VERSION' ) == 'V1' :
154+ return raw
155+
156+ except Exception :
157+ continue
158+
159+ raise DemoResourceNotFoundError (
160+ 'Could not find a valid metadata JSON with METADATA_SPEC_VERSION "V1".'
161+ )
162+
163+
164+ def _download (modality , dataset_name ):
165+ """Download dataset resources from a bucket.
166+
167+ Returns:
168+ tuple:
169+ (BytesIO(zip_bytes), metadata_bytes)
170+ """
171+ dataset_prefix = f'{ modality } /{ dataset_name } /'
172+ LOGGER .info (
173+ f"Downloading dataset '{ dataset_name } ' for modality '{ modality } ' from "
174+ f'{ BUCKET_URL } /{ dataset_prefix } '
175+ )
176+ contents = _list_objects (dataset_prefix )
177+
178+ zip_key = _find_data_zip_key (contents , dataset_prefix )
179+ zip_bytes = _get_data_from_bucket (zip_key )
180+ metadata_bytes = _get_first_v1_metadata_bytes (contents , dataset_prefix )
181+
182+ return io .BytesIO (zip_bytes ), metadata_bytes
63183
64184
65185def _extract_data (bytes_io , output_folder_name ):
66186 with ZipFile (bytes_io ) as zf :
67187 if output_folder_name :
68188 os .makedirs (output_folder_name , exist_ok = True )
69189 zf .extractall (output_folder_name )
70- metadata_v0_filepath = os .path .join (output_folder_name , 'metadata_v0.json' )
71- if os .path .isfile (metadata_v0_filepath ):
72- os .remove (metadata_v0_filepath )
73- os .rename (
74- os .path .join (output_folder_name , 'metadata_v1.json' ),
75- os .path .join (output_folder_name , METADATA_FILENAME ),
76- )
77190
78191 else :
79192 in_memory_directory = {}
@@ -104,32 +217,14 @@ def _get_data(modality, output_folder_name, in_memory_directory):
104217 return data
105218
106219
107- def _get_metadata (output_folder_name , in_memory_directory , dataset_name ):
108- metadata = Metadata ()
109- if output_folder_name :
110- metadata_path = os .path .join (output_folder_name , METADATA_FILENAME )
111- metadata = metadata .load_from_json (metadata_path , dataset_name )
112-
113- else :
114- metadata_path = 'metadata_v2.json'
115- if metadata_path not in in_memory_directory :
116- warnings .warn (f'Metadata for { dataset_name } is missing updated version v2.' )
117- metadata_path = 'metadata_v1.json'
118-
119- metadict = json .loads (in_memory_directory [metadata_path ])
120- metadata = metadata .load_from_dict (metadict , dataset_name )
121-
122- return metadata
123-
124-
125220def download_demo (modality , dataset_name , output_folder_name = None ):
126221 """Download a demo dataset.
127222
128223 Args:
129224 modality (str):
130225 The modality of the dataset: ``'single_table'``, ``'multi_table'``, ``'sequential'``.
131226 dataset_name (str):
132- Name of the dataset to be downloaded from the sdv-datasets S3 bucket.
227+ Name of the dataset to be downloaded from the sdv-datasets-public S3 bucket.
133228 output_folder_name (str or None):
134229 The name of the local folder where the metadata and data should be stored.
135230 If ``None`` the data is not saved locally and is loaded as a Python object.
@@ -149,14 +244,41 @@ def download_demo(modality, dataset_name, output_folder_name=None):
149244 """
150245 _validate_modalities (modality )
151246 _validate_output_folder (output_folder_name )
152- bytes_io = _download (modality , dataset_name )
153- in_memory_directory = _extract_data (bytes_io , output_folder_name )
247+ data_io , metadata_bytes = _download (modality , dataset_name )
248+ in_memory_directory = _extract_data (data_io , output_folder_name )
154249 data = _get_data (modality , output_folder_name , in_memory_directory )
155- metadata = _get_metadata (output_folder_name , in_memory_directory , dataset_name )
250+
251+ try :
252+ metadict = json .loads (metadata_bytes )
253+ metadata = Metadata ().load_from_dict (metadict , dataset_name )
254+ except Exception as e :
255+ raise DemoResourceNotFoundError ('Failed to parse metadata JSON for the dataset.' ) from e
156256
157257 return data , metadata
158258
159259
260+ def _iter_metainfo_yaml_entries (contents , modality ):
261+ """Yield (dataset_name, yaml_key) for metainfo.yaml files under a modality.
262+
263+ This matches keys like '<modality>/<dataset>/metainfo.yaml'.
264+ """
265+ modality_lower = (modality or '' ).lower ()
266+
267+ def is_metainfo_yaml (key ):
268+ parts = key .split ('/' )
269+ if len (parts ) != 3 :
270+ return False
271+ if parts [0 ].lower () != modality_lower :
272+ return False
273+ if parts [- 1 ].lower () != 'metainfo.yaml' :
274+ return False
275+ return bool (parts [1 ])
276+
277+ for key in _search_contents_keys (contents , is_metainfo_yaml ):
278+ dataset_name = key .split ('/' )[1 ]
279+ yield dataset_name , key
280+
281+
160282def get_available_demos (modality ):
161283 """Get demo datasets available for a ``modality``.
162284
@@ -170,23 +292,49 @@ def get_available_demos(modality):
170292 ``dataset_name``: The name of the dataset.
171293 ``size_MB``: The unzipped folder size in MB.
172294 ``num_tables``: The number of tables in the dataset.
173-
174- Raises:
175- Error:
176- * If ``modality`` is not ``'single_table'``, ``'multi_table'`` or ``'sequential'``.
177295 """
178296 _validate_modalities (modality )
179- client = boto3 . client ( 's3' , config = Config ( signature_version = SIGNATURE_VERSION ) )
297+ contents = _list_objects ( f' { modality } /' )
180298 tables_info = defaultdict (list )
181- for item in client .list_objects (Bucket = BUCKET )['Contents' ]:
182- dataset_modality , dataset = item ['Key' ].split ('/' , 1 )
183- if dataset_modality == modality .upper ():
184- tables_info ['dataset_name' ].append (dataset .replace ('.zip' , '' ))
185- headers = client .head_object (Bucket = BUCKET , Key = item ['Key' ])['Metadata' ]
186- size_mb = headers .get ('size-mb' , np .nan )
187- tables_info ['size_MB' ].append (round (float (size_mb ), 2 ))
188- tables_info ['num_tables' ].append (headers .get ('num-tables' , np .nan ))
189-
190- df = pd .DataFrame (tables_info )
191- df ['num_tables' ] = pd .to_numeric (df ['num_tables' ])
192- return df
299+ for dataset_name , yaml_key in _iter_metainfo_yaml_entries (contents , modality ):
300+ try :
301+ raw = _get_data_from_bucket (yaml_key )
302+ info = yaml .safe_load (raw ) or {}
303+ name = info .get ('dataset-name' ) or dataset_name
304+
305+ size_mb_val = info .get ('dataset-size-mb' )
306+ try :
307+ size_mb = float (size_mb_val ) if size_mb_val is not None else np .nan
308+ except (ValueError , TypeError ):
309+ LOGGER .info (
310+ f'Invalid dataset-size-mb { size_mb_val } for dataset { name } ; defaulting to NaN.'
311+ )
312+ size_mb = np .nan
313+
314+ num_tables_val = info .get ('num-tables' , np .nan )
315+ if isinstance (num_tables_val , str ):
316+ try :
317+ num_tables_val = float (num_tables_val )
318+ except (ValueError , TypeError ):
319+ LOGGER .info (
320+ f'Could not cast num_tables_val { num_tables_val } to float for '
321+ f'dataset { name } ; defaulting to NaN.'
322+ )
323+ num_tables_val = np .nan
324+
325+ try :
326+ num_tables = int (num_tables_val ) if not pd .isna (num_tables_val ) else np .nan
327+ except (ValueError , TypeError ):
328+ LOGGER .info (
329+ f'Invalid num-tables { num_tables_val } for dataset { name } when parsing as int.'
330+ )
331+ num_tables = np .nan
332+
333+ tables_info ['dataset_name' ].append (name )
334+ tables_info ['size_MB' ].append (size_mb )
335+ tables_info ['num_tables' ].append (num_tables )
336+
337+ except Exception :
338+ continue
339+
340+ return pd .DataFrame (tables_info )
0 commit comments