Skip to content

Commit 76f11d9

Browse files
committed
make release-tag: Merge branch 'main' into stable
2 parents 0f130f8 + 18c301d commit 76f11d9

File tree

17 files changed

+650
-56
lines changed

17 files changed

+650
-56
lines changed

HISTORY.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,21 @@
11
# Release Notes
22

3+
## v1.27.0 - 2025-09-15
4+
5+
### New Features
6+
7+
* Create a specific warning type for the purposes of refitting a synthesizer - Issue [#2662](https://github.com/sdv-dev/SDV/issues/2662) by @frances-h
8+
* [OneHotEncoding constraint] Allow me to specify whether to keep the one-hot columns or collapse them into one categorical column - Issue [#2650](https://github.com/sdv-dev/SDV/issues/2650) by @fealho
9+
10+
### Bugs Fixed
11+
12+
* "numerical_distributions" in HMASynthesizer get ignored - Issue [#2648](https://github.com/sdv-dev/SDV/issues/2648) by @fealho
13+
14+
### Internal
15+
16+
* Add helper method for transforming conditions - Issue [#2660](https://github.com/sdv-dev/SDV/issues/2660) by @rwedge
17+
* [OneHotEncoding Constraint] For higher quality, ensure the model creates floating point numbers - Issue [#2649](https://github.com/sdv-dev/SDV/issues/2649) by @fealho
18+
319
## v1.26.0 - 2025-08-18
420

521
### New Features

latest_requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ ctgan==0.11.0
44
deepecho==0.7.0
55
graphviz==0.21
66
numpy==2.3.2
7-
pandas==2.3.1
8-
platformdirs==4.3.8
7+
pandas==2.3.2
8+
platformdirs==4.4.0
99
rdt==1.18.0
1010
sdmetrics==0.23.0
1111
tqdm==4.67.1

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ namespaces = false
143143
version = {attr = 'sdv.__version__'}
144144

145145
[tool.bumpversion]
146-
current_version = "1.26.0"
146+
current_version = "1.27.0.dev0"
147147
parse = '(?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(\.(?P<release>[a-z]+)(?P<candidate>\d+))?'
148148
serialize = [
149149
'{major}.{minor}.{patch}.{release}{candidate}',

sdv/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
__author__ = 'DataCebo, Inc.'
88
__email__ = '[email protected]'
9-
__version__ = '1.26.0'
9+
__version__ = '1.27.0.dev0'
1010

1111

1212
import sys

sdv/cag/_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pandas as pd
66

77
from sdv.cag._errors import ConstraintNotMetError
8-
from sdv.errors import SynthesizerInputError, TableNameError
8+
from sdv.errors import RefitWarning, SynthesizerInputError, TableNameError
99
from sdv.metadata import Metadata
1010

1111

@@ -185,7 +185,8 @@ def _validate_constraints(constraints, synthesizer_fitted):
185185

186186
if synthesizer_fitted:
187187
warnings.warn(
188-
"For these constraints to take effect, please refit the synthesizer using 'fit'."
188+
"For these constraints to take effect, please refit the synthesizer using 'fit'.",
189+
RefitWarning,
189190
)
190191

191192
return _filter_old_style_constraints(constraints)

sdv/cag/one_hot_encoding.py

Lines changed: 73 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,22 @@
11
"""One Hot Encoding constraint."""
22

3+
from copy import deepcopy
4+
35
import numpy as np
46

7+
from sdv._utils import _create_unique_name
58
from sdv.cag._errors import ConstraintNotMetError
69
from sdv.cag._utils import (
710
_get_is_valid_dict,
811
_is_list_of_type,
12+
_remove_columns_from_metadata,
913
_validate_table_and_column_names,
1014
_validate_table_name_if_defined,
1115
)
1216
from sdv.cag.base import BaseConstraint
1317

18+
EPSILON = float(np.finfo(np.float32).eps)
19+
1420

1521
class OneHotEncoding(BaseConstraint):
1622
"""Ensure the appropriate columns are one hot encoded.
@@ -26,20 +32,30 @@ class OneHotEncoding(BaseConstraint):
2632
table_name (str, optional):
2733
The name of the table that contains the columns. Optional if the
2834
data is only a single table. Defaults to None.
35+
learning_strategy (str, optional):
36+
Strategy for how the model should learn the one-hot fields. Supported values:
37+
- 'one_hot' (default): Learn each one-hot column separately.
38+
- 'categorical': Internally collapse the one-hot columns into a single categorical
39+
column for the model to learn, then expand back to one-hot at sampling time.
2940
"""
3041

3142
@staticmethod
32-
def _validate_init_inputs(column_names, table_name):
43+
def _validate_init_inputs(column_names, table_name, learning_strategy):
3344
if not _is_list_of_type(column_names):
3445
raise ValueError('`column_names` must be a list of strings.')
3546

3647
_validate_table_name_if_defined(table_name)
3748

38-
def __init__(self, column_names, table_name=None):
49+
if learning_strategy not in ['one_hot', 'categorical']:
50+
raise ValueError("`learning_strategy` must be either 'one_hot' or 'categorical'.")
51+
52+
def __init__(self, column_names, table_name=None, learning_strategy='one_hot'):
3953
super().__init__()
40-
self._validate_init_inputs(column_names, table_name)
54+
self._validate_init_inputs(column_names, table_name, learning_strategy)
4155
self._column_names = column_names
4256
self.table_name = table_name
57+
self.learning_strategy = learning_strategy
58+
self._categorical_column = '#'.join(self._column_names)
4359

4460
def _validate_constraint_with_metadata(self, metadata):
4561
"""Validate the constraint is compatible with the provided metadata.
@@ -88,6 +104,28 @@ def _fit(self, data, metadata):
88104
"""
89105
pass
90106

107+
def _get_updated_metadata(self, metadata):
108+
table_name = self._get_single_table_name(metadata)
109+
if self.learning_strategy == 'categorical':
110+
self._categorical_column = _create_unique_name(
111+
self._categorical_column, metadata.tables[table_name].columns
112+
)
113+
md = metadata.to_dict()
114+
md['tables'][table_name]['columns'][self._categorical_column] = {
115+
'sdtype': 'categorical'
116+
}
117+
return _remove_columns_from_metadata(md, table_name, columns_to_drop=self._column_names)
118+
119+
else:
120+
metadata = deepcopy(metadata)
121+
for column in self._column_names:
122+
if metadata.tables[table_name].columns[column]['sdtype'] in [
123+
'categorical',
124+
'boolean',
125+
]:
126+
metadata.tables[table_name].columns[column]['sdtype'] = 'numerical'
127+
return metadata
128+
91129
def _transform(self, data):
92130
"""Transform the data.
93131
@@ -99,6 +137,17 @@ def _transform(self, data):
99137
dict[str, pd.DataFrame]:
100138
Transformed data.
101139
"""
140+
table_name = self._get_single_table_name(self.metadata)
141+
if self.learning_strategy == 'categorical':
142+
table_data = data[table_name]
143+
categories = table_data[self._column_names].idxmax(axis=1)
144+
table_data[self._categorical_column] = categories
145+
data[table_name] = table_data.drop(self._column_names, axis=1)
146+
else:
147+
one_hot_data = data[table_name][self._column_names]
148+
one_hot_data = np.where(one_hot_data == 0, EPSILON, 1 - EPSILON)
149+
data[table_name][self._column_names] = one_hot_data
150+
102151
return data
103152

104153
def _reverse_transform(self, data):
@@ -116,13 +165,28 @@ def _reverse_transform(self, data):
116165
"""
117166
table_name = self._get_single_table_name(self.metadata)
118167
table_data = data[table_name]
119-
one_hot_data = table_data[self._column_names]
120-
transformed_data = np.zeros_like(one_hot_data.to_numpy())
121-
max_category_indices = np.argmax(one_hot_data.to_numpy(), axis=1)
122-
transformed_data[np.arange(len(one_hot_data)), max_category_indices] = 1
123-
table_data[self._column_names] = transformed_data
124-
data[table_name] = table_data
125168

169+
if self.learning_strategy == 'categorical':
170+
categories = table_data.pop(self._categorical_column)
171+
num_rows = len(table_data)
172+
num_cols = len(self._column_names)
173+
transformed = np.zeros((num_rows, num_cols), dtype=float)
174+
175+
column_to_index = {name: idx for idx, name in enumerate(self._column_names)}
176+
indices = categories.map(lambda x: column_to_index[x]).to_numpy()
177+
transformed[np.arange(num_rows), indices] = 1
178+
179+
for idx, col in enumerate(self._column_names):
180+
table_data[col] = transformed[:, idx]
181+
182+
else:
183+
one_hot_data = table_data[self._column_names]
184+
transformed_data = np.zeros_like(one_hot_data.to_numpy())
185+
max_category_indices = np.argmax(one_hot_data.to_numpy(), axis=1)
186+
transformed_data[np.arange(len(one_hot_data)), max_category_indices] = 1
187+
table_data[self._column_names] = transformed_data
188+
189+
data[table_name] = table_data
126190
return data
127191

128192
def _is_valid(self, data, metadata):

sdv/errors.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,11 @@ def __init__(self, message):
8383

8484

8585
TableNameError = ValueError('`table_name` must be a string or None.')
86+
87+
88+
class RefitWarning(UserWarning):
89+
"""Warning to be raised if the synthesizer needs to be refit.
90+
91+
Warning to be raised if a change to a synthesizer requires the synthesizer
92+
to be refit for the change to be applied.
93+
"""

sdv/multi_table/base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from sdv.cag.programmable_constraint import ProgrammableConstraint, ProgrammableConstraintHarness
3030
from sdv.errors import (
3131
InvalidDataError,
32+
RefitWarning,
3233
SamplingError,
3334
SynthesizerInputError,
3435
)
@@ -551,10 +552,11 @@ def preprocess(self, data):
551552
self.validate(data)
552553
data = self._validate_transform_constraints(data)
553554
if self._fitted:
554-
warnings.warn(
555+
msg = (
555556
'This model has already been fitted. To use the new preprocessed data, '
556557
"please refit the model using 'fit' or 'fit_processed_data'."
557558
)
559+
warnings.warn(msg, RefitWarning)
558560

559561
processed_data = {}
560562
pbar_args = self._get_pbar_args(desc='Preprocess Tables')

sdv/multi_table/hma.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,11 @@ def _set_extended_columns_distributions(self, synthesizer, table_name, valid_col
283283
for extended_column in self._parent_extended_columns[table_name]:
284284
if extended_column in valid_columns:
285285
numerical_distributions[extended_column] = DEFAULT_EXTENDED_COLUMNS_DISTRIBUTION
286-
synthesizer._set_numerical_distributions(numerical_distributions)
286+
287+
if numerical_distributions:
288+
existing = getattr(synthesizer, 'numerical_distributions', {}) or {}
289+
merged = {**existing, **numerical_distributions}
290+
synthesizer._set_numerical_distributions(merged)
287291

288292
def _get_extension(self, child_name, child_table, foreign_key, progress_bar_desc):
289293
"""Generate the extension columns for this child table.

sdv/single_table/base.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from sdv.errors import (
4242
ConstraintsNotMetError,
4343
InvalidDataError,
44+
RefitWarning,
4445
SamplingError,
4546
SynthesizerInputError,
4647
)
@@ -306,7 +307,7 @@ def update_transformers(self, column_name_to_transformer):
306307
self._data_processor.update_transformers(column_name_to_transformer)
307308
if self._fitted:
308309
msg = 'For this change to take effect, please refit the synthesizer using `fit`.'
309-
warnings.warn(msg, UserWarning)
310+
warnings.warn(msg, RefitWarning)
310311

311312
def get_parameters(self):
312313
"""Return the parameters used to instantiate the synthesizer."""
@@ -587,10 +588,12 @@ def _preprocess_helper(self, data):
587588
"""
588589
self.validate(data)
589590
if self._fitted:
590-
warnings.warn(
591+
msg = (
591592
'This model has already been fitted. To use the new preprocessed data, '
592593
"please refit the model using 'fit' or 'fit_processed_data'."
593594
)
595+
warnings.warn(msg, RefitWarning)
596+
594597
data = self._validate_transform_constraints(data)
595598

596599
return data
@@ -1208,18 +1211,19 @@ def sample(self, num_rows, max_tries_per_batch=100, batch_size=None, output_file
12081211

12091212
return sampled_data
12101213

1214+
def _transform_conditions(self, condition_df):
1215+
return self._data_processor.transform(condition_df, is_condition=True)
1216+
12111217
def _transform_conditions_chained_constraints(self, condition_df):
12121218
try:
12131219
transformed_condition = self._validate_transform_constraints(condition_df)
1214-
transformed_condition = self._data_processor.transform(
1215-
transformed_condition, is_condition=True
1216-
)
1220+
transformed_condition = self._transform_conditions(transformed_condition)
12171221
except ConstraintNotMetError:
12181222
raise ConstraintNotMetError(
12191223
'Provided conditions are not valid for the given constraints.'
12201224
)
12211225
except Exception:
1222-
transformed_condition = self._data_processor.transform(condition_df, is_condition=True)
1226+
transformed_condition = self._transform_conditions(condition_df)
12231227

12241228
return transformed_condition
12251229

@@ -1274,13 +1278,13 @@ def _sample_with_conditions(
12741278

12751279
condition = dict(zip(condition_columns, group))
12761280
condition_df = dataframe.iloc[0].to_frame().T
1281+
dtypes = conditions.dtypes.to_dict()
1282+
condition_df = condition_df.astype(dtypes)
12771283
if hasattr(self, '_chained_constraints'):
12781284
transformed_condition = self._transform_conditions_chained_constraints(condition_df)
12791285
else:
12801286
try:
1281-
transformed_condition = self._data_processor.transform(
1282-
condition_df, is_condition=True
1283-
)
1287+
transformed_condition = self._transform_conditions(condition_df)
12841288
except ConstraintsNotMetError as error:
12851289
raise ConstraintsNotMetError(
12861290
'Provided conditions are not valid for the given constraints.'

0 commit comments

Comments
 (0)