Skip to content

Commit 65b4af6

Browse files
authored
fix: correctly write all Awkward array types to RNTuple (#1496)
* Fixed writing indexed (option) arrays * Fixed writing of ListArrays * Updated test filename * Minor tweaks * Fixed issue with 32-bit indices * Test with 32-bit indices * Fix 32-bit indices for string arrays * Implemented custom carry to make sure that form is preserved * Fixed writing of trimmed ListOffsetArrays * Switched to using to_packed * Simpified code a bit * Read std::optional as IndexedOptionArray * Fixed warning message * Fixed typo
1 parent 90f69d2 commit 65b4af6

File tree

6 files changed

+318
-119
lines changed

6 files changed

+318
-119
lines changed

src/uproot/behaviors/RNTuple.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1750,6 +1750,27 @@ def _recursive_find(form, res):
17501750
_recursive_find(form.content, res)
17511751

17521752

1753+
def _cupy_insert(arr, obj, value):
1754+
# obj is assumed to be sorted
1755+
# both arr and obj are assumed to be flat arrays
1756+
cupy = uproot.extras.cupy()
1757+
out_size = arr.size + obj.size
1758+
out = cupy.empty(out_size, dtype=arr.dtype)
1759+
src_i = 0
1760+
dst_i = 0
1761+
for idx in obj.get():
1762+
n = idx - src_i
1763+
if n > 0:
1764+
out[dst_i : dst_i + n] = arr[src_i : src_i + n]
1765+
dst_i += n
1766+
src_i += n
1767+
out[dst_i] = value
1768+
dst_i += 1
1769+
if src_i < arr.size:
1770+
out[dst_i:] = arr[src_i:]
1771+
return out
1772+
1773+
17531774
def _fill_container_dict(container_dict, content, key, dtype_byte):
17541775
array_library_string = uproot._util.get_array_library(content)
17551776

@@ -1758,7 +1779,19 @@ def _fill_container_dict(container_dict, content, key, dtype_byte):
17581779
if "cardinality" in key:
17591780
content = library.diff(content)
17601781

1761-
if dtype_byte == uproot.const.rntuple_col_type_to_num_dict["switch"]:
1782+
if "optional" in key:
1783+
# We need to convert from a ListOffsetArray to an IndexedOptionArray
1784+
diff = library.diff(content)
1785+
missing = library.nonzero(diff == 0)[0]
1786+
missing -= library.arange(len(missing), dtype=missing.dtype)
1787+
dtype = "int64" if content.dtype == library.uint64 else "int32"
1788+
indices = library.arange(len(content) - len(missing), dtype=dtype)
1789+
if array_library_string == "numpy":
1790+
indices = numpy.insert(indices, missing, -1)
1791+
else:
1792+
indices = _cupy_insert(indices, missing, -1)
1793+
container_dict[f"{key}-index"] = indices
1794+
elif dtype_byte == uproot.const.rntuple_col_type_to_num_dict["switch"]:
17621795
kindex, tags = uproot.models.RNTuple._split_switch_bits(content)
17631796
# Find invalid variants and adjust buffers accordingly
17641797
invalid = numpy.flatnonzero(tags == -1)

src/uproot/models/RNTuple.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -441,8 +441,9 @@ def col_form(self, field_id, extra_parameters=None, is_cardinality=False):
441441
parameters = {"__array__": "string"}
442442
if extra_parameters is not None:
443443
parameters.update(extra_parameters)
444+
idx_type = "i32" if rel_crs[0].nbits == 32 else "i64"
444445
return ak.forms.ListOffsetForm(
445-
"i64", inner, form_key=form_key, parameters=parameters
446+
idx_type, inner, form_key=form_key, parameters=parameters
446447
)
447448
else:
448449
raise (RuntimeError(f"Missing special case: {field_id}"))
@@ -547,6 +548,11 @@ def field_form(self, this_id, keys, ak_add_doc=False):
547548
idx_type = (
548549
"i32" if self._column_records_dict[cfid][0].nbits == 32 else "i64"
549550
)
551+
if self._all_fields[cfid].record.type_name.startswith("std::optional"):
552+
keyname = keyname + "-optional"
553+
return ak.forms.IndexedOptionForm(
554+
idx_type, inner, form_key=keyname, parameters=parameters
555+
)
550556
return ak.forms.ListOffsetForm(
551557
idx_type, inner, form_key=keyname, parameters=parameters
552558
)

0 commit comments

Comments
 (0)