|
22 | 22 | ) |
23 | 23 | from sdv.data_processing.datetime_formatter import DatetimeFormatter |
24 | 24 | from sdv.data_processing.errors import InvalidConstraintsError, NotFittedError |
25 | | -from sdv.data_processing.numerical_formatter import INTEGER_BOUNDS, NumericalFormatter |
| 25 | +from sdv.data_processing.numerical_formatter import NumericalFormatter |
26 | 26 | from sdv.data_processing.utils import load_module_from_path |
27 | 27 | from sdv.errors import SynthesizerInputError, log_exc_stacktrace |
28 | 28 | from sdv.metadata.single_table import SingleTableMetadata |
29 | 29 |
|
30 | 30 | LOGGER = logging.getLogger(__name__) |
31 | | -INTEGER_BOUNDS = {str(key).lower(): value for key, value in INTEGER_BOUNDS.items()} |
32 | 31 |
|
33 | 32 |
|
34 | 33 | class DataProcessor: |
@@ -70,8 +69,6 @@ class DataProcessor: |
70 | 69 | 'M': 'datetime', |
71 | 70 | } |
72 | 71 |
|
73 | | - _COLUMN_RELATIONSHIP_TO_TRANSFORMER = {'address': 'RandomLocationGenerator', 'gps': 'GPSNoiser'} |
74 | | - |
75 | 72 | def _update_numerical_transformer(self, enforce_rounding, enforce_min_max_values): |
76 | 73 | custom_float_formatter = rdt.transformers.FloatFormatter( |
77 | 74 | missing_value_replacement='mean', |
@@ -124,6 +121,10 @@ def __init__( |
124 | 121 | self._constraints = [] |
125 | 122 | self._constraints_to_reverse = [] |
126 | 123 | self._custom_constraint_classes = {} |
| 124 | + self._COLUMN_RELATIONSHIP_TO_TRANSFORMER = { |
| 125 | + 'address': 'RandomLocationGenerator', |
| 126 | + 'gps': 'GPSNoiser', |
| 127 | + } |
127 | 128 |
|
128 | 129 | self._transformers_by_sdtype = deepcopy(get_default_transformers()) |
129 | 130 | self._transformers_by_sdtype['id'] = rdt.transformers.RegexGenerator() |
@@ -575,11 +576,11 @@ def _create_config(self, data, columns_created_by_constraints): |
575 | 576 | if is_numeric: |
576 | 577 | function_name = 'random_int' |
577 | 578 | column_dtype = str(column_dtype).lower() |
578 | | - function_kwargs = {'min': 0, 'max': 9999999} |
579 | | - for key in INTEGER_BOUNDS: |
580 | | - if key in column_dtype: |
581 | | - _, max_value = INTEGER_BOUNDS[key] |
582 | | - function_kwargs = {'min': 0, 'max': max_value} |
| 579 | + function_kwargs = {'min': 0, 'max': 16777216} |
| 580 | + if 'int8' in column_dtype: |
| 581 | + function_kwargs['max'] = 127 |
| 582 | + elif 'int16' in column_dtype: |
| 583 | + function_kwargs['max'] = 32767 |
583 | 584 |
|
584 | 585 | else: |
585 | 586 | function_kwargs = {'text': 'sdv-id-??????'} |
|
0 commit comments