@@ -111,7 +111,7 @@ def _cpp_typename(akform, subcall=False):
111111 elif isinstance (akform , awkward .forms .UnionForm ):
112112 field_typenames = [_cpp_typename (t , subcall = True ) for t in akform .contents ]
113113 typename = f"std::variant<{ ',' .join (field_typenames )} >"
114- elif isinstance (akform , awkward .forms .UnmaskedForm ):
114+ elif isinstance (akform , ( awkward .forms .UnmaskedForm , awkward . forms . IndexedForm ) ):
115115 return _cpp_typename (akform .content , subcall = True )
116116 else :
117117 raise NotImplementedError (f"Form type { type (akform )} cannot be written yet" )
@@ -484,7 +484,10 @@ def _build_field_col_records(
484484 field_name = subfield_name ,
485485 parent_fid = field_id ,
486486 )
487- elif isinstance (akform , awkward .forms .UnmaskedForm ):
487+ elif isinstance (
488+ akform , (awkward .forms .UnmaskedForm , awkward .forms .IndexedForm )
489+ ):
490+ # IndexedForms just get rearranged, so they are transparent
488491 # Do nothing
489492 self ._build_field_col_records (
490493 akform .content ,
@@ -930,18 +933,24 @@ def extend(self, file, sink, data):
930933
931934 cluster_page_data = [] # list of list of (locator, len, offset)
932935 data_buffers = awkward .to_buffers (data )[2 ]
933- for idx , key in enumerate (self ._header ._column_keys ):
934- if "switch" in key :
936+
937+ # We need to modify make a few modifications since not everything directly translates to RNTuples
938+ for key in list (data_buffers .keys ()):
939+ barekey = key .split ("-" )[0 ]
940+ if "offset" in key :
941+ # RNTuples don't store the first offset
942+ data_buffers [key ] = data_buffers [key ][1 :]
943+ elif "index" in key and barekey + "-tags" in data_buffers :
944+ # We group indices and tags into a single array
935945 dtype = numpy .dtype ([("index" , "int64" ), ("tag" , "int32" )])
936- indices = data_buffers [key . split ( "-" )[ 0 ] + "-index" ]
937- tags = data_buffers [key . split ( "-" )[ 0 ] + "-tags" ]
946+ indices = data_buffers [barekey + "-index" ]
947+ tags = data_buffers [barekey + "-tags" ]
938948 switches = numpy .zeros (len (indices ), dtype = dtype )
939949 switches ["index" ] = indices
940950 switches ["tag" ] = tags + 1
941- col_data = switches
942- elif "startstop " in key :
951+ data_buffers [ barekey + "-switch" ] = switches
952+ elif "start " in key :
943953 # ListArrays need to be converted to ListOffsetArrays
944- barekey = key .split ("-" )[0 ]
945954 starts = awkward .index .Index (data_buffers [f"{ barekey } -starts" ])
946955 stops = awkward .index .Index (data_buffers [f"{ barekey } -stops" ])
947956 next_barekey = f"node{ int (barekey [4 :])+ 1 } "
@@ -953,17 +962,20 @@ def extend(self, file, sink, data):
953962 starts , stops , content
954963 ).to_ListOffsetArray64 ()
955964 )[2 ]
965+ data_buffers [f"{ barekey } -startstop" ] = tmp_buffers ["node0-offsets" ][1 :]
956966 data_buffers [f"{ next_barekey } -data" ] = tmp_buffers ["node1-data" ]
957- col_data = tmp_buffers ["node0-offsets" ][1 :]
958- # no longer need the temporary data
959- del starts , stops , content , tmp_buffers
960- else :
961- col_data = data_buffers [key ]
962- if "offsets" in key :
963- col_data = col_data [1 :]
964967 elif "index" in key :
965- deltas = numpy .array (col_data != - 1 , dtype = col_data .dtype )
966- col_data = numpy .cumsum (deltas )
968+ # We need to rearrange the data
969+ next_barekey = f"node{ int (barekey [4 :])+ 1 } "
970+ index = data_buffers [key ]
971+ content = data_buffers [f"{ next_barekey } -data" ]
972+ content = content [index [index >= 0 ]] # Rearrange data
973+ deltas = numpy .array (index >= 0 , dtype = index .dtype )
974+ data_buffers [key ] = numpy .cumsum (deltas , dtype = deltas .dtype )
975+ data_buffers [f"{ next_barekey } -data" ] = content
976+
977+ for idx , key in enumerate (self ._header ._column_keys ):
978+ col_data = data_buffers [key ]
967979 col_len = len (col_data .reshape (- 1 ))
968980 raw_data = col_data .reshape (- 1 ).view ("uint8" )
969981 if col_data .dtype == numpy .dtype ("bool" ):
0 commit comments