22# pylint: disable = import-error, too-few-public-methods, wrong-import-order, line-too-long,
33# pylint: disable = unused-argument, duplicate-code, consider-using-f-string, invalid-name
44from typing import Dict
5+
6+ import bokeh .models
57import matplotlib .pyplot as plt
68import matplotlib as mpl
79from bokeh .models import ColumnDataSource , HoverTool
1012
1113from trustyai import _default_initializer # pylint: disable=unused-import
1214from trustyai .utils ._visualisation import (
13- ExplanationVisualiser ,
1415 DEFAULT_STYLE as ds ,
1516 DEFAULT_RC_PARAMS as drcp ,
1617 bold_red_html ,
1718 bold_green_html ,
1819 output_html ,
1920 feature_html ,
2021)
21-
22+ from . explanation_results import SaliencyResults
2223from trustyai .model import simple_prediction , PredUnionType
2324
2425from org .kie .trustyai .explainability .local .lime import (
2930 EncodingParams ,
3031 PredictionProvider ,
3132 Saliency ,
32- SaliencyResults ,
3333 PerturbationContext ,
3434)
3535
3838LimeConfig = _LimeConfig
3939
4040
41- class LimeResults (ExplanationVisualiser ):
41+ class LimeResults (SaliencyResults ):
4242 """Wraps LIME results. This object is returned by the :class:`~LimeExplainer`,
4343 and provides a variety of methods to visualize and interact with the explanation.
4444 """
4545
4646 def __init__ (self , saliencyResults : SaliencyResults ):
4747 """Constructor method. This is called internally, and shouldn't ever need to be used
4848 manually."""
49- self ._saliency_results = saliencyResults
49+ self ._java_saliency_results = saliencyResults
5050
51- def map (self ) -> Dict [str , Saliency ]:
51+ def saliency_map (self ) -> Dict [str , Saliency ]:
5252 """
5353 Return a dictionary of found saliencies.
5454
@@ -59,7 +59,7 @@ def map(self) -> Dict[str, Saliency]:
5959 """
6060 return {
6161 entry .getKey (): entry .getValue ()
62- for entry in self ._saliency_results .saliencies .entrySet ()
62+ for entry in self ._java_saliency_results .saliencies .entrySet ()
6363 }
6464
6565 def as_dataframe (self ) -> pd .DataFrame :
@@ -77,11 +77,11 @@ def as_dataframe(self) -> pd.DataFrame:
7777 * ``${output_name}_value``: The original value of each feature.
7878 * ``${output_name}_confidence``: The confidence of the reported saliency.
7979 """
80- outputs = self .map ().keys ()
80+ outputs = self .saliency_map ().keys ()
8181
8282 data = {}
8383 for output in outputs :
84- pfis = self .map ().get (output ).getPerFeatureImportance ()
84+ pfis = self .saliency_map ().get (output ).getPerFeatureImportance ()
8585 data [f"{ output } _features" ] = [
8686 f"{ pfi .getFeature ().getName ()} " for pfi in pfis
8787 ]
@@ -106,12 +106,12 @@ def as_html(self) -> pd.io.formats.style.Styler:
106106 """
107107 return self .as_dataframe ().style
108108
109- def plot (self , decision : str ) -> None :
109+ def _matplotlib_plot (self , output_name : str ) -> None :
110110 """Plot the LIME saliencies."""
111111 with mpl .rc_context (drcp ):
112112 dictionary = {}
113113 for feature_importance in (
114- self .map ().get (decision ).getPerFeatureImportance ()
114+ self .saliency_map ().get (output_name ).getPerFeatureImportance ()
115115 ):
116116 dictionary [
117117 feature_importance .getFeature ().name
@@ -123,7 +123,7 @@ def plot(self, decision: str) -> None:
123123 else ds ["positive_primary_colour" ]
124124 for i in dictionary .values ()
125125 ]
126- plt .title (f"LIME explanation of { decision } " )
126+ plt .title (f"LIME explanation of { output_name } " )
127127 plt .barh (
128128 range (len (dictionary )),
129129 dictionary .values (),
@@ -134,64 +134,65 @@ def plot(self, decision: str) -> None:
134134 plt .tight_layout ()
135135 plt .show ()
136136
137- def _get_bokeh_plot_dict (self ):
138- plot_dict = {}
139- for output_name , value in self .map ().items ():
140- lime_data_source = pd .DataFrame (
141- [
142- {
143- "feature" : str (pfi .getFeature ().getName ()),
144- "saliency" : pfi .getScore (),
145- }
146- for pfi in value .getPerFeatureImportance ()
147- ]
148- )
149- lime_data_source ["color" ] = lime_data_source ["saliency" ].apply (
150- lambda x : ds ["positive_primary_colour" ]
151- if x >= 0
152- else ds ["negative_primary_colour" ]
153- )
154- lime_data_source ["saliency_colored" ] = lime_data_source ["saliency" ].apply (
155- lambda x : (bold_green_html if x >= 0 else bold_red_html )(
156- "{:.2f}" .format (x )
157- )
158- )
137+ def _get_bokeh_plot (self , output_name ) -> bokeh .models .Plot :
138+ lime_data_source = pd .DataFrame (
139+ [
140+ {
141+ "feature" : str (pfi .getFeature ().getName ()),
142+ "saliency" : pfi .getScore (),
143+ }
144+ for pfi in self .saliency_map ()[output_name ].getPerFeatureImportance ()
145+ ]
146+ )
147+ lime_data_source ["color" ] = lime_data_source ["saliency" ].apply (
148+ lambda x : ds ["positive_primary_colour" ]
149+ if x >= 0
150+ else ds ["negative_primary_colour" ]
151+ )
152+ lime_data_source ["saliency_colored" ] = lime_data_source ["saliency" ].apply (
153+ lambda x : (bold_green_html if x >= 0 else bold_red_html )("{:.2f}" .format (x ))
154+ )
159155
160- lime_data_source ["color_faded" ] = lime_data_source ["saliency" ].apply (
161- lambda x : ds ["positive_primary_colour_faded" ]
162- if x >= 0
163- else ds ["negative_primary_colour_faded" ]
164- )
165- source = ColumnDataSource (lime_data_source )
166- htool = HoverTool (
167- names = ["bars" ],
168- tooltips = "<h3>LIME</h3> {} saliency to {}: @saliency_colored" .format (
169- feature_html ("@feature" ), output_html (output_name )
170- ),
171- )
172- bokeh_plot = figure (
173- sizing_mode = "stretch_both" ,
174- title = "Lime Feature Importances" ,
175- y_range = lime_data_source ["feature" ],
176- tools = [htool ],
177- )
178- bokeh_plot .hbar (
179- y = "feature" ,
180- left = 0 ,
181- right = "saliency" ,
182- fill_color = "color_faded" ,
183- line_color = "color" ,
184- hover_color = "color" ,
185- color = "color" ,
186- height = 0.75 ,
187- name = "bars" ,
188- source = source ,
189- )
190- bokeh_plot .line ([0 , 0 ], [0 , len (lime_data_source )], color = "#000" )
191- bokeh_plot .xaxis .axis_label = "Saliency Value"
192- bokeh_plot .yaxis .axis_label = "Feature"
193- plot_dict [output_name ] = bokeh_plot
194- return plot_dict
156+ lime_data_source ["color_faded" ] = lime_data_source ["saliency" ].apply (
157+ lambda x : ds ["positive_primary_colour_faded" ]
158+ if x >= 0
159+ else ds ["negative_primary_colour_faded" ]
160+ )
161+ source = ColumnDataSource (lime_data_source )
162+ htool = HoverTool (
163+ names = ["bars" ],
164+ tooltips = "<h3>LIME</h3> {} saliency to {}: @saliency_colored" .format (
165+ feature_html ("@feature" ), output_html (output_name )
166+ ),
167+ )
168+ bokeh_plot = figure (
169+ sizing_mode = "stretch_both" ,
170+ title = "Lime Feature Importances" ,
171+ y_range = lime_data_source ["feature" ],
172+ tools = [htool ],
173+ )
174+ bokeh_plot .hbar (
175+ y = "feature" ,
176+ left = 0 ,
177+ right = "saliency" ,
178+ fill_color = "color_faded" ,
179+ line_color = "color" ,
180+ hover_color = "color" ,
181+ color = "color" ,
182+ height = 0.75 ,
183+ name = "bars" ,
184+ source = source ,
185+ )
186+ bokeh_plot .line ([0 , 0 ], [0 , len (lime_data_source )], color = "#000" )
187+ bokeh_plot .xaxis .axis_label = "Saliency Value"
188+ bokeh_plot .yaxis .axis_label = "Feature"
189+ return bokeh_plot
190+
191+ def _get_bokeh_plot_dict (self ) -> Dict [str , bokeh .models .Plot ]:
192+ return {
193+ output_name : self ._get_bokeh_plot (output_name )
194+ for output_name in self .saliency_map ().keys ()
195+ }
195196
196197
197198class LimeExplainer :
0 commit comments