Skip to content

Commit bfd8a32

Browse files
committed
[TMVA][Python] Fix feature name handling for RBDT::LoadText
If the XGBoost model encodes feature names, they will also be used in the `.txt` dump of the model. We have to use these names in `RBDT::LoadTxt` as well, so that the `.txt` file can be read correctly without errors. The RBDT unit test is also updated to cover this case of custom feature names, which happens when the training data comes from a pandas DataFrame. Closes #20267.
1 parent 775f235 commit bfd8a32

File tree

3 files changed

+23
-10
lines changed

3 files changed

+23
-10
lines changed

bindings/pyroot/pythonizations/python/ROOT/_pythonization/_tmva/_tree_inference.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,10 @@ def SaveXGBoost(xgb_model, key_name, output_path, num_inputs):
8080
# Dump XGB model as txt file
8181
xgb_model.get_booster().dump_model(output_path)
8282

83-
features = cppyy.gbl.std.vector["std::string"]([f"f{i}" for i in range(num_inputs)])
83+
if xgb_model.get_booster().feature_names is None:
84+
features = cppyy.gbl.std.vector["std::string"]([f"f{i}" for i in range(num_inputs)])
85+
else:
86+
features = cppyy.gbl.std.vector["std::string"](xgb_model.get_booster().feature_names)
8487
bs = get_basescore(xgb_model)
8588
logistic = objective == "logistic"
8689
bdt = cppyy.gbl.TMVA.Experimental.RBDT.LoadText(

tmva/tmva/test/CMakeLists.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,12 @@ endif()
3737

3838
if(dataframe)
3939
# Test uses the xgboost sklearn plugin, so we need to check for sklearn too.
40+
# It also uses pandas to test the case where the training data is passed via
41+
# a pandas DataFrame.
4042
ROOT_FIND_PYTHON_MODULE(xgboost)
43+
ROOT_FIND_PYTHON_MODULE(pandas)
4144
ROOT_FIND_PYTHON_MODULE(sklearn)
42-
if (ROOT_XGBOOST_FOUND AND ROOT_SKLEARN_FOUND)
45+
if (ROOT_XGBOOST_FOUND AND ROOT_SKLEARN_FOUND AND ROOT_PANDAS_FOUND)
4346
ROOT_ADD_PYUNITTEST(rbdt_xgboost rbdt_xgboost.py)
4447
endif()
4548
endif()

tmva/tmva/test/rbdt_xgboost.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
1-
# XGBoost has to be imported before ROOT to avoid crashes because of clashing
2-
# std::regexp symbols that are exported by cppyy.
3-
# See also: https://github.com/wlav/cppyy/issues/227
4-
import xgboost
1+
import json
52

6-
import unittest
73
import ROOT
84
import numpy as np
9-
import json
5+
import pandas
6+
import unittest
7+
import xgboost
108

119
np.random.seed(1234)
1210

@@ -41,9 +39,18 @@ def _test_XGBRegression(label):
4139
"""
4240
Compare response of XGB regressor and TMVA tree inference system.
4341
"""
44-
x, y = create_dataset(1000, 10, 1)
42+
n_samples = 1000
43+
n_features = 10
44+
x, y = create_dataset(n_samples, n_features, 1)
45+
# Other than in the XGBBinary test, we're passing the training features via
46+
# a pandas DataFrame this time. In that case, XGBoost will define custom
47+
# feature names according to the column names in the dataframe, and we can
48+
# test the case where the feature names in the .txt dump are not the
49+
# default "f0", "f1", "f2", etc.
50+
df_x = pandas.DataFrame({f"myfeature_{i}": x[:, i] for i in range(n_features)})
51+
assert len(x) == len(df_x)
4552
xgb = xgboost.XGBRegressor(n_estimators=1, max_depth=3)
46-
xgb.fit(x, y)
53+
xgb.fit(df_x, y)
4754
ROOT.TMVA.Experimental.SaveXGBoost(xgb, "myModel", "testXGBRegression{}.root".format(label), num_inputs=10)
4855
bdt = ROOT.TMVA.Experimental.RBDT("myModel", "testXGBRegression{}.root".format(label))
4956

0 commit comments

Comments
 (0)