Skip to content

Commit a17c4d6

Browse files
authored
HMASynthesizer - Cap displayed column count in PerformanceAlert Message (#2734)
1 parent 5fbeeeb commit a17c4d6

File tree

2 files changed

+45
-5
lines changed

2 files changed

+45
-5
lines changed

sdv/multi_table/hma.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@
1616
from sdv.sampling import BaseHierarchicalSampler
1717

1818
LOGGER = logging.getLogger(__name__)
19-
MAX_NUMBER_OF_COLUMNS = 1000
19+
PERFORMANCE_ALERT_DISPLAY_CAP = 1_000_000
2020
DEFAULT_EXTENDED_COLUMNS_DISTRIBUTION = 'truncnorm'
21+
MAX_NUMBER_OF_COLUMNS = 1000
2122

2223

2324
class HMASynthesizer(BaseHierarchicalSampler, BaseMultiTableSynthesizer):
@@ -139,6 +140,10 @@ def _estimate_columns_traversal(
139140
metadata, child_name, table_name, columns_per_table, distributions
140141
)
141142

143+
total_cols = sum(columns_list[1] for columns_list in columns_per_table.values())
144+
if total_cols > PERFORMANCE_ALERT_DISPLAY_CAP:
145+
return
146+
142147
visited.add(table_name)
143148

144149
@classmethod
@@ -171,6 +176,9 @@ def _estimate_num_columns(cls, metadata, distributions=None):
171176
cls._estimate_columns_traversal(
172177
metadata, table_name, columns_per_table, visited, distributions
173178
)
179+
total_cols = sum(columns_list[1] for columns_list in columns_per_table.values())
180+
if total_cols > PERFORMANCE_ALERT_DISPLAY_CAP:
181+
break
174182

175183
return {
176184
table_name: sum(columns_list) for table_name, columns_list in columns_per_table.items()
@@ -257,19 +265,25 @@ def _print_estimate_warning(self):
257265
metadata_columns = self._get_num_data_columns(self.metadata)
258266
print_table = []
259267
distributions = self._get_distributions()
260-
for table, est_cols in self._estimate_num_columns(self.metadata, distributions).items():
268+
estimated_columns = self._estimate_num_columns(self.metadata, distributions)
269+
for table, est_cols in estimated_columns.items():
261270
entry = []
262271
entry.append(table)
263272
entry.append(sum(metadata_columns[table]))
264273
total_est_cols += est_cols
265-
entry.append(est_cols)
274+
entry.append(min(est_cols, PERFORMANCE_ALERT_DISPLAY_CAP))
266275
print_table.append(entry)
267276

268277
if total_est_cols > MAX_NUMBER_OF_COLUMNS:
278+
display_total = (
279+
f'{PERFORMANCE_ALERT_DISPLAY_CAP}+'
280+
if total_est_cols > PERFORMANCE_ALERT_DISPLAY_CAP
281+
else f'{total_est_cols}'
282+
)
269283
self._print(
270284
'PerformanceAlert: Using the HMASynthesizer on this metadata '
271285
'schema is not recommended. To model this data, HMA will '
272-
f'generate a large number of columns. ({total_est_cols} columns)\n\n'
286+
f'generate a large number of columns. ({display_total} columns)\n\n'
273287
)
274288
self._print(
275289
pd.DataFrame(

tests/unit/multi_table/test_hma.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77

88
from sdv.errors import SynthesizerInputError
99
from sdv.metadata.metadata import Metadata
10-
from sdv.multi_table.hma import HMASynthesizer
10+
from sdv.multi_table.hma import (
11+
HMASynthesizer,
12+
)
1113
from sdv.single_table.copulas import GaussianCopulaSynthesizer
1214
from tests.utils import get_multi_table_data, get_multi_table_metadata
1315

@@ -129,6 +131,30 @@ def test__print_estimate_warning(self, get_distributions_mock, estimate_mock, ca
129131
match = re.search(constraint, captured.out + captured.err)
130132
assert match is None
131133

134+
@patch('sdv.multi_table.hma.HMASynthesizer._estimate_num_columns')
135+
@patch('sdv.multi_table.hma.HMASynthesizer._get_distributions')
136+
def test__print_estimate_warning_many_cols(self, get_distributions_mock, estimate_mock, capsys):
137+
"""Test that a warning appears if there are more than 1_000_000 expected columns"""
138+
# Setup
139+
metadata = get_multi_table_metadata()
140+
estimate_mock.side_effect = [{'nesreca': 1_000_010}, {'nesreca': 10}]
141+
142+
# Run
143+
HMASynthesizer(metadata)
144+
captured = capsys.readouterr()
145+
146+
# Assert
147+
expected_output = (
148+
'PerformanceAlert: Using the HMASynthesizer on this metadata schema is not recommended.'
149+
' To model this data, HMA will generate a large number of columns. (1000000+ columns)\n'
150+
'\n\nTable Name # Columns in Metadata Est # Columns\n'
151+
' nesreca 1 1000000\n\n'
152+
"We recommend simplifying your metadata schema using 'sdv.utils.poc.simplify_schema'."
153+
'\nIf this is not possible, please visit datacebo.com and reach out to us for '
154+
'enterprise solutions.\n\n'
155+
)
156+
assert captured.out == expected_output
157+
132158
def test__get_extension_foreign_key_only(self):
133159
"""Test the ``_get_extension`` method.
134160

0 commit comments

Comments
 (0)