|
8 | 8 | from sdv.cag._errors import ConstraintNotMetError |
9 | 9 | from sdv.metadata import Metadata |
10 | 10 | from sdv.single_table import GaussianCopulaSynthesizer |
11 | | -from tests.utils import run_copula, run_hma |
| 11 | +from tests.utils import run_constraint, run_copula, run_hma |
12 | 12 |
|
13 | 13 |
|
14 | 14 | @pytest.fixture() |
@@ -201,3 +201,144 @@ def test_end_to_end_boolean(): |
201 | 201 | assert (samples.sum(axis=1) == 1).all() |
202 | 202 | for col in columns: |
203 | 203 | assert sorted(samples[col].unique().tolist()) == [0, 1] |
| 204 | + |
| 205 | + |
| 206 | +def test_end_to_end_categorical_single(data, metadata): |
| 207 | + """End-to-end with learning_strategy='categorical' for single-table data.""" |
| 208 | + # Setup |
| 209 | + constraint = OneHotEncoding(column_names=['a', 'b', 'c'], learning_strategy='categorical') |
| 210 | + |
| 211 | + # Run |
| 212 | + synthesizer = run_copula(data, metadata, [constraint]) |
| 213 | + synthetic_data = synthesizer.sample(200) |
| 214 | + synthesizer.validate_constraints(synthetic_data=synthetic_data) |
| 215 | + |
| 216 | + # Assert |
| 217 | + assert set(synthetic_data.columns) == {'a', 'b', 'c'} |
| 218 | + for col in ['a', 'b', 'c']: |
| 219 | + assert set(synthetic_data[col]) == {0, 1} |
| 220 | + assert (synthetic_data[['a', 'b', 'c']].sum(axis=1) == 1).all() |
| 221 | + |
| 222 | + |
| 223 | +def test_end_to_end_categorical_single_raises(data, metadata): |
| 224 | + """Invalid synthetic data should raise with learning_strategy='categorical'.""" |
| 225 | + # Setup |
| 226 | + invalid_data = pd.DataFrame({ |
| 227 | + 'a': [1, 2, 0], |
| 228 | + 'b': [0, 1, np.nan], |
| 229 | + 'c': [0, 0, 3], |
| 230 | + }) |
| 231 | + constraint = OneHotEncoding(column_names=['a', 'b', 'c'], learning_strategy='categorical') |
| 232 | + |
| 233 | + # Run and Assert |
| 234 | + msg = re.escape( |
| 235 | + "Data is not valid for the 'OneHotEncoding' constraint in table 'table':\n" |
| 236 | + ' a b c\n' |
| 237 | + '1 2 1.0 0\n' |
| 238 | + '2 0 NaN 3' |
| 239 | + ) |
| 240 | + with pytest.raises(ConstraintNotMetError, match=msg): |
| 241 | + run_copula(invalid_data, metadata, [constraint]) |
| 242 | + |
| 243 | + # Run and Assert |
| 244 | + msg = re.escape('The one hot encoding requirement is not met for row indices: 1, 2') |
| 245 | + with pytest.raises(ConstraintNotMetError, match=msg): |
| 246 | + synthesizer = run_copula(data, metadata, [constraint]) |
| 247 | + synthesizer.validate_constraints(synthetic_data=invalid_data) |
| 248 | + |
| 249 | + |
| 250 | +def test_end_to_end_categorical_multi(data_multi, metadata_multi): |
| 251 | + """End-to-end with learning_strategy='categorical' for multi-table data.""" |
| 252 | + # Setup |
| 253 | + constraint = OneHotEncoding( |
| 254 | + column_names=['a', 'b', 'c'], table_name='table1', learning_strategy='categorical' |
| 255 | + ) |
| 256 | + |
| 257 | + # Run |
| 258 | + synthesizer = run_hma(data_multi, metadata_multi, [constraint]) |
| 259 | + synthetic = synthesizer.sample(200) |
| 260 | + synthesizer.validate_constraints(synthetic_data=synthetic) |
| 261 | + |
| 262 | + # Assert |
| 263 | + assert set(synthetic['table1'].columns) == {'a', 'b', 'c'} |
| 264 | + for col in ['a', 'b', 'c']: |
| 265 | + assert set(synthetic['table1'][col]) == {0, 1} |
| 266 | + assert (synthetic['table1'][['a', 'b', 'c']].sum(axis=1) == 1).all() |
| 267 | + |
| 268 | + |
| 269 | +def test_end_to_end_categorical_multi_raises(data_multi, metadata_multi): |
| 270 | + """Invalid multi-table synthetic data should raise with learning_strategy='categorical'.""" |
| 271 | + # Setup |
| 272 | + constraint = OneHotEncoding( |
| 273 | + column_names=['a', 'b', 'c'], table_name='table1', learning_strategy='categorical' |
| 274 | + ) |
| 275 | + invalid = { |
| 276 | + 'table1': pd.DataFrame({ |
| 277 | + 'a': [1, 2, 0], |
| 278 | + 'b': [0, 1, np.nan], |
| 279 | + 'c': [0, 0, 3], |
| 280 | + }), |
| 281 | + 'table2': pd.DataFrame({'id': range(5)}), |
| 282 | + } |
| 283 | + |
| 284 | + # Run and Assert |
| 285 | + msg = re.escape( |
| 286 | + "Data is not valid for the 'OneHotEncoding' constraint in table 'table1':\n " |
| 287 | + 'a b c\n1 2 1.0 0\n2 0 NaN 3' |
| 288 | + ) |
| 289 | + with pytest.raises(ConstraintNotMetError, match=msg): |
| 290 | + run_hma(invalid, metadata_multi, [constraint]) |
| 291 | + |
| 292 | + # Run and Assert |
| 293 | + msg = "Table 'table1': The one hot encoding requirement is not met for row indices: 1, 2." |
| 294 | + with pytest.raises(ConstraintNotMetError, match=msg): |
| 295 | + synthesizer = run_hma(data_multi, metadata_multi, [constraint]) |
| 296 | + synthesizer.validate_constraints(synthetic_data=invalid) |
| 297 | + |
| 298 | + |
| 299 | +def test_constraint_pipeline_categorical_single(data, metadata): |
| 300 | + """Constraint pipeline behavior for categorical strategy (single table).""" |
| 301 | + # Setup |
| 302 | + constraint = OneHotEncoding(column_names=['a', 'b', 'c'], learning_strategy='categorical') |
| 303 | + |
| 304 | + # Run |
| 305 | + updated_metadata, transformed, reverse_transformed = run_constraint(constraint, data, metadata) |
| 306 | + |
| 307 | + # Assert metadata |
| 308 | + assert updated_metadata.get_column_names() == ['a#b#c'] |
| 309 | + |
| 310 | + # Assert transform |
| 311 | + assert transformed.shape[1] == 1 |
| 312 | + assert not any(col in transformed.columns for col in ['a', 'b', 'c']) |
| 313 | + assert set(transformed.columns) == {'a#b#c'} |
| 314 | + |
| 315 | + # Assert reverse_transform |
| 316 | + assert set(reverse_transformed.columns) == {'a', 'b', 'c'} |
| 317 | + assert (reverse_transformed[['a', 'b', 'c']].sum(axis=1) == 1).all() |
| 318 | + assert set(reverse_transformed.columns) == {'a', 'b', 'c'} |
| 319 | + |
| 320 | + |
| 321 | +def test_constraint_pipeline_categorical_multi(data_multi, metadata_multi): |
| 322 | + """Constraint pipeline behavior for categorical strategy (multi table).""" |
| 323 | + # Setup |
| 324 | + orig_cols = ['a', 'b', 'c'] |
| 325 | + constraint = OneHotEncoding( |
| 326 | + column_names=orig_cols, table_name='table1', learning_strategy='categorical' |
| 327 | + ) |
| 328 | + |
| 329 | + # Run |
| 330 | + updated_metadata, transformed, reverse_transformed = run_constraint( |
| 331 | + constraint, data_multi, metadata_multi |
| 332 | + ) |
| 333 | + |
| 334 | + # Assert metadata |
| 335 | + assert updated_metadata.tables['table1'].get_column_names() == ['a#b#c'] |
| 336 | + |
| 337 | + # Assert transform |
| 338 | + assert list(transformed['table1'].columns) != orig_cols |
| 339 | + assert transformed['table1'].shape[1] == 1 |
| 340 | + assert list(transformed['table2'].columns) == list(data_multi['table2'].columns) |
| 341 | + |
| 342 | + # Assert reverse_transform |
| 343 | + assert set(reverse_transformed['table1'].columns) == set(orig_cols) |
| 344 | + assert (reverse_transformed['table1'][orig_cols].sum(axis=1) == 1).all() |
0 commit comments