Skip to content

Commit 202227a

Browse files
committed
Add support for composite primary keys in the metadata
1 parent 6f05c1f commit 202227a

File tree

13 files changed

+370
-61
lines changed

13 files changed

+370
-61
lines changed

sdv/metadata/multi_table.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -183,9 +183,14 @@ def _validate_new_foreign_key_is_not_reused(
183183
and relationship['parent_primary_key'] == parent_primary_key
184184
)
185185
if foreign_key_already_used and not parent_matches:
186+
child_foreign_key = (
187+
f"('{child_foreign_key}')"
188+
if isinstance(child_foreign_key, str)
189+
else f'({child_foreign_key})'
190+
)
186191
raise InvalidMetadataError(
187192
f'Relationship between tables ({parent_table_name}, {child_table_name}) uses '
188-
f"a foreign key column ('{child_foreign_key}') that is already used in another "
193+
f'a foreign key {child_foreign_key} that is already used in another '
189194
'relationship.'
190195
)
191196

@@ -197,15 +202,23 @@ def _validate_foreign_key_uniqueness_across_relationships(
197202
child_foreign_key,
198203
seen_foreign_keys,
199204
):
200-
key = (child_table_name, child_foreign_key)
205+
key = (
206+
tuple(_cast_to_iterable(child_table_name)),
207+
tuple(_cast_to_iterable(child_foreign_key)),
208+
)
201209
current_relationship = (parent_table_name, parent_primary_key)
202210

203211
if key in seen_foreign_keys:
204212
existing_relationship = seen_foreign_keys[key]
205213
if existing_relationship != current_relationship:
214+
child_foreign_key = (
215+
f"('{child_foreign_key}')"
216+
if isinstance(child_foreign_key, str)
217+
else f'({child_foreign_key})'
218+
)
206219
raise InvalidMetadataError(
207220
f'Relationship between tables ({parent_table_name}, {child_table_name}) uses '
208-
f"a foreign key column ('{child_foreign_key}') that is already used in another "
221+
f'a foreign key {child_foreign_key} that is already used in another '
209222
'relationship.'
210223
)
211224
else:
@@ -296,10 +309,10 @@ def add_relationship(
296309
A string representing the name of the parent table.
297310
child_table_name (str):
298311
A string representing the name of the child table.
299-
parent_primary_key (str or tuple):
300-
A string or tuple of strings representing the primary key of the parent.
301-
child_foreign_key (str or tuple):
302-
A string or tuple of strings representing the foreign key of the child.
312+
parent_primary_key (str or list[str]):
313+
A string or list of strings representing the primary key of the parent.
314+
child_foreign_key (str or list[str]):
315+
A string or list of strings representing the foreign key of the child.
303316
304317
Raises:
305318
- ``InvalidMetadataError`` if a table is missing.
@@ -1206,9 +1219,19 @@ def _set_metadata_dict(self, metadata):
12061219
) from error
12071220

12081221
for relationship in metadata.get('relationships', []):
1222+
parent_pk = relationship.get('parent_primary_key')
1223+
child_fk = relationship.get('child_foreign_key')
1224+
type_safe_pk = (
1225+
[str(col) for col in parent_pk] if isinstance(parent_pk, list) else str(parent_pk)
1226+
)
1227+
type_safe_fk = (
1228+
[str(col) for col in child_fk] if isinstance(parent_pk, list) else str(child_fk)
1229+
)
12091230
type_safe_relationships = {
1210-
key: str(value) if not isinstance(value, str) else value
1211-
for key, value in relationship.items()
1231+
'parent_table_name': str(relationship.get('parent_table_name')),
1232+
'child_table_name': str(relationship.get('child_table_name')),
1233+
'parent_primary_key': type_safe_pk,
1234+
'child_foreign_key': type_safe_fk,
12121235
}
12131236
self.relationships.append(type_safe_relationships)
12141237

sdv/metadata/single_table.py

Lines changed: 76 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -211,14 +211,33 @@ def _validate_pii(column_name, **kwargs):
211211

212212
def __init__(self):
213213
self.columns = {}
214-
self.primary_key = None
214+
self._primary_key = None
215215
self.alternate_keys = []
216216
self.sequence_key = None
217217
self.sequence_index = None
218218
self.column_relationships = []
219219
self._version = self.METADATA_SPEC_VERSION
220220
self._updated = False
221221

222+
@property
223+
def _primary_key_is_composite(self):
224+
if self.primary_key and isinstance(self.primary_key, list) and len(self.primary_key) > 1:
225+
return True
226+
227+
return False
228+
229+
@property
230+
def primary_key(self):
231+
"""Property to handle singleton composite key case."""
232+
if isinstance(self._primary_key, list) and len(self._primary_key) == 1:
233+
return self._primary_key[0]
234+
235+
return self._primary_key
236+
237+
@primary_key.setter
238+
def primary_key(self, primary_key):
239+
self._primary_key = primary_key
240+
222241
def _get_unexpected_kwargs(self, sdtype, **kwargs):
223242
expected_kwargs = self._SDTYPE_KWARGS.get(sdtype, ['pii'])
224243
unexpected_kwargs = set(kwargs) - set(expected_kwargs)
@@ -774,29 +793,40 @@ def detect_from_csv(self, filepath, read_csv_parameters=None):
774793
self.detect_from_dataframe(data)
775794

776795
@staticmethod
777-
def _validate_key_datatype(column_name):
796+
def _validate_key_datatype(column_name, key_type):
778797
"""Check whether column_name is a string."""
779-
return isinstance(column_name, str)
798+
is_string = isinstance(column_name, str)
799+
is_list_of_strings = isinstance(column_name, list) and all(
800+
isinstance(i, str) for i in column_name
801+
)
802+
return is_string or (key_type == 'primary' and is_list_of_strings)
780803

781804
def _validate_keys_sdtype(self, keys, key_type):
782805
"""Validate that each key is of type 'id' or a valid Faker function."""
783-
bad_keys = set()
806+
bad_keys = []
784807
for key in keys:
785-
if not (
786-
self.columns[key]['sdtype'] == 'id'
787-
or is_faker_function(self.columns[key]['sdtype'])
808+
if not any(
809+
self.columns[key_col]['sdtype'] == 'id'
810+
or is_faker_function(self.columns[key_col]['sdtype'])
811+
for key_col in _cast_to_iterable(key)
788812
):
789-
bad_keys.add(key)
813+
bad_keys.append(key)
814+
790815
if bad_keys:
791816
raise InvalidMetadataError(
792-
f"The {key_type}_keys {sorted(bad_keys)} must be type 'id' or another PII type."
817+
f'The {key_type}_keys {bad_keys} must have a column of '
818+
"type 'id' or another PII type."
793819
)
794820

795821
def _validate_key(self, column_name, key_type):
796822
"""Validate the primary and sequence keys."""
797823
if column_name is not None:
798-
if not self._validate_key_datatype(column_name):
799-
raise InvalidMetadataError(f"'{key_type}_key' must be a string.")
824+
if not self._validate_key_datatype(column_name, key_type):
825+
err_msg = f"'{key_type}_key' must be a string"
826+
if key_type == 'primary':
827+
err_msg += ' or a list of strings'
828+
829+
raise InvalidMetadataError(err_msg + '.')
800830

801831
keys = {column_name} if isinstance(column_name, str) else set(column_name)
802832
setting_sequence_as_primary = key_type == 'primary' and column_name == self.sequence_key
@@ -814,7 +844,7 @@ def _validate_key(self, column_name, key_type):
814844
' Keys should be columns that exist in the table.'
815845
)
816846

817-
self._validate_keys_sdtype(keys, key_type)
847+
self._validate_keys_sdtype([column_name], key_type)
818848

819849
def set_primary_key(self, column_name):
820850
"""Set the metadata primary key.
@@ -866,7 +896,8 @@ def set_sequence_key(self, column_name):
866896

867897
def _validate_alternate_keys(self, column_names):
868898
if not isinstance(column_names, list) or not all(
869-
self._validate_key_datatype(column_name) for column_name in column_names
899+
self._validate_key_datatype(column_name, 'alternate_keys')
900+
for column_name in column_names
870901
):
871902
raise InvalidMetadataError("'alternate_keys' must be a list of strings.")
872903

@@ -1158,7 +1189,10 @@ def _get_primary_and_alternate_keys(self):
11581189
"""
11591190
keys = set(self.alternate_keys)
11601191
if self.primary_key:
1161-
keys.update({self.primary_key})
1192+
primary_key = (
1193+
tuple(self.primary_key) if isinstance(self.primary_key, list) else self.primary_key
1194+
)
1195+
keys.add(primary_key)
11621196

11631197
return keys
11641198

@@ -1181,31 +1215,45 @@ def _validate_keys_dont_have_missing_values(self, data):
11811215
errors = []
11821216
keys = self._get_primary_and_alternate_keys()
11831217
keys.update(self._get_set_of_sequence_keys())
1184-
for key in sorted(keys):
1185-
if pd.isna(data[key]).any():
1186-
errors.append(f"Key column '{key}' contains missing values.")
1218+
for key in sorted(keys, key=lambda key: key if isinstance(key, str) else key[0]):
1219+
key_list = [key] if isinstance(key, str) else list(key)
1220+
if pd.isna(data[key_list]).all(axis=1).any():
1221+
key = f"'{key}'" if isinstance(key, str) else f'{key}'
1222+
errors.append(f'Key column {key} contains missing values.')
11871223

11881224
return errors
11891225

11901226
def _validate_key_values_are_unique(self, data):
11911227
errors = []
11921228
keys = self._get_primary_and_alternate_keys()
1193-
for key in sorted(keys):
1194-
repeated_values = set(data[key][data[key].duplicated()])
1195-
if repeated_values:
1196-
repeated_values = _format_invalid_values_string(repeated_values, 3)
1197-
errors.append(f"Key column '{key}' contains repeating values: " + repeated_values)
1229+
for key in sorted(keys, key=lambda key: key if isinstance(key, str) else key[0]):
1230+
key_list = [key] if isinstance(key, str) else list(key)
1231+
repeated_values = data[key_list][data[key_list].duplicated()]
1232+
if not repeated_values.empty:
1233+
if len(repeated_values.columns) == 1:
1234+
repeated_values = ' ' + _format_invalid_values_string(
1235+
set(repeated_values[key]), 3
1236+
)
1237+
else:
1238+
repeated_values = '\n' + _format_invalid_values_string(
1239+
repeated_values.drop_duplicates(), 3
1240+
)
1241+
1242+
key = f"'{key}'" if isinstance(key, str) else f'{key}'
1243+
errors.append(f'Key column {key} contains repeating values:' + repeated_values)
11981244

11991245
return errors
12001246

12011247
def _validate_primary_key(self, data):
12021248
error = []
1203-
is_int = self.primary_key and pd.api.types.is_integer_dtype(data[self.primary_key])
1204-
regex = self.columns.get(self.primary_key, {}).get('regex_format')
1205-
if is_int and regex:
1206-
possible_characters = get_possible_chars(regex, 1)
1207-
if '0' in possible_characters:
1208-
error.append(f'Primary key "{self.primary_key}" {INT_REGEX_ZERO_ERROR_MESSAGE}')
1249+
primary_key_list = _cast_to_iterable(self.primary_key) if self.primary_key else []
1250+
for key in primary_key_list:
1251+
is_int = pd.api.types.is_integer_dtype(data[key])
1252+
regex = self.columns.get(key, {}).get('regex_format')
1253+
if is_int and regex:
1254+
possible_characters = get_possible_chars(regex, 1)
1255+
if '0' in possible_characters:
1256+
error.append(f'Primary key column "{key}" {INT_REGEX_ZERO_ERROR_MESSAGE}')
12091257

12101258
return error
12111259

sdv/multi_table/base.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,20 @@ def _check_metadata_updated(self):
114114
' in future SDV versions.'
115115
)
116116

117+
def _handle_composite_keys(self):
118+
"""Validates that composite keys are not used in Public SDV."""
119+
composite_key_tables = []
120+
for table, table_metadata in self.metadata.tables.items():
121+
if table_metadata._primary_key_is_composite:
122+
composite_key_tables.append(table)
123+
124+
if composite_key_tables:
125+
raise SynthesizerInputError(
126+
'Your metadata contains composite keys (primary key of tables '
127+
f'{composite_key_tables} have multiple columns). Composite keys are '
128+
'not supported in SDV Community.'
129+
)
130+
117131
def __init__(self, metadata, locales=['en_US'], synthesizer_kwargs=None):
118132
self.metadata = metadata
119133
if type(metadata) is MultiTableMetadata:
@@ -123,6 +137,7 @@ def __init__(self, metadata, locales=['en_US'], synthesizer_kwargs=None):
123137
self.metadata.validate()
124138

125139
self._check_metadata_updated()
140+
self._handle_composite_keys()
126141
self.locales = locales
127142
self.verbose = False
128143
self.extended_columns = defaultdict(dict)

sdv/single_table/base.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,15 @@ def _validate_regex_format(self):
150150
)
151151
_check_regex_format(self._table_name, column_name, regex)
152152

153+
def _handle_composite_keys(self, single_table_metadata):
154+
"""Validates that composite keys are not used in Public SDV."""
155+
if single_table_metadata._primary_key_is_composite:
156+
raise SynthesizerInputError(
157+
'Your metadata contains composite keys (primary key of table '
158+
f"'{self._table_name}' has multiple columns). Composite keys are "
159+
'not supported in SDV Community.'
160+
)
161+
153162
def __init__(
154163
self, metadata, enforce_min_max_values=True, enforce_rounding=True, locales=['en_US']
155164
):
@@ -172,6 +181,8 @@ def __init__(
172181

173182
self.metadata.validate()
174183
self._check_metadata_updated()
184+
single_table_metadata = self.metadata._convert_to_single_table()
185+
self._handle_composite_keys(single_table_metadata)
175186

176187
# Points to a metadata object that conserves the initialized status of the synthesizer
177188
self._original_metadata = deepcopy(self.metadata)
@@ -180,7 +191,7 @@ def __init__(
180191
self.enforce_rounding = enforce_rounding
181192
self.locales = locales
182193
self._data_processor = DataProcessor(
183-
metadata=self.metadata._convert_to_single_table(),
194+
metadata=single_table_metadata,
184195
enforce_rounding=self.enforce_rounding,
185196
enforce_min_max_values=self.enforce_min_max_values,
186197
locales=self.locales,

tests/integration/metadata/test_metadata.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -993,7 +993,7 @@ def test_validate_metadata_with_reused_foreign_keys():
993993
# Run and Assert
994994
error_msg = re.escape(
995995
'Relationships:\n'
996-
'Relationship between tables (A2, A3) uses a foreign key column '
996+
'Relationship between tables (A2, A3) uses a foreign key '
997997
"('fk3_A1_A2') that is already used in another relationship."
998998
)
999999
with pytest.raises(InvalidMetadataError, match=error_msg):

tests/integration/metadata/test_multi_table.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,47 @@ def test_multi_table_metadata():
2222
assert instance.relationships == []
2323

2424

25+
def test_multi_table_metadata_composite_keys():
26+
"""Test ``MultiTableMetadata`` with composite keys."""
27+
# Setup
28+
metadata_dict = {
29+
'tables': {
30+
'table1': {
31+
'columns': {
32+
'table1_id': {'sdtype': 'id'},
33+
'cat_col': {'sdtype': 'categorical'},
34+
},
35+
'primary_key': ['table1_id', 'cat_col'],
36+
},
37+
'table2': {
38+
'columns': {
39+
'pk': {'sdtype': 'id'},
40+
'fk1': {'sdtype': 'id'},
41+
'fk2': {'sdtype': 'categorical'},
42+
},
43+
'primary_key': 'pk',
44+
},
45+
},
46+
'relationships': [
47+
{
48+
'parent_table_name': 'table1',
49+
'parent_primary_key': ['table1_id', 'cat_col'],
50+
'child_table_name': 'table2',
51+
'child_foreign_key': ['fk1', 'fk2'],
52+
},
53+
],
54+
}
55+
56+
# Run
57+
instance = MultiTableMetadata.load_from_dict(metadata_dict)
58+
result = instance.to_dict()
59+
60+
# Assert
61+
instance.validate()
62+
assert result == {**metadata_dict, 'METADATA_SPEC_VERSION': 'MULTI_TABLE_V1'}
63+
assert instance.relationships == metadata_dict['relationships']
64+
65+
2566
@patch('rdt.transformers')
2667
def test_add_column_relationship(mock_rdt_transformers):
2768
"""Test ``add_column_relationship`` method."""

0 commit comments

Comments
 (0)