Skip to content

Commit e6ea26a

Browse files
move viz methods to seperate module (#203)
1 parent 295b423 commit e6ea26a

File tree

15 files changed

+566
-456
lines changed

15 files changed

+566
-456
lines changed
Lines changed: 0 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
"""Generic class for Explanation and Saliency results"""
22
from abc import ABC, abstractmethod
3-
from typing import Dict
43

5-
import bokeh.models
64
import pandas as pd
7-
from bokeh.io import show
85
from pandas.io.formats.style import Styler
96

107

@@ -27,51 +24,3 @@ class SaliencyResults(ExplanationResults):
2724
@abstractmethod
2825
def saliency_map(self):
2926
"""Return the Saliencies as a dictionary, keyed by output name"""
30-
31-
@abstractmethod
32-
def _matplotlib_plot(self, output_name: str, block: bool, call_show: bool) -> None:
33-
"""Plot the saliencies of a particular output in matplotlib"""
34-
35-
@abstractmethod
36-
def _get_bokeh_plot(self, output_name: str) -> bokeh.models.Plot:
37-
"""Get a bokeh plot visualizing the saliencies of a particular output"""
38-
39-
def _get_bokeh_plot_dict(self) -> Dict[str, bokeh.models.Plot]:
40-
"""Get a dictionary containing visualizations of the saliencies of all outputs,
41-
keyed by output name"""
42-
return {
43-
output_name: self._get_bokeh_plot(output_name)
44-
for output_name in self.saliency_map().keys()
45-
}
46-
47-
def plot(
48-
self, output_name=None, render_bokeh=False, block=True, call_show=True
49-
) -> None:
50-
"""
51-
Plot the found feature saliencies.
52-
53-
Parameters
54-
----------
55-
output_name : str
56-
(default= `None`) The name of the output to be explainer. If `None`, all outputs will
57-
be displayed
58-
render_bokeh : bool
59-
(default= `False`) If true, render plot in bokeh, otherwise use matplotlib.
60-
block: bool
61-
(default= `True`) Whether displaying the plot blocks subsequent code execution
62-
call_show: bool
63-
(default= 'True') Whether plt.show() will be called by default at the end of the
64-
plotting function. If `False`, the plot will be returned to the user for further
65-
editing.
66-
"""
67-
if output_name is None:
68-
for output_name_iterator in self.saliency_map().keys():
69-
if render_bokeh:
70-
show(self._get_bokeh_plot(output_name_iterator))
71-
else:
72-
self._matplotlib_plot(output_name_iterator, block, call_show)
73-
else:
74-
if render_bokeh:
75-
show(self._get_bokeh_plot(output_name))
76-
else:
77-
self._matplotlib_plot(output_name, block, call_show)

src/trustyai/explainers/lime.py

Lines changed: 1 addition & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,12 @@
33
# pylint: disable = unused-argument, duplicate-code, consider-using-f-string, invalid-name
44
from typing import Dict, Union
55

6-
import bokeh.models
7-
import matplotlib.pyplot as plt
8-
import matplotlib as mpl
96
import numpy as np
10-
from bokeh.models import ColumnDataSource, HoverTool
11-
from bokeh.plotting import figure
127
import pandas as pd
138
from matplotlib.colors import LinearSegmentedColormap
149

1510
from trustyai import _default_initializer # pylint: disable=unused-import
16-
from trustyai.utils._visualisation import (
17-
DEFAULT_STYLE as ds,
18-
DEFAULT_RC_PARAMS as drcp,
19-
bold_red_html,
20-
bold_green_html,
21-
output_html,
22-
feature_html,
23-
)
24-
11+
from trustyai.utils._visualisation import DEFAULT_STYLE as ds
2512
from trustyai.utils.data_conversions import (
2613
OneInputUnionType,
2714
data_conversion_docstring,
@@ -137,96 +124,6 @@ def as_html(self) -> pd.io.formats.style.Styler:
137124
)
138125
return htmls
139126

140-
def _matplotlib_plot(self, output_name: str, block=True, call_show=True) -> None:
141-
"""Plot the LIME saliencies."""
142-
with mpl.rc_context(drcp):
143-
dictionary = {}
144-
for feature_importance in (
145-
self.saliency_map().get(output_name).getPerFeatureImportance()
146-
):
147-
dictionary[
148-
feature_importance.getFeature().name
149-
] = feature_importance.getScore()
150-
151-
colours = [
152-
ds["negative_primary_colour"]
153-
if i < 0
154-
else ds["positive_primary_colour"]
155-
for i in dictionary.values()
156-
]
157-
plt.title(f"LIME: Feature Importances to {output_name}")
158-
plt.barh(
159-
range(len(dictionary)),
160-
dictionary.values(),
161-
align="center",
162-
color=colours,
163-
)
164-
plt.yticks(range(len(dictionary)), list(dictionary.keys()))
165-
plt.tight_layout()
166-
167-
if call_show:
168-
plt.show(block=block)
169-
170-
def _get_bokeh_plot(self, output_name) -> bokeh.models.Plot:
171-
lime_data_source = pd.DataFrame(
172-
[
173-
{
174-
"feature": str(pfi.getFeature().getName()),
175-
"saliency": pfi.getScore(),
176-
}
177-
for pfi in self.saliency_map()[output_name].getPerFeatureImportance()
178-
]
179-
)
180-
lime_data_source["color"] = lime_data_source["saliency"].apply(
181-
lambda x: ds["positive_primary_colour"]
182-
if x >= 0
183-
else ds["negative_primary_colour"]
184-
)
185-
lime_data_source["saliency_colored"] = lime_data_source["saliency"].apply(
186-
lambda x: (bold_green_html if x >= 0 else bold_red_html)("{:.2f}".format(x))
187-
)
188-
189-
lime_data_source["color_faded"] = lime_data_source["saliency"].apply(
190-
lambda x: ds["positive_primary_colour_faded"]
191-
if x >= 0
192-
else ds["negative_primary_colour_faded"]
193-
)
194-
source = ColumnDataSource(lime_data_source)
195-
htool = HoverTool(
196-
name="bars",
197-
tooltips="<h3>LIME</h3> {} saliency to {}: @saliency_colored".format(
198-
feature_html("@feature"), output_html(output_name)
199-
),
200-
)
201-
bokeh_plot = figure(
202-
sizing_mode="stretch_both",
203-
title="Lime Feature Importances",
204-
y_range=lime_data_source["feature"],
205-
tools=[htool],
206-
)
207-
bokeh_plot.hbar(
208-
y="feature",
209-
left=0,
210-
right="saliency",
211-
fill_color="color_faded",
212-
line_color="color",
213-
hover_color="color",
214-
color="color",
215-
height=0.75,
216-
name="bars",
217-
source=source,
218-
)
219-
bokeh_plot.line([0, 0], [0, len(lime_data_source)], color="#000")
220-
bokeh_plot.xaxis.axis_label = "Saliency Value"
221-
bokeh_plot.yaxis.axis_label = "Feature"
222-
return bokeh_plot
223-
224-
def _get_bokeh_plot_dict(self) -> Dict[str, bokeh.models.Plot]:
225-
return {
226-
output_name: self._get_bokeh_plot(output_name)
227-
for output_name in self.saliency_map().keys()
228-
}
229-
230127

231128
class LimeExplainer:
232129
"""*"Which features were most important to the results?"*

src/trustyai/explainers/pdp.py

Lines changed: 1 addition & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
"""Explainers.pdp module"""
2-
32
import math
4-
import matplotlib.pyplot as plt
53
import pandas as pd
64
from pandas.io.formats.style import Styler
75

@@ -62,45 +60,6 @@ def as_html(self) -> Styler:
6260
"""
6361
return self.as_dataframe().style
6462

65-
def plot(self, output_name=None, block=True, call_show=True) -> None:
66-
"""
67-
Parameters
68-
----------
69-
output_name: str
70-
name of the output to be plotted
71-
Default to None
72-
block: bool
73-
whether the plotting operation
74-
should be blocking or not
75-
call_show: bool
76-
(default= 'True') Whether plt.show() will be called by default at the end of
77-
the plotting function. If `False`, the plot will be returned to the user for
78-
further editing.
79-
"""
80-
fig, axs = plt.subplots(len(self.pdp_graphs), constrained_layout=True)
81-
p_idx = 0
82-
for pdp_graph in self.pdp_graphs:
83-
if output_name is not None and output_name != str(
84-
pdp_graph.getOutput().getName()
85-
):
86-
continue
87-
fig.suptitle(str(pdp_graph.getOutput().getName()))
88-
pdp_x = []
89-
for i in range(len(pdp_graph.getX())):
90-
pdp_x.append(self._to_plottable(pdp_graph.getX()[i]))
91-
pdp_y = []
92-
for i in range(len(pdp_graph.getY())):
93-
pdp_y.append(self._to_plottable(pdp_graph.getY()[i]))
94-
axs[p_idx].plot(pdp_x, pdp_y)
95-
axs[p_idx].set_title(
96-
str(pdp_graph.getFeature().getName()), loc="left", fontsize="small"
97-
)
98-
axs[p_idx].grid()
99-
p_idx += 1
100-
fig.supylabel("Partial Dependence Plot")
101-
if call_show:
102-
plt.show(block=block)
103-
10463
@staticmethod
10564
def _to_plottable(datum: Value):
10665
plottable = datum.asNumber()
@@ -187,7 +146,7 @@ def getInputShape(self):
187146
"""
188147
return self.data.sample()
189148

190-
# pylint: disable = invalid-name
149+
# pylint: disable = invalid-name, missing-final-newline
191150
@JOverride
192151
def getOutputShape(self):
193152
"""

0 commit comments

Comments
 (0)