7
7
8
8
from aix360 .algorithms .tsice import TSICEExplainer as TSICEExplainerAIX
9
9
from aix360 .algorithms .tsutils .tsperturbers import TSPerturber
10
- import bokeh
11
10
import pandas as pd
11
+ import matplotlib .pyplot as plt
12
+ import numpy as np
13
+ from sklearn .linear_model import LinearRegression
12
14
13
- from trustyai .model import SaliencyResults
15
+ from trustyai .explainers . explanation_results import ExplanationResults
14
16
15
17
16
- class TSICEResults (SaliencyResults ):
18
+ class TSICEResults (ExplanationResults ):
17
19
"""Wraps TSICE results. This object is returned by the :class:`~TSICEExplainer`,
18
20
and provides a variety of methods to visualize and interact with the explanation.
19
21
"""
@@ -23,24 +25,187 @@ def __init__(self, explanation):
23
25
24
26
def as_dataframe (self ) -> pd .DataFrame :
25
27
"""Returns the explanation as a pandas dataframe."""
26
- return pd .DataFrame (self .explanation )
28
+ # Initialize an empty DataFrame
29
+ dataframe = pd .DataFrame ()
30
+
31
+ # Loop through each feature_name and each key in data_x
32
+ for key in self .explanation ["data_x" ]:
33
+ for i , feature in enumerate (self .explanation ["feature_names" ]):
34
+ dataframe [f"{ key } -{ feature } " ] = [
35
+ val [0 ] for val in self .explanation ["feature_values" ][i ]
36
+ ]
37
+
38
+ # Add "total_impact" as a column
39
+ dataframe ["total_impact" ] = self .explanation ["total_impact" ]
40
+ return dataframe
27
41
28
42
def as_html (self ) -> pd .io .formats .style .Styler :
29
43
"""Returns the explanation as an HTML table."""
30
44
dataframe = self .as_dataframe ()
31
45
return dataframe .style
32
46
33
- def saliency_map (self ):
34
- """
35
- Returns a dictionary of feature names and their total impact.
36
- """
37
- dict (zip (self .explanation ["feature_names" ], self .explanation ["total_impact" ]))
47
+ def plot_forecast (self , variable ): # pylint: disable=too-many-locals
48
+ """Plots the explanation.
49
+ Based on https://github.com/Trusted-AI/AIX360/blob/master/examples/tsice/plots.py"""
50
+ forecast_horizon = self .explanation ["current_forecast" ].shape [0 ]
51
+ original_ts = pd .DataFrame (
52
+ data = {variable : self .explanation ["data_x" ][variable ]}
53
+ )
54
+ perturbations = [d for d in self .explanation ["perturbations" ] if variable in d ]
55
+
56
+ # Generate a list of keys
57
+ keys = list (self .explanation ["data_x" ].keys ())
58
+ # Find the index of the given key
59
+ key = keys .index (variable )
60
+ forecasts_on_perturbations = [
61
+ arr [:, key : key + 1 ]
62
+ for arr in self .explanation ["forecasts_on_perturbations" ]
63
+ ]
64
+
65
+ new_perturbations = []
66
+ new_timestamps = []
67
+ pred_ts = []
68
+
69
+ original_ts .index .freq = pd .infer_freq (original_ts .index )
70
+ for i in range (1 , forecast_horizon + 1 ):
71
+ new_timestamps .append (original_ts .index [- 1 ] + (i * original_ts .index .freq ))
72
+
73
+ for perturbation in perturbations :
74
+ new_perturbations .append (pd .DataFrame (perturbation ))
75
+
76
+ for forecast in forecasts_on_perturbations :
77
+ pred_ts .append (pd .DataFrame (forecast , index = new_timestamps ))
78
+
79
+ current_forecast = self .explanation ["current_forecast" ][:, key : key + 1 ]
80
+ pred_original_ts = pd .DataFrame (current_forecast , index = new_timestamps )
81
+
82
+ _ , axis = plt .subplots ()
83
+
84
+ # Plot perturbed time series
85
+ axis = self ._plot_timeseries (
86
+ new_perturbations ,
87
+ color = "lightgreen" ,
88
+ axis = axis ,
89
+ name = "perturbed timeseries samples" ,
90
+ )
91
+
92
+ # Plot original time series
93
+ axis = self ._plot_timeseries (
94
+ original_ts , color = "green" , axis = axis , name = "input/original timeseries"
95
+ )
96
+
97
+ # Plot varying forecast range
98
+ axis = self ._plot_timeseries (
99
+ pred_ts , color = "lightblue" , axis = axis , name = "forecast on perturbed samples"
100
+ )
101
+
102
+ # Plot original forecast
103
+ axis = self ._plot_timeseries (
104
+ pred_original_ts , color = "blue" , axis = axis , name = "original forecast"
105
+ )
106
+
107
+ # Set labels and title
108
+ axis .set_xlabel ("Timestamp" )
109
+ axis .set_ylabel (variable )
110
+ axis .set_title ("Time-Series Individual Conditional Expectation (TSICE)" )
111
+
112
+ axis .legend ()
113
+
114
+ # Display the plot
115
+ plt .show ()
116
+
117
+ def _plot_timeseries (
118
+ self , timeseries , color = "green" , axis = None , name = "time series"
119
+ ):
120
+ showlegend = True
121
+ if isinstance (timeseries , dict ):
122
+ data = timeseries
123
+ if isinstance (color , str ):
124
+ color = {k : color for k in data }
125
+ elif isinstance (timeseries , list ):
126
+ data = {}
127
+ for k , ts_data in enumerate (timeseries ):
128
+ data [k ] = ts_data
129
+ if isinstance (color , str ):
130
+ color = {k : color for k in data }
131
+ else :
132
+ data = {}
133
+ data ["default" ] = timeseries
134
+ color = {"default" : color }
135
+
136
+ if axis is None :
137
+ _ , axis = plt .subplots ()
138
+
139
+ first = True
140
+ for key , _timeseries in data .items ():
141
+ if not first :
142
+ showlegend = False
143
+
144
+ self ._add_timeseries (
145
+ axis , _timeseries , color = color [key ], showlegend = showlegend , name = name
146
+ )
147
+ first = False
148
+
149
+ return axis
150
+
151
+ def _add_timeseries (
152
+ self , axis , timeseries , color = "green" , name = "time series" , showlegend = False
153
+ ):
154
+ timestamps = timeseries .index
155
+ axis .plot (
156
+ timestamps ,
157
+ timeseries [timeseries .columns [0 ]],
158
+ color = color ,
159
+ label = (name if showlegend else "_nolegend_" ),
160
+ )
161
+
162
+ def plot_impact (self , feature_per_row = 2 ):
163
+ """Plot the impace.
164
+ Based on https://github.com/Trusted-AI/AIX360/blob/master/examples/tsice/plots.py"""
165
+
166
+ n_row = int (np .ceil (len (self .explanation ["feature_names" ]) / feature_per_row ))
167
+ feat_values = np .array (self .explanation ["feature_values" ])
168
+
169
+ fig , axs = plt .subplots (n_row , feature_per_row , figsize = (15 , 15 ))
170
+ axs = axs .ravel () # Flatten the axs to iterate over it
171
+
172
+ for i , feat in enumerate (self .explanation ["feature_names" ]):
173
+ x_feat = feat_values [i , :, 0 ]
174
+ trend_fit = LinearRegression ()
175
+ trend_line = trend_fit .fit (
176
+ x_feat .reshape (- 1 , 1 ), self .explanation ["signed_impact" ]
177
+ )
178
+ x_trend = np .linspace (min (x_feat ), max (x_feat ), 101 )
179
+ y_trend = trend_line .predict (x_trend [..., np .newaxis ])
180
+
181
+ # Scatter plot
182
+ axs [i ].scatter (x = x_feat , y = self .explanation ["signed_impact" ], color = "blue" )
183
+ # Line plot
184
+ axs [i ].plot (
185
+ x_trend ,
186
+ y_trend ,
187
+ color = "green" ,
188
+ label = "correlation between forecast and observed feature" ,
189
+ )
190
+ # Reference line
191
+ current_value = self .explanation ["current_feature_values" ][i ][0 ]
192
+ axs [i ].axvline (
193
+ x = current_value ,
194
+ color = "firebrick" ,
195
+ linestyle = "--" ,
196
+ label = "current value" ,
197
+ )
198
+
199
+ axs [i ].set_xlabel (feat )
200
+ axs [i ].set_ylabel ("Δ forecast" )
38
201
39
- def _matplotlib_plot ( self , output_name : str , block : bool , call_show : bool ) -> None :
40
- pass
202
+ # Display the legend on the first subplot
203
+ axs [ 0 ]. legend ()
41
204
42
- def _get_bokeh_plot (self , output_name : str ) -> bokeh .models .Plot :
43
- pass
205
+ fig .suptitle ("Impact of Derived Variable On The Forecast" , fontsize = 16 )
206
+ plt .tight_layout ()
207
+ plt .subplots_adjust (top = 0.95 )
208
+ plt .show ()
44
209
45
210
46
211
class TSICEExplainer (TSICEExplainerAIX ):
0 commit comments