Skip to content

Commit 461036c

Browse files
committed
Ensure that the source/readme files can only be downloaded in txt format (#2693)
1 parent 4adc5c6 commit 461036c

File tree

3 files changed

+27
-6
lines changed

3 files changed

+27
-6
lines changed

sdv/datasets/demo.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,14 @@ def _get_text_file_content(modality, dataset_name, filename, output_filepath=Non
384384
The decoded text contents if the file exists, otherwise ``None``.
385385
"""
386386
_validate_modalities(modality)
387+
if output_filepath is not None and not str(output_filepath).endswith('.txt'):
388+
fname = (filename or '').lower()
389+
file_type = 'README' if 'readme' in fname else 'source'
390+
raise ValueError(
391+
f'The {file_type} can only be saved as a txt file. '
392+
"Please provide a filepath ending in '.txt'"
393+
)
394+
387395
dataset_prefix = f'{modality}/{dataset_name}/'
388396
contents = _list_objects(dataset_prefix)
389397

sdv/single_table/copulas.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import warnings
66
from copy import deepcopy
77

8-
import copulas
98
import copulas.univariate
109
import numpy as np
1110
import pandas as pd

tests/unit/datasets/test_demo.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -762,13 +762,13 @@ def fake(modality, dataset_name, filename, output_filepath=None):
762762
monkeypatch.setattr('sdv.datasets.demo._get_text_file_content', fake)
763763

764764
# Run
765-
r = get_readme('single_table', 'dataset1', '/tmp/readme')
766-
s = get_source('single_table', 'dataset1', '/tmp/source')
765+
readme = get_readme('single_table', 'dataset1', '/tmp/readme.txt')
766+
source = get_source('single_table', 'dataset1', '/tmp/source.txt')
767767

768768
# Assert
769-
assert r == 'X' and s == 'X'
770-
assert calls[0] == ('single_table', 'dataset1', 'README.txt', '/tmp/readme')
771-
assert calls[1] == ('single_table', 'dataset1', 'SOURCE.txt', '/tmp/source')
769+
assert readme == 'X' and source == 'X'
770+
assert calls[0] == ('single_table', 'dataset1', 'README.txt', '/tmp/readme.txt')
771+
assert calls[1] == ('single_table', 'dataset1', 'SOURCE.txt', '/tmp/source.txt')
772772

773773

774774
@patch('sdv.datasets.demo._get_data_from_bucket')
@@ -807,3 +807,17 @@ def test_get_source_raises_if_output_file_exists(mock_list, mock_get, tmp_path):
807807
err = f"A file named '{out}' already exists. Please specify a different filepath."
808808
with pytest.raises(ValueError, match=re.escape(err)):
809809
get_source('single_table', 'dataset1', str(out))
810+
811+
812+
def test_get_readme_raises_for_non_txt_output():
813+
"""get_readme should raise ValueError if output path is not .txt."""
814+
err = "The README can only be saved as a txt file. Please provide a filepath ending in '.txt'"
815+
with pytest.raises(ValueError, match=re.escape(err)):
816+
get_readme('single_table', 'dataset1', '/tmp/readme.md')
817+
818+
819+
def test_get_source_raises_for_non_txt_output():
820+
"""get_source should raise ValueError if output path is not .txt."""
821+
err = "The source can only be saved as a txt file. Please provide a filepath ending in '.txt'"
822+
with pytest.raises(ValueError, match=re.escape(err)):
823+
get_source('single_table', 'dataset1', '/tmp/source.pdf')

0 commit comments

Comments
 (0)