Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions sdv/datasets/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,14 @@ def _get_text_file_content(modality, dataset_name, filename, output_filepath=Non
The decoded text contents if the file exists, otherwise ``None``.
"""
_validate_modalities(modality)
if output_filepath is not None and not str(output_filepath).endswith('.txt'):
fname = (filename or '').lower()
file_type = 'README' if 'readme' in fname else 'source'
raise ValueError(
f'The {file_type} can only be saved as a txt file. '
"Please provide a filepath ending in '.txt'"
)

dataset_prefix = f'{modality}/{dataset_name}/'
contents = _list_objects(dataset_prefix)

Expand Down
1 change: 0 additions & 1 deletion sdv/single_table/copulas.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import warnings
from copy import deepcopy

import copulas
import copulas.univariate
import numpy as np
import pandas as pd
Expand Down
24 changes: 19 additions & 5 deletions tests/unit/datasets/test_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,13 +762,13 @@ def fake(modality, dataset_name, filename, output_filepath=None):
monkeypatch.setattr('sdv.datasets.demo._get_text_file_content', fake)

# Run
r = get_readme('single_table', 'dataset1', '/tmp/readme')
s = get_source('single_table', 'dataset1', '/tmp/source')
readme = get_readme('single_table', 'dataset1', '/tmp/readme.txt')
source = get_source('single_table', 'dataset1', '/tmp/source.txt')

# Assert
assert r == 'X' and s == 'X'
assert calls[0] == ('single_table', 'dataset1', 'README.txt', '/tmp/readme')
assert calls[1] == ('single_table', 'dataset1', 'SOURCE.txt', '/tmp/source')
assert readme == 'X' and source == 'X'
assert calls[0] == ('single_table', 'dataset1', 'README.txt', '/tmp/readme.txt')
assert calls[1] == ('single_table', 'dataset1', 'SOURCE.txt', '/tmp/source.txt')


@patch('sdv.datasets.demo._get_data_from_bucket')
Expand Down Expand Up @@ -807,3 +807,17 @@ def test_get_source_raises_if_output_file_exists(mock_list, mock_get, tmp_path):
err = f"A file named '{out}' already exists. Please specify a different filepath."
with pytest.raises(ValueError, match=re.escape(err)):
get_source('single_table', 'dataset1', str(out))


def test_get_readme_raises_for_non_txt_output():
"""get_readme should raise ValueError if output path is not .txt."""
err = "The README can only be saved as a txt file. Please provide a filepath ending in '.txt'"
with pytest.raises(ValueError, match=re.escape(err)):
get_readme('single_table', 'dataset1', '/tmp/readme.md')


def test_get_source_raises_for_non_txt_output():
"""get_source should raise ValueError if output path is not .txt."""
err = "The source can only be saved as a txt file. Please provide a filepath ending in '.txt'"
with pytest.raises(ValueError, match=re.escape(err)):
get_source('single_table', 'dataset1', '/tmp/source.pdf')
Loading