Skip to content

Commit bc4392b

Browse files
committed
[Python][RDF] Add tests
1 parent 0e75f57 commit bc4392b

File tree

1 file changed

+62
-0
lines changed

1 file changed

+62
-0
lines changed

bindings/pyroot/pythonizations/test/numbadeclare.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -633,5 +633,67 @@ def pass_reference(v):
633633
self.assertTrue(np.array_equal(rvecf, np.array([1.0, 4.0])))
634634

635635

636+
class NumbaDeclareInferred(unittest.TestCase):
637+
"""
638+
Test decorator created with a reconstructed list of arguments using RDF column types,
639+
and a return type inferred from the numba jitted function.
640+
"""
641+
642+
def test_fund_types(self):
643+
"""
644+
Test fundamental types
645+
"""
646+
def is_even(x):
647+
return x % 2 == 0
648+
649+
df = ROOT.RDataFrame(4).Define("x", "rdfentry_").Define("is_even_x", is_even, ["x"])
650+
results = df.Take["bool"]("is_even_x").GetValue()[0]
651+
self.assertEqual(results, True)
652+
653+
def test_rvec(self):
654+
"""
655+
Test RVec
656+
"""
657+
def square_rvec(v):
658+
return v*v
659+
660+
df = ROOT.RDataFrame(4).Define("x", "ROOT::VecOps::RVec<int>({1, 2, 3})").Define("square_rvec", square_rvec, ["x"])
661+
results = df.Take["RVec<int>"]("square_rvec").GetValue()[0]
662+
self.assertTrue(np.array_equal(results, np.array([1, 4, 9])))
663+
664+
def test_std_vec(self):
665+
"""
666+
Test std::vector
667+
"""
668+
def square_std_vec(v):
669+
return v*v
670+
671+
df = ROOT.RDataFrame(4).Define("x", "std::vector<int>({1, 2, 3})").Define("square_std_vec", square_std_vec, ["x"])
672+
results = df.Take["RVec<int>"]("square_std_vec").GetValue()[0]
673+
self.assertTrue(np.array_equal(results, np.array([1, 4, 9])))
674+
675+
def test_std_array(self):
676+
"""
677+
Test std::array
678+
"""
679+
def square_std_arr(v):
680+
return v*v
681+
682+
df = ROOT.RDataFrame(4).Define("x", "std::array<int, 3>({1, 2, 3})").Define("square_std_arr", square_std_arr, ["x"])
683+
results = df.Take["RVec<int>"]("square_std_arr").GetValue()[0]
684+
self.assertTrue(np.array_equal(results, np.array([1, 4, 9])))
685+
686+
def test_missing_signature_raises(self):
687+
"""
688+
Ensure an Exception is raised when return type cannot be inferred
689+
and no explicit signature is provided in the decorator.
690+
"""
691+
def f(x):
692+
return x.M()
693+
694+
with self.assertRaises(Exception):
695+
ROOT.RDataFrame(4).Define("v", "ROOT::Math::PtEtaPhiMVector(1, 2, 3, 4)").Define("m", f, ["v"])
696+
697+
636698
if __name__ == "__main__":
637699
unittest.main()

0 commit comments

Comments
 (0)