Skip to content

Commit e4f0369

Browse files
committed
wip
1 parent 2d56969 commit e4f0369

File tree

4 files changed

+19
-4
lines changed

4 files changed

+19
-4
lines changed

sdv/metadata/multi_table.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,7 @@ def _get_foreign_keys(self, parent_table_name, child_table_name):
276276
return foreign_keys
277277

278278
def _get_all_foreign_keys(self, table_name):
279+
279280
foreign_keys = []
280281
for relation in self.relationships:
281282
if table_name == relation['child_table_name']:

sdv/multi_table/base.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -468,9 +468,19 @@ def _validate_table_name(self, table_name):
468468
def _assign_table_transformers(self, synthesizer, table_name, table_data):
469469
"""Update the ``synthesizer`` to ignore the foreign keys while preprocessing the data."""
470470
synthesizer.auto_assign_transformers(table_data)
471-
foreign_key_columns = self.metadata._get_all_foreign_keys(table_name)
472-
column_name_to_transformers = {column_name: None for column_name in foreign_key_columns}
473-
print(column_name_to_transformers)
471+
472+
# foreign_key_columns = self.metadata._get_all_foreign_keys(table_name)
473+
# column_name_to_transformers = {column_name: None for column_name in foreign_key_columns}
474+
475+
column_name_to_transformers = {}
476+
primary_key = self.metadata.tables[table_name].primary_key
477+
for relation in self.metadata.relationships:
478+
column_name = deepcopy(relation['child_foreign_key'])
479+
if (
480+
relation['child_table_name'] == table_name
481+
and relation['child_foreign_key'] != primary_key
482+
):
483+
column_name_to_transformers[column_name] = None
474484
synthesizer.update_transformers(column_name_to_transformers)
475485

476486
def auto_assign_transformers(self, data):

sdv/single_table/base.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,6 @@ def _validate_transformers(self, column_name_to_transformer):
247247
for column, transformer in column_name_to_transformer.items():
248248
if transformer is None:
249249
continue
250-
251250
if column in keys and not transformer.is_generator():
252251
raise SynthesizerInputError(
253252
f"Column '{column}' is a key. It cannot be preprocessed using "

tests/integration/multi_table/test_hma.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2903,6 +2903,7 @@ def data_metadata_1_to_1():
29032903
metadata = Metadata.load_from_dict(metadata_dict)
29042904
metadata.validate()
29052905
metadata.validate_data(data)
2906+
metadata.remove_primary_key('rooms')
29062907
return data, metadata
29072908

29082909

@@ -2917,3 +2918,7 @@ def test_hma_1_to_1(data_metadata_1_to_1):
29172918

29182919
# Assert
29192920
assert synthetic_data['guests']['guest_email'].equals(synthetic_data['guests']['guest_email'])
2921+
2922+
2923+
def test_hma_1_to_1_or_0(data_metadata_1_to_1):
2924+
pass

0 commit comments

Comments
 (0)