Skip to content

Commit cbee162

Browse files
authored
Support TSICE explanations as plots (#169)
* Add plots to TSICE explanation * Remove duplicate plots and add perturbers import * Correct subclass for TSICE explanation results * Fix linting and formatting
1 parent 0cc313e commit cbee162

File tree

3 files changed

+183
-15
lines changed

3 files changed

+183
-15
lines changed

src/trustyai/explainers/extras/tsice.py

Lines changed: 178 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,15 @@
77

88
from aix360.algorithms.tsice import TSICEExplainer as TSICEExplainerAIX
99
from aix360.algorithms.tsutils.tsperturbers import TSPerturber
10-
import bokeh
1110
import pandas as pd
11+
import matplotlib.pyplot as plt
12+
import numpy as np
13+
from sklearn.linear_model import LinearRegression
1214

13-
from trustyai.model import SaliencyResults
15+
from trustyai.explainers.explanation_results import ExplanationResults
1416

1517

16-
class TSICEResults(SaliencyResults):
18+
class TSICEResults(ExplanationResults):
1719
"""Wraps TSICE results. This object is returned by the :class:`~TSICEExplainer`,
1820
and provides a variety of methods to visualize and interact with the explanation.
1921
"""
@@ -23,24 +25,187 @@ def __init__(self, explanation):
2325

2426
def as_dataframe(self) -> pd.DataFrame:
2527
"""Returns the explanation as a pandas dataframe."""
26-
return pd.DataFrame(self.explanation)
28+
# Initialize an empty DataFrame
29+
dataframe = pd.DataFrame()
30+
31+
# Loop through each feature_name and each key in data_x
32+
for key in self.explanation["data_x"]:
33+
for i, feature in enumerate(self.explanation["feature_names"]):
34+
dataframe[f"{key}-{feature}"] = [
35+
val[0] for val in self.explanation["feature_values"][i]
36+
]
37+
38+
# Add "total_impact" as a column
39+
dataframe["total_impact"] = self.explanation["total_impact"]
40+
return dataframe
2741

2842
def as_html(self) -> pd.io.formats.style.Styler:
2943
"""Returns the explanation as an HTML table."""
3044
dataframe = self.as_dataframe()
3145
return dataframe.style
3246

33-
def saliency_map(self):
34-
"""
35-
Returns a dictionary of feature names and their total impact.
36-
"""
37-
dict(zip(self.explanation["feature_names"], self.explanation["total_impact"]))
47+
def plot_forecast(self, variable): # pylint: disable=too-many-locals
48+
"""Plots the explanation.
49+
Based on https://github.com/Trusted-AI/AIX360/blob/master/examples/tsice/plots.py"""
50+
forecast_horizon = self.explanation["current_forecast"].shape[0]
51+
original_ts = pd.DataFrame(
52+
data={variable: self.explanation["data_x"][variable]}
53+
)
54+
perturbations = [d for d in self.explanation["perturbations"] if variable in d]
55+
56+
# Generate a list of keys
57+
keys = list(self.explanation["data_x"].keys())
58+
# Find the index of the given key
59+
key = keys.index(variable)
60+
forecasts_on_perturbations = [
61+
arr[:, key : key + 1]
62+
for arr in self.explanation["forecasts_on_perturbations"]
63+
]
64+
65+
new_perturbations = []
66+
new_timestamps = []
67+
pred_ts = []
68+
69+
original_ts.index.freq = pd.infer_freq(original_ts.index)
70+
for i in range(1, forecast_horizon + 1):
71+
new_timestamps.append(original_ts.index[-1] + (i * original_ts.index.freq))
72+
73+
for perturbation in perturbations:
74+
new_perturbations.append(pd.DataFrame(perturbation))
75+
76+
for forecast in forecasts_on_perturbations:
77+
pred_ts.append(pd.DataFrame(forecast, index=new_timestamps))
78+
79+
current_forecast = self.explanation["current_forecast"][:, key : key + 1]
80+
pred_original_ts = pd.DataFrame(current_forecast, index=new_timestamps)
81+
82+
_, axis = plt.subplots()
83+
84+
# Plot perturbed time series
85+
axis = self._plot_timeseries(
86+
new_perturbations,
87+
color="lightgreen",
88+
axis=axis,
89+
name="perturbed timeseries samples",
90+
)
91+
92+
# Plot original time series
93+
axis = self._plot_timeseries(
94+
original_ts, color="green", axis=axis, name="input/original timeseries"
95+
)
96+
97+
# Plot varying forecast range
98+
axis = self._plot_timeseries(
99+
pred_ts, color="lightblue", axis=axis, name="forecast on perturbed samples"
100+
)
101+
102+
# Plot original forecast
103+
axis = self._plot_timeseries(
104+
pred_original_ts, color="blue", axis=axis, name="original forecast"
105+
)
106+
107+
# Set labels and title
108+
axis.set_xlabel("Timestamp")
109+
axis.set_ylabel(variable)
110+
axis.set_title("Time-Series Individual Conditional Expectation (TSICE)")
111+
112+
axis.legend()
113+
114+
# Display the plot
115+
plt.show()
116+
117+
def _plot_timeseries(
118+
self, timeseries, color="green", axis=None, name="time series"
119+
):
120+
showlegend = True
121+
if isinstance(timeseries, dict):
122+
data = timeseries
123+
if isinstance(color, str):
124+
color = {k: color for k in data}
125+
elif isinstance(timeseries, list):
126+
data = {}
127+
for k, ts_data in enumerate(timeseries):
128+
data[k] = ts_data
129+
if isinstance(color, str):
130+
color = {k: color for k in data}
131+
else:
132+
data = {}
133+
data["default"] = timeseries
134+
color = {"default": color}
135+
136+
if axis is None:
137+
_, axis = plt.subplots()
138+
139+
first = True
140+
for key, _timeseries in data.items():
141+
if not first:
142+
showlegend = False
143+
144+
self._add_timeseries(
145+
axis, _timeseries, color=color[key], showlegend=showlegend, name=name
146+
)
147+
first = False
148+
149+
return axis
150+
151+
def _add_timeseries(
152+
self, axis, timeseries, color="green", name="time series", showlegend=False
153+
):
154+
timestamps = timeseries.index
155+
axis.plot(
156+
timestamps,
157+
timeseries[timeseries.columns[0]],
158+
color=color,
159+
label=(name if showlegend else "_nolegend_"),
160+
)
161+
162+
def plot_impact(self, feature_per_row=2):
163+
"""Plot the impace.
164+
Based on https://github.com/Trusted-AI/AIX360/blob/master/examples/tsice/plots.py"""
165+
166+
n_row = int(np.ceil(len(self.explanation["feature_names"]) / feature_per_row))
167+
feat_values = np.array(self.explanation["feature_values"])
168+
169+
fig, axs = plt.subplots(n_row, feature_per_row, figsize=(15, 15))
170+
axs = axs.ravel() # Flatten the axs to iterate over it
171+
172+
for i, feat in enumerate(self.explanation["feature_names"]):
173+
x_feat = feat_values[i, :, 0]
174+
trend_fit = LinearRegression()
175+
trend_line = trend_fit.fit(
176+
x_feat.reshape(-1, 1), self.explanation["signed_impact"]
177+
)
178+
x_trend = np.linspace(min(x_feat), max(x_feat), 101)
179+
y_trend = trend_line.predict(x_trend[..., np.newaxis])
180+
181+
# Scatter plot
182+
axs[i].scatter(x=x_feat, y=self.explanation["signed_impact"], color="blue")
183+
# Line plot
184+
axs[i].plot(
185+
x_trend,
186+
y_trend,
187+
color="green",
188+
label="correlation between forecast and observed feature",
189+
)
190+
# Reference line
191+
current_value = self.explanation["current_feature_values"][i][0]
192+
axs[i].axvline(
193+
x=current_value,
194+
color="firebrick",
195+
linestyle="--",
196+
label="current value",
197+
)
198+
199+
axs[i].set_xlabel(feat)
200+
axs[i].set_ylabel("Δ forecast")
38201

39-
def _matplotlib_plot(self, output_name: str, block: bool, call_show: bool) -> None:
40-
pass
202+
# Display the legend on the first subplot
203+
axs[0].legend()
41204

42-
def _get_bokeh_plot(self, output_name: str) -> bokeh.models.Plot:
43-
pass
205+
fig.suptitle("Impact of Derived Variable On The Forecast", fontsize=16)
206+
plt.tight_layout()
207+
plt.subplots_adjust(top=0.95)
208+
plt.show()
44209

45210

46211
class TSICEExplainer(TSICEExplainerAIX):
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
"""Extra time series utilities."""
2+
from aix360.algorithms.tsutils.tsframe import tsFrame # pylint: disable=unused-import
3+
from aix360.algorithms.tsutils.tsperturbers import * # pylint: disable=wildcard-import,unused-wildcard-import

tests/extras/tsice/test_tsice.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,14 +79,14 @@ def test_tsice_with_range(self):
7979
explanation_window_start=10,
8080
explanation_window_length=observation_length,
8181
features_to_analyze=[
82-
"mean", # analyze mean metric from recent time series of lengh <observation_length>
82+
"mean", "std" # analyze mean metric from recent time series of lengh <observation_length>
8383
],
8484
perturbers=[
8585
BlockBootstrapPerturber(window_length=5, block_length=5, block_swap=2),
8686
],
8787
input_length=input_length,
8888
forecast_lookahead=forecast_horizon,
89-
n_perturbations=20,
89+
n_perturbations=30,
9090
)
9191

9292
# compute explanations

0 commit comments

Comments
 (0)