Skip to content

Commit 8c83c8f

Browse files
committed
[TMVA][Python] Fix feature name handling for RBDT::LoadText
If the XGBoost model encodes feature names, they will also be used in the `.txt` dump of the model. We have to use these names in `RBDT::LoadTxt` as well, so that the `.txt` file can be read correctly without errors. The RBDT unit test is also updated to cover this case of custom feature names, which happens when the training data comes from a pandas DataFrame. Also implement some suggestions by ruff. Closes #20267.
1 parent 74e607e commit 8c83c8f

File tree

3 files changed

+26
-18
lines changed

3 files changed

+26
-18
lines changed

bindings/pyroot/pythonizations/python/ROOT/_pythonization/_tmva/_tree_inference.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,10 @@
88
# For the list of contributors see $ROOTSYS/README/CREDITS. #
99
################################################################################
1010

11-
from .. import pythonization
12-
import cppyy
13-
1411
import json
1512

13+
import cppyy
14+
1615

1716
def get_basescore(model):
1817
"""Get base score from an XGBoost sklearn estimator.
@@ -60,7 +59,7 @@ def SaveXGBoost(xgb_model, key_name, output_path, num_inputs):
6059
"reg:squarederror": "identity",
6160
}
6261
model_objective = xgb_model.objective
63-
if not model_objective in objective_map:
62+
if model_objective not in objective_map:
6463
raise Exception(
6564
'XGBoost model has unsupported objective "{}". Supported objectives are {}.'.format(
6665
model_objective, objective_map.keys()
@@ -74,13 +73,13 @@ def SaveXGBoost(xgb_model, key_name, output_path, num_inputs):
7473
# Dump XGB model as json file
7574
xgb_model.get_booster().dump_model(output_path, dump_format="json")
7675

77-
with open(output_path, "r") as json_file:
78-
forest = json.load(json_file)
79-
8076
# Dump XGB model as txt file
8177
xgb_model.get_booster().dump_model(output_path)
8278

83-
features = cppyy.gbl.std.vector["std::string"]([f"f{i}" for i in range(num_inputs)])
79+
if xgb_model.get_booster().feature_names is None:
80+
features = cppyy.gbl.std.vector["std::string"]([f"f{i}" for i in range(num_inputs)])
81+
else:
82+
features = cppyy.gbl.std.vector["std::string"](xgb_model.get_booster().feature_names)
8483
bs = get_basescore(xgb_model)
8584
logistic = objective == "logistic"
8685
bdt = cppyy.gbl.TMVA.Experimental.RBDT.LoadText(

tmva/tmva/test/CMakeLists.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,12 @@ endif()
3737

3838
if(dataframe)
3939
# Test uses the xgboost sklearn plugin, so we need to check for sklearn too.
40+
# It also uses pandas to test the case where the training data is passed via
41+
# a pandas DataFrame.
4042
ROOT_FIND_PYTHON_MODULE(xgboost)
43+
ROOT_FIND_PYTHON_MODULE(pandas)
4144
ROOT_FIND_PYTHON_MODULE(sklearn)
42-
if (ROOT_XGBOOST_FOUND AND ROOT_SKLEARN_FOUND)
45+
if (ROOT_XGBOOST_FOUND AND ROOT_SKLEARN_FOUND AND ROOT_PANDAS_FOUND)
4346
ROOT_ADD_PYUNITTEST(rbdt_xgboost rbdt_xgboost.py)
4447
endif()
4548
endif()

tmva/tmva/test/rbdt_xgboost.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
1-
# XGBoost has to be imported before ROOT to avoid crashes because of clashing
2-
# std::regexp symbols that are exported by cppyy.
3-
# See also: https://github.com/wlav/cppyy/issues/227
4-
import xgboost
5-
61
import unittest
7-
import ROOT
2+
83
import numpy as np
9-
import json
4+
import pandas
5+
import ROOT
6+
import xgboost
107

118
np.random.seed(1234)
129

@@ -41,9 +38,18 @@ def _test_XGBRegression(label):
4138
"""
4239
Compare response of XGB regressor and TMVA tree inference system.
4340
"""
44-
x, y = create_dataset(1000, 10, 1)
41+
n_samples = 1000
42+
n_features = 10
43+
x, y = create_dataset(n_samples, n_features, 1)
44+
# Other than in the XGBBinary test, we're passing the training features via
45+
# a pandas DataFrame this time. In that case, XGBoost will define custom
46+
# feature names according to the column names in the dataframe, and we can
47+
# test the case where the feature names in the .txt dump are not the
48+
# default "f0", "f1", "f2", etc.
49+
df_x = pandas.DataFrame({f"myfeature_{i}": x[:, i] for i in range(n_features)})
50+
assert len(x) == len(df_x)
4551
xgb = xgboost.XGBRegressor(n_estimators=1, max_depth=3)
46-
xgb.fit(x, y)
52+
xgb.fit(df_x, y)
4753
ROOT.TMVA.Experimental.SaveXGBoost(xgb, "myModel", "testXGBRegression{}.root".format(label), num_inputs=10)
4854
bdt = ROOT.TMVA.Experimental.RBDT("myModel", "testXGBRegression{}.root".format(label))
4955

0 commit comments

Comments
 (0)