Skip to content

Commit 9624c0e

Browse files
committed
[Python] Fix TMVA pythonizations when built with dataframe=OFF
The registration of the TMVA RBatchGenerator pythonizations already works fine with `dataframe=OFF` in principle because it's inside a try-except block in `_facade.py`. However, the `_rbatchgenerator` module that is conditional on `dataframe=ON` is imported globally in the `__init__.py` of the TMVA pythonizations, and this will still result in an uncaught exception. Importing `_rbatchgenerator` locally fixes the problem.
1 parent 714ff7e commit 9624c0e

File tree

1 file changed

+31
-15
lines changed
  • bindings/pyroot/pythonizations/python/ROOT/_pythonization/_tmva

1 file changed

+31
-15
lines changed

bindings/pyroot/pythonizations/python/ROOT/_pythonization/_tmva/__init__.py

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,32 +22,40 @@
2222

2323
from ._rbdt import Compute, pythonize_rbdt
2424

25-
from ._batchgenerator import (
26-
CreateNumPyGenerators,
27-
CreateTFDatasets,
28-
CreatePyTorchGenerators,
29-
)
30-
31-
python_batchgenerator_functions = [
32-
CreateNumPyGenerators,
33-
CreateTFDatasets,
34-
CreatePyTorchGenerators,
35-
]
3625

3726
def inject_rbatchgenerator(ns):
27+
from ._batchgenerator import (
28+
CreateNumPyGenerators,
29+
CreateTFDatasets,
30+
CreatePyTorchGenerators,
31+
)
32+
33+
python_batchgenerator_functions = [
34+
CreateNumPyGenerators,
35+
CreateTFDatasets,
36+
CreatePyTorchGenerators,
37+
]
38+
3839
for python_func in python_batchgenerator_functions:
3940
func_name = python_func.__name__
4041
setattr(ns.Experimental, func_name, python_func)
4142

4243
return ns
4344

45+
4446
from ._gnn import RModel_GNN, RModel_GraphIndependent
4547

4648
hasRDF = "dataframe" in cppyy.gbl.ROOT.GetROOT().GetConfigFeatures()
4749
if hasRDF:
48-
from ._rtensor import get_array_interface, add_array_interface_property, RTensorGetitem, pythonize_rtensor, _AsRTensor
49-
50-
#this should be available only when xgboost is there ?
50+
from ._rtensor import (
51+
get_array_interface,
52+
add_array_interface_property,
53+
RTensorGetitem,
54+
pythonize_rtensor,
55+
_AsRTensor,
56+
)
57+
58+
# this should be available only when xgboost is there ?
5159
# We probably don't need a protection here since the code is run only when there is xgboost
5260
from ._tree_inference import SaveXGBoost
5361

@@ -67,7 +75,15 @@ def get_defined_attributes(klass, consider_base_classes=False):
6775
any of its base classes (except for `object`).
6876
"""
6977

70-
blacklist = ["__dict__", "__doc__", "__hash__", "__module__", "__weakref__", "__firstlineno__", "__static_attributes__"]
78+
blacklist = [
79+
"__dict__",
80+
"__doc__",
81+
"__hash__",
82+
"__module__",
83+
"__weakref__",
84+
"__firstlineno__",
85+
"__static_attributes__",
86+
]
7187

7288
if not consider_base_classes:
7389
return sorted([attr for attr in klass.__dict__.keys() if attr not in blacklist])

0 commit comments

Comments
 (0)