Skip to content

Commit d1bbd33

Browse files
Explainer implementation for LCPN #minor (#108)
1 parent 4595264 commit d1bbd33

File tree

5 files changed

+152
-11
lines changed

5 files changed

+152
-11
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ predictions = pipeline.predict(X_test)
202202
```
203203

204204
## Explaining Hierarchical Classifiers
205-
Hierarchical classifiers can provide additional insights when combined with explainability methods. HiClass allows explaining hierarchical models using SHAP values. Different hierarchical models yield different insights. More information on explaining [Local classifier per parent node](https://colab.research.google.com/drive/1rVlYuRU_uO1jw5sD6qo2HoCpCz6E6z5J?usp=sharing), [Local classifier per node](), and [Local classifier per level]() is available on [Read the Docs](https://hiclass.readthedocs.io/en/latest/algorithms/explainer.html).
205+
Hierarchical classifiers can provide additional insights when combined with explainability methods. HiClass allows explaining hierarchical models using SHAP values. Different hierarchical models yield different insights. More information on explaining [Local classifier per parent node](https://colab.research.google.com/drive/1rVlYuRU_uO1jw5sD6qo2HoCpCz6E6z5J?usp=sharing), [Local classifier per node](https://colab.research.google.com/drive/1wqSl1t_Qn2f62WNZQ48mdB0mNeu1XSF1?usp=sharing), and [Local classifier per level]() is available on [Read the Docs](https://hiclass.readthedocs.io/en/latest/algorithms/explainer.html).
206206

207207
## Step-by-step walk-through
208208

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
=========================================
4+
Explaining Local Classifier Per Node
5+
=========================================
6+
7+
A minimalist example showing how to use HiClass Explainer to obtain SHAP values of LCPN model.
8+
A detailed summary of the Explainer class has been given at Algorithms Overview Section for :ref:`Hierarchical Explainability`.
9+
SHAP values are calculated based on a synthetic platypus diseases dataset that can be downloaded `here <https://gist.githubusercontent.com/ashishpatel16/9306f8ed3ed101e7ddcb519776bcbd80/raw/3f225c3f80dd8cbb1b6252f6c372a054ec968705/platypus_diseases.csv>`_.
10+
"""
11+
import numpy as np
12+
from sklearn.ensemble import RandomForestClassifier
13+
from hiclass import LocalClassifierPerNode, Explainer
14+
from hiclass.datasets import load_platypus
15+
import shap
16+
17+
# Load train and test splits
18+
X_train, X_test, Y_train, Y_test = load_platypus()
19+
20+
# Use random forest classifiers for every node
21+
rfc = RandomForestClassifier()
22+
classifier = LocalClassifierPerNode(local_classifier=rfc, replace_classifiers=False)
23+
24+
# Train local classifier per node
25+
classifier.fit(X_train, Y_train)
26+
27+
# Define Explainer
28+
explainer = Explainer(classifier, data=X_train.values, mode="tree")
29+
explanations = explainer.explain(X_test.values)
30+
print(explanations)
31+
32+
# Filter samples which only predicted "Respiratory" at first level
33+
respiratory_idx = classifier.predict(X_test)[:, 0] == "Respiratory"
34+
35+
# Specify additional filters to obtain only level 0
36+
shap_filter = {"level": 0, "class": "Respiratory_1", "sample": respiratory_idx}
37+
38+
# Use .sel() method to apply the filter and obtain filtered results
39+
shap_val_respiratory = explanations.sel(shap_filter)
40+
41+
# Plot feature importance on test set
42+
shap.plots.violin(
43+
shap_val_respiratory.shap_values,
44+
feature_names=X_train.columns.values,
45+
plot_size=(13, 8),
46+
)

docs/examples/plot_lcppn_explainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
classifier.fit(X_train, Y_train)
2727

2828
# Define Explainer
29-
explainer = Explainer(classifier, data=X_train, mode="tree")
29+
explainer = Explainer(classifier, data=X_train.values, mode="tree")
3030
explanations = explainer.explain(X_test.values)
3131
print(explanations)
3232

hiclass/Explainer.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,46 @@ def _get_traversed_nodes_lcppn(self, samples):
189189
).flatten()
190190
return traversals
191191

192+
def _get_traversed_nodes_lcpn(self, samples):
193+
"""
194+
Return a list of all traversed nodes as per the provided LocalClassifierPerNode model.
195+
196+
Parameters
197+
----------
198+
samples : array-like
199+
Sample data for which to generate traversed nodes.
200+
201+
Returns
202+
-------
203+
traversals : list
204+
A list of all traversed nodes as per LocalClassifierPerNode (LCPN) strategy.
205+
"""
206+
traversals = np.empty(
207+
(samples.shape[0], self.hierarchical_model.max_levels_),
208+
dtype=self.hierarchical_model.dtype_,
209+
)
210+
211+
predictions = self.hierarchical_model.predict(samples)
212+
213+
traversals[:, 0] = predictions[:, 0]
214+
separator = np.full(
215+
(samples.shape[0], 3),
216+
self.hierarchical_model.separator_,
217+
dtype=self.hierarchical_model.dtype_,
218+
)
219+
220+
for level in range(1, traversals.shape[1]):
221+
traversals[:, level] = np.char.add(
222+
traversals[:, level - 1],
223+
np.char.add(separator[:, 0], predictions[:, level]),
224+
)
225+
226+
# For inconsistent hierarchies, levels with empty nodes should be ignored
227+
mask = predictions == ""
228+
traversals[mask] = ""
229+
230+
return traversals
231+
192232
def _calculate_shap_values(self, X):
193233
"""
194234
Return an xarray.Dataset object for a single sample provided. This dataset is aligned on the `level` attribute.
@@ -206,11 +246,16 @@ def _calculate_shap_values(self, X):
206246
traversed_nodes = []
207247
if isinstance(self.hierarchical_model, LocalClassifierPerParentNode):
208248
traversed_nodes = self._get_traversed_nodes_lcppn(X)[0]
249+
elif isinstance(self.hierarchical_model, LocalClassifierPerNode):
250+
traversed_nodes = self._get_traversed_nodes_lcpn(X)[0]
209251
datasets = []
210252
level = 0
211253
for node in traversed_nodes:
212-
# Skip if classifier is not found, can happen in case of imbalanced hierarchies
213-
if "classifier" not in self.hierarchical_model.hierarchy_.nodes[node]:
254+
# Skip if node is empty or classifier is not found, can happen in case of imbalanced hierarchies
255+
if (
256+
node == ""
257+
or "classifier" not in self.hierarchical_model.hierarchy_.nodes[node]
258+
):
214259
continue
215260

216261
local_classifier = self.hierarchical_model.hierarchy_.nodes[node][

tests/test_Explainer.py

Lines changed: 57 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
import numpy as np
22
import pytest
33
from sklearn.ensemble import RandomForestClassifier
4-
from hiclass import (
5-
LocalClassifierPerParentNode,
6-
Explainer,
7-
)
4+
from hiclass import LocalClassifierPerNode, LocalClassifierPerParentNode, Explainer
85

96
try:
107
import shap
@@ -76,6 +73,31 @@ def test_explainer_tree_lcppn(data, request):
7673
assert explanation.data[j].split(lcppn.separator_)[-1] == y_pred[j]
7774

7875

76+
@pytest.mark.skipif(not shap_installed, reason="shap not installed")
77+
@pytest.mark.skipif(not xarray_installed, reason="xarray not installed")
78+
@pytest.mark.parametrize("data", ["explainer_data", "explainer_data_no_root"])
79+
def test_explainer_tree_lcpn(data, request):
80+
rfc = RandomForestClassifier()
81+
lcpn = LocalClassifierPerNode(local_classifier=rfc, replace_classifiers=False)
82+
83+
x_train, x_test, y_train = request.getfixturevalue(data)
84+
85+
lcpn.fit(x_train, y_train)
86+
87+
explainer = Explainer(lcpn, data=x_train, mode="tree")
88+
explanations = explainer.explain(x_test)
89+
90+
# Assert if explainer returns an xarray.Dataset object
91+
assert isinstance(explanations, xarray.Dataset)
92+
y_preds = lcpn.predict(x_test)
93+
94+
# Assert if predictions made are consistent with the explanation object
95+
for i in range(len(x_test)):
96+
y_pred = y_preds[i]
97+
for j in range(len(y_pred)):
98+
assert str(explanations["node"][i].data[j]) == y_pred[j]
99+
100+
79101
@pytest.mark.skipif(not shap_installed, reason="shap not installed")
80102
@pytest.mark.parametrize("data", ["explainer_data", "explainer_data_no_root"])
81103
def test_traversal_path_lcppn(data, request):
@@ -98,10 +120,34 @@ def test_traversal_path_lcppn(data, request):
98120
assert label == preds[i][j - 1]
99121

100122

123+
@pytest.mark.skipif(not shap_installed, reason="shap not installed")
124+
@pytest.mark.parametrize("data", ["explainer_data", "explainer_data_no_root"])
125+
def test_traversal_path_lcpn(data, request):
126+
x_train, x_test, y_train = request.getfixturevalue(data)
127+
rfc = RandomForestClassifier()
128+
lcpn = LocalClassifierPerNode(local_classifier=rfc, replace_classifiers=False)
129+
130+
lcpn.fit(x_train, y_train)
131+
explainer = Explainer(lcpn, data=x_train, mode="tree")
132+
traversals = explainer._get_traversed_nodes_lcpn(x_test)
133+
preds = lcpn.predict(x_test)
134+
135+
# Assert if predictions and traversals are of same length
136+
assert len(preds) == len(traversals)
137+
138+
# Assert if traversal path in predictions is same as the computed traversal path
139+
for i in range(len(x_test)):
140+
for j in range(len(traversals[i])):
141+
label = traversals[i][j].split(lcpn.separator_)[-1]
142+
assert label == preds[i][j]
143+
144+
101145
@pytest.mark.skipif(not shap_installed, reason="shap not installed")
102146
@pytest.mark.skipif(not xarray_installed, reason="xarray not installed")
103147
@pytest.mark.parametrize("data", ["explainer_data", "explainer_data_no_root"])
104-
@pytest.mark.parametrize("classifier", [LocalClassifierPerParentNode])
148+
@pytest.mark.parametrize(
149+
"classifier", [LocalClassifierPerParentNode, LocalClassifierPerNode]
150+
)
105151
def test_explain_with_xr(data, request, classifier):
106152
x_train, x_test, y_train = request.getfixturevalue(data)
107153
rfc = RandomForestClassifier()
@@ -115,7 +161,9 @@ def test_explain_with_xr(data, request, classifier):
115161
assert isinstance(explanations, xarray.Dataset)
116162

117163

118-
@pytest.mark.parametrize("classifier", [LocalClassifierPerParentNode])
164+
@pytest.mark.parametrize(
165+
"classifier", [LocalClassifierPerParentNode, LocalClassifierPerNode]
166+
)
119167
def test_imports(classifier):
120168
x_train = [[76, 12, 49], [88, 63, 31], [5, 42, 24], [17, 90, 55]]
121169
y_train = [["a", "b", "d"], ["a", "b", "e"], ["a", "c", "f"], ["a", "c", "g"]]
@@ -128,7 +176,9 @@ def test_imports(classifier):
128176
assert isinstance(explainer.data, np.ndarray)
129177

130178

131-
@pytest.mark.parametrize("classifier", [LocalClassifierPerParentNode])
179+
@pytest.mark.parametrize(
180+
"classifier", [LocalClassifierPerParentNode, LocalClassifierPerNode]
181+
)
132182
@pytest.mark.parametrize("data", ["explainer_data"])
133183
@pytest.mark.parametrize("mode", ["linear", "gradient", "deep", "tree", ""])
134184
def test_explainers(data, request, classifier, mode):

0 commit comments

Comments
 (0)