Skip to content

Commit 960b9c1

Browse files
committed
Initial notebook
1 parent 73b3368 commit 960b9c1

File tree

3 files changed

+824
-0
lines changed

3 files changed

+824
-0
lines changed

examples/TSICE.ipynb

Lines changed: 680 additions & 0 deletions
Large diffs are not rendered by default.

examples/data/sunspots.zip

22.3 KB
Binary file not shown.

examples/utils/plots.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
import matplotlib.pyplot as plt
2+
import pandas as pd
3+
from sklearn.linear_model import LinearRegression
4+
import numpy as np
5+
6+
def add_timeseries(ax, ts, color="green", name="time series", showlegend=False):
7+
timestamps = ts.index
8+
ax.plot(timestamps, ts[ts.columns[0]], color=color, label=(name if showlegend else '_nolegend_'))
9+
10+
def plot_timeseries(ts, color="green", ax=None, name="time series"):
11+
showlegend = True
12+
if type(ts) == dict:
13+
data = ts
14+
if type(color) == str:
15+
color = {k: color for k in data}
16+
elif type(ts) == list:
17+
data = {}
18+
for k, ts_data in enumerate(ts):
19+
data[k] = ts_data
20+
if type(color) == str:
21+
color = {k: color for k in data}
22+
else:
23+
data = {}
24+
data["default"] = ts
25+
color = {"default": color}
26+
27+
if ax is None:
28+
fig, ax = plt.subplots()
29+
30+
first = True
31+
for key, ts in data.items():
32+
if not first:
33+
showlegend = False
34+
35+
add_timeseries(ax, ts, color=color[key], showlegend=showlegend, name=name)
36+
first = False
37+
38+
return ax
39+
40+
def plot_tsice_explanation(explanation, forecast_horizon):
41+
original_ts = pd.DataFrame(explanation["data_x"])
42+
perturbations = explanation["perturbations"]
43+
forecasts_on_perturbations = explanation["forecasts_on_perturbations"]
44+
45+
new_perturbations = []
46+
new_timestamps = []
47+
pred_ts = []
48+
49+
original_ts.index.freq = pd.infer_freq(original_ts.index)
50+
for i in range(1, forecast_horizon + 1):
51+
new_timestamps.append(original_ts.index[-1] + (i * original_ts.index.freq))
52+
53+
for perturbation in perturbations:
54+
new_perturbations.append(pd.DataFrame(perturbation))
55+
56+
for forecast in forecasts_on_perturbations:
57+
pred_ts.append(pd.DataFrame(forecast, index=new_timestamps))
58+
59+
pred_original_ts = pd.DataFrame(
60+
explanation["current_forecast"], index=new_timestamps
61+
)
62+
63+
fig, ax = plt.subplots()
64+
65+
# Plot perturbed time series
66+
ax = plot_timeseries(new_perturbations, color="lightgreen", ax=ax, name="perturbed timeseries samples")
67+
68+
# Plot original time series
69+
ax = plot_timeseries(original_ts, color="green", ax=ax, name="input/original timeseries")
70+
71+
# Plot varying forecast range
72+
ax = plot_timeseries(pred_ts, color="lightblue", ax=ax, name="forecast on perturbed samples")
73+
74+
# Plot original forecast
75+
ax = plot_timeseries(pred_original_ts, color="blue", ax=ax, name="original forecast")
76+
77+
# Set labels and title
78+
ax.set_xlabel("Month/Year")
79+
ax.set_ylabel("sunspots")
80+
ax.set_title("Time Series Individual Conditional Expectation (TSICE) Plot")
81+
82+
ax.legend()
83+
84+
# Display the plot
85+
plt.show()
86+
87+
# Return the figure
88+
return fig
89+
90+
91+
def plot_tsice_with_observed_features(explanation, feature_per_row=2):
92+
df = pd.DataFrame(explanation["data_x"])
93+
n_row = int(np.ceil(len(explanation["feature_names"]) / feature_per_row))
94+
feat_values = np.array(explanation["feature_values"])
95+
96+
fig, axs = plt.subplots(n_row, feature_per_row, figsize=(15, 15))
97+
axs = axs.ravel() # Flatten the axs to iterate over it
98+
99+
for i, feat in enumerate(explanation["feature_names"]):
100+
x_feat = feat_values[i, :, 0]
101+
trend_fit = LinearRegression()
102+
trend_line = trend_fit.fit(x_feat.reshape(-1, 1), explanation["signed_impact"])
103+
x_trend = np.linspace(min(x_feat), max(x_feat), 101)
104+
y_trend = trend_line.predict(x_trend[..., np.newaxis])
105+
106+
# Scatter plot
107+
axs[i].scatter(x=x_feat, y=explanation["signed_impact"], color='blue')
108+
# Line plot
109+
axs[i].plot(x_trend, y_trend, color="green", label="correlation between forecast and observed feature")
110+
# Reference line
111+
current_value = explanation["current_feature_values"][i][0]
112+
axs[i].axvline(x=current_value, color='firebrick', linestyle='--', label="current value")
113+
114+
axs[i].set_xlabel(feat)
115+
axs[i].set_ylabel('Δ forecast')
116+
117+
# Display the legend on the first subplot
118+
axs[0].legend()
119+
120+
fig.suptitle("Impact of Derived Variable On The Forecast", fontsize=16)
121+
plt.tight_layout()
122+
plt.subplots_adjust(top=0.95)
123+
return fig
124+
125+
def plot_ts(df, df_timestamps, df_timestamp_name, df_targets, df_description):
126+
n_targets = len(df_targets)
127+
128+
fig, axs = plt.subplots(n_targets, 1, figsize=(10, 5 * n_targets))
129+
130+
# In case there's only one target, make sure axs is a list
131+
if n_targets == 1:
132+
axs = [axs]
133+
134+
for ax, target in zip(axs, df_targets):
135+
ax.plot(df_timestamps, df[target], color="black", label=target)
136+
ax.set_xlabel(df_timestamp_name)
137+
ax.set_ylabel(target)
138+
ax.set_title(f"[target] {target}")
139+
ax.legend()
140+
141+
fig.suptitle(df_description)
142+
plt.tight_layout()
143+
plt.subplots_adjust(top=0.95)
144+
return fig

0 commit comments

Comments
 (0)