|
3 | 3 | # pylint: disable = unused-argument, duplicate-code, consider-using-f-string, invalid-name
|
4 | 4 | from typing import Dict, Union
|
5 | 5 |
|
6 |
| -import bokeh.models |
7 |
| -import matplotlib.pyplot as plt |
8 |
| -import matplotlib as mpl |
9 | 6 | import numpy as np
|
10 |
| -from bokeh.models import ColumnDataSource, HoverTool |
11 |
| -from bokeh.plotting import figure |
12 | 7 | import pandas as pd
|
13 | 8 | from matplotlib.colors import LinearSegmentedColormap
|
14 | 9 |
|
15 | 10 | from trustyai import _default_initializer # pylint: disable=unused-import
|
16 |
| -from trustyai.utils._visualisation import ( |
17 |
| - DEFAULT_STYLE as ds, |
18 |
| - DEFAULT_RC_PARAMS as drcp, |
19 |
| - bold_red_html, |
20 |
| - bold_green_html, |
21 |
| - output_html, |
22 |
| - feature_html, |
23 |
| -) |
24 |
| - |
| 11 | +from trustyai.utils._visualisation import DEFAULT_STYLE as ds |
25 | 12 | from trustyai.utils.data_conversions import (
|
26 | 13 | OneInputUnionType,
|
27 | 14 | data_conversion_docstring,
|
@@ -137,96 +124,6 @@ def as_html(self) -> pd.io.formats.style.Styler:
|
137 | 124 | )
|
138 | 125 | return htmls
|
139 | 126 |
|
140 |
| - def _matplotlib_plot(self, output_name: str, block=True, call_show=True) -> None: |
141 |
| - """Plot the LIME saliencies.""" |
142 |
| - with mpl.rc_context(drcp): |
143 |
| - dictionary = {} |
144 |
| - for feature_importance in ( |
145 |
| - self.saliency_map().get(output_name).getPerFeatureImportance() |
146 |
| - ): |
147 |
| - dictionary[ |
148 |
| - feature_importance.getFeature().name |
149 |
| - ] = feature_importance.getScore() |
150 |
| - |
151 |
| - colours = [ |
152 |
| - ds["negative_primary_colour"] |
153 |
| - if i < 0 |
154 |
| - else ds["positive_primary_colour"] |
155 |
| - for i in dictionary.values() |
156 |
| - ] |
157 |
| - plt.title(f"LIME: Feature Importances to {output_name}") |
158 |
| - plt.barh( |
159 |
| - range(len(dictionary)), |
160 |
| - dictionary.values(), |
161 |
| - align="center", |
162 |
| - color=colours, |
163 |
| - ) |
164 |
| - plt.yticks(range(len(dictionary)), list(dictionary.keys())) |
165 |
| - plt.tight_layout() |
166 |
| - |
167 |
| - if call_show: |
168 |
| - plt.show(block=block) |
169 |
| - |
170 |
| - def _get_bokeh_plot(self, output_name) -> bokeh.models.Plot: |
171 |
| - lime_data_source = pd.DataFrame( |
172 |
| - [ |
173 |
| - { |
174 |
| - "feature": str(pfi.getFeature().getName()), |
175 |
| - "saliency": pfi.getScore(), |
176 |
| - } |
177 |
| - for pfi in self.saliency_map()[output_name].getPerFeatureImportance() |
178 |
| - ] |
179 |
| - ) |
180 |
| - lime_data_source["color"] = lime_data_source["saliency"].apply( |
181 |
| - lambda x: ds["positive_primary_colour"] |
182 |
| - if x >= 0 |
183 |
| - else ds["negative_primary_colour"] |
184 |
| - ) |
185 |
| - lime_data_source["saliency_colored"] = lime_data_source["saliency"].apply( |
186 |
| - lambda x: (bold_green_html if x >= 0 else bold_red_html)("{:.2f}".format(x)) |
187 |
| - ) |
188 |
| - |
189 |
| - lime_data_source["color_faded"] = lime_data_source["saliency"].apply( |
190 |
| - lambda x: ds["positive_primary_colour_faded"] |
191 |
| - if x >= 0 |
192 |
| - else ds["negative_primary_colour_faded"] |
193 |
| - ) |
194 |
| - source = ColumnDataSource(lime_data_source) |
195 |
| - htool = HoverTool( |
196 |
| - name="bars", |
197 |
| - tooltips="<h3>LIME</h3> {} saliency to {}: @saliency_colored".format( |
198 |
| - feature_html("@feature"), output_html(output_name) |
199 |
| - ), |
200 |
| - ) |
201 |
| - bokeh_plot = figure( |
202 |
| - sizing_mode="stretch_both", |
203 |
| - title="Lime Feature Importances", |
204 |
| - y_range=lime_data_source["feature"], |
205 |
| - tools=[htool], |
206 |
| - ) |
207 |
| - bokeh_plot.hbar( |
208 |
| - y="feature", |
209 |
| - left=0, |
210 |
| - right="saliency", |
211 |
| - fill_color="color_faded", |
212 |
| - line_color="color", |
213 |
| - hover_color="color", |
214 |
| - color="color", |
215 |
| - height=0.75, |
216 |
| - name="bars", |
217 |
| - source=source, |
218 |
| - ) |
219 |
| - bokeh_plot.line([0, 0], [0, len(lime_data_source)], color="#000") |
220 |
| - bokeh_plot.xaxis.axis_label = "Saliency Value" |
221 |
| - bokeh_plot.yaxis.axis_label = "Feature" |
222 |
| - return bokeh_plot |
223 |
| - |
224 |
| - def _get_bokeh_plot_dict(self) -> Dict[str, bokeh.models.Plot]: |
225 |
| - return { |
226 |
| - output_name: self._get_bokeh_plot(output_name) |
227 |
| - for output_name in self.saliency_map().keys() |
228 |
| - } |
229 |
| - |
230 | 127 |
|
231 | 128 | class LimeExplainer:
|
232 | 129 | """*"Which features were most important to the results?"*
|
|
0 commit comments