Skip to content

Commit a83ed2d

Browse files
authored
feat: implement transforms that do zeros like of branches at the schema level (#1466)
* fun with forms * updates * delete alias transform * avoid bad egm nano name * photon energy fix too * use float32 * itemsize should be 4 * remove useless parameters * fix zeros from content
1 parent f7356e5 commit a83ed2d

File tree

2 files changed

+73
-0
lines changed

2 files changed

+73
-0
lines changed

src/coffea/nanoevents/schemas/nanoaod.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,24 @@ def _build_collections(self, field_names, input_contents):
296296
if all(k in branch_forms for k in args):
297297
branch_forms[name] = fcn(*(branch_forms[k] for k in args))
298298

299+
# Add mass and charge fields for Photon collection (always zero for photons)
300+
if "oPhoton" in branch_forms:
301+
if "Photon_mass" not in branch_forms:
302+
branch_forms["Photon_mass"] = transforms.zeros_from_offsets_form(
303+
branch_forms["oPhoton"]
304+
)
305+
if "Photon_charge" not in branch_forms:
306+
branch_forms["Photon_charge"] = transforms.zeros_from_offsets_form(
307+
branch_forms["oPhoton"]
308+
)
309+
310+
# Rename Electron/Photon_energy to Electron/Photon_regrEnergy to avoid conflict with mixin
311+
# Present in EGamma NanoAOD flavor
312+
if "Electron_energy" in branch_forms:
313+
branch_forms["Electron_regrEnergy"] = branch_forms.pop("Electron_energy")
314+
if "Photon_energy" in branch_forms:
315+
branch_forms["Photon_regrEnergy"] = branch_forms.pop("Photon_energy")
316+
299317
output = {}
300318
for name in collections:
301319
mixin = self.mixins.get(name, "NanoCollection")

src/coffea/nanoevents/transforms.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,61 @@ def eventindex(stack):
487487
stack.append(out)
488488

489489

490+
def zeros_from_content_form(source_form):
491+
form = copy.deepcopy(source_form)
492+
if not (form["class"] == "NumpyArray" or form["class"].startswith("ListOffset")):
493+
raise RuntimeError
494+
if form["class"] == "NumpyArray":
495+
form["form_key"] = concat(source_form["form_key"], "!zeros_from_content")
496+
form["parameters"].pop("__doc__", None)
497+
elif form["class"].startswith("ListOffset"):
498+
form["content"]["form_key"] = concat(
499+
source_form["form_key"], "!zeros_from_content", "!content"
500+
)
501+
form["parameters"].pop("__doc__", None)
502+
form["content"]["parameters"].pop("__doc__", None)
503+
return form
504+
505+
506+
def zeros_from_content(stack):
507+
source = stack.pop()
508+
stack.append(awkward.zeros_like(source))
509+
510+
511+
def zeros_from_offsets_form(offsets_form):
512+
if not offsets_form["class"].startswith("NumpyArray"):
513+
raise RuntimeError
514+
515+
form = {
516+
"class": "ListOffsetArray",
517+
"offsets": "i64",
518+
"content": {
519+
"class": "NumpyArray",
520+
"primitive": "float32",
521+
"itemsize": 4,
522+
"format": "i",
523+
"form_key": concat(
524+
offsets_form["form_key"], "!zeros_from_offsets", "!content"
525+
),
526+
},
527+
"form_key": concat(offsets_form["form_key"], "!zeros_from_offsets"),
528+
}
529+
return form
530+
531+
532+
def zeros_from_offsets(stack):
533+
offsets = ensure_array(stack.pop())
534+
n_elements = offsets[-1]
535+
content = numpy.zeros(n_elements, dtype=numpy.float32)
536+
out = awkward.Array(
537+
awkward.contents.ListOffsetArray(
538+
awkward.index.Index64(offsets),
539+
awkward.contents.NumpyArray(content),
540+
)
541+
)
542+
stack.append(out)
543+
544+
490545
# For EDM4HEPSchema and FCCSChema:
491546

492547

0 commit comments

Comments
 (0)