Skip to content

Commit 9c1ca77

Browse files
committed
Small fix
1 parent e353e51 commit 9c1ca77

File tree

3 files changed

+106
-55
lines changed

3 files changed

+106
-55
lines changed

sdmetrics/reports/base_report.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,8 @@ def get_details(self, property_name):
272272
pandas.DataFrame
273273
"""
274274
self._validate_property_generated(property_name)
275-
return self._properties[property_name].details.copy()
275+
details = self._properties[property_name].details.copy()
276+
return details
276277

277278
def save(self, filepath):
278279
"""Save this report instance to the given path using pickle.

sdmetrics/reports/single_table/_properties/column_pair_trends.py

Lines changed: 101 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,60 @@ def _preprocessing_failed(self, column_name_1, column_name_2, sdtype_col_1, sdty
217217

218218
return error
219219

220+
def _compute_pair_score(
221+
self,
222+
metric,
223+
col_real,
224+
col_synthetic,
225+
metric_params,
226+
):
227+
"""Compute the score breakdown and threshold check for a column pair."""
228+
params = dict(metric_params)
229+
if metric.__name__ == 'CorrelationSimilarity':
230+
if self.real_correlation_threshold > 0:
231+
params['real_correlation_threshold'] = self.real_correlation_threshold
232+
233+
score_breakdown = metric.compute_breakdown(
234+
real_data=col_real, synthetic_data=col_synthetic, **params
235+
)
236+
pair_score = score_breakdown['score']
237+
real_correlation = score_breakdown['real']
238+
synthetic_correlation = score_breakdown['synthetic']
239+
real_association = np.nan
240+
if self.real_correlation_threshold <= 0:
241+
meets_threshold_pair = True
242+
else:
243+
meets_threshold_pair = (
244+
not np.isnan(real_correlation)
245+
and abs(real_correlation) > self.real_correlation_threshold
246+
)
247+
else:
248+
if self.real_association_threshold > 0:
249+
params['real_association_threshold'] = self.real_association_threshold
250+
251+
score_breakdown = metric.compute_breakdown(
252+
real_data=col_real, synthetic_data=col_synthetic, **params
253+
)
254+
pair_score = score_breakdown['score']
255+
real_correlation = np.nan
256+
synthetic_correlation = np.nan
257+
real_association = score_breakdown.get('real_association', np.nan)
258+
if self.real_association_threshold > 0:
259+
meets_threshold_pair = (
260+
not np.isnan(real_association)
261+
and real_association > self.real_association_threshold
262+
)
263+
else:
264+
meets_threshold_pair = True
265+
266+
return (
267+
pair_score,
268+
real_correlation,
269+
synthetic_correlation,
270+
real_association,
271+
meets_threshold_pair,
272+
)
273+
220274
def _generate_details(
221275
self, real_data, synthetic_data, metadata, progress_bar=None, column_pairs=None
222276
):
@@ -297,45 +351,13 @@ def _generate_details(
297351
if error:
298352
raise Exception('Preprocessing failed')
299353

300-
real_association = np.nan
301-
if metric.__name__ == 'CorrelationSimilarity':
302-
if self.real_correlation_threshold > 0:
303-
metric_params['real_correlation_threshold'] = (
304-
self.real_correlation_threshold
305-
)
306-
score_breakdown = metric.compute_breakdown(
307-
real_data=col_real, synthetic_data=col_synthetic, **metric_params
308-
)
309-
pair_score = score_breakdown['score']
310-
real_correlation = score_breakdown['real']
311-
synthetic_correlation = score_breakdown['synthetic']
312-
if self.real_correlation_threshold <= 0:
313-
meets_threshold_pair = True
314-
else:
315-
meets_threshold_pair = (
316-
not np.isnan(real_correlation)
317-
and abs(real_correlation) > self.real_correlation_threshold
318-
)
319-
else:
320-
real_correlation = np.nan
321-
synthetic_correlation = np.nan
322-
if self.real_association_threshold > 0:
323-
metric_params['real_association_threshold'] = (
324-
self.real_association_threshold
325-
)
326-
score_breakdown = metric.compute_breakdown(
327-
real_data=col_real, synthetic_data=col_synthetic, **metric_params
328-
)
329-
pair_score = score_breakdown['score']
330-
real_association = score_breakdown.get('real_association', np.nan)
331-
332-
if self.real_association_threshold > 0:
333-
meets_threshold_pair = (
334-
not np.isnan(real_association)
335-
and real_association > self.real_association_threshold
336-
)
337-
else:
338-
meets_threshold_pair = True
354+
(
355+
pair_score,
356+
real_correlation,
357+
synthetic_correlation,
358+
real_association,
359+
meets_threshold_pair,
360+
) = self._compute_pair_score(metric, col_real, col_synthetic, metric_params)
339361

340362
except Exception as e:
341363
pair_score = np.nan
@@ -386,14 +408,26 @@ def _get_correlation_matrix(self, column_name):
386408
if column_name not in ['Score', 'Real Correlation', 'Synthetic Correlation']:
387409
raise ValueError(f"Invalid column name for _get_correlation_matrix : '{column_name}'")
388410

389-
table = self.details.dropna(subset=[column_name])
390-
names = list(pd.concat([table['Column 1'], table['Column 2']]).unique())
411+
table = self.details
412+
if column_name in ['Real Correlation', 'Synthetic Correlation']:
413+
names_source = table[table['Metric'] == 'CorrelationSimilarity']
414+
else:
415+
names_source = table
416+
417+
table = names_source.dropna(subset=[column_name])
418+
if column_name == 'Score':
419+
names_source = self.details
420+
421+
names = list(pd.concat([names_source['Column 1'], names_source['Column 2']]).unique())
422+
available_columns = set(pd.concat([table['Column 1'], table['Column 2']]).unique())
391423
heatmap_df = pd.DataFrame(index=names, columns=names)
392424

393425
for idx_1, column_name_1 in enumerate(names):
394426
for column_name_2 in names[idx_1:]:
395427
if column_name_1 == column_name_2:
396-
heatmap_df.loc[column_name_1, column_name_2] = 1
428+
heatmap_df.loc[column_name_1, column_name_2] = (
429+
1 if column_name_1 in available_columns else np.nan
430+
)
397431
continue
398432

399433
# check wether the combination (Colunm 1, Column 2) or (Column 2, Column 1)
@@ -426,7 +460,7 @@ def _get_heatmap(self, correlation_matrix, coloraxis, hovertemplate, customdata=
426460
customdata (pandas.DataFrame or None):
427461
The customdata to use. Defaults to None.
428462
"""
429-
fig = go.Heatmap(
463+
base_heatmap = go.Heatmap(
430464
x=correlation_matrix.columns,
431465
y=correlation_matrix.columns,
432466
z=correlation_matrix,
@@ -435,7 +469,20 @@ def _get_heatmap(self, correlation_matrix, coloraxis, hovertemplate, customdata=
435469
hovertemplate=hovertemplate,
436470
)
437471

438-
return fig
472+
nan_mask = correlation_matrix.isna().to_numpy()
473+
if not nan_mask.any():
474+
return [base_heatmap]
475+
476+
nan_heatmap = go.Heatmap(
477+
x=correlation_matrix.columns,
478+
y=correlation_matrix.columns,
479+
z=np.where(nan_mask, 1, np.nan),
480+
colorscale=[[0, '#B0B0B0'], [1, '#B0B0B0']],
481+
showscale=False,
482+
hoverinfo='skip',
483+
)
484+
485+
return [base_heatmap, nan_heatmap]
439486

440487
def _update_layout(self, fig):
441488
"""Update the layout of the figure.
@@ -495,13 +542,16 @@ def get_visualization(self):
495542

496543
fig.update_xaxes(tickangle=45)
497544

498-
fig.add_trace(self._get_heatmap(similarity_correlation, 'coloraxis', tmpl_1), 1, 1)
499-
fig.add_trace(
500-
self._get_heatmap(real_correlation, 'coloraxis2', tmpl_2, synthetic_correlation), 2, 1
501-
)
502-
fig.add_trace(
503-
self._get_heatmap(synthetic_correlation, 'coloraxis2', tmpl_2, real_correlation), 2, 2
504-
)
545+
for trace in self._get_heatmap(similarity_correlation, 'coloraxis', tmpl_1):
546+
fig.add_trace(trace, 1, 1)
547+
for trace in self._get_heatmap(
548+
real_correlation, 'coloraxis2', tmpl_2, synthetic_correlation
549+
):
550+
fig.add_trace(trace, 2, 1)
551+
for trace in self._get_heatmap(
552+
synthetic_correlation, 'coloraxis2', tmpl_2, real_correlation
553+
):
554+
fig.add_trace(trace, 2, 2)
505555

506556
self._update_layout(fig)
507557

tests/unit/reports/single_table/_properties/test_column_pair_trends.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,7 @@ def test__get_correlation_matrix_score(self):
385385
cpt_property.details = pd.DataFrame({
386386
'Column 1': ['col1', 'col1', 'col2'],
387387
'Column 2': ['col2', 'col3', 'col3'],
388-
'metric': ['CorrelationSimilarity', 'ContingencySimilarity', 'ContingencySimilarity'],
388+
'Metric': ['CorrelationSimilarity', 'ContingencySimilarity', 'ContingencySimilarity'],
389389
'Score': [0.5, 0.6, 0.7],
390390
})
391391

@@ -411,7 +411,7 @@ def test__get_correlation_matrix_correlation(self):
411411
cpt_property.details = pd.DataFrame({
412412
'Column 1': ['col1', 'col1', 'col2'],
413413
'Column 2': ['col2', 'col3', 'col3'],
414-
'metric': ['CorrelationSimilarity', 'ContingencySimilarity', 'ContingencySimilarity'],
414+
'Metric': ['CorrelationSimilarity', 'ContingencySimilarity', 'ContingencySimilarity'],
415415
'Score': [0.5, 0.6, 0.7],
416416
'Real Correlation': [0.3, None, None],
417417
'Synthetic Correlation': [0.4, None, None],
@@ -526,7 +526,7 @@ def test_get_visualization(self, mock_make_subplots):
526526
cpt_property._get_correlation_matrix = mock__get_correlation_matrix
527527

528528
mock_heatmap = Mock()
529-
cpt_property._get_heatmap = Mock(return_value=mock_heatmap)
529+
cpt_property._get_heatmap = Mock(return_value=[mock_heatmap])
530530

531531
mock__update_layout = Mock()
532532
cpt_property._update_layout = mock__update_layout

0 commit comments

Comments
 (0)