diff --git a/bindings/pyroot/pythonizations/python/ROOT/_numbadeclare.py b/bindings/pyroot/pythonizations/python/ROOT/_numbadeclare.py index 5e033d3b77f9f..f5a0bf107db09 100644 --- a/bindings/pyroot/pythonizations/python/ROOT/_numbadeclare.py +++ b/bindings/pyroot/pythonizations/python/ROOT/_numbadeclare.py @@ -46,12 +46,12 @@ def _NumbaDeclareDecorator(input_types, return_type=None, name=None): "match_pattern": r"(?:ROOT::)?(?:VecOps::)?RVec\w+|(?:ROOT::)?(?:VecOps::)?RVec<[\w\s]+>", "cpp_name": ["ROOT::RVec", "ROOT::VecOps::RVec"], }, - "std::vector": { - "match_pattern": r"std::vector<[\w\s]+>", + "vector": { + "match_pattern": r"(?:std::)?vector<[\w\s]+>", "cpp_name": ["std::vector"], }, - "std::array": { - "match_pattern": r"std::array<[\w\s,<>]+>", + "array": { + "match_pattern": r"(?:std::)?array<[\w\s,<>]+>", "cpp_name": ["std::array"], }, } @@ -233,7 +233,6 @@ def inner(func, input_types=input_types, return_type=return_type, name=name): """ Inner decorator without arguments, see outer decorator for documentation """ - # Jit the given Python callable with numba try: nb_return_type, nb_input_types = get_numba_signature(input_types, return_type) @@ -255,6 +254,13 @@ def inner(func, input_types=input_types, return_type=return_type, name=name): "See https://cppyy.readthedocs.io/en/latest/numba.html#numba-support" ) nbjit = nb.jit(nopython=True, inline="always")(func) + # In this case, the user has to explictly provide the return type, cannot be inferred + if return_type is None: + raise RuntimeError( + "Failed to infer the return type for the provided function. " + "Please specify the signature explicitly in the decorator, e.g.: " + "@ROOT.NumbaDeclare(['double'], 'double')" + ) except: # noqa E722 raise Exception("Failed to jit Python callable {} with numba.jit".format(func)) func.numba_func = nbjit diff --git a/bindings/pyroot/pythonizations/python/ROOT/_pythonization/_rdf_pyz.py b/bindings/pyroot/pythonizations/python/ROOT/_pythonization/_rdf_pyz.py index 1ff3e9dcb0ac0..d82ee84554339 100755 --- a/bindings/pyroot/pythonizations/python/ROOT/_pythonization/_rdf_pyz.py +++ b/bindings/pyroot/pythonizations/python/ROOT/_pythonization/_rdf_pyz.py @@ -81,17 +81,14 @@ def find_type(self, x): t = self.rdf.GetColumnType(x) if t in TREE_TO_NUMBA: # The column is a fundamental type from tree return TREE_TO_NUMBA[t] - elif "<" in t: # The column type is a RVec - if ">>" in t: # It is a RVec> - raise TypeError( - f"Only columns with 'RVec' where T is is a fundamental type are supported, not '{t}'." - ) - g = re.match("(.*)<(.*)>", t).groups(0) - if g[1] in TREE_TO_NUMBA: - return "RVec<" + TREE_TO_NUMBA[g[1]] + ">" - # There are data type that leak into here. Not sure from where. But need to implement something here such that this condition is never met. - return "RVec<" + str(g[1]) + ">" + match = re.match(r"([\w:]+)<(.+)>", t) + if match: + container_type, inner_type = match.groups() + container_type = container_type.strip() + inner_type = inner_type.strip() + inner_mapped = TREE_TO_NUMBA.get(inner_type, inner_type) + return f"{container_type}<{inner_mapped}>" else: return t else: diff --git a/bindings/pyroot/pythonizations/test/numbadeclare.py b/bindings/pyroot/pythonizations/test/numbadeclare.py index cb63c07f32e73..81f5b57a430d0 100644 --- a/bindings/pyroot/pythonizations/test/numbadeclare.py +++ b/bindings/pyroot/pythonizations/test/numbadeclare.py @@ -633,5 +633,95 @@ def pass_reference(v): self.assertTrue(np.array_equal(rvecf, np.array([1.0, 4.0]))) +class NumbaDeclareInferred(unittest.TestCase): + """ + Test decorator created with a reconstructed list of arguments using RDF column types, + and a return type inferred from the numba jitted function. + """ + + def test_fund_types(self): + """ + Test fundamental types + """ + df = ROOT.RDataFrame(4).Define("x", "rdfentry_") + + with self.subTest("function"): + def is_even(x): + return x % 2 == 0 + df = df.Define("is_even_x_1", is_even, ["x"]) + results = df.Take["bool"]("is_even_x_1").GetValue()[0] + self.assertEqual(results, True) + + with self.subTest("lambda"): + df = df.Define("is_even_x_2", lambda x: x % 2 == 0, ["x"]) + results = df.Take["bool"]("is_even_x_2").GetValue()[0] + self.assertEqual(results, True) + + def test_rvec(self): + """ + Test RVec + """ + df = ROOT.RDataFrame(4).Define("x", "ROOT::VecOps::RVec({1, 2, 3})") + + with self.subTest("function"): + def square_rvec(v): + return v*v + df = df.Define("square_rvec_1", square_rvec, ["x"]) + results = df.Take["RVec"]("square_rvec_1").GetValue()[0] + self.assertTrue(np.array_equal(results, np.array([1, 4, 9]))) + + with self.subTest("lambda"): + df = df.Define("square_rvec_2", lambda v: v*v, ["x"]) + results = df.Take["RVec"]("square_rvec_2").GetValue()[0] + self.assertTrue(np.array_equal(results, np.array([1, 4, 9]))) + + def test_std_vec(self): + """ + Test std::vector + """ + df = ROOT.RDataFrame(4).Define("x", "std::vector({1, 2, 3})") + + with self.subTest("function"): + def square_std_vec(v): + return v*v + df = df.Define("square_std_vec_1", square_std_vec, ["x"]) + results = df.Take["RVec"]("square_std_vec_1").GetValue()[0] + self.assertTrue(np.array_equal(results, np.array([1, 4, 9]))) + + with self.subTest("lambda"): + df = df.Define("square_std_vec_2", lambda v: v*v, ["x"]) + results = df.Take["RVec"]("square_std_vec_2").GetValue()[0] + self.assertTrue(np.array_equal(results, np.array([1, 4, 9]))) + + def test_std_array(self): + """ + Test std::array + """ + df = ROOT.RDataFrame(4).Define("x", "std::array({1, 2, 3})") + + with self.subTest("function"): + def square_std_arr(v): + return v*v + df = df.Define("square_std_arr_1", square_std_arr, ["x"]) + results = df.Take["RVec"]("square_std_arr_1").GetValue()[0] + self.assertTrue(np.array_equal(results, np.array([1, 4, 9]))) + + with self.subTest("lambda"): + df = df.Define("square_std_arr_2", lambda v: v*v, ["x"]) + results = df.Take["RVec"]("square_std_arr_2").GetValue()[0] + self.assertTrue(np.array_equal(results, np.array([1, 4, 9]))) + + def test_missing_signature_raises(self): + """ + Ensure an Exception is raised when return type cannot be inferred + and no explicit signature is provided in the decorator. + """ + def f(x): + return x.M() + + with self.assertRaises(Exception): + ROOT.RDataFrame(4).Define("v", "ROOT::Math::PtEtaPhiMVector(1, 2, 3, 4)").Define("m", f, ["v"]) + + if __name__ == "__main__": unittest.main()