Skip to content

Commit c96f018

Browse files
MAINT: Updates for XGBoost3.1 model structures (#2741)
* updates for xgboost3.1 model formats * update xgboost * try another way * remove workaround for treelite incompatibilities
1 parent 823a9ff commit c96f018

File tree

3 files changed

+38
-13
lines changed

3 files changed

+38
-13
lines changed

daal4py/mb/gbt_convertors.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from collections import deque
2020
from copy import deepcopy
2121
from tempfile import NamedTemporaryFile
22-
from typing import Any, Deque, Dict, List, Optional, Tuple
22+
from typing import Any, Deque, Dict, List, Optional, Tuple, Union
2323

2424
import numpy as np
2525

@@ -393,7 +393,7 @@ def get_gbt_model_from_tree_list(
393393
is_regression: bool,
394394
n_features: int,
395395
n_classes: int,
396-
base_score: Optional[float] = None,
396+
base_score: Optional[Union[float, List[float]]] = None,
397397
):
398398
"""Return a GBT Model from TreeList"""
399399

@@ -412,11 +412,21 @@ def get_gbt_model_from_tree_list(
412412
else:
413413
tree_id = mb.create_tree(n_nodes=tree.n_nodes, class_label=class_label)
414414

415+
# Note: starting from xgboost>=3.1.0, multi-class classification models have
416+
# vector-valued intercepts. Since oneDAL doesn't support these, it instead
417+
# adds the scores to all of the terminal leafs in the first tree.
418+
if isinstance(base_score, list) and counter <= n_classes:
419+
intercept_add = base_score[counter - 1]
420+
else:
421+
intercept_add = 0.0
422+
415423
if counter % n_iterations == 0:
416424
class_label += 1
417425

418426
if tree.is_leaf:
419-
mb.add_leaf(tree_id=tree_id, response=tree.value, cover=tree.cover)
427+
mb.add_leaf(
428+
tree_id=tree_id, response=tree.value + intercept_add, cover=tree.cover
429+
)
420430
continue
421431

422432
root_node = tree.root_node
@@ -445,7 +455,7 @@ def get_gbt_model_from_tree_list(
445455
if node.is_leaf:
446456
mb.add_leaf(
447457
tree_id=tree_id,
448-
response=node.value,
458+
response=node.value + intercept_add,
449459
cover=node.cover,
450460
parent_id=node.parent_id,
451461
position=node.position,
@@ -468,7 +478,7 @@ def get_gbt_model_from_tree_list(
468478
child.position = position
469479
node_queue.append(child)
470480

471-
return mb.model(base_score=base_score)
481+
return mb.model(base_score=base_score if isinstance(base_score, float) else None)
472482

473483

474484
def get_gbt_model_from_lightgbm(model: Any, booster=None) -> Any:
@@ -543,7 +553,18 @@ def get_gbt_model_from_xgboost(booster: Any, xgb_config=None) -> Any:
543553

544554
n_features = int(xgb_config["learner"]["learner_model_param"]["num_feature"])
545555
n_classes = int(xgb_config["learner"]["learner_model_param"]["num_class"])
546-
base_score = float(xgb_config["learner"]["learner_model_param"]["base_score"])
556+
# Note: base scores in XGBoost might be vector-valued starting from version 3.1.0.
557+
# When this is the case, the 'base_score' attribute will be a JSON list, otherwise
558+
# it will be a scalar. Note that in either case, it will be in the response scale.
559+
base_score_str: str = xgb_config["learner"]["learner_model_param"]["base_score"]
560+
if base_score_str.startswith("["):
561+
base_score = json.loads(base_score_str)
562+
if len(base_score) == 1:
563+
base_score = base_score[0]
564+
elif len(base_score) == 0:
565+
base_score = 0.5
566+
else:
567+
base_score = float(base_score_str)
547568

548569
is_regression = False
549570
objective_fun = xgb_config["learner"]["learner_train_param"]["objective"]

requirements-test.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ scikit-learn==1.7.2 ; python_version >= '3.10'
1111
pandas==2.1.3 ; python_version < '3.11'
1212
pandas==2.3.3 ; python_version >= '3.11'
1313
xgboost==2.1.4 ; python_version < '3.10'
14-
xgboost==3.0.5 ; python_version >= '3.10'
14+
xgboost==3.1.1 ; python_version >= '3.10'
1515
lightgbm==4.6.0
1616
catboost==1.2.8
1717
shap==0.49.1

tests/test_model_builders.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -293,9 +293,11 @@ def test_xgb_regression_shap(
293293
pytest.skip()
294294

295295
xgb_model = make_xgb_model(objective, base_score, sklearn_class, empty_trees)
296-
d4p_model = d4p.mb.convert_model(
297-
xgb_model if not from_treelite else treelite.frontend.from_xgboost(xgb_model)
298-
)
296+
if from_treelite:
297+
tl_model = treelite.frontend.from_xgboost(xgb_model)
298+
d4p_model = d4p.mb.convert_model(tl_model)
299+
else:
300+
d4p_model = d4p.mb.convert_model(xgb_model)
299301

300302
if sklearn_class:
301303
xgb_model = xgb_model.get_booster()
@@ -430,9 +432,11 @@ def test_xgb_binary_classification_shap(
430432
if sklearn_class and from_treelite:
431433
pytest.skip()
432434
xgb_model = make_xgb_model(objective, base_score, sklearn_class, empty_trees)
433-
d4p_model = d4p.mb.convert_model(
434-
xgb_model if not from_treelite else treelite.frontend.from_xgboost(xgb_model)
435-
)
435+
if from_treelite:
436+
tl_model = treelite.frontend.from_xgboost(xgb_model)
437+
d4p_model = d4p.mb.convert_model(tl_model)
438+
else:
439+
d4p_model = d4p.mb.convert_model(xgb_model)
436440

437441
if sklearn_class:
438442
xgb_model = xgb_model.get_booster()

0 commit comments

Comments
 (0)