Skip to content

Commit ee10f7e

Browse files
authored
Added option to bypass mandatory plotting of exp results (#151)
1 parent ebc34ea commit ee10f7e

File tree

6 files changed

+53
-13
lines changed

6 files changed

+53
-13
lines changed

src/trustyai/explainers/counterfactuals.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def as_html(self) -> pd.io.formats.style.Styler:
123123
"""
124124
return self.as_dataframe().style
125125

126-
def plot(self, block=True) -> None:
126+
def plot(self, block=True, call_show=True) -> None:
127127
"""
128128
Plot the counterfactual result.
129129
"""
@@ -145,7 +145,8 @@ def change_colour(value):
145145
x="features", color={"proposed": colour, "original": "black"}
146146
)
147147
plot.set_title("Counterfactual")
148-
plt.show(block=block)
148+
if call_show:
149+
plt.show(block=block)
149150

150151

151152
class CounterfactualExplainer:

src/trustyai/explainers/explanation_results.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def saliency_map(self):
2929
"""Return the Saliencies as a dictionary, keyed by output name"""
3030

3131
@abstractmethod
32-
def _matplotlib_plot(self, output_name: str, block: bool) -> None:
32+
def _matplotlib_plot(self, output_name: str, block: bool, call_show: bool) -> None:
3333
"""Plot the saliencies of a particular output in matplotlib"""
3434

3535
@abstractmethod
@@ -44,7 +44,9 @@ def _get_bokeh_plot_dict(self) -> Dict[str, bokeh.models.Plot]:
4444
for output_name in self.saliency_map().keys()
4545
}
4646

47-
def plot(self, output_name=None, render_bokeh=False, block=True) -> None:
47+
def plot(
48+
self, output_name=None, render_bokeh=False, block=True, call_show=True
49+
) -> None:
4850
"""
4951
Plot the found feature saliencies.
5052
@@ -57,15 +59,19 @@ def plot(self, output_name=None, render_bokeh=False, block=True) -> None:
5759
(default= `False`) If true, render plot in bokeh, otherwise use matplotlib.
5860
block: bool
5961
(default= `True`) Whether displaying the plot blocks subsequent code execution
62+
call_show: bool
63+
(default= 'True') Whether plt.show() will be called by default at the end of the
64+
plotting function. If `False`, the plot will be returned to the user for further
65+
editing.
6066
"""
6167
if output_name is None:
6268
for output_name_iterator in self.saliency_map().keys():
6369
if render_bokeh:
6470
show(self._get_bokeh_plot(output_name_iterator))
6571
else:
66-
self._matplotlib_plot(output_name_iterator, block)
72+
self._matplotlib_plot(output_name_iterator, block, call_show)
6773
else:
6874
if render_bokeh:
6975
show(self._get_bokeh_plot(output_name))
7076
else:
71-
self._matplotlib_plot(output_name, block)
77+
self._matplotlib_plot(output_name, block, call_show)

src/trustyai/explainers/lime.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def as_html(self) -> pd.io.formats.style.Styler:
137137
)
138138
return htmls
139139

140-
def _matplotlib_plot(self, output_name: str, block=True) -> None:
140+
def _matplotlib_plot(self, output_name: str, block=True, call_show=True) -> None:
141141
"""Plot the LIME saliencies."""
142142
with mpl.rc_context(drcp):
143143
dictionary = {}
@@ -163,7 +163,9 @@ def _matplotlib_plot(self, output_name: str, block=True) -> None:
163163
)
164164
plt.yticks(range(len(dictionary)), list(dictionary.keys()))
165165
plt.tight_layout()
166-
plt.show(block=block)
166+
167+
if call_show:
168+
plt.show(block=block)
167169

168170
def _get_bokeh_plot(self, output_name) -> bokeh.models.Plot:
169171
lime_data_source = pd.DataFrame(

src/trustyai/explainers/pdp.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def as_html(self) -> Styler:
6262
"""
6363
return self.as_dataframe().style
6464

65-
def plot(self, output_name=None, block=True) -> None:
65+
def plot(self, output_name=None, block=True, call_show=True) -> None:
6666
"""
6767
Parameters
6868
----------
@@ -72,6 +72,10 @@ def plot(self, output_name=None, block=True) -> None:
7272
block: bool
7373
whether the plotting operation
7474
should be blocking or not
75+
call_show: bool
76+
(default= 'True') Whether plt.show() will be called by default at the end of
77+
the plotting function. If `False`, the plot will be returned to the user for
78+
further editing.
7579
"""
7680
fig, axs = plt.subplots(len(self.pdp_graphs), constrained_layout=True)
7781
p_idx = 0
@@ -94,7 +98,8 @@ def plot(self, output_name=None, block=True) -> None:
9498
axs[p_idx].grid()
9599
p_idx += 1
96100
fig.supylabel("Partial Dependence Plot")
97-
plt.show(block=block)
101+
if call_show:
102+
plt.show(block=block)
98103

99104
@staticmethod
100105
def _to_plottable(datum: Value):

src/trustyai/explainers/shap.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def _color_feature_values(feature_values, background_vals):
201201
)
202202
return df_dict
203203

204-
def _matplotlib_plot(self, output_name, block=True) -> None:
204+
def _matplotlib_plot(self, output_name, block=True, call_show=True) -> None:
205205
"""Visualize the SHAP explanation of each output as a set of candlestick plots,
206206
one per output."""
207207
with mpl.rc_context(drcp):
@@ -219,7 +219,9 @@ def _matplotlib_plot(self, output_name, block=True) -> None:
219219
]
220220
fnull = self.get_fnull()[output_name]
221221
prediction = fnull + sum(shap_values)
222-
plt.figure()
222+
223+
if call_show:
224+
plt.figure()
223225
pos = fnull
224226
for j, shap_value in enumerate(shap_values):
225227
color = (
@@ -255,7 +257,8 @@ def _matplotlib_plot(self, output_name, block=True) -> None:
255257
plt.ylabel(self.saliency_map()[output_name].getOutput().getName())
256258
plt.xlabel("Feature SHAP Value")
257259
plt.title(f"SHAP: Feature Contributions to {output_name}")
258-
plt.show(block=block)
260+
if call_show:
261+
plt.show(block=block)
259262

260263
def _get_bokeh_plot(self, output_name):
261264
fnull = self.get_fnull()[output_name]

tests/general/test_shap.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import pandas as pd
77
import numpy as np
8+
import matplotlib.pyplot as plt
89

910
np.random.seed(0)
1011

@@ -133,3 +134,25 @@ def test_shap_numpy():
133134
assert oname in explanation.as_dataframe().keys()
134135
for fname in fnames:
135136
assert fname in explanation.as_dataframe()[oname]['Feature'].values
137+
138+
139+
# deliberately make strange plot to test pre and post-function plot editing
140+
def test_shap_edit_plot():
141+
np.random.seed(0)
142+
data = pd.DataFrame(np.random.rand(101, 5))
143+
background = data.iloc[:100].values
144+
to_explain = data.iloc[100:101].values
145+
146+
model_weights = np.random.rand(5)
147+
predict_function = lambda x: np.stack([np.dot(x, model_weights), 2 * np.dot(x, model_weights)], -1)
148+
149+
model = Model(predict_function, disable_arrow=True)
150+
151+
shap_explainer = SHAPExplainer(background=background)
152+
explanation = shap_explainer.explain(inputs=to_explain, outputs=model(to_explain), model=model)
153+
154+
plt.figure(figsize=(32,2))
155+
explanation.plot(call_show=False)
156+
plt.ylim(0, 123)
157+
plt.show()
158+

0 commit comments

Comments
 (0)