1919from collections import deque
2020from copy import deepcopy
2121from tempfile import NamedTemporaryFile
22- from typing import Any , Deque , Dict , List , Optional , Tuple
22+ from typing import Any , Deque , Dict , List , Optional , Tuple , Union
2323
2424import numpy as np
2525
@@ -393,7 +393,7 @@ def get_gbt_model_from_tree_list(
393393 is_regression : bool ,
394394 n_features : int ,
395395 n_classes : int ,
396- base_score : Optional [float ] = None ,
396+ base_score : Optional [Union [ float , List [ float ]] ] = None ,
397397):
398398 """Return a GBT Model from TreeList"""
399399
@@ -412,11 +412,21 @@ def get_gbt_model_from_tree_list(
412412 else :
413413 tree_id = mb .create_tree (n_nodes = tree .n_nodes , class_label = class_label )
414414
415+ # Note: starting from xgboost>=3.1.0, multi-class classification models have
416+ # vector-valued intercepts. Since oneDAL doesn't support these, it instead
417+ # adds the scores to all of the terminal leafs in the first tree.
418+ if isinstance (base_score , list ) and counter <= n_classes :
419+ intercept_add = base_score [counter - 1 ]
420+ else :
421+ intercept_add = 0.0
422+
415423 if counter % n_iterations == 0 :
416424 class_label += 1
417425
418426 if tree .is_leaf :
419- mb .add_leaf (tree_id = tree_id , response = tree .value , cover = tree .cover )
427+ mb .add_leaf (
428+ tree_id = tree_id , response = tree .value + intercept_add , cover = tree .cover
429+ )
420430 continue
421431
422432 root_node = tree .root_node
@@ -445,7 +455,7 @@ def get_gbt_model_from_tree_list(
445455 if node .is_leaf :
446456 mb .add_leaf (
447457 tree_id = tree_id ,
448- response = node .value ,
458+ response = node .value + intercept_add ,
449459 cover = node .cover ,
450460 parent_id = node .parent_id ,
451461 position = node .position ,
@@ -468,7 +478,7 @@ def get_gbt_model_from_tree_list(
468478 child .position = position
469479 node_queue .append (child )
470480
471- return mb .model (base_score = base_score )
481+ return mb .model (base_score = base_score if isinstance ( base_score , float ) else None )
472482
473483
474484def get_gbt_model_from_lightgbm (model : Any , booster = None ) -> Any :
@@ -543,7 +553,18 @@ def get_gbt_model_from_xgboost(booster: Any, xgb_config=None) -> Any:
543553
544554 n_features = int (xgb_config ["learner" ]["learner_model_param" ]["num_feature" ])
545555 n_classes = int (xgb_config ["learner" ]["learner_model_param" ]["num_class" ])
546- base_score = float (xgb_config ["learner" ]["learner_model_param" ]["base_score" ])
556+ # Note: base scores in XGBoost might be vector-valued starting from version 3.1.0.
557+ # When this is the case, the 'base_score' attribute will be a JSON list, otherwise
558+ # it will be a scalar. Note that in either case, it will be in the response scale.
559+ base_score_str : str = xgb_config ["learner" ]["learner_model_param" ]["base_score" ]
560+ if base_score_str .startswith ("[" ):
561+ base_score = json .loads (base_score_str )
562+ if len (base_score ) == 1 :
563+ base_score = base_score [0 ]
564+ elif len (base_score ) == 0 :
565+ base_score = 0.5
566+ else :
567+ base_score = float (base_score_str )
547568
548569 is_regression = False
549570 objective_fun = xgb_config ["learner" ]["learner_train_param" ]["objective" ]
0 commit comments