1616from sdv .sampling import BaseHierarchicalSampler
1717
1818LOGGER = logging .getLogger (__name__ )
19- MAX_NUMBER_OF_COLUMNS = 1000
19+ MAX_NUMBER_OF_COLUMNS = 1_000_000
2020DEFAULT_EXTENDED_COLUMNS_DISTRIBUTION = 'truncnorm'
21- PERFORMANCE_ALERT_DISPLAY_CAP = 1_000_000
22-
23-
24- class _EarlyStopEstimation (Exception ):
25- pass
21+ PERFORMANCE_ALERT_DISPLAY_CAP = 1000
2622
2723
2824class HMASynthesizer (BaseHierarchicalSampler , BaseMultiTableSynthesizer ):
@@ -107,7 +103,7 @@ def _get_num_extended_columns(
107103
108104 @classmethod
109105 def _estimate_columns_traversal (
110- cls , metadata , table_name , columns_per_table , visited , distributions = None , max_total = None
106+ cls , metadata , table_name , columns_per_table , visited , distributions = None
111107 ):
112108 """Given a table, estimate how many columns each parent will model.
113109
@@ -124,20 +120,20 @@ def _estimate_columns_traversal(
124120 for child_name in metadata ._get_child_map ()[table_name ]:
125121 if child_name not in visited :
126122 cls ._estimate_columns_traversal (
127- metadata , child_name , columns_per_table , visited , distributions , max_total
123+ metadata , child_name , columns_per_table , visited , distributions
128124 )
129125
130126 columns_per_table [table_name ] += cls ._get_num_extended_columns (
131127 metadata , child_name , table_name , columns_per_table , distributions
132128 )
133129
134- if max_total is not None and sum (columns_per_table .values ()) > max_total :
135- raise _EarlyStopEstimation
130+ if sum (columns_per_table .values ()) > MAX_NUMBER_OF_COLUMNS :
131+ return
136132
137133 visited .add (table_name )
138134
139135 @classmethod
140- def _estimate_num_columns (cls , metadata , distributions = None , max_total = None ):
136+ def _estimate_num_columns (cls , metadata , distributions = None ):
141137 """Estimate the number of columns that will be modeled for each table.
142138
143139 This method estimates how many extended columns will be generated during the
@@ -163,13 +159,10 @@ def _estimate_num_columns(cls, metadata, distributions=None, max_total=None):
163159 # each table will model
164160 visited = set ()
165161 for table_name in _get_root_tables (metadata .relationships ):
166- try :
167- cls ._estimate_columns_traversal (
168- metadata , table_name , columns_per_table , visited , distributions , max_total
169- )
170- except _EarlyStopEstimation :
171- break
172- if max_total is not None and sum (columns_per_table .values ()) > max_total :
162+ cls ._estimate_columns_traversal (
163+ metadata , table_name , columns_per_table , visited , distributions
164+ )
165+ if sum (columns_per_table .values ()) > MAX_NUMBER_OF_COLUMNS :
173166 break
174167
175168 return columns_per_table
@@ -250,23 +243,19 @@ def _print_estimate_warning(self):
250243 metadata_columns = self ._get_num_data_columns (self .metadata )
251244 print_table = []
252245 distributions = self ._get_distributions ()
253- estimated_columns = self ._estimate_num_columns (
254- self .metadata , distributions , max_total = PERFORMANCE_ALERT_DISPLAY_CAP
255- )
246+ estimated_columns = self ._estimate_num_columns (self .metadata , distributions )
256247 for table , est_cols in estimated_columns .items ():
257248 entry = []
258249 entry .append (table )
259250 entry .append (metadata_columns [table ])
260251 total_est_cols += est_cols
261252 entry .append (est_cols )
262253 print_table .append (entry )
263- if total_est_cols > PERFORMANCE_ALERT_DISPLAY_CAP :
264- break
265254
266- if total_est_cols > MAX_NUMBER_OF_COLUMNS :
255+ if total_est_cols > PERFORMANCE_ALERT_DISPLAY_CAP :
267256 display_total = (
268- f'{ PERFORMANCE_ALERT_DISPLAY_CAP } +'
269- if total_est_cols > PERFORMANCE_ALERT_DISPLAY_CAP
257+ f'{ MAX_NUMBER_OF_COLUMNS } +'
258+ if total_est_cols > MAX_NUMBER_OF_COLUMNS
270259 else f'{ total_est_cols } '
271260 )
272261 self ._print (
0 commit comments