Skip to content

Commit

Permalink
Merge branch 'main' into issue-2266-rounding-warning
Browse files Browse the repository at this point in the history
  • Loading branch information
fealho authored Nov 12, 2024
2 parents 275abdd + 0fe0123 commit 852e01b
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 22 deletions.
4 changes: 2 additions & 2 deletions latest_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ copulas==0.11.1
ctgan==0.10.2
deepecho==0.6.1
graphviz==0.20.3
numpy==1.26.4
numpy==2.0.2
pandas==2.2.3
platformdirs==4.3.6
rdt==1.13.0
sdmetrics==0.16.0
tqdm==4.66.5
tqdm==4.67.0
19 changes: 10 additions & 9 deletions sdv/data_processing/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,12 @@
)
from sdv.data_processing.datetime_formatter import DatetimeFormatter
from sdv.data_processing.errors import InvalidConstraintsError, NotFittedError
from sdv.data_processing.numerical_formatter import INTEGER_BOUNDS, NumericalFormatter
from sdv.data_processing.numerical_formatter import NumericalFormatter
from sdv.data_processing.utils import load_module_from_path
from sdv.errors import SynthesizerInputError, log_exc_stacktrace
from sdv.metadata.single_table import SingleTableMetadata

LOGGER = logging.getLogger(__name__)
INTEGER_BOUNDS = {str(key).lower(): value for key, value in INTEGER_BOUNDS.items()}


class DataProcessor:
Expand Down Expand Up @@ -70,8 +69,6 @@ class DataProcessor:
'M': 'datetime',
}

_COLUMN_RELATIONSHIP_TO_TRANSFORMER = {'address': 'RandomLocationGenerator', 'gps': 'GPSNoiser'}

def _update_numerical_transformer(self, enforce_rounding, enforce_min_max_values):
custom_float_formatter = rdt.transformers.FloatFormatter(
missing_value_replacement='mean',
Expand Down Expand Up @@ -124,6 +121,10 @@ def __init__(
self._constraints = []
self._constraints_to_reverse = []
self._custom_constraint_classes = {}
self._COLUMN_RELATIONSHIP_TO_TRANSFORMER = {
'address': 'RandomLocationGenerator',
'gps': 'GPSNoiser',
}

self._transformers_by_sdtype = deepcopy(get_default_transformers())
self._transformers_by_sdtype['id'] = rdt.transformers.RegexGenerator()
Expand Down Expand Up @@ -575,11 +576,11 @@ def _create_config(self, data, columns_created_by_constraints):
if is_numeric:
function_name = 'random_int'
column_dtype = str(column_dtype).lower()
function_kwargs = {'min': 0, 'max': 9999999}
for key in INTEGER_BOUNDS:
if key in column_dtype:
_, max_value = INTEGER_BOUNDS[key]
function_kwargs = {'min': 0, 'max': max_value}
function_kwargs = {'min': 0, 'max': 16777216}
if 'int8' in column_dtype:
function_kwargs['max'] = 127
elif 'int16' in column_dtype:
function_kwargs['max'] = 32767

else:
function_kwargs = {'text': 'sdv-id-??????'}
Expand Down
20 changes: 10 additions & 10 deletions tests/integration/single_table/test_copulas.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,16 +348,16 @@ def test_numerical_columns_gets_pii():
# Assert
expected_sampled = pd.DataFrame({
'id': [
1089619006166876142,
8373046707753416652,
9070705361670139280,
7227045982112645011,
3461931576753619633,
1005734164466301683,
3312031189447929384,
82456842876428117,
1819741328868365520,
8019169766233150107,
1982005,
15967014,
10406639,
15230483,
14028549,
16499516,
9244156,
13145920,
10106629,
6297216,
],
'city': [
'Danielfort',
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/data_processing/test_data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1258,7 +1258,7 @@ def test__create_config(self):
assert id_numeric_int_32_transformer.function_name == 'random_int'
assert id_numeric_int_32_transformer.function_kwargs == {
'min': 0,
'max': 2147483647,
'max': 16777216,
}

id_column_transformer = config['transformers']['id_column']
Expand Down

0 comments on commit 852e01b

Please # to comment.