Skip to content

Commit 18ece9e

Browse files
authored
Fix download_demo for data.zip files (#2699)
1 parent 6ce265a commit 18ece9e

File tree

2 files changed

+237
-12
lines changed

2 files changed

+237
-12
lines changed

sdv/datasets/demo.py

Lines changed: 59 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
BUCKET_URL = f'https://{BUCKET}.s3.amazonaws.com'
2525
SIGNATURE_VERSION = UNSIGNED
2626
METADATA_FILENAME = 'metadata.json'
27+
FALLBACK_ENCODING = 'latin-1'
2728

2829

2930
def _validate_modalities(modality):
@@ -197,21 +198,67 @@ def _extract_data(bytes_io, output_folder_name):
197198
return in_memory_directory
198199

199200

200-
def _get_data(modality, output_folder_name, in_memory_directory):
201+
def _get_data_with_output_folder(output_folder_name):
202+
"""Load CSV tables from an extracted folder on disk.
203+
204+
Returns a tuple of (data_dict, skipped_files).
205+
Non-CSV files are ignored.
206+
"""
201207
data = {}
202-
if output_folder_name:
203-
for root, _dirs, files in os.walk(output_folder_name):
204-
for filename in files:
205-
if filename.endswith('.csv'):
206-
table_name = Path(filename).stem
207-
data_path = os.path.join(root, filename)
208-
data[table_name] = pd.read_csv(data_path)
208+
skipped_files = []
209+
for root, _dirs, files in os.walk(output_folder_name):
210+
for filename in files:
211+
if not filename.lower().endswith('.csv'):
212+
skipped_files.append(filename)
213+
continue
214+
215+
table_name = Path(filename).stem
216+
data_path = os.path.join(root, filename)
217+
try:
218+
data[table_name] = pd.read_csv(data_path)
219+
except UnicodeDecodeError:
220+
data[table_name] = pd.read_csv(data_path, encoding=FALLBACK_ENCODING)
221+
except Exception as e:
222+
rel = os.path.relpath(data_path, output_folder_name)
223+
skipped_files.append(f'{rel}: {e}')
224+
225+
return data, skipped_files
226+
227+
228+
def _get_data_without_output_folder(in_memory_directory):
229+
"""Load CSV tables directly from in-memory zip contents.
230+
231+
Returns a tuple of (data_dict, skipped_files).
232+
Non-CSV entries are ignored.
233+
"""
234+
data = {}
235+
skipped_files = []
236+
for filename, file_ in in_memory_directory.items():
237+
if not filename.lower().endswith('.csv'):
238+
skipped_files.append(filename)
239+
continue
240+
241+
table_name = Path(filename).stem
242+
try:
243+
data[table_name] = pd.read_csv(io.BytesIO(file_), low_memory=False)
244+
except UnicodeDecodeError:
245+
data[table_name] = pd.read_csv(
246+
io.BytesIO(file_), low_memory=False, encoding=FALLBACK_ENCODING
247+
)
248+
except Exception as e:
249+
skipped_files.append(f'{filename}: {e}')
209250

251+
return data, skipped_files
252+
253+
254+
def _get_data(modality, output_folder_name, in_memory_directory):
255+
if output_folder_name:
256+
data, skipped_files = _get_data_with_output_folder(output_folder_name)
210257
else:
211-
for filename, file_ in in_memory_directory.items():
212-
if filename.endswith('.csv'):
213-
table_name = Path(filename).stem
214-
data[table_name] = pd.read_csv(io.StringIO(file_.decode()), low_memory=False)
258+
data, skipped_files = _get_data_without_output_folder(in_memory_directory)
259+
260+
if skipped_files:
261+
warnings.warn('Skipped files: ' + ', '.join(sorted(skipped_files)))
215262

216263
if not data:
217264
raise DemoResourceNotFoundError(

tests/unit/datasets/test_demo.py

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1019,3 +1019,181 @@ def test_download_demo_raises_when_no_csv_in_zip_single_table(mock_list, mock_ge
10191019
msg = 'Demo data could not be downloaded because no csv files were found in data.zip'
10201020
with pytest.raises(DemoResourceNotFoundError, match=re.escape(msg)):
10211021
download_demo('single_table', 'word')
1022+
1023+
1024+
@patch('sdv.datasets.demo._get_data_from_bucket')
1025+
@patch('sdv.datasets.demo._list_objects')
1026+
def test_download_demo_skips_non_csv_in_memory_no_warning(mock_list, mock_get):
1027+
"""In-memory path: ignore non-CSV files silently; load valid CSVs."""
1028+
# Setup
1029+
mock_list.return_value = [
1030+
{'Key': 'single_table/mix/data.zip'},
1031+
{'Key': 'single_table/mix/metadata.json'},
1032+
]
1033+
1034+
df = pd.DataFrame({'id': [1, 2], 'name': ['a', 'b']})
1035+
buf = io.BytesIO()
1036+
with zipfile.ZipFile(buf, mode='w', compression=zipfile.ZIP_DEFLATED) as zf:
1037+
zf.writestr('good.csv', df.to_csv(index=False))
1038+
zf.writestr('note.txt', 'hello world')
1039+
zf.writestr('nested/readme.md', '# readme')
1040+
# Add a directory entry explicitly
1041+
zf.writestr('empty_dir/', '')
1042+
zip_bytes = buf.getvalue()
1043+
1044+
meta_bytes = json.dumps({
1045+
'METADATA_SPEC_VERSION': 'V1',
1046+
'tables': {
1047+
'good': {
1048+
'columns': {
1049+
'id': {'sdtype': 'numerical', 'computer_representation': 'Int64'},
1050+
'name': {'sdtype': 'categorical'},
1051+
}
1052+
}
1053+
},
1054+
'relationships': [],
1055+
}).encode()
1056+
1057+
mock_get.side_effect = lambda key: zip_bytes if key.endswith('data.zip') else meta_bytes
1058+
1059+
# Run and Assert
1060+
warn_msg = 'Skipped files: empty_dir/, nested/readme.md, note.txt'
1061+
with pytest.warns(UserWarning, match=warn_msg) as rec:
1062+
data, _ = download_demo('single_table', 'mix')
1063+
1064+
assert len(rec) == 1
1065+
expected = pd.DataFrame({'id': [1, 2], 'name': ['a', 'b']})
1066+
pd.testing.assert_frame_equal(data, expected)
1067+
1068+
1069+
@patch('sdv.datasets.demo._get_data_from_bucket')
1070+
@patch('sdv.datasets.demo._list_objects')
1071+
def test_download_demo_on_disk_warns_failed_csv_only(mock_list, mock_get, tmp_path, monkeypatch):
1072+
"""On-disk path: warn only for failed CSVs; non-CSV are skipped silently."""
1073+
# Setup
1074+
mock_list.return_value = [
1075+
{'Key': 'single_table/mix/data.zip'},
1076+
{'Key': 'single_table/mix/metadata.json'},
1077+
]
1078+
1079+
good = pd.DataFrame({'x': [1, 2]})
1080+
buf = io.BytesIO()
1081+
with zipfile.ZipFile(buf, mode='w', compression=zipfile.ZIP_DEFLATED) as zf:
1082+
zf.writestr('good.csv', good.to_csv(index=False))
1083+
zf.writestr('bad.csv', 'will_fail')
1084+
zf.writestr('info.txt', 'ignore me')
1085+
zip_bytes = buf.getvalue()
1086+
1087+
meta_bytes = json.dumps({
1088+
'METADATA_SPEC_VERSION': 'V1',
1089+
'tables': {
1090+
'good': {
1091+
'columns': {
1092+
'x': {'sdtype': 'numerical', 'computer_representation': 'Int64'},
1093+
}
1094+
}
1095+
},
1096+
'relationships': [],
1097+
}).encode()
1098+
1099+
mock_get.side_effect = lambda key: zip_bytes if key.endswith('data.zip') else meta_bytes
1100+
1101+
# Force read_csv to fail on bad.csv only
1102+
orig_read_csv = pd.read_csv
1103+
1104+
def fake_read_csv(path_or_buf, *args, **kwargs):
1105+
if isinstance(path_or_buf, str) and path_or_buf.endswith('bad.csv'):
1106+
raise ValueError('bad-parse')
1107+
return orig_read_csv(path_or_buf, *args, **kwargs)
1108+
1109+
monkeypatch.setattr('pandas.read_csv', fake_read_csv)
1110+
1111+
out_dir = tmp_path / 'mix_out'
1112+
1113+
# Run and Assert
1114+
warn_msg = 'Skipped files: bad.csv: bad-parse, info.txt'
1115+
with pytest.warns(UserWarning, match=warn_msg) as rec:
1116+
data, _ = download_demo('single_table', 'mix', out_dir)
1117+
1118+
assert len(rec) == 1
1119+
pd.testing.assert_frame_equal(data, good)
1120+
1121+
1122+
@patch('sdv.datasets.demo._get_data_from_bucket')
1123+
@patch('sdv.datasets.demo._list_objects')
1124+
def test_download_demo_handles_non_utf8_in_memory(mock_list, mock_get):
1125+
"""It should successfully read Latin-1 encoded CSVs from in-memory extraction."""
1126+
# Setup
1127+
mock_list.return_value = [
1128+
{'Key': 'single_table/nonutf/data.zip'},
1129+
{'Key': 'single_table/nonutf/metadata.json'},
1130+
]
1131+
1132+
df = pd.DataFrame({'id': [1], 'name': ['café']})
1133+
buf = io.BytesIO()
1134+
with zipfile.ZipFile(buf, mode='w', compression=zipfile.ZIP_DEFLATED) as zf:
1135+
zf.writestr('nonutf.csv', df.to_csv(index=False).encode('latin-1'))
1136+
zip_bytes = buf.getvalue()
1137+
1138+
meta_bytes = json.dumps({
1139+
'METADATA_SPEC_VERSION': 'V1',
1140+
'tables': {
1141+
'nonutf': {
1142+
'columns': {
1143+
'id': {'sdtype': 'numerical', 'computer_representation': 'Int64'},
1144+
'name': {'sdtype': 'categorical'},
1145+
}
1146+
}
1147+
},
1148+
'relationships': [],
1149+
}).encode()
1150+
1151+
mock_get.side_effect = lambda key: zip_bytes if key.endswith('data.zip') else meta_bytes
1152+
1153+
# Run
1154+
data, _ = download_demo('single_table', 'nonutf')
1155+
1156+
# Assert
1157+
expected = pd.DataFrame({'id': [1], 'name': ['café']})
1158+
pd.testing.assert_frame_equal(data, expected)
1159+
1160+
1161+
@patch('sdv.datasets.demo._get_data_from_bucket')
1162+
@patch('sdv.datasets.demo._list_objects')
1163+
def test_download_demo_handles_non_utf8_on_disk(mock_list, mock_get, tmp_path):
1164+
"""It should successfully read Latin-1 encoded CSVs when extracted to disk."""
1165+
# Setup
1166+
mock_list.return_value = [
1167+
{'Key': 'single_table/nonutf/data.zip'},
1168+
{'Key': 'single_table/nonutf/metadata.json'},
1169+
]
1170+
1171+
df = pd.DataFrame({'id': [1], 'name': ['café']})
1172+
buf = io.BytesIO()
1173+
with zipfile.ZipFile(buf, mode='w', compression=zipfile.ZIP_DEFLATED) as zf:
1174+
zf.writestr('nonutf.csv', df.to_csv(index=False).encode('latin-1'))
1175+
zip_bytes = buf.getvalue()
1176+
1177+
meta_bytes = json.dumps({
1178+
'METADATA_SPEC_VERSION': 'V1',
1179+
'tables': {
1180+
'nonutf': {
1181+
'columns': {
1182+
'id': {'sdtype': 'numerical', 'computer_representation': 'Int64'},
1183+
'name': {'sdtype': 'categorical'},
1184+
}
1185+
}
1186+
},
1187+
'relationships': [],
1188+
}).encode()
1189+
1190+
mock_get.side_effect = lambda key: zip_bytes if key.endswith('data.zip') else meta_bytes
1191+
1192+
out_dir = tmp_path / 'latin_out'
1193+
1194+
# Run
1195+
data, _ = download_demo('single_table', 'nonutf', out_dir)
1196+
1197+
# Assert
1198+
expected = pd.DataFrame({'id': [1], 'name': ['café']})
1199+
pd.testing.assert_frame_equal(data, expected)

0 commit comments

Comments
 (0)