1919from copy import deepcopy
2020from tempfile import NamedTemporaryFile
2121from typing import Any , Deque , Dict , List , Optional , Tuple
22- from warnings import warn
2322
2423import numpy as np
2524
@@ -94,18 +93,9 @@ def n_iterations(self):
9493 else :
9594 return len (self .trees )
9695
97- @property
98- def bias (self ):
99- if self .is_classification :
100- return 0
101- return self .__data ["scale_and_bias" ][1 ][0 ] / self .n_iterations
102-
10396 @property
10497 def scale (self ):
105- if self .is_classification :
106- return 1
107- else :
108- return self .__data ["scale_and_bias" ][0 ]
98+ return self .__data ["scale_and_bias" ][0 ]
10999
110100 @property
111101 def default_left (self ):
@@ -223,7 +213,7 @@ def feature(self) -> int:
223213 return self .__feature
224214 if isinstance (self .__feature , str ) and self .__feature .isnumeric ():
225215 return int (self .__feature )
226- raise ValueError (
216+ raise AttributeError (
227217 f"Feature names must be integers (got ({ type (self .__feature )} ){ self .__feature } )"
228218 )
229219
@@ -242,15 +232,15 @@ def is_leaf(self) -> bool:
242232 @property
243233 def value (self ) -> float :
244234 if not self .is_leaf :
245- raise ValueError ("Tree is not a leaf-only tree" )
235+ raise AttributeError ("Tree is not a leaf-only tree" )
246236 if self .root_node .value is None :
247- raise ValueError ("Tree is leaf-only but leaf node has no value" )
237+ raise AttributeError ("Tree is leaf-only but leaf node has no value" )
248238 return self .root_node .value
249239
250240 @property
251241 def cover (self ) -> float :
252242 if not self .is_leaf :
253- raise ValueError ("Tree is not a leaf-only tree" )
243+ raise AttributeError ("Tree is not a leaf-only tree" )
254244 return self .root_node .cover
255245
256246 @property
@@ -427,6 +417,18 @@ def get_gbt_model_from_tree_list(
427417
428418
429419def get_gbt_model_from_lightgbm (model : Any , booster = None ) -> Any :
420+ model_str = model .model_to_string ()
421+ if "is_linear=1" in model_str :
422+ raise TypeError ("Linear trees are not supported." )
423+ if "[boosting: dart]" in model_str :
424+ raise TypeError ("'Dart' booster is not supported." )
425+ if "[boosting: rf]" in model_str :
426+ raise TypeError ("Random forest boosters are not supported." )
427+ if ("[objective: lambdarank]" in model_str ) or (
428+ "[objective: rank_xendcg]" in model_str
429+ ):
430+ raise TypeError ("Ranking objectives are not supported." )
431+
430432 if booster is None :
431433 booster = model .dump_model ()
432434
@@ -437,9 +439,9 @@ def get_gbt_model_from_lightgbm(model: Any, booster=None) -> Any:
437439 is_regression = False
438440 objective_fun = booster ["objective" ]
439441 if n_classes > 2 :
440- if "multiclass" not in objective_fun :
442+ if ( "ova" in objective_fun ) or ( "ovr" in objective_fun ) :
441443 raise TypeError (
442- "multiclass (softmax) objective is only supported for multiclass classification"
444+ "Only multiclass (softmax) objective is supported for multiclass classification"
443445 )
444446 elif "binary" in objective_fun : # nClasses == 1
445447 n_classes = 2
@@ -473,6 +475,13 @@ def get_gbt_model_from_xgboost(booster: Any, xgb_config=None) -> Any:
473475 if xgb_config is None :
474476 xgb_config = get_xgboost_params (booster )
475477
478+ if xgb_config ["learner" ]["learner_train_param" ]["booster" ] != "gbtree" :
479+ raise TypeError ("Only 'gbtree' booster type is supported." )
480+
481+ n_targets = xgb_config ["learner" ]["learner_model_param" ].get ("num_target" )
482+ if n_targets is not None and int (n_targets ) > 1 :
483+ raise TypeError ("Multi-target boosters are not supported." )
484+
476485 n_features = int (xgb_config ["learner" ]["learner_model_param" ]["num_feature" ])
477486 n_classes = int (xgb_config ["learner" ]["learner_model_param" ]["num_class" ])
478487 base_score = float (xgb_config ["learner" ]["learner_model_param" ]["base_score" ])
@@ -504,11 +513,6 @@ def get_gbt_model_from_xgboost(booster: Any, xgb_config=None) -> Any:
504513 if objective_fun == "binary:logitraw" :
505514 # daal4py always applies a sigmoid for pred_proba, wheres XGBoost
506515 # returns raw predictions with logitraw
507- warn (
508- "objective='binary:logitraw' selected\n "
509- "XGBoost returns raw class scores when calling pred_proba()\n "
510- "whilst scikit-learn-intelex always uses binary:logistic\n "
511- )
512516 base_score = float (1 / (1 + np .exp (- base_score )))
513517 else :
514518 is_regression = True
@@ -567,6 +571,22 @@ def get_gbt_model_from_catboost(booster: Any) -> Any:
567571 "Categorical features are not supported in daal4py Gradient Boosting Trees"
568572 )
569573
574+ objective = booster .get_params ().get ("objective" , "" )
575+ if (
576+ "Rank" in objective
577+ or "Query" in objective
578+ or "Pair" in objective
579+ or objective in ["LambdaMart" , "StochasticFilter" , "GroupQuantile" ]
580+ ):
581+ raise TypeError ("Ranking objectives are not supported." )
582+ if "Multi" in objective and objective != "MultiClass" :
583+ if model .is_classification :
584+ raise TypeError (
585+ "Only 'MultiClass' loss is supported for multi-class classification."
586+ )
587+ else :
588+ raise TypeError ("Multi-output models are not supported." )
589+
570590 if model .is_classification :
571591 mb = gbt_clf_model_builder (
572592 n_features = model .n_features ,
@@ -587,21 +607,37 @@ def get_gbt_model_from_catboost(booster: Any) -> Any:
587607 {"feature_index" : feature ["feature_index" ], "value" : feature_border }
588608 )
589609
610+ # Note: catboost models might have a 'bias' (intercept) which gets added
611+ # to all predictions. In the case of single-output models, this is a scalar,
612+ # but in the case of multi-output models such as multinomial logistic, it
613+ # is a vector. Since daal4py doesn't support vector-valued intercepts, this
614+ # adds the intercept to every terminal node instead, by dividing it equally
615+ # among all trees. Usually, catboost would anyway set them to zero, but it
616+ # still allows setting custom intercepts.
617+ cb_bias = booster .get_scale_and_bias ()[1 ]
618+ add_intercept_to_each_node = isinstance (cb_bias , list )
619+ if add_intercept_to_each_node :
620+ cb_bias = np .array (cb_bias ) / model .n_iterations
621+ if not model .is_classification :
622+ raise TypeError ("Multi-output regression models are not supported." )
623+
624+ def add_vector_bias (values : list [float ]) -> list [float ]:
625+ return list (np .array (values ) + cb_bias )
626+
590627 trees_explicit = []
591628 tree_symmetric = []
592629
630+ all_trees_are_empty = True
631+
593632 if model .is_symmetric_tree :
594633 for tree in model .oblivious_trees :
595- cur_tree_depth = len (tree .get ("splits" , []))
634+ tree_splits = tree .get ("splits" , [])
635+ cur_tree_depth = len (tree_splits ) if tree_splits is not None else 0
596636 tree_symmetric .append ((tree , cur_tree_depth ))
597637 else :
598638 for tree in model .trees :
599639 n_nodes = 1
600- if "split" not in tree :
601- # handle leaf node
602- values = __get_value_as_list (tree )
603- root_node = CatBoostNode (value = [value * model .scale for value in values ])
604- continue
640+
605641 # Check if node is a leaf (in case of stump)
606642 if "split" in tree :
607643 # Get number of trees and splits info via BFS
@@ -622,12 +658,15 @@ def get_gbt_model_from_catboost(booster: Any) -> Any:
622658 nodes_queue .append ((cur_node_data ["left" ], left_node ))
623659 nodes_queue .append ((cur_node_data ["right" ], right_node ))
624660 n_nodes += 2
661+ all_trees_are_empty = False
625662 else :
626663 root_node = CatBoostNode ()
627664 if model .is_classification and model .n_classes > 2 :
628665 root_node .value = [value * model .scale for value in tree ["value" ]]
666+ if add_intercept_to_each_node :
667+ root_node .value = add_vector_bias (root_node .value )
629668 else :
630- root_node .value = [tree ["value" ] * model .scale + model . bias ]
669+ root_node .value = [tree ["value" ] * model .scale ]
631670 trees_explicit .append ((root_node , n_nodes ))
632671
633672 tree_id = []
@@ -646,9 +685,15 @@ def get_gbt_model_from_catboost(booster: Any) -> Any:
646685 for i in range (model .n_iterations ):
647686 for _ in range (n_tree_each_iter ):
648687 if model .is_symmetric_tree :
649- n_nodes = 2 ** (tree_symmetric [i ][1 ] + 1 ) - 1
688+ if not len (tree_symmetric ):
689+ n_nodes = 1
690+ else :
691+ n_nodes = 2 ** (tree_symmetric [i ][1 ] + 1 ) - 1
650692 else :
651- n_nodes = trees_explicit [i ][1 ]
693+ if not len (trees_explicit ):
694+ n_nodes = 1
695+ else :
696+ n_nodes = trees_explicit [i ][1 ]
652697
653698 if model .is_classification and model .n_classes > 2 :
654699 tree_id .append (mb .create_tree (n_nodes , class_label ))
@@ -663,9 +708,9 @@ def get_gbt_model_from_catboost(booster: Any) -> Any:
663708 tree_id .append (mb .create_tree (n_nodes ))
664709
665710 if model .is_symmetric_tree :
711+ shap_ready = True # this code branch provides all info for SHAP values
666712 for class_label in range (n_tree_each_iter ):
667713 for i in range (model .n_iterations ):
668- shap_ready = True # this code branch provides all info for SHAP values
669714 cur_tree_info = tree_symmetric [i ][0 ]
670715 cur_tree_id = tree_id [i * n_tree_each_iter + class_label ]
671716 cur_tree_leaf_val = cur_tree_info ["leaf_values" ]
@@ -674,7 +719,8 @@ def get_gbt_model_from_catboost(booster: Any) -> Any:
674719 if cur_tree_depth == 0 :
675720 mb .add_leaf (
676721 tree_id = cur_tree_id ,
677- response = cur_tree_leaf_val [0 ],
722+ response = cur_tree_leaf_val [class_label ] * model .scale
723+ + (cb_bias [class_label ] if add_intercept_to_each_node else 0 ),
678724 cover = cur_tree_leaf_weights [0 ],
679725 )
680726 else :
@@ -686,6 +732,7 @@ def get_gbt_model_from_catboost(booster: Any) -> Any:
686732 cur_tree_leaf_weights
687733 )
688734 root_weight = cur_tree_weights_per_level [0 ][0 ]
735+
689736 root_id = mb .add_split (
690737 tree_id = cur_tree_id ,
691738 feature_index = cur_level_split ["feature_index" ],
@@ -704,14 +751,17 @@ def get_gbt_model_from_catboost(booster: Any) -> Any:
704751 cur_level_split = splits [
705752 cur_tree_info ["splits" ][cur_level ]["split_index" ]
706753 ]
754+ cover_nodes = next_level_weights [cur_level_node_index ]
755+ if cover_nodes == 0 :
756+ shap_ready = False
707757 cur_left_node = mb .add_split (
708758 tree_id = cur_tree_id ,
709759 parent_id = cur_parent ,
710760 position = 0 ,
711761 feature_index = cur_level_split ["feature_index" ],
712762 feature_value = cur_level_split ["value" ],
713763 default_left = model .default_left ,
714- cover = next_level_weights [ cur_level_node_index ] ,
764+ cover = cover_nodes ,
715765 )
716766 # cur_level_node_index += 1
717767 cur_right_node = mb .add_split (
@@ -721,7 +771,7 @@ def get_gbt_model_from_catboost(booster: Any) -> Any:
721771 feature_index = cur_level_split ["feature_index" ],
722772 feature_value = cur_level_split ["value" ],
723773 default_left = model .default_left ,
724- cover = next_level_weights [ cur_level_node_index ] ,
774+ cover = cover_nodes ,
725775 )
726776 # cur_level_node_index += 1
727777 cur_level_nodes .append (cur_left_node )
@@ -734,17 +784,15 @@ def get_gbt_model_from_catboost(booster: Any) -> Any:
734784 mb .add_leaf (
735785 tree_id = cur_tree_id ,
736786 response = cur_tree_leaf_val [2 * last_level_node_num ]
737- * model .scale
738- + model .bias ,
787+ * model .scale ,
739788 parent_id = prev_level_nodes [last_level_node_num ],
740789 position = 0 ,
741790 cover = cur_tree_leaf_weights [2 * last_level_node_num ],
742791 )
743792 mb .add_leaf (
744793 tree_id = cur_tree_id ,
745794 response = cur_tree_leaf_val [2 * last_level_node_num + 1 ]
746- * model .scale
747- + model .bias ,
795+ * model .scale ,
748796 parent_id = prev_level_nodes [last_level_node_num ],
749797 position = 1 ,
750798 cover = cur_tree_leaf_weights [2 * last_level_node_num + 1 ],
@@ -761,21 +809,30 @@ def get_gbt_model_from_catboost(booster: Any) -> Any:
761809 mb .add_leaf (
762810 tree_id = cur_tree_id ,
763811 response = cur_tree_leaf_val [left_index ] * model .scale
764- + model .bias ,
812+ + (
813+ cb_bias [class_label ]
814+ if add_intercept_to_each_node
815+ else 0
816+ ),
765817 parent_id = prev_level_nodes [last_level_node_num ],
766818 position = 0 ,
767819 cover = 0.0 ,
768820 )
769821 mb .add_leaf (
770822 tree_id = cur_tree_id ,
771823 response = cur_tree_leaf_val [right_index ] * model .scale
772- + model .bias ,
824+ + (
825+ cb_bias [class_label ]
826+ if add_intercept_to_each_node
827+ else 0
828+ ),
773829 parent_id = prev_level_nodes [last_level_node_num ],
774830 position = 1 ,
775831 cover = 0.0 ,
776832 )
777833 else :
778834 shap_ready = False
835+ scale = booster .get_scale_and_bias ()[0 ]
779836 for class_label in range (n_tree_each_iter ):
780837 for i in range (model .n_iterations ):
781838 root_node = trees_explicit [i ][0 ]
@@ -809,7 +866,12 @@ def get_gbt_model_from_catboost(booster: Any) -> Any:
809866 else :
810867 mb .add_leaf (
811868 tree_id = cur_tree_id ,
812- response = left_node .value [class_label ],
869+ response = scale * left_node .value [class_label ]
870+ + (
871+ cb_bias [class_label ]
872+ if add_intercept_to_each_node
873+ else 0
874+ ),
813875 parent_id = cur_node_id ,
814876 position = 0 ,
815877 cover = 0.0 ,
@@ -830,26 +892,31 @@ def get_gbt_model_from_catboost(booster: Any) -> Any:
830892 else :
831893 mb .add_leaf (
832894 tree_id = cur_tree_id ,
833- response = cur_node .right .value [class_label ],
895+ response = scale * cur_node .right .value [class_label ]
896+ + (
897+ cb_bias [class_label ]
898+ if add_intercept_to_each_node
899+ else 0
900+ ),
834901 parent_id = cur_node_id ,
835902 position = 1 ,
836903 cover = 0.0 ,
837904 )
838905
839906 else :
840907 # Tree has only one node
908+ # Note: the root node already has scale and bias added to it,
909+ # so no need to add them again here like it is done for the leafs.
841910 mb .add_leaf (
842911 tree_id = cur_tree_id ,
843912 response = root_node .value [class_label ],
844913 cover = 0.0 ,
845914 )
846915
847- if not shap_ready :
848- warn ("Converted models of this type do not support SHAP value calculation" )
849- else :
850- warn (
851- "CatBoost SHAP values seem to be incorrect. "
852- "Values from converted models will differ. "
853- "See https://github.com/catboost/catboost/issues/2556 for more details."
854- )
855- return mb .model (base_score = 0.0 )
916+ if all_trees_are_empty and not model .is_symmetric_tree :
917+ shap_ready = True
918+
919+ intercept = 0.0
920+ if not add_intercept_to_each_node :
921+ intercept = booster .get_scale_and_bias ()[1 ]
922+ return mb .model (base_score = intercept ), shap_ready
0 commit comments