Skip to content

Commit 1359273

Browse files
authored
feat: add RNTuple writing support for IndexedArray and fix IndexedOptionArray (#1493)
* Added writing support for IndexedArray and fixed IndexedOptionArray * Slightly better test * Fixed issue with 32-bit indices * Updated ROOT test
1 parent d4ac22a commit 1359273

File tree

3 files changed

+56
-20
lines changed

3 files changed

+56
-20
lines changed

src/uproot/models/RNTuple.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -544,8 +544,11 @@ def field_form(self, this_id, keys, ak_add_doc=False):
544544
if this_id in self._related_ids:
545545
child_id = self._related_ids[this_id][0]
546546
inner = self.field_form(child_id, keys, ak_add_doc=ak_add_doc)
547+
idx_type = (
548+
"i32" if self._column_records_dict[cfid][0].nbits == 32 else "i64"
549+
)
547550
return ak.forms.ListOffsetForm(
548-
"i64", inner, form_key=keyname, parameters=parameters
551+
idx_type, inner, form_key=keyname, parameters=parameters
549552
)
550553
elif structural_role == uproot.const.RNTupleFieldRole.RECORD:
551554
newids = []

src/uproot/writing/_cascadentuple.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def _cpp_typename(akform, subcall=False):
111111
elif isinstance(akform, awkward.forms.UnionForm):
112112
field_typenames = [_cpp_typename(t, subcall=True) for t in akform.contents]
113113
typename = f"std::variant<{','.join(field_typenames)}>"
114-
elif isinstance(akform, awkward.forms.UnmaskedForm):
114+
elif isinstance(akform, (awkward.forms.UnmaskedForm, awkward.forms.IndexedForm)):
115115
return _cpp_typename(akform.content, subcall=True)
116116
else:
117117
raise NotImplementedError(f"Form type {type(akform)} cannot be written yet")
@@ -484,7 +484,10 @@ def _build_field_col_records(
484484
field_name=subfield_name,
485485
parent_fid=field_id,
486486
)
487-
elif isinstance(akform, awkward.forms.UnmaskedForm):
487+
elif isinstance(
488+
akform, (awkward.forms.UnmaskedForm, awkward.forms.IndexedForm)
489+
):
490+
# IndexedForms just get rearranged, so they are transparent
488491
# Do nothing
489492
self._build_field_col_records(
490493
akform.content,
@@ -930,18 +933,24 @@ def extend(self, file, sink, data):
930933

931934
cluster_page_data = [] # list of list of (locator, len, offset)
932935
data_buffers = awkward.to_buffers(data)[2]
933-
for idx, key in enumerate(self._header._column_keys):
934-
if "switch" in key:
936+
937+
# We need to modify make a few modifications since not everything directly translates to RNTuples
938+
for key in list(data_buffers.keys()):
939+
barekey = key.split("-")[0]
940+
if "offset" in key:
941+
# RNTuples don't store the first offset
942+
data_buffers[key] = data_buffers[key][1:]
943+
elif "index" in key and barekey + "-tags" in data_buffers:
944+
# We group indices and tags into a single array
935945
dtype = numpy.dtype([("index", "int64"), ("tag", "int32")])
936-
indices = data_buffers[key.split("-")[0] + "-index"]
937-
tags = data_buffers[key.split("-")[0] + "-tags"]
946+
indices = data_buffers[barekey + "-index"]
947+
tags = data_buffers[barekey + "-tags"]
938948
switches = numpy.zeros(len(indices), dtype=dtype)
939949
switches["index"] = indices
940950
switches["tag"] = tags + 1
941-
col_data = switches
942-
elif "startstop" in key:
951+
data_buffers[barekey + "-switch"] = switches
952+
elif "start" in key:
943953
# ListArrays need to be converted to ListOffsetArrays
944-
barekey = key.split("-")[0]
945954
starts = awkward.index.Index(data_buffers[f"{barekey}-starts"])
946955
stops = awkward.index.Index(data_buffers[f"{barekey}-stops"])
947956
next_barekey = f"node{int(barekey[4:])+1}"
@@ -953,17 +962,20 @@ def extend(self, file, sink, data):
953962
starts, stops, content
954963
).to_ListOffsetArray64()
955964
)[2]
965+
data_buffers[f"{barekey}-startstop"] = tmp_buffers["node0-offsets"][1:]
956966
data_buffers[f"{next_barekey}-data"] = tmp_buffers["node1-data"]
957-
col_data = tmp_buffers["node0-offsets"][1:]
958-
# no longer need the temporary data
959-
del starts, stops, content, tmp_buffers
960-
else:
961-
col_data = data_buffers[key]
962-
if "offsets" in key:
963-
col_data = col_data[1:]
964967
elif "index" in key:
965-
deltas = numpy.array(col_data != -1, dtype=col_data.dtype)
966-
col_data = numpy.cumsum(deltas)
968+
# We need to rearrange the data
969+
next_barekey = f"node{int(barekey[4:])+1}"
970+
index = data_buffers[key]
971+
content = data_buffers[f"{next_barekey}-data"]
972+
content = content[index[index >= 0]] # Rearrange data
973+
deltas = numpy.array(index >= 0, dtype=index.dtype)
974+
data_buffers[key] = numpy.cumsum(deltas, dtype=deltas.dtype)
975+
data_buffers[f"{next_barekey}-data"] = content
976+
977+
for idx, key in enumerate(self._header._column_keys):
978+
col_data = data_buffers[key]
967979
col_len = len(col_data.reshape(-1))
968980
raw_data = col_data.reshape(-1).view("uint8")
969981
if col_data.dtype == numpy.dtype("bool"):

tests/test_1395_rntuple_writing_lists_and_structs.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,18 @@
4444
ak.index.Index([1, 2, 4]),
4545
ak.contents.NumpyArray([0, 1, 2, 3, 4, 5]),
4646
),
47+
"indexed_option_array": ak.contents.IndexedOptionArray(
48+
ak.index.Index([3, -1, 1]),
49+
ak.contents.NumpyArray([0, 1, 2, 3, 4, 5]),
50+
),
51+
"indexed_option_array32": ak.contents.IndexedOptionArray(
52+
ak.index.Index32([3, -1, 1]),
53+
ak.contents.NumpyArray([0, 1, 2, 3, 4, 5]),
54+
),
55+
"indexed_array": ak.contents.IndexedArray(
56+
ak.index.Index([3, 0, 1]),
57+
ak.contents.NumpyArray([0, 1, 2, 3, 4, 5]),
58+
),
4759
}
4860
)
4961

@@ -71,7 +83,7 @@ def test_writing_and_reading(tmp_path):
7183
arrays = obj.arrays()
7284

7385
for f in data.fields:
74-
if f == "optional":
86+
if f in ("optional", "indexed_option_array", "indexed_option_array32"):
7587
assert [t[0] if len(t) > 0 else None for t in arrays[f][:3]] == data[
7688
f
7789
].tolist()
@@ -143,6 +155,15 @@ def test_writing_then_reading_with_ROOT(tmp_path, capfd):
143155
in out
144156
)
145157
assert "* Field 17 : list_array (std::vector<std::int64_t>)" in out
158+
assert (
159+
"* Field 18 : indexed_option_array (std::optional<std::int64_t>)"
160+
in out
161+
)
162+
assert (
163+
"* Field 19 : indexed_option_array32 (std::optional<std::int64_t>)"
164+
in out
165+
)
166+
assert "* Field 20 : indexed_array (std::int64_t)" in out
146167

147168

148169
def test_field_descriptions(tmp_path):

0 commit comments

Comments
 (0)