Skip to content

Commit fd33615

Browse files
committed
[Python][RDF] Extend FunctionJitter with newest supported numba jitting types
1 parent 9810263 commit fd33615

File tree

2 files changed

+18
-15
lines changed

2 files changed

+18
-15
lines changed

bindings/pyroot/pythonizations/python/ROOT/_numbadeclare.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,12 @@ def _NumbaDeclareDecorator(input_types, return_type=None, name=None):
4646
"match_pattern": r"(?:ROOT::)?(?:VecOps::)?RVec\w+|(?:ROOT::)?(?:VecOps::)?RVec<[\w\s]+>",
4747
"cpp_name": ["ROOT::RVec", "ROOT::VecOps::RVec"],
4848
},
49-
"std::vector": {
50-
"match_pattern": r"std::vector<[\w\s]+>",
49+
"vector": {
50+
"match_pattern": r"(?:std::)?vector<[\w\s]+>",
5151
"cpp_name": ["std::vector"],
5252
},
53-
"std::array": {
54-
"match_pattern": r"std::array<[\w\s,<>]+>",
53+
"array": {
54+
"match_pattern": r"(?:std::)?array<[\w\s,<>]+>",
5555
"cpp_name": ["std::array"],
5656
},
5757
}
@@ -233,7 +233,6 @@ def inner(func, input_types=input_types, return_type=return_type, name=name):
233233
"""
234234
Inner decorator without arguments, see outer decorator for documentation
235235
"""
236-
237236
# Jit the given Python callable with numba
238237
try:
239238
nb_return_type, nb_input_types = get_numba_signature(input_types, return_type)
@@ -255,6 +254,13 @@ def inner(func, input_types=input_types, return_type=return_type, name=name):
255254
"See https://cppyy.readthedocs.io/en/latest/numba.html#numba-support"
256255
)
257256
nbjit = nb.jit(nopython=True, inline="always")(func)
257+
# In this case, the user has to explictly provide the return type, cannot be inferred
258+
if return_type is None:
259+
raise RuntimeError(
260+
"Failed to infer the return type for the provided function. "
261+
"Please specify the signature explicitly in the decorator, e.g.: "
262+
"@ROOT.NumbaDeclare(['double'], 'double')"
263+
)
258264
except: # noqa E722
259265
raise Exception("Failed to jit Python callable {} with numba.jit".format(func))
260266
func.numba_func = nbjit

bindings/pyroot/pythonizations/python/ROOT/_pythonization/_rdf_pyz.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -81,17 +81,14 @@ def find_type(self, x):
8181
t = self.rdf.GetColumnType(x)
8282
if t in TREE_TO_NUMBA: # The column is a fundamental type from tree
8383
return TREE_TO_NUMBA[t]
84-
elif "<" in t: # The column type is a RVec<type>
85-
if ">>" in t: # It is a RVec<RVec<T>>
86-
raise TypeError(
87-
f"Only columns with 'RVec<T>' where T is is a fundamental type are supported, not '{t}'."
88-
)
89-
g = re.match("(.*)<(.*)>", t).groups(0)
90-
if g[1] in TREE_TO_NUMBA:
91-
return "RVec<" + TREE_TO_NUMBA[g[1]] + ">"
92-
# There are data type that leak into here. Not sure from where. But need to implement something here such that this condition is never met.
93-
return "RVec<" + str(g[1]) + ">"
9484

85+
match = re.match(r"([\w:]+)<(.+)>", t)
86+
if match:
87+
container_type, inner_type = match.groups()
88+
container_type = container_type.strip()
89+
inner_type = inner_type.strip()
90+
inner_mapped = TREE_TO_NUMBA.get(inner_type, inner_type)
91+
return f"{container_type}<{inner_mapped}>"
9592
else:
9693
return t
9794
else:

0 commit comments

Comments
 (0)