Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions sdv/multi_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def _set_temp_numpy_seed(self):

def _initialize_models(self):
with disable_single_table_logger():
for table_name, table_metadata in self.metadata.tables.items():
for table_name, table_metadata in self._modified_multi_table_metadata.tables.items():
synthesizer_parameters = {'locales': self.locales}
synthesizer_parameters.update(self._table_parameters.get(table_name, {}))
metadata_dict = {'tables': {table_name: table_metadata.to_dict()}}
Expand Down Expand Up @@ -132,7 +132,7 @@ def __init__(self, metadata, locales=['en_US'], synthesizer_kwargs=None):
self._original_metadata = deepcopy(self.metadata)
self._modified_multi_table_metadata = deepcopy(self.metadata)
self.constraints = []
self._has_seen_single_table_constraint = False
self._single_table_constraints = []
if synthesizer_kwargs is not None:
warn_message = (
'The `synthesizer_kwargs` parameter is deprecated as of SDV 1.2.0 and does not '
Expand Down Expand Up @@ -180,9 +180,10 @@ def _detect_single_table_constraints(self, constraints):
constraints (list):
A list of constraints to filter.
"""
idx_single_table_constraint = 0 if self._has_seen_single_table_constraint else None
has_seen_single_table_constraint = len(self._single_table_constraints) > 0
idx_single_table_constraint = 0 if has_seen_single_table_constraint else None
for idx, constraint in enumerate(constraints):
if self._has_seen_single_table_constraint and constraint._is_single_table is False:
if has_seen_single_table_constraint and constraint._is_single_table is False:
raise SynthesizerInputError(
'Cannot apply multi-table constraint after single-table constraint has '
'been applied.'
Expand All @@ -191,8 +192,8 @@ def _detect_single_table_constraints(self, constraints):
if not constraint._is_single_table:
continue

if not self._has_seen_single_table_constraint:
self._has_seen_single_table_constraint = True
if not has_seen_single_table_constraint:
has_seen_single_table_constraint = True
idx_single_table_constraint = idx

return idx_single_table_constraint
Expand Down Expand Up @@ -234,20 +235,21 @@ def add_constraints(self, constraints):
self.constraints += multi_table_constraints
self._constraints_fitted = False
self._initialize_models()
if single_table_constraints:
for constraint in single_table_constraints:
if self._single_table_constraints or single_table_constraints:
for constraint in [*self._single_table_constraints, *single_table_constraints]:
table_name = constraint.table_name
self._table_synthesizers[table_name].add_constraints([constraint])
try:
self.metadata = constraint.get_updated_metadata(self.metadata)
except ConstraintNotMetError:
constraint.get_updated_metadata(self._modified_multi_table_metadata)

self._single_table_constraints += single_table_constraints

def get_constraints(self):
"""Get a copy of the list of constraints applied to the synthesizer."""
if not hasattr(self, 'constraints'):
return []

constraints = []
for constraint in self.constraints:
if isinstance(constraint, ProgrammableConstraintHarness):
Expand Down
44 changes: 44 additions & 0 deletions tests/integration/multi_table/test_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,47 @@ def test_overlapping_single_table_constraints():
# Assert
assert all(sampled['parent_table']['colA'] < sampled['parent_table']['colB'])
assert all(sampled['parent_table']['colB'] < sampled['parent_table']['colC'])


def test_add_constraint_iteratively():
"""Test adding constraints in multiple steps."""
# Setup
parent_table = pd.DataFrame({
'id': [i for i in range(20)],
'colA': np.random.randint(low=0, high=100, size=20),
})
parent_table['colB'] = parent_table['colA'] + np.random.randint(low=1, high=10, size=20)
parent_table['colC'] = parent_table['colB'] + np.random.randint(low=1, high=10, size=20)

child_table = pd.DataFrame({
'parent_id': np.random.randint(low=0, high=20, size=100),
'colD': np.random.randint(low=100, high=200, size=100),
})
data = {'parent_table': parent_table, 'child_table': child_table}

metadata = Metadata()
metadata = Metadata.detect_from_dataframes(data)

constraint1 = Inequality(
low_column_name='colA',
high_column_name='colB',
table_name='parent_table',
strict_boundaries=True,
)
constraint2 = Inequality(
low_column_name='colB',
high_column_name='colC',
table_name='parent_table',
strict_boundaries=True,
)
synthesizer = HMASynthesizer(metadata)

# Run
synthesizer.add_constraints([constraint1])
synthesizer.add_constraints([constraint2])
synthesizer.fit(data)
sampled = synthesizer.sample(10)

# Assert
assert all(sampled['parent_table']['colA'] < sampled['parent_table']['colB'])
assert all(sampled['parent_table']['colB'] < sampled['parent_table']['colC'])
9 changes: 7 additions & 2 deletions tests/unit/multi_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def test__initialize_models(self):
}
instance.locales = locales
instance.metadata = get_multi_table_metadata()
instance._modified_multi_table_metadata = instance.metadata

# Run
BaseMultiTableSynthesizer._initialize_models(instance)
Expand Down Expand Up @@ -1488,7 +1489,7 @@ def test__detect_single_table_constraints(self):
# Setup
instance = Mock()
instance.metadata = get_multi_table_metadata()
instance._has_seen_single_table_constraint = False
instance._single_table_constraints = []
instance._table_synthesizers = {
'table1': Mock(),
'table2': Mock(),
Expand All @@ -1512,9 +1513,10 @@ def test__detect_single_table_constraints(self):
idx_single_table_2 = BaseMultiTableSynthesizer._detect_single_table_constraints(
instance, [constraint3, constraint4]
)
instance._single_table_constraints = [constraint3, constraint4]
with pytest.raises(SynthesizerInputError, match=expected_error):
BaseMultiTableSynthesizer._detect_single_table_constraints(instance, [constraint1])
instance._has_seen_single_table_constraint = False
instance._single_table_constraints = []
with pytest.raises(SynthesizerInputError, match=expected_error):
BaseMultiTableSynthesizer._detect_single_table_constraints(
instance, [constraint1, constraint3, constraint2]
Expand All @@ -1534,6 +1536,7 @@ def test_add_constraints(self, mock_validate_constraints, mock_programmable_cons
instance.metadata = original_metadata
instance._original_metadata = original_metadata
instance.constraints = []
instance._single_table_constraints = []
constraint1 = Mock()
constraint2 = Mock()
constraint3 = ProgrammableConstraint()
Expand Down Expand Up @@ -1584,6 +1587,7 @@ def test_add_constraints_single_table_overlap(self, mock_validate_constraints):
instance.metadata = original_metadata
instance._original_metadata = original_metadata
instance.constraints = []
instance._single_table_constraints = []
constraint1 = Mock()
constraint1.table_name = 'table1'
constraint2 = Mock()
Expand Down Expand Up @@ -1634,6 +1638,7 @@ def test_updating_constraints_keeps_original_metadata(self, mock_validate_constr
constraint1 = Mock()
constraint2 = Mock()
instance.constraints = [constraint1]
instance._single_table_constraints = []
instance._detect_single_table_constraints = Mock(return_value=None)
mock_validate_constraints.return_value = [constraint2]

Expand Down