Skip to content

Commit 01c9dc8

Browse files
committed
tests
1 parent db87d48 commit 01c9dc8

File tree

4 files changed

+25
-4
lines changed

4 files changed

+25
-4
lines changed

sdmetrics/single_table/privacy/util.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,9 @@ def hamming_distance(target, test):
6969
The hamming distance
7070
"""
7171
dist = 0
72-
assert len(target) == len(test), (
73-
'Tuples must have the same length in the calculation of hamming distance!'
74-
)
72+
assert len(target) == len(
73+
test
74+
), 'Tuples must have the same length in the calculation of hamming distance!'
7575

7676
for target_entry, test_entry in zip(target, test):
7777
if target_entry != test_entry:

tests/integration/reports/single_table/test_quality_report.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def test_report_end_to_end(self):
7777
key: val for key, val in metadata['columns'].items() if key in column_names
7878
}
7979
report = QualityReport()
80+
report.num_rows_subsample = None
8081

8182
# Run
8283
generate_start_time = time.time()
@@ -141,7 +142,8 @@ def test_report_end_to_end(self):
141142
report.get_details('Column Pair Trends'), expected_details_cpt
142143
)
143144
assert report.get_score() == 0.8393750143888287
144-
145+
assert report._properties['Column Shapes'].num_rows_subsample is None
146+
assert report._properties['Column Pair Trends'].num_rows_subsample == None
145147
report_info = report.get_info()
146148
assert report_info == report.report_info
147149

@@ -183,6 +185,8 @@ def test_with_large_dataset(self):
183185
# Assert
184186
cpt_report_1 = report_1.get_properties().iloc[1]['Score']
185187
cpt_report_2 = report_2.get_properties().iloc[1]['Score']
188+
assert report_1._properties['Column Pair Trends'].num_rows_subsample == 50000
189+
assert report_2._properties['Column Pair Trends'].num_rows_subsample == 50000
186190
assert score_1_run_1 != score_1_run_2
187191
assert np.isclose(score_1_run_1, score_1_run_2, atol=0.001)
188192
assert np.isclose(report_2.get_score(), score_1_run_1, atol=0.001)

tests/unit/reports/multi_table/test_base_multi_table_report.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from sdmetrics.demos import load_demo
1010
from sdmetrics.reports.multi_table.base_multi_table_report import BaseMultiTableReport
11+
from sdmetrics.reports.utils import DEFAULT_NUM_ROWS_SUBSAMPLE
1112

1213

1314
class TestBaseReport:
@@ -21,6 +22,7 @@ def test__init__(self):
2122
assert report.is_generated is False
2223
assert report._properties == {}
2324
assert report.table_names == []
25+
assert report.num_rows_subsample == DEFAULT_NUM_ROWS_SUBSAMPLE
2426

2527
def test__validate_data_format(self):
2628
"""Test the ``_validate_data_format`` method.

tests/unit/reports/test_base_report.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,21 @@
99

1010
from sdmetrics.demos import load_demo
1111
from sdmetrics.reports.base_report import BaseReport
12+
from sdmetrics.reports.utils import DEFAULT_NUM_ROWS_SUBSAMPLE
1213

1314

1415
class TestBaseReport:
16+
def test__init__(self):
17+
"""Test the initialization of the BaseReport class."""
18+
# Run
19+
base_report = BaseReport()
20+
21+
# Assert
22+
assert base_report._overall_score is None
23+
assert not base_report.is_generated
24+
assert base_report._properties == {}
25+
assert base_report.num_rows_subsample == DEFAULT_NUM_ROWS_SUBSAMPLE
26+
1527
def test__validate_data_format(self):
1628
"""Test the ``_validate_data_format`` method.
1729
@@ -268,6 +280,7 @@ def test_generate(self, version_mock, time_mock, datetime_mock):
268280
version_mock.return_value = 'version'
269281

270282
base_report = BaseReport()
283+
base_report.num_rows_subsample = 1000
271284
mock_validate = Mock()
272285
mock__print_results = Mock()
273286
base_report._print_results = mock__print_results
@@ -292,9 +305,11 @@ def test_generate(self, version_mock, time_mock, datetime_mock):
292305
base_report._properties['Property 1'].get_score.assert_called_with(
293306
real_data, synthetic_data, metadata, progress_bar=None
294307
)
308+
assert base_report._properties['Property 1'].num_rows_subsample == 1000
295309
base_report._properties['Property 2'].get_score.assert_called_with(
296310
real_data, synthetic_data, metadata, progress_bar=None
297311
)
312+
assert base_report._properties['Property 2'].num_rows_subsample == 1000
298313
expected_info = {
299314
'report_type': 'BaseReport',
300315
'generated_date': '2020-01-05',

0 commit comments

Comments
 (0)