Skip to content

Commit 6a6ed1a

Browse files
lobisdpiparo
authored andcommitted
[python][RDF] Extent AsNumpy to convert nested collection data to numpy arrays
1 parent 9b85ca5 commit 6a6ed1a

File tree

2 files changed

+122
-42
lines changed

2 files changed

+122
-42
lines changed

bindings/pyroot/pythonizations/python/ROOT/_pythonization/_rdataframe.py

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -96,14 +96,35 @@ def pypowarray(numpyvec, pow):
9696
9797
Eventually, you probably would like to inspect the content of the RDataFrame or process the data further
9898
with Python libraries. For this purpose, we provide the `AsNumpy()` function, which returns the columns
99-
of your RDataFrame as a dictionary of NumPy arrays. See a simple example below or a full tutorial [here](df026__AsNumpyArrays_8py.html).
99+
of your RDataFrame as a dictionary of NumPy arrays. See a few simple examples below or a full tutorial [here](df026__AsNumpyArrays_8py.html).
100100
101+
\anchor asnumpy_scalar_columns
102+
##### Scalar columns
103+
If your column contains scalar values of fundamental types (e.g., integers, floats), `AsNumpy()` produces NumPy arrays with the appropriate `dtype`:
101104
~~~{.py}
102-
df = ROOT.RDataFrame("myTree", "myFile.root")
103-
cols = df.Filter("x > 10").AsNumpy(["x", "y"]) # retrieve columns "x" and "y" as NumPy arrays
104-
print(cols["x"], cols["y"]) # the values of the cols dictionary are NumPy arrays
105+
rdf = ROOT.RDataFrame(10).Define("int_col", "1").Define("float_col", "2.3")
106+
print(rdf.AsNumpy(["int_col", "float_col"]))
107+
# Output: {'int_col': array([...], dtype=int32), 'float_col': array([...], dtype=float64)}
108+
~~~
109+
110+
Columns containing non-fundamental types (e.g., objects, strings) will result in NumPy arrays with `dtype=object`.
111+
112+
##### Collection Columns
113+
If your column contains collections of fundamental types (e.g., std::vector<int>), `AsNumpy()` produces a NumPy array with `dtype=object` where each
114+
element is a NumPy array representing the collection for its corresponding entry in the column.
115+
116+
If the collection at a certain entry contains values of fundamental types, or if it is a regularly shaped multi-dimensional array of a fundamental type,
117+
then the numpy array representing the collection for that entry will have the `dtype` associated with the value type of the collection, for example:
118+
~~~{.py}
119+
rdf = rdf.Define("v_col", "std::vector<int>{{1, 2, 3}}")
120+
print(rdf.AsNumpy(["v_col", "int_col", "float_col"]))
121+
# Output: {'v_col': array([array([1, 2, 3], dtype=int32), ...], dtype=object), ...}
105122
~~~
106123
124+
If the collection at a certain entry contains values of a non-fundamental type, `AsNumpy()` will fallback on the [default behavior](\ref asnumpy_scalar_columns) and produce a NumPy array with `dtype=object` for that collection.
125+
126+
For more complex collection types in your entries, e.g. when every entry has a jagged array value, refer to the section on [interoperability with AwkwardArray](\ref awkward_interop).
127+
107128
#### Processing data stored in NumPy arrays
108129
109130
In case you have data in NumPy arrays in Python and you want to process the data with ROOT, you can easily
@@ -124,6 +145,8 @@ def pypowarray(numpyvec, pow):
124145
df.Define("z", "x + y").Snapshot("tree", "file.root")
125146
~~~
126147
148+
149+
\anchor awkward_interop
127150
### Interoperability with [AwkwardArray](https://awkward-array.org/doc/main/user-guide/how-to-convert-rdataframe.html)
128151
129152
The function for RDataFrame to Awkward conversion is ak.from_rdataframe(). The argument to this function accepts a tuple of strings that are the RDataFrame column names. By default this function returns ak.Array type.
@@ -204,11 +227,20 @@ def pypowarray(numpyvec, pow):
204227
\endpythondoc
205228
'''
206229

230+
from __future__ import annotations
231+
232+
from typing import Iterable, Optional
233+
207234
from . import pythonization
208235
from ._pyz_utils import MethodTemplateGetter, MethodTemplateWrapper
209236

210237

211-
def RDataFrameAsNumpy(df, columns=None, exclude=None, lazy=False):
238+
def RDataFrameAsNumpy(
239+
df: ROOT.RDataFrame, # noqa: F821
240+
columns: Optional[Iterable[str]] = None,
241+
exclude: Optional[Iterable[str]] = None,
242+
lazy: bool = False,
243+
):
212244
"""Read-out the RDataFrame as a collection of numpy arrays.
213245
214246
The values of the dataframe are read out as numpy array of the respective type
@@ -226,6 +258,7 @@ def RDataFrameAsNumpy(df, columns=None, exclude=None, lazy=False):
226258
event-loop.
227259
228260
Parameters:
261+
df: The RDataFrame to read out.
229262
columns: If None return all branches as columns, otherwise specify names in iterable.
230263
exclude: Exclude branches from selection.
231264
lazy: Determines whether this action is instant (False, default) or lazy (True).
@@ -240,9 +273,9 @@ def RDataFrameAsNumpy(df, columns=None, exclude=None, lazy=False):
240273

241274
# Sanitize input arguments
242275
if isinstance(columns, str):
243-
raise TypeError("The columns argument requires a list of strings")
276+
raise TypeError("The columns argument requires an iterable of strings")
244277
if isinstance(exclude, str):
245-
raise TypeError("The exclude argument requires a list of strings")
278+
raise TypeError("The exclude argument requires an iterable of strings")
246279

247280
# Early check for numpy
248281
try:
@@ -310,7 +343,7 @@ def __init__(self, result_ptrs, columns):
310343
self._columns = columns
311344
self._py_arrays = None
312345

313-
def GetValue(self):
346+
def GetValue(self) -> dict:
314347
"""Triggers, if necessary, the event loop to run the Take actions for
315348
the requested columns and produce the NumPy arrays as result.
316349
@@ -334,7 +367,11 @@ def GetValue(self):
334367
else:
335368
tmp = numpy.empty(len(cpp_reference), dtype=object)
336369
for i, x in enumerate(cpp_reference):
337-
tmp[i] = x # This creates only the wrapping of the objects and does not copy.
370+
if hasattr(x, "__array_interface__"):
371+
tmp[i] = numpy.asarray(x)
372+
else:
373+
tmp[i] = x
374+
338375
self._py_arrays[column] = ndarray(tmp, self._result_ptrs[column])
339376

340377
return self._py_arrays

bindings/pyroot/pythonizations/test/rdataframe_asnumpy.py

Lines changed: 76 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1-
import unittest
2-
import ROOT
3-
import numpy as np
41
import pickle
2+
import platform
3+
import tempfile
4+
import unittest
5+
from pathlib import Path
56

7+
import numpy as np
8+
import ROOT
69
from ROOT._pythonization._rdataframe import _clone_asnumpyresult
710

811

@@ -38,8 +41,7 @@ def make_tree(*dtypes):
3841
elif "O" in dtype:
3942
var = np.empty(1, dtype=np.uint8)
4043
else:
41-
raise Exception(
42-
"Type {} not known to create branch.".format(dtype))
44+
raise Exception("Type {} not known to create branch.".format(dtype))
4345
col_vars.append(var)
4446

4547
for dtype, name, var in zip(dtypes, col_names, col_vars):
@@ -71,6 +73,7 @@ class RDataFrameAsNumpy(unittest.TestCase):
7173
"""
7274
Testing of RDataFrame.AsNumpy pythonization
7375
"""
76+
7477
def test_branch_dtypes(self):
7578
"""
7679
Test supported data-types for read-out
@@ -89,8 +92,8 @@ def test_branch_bool(self):
8992
"""
9093
df = ROOT.RDataFrame(2).Define("x", "bool(rdfentry_)")
9194
npy = df.AsNumpy()
92-
self.assertTrue(bool(npy["x"][0]) == False)
93-
self.assertTrue(bool(npy["x"][1]) == True)
95+
self.assertFalse(bool(npy["x"][0]))
96+
self.assertTrue(bool(npy["x"][1]))
9497

9598
def test_read_array(self):
9699
"""
@@ -131,12 +134,11 @@ def test_read_vector_constantsize(self):
131134
return std::vector<unsigned int>({n, n, n});
132135
}
133136
""")
134-
df = ROOT.ROOT.RDataFrame(5).Define("x",
135-
"create_vector_constantsize(rdfentry_)")
137+
df = ROOT.ROOT.RDataFrame(5).Define("x", "create_vector_constantsize(rdfentry_)")
136138
npy = df.AsNumpy()
137139
self.assertEqual(npy["x"].size, 5)
138140
self.assertEqual(list(npy["x"][0]), [0, 0, 0])
139-
self.assertIn("vector<unsigned int>", str(type(npy["x"][0])))
141+
self.assertTrue(isinstance(npy["x"], np.ndarray))
140142

141143
def test_read_vector_variablesize(self):
142144
"""
@@ -147,12 +149,11 @@ def test_read_vector_variablesize(self):
147149
return std::vector<unsigned int>(n);
148150
}
149151
""")
150-
df = ROOT.ROOT.RDataFrame(5).Define("x",
151-
"create_vector_variablesize(rdfentry_)")
152+
df = ROOT.ROOT.RDataFrame(5).Define("x", "create_vector_variablesize(rdfentry_)")
152153
npy = df.AsNumpy()
153154
self.assertEqual(npy["x"].size, 5)
154155
self.assertEqual(list(npy["x"][3]), [0, 0, 0])
155-
self.assertIn("vector<unsigned int>", str(type(npy["x"][0])))
156+
self.assertTrue(isinstance(npy["x"], np.ndarray))
156157

157158
def test_read_tlorentzvector(self):
158159
"""
@@ -197,8 +198,7 @@ def test_define_columns(self):
197198
"""
198199
Testing reading defined columns
199200
"""
200-
df = ROOT.ROOT.RDataFrame(4).Define("x", "1").Define("y", "2").Define(
201-
"z", "3")
201+
df = ROOT.ROOT.RDataFrame(4).Define("x", "1").Define("y", "2").Define("z", "3")
202202
npy = df.AsNumpy(columns=["x", "y"])
203203
ref = {"x": np.array([1] * 4), "y": np.array([2] * 4)}
204204
self.assertTrue(sorted(["x", "y"]) == sorted(npy.keys()))
@@ -209,16 +209,14 @@ def test_exclude_columns(self):
209209
"""
210210
Testing excluding columns from read-out
211211
"""
212-
df = ROOT.ROOT.RDataFrame(4).Define("x", "1").Define("y", "2").Define(
213-
"z", "3")
212+
df = ROOT.ROOT.RDataFrame(4).Define("x", "1").Define("y", "2").Define("z", "3")
214213
npy = df.AsNumpy(exclude=["z"])
215214
ref = {"x": np.array([1] * 4), "y": np.array([2] * 4)}
216215
self.assertTrue(sorted(["x", "y"]) == sorted(npy.keys()))
217216
self.assertTrue(all(ref["x"] == npy["x"]))
218217
self.assertTrue(all(ref["y"] == npy["y"]))
219218

220-
df2 = ROOT.ROOT.RDataFrame(4).Define("x", "1").Define("y", "2").Define(
221-
"z", "3")
219+
df2 = ROOT.ROOT.RDataFrame(4).Define("x", "1").Define("y", "2").Define("z", "3")
222220
npy = df2.AsNumpy(columns=["x", "y"], exclude=["y"])
223221
ref = {"x": np.array([1] * 4)}
224222
self.assertTrue(["x"] == list(npy.keys()))
@@ -264,7 +262,7 @@ def test_empty_array(self):
264262
df = ROOT.ROOT.RDataFrame(1).Define("x", "std::vector<float>()")
265263
npy = df.AsNumpy(["x"])
266264
self.assertEqual(npy["x"].size, 1)
267-
self.assertTrue(npy["x"][0].empty())
265+
self.assertEqual(npy["x"][0].size, 0)
268266

269267
def test_empty_selection(self):
270268
"""
@@ -319,19 +317,15 @@ def test_cloning(self):
319317

320318
# Get the result for the first range
321319
(begin, end) = ranges.pop(0)
322-
ROOT.Internal.RDF.ChangeEmptyEntryRange(
323-
ROOT.RDF.AsRNode(df), (begin, end))
320+
ROOT.Internal.RDF.ChangeEmptyEntryRange(ROOT.RDF.AsRNode(df), (begin, end))
324321
asnumpyres = df.AsNumpy(["x"], lazy=True) # To return an AsNumpyResult
325-
self.assertSequenceEqual(
326-
asnumpyres.GetValue()["x"].tolist(), np.arange(begin, end).tolist())
322+
self.assertSequenceEqual(asnumpyres.GetValue()["x"].tolist(), np.arange(begin, end).tolist())
327323

328324
# Clone the result for following ranges
329-
for (begin, end) in ranges:
330-
ROOT.Internal.RDF.ChangeEmptyEntryRange(
331-
ROOT.RDF.AsRNode(df), (begin, end))
325+
for begin, end in ranges:
326+
ROOT.Internal.RDF.ChangeEmptyEntryRange(ROOT.RDF.AsRNode(df), (begin, end))
332327
asnumpyres = _clone_asnumpyresult(asnumpyres)
333-
self.assertSequenceEqual(
334-
asnumpyres.GetValue()["x"].tolist(), np.arange(begin, end).tolist())
328+
self.assertSequenceEqual(asnumpyres.GetValue()["x"].tolist(), np.arange(begin, end).tolist())
335329

336330
def test_bool_column(self):
337331
"""
@@ -343,8 +337,57 @@ def test_bool_column(self):
343337
df = ROOT.RDataFrame(n_events).Define(name, f"(int)rdfentry_ > {cut}")
344338
arr = df.AsNumpy([name])[name]
345339
ref = np.arange(0, n_events) > cut
346-
self.assertTrue(all(arr == ref)) # test values
347-
self.assertEqual(arr.dtype, ref.dtype) # test type
348-
349-
if __name__ == '__main__':
340+
self.assertTrue(all(arr == ref)) # test values
341+
self.assertEqual(arr.dtype, ref.dtype) # test type
342+
343+
def test_rdataframe_as_numpy_array_regular(self):
344+
column_name = "vector"
345+
n = 10
346+
for from_file in [False, True]:
347+
for shape, declaration in [
348+
((n, 3), "std::vector<int>{1,2,3}"),
349+
((n, 3), "std::vector<float>{1,2,3}"),
350+
((n, 3), "std::vector<double>{1,2,3}"),
351+
]:
352+
df = ROOT.RDataFrame(10).Define(column_name, declaration)
353+
temp_file_path = None
354+
if from_file:
355+
# save to disk and read back
356+
temp_file = tempfile.NamedTemporaryFile(delete=False)
357+
temp_file_path = Path(temp_file.name)
358+
temp_file.close()
359+
360+
df.Snapshot("tree", str(temp_file_path))
361+
df = ROOT.RDataFrame("tree", str(temp_file_path))
362+
363+
array = df.AsNumpy([column_name])[column_name]
364+
self.assertTrue(isinstance(array, np.ndarray))
365+
# self.assertEqual(array.shape, shape) # when we implement regular array handling
366+
self.assertTrue(array.shape[0] == n)
367+
self.assertTrue(all(x.shape[0] == shape[1] for x in array))
368+
369+
if from_file and platform.system() != "Windows":
370+
temp_file_path.unlink()
371+
372+
def test_rdataframe_as_numpy_array_jagged(self):
373+
jagged_array = ROOT.std.vector(float)()
374+
column_name = "jagged_array"
375+
tree = ROOT.TTree("tree", "Tree with Jagged Array")
376+
tree.Branch(column_name, jagged_array)
377+
n = 10
378+
for i in range(n):
379+
jagged_array.clear()
380+
for j in range(i):
381+
jagged_array.push_back(j)
382+
tree.Fill()
383+
384+
df = ROOT.RDataFrame(tree)
385+
array = df.AsNumpy([column_name])[column_name]
386+
self.assertTrue(isinstance(array, np.ndarray))
387+
self.assertTrue(array.shape[0] == n)
388+
self.assertTrue(all(isinstance(x, np.ndarray) for x in array))
389+
self.assertTrue(all(len(x) == i for i, x in enumerate(array)))
390+
391+
392+
if __name__ == "__main__":
350393
unittest.main()

0 commit comments

Comments
 (0)