Skip to content

Commit 6f37990

Browse files
authored
Add explainer for local classifier per level #minor (#116)
1 parent 2fe8480 commit 6f37990

File tree

4 files changed

+157
-15
lines changed

4 files changed

+157
-15
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,8 @@ predictions = pipeline.predict(X_test)
202202
```
203203

204204
## Explaining Hierarchical Classifiers
205-
Hierarchical classifiers can provide additional insights when combined with explainability methods. HiClass allows explaining hierarchical models using SHAP values. Different hierarchical models yield different insights. More information on explaining [Local classifier per parent node](https://colab.research.google.com/drive/1rVlYuRU_uO1jw5sD6qo2HoCpCz6E6z5J?usp=sharing), [Local classifier per node](https://colab.research.google.com/drive/1wqSl1t_Qn2f62WNZQ48mdB0mNeu1XSF1?usp=sharing), and [Local classifier per level]() is available on [Read the Docs](https://hiclass.readthedocs.io/en/latest/algorithms/explainer.html).
205+
206+
Hierarchical classifiers can provide additional insights when combined with explainability methods. HiClass allows explaining hierarchical models using SHAP values. Different hierarchical models yield different insights. More information on explaining [Local classifier per parent node](https://colab.research.google.com/drive/1rVlYuRU_uO1jw5sD6qo2HoCpCz6E6z5J?usp=sharing), [Local classifier per node](https://colab.research.google.com/drive/1wqSl1t_Qn2f62WNZQ48mdB0mNeu1XSF1?usp=sharing), and [Local classifier per level](https://colab.research.google.com/drive/1VnGlJu-1wSG4wxHXL0Ijf2a7Pu3kklT-?usp=sharing) is available on [Read the Docs](https://hiclass.readthedocs.io/en/latest/algorithms/explainer.html).
206207

207208
## Step-by-step walk-through
208209

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
=========================================
4+
Explaining Local Classifier Per Level
5+
=========================================
6+
7+
A minimalist example showing how to use HiClass Explainer to obtain SHAP values of LCPL model.
8+
A detailed summary of the Explainer class has been given at Algorithms Overview Section for :ref:`Hierarchical Explainability`.
9+
SHAP values are calculated based on a synthetic platypus diseases dataset that can be downloaded `here <https://gist.githubusercontent.com/ashishpatel16/9306f8ed3ed101e7ddcb519776bcbd80/raw/3f225c3f80dd8cbb1b6252f6c372a054ec968705/platypus_diseases.csv>`_.
10+
"""
11+
from sklearn.ensemble import RandomForestClassifier
12+
from hiclass import LocalClassifierPerLevel, Explainer
13+
import shap
14+
from hiclass.datasets import load_platypus
15+
16+
# Load train and test splits
17+
X_train, X_test, Y_train, Y_test = load_platypus()
18+
19+
# Use random forest classifiers for every level
20+
rfc = RandomForestClassifier()
21+
classifier = LocalClassifierPerLevel(local_classifier=rfc, replace_classifiers=False)
22+
23+
# Train local classifiers per level
24+
classifier.fit(X_train, Y_train)
25+
26+
# Define Explainer
27+
explainer = Explainer(classifier, data=X_train, mode="tree")
28+
explanations = explainer.explain(X_test.values)
29+
print(explanations)
30+
31+
# Let's filter the Shapley values corresponding to the Covid (level 1)
32+
# and 'Respiratory' (level 0)
33+
34+
covid_idx = classifier.predict(X_test)[:, 1] == "Covid"
35+
36+
shap_filter_covid = {"level": 1, "class": "Covid", "sample": covid_idx}
37+
shap_filter_resp = {"level": 0, "class": "Respiratory", "sample": covid_idx}
38+
shap_val_covid = explanations.sel(**shap_filter_covid)
39+
shap_val_resp = explanations.sel(**shap_filter_resp)
40+
41+
42+
# This code snippet demonstrates how to visually compare the mean absolute SHAP values for 'Covid' vs. 'Respiratory' diseases.
43+
44+
# Feature names for the X-axis
45+
feature_names = X_train.columns.values
46+
47+
# SHAP values for 'Covid'
48+
shap_values_covid = shap_val_covid.shap_values.values
49+
50+
# SHAP values for 'Respiratory'
51+
shap_values_resp = shap_val_resp.shap_values.values
52+
53+
shap.summary_plot(
54+
[shap_values_covid, shap_values_resp],
55+
features=X_test.iloc[covid_idx],
56+
feature_names=X_train.columns.values,
57+
plot_type="bar",
58+
class_names=["Covid", "Respiratory"],
59+
)

hiclass/Explainer.py

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,31 @@ def _get_traversed_nodes_lcpn(self, samples):
229229

230230
return traversals
231231

232+
def _get_traversed_nodes_lcpl(self, samples):
233+
"""
234+
Return a list of all traversed nodes as per the provided LocalClassifierPerLevel model.
235+
236+
Parameters
237+
----------
238+
samples : array-like
239+
Sample data for which to generate traversed nodes.
240+
241+
Returns
242+
-------
243+
traversals : list
244+
A list of all traversed nodes as per LocalClassifierPerLevel (LCPL) strategy.
245+
"""
246+
traversals = []
247+
predictions = self.hierarchical_model.predict(samples)
248+
for pred in predictions:
249+
traversal_order = []
250+
filtered_pred = [p for p in pred if p.strip()]
251+
for i in range(1, len(filtered_pred) + 1):
252+
node = self.hierarchical_model.separator_.join(filtered_pred[:i])
253+
traversal_order.append(node)
254+
traversals.append(traversal_order)
255+
return traversals
256+
232257
def _calculate_shap_values(self, X):
233258
"""
234259
Return an xarray.Dataset object for a single sample provided. This dataset is aligned on the `level` attribute.
@@ -244,23 +269,27 @@ def _calculate_shap_values(self, X):
244269
A single explanation for the prediction of given sample.
245270
"""
246271
traversed_nodes = []
247-
if isinstance(self.hierarchical_model, LocalClassifierPerParentNode):
272+
if isinstance(self.hierarchical_model, LocalClassifierPerLevel):
273+
traversed_nodes = self._get_traversed_nodes_lcpl(X)[0]
274+
elif isinstance(self.hierarchical_model, LocalClassifierPerParentNode):
248275
traversed_nodes = self._get_traversed_nodes_lcppn(X)[0]
249276
elif isinstance(self.hierarchical_model, LocalClassifierPerNode):
250277
traversed_nodes = self._get_traversed_nodes_lcpn(X)[0]
251278
datasets = []
252279
level = 0
253280
for node in traversed_nodes:
254-
# Skip if node is empty or classifier is not found, can happen in case of imbalanced hierarchies
255-
if (
256-
node == ""
257-
or "classifier" not in self.hierarchical_model.hierarchy_.nodes[node]
281+
if node == "" or (
282+
("classifier" not in self.hierarchical_model.hierarchy_.nodes[node])
283+
and (not isinstance(self.hierarchical_model, LocalClassifierPerLevel))
258284
):
259285
continue
260286

261-
local_classifier = self.hierarchical_model.hierarchy_.nodes[node][
262-
"classifier"
263-
]
287+
if isinstance(self.hierarchical_model, LocalClassifierPerLevel):
288+
local_classifier = self.hierarchical_model.local_classifiers_[level]
289+
else:
290+
local_classifier = self.hierarchical_model.hierarchy_.nodes[node][
291+
"classifier"
292+
]
264293

265294
# Create a SHAP explainer for the local classifier
266295
local_explainer = deepcopy(self.explainer)(local_classifier, self.data)
@@ -283,7 +312,7 @@ def _calculate_shap_values(self, X):
283312
for label in local_classifier.classes_
284313
]
285314
predicted_class = current_node
286-
else:
315+
elif isinstance(self.hierarchical_model, LocalClassifierPerParentNode):
287316
simplified_labels = [
288317
label.split(self.hierarchical_model.separator_)[-1]
289318
for label in local_classifier.classes_
@@ -293,6 +322,12 @@ def _calculate_shap_values(self, X):
293322
.flatten()[0]
294323
.split(self.hierarchical_model.separator_)[-1]
295324
)
325+
else:
326+
simplified_labels = [
327+
label.split(self.hierarchical_model.separator_)[-1]
328+
for label in local_classifier.classes_
329+
]
330+
predicted_class = current_node
296331

297332
classes = xr.DataArray(
298333
simplified_labels,
@@ -326,7 +361,7 @@ def _calculate_shap_values(self, X):
326361
"level": level,
327362
}
328363
)
329-
level = level + 1
364+
level += 1
330365
datasets.append(local_dataset)
331366
sample_explanation = xr.concat(datasets, dim="level")
332367
return sample_explanation

tests/test_Explainer.py

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
import numpy as np
22
import pytest
33
from sklearn.ensemble import RandomForestClassifier
4-
from hiclass import LocalClassifierPerNode, LocalClassifierPerParentNode, Explainer
4+
from hiclass import (
5+
LocalClassifierPerLevel,
6+
LocalClassifierPerParentNode,
7+
LocalClassifierPerNode,
8+
Explainer,
9+
)
510

611
try:
712
import shap
@@ -98,6 +103,26 @@ def test_explainer_tree_lcpn(data, request):
98103
assert str(explanations["node"][i].data[j]) == y_pred[j]
99104

100105

106+
@pytest.mark.skipif(not shap_installed, reason="shap not installed")
107+
@pytest.mark.parametrize("data", ["explainer_data", "explainer_data_no_root"])
108+
def test_explainer_tree_lcpl(data, request):
109+
rfc = RandomForestClassifier()
110+
lcpl = LocalClassifierPerLevel(local_classifier=rfc, replace_classifiers=False)
111+
112+
x_train, x_test, y_train = request.getfixturevalue(data)
113+
114+
lcpl.fit(x_train, y_train)
115+
116+
explainer = Explainer(lcpl, data=x_train, mode="tree")
117+
explanations = explainer.explain(x_test)
118+
assert explanations is not None
119+
y_preds = lcpl.predict(x_test)
120+
for i in range(len(x_test)):
121+
y_pred = y_preds[i]
122+
for j in range(len(y_pred)):
123+
assert str(explanations["node"][i].data[j]) == y_pred[j]
124+
125+
101126
@pytest.mark.skipif(not shap_installed, reason="shap not installed")
102127
@pytest.mark.parametrize("data", ["explainer_data", "explainer_data_no_root"])
103128
def test_traversal_path_lcppn(data, request):
@@ -142,11 +167,30 @@ def test_traversal_path_lcpn(data, request):
142167
assert label == preds[i][j]
143168

144169

170+
@pytest.mark.skipif(not shap_installed, reason="shap not installed")
171+
@pytest.mark.parametrize("data", ["explainer_data", "explainer_data_no_root"])
172+
def test_traversal_path_lcpl(data, request):
173+
x_train, x_test, y_train = request.getfixturevalue(data)
174+
rfc = RandomForestClassifier()
175+
lcpl = LocalClassifierPerLevel(local_classifier=rfc, replace_classifiers=False)
176+
177+
lcpl.fit(x_train, y_train)
178+
explainer = Explainer(lcpl, data=x_train, mode="tree")
179+
traversals = explainer._get_traversed_nodes_lcpl(x_test)
180+
preds = lcpl.predict(x_test)
181+
assert len(preds) == len(traversals)
182+
for i in range(len(x_test)):
183+
for j in range(len(traversals[i])):
184+
label = traversals[i][j].split(lcpl.separator_)[-1]
185+
assert label == preds[i][j]
186+
187+
145188
@pytest.mark.skipif(not shap_installed, reason="shap not installed")
146189
@pytest.mark.skipif(not xarray_installed, reason="xarray not installed")
147190
@pytest.mark.parametrize("data", ["explainer_data", "explainer_data_no_root"])
148191
@pytest.mark.parametrize(
149-
"classifier", [LocalClassifierPerParentNode, LocalClassifierPerNode]
192+
"classifier",
193+
[LocalClassifierPerLevel, LocalClassifierPerParentNode, LocalClassifierPerNode],
150194
)
151195
def test_explain_with_xr(data, request, classifier):
152196
x_train, x_test, y_train = request.getfixturevalue(data)
@@ -162,7 +206,8 @@ def test_explain_with_xr(data, request, classifier):
162206

163207

164208
@pytest.mark.parametrize(
165-
"classifier", [LocalClassifierPerParentNode, LocalClassifierPerNode]
209+
"classifier",
210+
[LocalClassifierPerParentNode, LocalClassifierPerLevel, LocalClassifierPerNode],
166211
)
167212
def test_imports(classifier):
168213
x_train = [[76, 12, 49], [88, 63, 31], [5, 42, 24], [17, 90, 55]]
@@ -176,8 +221,10 @@ def test_imports(classifier):
176221
assert isinstance(explainer.data, np.ndarray)
177222

178223

224+
@pytest.mark.skipif(not shap_installed, reason="shap not installed")
179225
@pytest.mark.parametrize(
180-
"classifier", [LocalClassifierPerParentNode, LocalClassifierPerNode]
226+
"classifier",
227+
[LocalClassifierPerLevel, LocalClassifierPerParentNode, LocalClassifierPerNode],
181228
)
182229
@pytest.mark.parametrize("data", ["explainer_data"])
183230
@pytest.mark.parametrize("mode", ["linear", "gradient", "deep", "tree", ""])

0 commit comments

Comments
 (0)