Skip to content

Commit 35fb390

Browse files
committed
[TMVA][Python] Fix feature name handling for RBDT::LoadText
If the XGBoost model encodes feature names, they will also be used in the `.txt` dump of the model. We have to use these names in `RBDT::LoadTxt` as well, so that the `.txt` file can be read correctly without errors. Closes #20267.
1 parent 775f235 commit 35fb390

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

bindings/pyroot/pythonizations/python/ROOT/_pythonization/_tmva/_tree_inference.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,10 @@ def SaveXGBoost(xgb_model, key_name, output_path, num_inputs):
8080
# Dump XGB model as txt file
8181
xgb_model.get_booster().dump_model(output_path)
8282

83-
features = cppyy.gbl.std.vector["std::string"]([f"f{i}" for i in range(num_inputs)])
83+
if xgb_model.get_booster().feature_names is None:
84+
features = cppyy.gbl.std.vector["std::string"]([f"f{i}" for i in range(num_inputs)])
85+
else:
86+
features = cppyy.gbl.std.vector["std::string"](xgb_model.get_booster().feature_names)
8487
bs = get_basescore(xgb_model)
8588
logistic = objective == "logistic"
8689
bdt = cppyy.gbl.TMVA.Experimental.RBDT.LoadText(

0 commit comments

Comments
 (0)