|
3 | 3 | # pylint: disable = unused-argument, duplicate-code, consider-using-f-string, invalid-name |
4 | 4 | from typing import Dict, Union |
5 | 5 |
|
6 | | -import bokeh.models |
7 | | -import matplotlib.pyplot as plt |
8 | | -import matplotlib as mpl |
9 | 6 | import numpy as np |
10 | | -from bokeh.models import ColumnDataSource, HoverTool |
11 | | -from bokeh.plotting import figure |
12 | 7 | import pandas as pd |
13 | 8 | from matplotlib.colors import LinearSegmentedColormap |
14 | 9 |
|
15 | 10 | from trustyai import _default_initializer # pylint: disable=unused-import |
16 | | -from trustyai.utils._visualisation import ( |
17 | | - DEFAULT_STYLE as ds, |
18 | | - DEFAULT_RC_PARAMS as drcp, |
19 | | - bold_red_html, |
20 | | - bold_green_html, |
21 | | - output_html, |
22 | | - feature_html, |
23 | | -) |
24 | | - |
| 11 | +from trustyai.utils._visualisation import DEFAULT_STYLE as ds |
25 | 12 | from trustyai.utils.data_conversions import ( |
26 | 13 | OneInputUnionType, |
27 | 14 | data_conversion_docstring, |
@@ -137,96 +124,6 @@ def as_html(self) -> pd.io.formats.style.Styler: |
137 | 124 | ) |
138 | 125 | return htmls |
139 | 126 |
|
140 | | - def _matplotlib_plot(self, output_name: str, block=True, call_show=True) -> None: |
141 | | - """Plot the LIME saliencies.""" |
142 | | - with mpl.rc_context(drcp): |
143 | | - dictionary = {} |
144 | | - for feature_importance in ( |
145 | | - self.saliency_map().get(output_name).getPerFeatureImportance() |
146 | | - ): |
147 | | - dictionary[ |
148 | | - feature_importance.getFeature().name |
149 | | - ] = feature_importance.getScore() |
150 | | - |
151 | | - colours = [ |
152 | | - ds["negative_primary_colour"] |
153 | | - if i < 0 |
154 | | - else ds["positive_primary_colour"] |
155 | | - for i in dictionary.values() |
156 | | - ] |
157 | | - plt.title(f"LIME: Feature Importances to {output_name}") |
158 | | - plt.barh( |
159 | | - range(len(dictionary)), |
160 | | - dictionary.values(), |
161 | | - align="center", |
162 | | - color=colours, |
163 | | - ) |
164 | | - plt.yticks(range(len(dictionary)), list(dictionary.keys())) |
165 | | - plt.tight_layout() |
166 | | - |
167 | | - if call_show: |
168 | | - plt.show(block=block) |
169 | | - |
170 | | - def _get_bokeh_plot(self, output_name) -> bokeh.models.Plot: |
171 | | - lime_data_source = pd.DataFrame( |
172 | | - [ |
173 | | - { |
174 | | - "feature": str(pfi.getFeature().getName()), |
175 | | - "saliency": pfi.getScore(), |
176 | | - } |
177 | | - for pfi in self.saliency_map()[output_name].getPerFeatureImportance() |
178 | | - ] |
179 | | - ) |
180 | | - lime_data_source["color"] = lime_data_source["saliency"].apply( |
181 | | - lambda x: ds["positive_primary_colour"] |
182 | | - if x >= 0 |
183 | | - else ds["negative_primary_colour"] |
184 | | - ) |
185 | | - lime_data_source["saliency_colored"] = lime_data_source["saliency"].apply( |
186 | | - lambda x: (bold_green_html if x >= 0 else bold_red_html)("{:.2f}".format(x)) |
187 | | - ) |
188 | | - |
189 | | - lime_data_source["color_faded"] = lime_data_source["saliency"].apply( |
190 | | - lambda x: ds["positive_primary_colour_faded"] |
191 | | - if x >= 0 |
192 | | - else ds["negative_primary_colour_faded"] |
193 | | - ) |
194 | | - source = ColumnDataSource(lime_data_source) |
195 | | - htool = HoverTool( |
196 | | - name="bars", |
197 | | - tooltips="<h3>LIME</h3> {} saliency to {}: @saliency_colored".format( |
198 | | - feature_html("@feature"), output_html(output_name) |
199 | | - ), |
200 | | - ) |
201 | | - bokeh_plot = figure( |
202 | | - sizing_mode="stretch_both", |
203 | | - title="Lime Feature Importances", |
204 | | - y_range=lime_data_source["feature"], |
205 | | - tools=[htool], |
206 | | - ) |
207 | | - bokeh_plot.hbar( |
208 | | - y="feature", |
209 | | - left=0, |
210 | | - right="saliency", |
211 | | - fill_color="color_faded", |
212 | | - line_color="color", |
213 | | - hover_color="color", |
214 | | - color="color", |
215 | | - height=0.75, |
216 | | - name="bars", |
217 | | - source=source, |
218 | | - ) |
219 | | - bokeh_plot.line([0, 0], [0, len(lime_data_source)], color="#000") |
220 | | - bokeh_plot.xaxis.axis_label = "Saliency Value" |
221 | | - bokeh_plot.yaxis.axis_label = "Feature" |
222 | | - return bokeh_plot |
223 | | - |
224 | | - def _get_bokeh_plot_dict(self) -> Dict[str, bokeh.models.Plot]: |
225 | | - return { |
226 | | - output_name: self._get_bokeh_plot(output_name) |
227 | | - for output_name in self.saliency_map().keys() |
228 | | - } |
229 | | - |
230 | 127 |
|
231 | 128 | class LimeExplainer: |
232 | 129 | """*"Which features were most important to the results?"* |
|
0 commit comments