Skip to content

Commit c1c5a19

Browse files
authored
Add infer_sdtypes and infer_keys parameters to detect_from_dataframes method (#2363)
1 parent 363d8bd commit c1c5a19

File tree

7 files changed

+581
-84
lines changed

7 files changed

+581
-84
lines changed

sdv/metadata/metadata.py

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,13 @@ def load_from_dict(cls, metadata_dict, single_table_name=None):
6161
instance._set_metadata_dict(metadata_dict, single_table_name)
6262
return instance
6363

64+
@staticmethod
65+
def _validate_infer_sdtypes(infer_sdtypes):
66+
if not isinstance(infer_sdtypes, bool):
67+
raise ValueError("'infer_sdtypes' must be a boolean value.")
68+
6469
@classmethod
65-
def detect_from_dataframes(cls, data):
70+
def detect_from_dataframes(cls, data, infer_sdtypes=True, infer_keys='primary_and_foreign'):
6671
"""Detect the metadata for all tables in a dictionary of dataframes.
6772
6873
This method automatically detects the ``sdtypes`` for the given ``pandas.DataFrames``.
@@ -71,23 +76,50 @@ def detect_from_dataframes(cls, data):
7176
Args:
7277
data (dict):
7378
Dictionary of table names to dataframes.
79+
infer_sdtypes (bool):
80+
A boolean describing whether to infer the sdtypes of each column.
81+
If True it infers the sdtypes based on the data.
82+
If False it does not infer the sdtypes and all columns are marked as unknown.
83+
Defaults to True.
84+
infer_keys (str):
85+
A string describing whether to infer the primary and/or foreign keys. Options are:
86+
- 'primary_and_foreign': Infer the primary keys in each table,
87+
and the foreign keys in other tables that refer to them
88+
- 'primary_only': Infer only the primary keys of each table
89+
- None: Do not infer any keys
90+
Defaults to 'primary_and_foreign'.
7491
7592
Returns:
7693
Metadata:
7794
A new metadata object with the sdtypes detected from the data.
7895
"""
7996
if not data or not all(isinstance(df, pd.DataFrame) for df in data.values()):
8097
raise ValueError('The provided dictionary must contain only pandas DataFrame objects.')
98+
if infer_keys not in ['primary_and_foreign', 'primary_only', None]:
99+
raise ValueError(
100+
"'infer_keys' must be one of: 'primary_and_foreign', 'primary_only', None."
101+
)
102+
cls._validate_infer_sdtypes(infer_sdtypes)
81103

82104
metadata = Metadata()
83105
for table_name, dataframe in data.items():
84-
metadata.detect_table_from_dataframe(table_name, dataframe)
106+
metadata.detect_table_from_dataframe(
107+
table_name, dataframe, infer_sdtypes, None if infer_keys is None else 'primary_only'
108+
)
109+
110+
if infer_keys == 'primary_and_foreign':
111+
metadata._detect_relationships(data)
85112

86-
metadata._detect_relationships(data)
87113
return metadata
88114

89115
@classmethod
90-
def detect_from_dataframe(cls, data, table_name=DEFAULT_SINGLE_TABLE_NAME):
116+
def detect_from_dataframe(
117+
cls,
118+
data,
119+
table_name=DEFAULT_SINGLE_TABLE_NAME,
120+
infer_sdtypes=True,
121+
infer_keys='primary_only',
122+
):
91123
"""Detect the metadata for a DataFrame.
92124
93125
This method automatically detects the ``sdtypes`` for the given ``pandas.DataFrame``.
@@ -96,16 +128,29 @@ def detect_from_dataframe(cls, data, table_name=DEFAULT_SINGLE_TABLE_NAME):
96128
Args:
97129
data (pandas.DataFrame):
98130
Dictionary of table names to dataframes.
131+
infer_sdtypes (bool):
132+
A boolean describing whether to infer the sdtypes of each column.
133+
If True it infers the sdtypes based on the data.
134+
If False it does not infer the sdtypes and all columns are marked as unknown.
135+
Defaults to True.
136+
infer_keys (str):
137+
A string describing whether to infer the primary keys. Options are:
138+
- 'primary_only': Infer only the primary keys of each table
139+
- None: Do not infer any keys
140+
Defaults to 'primary_only'.
99141
100142
Returns:
101143
Metadata:
102144
A new metadata object with the sdtypes detected from the data.
103145
"""
104146
if not isinstance(data, pd.DataFrame):
105147
raise ValueError('The provided data must be a pandas DataFrame object.')
148+
if infer_keys not in ['primary_only', None]:
149+
raise ValueError("'infer_keys' must be one of: 'primary_only', None.")
150+
cls._validate_infer_sdtypes(infer_sdtypes)
106151

107152
metadata = Metadata()
108-
metadata.detect_table_from_dataframe(table_name, data)
153+
metadata.detect_table_from_dataframe(table_name, data, infer_sdtypes, infer_keys)
109154
return metadata
110155

111156
def _set_metadata_dict(self, metadata, single_table_name=None):

sdv/metadata/multi_table.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -530,7 +530,9 @@ def _detect_relationships(self, data=None):
530530
)
531531
continue
532532

533-
def detect_table_from_dataframe(self, table_name, data):
533+
def detect_table_from_dataframe(
534+
self, table_name, data, infer_sdtypes=True, infer_keys='primary_only'
535+
):
534536
"""Detect the metadata for a table from a dataframe.
535537
536538
This method automatically detects the ``sdtypes`` for the given ``pandas.DataFrame``,
@@ -541,10 +543,20 @@ def detect_table_from_dataframe(self, table_name, data):
541543
Name of the table to detect.
542544
data (pandas.DataFrame):
543545
``pandas.DataFrame`` to detect the metadata from.
546+
infer_sdtypes (bool):
547+
A boolean describing whether to infer the sdtypes of each column.
548+
If True it infers the sdtypes based on the data.
549+
If False it does not infer the sdtypes and all columns are marked as unknown.
550+
Defaults to True.
551+
infer_keys (str):
552+
A string describing whether to infer the primary keys. Options are:
553+
- 'primary_only': Infer only the primary keys of each table
554+
- None: Do not infer any keys
555+
Defaults to 'primary_only'.
544556
"""
545557
self._validate_table_not_detected(table_name)
546558
table = SingleTableMetadata()
547-
table._detect_columns(data, table_name)
559+
table._detect_columns(data, table_name, infer_sdtypes, infer_keys)
548560
self.tables[table_name] = table
549561
self._log_detected_table(table)
550562

sdv/metadata/single_table.py

Lines changed: 50 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -595,53 +595,67 @@ def _detect_primary_key(self, data):
595595

596596
return None
597597

598-
def _detect_columns(self, data, table_name=None):
598+
def _detect_columns(self, data, table_name=None, infer_sdtypes=True, infer_keys='primary_only'):
599599
"""Detect the columns' sdtypes from the data.
600600
601601
Args:
602602
data (pandas.DataFrame):
603603
The data to be analyzed.
604604
table_name (str):
605605
The name of the table to be analyzed. Defaults to ``None``.
606+
infer_sdtypes (bool):
607+
A boolean describing whether to infer the sdtypes of each column.
608+
If True it infers the sdtypes based on the data.
609+
If False it does not infer the sdtypes and all columns are marked as unknown.
610+
Defaults to True.
611+
infer_keys (str):
612+
A string describing whether to infer the primary keys. Options are:
613+
- 'primary_only': Infer the primary keys.
614+
- None: Do not infer any keys.
615+
Defaults to 'primary_only'.
606616
"""
607617
old_columns = data.columns
608618
data.columns = data.columns.astype(str)
609619
for field in data:
610-
try:
611-
column_data = data[field]
612-
clean_data = column_data.dropna()
613-
dtype = clean_data.infer_objects().dtype.kind
614-
615-
sdtype = self._detect_pii_column(field)
616-
if sdtype is None:
617-
if dtype in self._DTYPES_TO_SDTYPES:
618-
sdtype = self._DTYPES_TO_SDTYPES[dtype]
619-
elif dtype in ['i', 'f', 'u']:
620-
sdtype = self._determine_sdtype_for_numbers(column_data)
621-
622-
elif dtype == 'O':
623-
sdtype = self._determine_sdtype_for_objects(column_data)
620+
if infer_sdtypes:
621+
try:
622+
column_data = data[field]
623+
clean_data = column_data.dropna()
624+
dtype = clean_data.infer_objects().dtype.kind
624625

626+
sdtype = self._detect_pii_column(field)
625627
if sdtype is None:
626-
table_str = f"table '{table_name}' " if table_name else ''
627-
error_message = (
628-
f"Unsupported data type for {table_str}column '{field}' (kind: {dtype}"
629-
"). The valid data types are: 'object', 'int', 'float', 'datetime',"
630-
" 'bool'."
631-
)
632-
raise InvalidMetadataError(error_message)
633-
634-
except Exception as e:
635-
error_type = type(e).__name__
636-
if error_type == 'InvalidMetadataError':
637-
raise e
638-
639-
table_str = f"table '{table_name}' " if table_name else ''
640-
error_message = (
641-
f"Unable to detect metadata for {table_str}column '{field}' due to an invalid "
642-
f'data format.\n {error_type}: {e}'
643-
)
644-
raise InvalidMetadataError(error_message) from e
628+
if dtype in self._DTYPES_TO_SDTYPES:
629+
sdtype = self._DTYPES_TO_SDTYPES[dtype]
630+
elif dtype in ['i', 'f', 'u']:
631+
sdtype = self._determine_sdtype_for_numbers(column_data)
632+
633+
elif dtype == 'O':
634+
sdtype = self._determine_sdtype_for_objects(column_data)
635+
636+
if sdtype is None:
637+
table_str = f"table '{table_name}' " if table_name else ''
638+
error_message = (
639+
f"Unsupported data type for {table_str}column '{field}' "
640+
f"(kind: {dtype}). The valid data types are: 'object', "
641+
"'int', 'float', 'datetime', 'bool'."
642+
)
643+
raise InvalidMetadataError(error_message)
644+
645+
except Exception as e:
646+
error_type = type(e).__name__
647+
if error_type == 'InvalidMetadataError':
648+
raise e
649+
650+
table_str = f"table '{table_name}' " if table_name else ''
651+
error_message = (
652+
f"Unable to detect metadata for {table_str}column '{field}' due "
653+
f'to an invalid data format.\n {error_type}: {e}'
654+
)
655+
raise InvalidMetadataError(error_message) from e
656+
657+
else:
658+
sdtype = 'unknown'
645659

646660
column_dict = {'sdtype': sdtype}
647661
sdtype_in_reference = sdtype in self._REFERENCE_TO_SDTYPE.values()
@@ -655,7 +669,8 @@ def _detect_columns(self, data, table_name=None):
655669

656670
self.columns[field] = deepcopy(column_dict)
657671

658-
self.primary_key = self._detect_primary_key(data)
672+
if infer_keys == 'primary_only':
673+
self.primary_key = self._detect_primary_key(data)
659674
self._updated = True
660675
data.columns = old_columns
661676

0 commit comments

Comments
 (0)