Skip to content

Commit b18e0b2

Browse files
committed
[Python][RDF] Add tests for Numba.Declare with std::vector/std::array
1 parent ae8626d commit b18e0b2

File tree

1 file changed

+25
-0
lines changed

1 file changed

+25
-0
lines changed

bindings/pyroot/pythonizations/test/numbadeclare.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,31 @@ def pass_temporary(v):
113113

114114
self.assertTrue(np.array_equal(rvecf, np.array([4.])))
115115

116+
def test_rdataframe_std_vector(self):
117+
"""
118+
Test function call as part of RDataFrame
119+
"""
120+
@ROOT.Numba.Declare(["std::vector<int>"], "std::vector<int>")
121+
def square_vec(x):
122+
return x * x
123+
df = ROOT.RDataFrame(4).Define("x", "std::vector{1, 2, 3}").Define("x_sq", "Numba::square_vec(x)")
124+
df.Display().Print()
125+
self.assertEqual(df.Sum("x").GetValue(), 24)
126+
self.assertEqual(df.Sum("x_sq").GetValue(), 56)
127+
128+
def test_rdataframe_std_array(self):
129+
"""
130+
Test function call as part of RDataFrame with std::array
131+
"""
132+
@ROOT.Numba.Declare(["std::array<int, 3>"], "std::array<int, 3>")
133+
def square_array(x):
134+
return x * x
135+
136+
df = ROOT.RDataFrame(4).Define("x", "std::array{1, 2, 3}").Define("x_sq", "Numba::square_array(x)")
137+
df.Display().Print()
138+
self.assertEqual(df.Sum("x").GetValue(), 24)
139+
self.assertEqual(df.Sum("x_sq").GetValue(), 56)
140+
116141
# Test wrappings
117142
def test_wrapper_in_void(self):
118143
"""

0 commit comments

Comments
 (0)