Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 55 additions & 21 deletions sdmetrics/reports/multi_table/_properties/inter_table_trends.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import pandas as pd
import plotly.express as px
from plotly import graph_objects as go

from sdmetrics.reports.multi_table._properties import BaseMultiTableProperty
from sdmetrics.reports.single_table._properties import (
Expand Down Expand Up @@ -180,6 +181,55 @@ def _generate_details(self, real_data, synthetic_data, metadata, progress_bar=No

self.details = self.details[detail_columns]

def _filter_details_for_plot(self, table_name=None):
to_plot = self.details.copy()
if table_name is not None:
to_plot = to_plot[
(to_plot['Parent Table'] == table_name) | (to_plot['Child Table'] == table_name)
]

if 'Error' in to_plot.columns:
to_plot = to_plot[to_plot['Error'].isna()]
if 'Score' in to_plot.columns:
to_plot = to_plot[to_plot['Score'].notna()]

return to_plot

def _create_empty_plot(self):
fig = go.Figure()
fig.update_layout(
title_text='No data to plot',
xaxis={'visible': False},
yaxis={'visible': False},
showlegend=False,
plot_bgcolor='rgba(0,0,0,0)',
paper_bgcolor='white',
)
return fig

def _prepare_plot_data(self, to_plot):
to_plot = to_plot.reset_index(drop=True)

parent_cols = to_plot['Parent Table'] + '.' + to_plot['Column 1']
child_cols = to_plot['Child Table'] + '.' + to_plot['Column 2']
to_plot['Columns'] = parent_cols + ', ' + child_cols
duplicated = to_plot['Columns'].duplicated(keep=False)
to_plot.loc[duplicated, 'Columns'] = (
to_plot.loc[duplicated, 'Columns'] + ' (' + to_plot.loc[duplicated, 'Foreign Key'] + ')'
)

to_plot['Real Correlation'] = to_plot['Real Correlation'].fillna('None')
to_plot['Synthetic Correlation'] = to_plot['Synthetic Correlation'].fillna('None')

return to_plot

def _compute_average_score(self, to_plot):
if 'Meets Threshold?' in to_plot.columns:
contributing = to_plot['Meets Threshold?'].astype('boolean').fillna(False)
return round(to_plot.loc[contributing, 'Score'].mean(), 2)

return round(to_plot['Score'].mean(), 2)

def get_visualization(self, table_name=None):
"""Create a plot to show the inter table trends data.

Expand All @@ -202,28 +252,12 @@ def get_visualization(self, table_name=None):
'Please call the ``get_score`` method first.'
)

to_plot = self.details.copy()
if table_name is not None:
to_plot = to_plot[
(to_plot['Parent Table'] == table_name) | (to_plot['Child Table'] == table_name)
]
to_plot = self._filter_details_for_plot(table_name)
if to_plot.empty:
return self._create_empty_plot()

parent_cols = to_plot['Parent Table'] + '.' + to_plot['Column 1']
child_cols = to_plot['Child Table'] + '.' + to_plot['Column 2']
to_plot['Columns'] = parent_cols + ', ' + child_cols
duplicated = to_plot['Columns'].duplicated(keep=False)
to_plot.loc[duplicated, 'Columns'] = (
to_plot.loc[duplicated, 'Columns'] + ' (' + to_plot.loc[duplicated, 'Foreign Key'] + ')'
)

to_plot['Real Correlation'] = to_plot['Real Correlation'].fillna('None')
to_plot['Synthetic Correlation'] = to_plot['Synthetic Correlation'].fillna('None')

if 'Meets Threshold?' in to_plot.columns:
contributing = to_plot['Meets Threshold?'].astype('boolean').fillna(False)
average_score = round(to_plot.loc[contributing, 'Score'].mean(), 2)
else:
average_score = round(to_plot['Score'].mean(), 2)
to_plot = self._prepare_plot_data(to_plot)
average_score = self._compute_average_score(to_plot)

fig = px.bar(
to_plot,
Expand Down
223 changes: 162 additions & 61 deletions sdmetrics/reports/single_table/_properties/column_pair_trends.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,28 +398,45 @@ def _generate_details(

return result

def _get_correlation_table(self, column_name):
if column_name not in ['Score', 'Real Correlation', 'Synthetic Correlation']:
raise ValueError(f"Invalid column name for _get_correlation_matrix : '{column_name}'")

table = self.details.copy()
if 'Error' in table.columns:
table = table[table['Error'].isna()]

if column_name in ['Real Correlation', 'Synthetic Correlation']:
table = table[table['Metric'] == 'CorrelationSimilarity']

return table.dropna(subset=['Score'])

def _get_pair_score(self, table, column_name_1, column_name_2, column_name):
# check whether the combination (Column 1, Column 2) or (Column 2, Column 1) is in the table
col_1_loc = table['Column 1'] == column_name_1
col_2_loc = table['Column 2'] == column_name_2
if table.loc[col_1_loc & col_2_loc].empty:
col_1_loc = table['Column 1'] == column_name_2
col_2_loc = table['Column 2'] == column_name_1

if not table.loc[col_1_loc & col_2_loc].empty:
return table.loc[col_1_loc & col_2_loc][column_name].array[0]

return None

def _get_correlation_matrix(self, column_name):
"""Get the correlation matrix for the given column name.

Args:
column_name (str):
The column name to use.
"""
if column_name not in ['Score', 'Real Correlation', 'Synthetic Correlation']:
raise ValueError(f"Invalid column name for _get_correlation_matrix : '{column_name}'")

table = self.details
if column_name in ['Real Correlation', 'Synthetic Correlation']:
names_source = table[table['Metric'] == 'CorrelationSimilarity']
else:
names_source = table
table = self._get_correlation_table(column_name)
if table.empty:
return pd.DataFrame()

table = names_source.dropna(subset=[column_name])
if column_name == 'Score':
names_source = self.details

names = list(pd.concat([names_source['Column 1'], names_source['Column 2']]).unique())
available_columns = set(pd.concat([table['Column 1'], table['Column 2']]).unique())
names = list(pd.concat([table['Column 1'], table['Column 2']]).unique())
available_columns = set(names)
heatmap_df = pd.DataFrame(index=names, columns=names)

for idx_1, column_name_1 in enumerate(names):
Expand All @@ -430,22 +447,12 @@ def _get_correlation_matrix(self, column_name):
)
continue

# check wether the combination (Colunm 1, Column 2) or (Column 2, Column 1)
# is in the table
col_1_loc = table['Column 1'] == column_name_1
col_2_loc = table['Column 2'] == column_name_2
if table.loc[col_1_loc & col_2_loc].empty:
col_1_loc = table['Column 1'] == column_name_2
col_2_loc = table['Column 2'] == column_name_1

if not table.loc[col_1_loc & col_2_loc].empty:
score = table.loc[col_1_loc & col_2_loc][column_name].array[0]
score = self._get_pair_score(table, column_name_1, column_name_2, column_name)
if score is not None:
heatmap_df.loc[column_name_1, column_name_2] = score
heatmap_df.loc[column_name_2, column_name_1] = score

heatmap_df = heatmap_df.astype(float)

return heatmap_df.round(3)
return heatmap_df.astype(float).round(3)

def _get_heatmap(self, correlation_matrix, coloraxis, hovertemplate, customdata=None):
"""Get the heatmap for the given correlation matrix.
Expand Down Expand Up @@ -484,12 +491,14 @@ def _get_heatmap(self, correlation_matrix, coloraxis, hovertemplate, customdata=

return [base_heatmap, nan_heatmap]

def _update_layout(self, fig):
def _update_layout(self, fig, show_correlations=True):
"""Update the layout of the figure.

Args:
fig (plotly.graph_objects._figure.Figure):
The figure to update.
show_correlations (bool):
Whether or not to show numerical correlation plots.
"""
average_score = round(self._compute_average(), 2)
color_dict = {
Expand All @@ -501,18 +510,123 @@ def _update_layout(self, fig):
colors_1 = [PlotConfig.RED, PlotConfig.ORANGE, PlotConfig.GREEN]
colors_2 = [PlotConfig.DATACEBO_BLUE, PlotConfig.DATACEBO_DARK, PlotConfig.DATACEBO_GREEN]

layout_args = {
'title_text': (f'Data Quality: Column Pair Trends (Average Score={average_score})'),
'coloraxis': {
**color_dict,
'colorbar_x': 0.8,
'colorbar_y': 0.8,
'colorscale': colors_1,
},
'height': 900 if show_correlations else 450,
'width': 900,
'font': {'size': PlotConfig.FONT_SIZE},
}

if show_correlations:
layout_args.update({
'coloraxis2': {**color_dict, 'colorbar_y': 0.2, 'cmin': -1, 'colorscale': colors_2},
'yaxis3': {'visible': False, 'matches': 'y2'},
'xaxis3': {'matches': 'x2'},
})
else:
layout_args['coloraxis'].pop('colorbar_len', None)
layout_args['coloraxis']['colorbar'] = {
'x': 0.92,
'xanchor': 'left',
'y': 0.5,
'yanchor': 'middle',
'len': 1,
}
layout_args['xaxis'] = {'constrain': 'domain', 'domain': [0.1, 0.9]}
layout_args['yaxis'] = {'scaleanchor': 'x', 'scaleratio': 1}
layout_args['margin'] = {'r': 60}

fig.update_layout(**layout_args)
fig.update_yaxes(autorange='reversed')

def _get_empty_visualization(self):
fig = go.Figure()
fig.update_layout(
title_text=f'Data Quality: Column Pair Trends (Average Score={average_score})',
coloraxis={**color_dict, 'colorbar_x': 0.8, 'colorbar_y': 0.8, 'colorscale': colors_1},
coloraxis2={**color_dict, 'colorbar_y': 0.2, 'cmin': -1, 'colorscale': colors_2},
yaxis3={'visible': False, 'matches': 'y2'},
xaxis3={'matches': 'x2'},
height=900,
width=900,
font={'size': PlotConfig.FONT_SIZE},
title_text='No data to plot',
xaxis={'visible': False},
yaxis={'visible': False},
showlegend=False,
plot_bgcolor='rgba(0,0,0,0)',
paper_bgcolor='white',
)

fig.update_yaxes(autorange='reversed')
return fig

def _add_heatmap_traces(
self,
fig,
correlation_matrix,
coloraxis,
hovertemplate,
customdata=None,
row=None,
col=None,
):
for trace in self._get_heatmap(
correlation_matrix, coloraxis, hovertemplate, customdata=customdata
):
if row is None or col is None:
fig.add_trace(trace)
else:
fig.add_trace(trace, row, col)

def _build_single_heatmap_figure(self, similarity_correlation, title):
fig = make_subplots(rows=1, cols=1, subplot_titles=[title])
fig.update_xaxes(tickangle=45)
tmpl_1 = '<b>Column Pair</b><br>(%{x},%{y})<br><br>Similarity: %{z}<extra></extra>'
self._add_heatmap_traces(fig, similarity_correlation, 'coloraxis', tmpl_1)
self._update_layout(fig, show_correlations=False)

return fig

def _build_full_heatmap_figure(
self,
similarity_correlation,
real_correlation,
synthetic_correlation,
titles,
):
tmpl_1 = '<b>Column Pair</b><br>(%{x},%{y})<br><br>Similarity: %{z}<extra></extra>'
tmpl_2 = (
'<b>Correlation</b><br>(%{x},%{y})<br><br>Synthetic: %{z}<br>(vs. Real: '
'%{customdata})<extra></extra>'
)
specs = [[{'colspan': 2, 'l': 0.26, 'r': 0.26}, None], [{}, {}]]
fig = make_subplots(rows=2, cols=2, subplot_titles=titles, specs=specs)
fig.update_xaxes(tickangle=45)
self._add_heatmap_traces(fig, similarity_correlation, 'coloraxis', tmpl_1)

synthetic_correlation = synthetic_correlation.reindex(
index=real_correlation.index, columns=real_correlation.columns
)
self._add_heatmap_traces(
fig,
real_correlation,
'coloraxis2',
tmpl_2,
customdata=synthetic_correlation,
row=2,
col=1,
)
self._add_heatmap_traces(
fig,
synthetic_correlation,
'coloraxis2',
tmpl_2,
customdata=real_correlation,
row=2,
col=2,
)

self._update_layout(fig)

return fig

def get_visualization(self):
"""Create a plot to show the column pairs data.
Expand All @@ -523,6 +637,9 @@ def get_visualization(self):
plotly.graph_objects._figure.Figure
"""
similarity_correlation = self._get_correlation_matrix('Score')
if similarity_correlation.empty:
return self._get_empty_visualization()

real_correlation = self._get_correlation_matrix('Real Correlation')
synthetic_correlation = self._get_correlation_matrix('Synthetic Correlation')

Expand All @@ -531,28 +648,12 @@ def get_visualization(self):
'Numerical Correlation (Real Data)',
'Numerical Correlation (Synthetic Data)',
]
specs = [[{'colspan': 2, 'l': 0.26, 'r': 0.26}, None], [{}, {}]]
tmpl_1 = '<b>Column Pair</b><br>(%{x},%{y})<br><br>Similarity: %{z}<extra></extra>'
tmpl_2 = (
'<b>Correlation</b><br>(%{x},%{y})<br><br>Synthetic: %{z}<br>(vs. Real: '
'%{customdata})<extra></extra>'
)

fig = make_subplots(rows=2, cols=2, subplot_titles=titles, specs=specs)

fig.update_xaxes(tickangle=45)
if real_correlation.empty:
return self._build_single_heatmap_figure(similarity_correlation, titles[0])

for trace in self._get_heatmap(similarity_correlation, 'coloraxis', tmpl_1):
fig.add_trace(trace, 1, 1)
for trace in self._get_heatmap(
real_correlation, 'coloraxis2', tmpl_2, synthetic_correlation
):
fig.add_trace(trace, 2, 1)
for trace in self._get_heatmap(
synthetic_correlation, 'coloraxis2', tmpl_2, real_correlation
):
fig.add_trace(trace, 2, 2)

self._update_layout(fig)

return fig
return self._build_full_heatmap_figure(
similarity_correlation,
real_correlation,
synthetic_correlation,
titles,
)
Loading