Skip to content

Commit 48714b0

Browse files
ahuber21icfaust
andauthored
feat: SHAP value support for XGBoost's binary classification models (#1660)
* dispatch SHAP settings for XGBoost clsf * enable XGBoost Classification SHAP check * Update checks for SHAP binary classification * add pred_contribs/pred_interactions keyword support * fix classification SHAP value tests * fix some test failures * include daal_check_version * remove circular import * forgotten evaluation * disable tests for older onedal versions * change tolerances * change correct tolerances * return to original design * fix for 2024.7 * modify tests * forgotten formatting --------- Co-authored-by: icfaust <[email protected]>
1 parent 45fc83d commit 48714b0

File tree

3 files changed

+173
-20
lines changed

3 files changed

+173
-20
lines changed

daal4py/mb/model_builders.py

Lines changed: 68 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,9 @@ def _convert_model(self, model):
190190
else:
191191
raise TypeError(f"Unknown model format {submodule_name}.{class_name}")
192192

193-
def _predict_classification(self, X, fptype, resultsToEvaluate):
193+
def _predict_classification(
194+
self, X, fptype, resultsToEvaluate, pred_contribs=False, pred_interactions=False
195+
):
194196
if X.shape[1] != self.n_features_in_:
195197
raise ValueError("Shape of input is different from what was seen in `fit`")
196198

@@ -203,8 +205,34 @@ def _predict_classification(self, X, fptype, resultsToEvaluate):
203205
)
204206

205207
# Prediction
208+
try:
209+
return self._predict_classification_with_results_to_compute(
210+
X, fptype, resultsToEvaluate, pred_contribs, pred_interactions
211+
)
212+
except TypeError as e:
213+
if "unexpected keyword argument 'resultsToCompute'" in str(e):
214+
if pred_contribs or pred_interactions:
215+
# SHAP values requested, but not supported by this version
216+
raise TypeError(
217+
f"{'pred_contribs' if pred_contribs else 'pred_interactions'} not supported by this version of daal4py"
218+
) from e
219+
else:
220+
# unknown type error
221+
raise
222+
except RuntimeError as e:
223+
if "Method is not implemented" in str(e):
224+
if pred_contribs or pred_interactions:
225+
raise NotImplementedError(
226+
f"{'pred_contribs' if pred_contribs else 'pred_interactions'} is not implemented for classification models"
227+
)
228+
else:
229+
raise
230+
231+
# fallback to calculation without `resultsToCompute`
206232
predict_algo = d4p.gbt_classification_prediction(
207-
fptype=fptype, nClasses=self.n_classes_, resultsToEvaluate=resultsToEvaluate
233+
nClasses=self.n_classes_,
234+
fptype=fptype,
235+
resultsToEvaluate=resultsToEvaluate,
208236
)
209237
predict_result = predict_algo.compute(X, self.daal_model_)
210238

@@ -213,6 +241,40 @@ def _predict_classification(self, X, fptype, resultsToEvaluate):
213241
else:
214242
return predict_result.probabilities
215243

244+
def _predict_classification_with_results_to_compute(
245+
self,
246+
X,
247+
fptype,
248+
resultsToEvaluate,
249+
pred_contribs=False,
250+
pred_interactions=False,
251+
):
252+
"""Assume daal4py supports the resultsToCompute kwarg"""
253+
resultsToCompute = ""
254+
if pred_contribs:
255+
resultsToCompute = "shapContributions"
256+
elif pred_interactions:
257+
resultsToCompute = "shapInteractions"
258+
259+
predict_algo = d4p.gbt_classification_prediction(
260+
nClasses=self.n_classes_,
261+
fptype=fptype,
262+
resultsToCompute=resultsToCompute,
263+
resultsToEvaluate=resultsToEvaluate,
264+
)
265+
predict_result = predict_algo.compute(X, self.daal_model_)
266+
267+
if pred_contribs:
268+
return predict_result.prediction.ravel().reshape((-1, X.shape[1] + 1))
269+
elif pred_interactions:
270+
return predict_result.prediction.ravel().reshape(
271+
(-1, X.shape[1] + 1, X.shape[1] + 1)
272+
)
273+
elif resultsToEvaluate == "computeClassLabels":
274+
return predict_result.prediction.ravel().astype(np.int64, copy=False)
275+
else:
276+
return predict_result.probabilities
277+
216278
def _predict_regression(
217279
self, X, fptype, pred_contribs=False, pred_interactions=False
218280
):
@@ -278,11 +340,13 @@ def predict(self, X, pred_contribs=False, pred_interactions=False):
278340
if self._is_regression:
279341
return self._predict_regression(X, fptype, pred_contribs, pred_interactions)
280342
else:
281-
if pred_contribs or pred_interactions:
343+
if (pred_contribs or pred_interactions) and self.model_type != "xgboost":
282344
raise NotImplementedError(
283345
f"{'pred_contribs' if pred_contribs else 'pred_interactions'} is not implemented for classification models"
284346
)
285-
return self._predict_classification(X, fptype, "computeClassLabels")
347+
return self._predict_classification(
348+
X, fptype, "computeClassLabels", pred_contribs, pred_interactions
349+
)
286350

287351
def _check_proba(self):
288352
return not self._is_regression

daal4py/sklearn/ensemble/GBTDAAL.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,9 @@ def fit(self, X, y):
193193
# Return the classifier
194194
return self
195195

196-
def _predict(self, X, resultsToEvaluate):
196+
def _predict(
197+
self, X, resultsToEvaluate, pred_contribs=False, pred_interactions=False
198+
):
197199
# Input validation
198200
if not self.allow_nan_:
199201
X = check_array(X, dtype=[np.single, np.double])
@@ -208,17 +210,21 @@ def _predict(self, X, resultsToEvaluate):
208210
return np.full(X.shape[0], self.classes_[0])
209211

210212
fptype = getFPType(X)
211-
predict_result = self._predict_classification(X, fptype, resultsToEvaluate)
213+
predict_result = self._predict_classification(
214+
X, fptype, resultsToEvaluate, pred_contribs, pred_interactions
215+
)
212216

213-
if resultsToEvaluate == "computeClassLabels":
217+
if resultsToEvaluate == "computeClassLabels" and not (
218+
pred_contribs or pred_interactions
219+
):
214220
# Decode labels
215221
le = preprocessing.LabelEncoder()
216222
le.classes_ = self.classes_
217223
return le.inverse_transform(predict_result)
218224
return predict_result
219225

220-
def predict(self, X):
221-
return self._predict(X, "computeClassLabels")
226+
def predict(self, X, pred_contribs=False, pred_interactions=False):
227+
return self._predict(X, "computeClassLabels", pred_contribs, pred_interactions)
222228

223229
def predict_proba(self, X):
224230
return self._predict(X, "computeClassProbabilities")

tests/test_model_builders.py

Lines changed: 94 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,14 @@
4848

4949

5050
shap_required_version = (2024, "P", 1)
51+
shap_api_change_version = (2025, "P", 0)
5152
shap_supported = daal_check_version(shap_required_version)
53+
shap_api_changed = daal_check_version(shap_api_change_version)
5254
shap_not_supported_str = (
5355
f"SHAP value calculation only supported for version {shap_required_version} or later"
5456
)
5557
shap_unavailable_str = "SHAP Python package not available"
58+
shap_api_change_str = "SHAP calculation requires 2025.0 API"
5659
cb_unavailable_str = "CatBoost not available"
5760

5861
# CatBoost's SHAP value calculation seems to be buggy
@@ -208,15 +211,15 @@ def test_model_predict_shap_contribs_missing_values(self):
208211
np.testing.assert_allclose(d4p_pred, xgboost_pred, rtol=5e-6)
209212

210213

211-
# duplicate all tests for bae_score=0.0
214+
# duplicate all tests for base_score=0.0
212215
@unittest.skipUnless(shap_supported, reason=shap_not_supported_str)
213216
class XGBoostRegressionModelBuilder_base_score0(XGBoostRegressionModelBuilder):
214217
@classmethod
215218
def setUpClass(cls):
216219
XGBoostRegressionModelBuilder.setUpClass(0)
217220

218221

219-
# duplicate all tests for bae_score=100
222+
# duplicate all tests for base_score=100
220223
@unittest.skipUnless(shap_supported, reason=shap_not_supported_str)
221224
class XGBoostRegressionModelBuilder_base_score100(XGBoostRegressionModelBuilder):
222225
@classmethod
@@ -235,7 +238,7 @@ def setUpClass(cls, base_score=0.5, n_classes=2, objective="binary:logistic"):
235238
n_samples=500,
236239
n_classes=n_classes,
237240
n_features=n_features,
238-
n_informative=10,
241+
n_informative=(2 * n_features) // 3,
239242
random_state=42,
240243
)
241244
cls.X_test = X[:2, :]
@@ -282,25 +285,59 @@ def test_missing_value_support(self):
282285
def test_model_predict_shap_contribs(self):
283286
booster = self.xgb_model.get_booster()
284287
m = d4p.mb.convert_model(booster)
285-
with self.assertRaises(NotImplementedError):
286-
m.predict(self.X_test, pred_contribs=True)
288+
if not shap_api_changed:
289+
with self.assertRaises(NotImplementedError):
290+
m.predict(self.X_test, pred_contribs=True)
291+
elif self.n_classes > 2:
292+
with self.assertRaisesRegex(
293+
RuntimeError, "Multiclass classification SHAP values not supported"
294+
):
295+
m.predict(self.X_test, pred_contribs=True)
296+
else:
297+
d4p_pred = m.predict(self.X_test, pred_contribs=True)
298+
xgboost_pred = booster.predict(
299+
xgb.DMatrix(self.X_test),
300+
pred_contribs=True,
301+
approx_contribs=False,
302+
validate_features=False,
303+
)
304+
np.testing.assert_allclose(d4p_pred, xgboost_pred, rtol=1e-5)
287305

288306
def test_model_predict_shap_interactions(self):
289307
booster = self.xgb_model.get_booster()
290308
m = d4p.mb.convert_model(booster)
291-
with self.assertRaises(NotImplementedError):
292-
m.predict(self.X_test, pred_contribs=True)
293-
294-
295-
# duplicate all tests for bae_score=0.3
309+
if not shap_api_changed:
310+
with self.assertRaises(NotImplementedError):
311+
m.predict(self.X_test, pred_contribs=True)
312+
elif self.n_classes > 2:
313+
with self.assertRaisesRegex(
314+
RuntimeError, "Multiclass classification SHAP values not supported"
315+
):
316+
m.predict(self.X_test, pred_interactions=True)
317+
else:
318+
d4p_pred = m.predict(self.X_test, pred_interactions=True)
319+
xgboost_pred = booster.predict(
320+
xgb.DMatrix(self.X_test),
321+
pred_interactions=True,
322+
approx_contribs=False,
323+
validate_features=False,
324+
)
325+
# hitting floating precision limits for classification where class probabilities
326+
# are between 0 and 1
327+
# we need to accept large relative differences, as long as the absolute difference
328+
# remains small (<1e-6)
329+
np.testing.assert_allclose(d4p_pred, xgboost_pred, rtol=5e-2, atol=1e-6)
330+
331+
332+
# duplicate all tests for base_score=0.3
296333
@unittest.skipUnless(shap_supported, reason=shap_not_supported_str)
297334
class XGBoostClassificationModelBuilder_base_score03(XGBoostClassificationModelBuilder):
298335
@classmethod
299336
def setUpClass(cls):
300337
XGBoostClassificationModelBuilder.setUpClass(base_score=0.3)
301338

302339

303-
# duplicate all tests for bae_score=0.7
340+
# duplicate all tests for base_score=0.7
304341
@unittest.skipUnless(shap_supported, reason=shap_not_supported_str)
305342
class XGBoostClassificationModelBuilder_base_score07(XGBoostClassificationModelBuilder):
306343
@classmethod
@@ -328,6 +365,16 @@ def setUpClass(cls):
328365
class XGBoostClassificationModelBuilder_objective_logitraw(
329366
XGBoostClassificationModelBuilder
330367
):
368+
"""
369+
Caveat: logitraw is not per se supported in daal4py because we always
370+
371+
1. apply the bias
372+
2. normalize to probabilities ("activation") using sigmoid
373+
(exception: SHAP values, the scores defining phi_ij are the raw class scores)
374+
375+
However, by undoing the activation and bias we can still compare if the original probas and SHAP values are aligned.
376+
"""
377+
331378
@classmethod
332379
def setUpClass(cls):
333380
XGBoostClassificationModelBuilder.setUpClass(
@@ -352,6 +399,42 @@ def test_model_predict_proba(self):
352399
# accept an rtol of 1e-5
353400
np.testing.assert_allclose(d4p_pred, xgboost_pred, rtol=1e-5)
354401

402+
@unittest.skipUnless(shap_api_changed, reason=shap_api_change_str)
403+
def test_model_predict_shap_contribs(self):
404+
booster = self.xgb_model.get_booster()
405+
with self.assertWarns(UserWarning):
406+
# expect a warning that logitraw behaves differently and/or
407+
# that base_score is ignored / fixed to 0.5
408+
m = d4p.mb.convert_model(self.xgb_model.get_booster())
409+
d4p_pred = m.predict(self.X_test, pred_contribs=True)
410+
xgboost_pred = booster.predict(
411+
xgb.DMatrix(self.X_test),
412+
pred_contribs=True,
413+
approx_contribs=False,
414+
validate_features=False,
415+
)
416+
# undo bias
417+
d4p_pred[:, -1] += 0.5
418+
np.testing.assert_allclose(d4p_pred, xgboost_pred, rtol=5e-6)
419+
420+
@unittest.skipUnless(shap_api_changed, reason=shap_api_change_str)
421+
def test_model_predict_shap_interactions(self):
422+
booster = self.xgb_model.get_booster()
423+
with self.assertWarns(UserWarning):
424+
# expect a warning that logitraw behaves differently and/or
425+
# that base_score is ignored / fixed to 0.5
426+
m = d4p.mb.convert_model(self.xgb_model.get_booster())
427+
d4p_pred = m.predict(self.X_test, pred_interactions=True)
428+
xgboost_pred = booster.predict(
429+
xgb.DMatrix(self.X_test),
430+
pred_interactions=True,
431+
approx_contribs=False,
432+
validate_features=False,
433+
)
434+
# undo bias
435+
d4p_pred[:, -1, -1] += 0.5
436+
np.testing.assert_allclose(d4p_pred, xgboost_pred, rtol=5e-5)
437+
355438

356439
@unittest.skipUnless(shap_supported, reason=shap_not_supported_str)
357440
class LightGBMRegressionModelBuilder(unittest.TestCase):

0 commit comments

Comments
 (0)