Skip to content

Commit

Permalink
Update warn msg
Browse files Browse the repository at this point in the history
  • Loading branch information
fealho committed Nov 7, 2024
1 parent 863b29e commit 275abdd
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 22 deletions.
16 changes: 10 additions & 6 deletions sdv/single_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,17 +261,21 @@ def _warn_quality_and_performance(self, column_name_to_transformer):

def _warn_unable_to_enforce_rounding(self, column_name_to_transformer):
if self.enforce_rounding:
invalid_columns = []
for column, transformer in column_name_to_transformer.items():
if (
hasattr(transformer, 'learn_rounding_scheme')
and not transformer.learn_rounding_scheme
):
warnings.warn(
f"Unable to turn off rounding scheme for column '{column}', "
'because the overall synthesizer is enforcing rounding. We '
"recommend setting the synthesizer's 'enforce_rounding' "
'parameter to False.'
)
invalid_columns.append(column)

if invalid_columns:
warnings.warn(
f'Unable to turn off rounding scheme for column(s) {invalid_columns}, '
'because the overall synthesizer is enforcing rounding. We '
"recommend setting the synthesizer's 'enforce_rounding' "
'parameter to False.'
)

def update_transformers(self, column_name_to_transformer):
"""Update any of the transformers assigned to each of the column names.
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/single_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,7 +859,7 @@ def test_update_transformers(warning_mock):

# Assert
warning_mock.warn.assert_called_once_with(
"Unable to turn off rounding scheme for column 'amenities_fee', because the overall "
"Unable to turn off rounding scheme for column(s) ['amenities_fee'], because the overall "
"synthesizer is enforcing rounding. We recommend setting the synthesizer's "
"'enforce_rounding' parameter to False."
)
22 changes: 7 additions & 15 deletions tests/unit/single_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,22 +850,14 @@ def test_update_transformers_warns_rounding(self):
instance.enforce_rounding = True
instance._fitted = False

# Run
with pytest.warns(UserWarning) as record:
instance.update_transformers(column_name_to_transformer)

# Assert
assert len(record) == 2
assert str(record[0].message) == (
"Unable to turn off rounding scheme for column 'col1', because the overall "
"synthesizer is enforcing rounding. We recommend setting the synthesizer's "
"'enforce_rounding' parameter to False."
)
assert str(record[1].message) == (
"Unable to turn off rounding scheme for column 'col3', because the overall "
"synthesizer is enforcing rounding. We recommend setting the synthesizer's "
"'enforce_rounding' parameter to False."
# Run and Assert
warn_msg = re.escape(
"Unable to turn off rounding scheme for column(s) ['col1', 'col3'], "
'because the overall synthesizer is enforcing rounding. We recommend '
"setting the synthesizer's 'enforce_rounding' parameter to False."
)
with pytest.warns(UserWarning, match=warn_msg):
instance.update_transformers(column_name_to_transformer)

@patch('sdv.single_table.base.DataProcessor')
def test__set_random_state(self, mock_data_processor):
Expand Down

0 comments on commit 275abdd

Please # to comment.