Skip to content

Commit 6d9f6b4

Browse files
authored
Add plot and as_dataframe (#174)
- Move tests - Add tslime extra dependency
1 parent 20ae91a commit 6d9f6b4

File tree

6 files changed

+199
-1
lines changed

6 files changed

+199
-1
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ dev = [
4949
"xgboost==1.4.2"
5050
]
5151
extras = [
52-
"aix360 [default,tsice,tssaliency] @ https://github.com/Trusted-AI/AIX360/archive/refs/heads/master.zip"
52+
"aix360 [default,tsice,tslime,tssaliency] @ https://github.com/Trusted-AI/AIX360/archive/refs/heads/master.zip"
5353
]
5454

5555
[project.urls]
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
"""
2+
Wrapper module for TSLIME from aix360.
3+
Original at https://github.com/Trusted-AI/AIX360/
4+
"""
5+
6+
from typing import Callable, List, Union
7+
8+
import pandas as pd
9+
import numpy as np
10+
from aix360.algorithms.tslime import TSLimeExplainer as TSLimeExplainerAIX
11+
from aix360.algorithms.tslime.surrogate import LinearSurrogateModel
12+
from pandas.io.formats.style import Styler
13+
import matplotlib.pyplot as plt
14+
15+
from trustyai.explainers.explanation_results import ExplanationResults
16+
from trustyai.utils.extras.timeseries import TSPerturber
17+
18+
19+
class TSSLIMEResults(ExplanationResults):
20+
"""Wraps TSLimeExplainer results. This object is returned by the :class:`~TSLimeExplainer`,
21+
and provides a variety of methods to visualize and interact with the explanation.
22+
"""
23+
24+
def __init__(self, explanation):
25+
self.explanation = explanation
26+
27+
def as_dataframe(self) -> pd.DataFrame:
28+
"""Returns the weights as a pandas dataframe."""
29+
return pd.DataFrame(self.explanation["history_weights"])
30+
31+
def as_html(self) -> Styler:
32+
"""Returns the explanation as an HTML table."""
33+
dataframe = self.as_dataframe()
34+
return dataframe.style
35+
36+
def plot(self):
37+
"""Plot TSLime explanation for the time-series instance. Based on
38+
https://github.com/Trusted-AI/AIX360/blob/master/examples/tslime/tslime_univariate_demo.ipynb"""
39+
relevant_history = self.explanation["history_weights"].shape[0]
40+
input_data = self.explanation["input_data"]
41+
relevant_df = input_data[-relevant_history:]
42+
43+
plt.figure(layout="constrained")
44+
plt.plot(relevant_df, label="Input Time Series", marker="o")
45+
plt.gca().invert_yaxis()
46+
47+
normalized_weights = (
48+
self.explanation["history_weights"]
49+
/ np.mean(np.abs(self.explanation["history_weights"]))
50+
).flatten()
51+
52+
plt.bar(
53+
input_data.index[-relevant_history:],
54+
normalized_weights,
55+
0.4,
56+
label="TSLime Weights (Normalized)",
57+
color="red",
58+
)
59+
plt.axhline(y=0, color="r", linestyle="-", alpha=0.4)
60+
plt.title("Time Series Lime Explanation Plot")
61+
plt.legend(bbox_to_anchor=(1.25, 1.0), loc="upper right")
62+
plt.show()
63+
64+
65+
class TSLimeExplainer(TSLimeExplainerAIX):
66+
"""
67+
Wrapper for TSLimeExplainer from aix360.
68+
"""
69+
70+
def __init__( # pylint: disable=too-many-arguments
71+
self,
72+
model: Callable,
73+
input_length: int,
74+
n_perturbations: int = 2000,
75+
relevant_history: int = None,
76+
perturbers: List[Union[TSPerturber, dict]] = None,
77+
local_interpretable_model: LinearSurrogateModel = None,
78+
random_seed: int = None,
79+
):
80+
super().__init__(
81+
model=model,
82+
input_length=input_length,
83+
n_perturbations=n_perturbations,
84+
relevant_history=relevant_history,
85+
perturbers=perturbers,
86+
local_interpretable_model=local_interpretable_model,
87+
random_seed=random_seed,
88+
)
89+
90+
def explain(self, inputs, outputs=None, **kwargs) -> TSSLIMEResults:
91+
"""
92+
Explain the model's prediction on X.
93+
"""
94+
_explanation = super().explain_instance(inputs, y=outputs, **kwargs)
95+
return TSSLIMEResults(_explanation)

src/trustyai/utils/extras/models.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
"""AIX360 model wrappers"""
2+
from aix360.algorithms.tsutils.model_wrappers import * # pylint: disable=wildcard-import,unused-wildcard-import
File renamed without changes.

tests/extras/test_tslime.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
import os
2+
import unittest
3+
import numpy as np
4+
import pandas as pd
5+
from sklearn.model_selection import train_test_split
6+
from sklearn.ensemble import RandomForestRegressor
7+
from trustyai.utils.extras.timeseries import tsFrame
8+
from aix360.datasets import SunspotDataset
9+
from trustyai.explainers.extras.tslime import TSLimeExplainer
10+
from trustyai.utils.extras.timeseries import BlockBootstrapPerturber
11+
12+
13+
# transform a time series dataset into a supervised learning dataset
14+
# below sample forecaster is from: https://machinelearningmastery.com/random-forest-for-time-series-forecasting/
15+
class RandomForestUniVariateForecaster:
16+
def __init__(self, n_past=4, n_future=1, RFparams={"n_estimators": 250}):
17+
self.n_past = n_past
18+
self.n_future = n_future
19+
self.model = RandomForestRegressor(**RFparams)
20+
21+
def fit(self, X):
22+
train = self._series_to_supervised(X, n_in=self.n_past, n_out=self.n_future)
23+
trainX, trainy = train[:, : -self.n_future], train[:, -self.n_future:]
24+
self.model = self.model.fit(trainX, trainy)
25+
return self
26+
27+
def _series_to_supervised(self, data, n_in=1, n_out=1, dropnan=True):
28+
n_vars = 1 if type(data) is list else data.shape[1]
29+
df = pd.DataFrame(data)
30+
cols = list()
31+
32+
# input sequence (t-n, ... t-1)
33+
for i in range(n_in, 0, -1):
34+
cols.append(df.shift(i))
35+
# forecast sequence (t, t+1, ... t+n)
36+
for i in range(0, n_out):
37+
cols.append(df.shift(-i))
38+
# put it all together
39+
agg = pd.concat(cols, axis=1)
40+
# drop rows with NaN values
41+
if dropnan:
42+
agg.dropna(inplace=True)
43+
return agg.values
44+
45+
def predict(self, X):
46+
row = X[-self.n_past:].flatten()
47+
y_pred = self.model.predict(np.asarray([row]))
48+
return y_pred
49+
50+
51+
class TestTSLimeExplainer(unittest.TestCase):
52+
def setUp(self):
53+
# load data
54+
df, schema = SunspotDataset().load_data()
55+
ts = tsFrame(
56+
df, timestamp_column=schema["timestamp"], columns=schema["targets"]
57+
)
58+
59+
(self.ts_train, self.ts_test) = train_test_split(
60+
ts, shuffle=False, stratify=None, test_size=0.15, train_size=None
61+
)
62+
63+
def test_tslime(self):
64+
# load model
65+
input_length = 24
66+
forecast_horizon = 4
67+
forecaster = RandomForestUniVariateForecaster(
68+
n_past=input_length, n_future=forecast_horizon
69+
)
70+
71+
forecaster.fit(self.ts_train.iloc[-200:])
72+
73+
# initialize/fit explainer
74+
75+
relevant_history = 12
76+
explainer = TSLimeExplainer(
77+
model=forecaster.predict,
78+
input_length=input_length,
79+
relevant_history=relevant_history,
80+
perturbers=[
81+
BlockBootstrapPerturber(
82+
window_length=min(4, input_length - 1), block_length=2, block_swap=2
83+
),
84+
],
85+
n_perturbations=10,
86+
random_seed=22,
87+
)
88+
89+
# compute explanations
90+
test_window = self.ts_test.iloc[:input_length]
91+
explanation = explainer.explain(test_window)
92+
93+
# validate explanation structure
94+
self.assertIn("input_data", explanation.explanation)
95+
self.assertIn("history_weights", explanation.explanation)
96+
self.assertIn("x_perturbations", explanation.explanation)
97+
self.assertIn("y_perturbations", explanation.explanation)
98+
self.assertIn("model_prediction", explanation.explanation)
99+
self.assertIn("surrogate_prediction", explanation.explanation)
100+
101+
self.assertEqual(explanation.explanation["history_weights"].shape[0], relevant_history)

0 commit comments

Comments
 (0)