Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions sdmetrics/reports/base_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import tqdm

from sdmetrics._utils_metadata import _convert_datetime_column, _validate_metadata
from sdmetrics.reports.utils import DEFAULT_NUM_ROWS_SUBSAMPLE
from sdmetrics.visualization import set_plotly_config


Expand All @@ -27,6 +28,7 @@ def __init__(self):
self._overall_score = None
self.is_generated = False
self._properties = {}
self.num_rows_subsample = DEFAULT_NUM_ROWS_SUBSAMPLE
self.report_info = {
'report_type': self.__class__.__name__,
'generated_date': None,
Expand Down Expand Up @@ -163,6 +165,7 @@ def generate(self, real_data, synthetic_data, metadata, verbose=True):
f'({ind + 1}/{len(self._properties)}) Evaluating {property_name}'
)

self._properties[property_name].num_rows_subsample = self.num_rows_subsample
score = self._properties[property_name].get_score(
real_data, synthetic_data, metadata, progress_bar=progress_bar
)
Expand Down
3 changes: 3 additions & 0 deletions sdmetrics/reports/multi_table/_properties/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import numpy as np
import pandas as pd

from sdmetrics.reports.utils import DEFAULT_NUM_ROWS_SUBSAMPLE


class BaseMultiTableProperty:
"""Base class for multi table properties.
Expand All @@ -26,6 +28,7 @@ def __init__(self):
self._properties = {}
self.is_computed = False
self.details = pd.DataFrame()
self.num_rows_subsample = DEFAULT_NUM_ROWS_SUBSAMPLE

def _get_num_iterations(self, metadata):
"""Get the number of iterations for the property."""
Expand Down
3 changes: 3 additions & 0 deletions sdmetrics/reports/single_table/_properties/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import pandas as pd

from sdmetrics.reports.utils import DEFAULT_NUM_ROWS_SUBSAMPLE


class BaseSingleTableProperty:
"""Base class for single table properties.
Expand All @@ -14,6 +16,7 @@ class BaseSingleTableProperty:

def __init__(self):
self.details = pd.DataFrame()
self.num_rows_subsample = DEFAULT_NUM_ROWS_SUBSAMPLE

def _compute_average(self):
"""Average the scores for each column."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
from sdmetrics.reports.single_table._properties import BaseSingleTableProperty
from sdmetrics.reports.utils import PlotConfig

DEFAULT_NUM_ROWS_SUBSAMPLE = 50000


class ColumnPairTrends(BaseSingleTableProperty):
"""Column pair trends property.
Expand All @@ -30,6 +28,7 @@ class ColumnPairTrends(BaseSingleTableProperty):
}

def __init__(self):
super().__init__()
self._columns_datetime_conversion_failed = {}
self._columns_discretization_failed = {}

Expand Down Expand Up @@ -276,10 +275,12 @@ def _generate_details(
)

metric_params = {}
if (metric == ContingencySimilarity) and (
max(len(col_real), len(col_synthetic)) > DEFAULT_NUM_ROWS_SUBSAMPLE
if (
self.num_rows_subsample
and (metric == ContingencySimilarity)
and (max(len(col_real), len(col_synthetic)) > self.num_rows_subsample)
):
metric_params['num_rows_subsample'] = DEFAULT_NUM_ROWS_SUBSAMPLE
metric_params['num_rows_subsample'] = self.num_rows_subsample

try:
error = self._preprocessing_failed(
Expand Down
1 change: 1 addition & 0 deletions sdmetrics/reports/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

CONTINUOUS_SDTYPES = ['numerical', 'datetime']
DISCRETE_SDTYPES = ['categorical', 'boolean']
DEFAULT_NUM_ROWS_SUBSAMPLE = 50000


class PlotConfig:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def test_report_end_to_end(self):
key: val for key, val in metadata['columns'].items() if key in column_names
}
report = QualityReport()
report.num_rows_subsample = None

# Run
generate_start_time = time.time()
Expand Down Expand Up @@ -141,7 +142,8 @@ def test_report_end_to_end(self):
report.get_details('Column Pair Trends'), expected_details_cpt
)
assert report.get_score() == 0.8393750143888287

assert report._properties['Column Shapes'].num_rows_subsample is None
assert report._properties['Column Pair Trends'].num_rows_subsample is None
report_info = report.get_info()
assert report_info == report.report_info

Expand Down Expand Up @@ -183,6 +185,8 @@ def test_with_large_dataset(self):
# Assert
cpt_report_1 = report_1.get_properties().iloc[1]['Score']
cpt_report_2 = report_2.get_properties().iloc[1]['Score']
assert report_1._properties['Column Pair Trends'].num_rows_subsample == 50000
assert report_2._properties['Column Pair Trends'].num_rows_subsample == 50000
assert score_1_run_1 != score_1_run_2
assert np.isclose(score_1_run_1, score_1_run_2, atol=0.001)
assert np.isclose(report_2.get_score(), score_1_run_1, atol=0.001)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from sdmetrics.demos import load_demo
from sdmetrics.reports.multi_table.base_multi_table_report import BaseMultiTableReport
from sdmetrics.reports.utils import DEFAULT_NUM_ROWS_SUBSAMPLE


class TestBaseReport:
Expand All @@ -21,6 +22,7 @@ def test__init__(self):
assert report.is_generated is False
assert report._properties == {}
assert report.table_names == []
assert report.num_rows_subsample == DEFAULT_NUM_ROWS_SUBSAMPLE

def test__validate_data_format(self):
"""Test the ``_validate_data_format`` method.
Expand Down
15 changes: 15 additions & 0 deletions tests/unit/reports/test_base_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,21 @@

from sdmetrics.demos import load_demo
from sdmetrics.reports.base_report import BaseReport
from sdmetrics.reports.utils import DEFAULT_NUM_ROWS_SUBSAMPLE


class TestBaseReport:
def test__init__(self):
"""Test the initialization of the BaseReport class."""
# Run
base_report = BaseReport()

# Assert
assert base_report._overall_score is None
assert not base_report.is_generated
assert base_report._properties == {}
assert base_report.num_rows_subsample == DEFAULT_NUM_ROWS_SUBSAMPLE

def test__validate_data_format(self):
"""Test the ``_validate_data_format`` method.

Expand Down Expand Up @@ -268,6 +280,7 @@ def test_generate(self, version_mock, time_mock, datetime_mock):
version_mock.return_value = 'version'

base_report = BaseReport()
base_report.num_rows_subsample = 1000
mock_validate = Mock()
mock__print_results = Mock()
base_report._print_results = mock__print_results
Expand All @@ -292,9 +305,11 @@ def test_generate(self, version_mock, time_mock, datetime_mock):
base_report._properties['Property 1'].get_score.assert_called_with(
real_data, synthetic_data, metadata, progress_bar=None
)
assert base_report._properties['Property 1'].num_rows_subsample == 1000
base_report._properties['Property 2'].get_score.assert_called_with(
real_data, synthetic_data, metadata, progress_bar=None
)
assert base_report._properties['Property 2'].num_rows_subsample == 1000
expected_info = {
'report_type': 'BaseReport',
'generated_date': '2020-01-05',
Expand Down
Loading