Skip to content

Commit 1a7c849

Browse files
authored
Merge branch 'main' into issue-2661-download-demo-update
2 parents bc35d42 + 1705ced commit 1a7c849

File tree

6 files changed

+24
-7
lines changed

6 files changed

+24
-7
lines changed

sdv/cag/_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pandas as pd
66

77
from sdv.cag._errors import ConstraintNotMetError
8-
from sdv.errors import SynthesizerInputError, TableNameError
8+
from sdv.errors import RefitWarning, SynthesizerInputError, TableNameError
99
from sdv.metadata import Metadata
1010

1111

@@ -185,7 +185,8 @@ def _validate_constraints(constraints, synthesizer_fitted):
185185

186186
if synthesizer_fitted:
187187
warnings.warn(
188-
"For these constraints to take effect, please refit the synthesizer using 'fit'."
188+
"For these constraints to take effect, please refit the synthesizer using 'fit'.",
189+
RefitWarning,
189190
)
190191

191192
return _filter_old_style_constraints(constraints)

sdv/errors.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,11 @@ class DemoResourceNotFoundError(Exception):
9191
This error is intended for missing demo assets such as the dataset archive,
9292
metadata, license, README, or other auxiliary files in the demo bucket.
9393
"""
94+
95+
96+
class RefitWarning(UserWarning):
97+
"""Warning to be raised if the synthesizer needs to be refit.
98+
99+
Warning to be raised if a change to a synthesizer requires the synthesizer
100+
to be refit for the change to be applied.
101+
"""

sdv/multi_table/base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from sdv.cag.programmable_constraint import ProgrammableConstraint, ProgrammableConstraintHarness
3030
from sdv.errors import (
3131
InvalidDataError,
32+
RefitWarning,
3233
SamplingError,
3334
SynthesizerInputError,
3435
)
@@ -551,10 +552,11 @@ def preprocess(self, data):
551552
self.validate(data)
552553
data = self._validate_transform_constraints(data)
553554
if self._fitted:
554-
warnings.warn(
555+
msg = (
555556
'This model has already been fitted. To use the new preprocessed data, '
556557
"please refit the model using 'fit' or 'fit_processed_data'."
557558
)
559+
warnings.warn(msg, RefitWarning)
558560

559561
processed_data = {}
560562
pbar_args = self._get_pbar_args(desc='Preprocess Tables')

sdv/single_table/base.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from sdv.errors import (
4242
ConstraintsNotMetError,
4343
InvalidDataError,
44+
RefitWarning,
4445
SamplingError,
4546
SynthesizerInputError,
4647
)
@@ -306,7 +307,7 @@ def update_transformers(self, column_name_to_transformer):
306307
self._data_processor.update_transformers(column_name_to_transformer)
307308
if self._fitted:
308309
msg = 'For this change to take effect, please refit the synthesizer using `fit`.'
309-
warnings.warn(msg, UserWarning)
310+
warnings.warn(msg, RefitWarning)
310311

311312
def get_parameters(self):
312313
"""Return the parameters used to instantiate the synthesizer."""
@@ -587,10 +588,12 @@ def _preprocess_helper(self, data):
587588
"""
588589
self.validate(data)
589590
if self._fitted:
590-
warnings.warn(
591+
msg = (
591592
'This model has already been fitted. To use the new preprocessed data, '
592593
"please refit the model using 'fit' or 'fit_processed_data'."
593594
)
595+
warnings.warn(msg, RefitWarning)
596+
594597
data = self._validate_transform_constraints(data)
595598

596599
return data

tests/unit/multi_table/test_base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from sdv.errors import (
1818
InvalidDataError,
1919
NotFittedError,
20+
RefitWarning,
2021
SamplingError,
2122
SynthesizerInputError,
2223
VersionError,
@@ -1013,7 +1014,8 @@ def test_preprocess_warning(self, mock_warnings):
10131014
assert args[0].equals(data['upravna_enota'])
10141015
mock_warnings.warn.assert_called_once_with(
10151016
'This model has already been fitted. To use the new preprocessed data, '
1016-
"please refit the model using 'fit' or 'fit_processed_data'."
1017+
"please refit the model using 'fit' or 'fit_processed_data'.",
1018+
RefitWarning,
10171019
)
10181020

10191021
@patch('sdv.metadata.single_table.SingleTableMetadata._validate_metadata_matches_data')

tests/unit/single_table/test_base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from sdv.errors import (
2424
ConstraintsNotMetError,
2525
InvalidDataError,
26+
RefitWarning,
2627
SamplingError,
2728
SynthesizerInputError,
2829
VersionError,
@@ -665,7 +666,7 @@ def test__preprocess_helper(self, mock_warnings):
665666
result = BaseSynthesizer._preprocess_helper(instance, data)
666667

667668
# Assert
668-
mock_warnings.warn.assert_called_once_with(expected_warning)
669+
mock_warnings.warn.assert_called_once_with(expected_warning, RefitWarning)
669670
instance.validate.assert_called_once_with(data)
670671
instance._validate_transform_constraints.assert_called_once_with(data)
671672
pd.testing.assert_frame_equal(result, data)

0 commit comments

Comments
 (0)