|
16 | 16 | from sdv.sampling import BaseHierarchicalSampler |
17 | 17 |
|
18 | 18 | LOGGER = logging.getLogger(__name__) |
19 | | -MAX_NUMBER_OF_COLUMNS = 1000 |
| 19 | +PERFORMANCE_ALERT_DISPLAY_CAP = 1_000_000 |
20 | 20 | DEFAULT_EXTENDED_COLUMNS_DISTRIBUTION = 'truncnorm' |
| 21 | +MAX_NUMBER_OF_COLUMNS = 1000 |
21 | 22 |
|
22 | 23 |
|
23 | 24 | class HMASynthesizer(BaseHierarchicalSampler, BaseMultiTableSynthesizer): |
@@ -139,6 +140,10 @@ def _estimate_columns_traversal( |
139 | 140 | metadata, child_name, table_name, columns_per_table, distributions |
140 | 141 | ) |
141 | 142 |
|
| 143 | + total_cols = sum(columns_list[1] for columns_list in columns_per_table.values()) |
| 144 | + if total_cols > PERFORMANCE_ALERT_DISPLAY_CAP: |
| 145 | + return |
| 146 | + |
142 | 147 | visited.add(table_name) |
143 | 148 |
|
144 | 149 | @classmethod |
@@ -171,6 +176,9 @@ def _estimate_num_columns(cls, metadata, distributions=None): |
171 | 176 | cls._estimate_columns_traversal( |
172 | 177 | metadata, table_name, columns_per_table, visited, distributions |
173 | 178 | ) |
| 179 | + total_cols = sum(columns_list[1] for columns_list in columns_per_table.values()) |
| 180 | + if total_cols > PERFORMANCE_ALERT_DISPLAY_CAP: |
| 181 | + break |
174 | 182 |
|
175 | 183 | return { |
176 | 184 | table_name: sum(columns_list) for table_name, columns_list in columns_per_table.items() |
@@ -257,19 +265,25 @@ def _print_estimate_warning(self): |
257 | 265 | metadata_columns = self._get_num_data_columns(self.metadata) |
258 | 266 | print_table = [] |
259 | 267 | distributions = self._get_distributions() |
260 | | - for table, est_cols in self._estimate_num_columns(self.metadata, distributions).items(): |
| 268 | + estimated_columns = self._estimate_num_columns(self.metadata, distributions) |
| 269 | + for table, est_cols in estimated_columns.items(): |
261 | 270 | entry = [] |
262 | 271 | entry.append(table) |
263 | 272 | entry.append(sum(metadata_columns[table])) |
264 | 273 | total_est_cols += est_cols |
265 | | - entry.append(est_cols) |
| 274 | + entry.append(min(est_cols, PERFORMANCE_ALERT_DISPLAY_CAP)) |
266 | 275 | print_table.append(entry) |
267 | 276 |
|
268 | 277 | if total_est_cols > MAX_NUMBER_OF_COLUMNS: |
| 278 | + display_total = ( |
| 279 | + f'{PERFORMANCE_ALERT_DISPLAY_CAP}+' |
| 280 | + if total_est_cols > PERFORMANCE_ALERT_DISPLAY_CAP |
| 281 | + else f'{total_est_cols}' |
| 282 | + ) |
269 | 283 | self._print( |
270 | 284 | 'PerformanceAlert: Using the HMASynthesizer on this metadata ' |
271 | 285 | 'schema is not recommended. To model this data, HMA will ' |
272 | | - f'generate a large number of columns. ({total_est_cols} columns)\n\n' |
| 286 | + f'generate a large number of columns. ({display_total} columns)\n\n' |
273 | 287 | ) |
274 | 288 | self._print( |
275 | 289 | pd.DataFrame( |
|
0 commit comments