Skip to content

Commit 4c835de

Browse files
committed
Added argument to as_html and as_df to allow for single output selection
1 parent 3afedc0 commit 4c835de

File tree

6 files changed

+135
-76
lines changed

6 files changed

+135
-76
lines changed

src/trustyai/explainers/counterfactuals.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
import matplotlib.pyplot as plt
66
import matplotlib as mpl
77
import pandas as pd
8+
from pandas.io.formats.style import Styler
89
import uuid as _uuid
910

1011
from trustyai import _default_initializer # pylint: disable=unused-import
11-
from .explanation_results import ExplanationResults
1212
from trustyai.utils._visualisation import (
1313
DEFAULT_STYLE as ds,
1414
DEFAULT_RC_PARAMS as drcp,
@@ -50,7 +50,7 @@
5050
CounterfactualConfig = _CounterfactualConfig
5151

5252

53-
class CounterfactualResult(ExplanationResults):
53+
class CounterfactualResult:
5454
"""Wraps Counterfactual results. This object is returned by the
5555
:class:`~CounterfactualExplainer`, and provides a variety of methods to visualize and interact
5656
with the results of the counterfactual explanation.

src/trustyai/explainers/explanation_results.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,27 @@
11
"""Generic class for Explanation and Saliency results"""
22
from abc import ABC, abstractmethod
3-
from typing import Dict
3+
from typing import Dict, Union
44

55
import bokeh.models
66
import pandas as pd
77
from bokeh.io import show
88
from pandas.io.formats.style import Styler
99

1010

11-
class ExplanationResults(ABC):
12-
"""Abstract class for explanation visualisers"""
11+
# pylint: disable=too-few-public-methods
12+
class SaliencyResults(ABC):
13+
"""Abstract class for saliency visualisers"""
1314

1415
@abstractmethod
15-
def as_dataframe(self) -> pd.DataFrame:
16+
def as_dataframe(
17+
self, output_name=None
18+
) -> Union[Dict[str, pd.DataFrame], pd.DataFrame]:
1619
"""Display explanation result as a dataframe"""
1720

1821
@abstractmethod
19-
def as_html(self) -> Styler:
22+
def as_html(self, output_name=None) -> Union[Dict[str, Styler], Styler]:
2023
"""Visualise the styled dataframe"""
2124

22-
23-
# pylint: disable=too-few-public-methods
24-
class SaliencyResults(ExplanationResults):
25-
"""Abstract class for saliency visualisers"""
26-
2725
@abstractmethod
2826
def saliency_map(self):
2927
"""Return the Saliencies as a dictionary, keyed by output name"""

src/trustyai/explainers/lime.py

Lines changed: 58 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -72,46 +72,69 @@ def saliency_map(self) -> Dict[str, Saliency]:
7272
for entry in self._java_saliency_results.saliencies.entrySet()
7373
}
7474

75-
def as_dataframe(self) -> pd.DataFrame:
75+
def as_dataframe(
76+
self, output_name: str = None
77+
) -> Union[Dict[str, pd.DataFrame], pd.DataFrame]:
7678
"""
7779
Return the LIME result as a dataframe.
7880
81+
Parameters
82+
----------
83+
output_name: str
84+
If an output_name is passed, that output's explanation is returned as a pandas
85+
dataframe. Otherwise, all outputs' explanation dataframes are returned in a dictionary.
86+
87+
7988
Returns
8089
-------
81-
pandas.DataFrame
90+
pandas.DataFrame or Dict[str, pandas.Dataframe]
8291
Dictionary of DataFrames, keyed by output name, containing the results of the LIME
83-
explanation. For each model output, the table will contain the following columns:
92+
explanation. Each dataframe will contain the following columns:
8493
8594
* ``Feature``: The name of the feature
8695
* ``Value``: The value of the feature for this particular input.
8796
* ``Saliency``: The importance of this feature to the output.
8897
* ``Confidence``: The confidence of this explanation as returned by the explainer.
8998
9099
"""
100+
91101
outputs = self.saliency_map().keys()
92102

93103
data = {}
94104
for output in outputs:
95-
output_rows = []
96-
for pfi in self.saliency_map().get(output).getPerFeatureImportance():
97-
output_rows.append(
98-
{
99-
"Feature": str(pfi.getFeature().getName().toString()),
100-
"Value": pfi.getFeature().getValue().getUnderlyingObject(),
101-
"Saliency": pfi.getScore(),
102-
"Confidence": pfi.getConfidence(),
103-
}
104-
)
105-
data[output] = pd.DataFrame(output_rows)
105+
if output_name is None or output == output_name:
106+
output_rows = []
107+
for pfi in self.saliency_map().get(output).getPerFeatureImportance():
108+
output_rows.append(
109+
{
110+
"Feature": str(pfi.getFeature().getName().toString()),
111+
"Value": pfi.getFeature().getValue().getUnderlyingObject(),
112+
"Saliency": pfi.getScore(),
113+
"Confidence": pfi.getConfidence(),
114+
}
115+
)
116+
data[output] = pd.DataFrame(output_rows)
117+
118+
if output_name is not None:
119+
return data[output_name]
106120
return data
107121

108-
def as_html(self) -> pd.io.formats.style.Styler:
122+
def as_html(
123+
self, output_name: str = None
124+
) -> Union[Dict[str, pd.io.formats.style.Styler], pd.io.formats.style.Styler]:
109125
"""
110126
Return the LIME results as Pandas Styler objects.
111127
128+
Parameters
129+
----------
130+
output_name: str
131+
If an output_name is passed, that output's explanation is returned as a pandas Styler.
132+
Otherwise, all outputs' explanation stylers are returned in a dictionary.
133+
134+
112135
Returns
113136
-------
114-
Dict[str, pandas.Styler]
137+
pandas.Styler or Dict[str, pandas.Styler]
115138
Dictionary of stylers keyed by output name. Each styler containing the results of the
116139
LIME explanation for that particular output, in the same
117140
schema as in :func:`as_dataframe`. This will:
@@ -121,19 +144,25 @@ def as_html(self) -> pd.io.formats.style.Styler:
121144

122145
htmls = {}
123146
for k, df in self.as_dataframe().items():
124-
htmls[k] = df.style.background_gradient(
125-
LinearSegmentedColormap.from_list(
126-
name="rwg",
127-
colors=[
128-
ds["negative_primary_colour"],
129-
ds["neutral_primary_colour"],
130-
ds["positive_primary_colour"],
131-
],
132-
),
133-
subset="Saliency",
134-
vmin=-1 * max(np.abs(df["Saliency"])),
135-
vmax=max(np.abs(df["Saliency"])),
136-
)
147+
if output_name is None or k == output_name:
148+
style = df.style.background_gradient(
149+
LinearSegmentedColormap.from_list(
150+
name="rwg",
151+
colors=[
152+
ds["negative_primary_colour"],
153+
ds["neutral_primary_colour"],
154+
ds["positive_primary_colour"],
155+
],
156+
),
157+
subset="Saliency",
158+
vmin=-1 * max(np.abs(df["Saliency"])),
159+
vmax=max(np.abs(df["Saliency"])),
160+
)
161+
style.set_caption(f"LIME Explanation of {output_name}")
162+
htmls[k] = style
163+
164+
if output_name is not None:
165+
return htmls[output_name]
137166
return htmls
138167

139168
def _matplotlib_plot(self, output_name: str, block=True) -> None:

src/trustyai/explainers/shap.py

Lines changed: 63 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -123,15 +123,24 @@ def _saliency_to_dataframe(self, saliency, output_name):
123123

124124
return pd.DataFrame([fnull] + data_rows)
125125

126-
def as_dataframe(self) -> Dict[str, pd.DataFrame]:
126+
def as_dataframe(
127+
self, output_name: str = None
128+
) -> Union[Dict[str, pd.DataFrame], pd.DataFrame]:
127129
"""
128130
Return the SHAP results as dataframes.
129131
132+
Parameters
133+
----------
134+
output_name: str
135+
If an output_name is passed, that output's explanation is returned as a dataframe.
136+
Otherwise, all outputs' explanation dataframe are returned in a dictionary.
137+
130138
Returns
131139
-------
132-
Dict[str, pandas.DataFrame]
133-
Dictionary of DataFrames, keyed by output name, containing the results of the SHAP
134-
explanation. For each model output, the table will contain the following columns:
140+
pandas.Dataframe or Dict[str, pandas.DataFrame]
141+
A dataframe or dictionary of DataFrames, keyed by output name. Each dataframe
142+
contains the results of the SHAP explanation for a particular output. Each dataframe
143+
wiil contain the following columns:
135144
136145
* ``Feature``: The name of the feature
137146
* ``Feature Value``: The value of the feature for this particular input.
@@ -140,18 +149,33 @@ def as_dataframe(self) -> Dict[str, pd.DataFrame]:
140149
* ``Confidence``: The confidence of this explanation as returned by the explainer.
141150
142151
"""
143-
df_dict = {}
144-
for output_name, saliency in self.saliency_map().items():
145-
df_dict[output_name] = self._saliency_to_dataframe(saliency, output_name)
146-
return df_dict
152+
if output_name is None:
153+
df_dict = {}
154+
for output_name_key, saliency in self.saliency_map().items():
155+
df_dict[output_name_key] = self._saliency_to_dataframe(
156+
saliency, output_name_key
157+
)
158+
return df_dict
159+
return self._saliency_to_dataframe(
160+
self.saliency_map()[output_name], output_name
161+
)
147162

148-
def as_html(self) -> Dict[str, pd.io.formats.style.Styler]:
163+
def as_html(
164+
self, output_name: str = None
165+
) -> Union[Dict[str, pd.io.formats.style.Styler], pd.io.formats.style.Styler]:
149166
"""
150167
Return the SHAP results as Pandas Styler objects.
151168
169+
Parameters
170+
----------
171+
output_name: str
172+
If an output_name is passed, that output's explanation is returned as a pandas Styler.
173+
Otherwise, all outputs' explanation stylers are returned in a dictionary.
174+
175+
152176
Returns
153177
-------
154-
Dict[str, pandas.Styler]
178+
Pandas Styler or Dict[str, pandas.Styler]
155179
Dictionary of stylers keyed by output name. Each styler containing the results of the
156180
SHAP explanation for that particular output, in the same
157181
schema as in :func:`as_dataframe`. This will:
@@ -174,31 +198,35 @@ def _color_feature_values(feature_values, background_vals):
174198
return [None] + formats
175199

176200
df_dict = {}
177-
for output_name, saliency in self.saliency_map().items():
178-
df = self._saliency_to_dataframe(saliency, output_name)
179-
shap_values = df["SHAP Value"].values[1:]
180-
background_mean_feature_values = df["Mean Background Value"].values[1:]
181-
182-
style = df.style.background_gradient(
183-
LinearSegmentedColormap.from_list(
184-
name="rwg",
185-
colors=[
186-
ds["negative_primary_colour"],
187-
ds["neutral_primary_colour"],
188-
ds["positive_primary_colour"],
189-
],
190-
),
191-
subset=(slice(1, None), "SHAP Value"),
192-
vmin=-1 * max(np.abs(shap_values)),
193-
vmax=max(np.abs(shap_values)),
194-
)
195-
style.set_caption(f"Explanation of {output_name}")
196-
df_dict[output_name] = style.apply(
197-
_color_feature_values,
198-
background_vals=background_mean_feature_values,
199-
subset="Value",
200-
axis=0,
201-
)
201+
for output_name_key, saliency in self.saliency_map().items():
202+
if output_name is None or output_name_key == output_name:
203+
df = self._saliency_to_dataframe(saliency, output_name_key)
204+
shap_values = df["SHAP Value"].values[1:]
205+
background_mean_feature_values = df["Mean Background Value"].values[1:]
206+
207+
style = df.style.background_gradient(
208+
LinearSegmentedColormap.from_list(
209+
name="rwg",
210+
colors=[
211+
ds["negative_primary_colour"],
212+
ds["neutral_primary_colour"],
213+
ds["positive_primary_colour"],
214+
],
215+
),
216+
subset=(slice(1, None), "SHAP Value"),
217+
vmin=-1 * max(np.abs(shap_values)),
218+
vmax=max(np.abs(shap_values)),
219+
)
220+
style.set_caption(f"SHAP Explanation of {output_name_key}")
221+
df_dict[output_name_key] = style.apply(
222+
_color_feature_values,
223+
background_vals=background_mean_feature_values,
224+
subset="Value",
225+
axis=0,
226+
)
227+
228+
if output_name is not None:
229+
return df_dict[output_name]
202230
return df_dict
203231

204232
def _matplotlib_plot(self, output_name, block=True) -> None:

tests/general/test_limeexplainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ def test_lime_numpy():
188188

189189
for oname in onames:
190190
assert oname in explanation.as_dataframe().keys()
191+
assert len(explanation.as_dataframe(oname)) == 5
191192
for fname in fnames:
192193
assert fname in explanation.as_dataframe()[oname]['Feature'].values
193194

tests/general/test_shap.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,5 +131,8 @@ def test_shap_numpy():
131131

132132
for oname in onames:
133133
assert oname in explanation.as_dataframe().keys()
134+
assert len(explanation.as_dataframe(oname)) == 5 + 1
135+
134136
for fname in fnames:
135137
assert fname in explanation.as_dataframe()[oname]['Feature'].values
138+

0 commit comments

Comments
 (0)