Skip to content

Commit accf75e

Browse files
authored
Initial support for external algorithms (#160)
* Add dependencies * Add TSICE tests * Fix linting errors * Update tests * Refactored TSICE forecaster to model * Match Pandas dependencies between ODH and XAI360
1 parent ad2b420 commit accf75e

File tree

4 files changed

+258
-0
lines changed

4 files changed

+258
-0
lines changed

.github/workflows/workflow.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,14 @@ jobs:
2727
run: |
2828
pip install .
2929
pip install ".[dev]"
30+
pip install ".[extras]"
3031
- name: Lint
3132
run: |
3233
pylint --ignore-imports=yes $(find src/trustyai -type f -name "*.py")
3334
- name: Test with pytest
3435
run: |
3536
pytest -v -s tests/general
37+
pytest -v -s tests/extras
3638
pytest -v -s tests/initialization --forked
3739
- name: Style
3840
run: |

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ dev = [
4848
"wheel~=0.38.4",
4949
"xgboost==1.4.2"
5050
]
51+
extras = [
52+
"aix360 [default,tsice] @ https://github.com/Trusted-AI/AIX360/archive/refs/heads/master.zip"
53+
]
5154

5255
[project.urls]
5356
homepage = "https://github.com/trustyai-explainability/trustyai-explainability-python"
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
"""
2+
Wrapper module for TSICEExplainer from aix360.
3+
Original at https://github.com/Trusted-AI/AIX360/
4+
"""
5+
# pylint: disable=too-many-arguments,import-error
6+
from typing import Callable, List, Optional, Union
7+
8+
from aix360.algorithms.tsice import TSICEExplainer as TSICEExplainerAIX
9+
from aix360.algorithms.tsutils.tsperturbers import TSPerturber
10+
import bokeh
11+
import pandas as pd
12+
13+
from trustyai.model import SaliencyResults
14+
15+
16+
class TSICEResults(SaliencyResults):
17+
"""Wraps TSICE results. This object is returned by the :class:`~TSICEExplainer`,
18+
and provides a variety of methods to visualize and interact with the explanation.
19+
"""
20+
21+
def __init__(self, explanation):
22+
self.explanation = explanation
23+
24+
def as_dataframe(self) -> pd.DataFrame:
25+
"""Returns the explanation as a pandas dataframe."""
26+
return pd.DataFrame(self.explanation)
27+
28+
def as_html(self) -> pd.io.formats.style.Styler:
29+
"""Returns the explanation as an HTML table."""
30+
dataframe = self.as_dataframe()
31+
return dataframe.style
32+
33+
def saliency_map(self):
34+
"""
35+
Returns a dictionary of feature names and their total impact.
36+
"""
37+
dict(zip(self.explanation["feature_names"], self.explanation["total_impact"]))
38+
39+
def _matplotlib_plot(self, output_name: str, block: bool, call_show: bool) -> None:
40+
pass
41+
42+
def _get_bokeh_plot(self, output_name: str) -> bokeh.models.Plot:
43+
pass
44+
45+
46+
class TSICEExplainer(TSICEExplainerAIX):
47+
"""
48+
Wrapper for TSICEExplainer from aix360.
49+
"""
50+
51+
def __init__(
52+
self,
53+
model: Callable,
54+
input_length: int,
55+
forecast_lookahead: int,
56+
n_variables: int = 1,
57+
n_exogs: int = 0,
58+
n_perturbations: int = 25,
59+
features_to_analyze: Optional[List[str]] = None,
60+
perturbers: Optional[List[Union[TSPerturber, dict]]] = None,
61+
explanation_window_start: Optional[int] = None,
62+
explanation_window_length: int = 10,
63+
):
64+
super().__init__(
65+
forecaster=model,
66+
input_length=input_length,
67+
forecast_lookahead=forecast_lookahead,
68+
n_variables=n_variables,
69+
n_exogs=n_exogs,
70+
n_perturbations=n_perturbations,
71+
features_to_analyze=features_to_analyze,
72+
perturbers=perturbers,
73+
explanation_window_start=explanation_window_start,
74+
explanation_window_length=explanation_window_length,
75+
)
76+
77+
def explain(self, inputs, outputs=None, **kwargs) -> TSICEResults:
78+
"""
79+
Explain the model's prediction on X.
80+
"""
81+
_explanation = super().explain_instance(inputs, y=outputs, **kwargs)
82+
return TSICEResults(_explanation)

tests/extras/tsice/test_tsice.py

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
""" Tests for :py:mod:`aix360.algorithms.tsice.TSICEExplainer`.
2+
Original: https://github.com/Trusted-AI/AIX360/blob/master/tests/tsice/test_tsice.py
3+
"""
4+
import unittest
5+
import numpy as np
6+
import pandas as pd
7+
from sklearn.model_selection import train_test_split
8+
from sklearn.ensemble import RandomForestRegressor
9+
from aix360.algorithms.tsutils.tsframe import tsFrame
10+
from aix360.datasets import SunspotDataset
11+
from aix360.algorithms.tsutils.tsperturbers import BlockBootstrapPerturber
12+
from trustyai.explainers.extras.tsice import TSICEExplainer
13+
14+
15+
# transform a time series dataset into a supervised learning dataset
16+
# below sample forecaster is from: https://machinelearningmastery.com/random-forest-for-time-series-forecasting/
17+
class RandomForestUniVariateForecaster:
18+
def __init__(self, n_past=4, n_future=1, RFparams={"n_estimators": 250}):
19+
self.n_past = n_past
20+
self.n_future = n_future
21+
self.model = RandomForestRegressor(**RFparams)
22+
23+
def fit(self, X):
24+
train = self._series_to_supervised(X, n_in=self.n_past, n_out=self.n_future)
25+
trainX, trainy = train[:, : -self.n_future], train[:, -self.n_future:]
26+
self.model = self.model.fit(trainX, trainy)
27+
return self
28+
29+
def _series_to_supervised(self, data, n_in=1, n_out=1, dropnan=True):
30+
1 if type(data) is list else data.shape[1]
31+
df = pd.DataFrame(data)
32+
cols = list()
33+
34+
# input sequence (t-n, ... t-1)
35+
for i in range(n_in, 0, -1):
36+
cols.append(df.shift(i))
37+
# forecast sequence (t, t+1, ... t+n)
38+
for i in range(0, n_out):
39+
cols.append(df.shift(-i))
40+
# put it all together
41+
agg = pd.concat(cols, axis=1)
42+
# drop rows with NaN values
43+
if dropnan:
44+
agg.dropna(inplace=True)
45+
return agg.values
46+
47+
def predict(self, X):
48+
row = X[-self.n_past:].flatten()
49+
y_pred = self.model.predict(np.asarray([row]))
50+
return y_pred
51+
52+
53+
class TestTSICEExplainer(unittest.TestCase):
54+
def setUp(self):
55+
# load data
56+
df, schema = SunspotDataset().load_data()
57+
ts = tsFrame(
58+
df, timestamp_column=schema["timestamp"], columns=schema["targets"]
59+
)
60+
61+
(self.ts_train, self.ts_test) = train_test_split(
62+
ts, shuffle=False, stratify=None, test_size=0.15, train_size=None
63+
)
64+
65+
def test_tsice_with_range(self):
66+
# load model
67+
input_length = 24
68+
forecast_horizon = 4
69+
forecaster = RandomForestUniVariateForecaster(
70+
n_past=input_length, n_future=forecast_horizon
71+
)
72+
73+
forecaster.fit(self.ts_train.iloc[-200:])
74+
75+
# initialize/fit explainer
76+
observation_length = 12
77+
explainer = TSICEExplainer(
78+
model=forecaster.predict,
79+
explanation_window_start=10,
80+
explanation_window_length=observation_length,
81+
features_to_analyze=[
82+
"mean", # analyze mean metric from recent time series of lengh <observation_length>
83+
],
84+
perturbers=[
85+
BlockBootstrapPerturber(window_length=5, block_length=5, block_swap=2),
86+
],
87+
input_length=input_length,
88+
forecast_lookahead=forecast_horizon,
89+
n_perturbations=20,
90+
)
91+
92+
# compute explanations
93+
explanation = explainer.explain(
94+
inputs=self.ts_test.iloc[:80],
95+
)
96+
97+
# validate explanation structure
98+
self.assertIn("data_x", explanation.explanation)
99+
self.assertIn("feature_names", explanation.explanation)
100+
self.assertIn("feature_values", explanation.explanation)
101+
self.assertIn("signed_impact", explanation.explanation)
102+
self.assertIn("total_impact", explanation.explanation)
103+
self.assertIn("current_forecast", explanation.explanation)
104+
self.assertIn("current_feature_values", explanation.explanation)
105+
self.assertIn("perturbations", explanation.explanation)
106+
self.assertIn("forecasts_on_perturbations", explanation.explanation)
107+
108+
def test_tsice_with_latest(self):
109+
# load model
110+
input_length = 24
111+
forecast_horizon = 4
112+
forecaster = RandomForestUniVariateForecaster(
113+
n_past=input_length, n_future=forecast_horizon
114+
)
115+
116+
forecaster.fit(self.ts_train.iloc[-200:])
117+
118+
# initialize/fit explainer
119+
observation_length = 12
120+
explainer = TSICEExplainer(
121+
model=forecaster.predict,
122+
explanation_window_start=None,
123+
explanation_window_length=observation_length,
124+
features_to_analyze=[
125+
"mean", # analyze mean metric from recent time series of lengh <observation_length>
126+
"median", # analyze median metric from recent time series of lengh <observation_length>
127+
"std", # analyze std metric from recent time series of lengh <observation_length>
128+
"max_variation", # analyze max_variation metric from recent time series of lengh <observation_length>
129+
"min",
130+
"max",
131+
"range",
132+
"intercept",
133+
"trend",
134+
"rsquared",
135+
],
136+
perturbers=[
137+
BlockBootstrapPerturber(window_length=5, block_length=5, block_swap=2),
138+
dict(
139+
type="frequency",
140+
window_length=5,
141+
truncate_frequencies=5,
142+
block_length=4,
143+
),
144+
dict(type="moving-average", window_length=5, lag=5, block_length=4),
145+
dict(type="impute", block_length=4),
146+
dict(type="shift", block_length=4),
147+
],
148+
input_length=input_length,
149+
forecast_lookahead=forecast_horizon,
150+
n_perturbations=20,
151+
)
152+
153+
# compute explanations
154+
explanation = explainer.explain(
155+
inputs=self.ts_test.iloc[:80],
156+
)
157+
158+
# validate explanation structure
159+
self.assertIn("data_x", explanation.explanation)
160+
self.assertIn("feature_names", explanation.explanation)
161+
self.assertIn("feature_values", explanation.explanation)
162+
self.assertIn("signed_impact", explanation.explanation)
163+
self.assertIn("total_impact", explanation.explanation)
164+
self.assertIn("current_forecast", explanation.explanation)
165+
self.assertIn("current_feature_values", explanation.explanation)
166+
self.assertIn("perturbations", explanation.explanation)
167+
self.assertIn("forecasts_on_perturbations", explanation.explanation)
168+
169+
170+
if __name__ == "__main__":
171+
unittest.main()

0 commit comments

Comments
 (0)