Skip to content

Commit 7ff0afa

Browse files
authored
[Fix] Combining column relationship with constraints not working (#2771)
1 parent 4c16f9d commit 7ff0afa

File tree

2 files changed

+51
-0
lines changed

2 files changed

+51
-0
lines changed

sdv/single_table/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,7 @@ def add_constraints(self, constraints):
474474
except ConstraintNotMetError:
475475
raise e
476476

477+
self.metadata.validate()
477478
self._data_processor = DataProcessor(
478479
metadata=self.metadata._convert_to_single_table(),
479480
enforce_rounding=self.enforce_rounding,

tests/integration/single_table/test_constraints.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,56 @@ def demo_metadata():
3737
return DEMO_METADATA
3838

3939

40+
def test_column_relationships_with_constraints():
41+
"""Test constraints with column relationships. GH#2768"""
42+
# Setup
43+
metadata = Metadata.load_from_dict({
44+
'tables': {
45+
'table': {
46+
'columns': {
47+
'id': {'sdtype': 'id'},
48+
'street': {'sdtype': 'street_address'},
49+
'city': {'sdtype': 'city'},
50+
'state': {'sdtype': 'administrative_unit'},
51+
'zip': {'sdtype': 'postcode'},
52+
'code': {'sdtype': 'categorical'},
53+
'description': {'sdtype': 'categorical'},
54+
}
55+
},
56+
},
57+
})
58+
metadata.add_column_relationship(
59+
table_name='table',
60+
relationship_type='address',
61+
column_names=['street', 'city', 'state', 'zip'],
62+
)
63+
data = pd.DataFrame({
64+
'id': [f'id_{i}' for i in range(6)],
65+
'street': ['123 Street'] * 6,
66+
'city': ['Boston', 'LA', 'Cambridge', 'San Francisco', 'Boston', 'LA'],
67+
'state': ['MA', 'CA'] * 3,
68+
'zip': ['72801'] * 6,
69+
'code': ['000', '001', '002'] * 2,
70+
'description': ['code 0', 'code 1', 'code 2'] * 2,
71+
})
72+
constraint = FixedCombinations(table_name='table', column_names=['code', 'description'])
73+
synth = GaussianCopulaSynthesizer(metadata)
74+
75+
# Run
76+
synth.add_constraints([constraint])
77+
synth.fit(data)
78+
samples = synth.sample(100)
79+
80+
# Assert
81+
assert samples.columns.tolist() == data.columns.to_list()
82+
expected_combinations = {('000', 'code 0'), ('001', 'code 1'), ('002', 'code 2')}
83+
sampled_combinations = {
84+
(code, description)
85+
for code, description in samples[['code', 'description']].drop_duplicates().to_numpy()
86+
}
87+
assert expected_combinations == sampled_combinations
88+
89+
4090
def test_conditional_sampling_with_constraints(demo_data, demo_metadata):
4191
"""Test constraints with conditional sampling. GH#1737"""
4292
# Setup

0 commit comments

Comments
 (0)