Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions sdmetrics/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,33 @@ def _generate_box_plot(all_data, columns):
return fig


def _generate_violin_plot(data, columns):
"""Return a violin plot for a given column pair."""
fig = px.violin(
data,
x=columns[0],
y=columns[1],
box=False,
violinmode='overlay',
color='Data',
color_discrete_map={
'Real': PlotConfig.DATACEBO_DARK,
'Synthetic': PlotConfig.DATACEBO_GREEN,
},
)

unique_values = data['Data'].unique()
title = ' vs. '.join(unique_values)
title += f" Data for columns '{columns[0]}' and '{columns[1]}'"
fig.update_layout(
title=title,
plot_bgcolor=PlotConfig.BACKGROUND_COLOR,
font={'size': PlotConfig.FONT_SIZE},
)

return fig


def _generate_scatter_plot(all_data, columns):
"""Generate a scatter plot for column pair plot.

Expand Down Expand Up @@ -615,10 +642,10 @@ def get_column_pair_plot(real_data, synthetic_data, column_names, plot_type=None
)
synthetic_data = synthetic_data[column_names]

if plot_type not in ['box', 'heatmap', 'scatter', None]:
if plot_type not in ['box', 'heatmap', 'scatter', 'violin', None]:
raise ValueError(
f"Invalid plot_type '{plot_type}'. Please use one of "
"['box', 'heatmap', 'scatter', None]."
"['box', 'heatmap', 'scatter', 'violin', None]."
)

if plot_type is None:
Expand Down Expand Up @@ -654,6 +681,8 @@ def get_column_pair_plot(real_data, synthetic_data, column_names, plot_type=None
return _generate_scatter_plot(all_data, column_names)
elif plot_type == 'heatmap':
return _generate_heatmap_plot(all_data, column_names)
elif plot_type == 'violin':
return _generate_violin_plot(all_data, column_names)

return _generate_box_plot(all_data, column_names)

Expand Down
66 changes: 65 additions & 1 deletion tests/unit/test_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
_generate_heatmap_plot,
_generate_line_plot,
_generate_scatter_plot,
_generate_violin_plot,
_get_cardinality,
_get_max_between_datasets,
_get_min_between_datasets,
Expand Down Expand Up @@ -1004,6 +1005,47 @@ def test__generate_box_plot_title_one_dataset_only(px_mock):
assert fig_real == mock_figure


@patch('sdmetrics.visualization.px')
def test__generate_violin_plot(px_mock):
"""Test the ``_generate_violin_plot`` method."""
# Setup
real_column = pd.DataFrame({
'col1': [1, 2, 3, 4],
'col2': ['a', 'b', 'c', 'd'],
'Data': ['Real'] * 4,
})
synthetic_column = pd.DataFrame({
'col1': [1, 2, 4, 5],
'col2': ['a', 'b', 'c', 'd'],
'Data': ['Synthetic'] * 4,
})
columns = ['col1', 'col2']
all_data = pd.concat([real_column, synthetic_column], axis=0, ignore_index=True)

mock_figure = Mock()
px_mock.violin.return_value = mock_figure

# Run
fig = _generate_violin_plot(all_data, columns)

# Assert
px_mock.violin.assert_called_once_with(
DataFrameMatcher(all_data),
x='col1',
y='col2',
box=False,
violinmode='overlay',
color='Data',
color_discrete_map={'Real': '#000036', 'Synthetic': '#01E0C9'},
)
mock_figure.update_layout.assert_called_once_with(
title="Real vs. Synthetic Data for columns 'col1' and 'col2'",
plot_bgcolor='#F5F5F8',
font={'size': 18},
)
assert fig == mock_figure


def test_get_column_pair_plot_invalid_column_names():
"""Test ``get_column_pair_plot`` method with invalid ``column_names``."""
# Setup
Expand Down Expand Up @@ -1049,12 +1091,34 @@ def test_get_column_pair_plot_invalid_plot_type():

# Run and Assert
match = re.escape(
"Invalid plot_type 'distplot'. Please use one of ['box', 'heatmap', 'scatter', None]."
"Invalid plot_type 'distplot'. Please use one of ['box', 'heatmap', 'scatter',"
" 'violin', None]."
)
with pytest.raises(ValueError, match=match):
get_column_pair_plot(real_data, synthetic_data, columns, plot_type='distplot')


@patch('sdmetrics.visualization._generate_violin_plot')
def test_get_column_pair_plot_violin(mock__generate_violin_plot):
"""Test ``get_column_pair_plot`` with ``plot_type`` set to ``violin``."""
# Setup
columns = ['amount', 'price']
real_data = pd.DataFrame({'amount': [1, 2, 3], 'price': [4, 5, 6]})
synthetic_data = pd.DataFrame({'amount': [1.0, 2.0, 3.0], 'price': [4.0, 5.0, 6.0]})
all_data = pd.concat([real_data, synthetic_data], axis=0, ignore_index=True)
all_data['Data'] = ['Real'] * 3 + ['Synthetic'] * 3

# Run
fig = get_column_pair_plot(real_data, synthetic_data, columns, plot_type='violin')

# Assert
mock__generate_violin_plot.assert_called_once_with(
DataFrameMatcher(all_data),
['amount', 'price'],
)
assert fig == mock__generate_violin_plot.return_value


@patch('sdmetrics.visualization._generate_scatter_plot')
def test_get_column_pair_plot_plot_type_none_continuous_data(mock__generate_scatter_plot):
"""Test ``get_column_pair_plot`` with continuous data and ``plot_type`` ``None``."""
Expand Down
Loading