77
88from aix360 .algorithms .tsice import TSICEExplainer as TSICEExplainerAIX
99from aix360 .algorithms .tsutils .tsperturbers import TSPerturber
10- import bokeh
1110import pandas as pd
11+ import matplotlib .pyplot as plt
12+ import numpy as np
13+ from sklearn .linear_model import LinearRegression
1214
13- from trustyai .model import SaliencyResults
15+ from trustyai .explainers . explanation_results import ExplanationResults
1416
1517
16- class TSICEResults (SaliencyResults ):
18+ class TSICEResults (ExplanationResults ):
1719 """Wraps TSICE results. This object is returned by the :class:`~TSICEExplainer`,
1820 and provides a variety of methods to visualize and interact with the explanation.
1921 """
@@ -23,24 +25,187 @@ def __init__(self, explanation):
2325
2426 def as_dataframe (self ) -> pd .DataFrame :
2527 """Returns the explanation as a pandas dataframe."""
26- return pd .DataFrame (self .explanation )
28+ # Initialize an empty DataFrame
29+ dataframe = pd .DataFrame ()
30+
31+ # Loop through each feature_name and each key in data_x
32+ for key in self .explanation ["data_x" ]:
33+ for i , feature in enumerate (self .explanation ["feature_names" ]):
34+ dataframe [f"{ key } -{ feature } " ] = [
35+ val [0 ] for val in self .explanation ["feature_values" ][i ]
36+ ]
37+
38+ # Add "total_impact" as a column
39+ dataframe ["total_impact" ] = self .explanation ["total_impact" ]
40+ return dataframe
2741
2842 def as_html (self ) -> pd .io .formats .style .Styler :
2943 """Returns the explanation as an HTML table."""
3044 dataframe = self .as_dataframe ()
3145 return dataframe .style
3246
33- def saliency_map (self ):
34- """
35- Returns a dictionary of feature names and their total impact.
36- """
37- dict (zip (self .explanation ["feature_names" ], self .explanation ["total_impact" ]))
47+ def plot_forecast (self , variable ): # pylint: disable=too-many-locals
48+ """Plots the explanation.
49+ Based on https://github.com/Trusted-AI/AIX360/blob/master/examples/tsice/plots.py"""
50+ forecast_horizon = self .explanation ["current_forecast" ].shape [0 ]
51+ original_ts = pd .DataFrame (
52+ data = {variable : self .explanation ["data_x" ][variable ]}
53+ )
54+ perturbations = [d for d in self .explanation ["perturbations" ] if variable in d ]
55+
56+ # Generate a list of keys
57+ keys = list (self .explanation ["data_x" ].keys ())
58+ # Find the index of the given key
59+ key = keys .index (variable )
60+ forecasts_on_perturbations = [
61+ arr [:, key : key + 1 ]
62+ for arr in self .explanation ["forecasts_on_perturbations" ]
63+ ]
64+
65+ new_perturbations = []
66+ new_timestamps = []
67+ pred_ts = []
68+
69+ original_ts .index .freq = pd .infer_freq (original_ts .index )
70+ for i in range (1 , forecast_horizon + 1 ):
71+ new_timestamps .append (original_ts .index [- 1 ] + (i * original_ts .index .freq ))
72+
73+ for perturbation in perturbations :
74+ new_perturbations .append (pd .DataFrame (perturbation ))
75+
76+ for forecast in forecasts_on_perturbations :
77+ pred_ts .append (pd .DataFrame (forecast , index = new_timestamps ))
78+
79+ current_forecast = self .explanation ["current_forecast" ][:, key : key + 1 ]
80+ pred_original_ts = pd .DataFrame (current_forecast , index = new_timestamps )
81+
82+ _ , axis = plt .subplots ()
83+
84+ # Plot perturbed time series
85+ axis = self ._plot_timeseries (
86+ new_perturbations ,
87+ color = "lightgreen" ,
88+ axis = axis ,
89+ name = "perturbed timeseries samples" ,
90+ )
91+
92+ # Plot original time series
93+ axis = self ._plot_timeseries (
94+ original_ts , color = "green" , axis = axis , name = "input/original timeseries"
95+ )
96+
97+ # Plot varying forecast range
98+ axis = self ._plot_timeseries (
99+ pred_ts , color = "lightblue" , axis = axis , name = "forecast on perturbed samples"
100+ )
101+
102+ # Plot original forecast
103+ axis = self ._plot_timeseries (
104+ pred_original_ts , color = "blue" , axis = axis , name = "original forecast"
105+ )
106+
107+ # Set labels and title
108+ axis .set_xlabel ("Timestamp" )
109+ axis .set_ylabel (variable )
110+ axis .set_title ("Time-Series Individual Conditional Expectation (TSICE)" )
111+
112+ axis .legend ()
113+
114+ # Display the plot
115+ plt .show ()
116+
117+ def _plot_timeseries (
118+ self , timeseries , color = "green" , axis = None , name = "time series"
119+ ):
120+ showlegend = True
121+ if isinstance (timeseries , dict ):
122+ data = timeseries
123+ if isinstance (color , str ):
124+ color = {k : color for k in data }
125+ elif isinstance (timeseries , list ):
126+ data = {}
127+ for k , ts_data in enumerate (timeseries ):
128+ data [k ] = ts_data
129+ if isinstance (color , str ):
130+ color = {k : color for k in data }
131+ else :
132+ data = {}
133+ data ["default" ] = timeseries
134+ color = {"default" : color }
135+
136+ if axis is None :
137+ _ , axis = plt .subplots ()
138+
139+ first = True
140+ for key , _timeseries in data .items ():
141+ if not first :
142+ showlegend = False
143+
144+ self ._add_timeseries (
145+ axis , _timeseries , color = color [key ], showlegend = showlegend , name = name
146+ )
147+ first = False
148+
149+ return axis
150+
151+ def _add_timeseries (
152+ self , axis , timeseries , color = "green" , name = "time series" , showlegend = False
153+ ):
154+ timestamps = timeseries .index
155+ axis .plot (
156+ timestamps ,
157+ timeseries [timeseries .columns [0 ]],
158+ color = color ,
159+ label = (name if showlegend else "_nolegend_" ),
160+ )
161+
162+ def plot_impact (self , feature_per_row = 2 ):
163+ """Plot the impace.
164+ Based on https://github.com/Trusted-AI/AIX360/blob/master/examples/tsice/plots.py"""
165+
166+ n_row = int (np .ceil (len (self .explanation ["feature_names" ]) / feature_per_row ))
167+ feat_values = np .array (self .explanation ["feature_values" ])
168+
169+ fig , axs = plt .subplots (n_row , feature_per_row , figsize = (15 , 15 ))
170+ axs = axs .ravel () # Flatten the axs to iterate over it
171+
172+ for i , feat in enumerate (self .explanation ["feature_names" ]):
173+ x_feat = feat_values [i , :, 0 ]
174+ trend_fit = LinearRegression ()
175+ trend_line = trend_fit .fit (
176+ x_feat .reshape (- 1 , 1 ), self .explanation ["signed_impact" ]
177+ )
178+ x_trend = np .linspace (min (x_feat ), max (x_feat ), 101 )
179+ y_trend = trend_line .predict (x_trend [..., np .newaxis ])
180+
181+ # Scatter plot
182+ axs [i ].scatter (x = x_feat , y = self .explanation ["signed_impact" ], color = "blue" )
183+ # Line plot
184+ axs [i ].plot (
185+ x_trend ,
186+ y_trend ,
187+ color = "green" ,
188+ label = "correlation between forecast and observed feature" ,
189+ )
190+ # Reference line
191+ current_value = self .explanation ["current_feature_values" ][i ][0 ]
192+ axs [i ].axvline (
193+ x = current_value ,
194+ color = "firebrick" ,
195+ linestyle = "--" ,
196+ label = "current value" ,
197+ )
198+
199+ axs [i ].set_xlabel (feat )
200+ axs [i ].set_ylabel ("Δ forecast" )
38201
39- def _matplotlib_plot ( self , output_name : str , block : bool , call_show : bool ) -> None :
40- pass
202+ # Display the legend on the first subplot
203+ axs [ 0 ]. legend ()
41204
42- def _get_bokeh_plot (self , output_name : str ) -> bokeh .models .Plot :
43- pass
205+ fig .suptitle ("Impact of Derived Variable On The Forecast" , fontsize = 16 )
206+ plt .tight_layout ()
207+ plt .subplots_adjust (top = 0.95 )
208+ plt .show ()
44209
45210
46211class TSICEExplainer (TSICEExplainerAIX ):
0 commit comments