77
88import numpy as np
99import pandas as pd
10+ from sdv .metadata import Metadata
11+ from sdv .utils import poc
1012
1113LOGGER = logging .getLogger (__name__ )
1214
15+ MAX_NUM_COLUMNS = 10
16+ MAX_NUM_ROWS = 1000
17+
1318
1419def _parse_numeric_value (value , dataset_name , field_name , target_type = float ):
1520 """Generic parser for numeric values with logging and NaN fallback."""
@@ -23,6 +28,65 @@ def _parse_numeric_value(value, dataset_name, field_name, target_type=float):
2328 return np .nan
2429
2530
31+ def _filter_columns (columns , mandatory_columns ):
32+ """Given a dictionary of columns and a list of mandatory ones, return a filtered subset."""
33+ mandatory_columns = [m_col for m_col in mandatory_columns if m_col in columns ]
34+ optional_columns = [col for col in columns if col not in mandatory_columns ]
35+ keep_columns = mandatory_columns + optional_columns [:MAX_NUM_COLUMNS ]
36+ return {col : columns [col ] for col in keep_columns if col in columns }
37+
38+
39+ def _get_multi_table_dataset_subset (data , metadata_dict ):
40+ """Create a smaller, referentially consistent subset of multi-table data.
41+
42+ This function limits each table to at most 10 columns by keeping all
43+ mandatory columns and, if needed, a subset of the remaining columns, then
44+ trims the underlying DataFrames to match the updated metadata. Finally, it
45+ uses SDV's multi-table utility to sample up to 1,000 rows from
46+ the main table and a consistent subset of rows from all related tables
47+ while preserving referential integrity.
48+
49+ Args:
50+ data (dict):
51+ A dictionary where keys are table names and values are DataFrames
52+ representing tables.
53+ metadata_dict (dict):
54+ Metadata dictionary containing schema information for each table.
55+
56+ Returns:
57+ tuple:
58+ A tuple containing:
59+ - dict: The subset of the input data with reduced columns and rows.
60+ - dict: The updated metadata dictionary reflecting the reduced column sets.
61+ """
62+ metadata = Metadata .load_from_dict (metadata_dict )
63+ for table_name , table in metadata .tables .items ():
64+ table_columns = table .columns
65+ mandatory_columns = list (metadata ._get_all_keys (table_name ))
66+ subset_column_schema = _filter_columns (
67+ columns = table_columns , mandatory_columns = mandatory_columns
68+ )
69+ metadata_dict ['tables' ][table_name ]['columns' ] = subset_column_schema
70+
71+ # Re-load the metadata object that will be used with the `SDV` utility function
72+ metadata = Metadata .load_from_dict (metadata_dict )
73+ largest_table_name = max (data , key = lambda table_name : len (data [table_name ]))
74+
75+ # Trim the data to contain only the subset of columns
76+ for table_name , table in metadata .tables .items ():
77+ data [table_name ] = data [table_name ][list (table .columns )]
78+
79+ # Subsample the data mantaining the referential integrity
80+ data = poc .get_random_subset (
81+ data = data ,
82+ metadata = metadata ,
83+ main_table_name = largest_table_name ,
84+ num_rows = MAX_NUM_ROWS ,
85+ verbose = False ,
86+ )
87+ return data , metadata_dict
88+
89+
2690def _get_dataset_subset (data , metadata_dict , modality ):
2791 """Limit the size of a dataset for faster evaluation or testing.
2892
@@ -31,52 +95,37 @@ def _get_dataset_subset(data, metadata_dict, modality):
3195 columns—such as sequence indices and keys in sequential datasets—are always retained.
3296
3397 Args:
34- data (pd.DataFrame):
98+ data (pd.DataFrame or dict ):
3599 The dataset to be reduced.
36100 metadata_dict (dict):
37- A dictionary containing the dataset's metadata.
101+ A dictionary representing the dataset's metadata.
38102 modality (str):
39- The dataset modality. Must be one of: ``'single_table'``, ``'sequential'``.
103+ The dataset modality.
40104
41105 Returns:
42106 tuple[pd.DataFrame, dict]:
43107 A tuple containing:
44- - The reduced dataset as a DataFrame.
108+ - The reduced dataset as a DataFrame or Dictionary .
45109 - The updated metadata dictionary reflecting any removed columns.
46-
47- Raises:
48- ValueError:
49- If the provided modality is ``'multi_table'``.
50110 """
51111 if modality == 'multi_table' :
52- raise ValueError ( 'limit_dataset_size is not supported for multi-table datasets.' )
112+ return _get_multi_table_dataset_subset ( data , metadata_dict )
53113
54- max_rows , max_columns = (1000 , 10 )
55114 tables = metadata_dict .get ('tables' , {})
56115 mandatory_columns = []
57116 table_name , table_info = next (iter (tables .items ()))
58-
59117 columns = table_info .get ('columns' , {})
60- keep_columns = list (columns )
61- if modality == 'sequential' :
62- seq_index = table_info .get ('sequence_index' )
63- seq_key = table_info .get ('sequence_key' )
64- mandatory_columns = [col for col in (seq_index , seq_key ) if col ]
65118
66- optional_columns = [col for col in columns if col not in mandatory_columns ]
119+ seq_index = table_info .get ('sequence_index' )
120+ seq_key = table_info .get ('sequence_key' )
121+ mandatory_columns = [column for column in (seq_index , seq_key ) if column ]
122+ filtered = _filter_columns (columns = columns , mandatory_columns = mandatory_columns )
67123
68- # If we have too many columns, drop extras but never mandatory ones
69- if len (columns ) > max_columns :
70- keep_count = max_columns - len (mandatory_columns )
71- keep_columns = mandatory_columns + optional_columns [:keep_count ]
72- table_info ['columns' ] = {
73- column_name : column_definition
74- for column_name , column_definition in columns .items ()
75- if column_name in keep_columns
76- }
77-
78- data = data [list (keep_columns )]
124+ table_info ['columns' ] = filtered
125+ data = data [list (filtered )]
126+ max_rows = min (MAX_NUM_ROWS , len (data ))
79127 data = data .sample (max_rows )
128+
80129 return data , metadata_dict
81130
82131
0 commit comments