@@ -123,15 +123,24 @@ def _saliency_to_dataframe(self, saliency, output_name):
123123
124124 return pd .DataFrame ([fnull ] + data_rows )
125125
126- def as_dataframe (self ) -> Dict [str , pd .DataFrame ]:
126+ def as_dataframe (
127+ self , output_name : str = None
128+ ) -> Union [Dict [str , pd .DataFrame ], pd .DataFrame ]:
127129 """
128130 Return the SHAP results as dataframes.
129131
132+ Parameters
133+ ----------
134+ output_name: str
135+ If an output_name is passed, that output's explanation is returned as a dataframe.
136+ Otherwise, all outputs' explanation dataframe are returned in a dictionary.
137+
130138 Returns
131139 -------
132- Dict[str, pandas.DataFrame]
133- Dictionary of DataFrames, keyed by output name, containing the results of the SHAP
134- explanation. For each model output, the table will contain the following columns:
140+ pandas.Dataframe or Dict[str, pandas.DataFrame]
141+ A dataframe or dictionary of DataFrames, keyed by output name. Each dataframe
142+ contains the results of the SHAP explanation for a particular output. Each dataframe
143+ wiil contain the following columns:
135144
136145 * ``Feature``: The name of the feature
137146 * ``Feature Value``: The value of the feature for this particular input.
@@ -140,18 +149,33 @@ def as_dataframe(self) -> Dict[str, pd.DataFrame]:
140149 * ``Confidence``: The confidence of this explanation as returned by the explainer.
141150
142151 """
143- df_dict = {}
144- for output_name , saliency in self .saliency_map ().items ():
145- df_dict [output_name ] = self ._saliency_to_dataframe (saliency , output_name )
146- return df_dict
152+ if output_name is None :
153+ df_dict = {}
154+ for output_name_key , saliency in self .saliency_map ().items ():
155+ df_dict [output_name_key ] = self ._saliency_to_dataframe (
156+ saliency , output_name_key
157+ )
158+ return df_dict
159+ return self ._saliency_to_dataframe (
160+ self .saliency_map ()[output_name ], output_name
161+ )
147162
148- def as_html (self ) -> Dict [str , pd .io .formats .style .Styler ]:
163+ def as_html (
164+ self , output_name : str = None
165+ ) -> Union [Dict [str , pd .io .formats .style .Styler ], pd .io .formats .style .Styler ]:
149166 """
150167 Return the SHAP results as Pandas Styler objects.
151168
169+ Parameters
170+ ----------
171+ output_name: str
172+ If an output_name is passed, that output's explanation is returned as a pandas Styler.
173+ Otherwise, all outputs' explanation stylers are returned in a dictionary.
174+
175+
152176 Returns
153177 -------
154- Dict[str, pandas.Styler]
178+ Pandas Styler or Dict[str, pandas.Styler]
155179 Dictionary of stylers keyed by output name. Each styler containing the results of the
156180 SHAP explanation for that particular output, in the same
157181 schema as in :func:`as_dataframe`. This will:
@@ -174,31 +198,35 @@ def _color_feature_values(feature_values, background_vals):
174198 return [None ] + formats
175199
176200 df_dict = {}
177- for output_name , saliency in self .saliency_map ().items ():
178- df = self ._saliency_to_dataframe (saliency , output_name )
179- shap_values = df ["SHAP Value" ].values [1 :]
180- background_mean_feature_values = df ["Mean Background Value" ].values [1 :]
181-
182- style = df .style .background_gradient (
183- LinearSegmentedColormap .from_list (
184- name = "rwg" ,
185- colors = [
186- ds ["negative_primary_colour" ],
187- ds ["neutral_primary_colour" ],
188- ds ["positive_primary_colour" ],
189- ],
190- ),
191- subset = (slice (1 , None ), "SHAP Value" ),
192- vmin = - 1 * max (np .abs (shap_values )),
193- vmax = max (np .abs (shap_values )),
194- )
195- style .set_caption (f"Explanation of { output_name } " )
196- df_dict [output_name ] = style .apply (
197- _color_feature_values ,
198- background_vals = background_mean_feature_values ,
199- subset = "Value" ,
200- axis = 0 ,
201- )
201+ for output_name_key , saliency in self .saliency_map ().items ():
202+ if output_name is None or output_name_key == output_name :
203+ df = self ._saliency_to_dataframe (saliency , output_name_key )
204+ shap_values = df ["SHAP Value" ].values [1 :]
205+ background_mean_feature_values = df ["Mean Background Value" ].values [1 :]
206+
207+ style = df .style .background_gradient (
208+ LinearSegmentedColormap .from_list (
209+ name = "rwg" ,
210+ colors = [
211+ ds ["negative_primary_colour" ],
212+ ds ["neutral_primary_colour" ],
213+ ds ["positive_primary_colour" ],
214+ ],
215+ ),
216+ subset = (slice (1 , None ), "SHAP Value" ),
217+ vmin = - 1 * max (np .abs (shap_values )),
218+ vmax = max (np .abs (shap_values )),
219+ )
220+ style .set_caption (f"SHAP Explanation of { output_name_key } " )
221+ df_dict [output_name_key ] = style .apply (
222+ _color_feature_values ,
223+ background_vals = background_mean_feature_values ,
224+ subset = "Value" ,
225+ axis = 0 ,
226+ )
227+
228+ if output_name is not None :
229+ return df_dict [output_name ]
202230 return df_dict
203231
204232 def _matplotlib_plot (self , output_name , block = True ) -> None :
0 commit comments