|  | 
| 3 | 3 | # pylint: disable = unused-argument, duplicate-code, consider-using-f-string, invalid-name | 
| 4 | 4 | from typing import Dict, Union | 
| 5 | 5 | 
 | 
| 6 |  | -import bokeh.models | 
| 7 |  | -import matplotlib.pyplot as plt | 
| 8 |  | -import matplotlib as mpl | 
| 9 | 6 | import numpy as np | 
| 10 |  | -from bokeh.models import ColumnDataSource, HoverTool | 
| 11 |  | -from bokeh.plotting import figure | 
| 12 | 7 | import pandas as pd | 
| 13 | 8 | from matplotlib.colors import LinearSegmentedColormap | 
| 14 | 9 | 
 | 
| 15 | 10 | from trustyai import _default_initializer  # pylint: disable=unused-import | 
| 16 |  | -from trustyai.utils._visualisation import ( | 
| 17 |  | -    DEFAULT_STYLE as ds, | 
| 18 |  | -    DEFAULT_RC_PARAMS as drcp, | 
| 19 |  | -    bold_red_html, | 
| 20 |  | -    bold_green_html, | 
| 21 |  | -    output_html, | 
| 22 |  | -    feature_html, | 
| 23 |  | -) | 
| 24 |  | - | 
|  | 11 | +from trustyai.utils._visualisation import DEFAULT_STYLE as ds | 
| 25 | 12 | from trustyai.utils.data_conversions import ( | 
| 26 | 13 |     OneInputUnionType, | 
| 27 | 14 |     data_conversion_docstring, | 
| @@ -137,96 +124,6 @@ def as_html(self) -> pd.io.formats.style.Styler: | 
| 137 | 124 |             ) | 
| 138 | 125 |         return htmls | 
| 139 | 126 | 
 | 
| 140 |  | -    def _matplotlib_plot(self, output_name: str, block=True, call_show=True) -> None: | 
| 141 |  | -        """Plot the LIME saliencies.""" | 
| 142 |  | -        with mpl.rc_context(drcp): | 
| 143 |  | -            dictionary = {} | 
| 144 |  | -            for feature_importance in ( | 
| 145 |  | -                self.saliency_map().get(output_name).getPerFeatureImportance() | 
| 146 |  | -            ): | 
| 147 |  | -                dictionary[ | 
| 148 |  | -                    feature_importance.getFeature().name | 
| 149 |  | -                ] = feature_importance.getScore() | 
| 150 |  | - | 
| 151 |  | -            colours = [ | 
| 152 |  | -                ds["negative_primary_colour"] | 
| 153 |  | -                if i < 0 | 
| 154 |  | -                else ds["positive_primary_colour"] | 
| 155 |  | -                for i in dictionary.values() | 
| 156 |  | -            ] | 
| 157 |  | -            plt.title(f"LIME: Feature Importances to {output_name}") | 
| 158 |  | -            plt.barh( | 
| 159 |  | -                range(len(dictionary)), | 
| 160 |  | -                dictionary.values(), | 
| 161 |  | -                align="center", | 
| 162 |  | -                color=colours, | 
| 163 |  | -            ) | 
| 164 |  | -            plt.yticks(range(len(dictionary)), list(dictionary.keys())) | 
| 165 |  | -            plt.tight_layout() | 
| 166 |  | - | 
| 167 |  | -            if call_show: | 
| 168 |  | -                plt.show(block=block) | 
| 169 |  | - | 
| 170 |  | -    def _get_bokeh_plot(self, output_name) -> bokeh.models.Plot: | 
| 171 |  | -        lime_data_source = pd.DataFrame( | 
| 172 |  | -            [ | 
| 173 |  | -                { | 
| 174 |  | -                    "feature": str(pfi.getFeature().getName()), | 
| 175 |  | -                    "saliency": pfi.getScore(), | 
| 176 |  | -                } | 
| 177 |  | -                for pfi in self.saliency_map()[output_name].getPerFeatureImportance() | 
| 178 |  | -            ] | 
| 179 |  | -        ) | 
| 180 |  | -        lime_data_source["color"] = lime_data_source["saliency"].apply( | 
| 181 |  | -            lambda x: ds["positive_primary_colour"] | 
| 182 |  | -            if x >= 0 | 
| 183 |  | -            else ds["negative_primary_colour"] | 
| 184 |  | -        ) | 
| 185 |  | -        lime_data_source["saliency_colored"] = lime_data_source["saliency"].apply( | 
| 186 |  | -            lambda x: (bold_green_html if x >= 0 else bold_red_html)("{:.2f}".format(x)) | 
| 187 |  | -        ) | 
| 188 |  | - | 
| 189 |  | -        lime_data_source["color_faded"] = lime_data_source["saliency"].apply( | 
| 190 |  | -            lambda x: ds["positive_primary_colour_faded"] | 
| 191 |  | -            if x >= 0 | 
| 192 |  | -            else ds["negative_primary_colour_faded"] | 
| 193 |  | -        ) | 
| 194 |  | -        source = ColumnDataSource(lime_data_source) | 
| 195 |  | -        htool = HoverTool( | 
| 196 |  | -            name="bars", | 
| 197 |  | -            tooltips="<h3>LIME</h3> {} saliency to {}: @saliency_colored".format( | 
| 198 |  | -                feature_html("@feature"), output_html(output_name) | 
| 199 |  | -            ), | 
| 200 |  | -        ) | 
| 201 |  | -        bokeh_plot = figure( | 
| 202 |  | -            sizing_mode="stretch_both", | 
| 203 |  | -            title="Lime Feature Importances", | 
| 204 |  | -            y_range=lime_data_source["feature"], | 
| 205 |  | -            tools=[htool], | 
| 206 |  | -        ) | 
| 207 |  | -        bokeh_plot.hbar( | 
| 208 |  | -            y="feature", | 
| 209 |  | -            left=0, | 
| 210 |  | -            right="saliency", | 
| 211 |  | -            fill_color="color_faded", | 
| 212 |  | -            line_color="color", | 
| 213 |  | -            hover_color="color", | 
| 214 |  | -            color="color", | 
| 215 |  | -            height=0.75, | 
| 216 |  | -            name="bars", | 
| 217 |  | -            source=source, | 
| 218 |  | -        ) | 
| 219 |  | -        bokeh_plot.line([0, 0], [0, len(lime_data_source)], color="#000") | 
| 220 |  | -        bokeh_plot.xaxis.axis_label = "Saliency Value" | 
| 221 |  | -        bokeh_plot.yaxis.axis_label = "Feature" | 
| 222 |  | -        return bokeh_plot | 
| 223 |  | - | 
| 224 |  | -    def _get_bokeh_plot_dict(self) -> Dict[str, bokeh.models.Plot]: | 
| 225 |  | -        return { | 
| 226 |  | -            output_name: self._get_bokeh_plot(output_name) | 
| 227 |  | -            for output_name in self.saliency_map().keys() | 
| 228 |  | -        } | 
| 229 |  | - | 
| 230 | 127 | 
 | 
| 231 | 128 | class LimeExplainer: | 
| 232 | 129 |     """*"Which features were most important to the results?"* | 
|  | 
0 commit comments