diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index 5150f5f5b..8037fbf4b 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -82,7 +82,7 @@ def _set_temp_numpy_seed(self): def _initialize_models(self): with disable_single_table_logger(): - for table_name, table_metadata in self.metadata.tables.items(): + for table_name, table_metadata in self._modified_multi_table_metadata.tables.items(): synthesizer_parameters = {'locales': self.locales} synthesizer_parameters.update(self._table_parameters.get(table_name, {})) metadata_dict = {'tables': {table_name: table_metadata.to_dict()}} @@ -132,7 +132,7 @@ def __init__(self, metadata, locales=['en_US'], synthesizer_kwargs=None): self._original_metadata = deepcopy(self.metadata) self._modified_multi_table_metadata = deepcopy(self.metadata) self.constraints = [] - self._has_seen_single_table_constraint = False + self._single_table_constraints = [] if synthesizer_kwargs is not None: warn_message = ( 'The `synthesizer_kwargs` parameter is deprecated as of SDV 1.2.0 and does not ' @@ -180,9 +180,10 @@ def _detect_single_table_constraints(self, constraints): constraints (list): A list of constraints to filter. """ - idx_single_table_constraint = 0 if self._has_seen_single_table_constraint else None + has_seen_single_table_constraint = len(self._single_table_constraints) > 0 + idx_single_table_constraint = 0 if has_seen_single_table_constraint else None for idx, constraint in enumerate(constraints): - if self._has_seen_single_table_constraint and constraint._is_single_table is False: + if has_seen_single_table_constraint and constraint._is_single_table is False: raise SynthesizerInputError( 'Cannot apply multi-table constraint after single-table constraint has ' 'been applied.' @@ -191,8 +192,8 @@ def _detect_single_table_constraints(self, constraints): if not constraint._is_single_table: continue - if not self._has_seen_single_table_constraint: - self._has_seen_single_table_constraint = True + if not has_seen_single_table_constraint: + has_seen_single_table_constraint = True idx_single_table_constraint = idx return idx_single_table_constraint @@ -234,8 +235,8 @@ def add_constraints(self, constraints): self.constraints += multi_table_constraints self._constraints_fitted = False self._initialize_models() - if single_table_constraints: - for constraint in single_table_constraints: + if self._single_table_constraints or single_table_constraints: + for constraint in [*self._single_table_constraints, *single_table_constraints]: table_name = constraint.table_name self._table_synthesizers[table_name].add_constraints([constraint]) try: @@ -243,11 +244,12 @@ def add_constraints(self, constraints): except ConstraintNotMetError: constraint.get_updated_metadata(self._modified_multi_table_metadata) + self._single_table_constraints += single_table_constraints + def get_constraints(self): """Get a copy of the list of constraints applied to the synthesizer.""" if not hasattr(self, 'constraints'): return [] - constraints = [] for constraint in self.constraints: if isinstance(constraint, ProgrammableConstraintHarness): diff --git a/tests/integration/multi_table/test_constraints.py b/tests/integration/multi_table/test_constraints.py index fc1c737bf..87605d9a6 100644 --- a/tests/integration/multi_table/test_constraints.py +++ b/tests/integration/multi_table/test_constraints.py @@ -47,3 +47,47 @@ def test_overlapping_single_table_constraints(): # Assert assert all(sampled['parent_table']['colA'] < sampled['parent_table']['colB']) assert all(sampled['parent_table']['colB'] < sampled['parent_table']['colC']) + + +def test_add_constraint_iteratively(): + """Test adding constraints in multiple steps.""" + # Setup + parent_table = pd.DataFrame({ + 'id': [i for i in range(20)], + 'colA': np.random.randint(low=0, high=100, size=20), + }) + parent_table['colB'] = parent_table['colA'] + np.random.randint(low=1, high=10, size=20) + parent_table['colC'] = parent_table['colB'] + np.random.randint(low=1, high=10, size=20) + + child_table = pd.DataFrame({ + 'parent_id': np.random.randint(low=0, high=20, size=100), + 'colD': np.random.randint(low=100, high=200, size=100), + }) + data = {'parent_table': parent_table, 'child_table': child_table} + + metadata = Metadata() + metadata = Metadata.detect_from_dataframes(data) + + constraint1 = Inequality( + low_column_name='colA', + high_column_name='colB', + table_name='parent_table', + strict_boundaries=True, + ) + constraint2 = Inequality( + low_column_name='colB', + high_column_name='colC', + table_name='parent_table', + strict_boundaries=True, + ) + synthesizer = HMASynthesizer(metadata) + + # Run + synthesizer.add_constraints([constraint1]) + synthesizer.add_constraints([constraint2]) + synthesizer.fit(data) + sampled = synthesizer.sample(10) + + # Assert + assert all(sampled['parent_table']['colA'] < sampled['parent_table']['colB']) + assert all(sampled['parent_table']['colB'] < sampled['parent_table']['colC']) diff --git a/tests/unit/multi_table/test_base.py b/tests/unit/multi_table/test_base.py index 877c71dd5..fb3cebd8a 100644 --- a/tests/unit/multi_table/test_base.py +++ b/tests/unit/multi_table/test_base.py @@ -48,6 +48,7 @@ def test__initialize_models(self): } instance.locales = locales instance.metadata = get_multi_table_metadata() + instance._modified_multi_table_metadata = instance.metadata # Run BaseMultiTableSynthesizer._initialize_models(instance) @@ -1488,7 +1489,7 @@ def test__detect_single_table_constraints(self): # Setup instance = Mock() instance.metadata = get_multi_table_metadata() - instance._has_seen_single_table_constraint = False + instance._single_table_constraints = [] instance._table_synthesizers = { 'table1': Mock(), 'table2': Mock(), @@ -1512,9 +1513,10 @@ def test__detect_single_table_constraints(self): idx_single_table_2 = BaseMultiTableSynthesizer._detect_single_table_constraints( instance, [constraint3, constraint4] ) + instance._single_table_constraints = [constraint3, constraint4] with pytest.raises(SynthesizerInputError, match=expected_error): BaseMultiTableSynthesizer._detect_single_table_constraints(instance, [constraint1]) - instance._has_seen_single_table_constraint = False + instance._single_table_constraints = [] with pytest.raises(SynthesizerInputError, match=expected_error): BaseMultiTableSynthesizer._detect_single_table_constraints( instance, [constraint1, constraint3, constraint2] @@ -1534,6 +1536,7 @@ def test_add_constraints(self, mock_validate_constraints, mock_programmable_cons instance.metadata = original_metadata instance._original_metadata = original_metadata instance.constraints = [] + instance._single_table_constraints = [] constraint1 = Mock() constraint2 = Mock() constraint3 = ProgrammableConstraint() @@ -1584,6 +1587,7 @@ def test_add_constraints_single_table_overlap(self, mock_validate_constraints): instance.metadata = original_metadata instance._original_metadata = original_metadata instance.constraints = [] + instance._single_table_constraints = [] constraint1 = Mock() constraint1.table_name = 'table1' constraint2 = Mock() @@ -1634,6 +1638,7 @@ def test_updating_constraints_keeps_original_metadata(self, mock_validate_constr constraint1 = Mock() constraint2 = Mock() instance.constraints = [constraint1] + instance._single_table_constraints = [] instance._detect_single_table_constraints = Mock(return_value=None) mock_validate_constraints.return_value = [constraint2]