Skip to content

Commit 28582e0

Browse files
authored
[pyroot] add np.int16 and np.uint16 to conversion map
Fixes #18365
1 parent d1717fa commit 28582e0

File tree

3 files changed

+42
-2
lines changed

3 files changed

+42
-2
lines changed

bindings/pyroot/pythonizations/python/ROOT/_pythonization/_rvec.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@
7575

7676

7777
def _get_cpp_type_from_numpy_type(dtype):
78-
cpptypes = {"i4": "int", "u4": "unsigned int", "i8": "Long64_t", "u8": "ULong64_t", "f4": "float", "f8": "double", "b1": "bool"}
78+
cpptypes = {"i2": "Short_t", "u2": "UShort_t", "i4": "int", "u4": "unsigned int", "i8": "Long64_t", "u8": "ULong64_t", "f4": "float", "f8": "double", "b1": "bool"}
7979

8080
if not dtype in cpptypes:
8181
raise RuntimeError("Object not convertible: Python object has unknown data-type '" + dtype + "'.")

bindings/pyroot/pythonizations/test/rvec_asrvec.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class AsRVec(unittest.TestCase):
2626

2727
# Helpers
2828
dtypes = [
29-
"int32", "int64", "uint32", "uint64", "float32", "float64", "bool"
29+
"int16", "int32", "int64", "uint16", "uint32", "uint64", "float32", "float64", "bool"
3030
]
3131

3232
def check_memory_adoption(self, root_obj, np_obj):

bindings/pyroot/pythonizations/test/ttree_setbranchaddress.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,46 @@ def test_class_with_array_member(self):
184184
self.assertEqual(mc.foo[0], 1.0)
185185
self.assertEqual(mc.foo[1], 2.0)
186186

187+
def test_np_conversion(self):
188+
# 18365
189+
a = np.zeros(3, np.uint16)
190+
a[0] = 1
191+
a[1] = 2
192+
a[2] = 3
193+
c = np.zeros(3, np.int16)
194+
c[0] = 4
195+
c[1] = 5
196+
c[2] = 6
197+
t = ROOT.TTree("t", "t")
198+
t.Branch("b", a, "b[3]/s")
199+
t.Branch("d", c, "d[3]/S")
200+
t.Fill()
201+
a[0] = 10
202+
a[1] = 20
203+
a[2] = 30
204+
c[0] = 40
205+
c[1] = 50
206+
c[2] = 60
207+
t.Fill()
208+
# t.Print()
209+
# t.Scan()
210+
t.SetBranchAddress("b", a)
211+
t.SetBranchAddress("d", c)
212+
t.GetEntry(0)
213+
self.assertEqual(a[0], 1)
214+
self.assertEqual(a[1], 2)
215+
self.assertEqual(a[2], 3)
216+
self.assertEqual(c[0], 4)
217+
self.assertEqual(c[1], 5)
218+
self.assertEqual(c[2], 6)
219+
t.GetEntry(1)
220+
self.assertEqual(a[0], 10)
221+
self.assertEqual(a[1], 20)
222+
self.assertEqual(a[2], 30)
223+
self.assertEqual(c[0], 40)
224+
self.assertEqual(c[1], 50)
225+
self.assertEqual(c[2], 60)
226+
187227

188228
if __name__ == "__main__":
189229
unittest.main()

0 commit comments

Comments
 (0)