Skip to content

Commit 20ae91a

Browse files
authored
Add TSSaliency to Python TrustyAI (#172)
* Add TSSaliency * Fix linting and formatting
1 parent cbee162 commit 20ae91a

File tree

3 files changed

+179
-1
lines changed

3 files changed

+179
-1
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ dev = [
4949
"xgboost==1.4.2"
5050
]
5151
extras = [
52-
"aix360 [default,tsice] @ https://github.com/Trusted-AI/AIX360/archive/refs/heads/master.zip"
52+
"aix360 [default,tsice,tssaliency] @ https://github.com/Trusted-AI/AIX360/archive/refs/heads/master.zip"
5353
]
5454

5555
[project.urls]
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
"""
2+
Wrapper module for TSSaliencyExplainer from aix360.
3+
Original at https://github.com/Trusted-AI/AIX360/
4+
"""
5+
6+
from typing import Callable, List
7+
8+
import pandas as pd
9+
import numpy as np
10+
from aix360.algorithms.tssaliency import TSSaliencyExplainer as TSSaliencyExplainerAIX
11+
from pandas.io.formats.style import Styler
12+
import matplotlib.pyplot as plt
13+
14+
from trustyai.explainers.explanation_results import ExplanationResults
15+
16+
17+
class TSSaliencyResults(ExplanationResults):
18+
"""Wraps TSSaliency results. This object is returned by the :class:`~TSSaliencyExplainer`,
19+
and provides a variety of methods to visualize and interact with the explanation.
20+
"""
21+
22+
def __init__(self, explanation):
23+
self.explanation = explanation
24+
25+
def as_dataframe(self) -> pd.DataFrame:
26+
saliencies = self.explanation["saliency"].reshape(-1)
27+
return pd.DataFrame(saliencies, columns=self.explanation["feature_names"])
28+
29+
def as_html(self) -> Styler:
30+
"""Returns the explanation as an HTML table."""
31+
dataframe = self.as_dataframe()
32+
return dataframe.style
33+
34+
def plot(self):
35+
"""Plot tssaliency explanation for the test point
36+
Based on https://github.com/Trusted-AI/AIX360/blob/master/examples/tssaliency"""
37+
max_abs = np.max(np.abs(self.explanation["saliency"]))
38+
39+
plt.imshow(
40+
self.explanation["saliency"][np.newaxis, :],
41+
aspect="auto",
42+
cmap="seismic",
43+
vmin=-max_abs,
44+
vmax=max_abs,
45+
)
46+
plt.colorbar()
47+
plt.plot(self.explanation["input_data"])
48+
plt.show()
49+
50+
51+
class TSSaliencyExplainer(TSSaliencyExplainerAIX):
52+
"""
53+
Wrapper for TSSaliencyExplainer from aix360.
54+
"""
55+
56+
def __init__( # pylint: disable=too-many-arguments
57+
self,
58+
model: Callable,
59+
input_length: int,
60+
feature_names: List[str],
61+
base_value: List[float] = None,
62+
n_samples: int = 50,
63+
gradient_samples: int = 25,
64+
gradient_function: Callable = None,
65+
random_seed: int = 22,
66+
):
67+
super().__init__(
68+
model=model,
69+
input_length=input_length,
70+
feature_names=feature_names,
71+
base_value=base_value,
72+
n_samples=n_samples,
73+
gradient_samples=gradient_samples,
74+
gradient_function=gradient_function,
75+
random_seed=random_seed,
76+
)
77+
78+
def explain(self, inputs, outputs=None, **kwargs) -> TSSaliencyResults:
79+
"""
80+
Explain the model's prediction on X.
81+
"""
82+
_explanation = super().explain_instance(inputs, y=outputs, **kwargs)
83+
return TSSaliencyResults(_explanation)

tests/extras/tsice/test_tssaliency.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import unittest
2+
import numpy as np
3+
import pandas as pd
4+
from sklearn.model_selection import train_test_split
5+
from sklearn.ensemble import RandomForestRegressor
6+
7+
from aix360.datasets import SunspotDataset
8+
from trustyai.explainers.extras.tssaliency import TSSaliencyExplainer
9+
from trustyai.utils.extras.timeseries import tsFrame
10+
11+
12+
# transform a time series dataset into a supervised learning dataset
13+
# below sample forecaster is from: https://machinelearningmastery.com/random-forest-for-time-series-forecasting/
14+
class RandomForestUniVariateForecaster:
15+
def __init__(self, n_past=4, n_future=1, RFparams={"n_estimators": 250}):
16+
self.n_past = n_past
17+
self.n_future = n_future
18+
self.model = RandomForestRegressor(**RFparams)
19+
20+
def fit(self, X):
21+
train = self._series_to_supervised(X, n_in=self.n_past, n_out=self.n_future)
22+
trainX, trainy = train[:, : -self.n_future], train[:, -self.n_future:]
23+
self.model = self.model.fit(trainX, trainy)
24+
return self
25+
26+
def _series_to_supervised(self, data, n_in=1, n_out=1, dropnan=True):
27+
n_vars = 1 if type(data) is list else data.shape[1]
28+
df = pd.DataFrame(data)
29+
cols = list()
30+
31+
# input sequence (t-n, ... t-1)
32+
for i in range(n_in, 0, -1):
33+
cols.append(df.shift(i))
34+
# forecast sequence (t, t+1, ... t+n)
35+
for i in range(0, n_out):
36+
cols.append(df.shift(-i))
37+
# put it all together
38+
agg = pd.concat(cols, axis=1)
39+
# drop rows with NaN values
40+
if dropnan:
41+
agg.dropna(inplace=True)
42+
return agg.values
43+
44+
def predict(self, X):
45+
row = X[-self.n_past:].flatten()
46+
y_pred = self.model.predict(np.asarray([row]))
47+
return y_pred
48+
49+
50+
class TestTSSaliencyExplainer(unittest.TestCase):
51+
def setUp(self):
52+
# load data
53+
df, schema = SunspotDataset().load_data()
54+
ts = tsFrame(
55+
df, timestamp_column=schema["timestamp"], columns=schema["targets"]
56+
)
57+
58+
(self.ts_train, self.ts_test) = train_test_split(
59+
ts, shuffle=False, stratify=None, test_size=0.15, train_size=None
60+
)
61+
62+
def test_tssaliency(self):
63+
# load model
64+
input_length = 48
65+
forecast_horizon = 10
66+
forecaster = RandomForestUniVariateForecaster(
67+
n_past=input_length, n_future=forecast_horizon
68+
)
69+
70+
forecaster.fit(self.ts_train.iloc[-200:])
71+
72+
# initialize/fit explainer
73+
74+
explainer = TSSaliencyExplainer(
75+
model=forecaster.predict,
76+
input_length=input_length,
77+
feature_names=self.ts_train.columns.tolist(),
78+
n_samples=2,
79+
gradient_samples=50,
80+
)
81+
82+
# compute explanations
83+
test_window = self.ts_test.iloc[:input_length]
84+
explanation = explainer.explain(test_window)
85+
86+
# validate explanation structure
87+
self.assertIn("input_data", explanation.explanation)
88+
self.assertIn("feature_names", explanation.explanation)
89+
self.assertIn("saliency", explanation.explanation)
90+
self.assertIn("timestamps", explanation.explanation)
91+
self.assertIn("base_value", explanation.explanation)
92+
self.assertIn("instance_prediction", explanation.explanation)
93+
self.assertIn("base_value_prediction", explanation.explanation)
94+
95+
self.assertEqual(explanation.explanation["saliency"].shape, test_window.shape)

0 commit comments

Comments
 (0)