Skip to content

Commit bc4b5b2

Browse files
authored
Fix '_wrap_result' for methods called on ak.Record. (#100)
1 parent 7978163 commit bc4b5b2

File tree

2 files changed

+76
-26
lines changed

2 files changed

+76
-26
lines changed

src/vector/_backends/awkward_.py

Lines changed: 55 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,14 @@ def _class_to_name(cls: typing.Type[VectorProtocol]) -> str:
332332
# the vector class ############################################################
333333

334334

335+
def _yes_record(x: ak.Array) -> typing.Optional[typing.Union[float, ak.Record]]:
336+
return x[0]
337+
338+
339+
def _no_record(x: ak.Array) -> typing.Optional[ak.Array]:
340+
return x
341+
342+
335343
class VectorAwkward:
336344
lib: types.ModuleType = numpy
337345

@@ -362,7 +370,18 @@ def _wrap_result(
362370
if returns == [float] or returns == [bool]:
363371
return result
364372

365-
elif (
373+
if all(not isinstance(x, ak.Array) for x in result):
374+
maybe_record = _yes_record
375+
result = [
376+
ak.Array(x.layout.array[x.layout.at : x.layout.at + 1])
377+
if isinstance(x, ak.Record)
378+
else ak.Array([x])
379+
for x in result
380+
]
381+
else:
382+
maybe_record = _no_record
383+
384+
if (
366385
len(returns) == 1
367386
and isinstance(returns[0], type)
368387
and issubclass(returns[0], Azimuthal)
@@ -396,11 +415,13 @@ def _wrap_result(
396415
else:
397416
cls = cls.ProjectionClass2D
398417

399-
return ak.zip(
400-
dict(zip(names, arrays)),
401-
depth_limit=first.layout.purelist_depth,
402-
with_name=_class_to_name(cls),
403-
behavior=None if vector._awkward_registered else first.behavior,
418+
return maybe_record(
419+
ak.zip(
420+
dict(zip(names, arrays)),
421+
depth_limit=first.layout.purelist_depth,
422+
with_name=_class_to_name(cls),
423+
behavior=None if vector._awkward_registered else first.behavior,
424+
)
404425
)
405426

406427
elif (
@@ -440,11 +461,13 @@ def _wrap_result(
440461
names.append(name)
441462
arrays.append(self[name])
442463

443-
return ak.zip(
444-
dict(zip(names, arrays)),
445-
depth_limit=first.layout.purelist_depth,
446-
with_name=_class_to_name(cls.ProjectionClass2D),
447-
behavior=None if vector._awkward_registered else first.behavior,
464+
return maybe_record(
465+
ak.zip(
466+
dict(zip(names, arrays)),
467+
depth_limit=first.layout.purelist_depth,
468+
with_name=_class_to_name(cls.ProjectionClass2D),
469+
behavior=None if vector._awkward_registered else first.behavior,
470+
)
448471
)
449472

450473
elif (
@@ -491,11 +514,13 @@ def _wrap_result(
491514
else:
492515
cls = cls.ProjectionClass3D
493516

494-
return ak.zip(
495-
dict(zip(names, arrays)),
496-
depth_limit=first.layout.purelist_depth,
497-
with_name=_class_to_name(cls),
498-
behavior=None if vector._awkward_registered else first.behavior,
517+
return maybe_record(
518+
ak.zip(
519+
dict(zip(names, arrays)),
520+
depth_limit=first.layout.purelist_depth,
521+
with_name=_class_to_name(cls),
522+
behavior=None if vector._awkward_registered else first.behavior,
523+
)
499524
)
500525

501526
elif (
@@ -547,11 +572,13 @@ def _wrap_result(
547572
names.append(name)
548573
arrays.append(self[name])
549574

550-
return ak.zip(
551-
dict(zip(names, arrays)),
552-
depth_limit=first.layout.purelist_depth,
553-
with_name=_class_to_name(cls.ProjectionClass3D),
554-
behavior=None if vector._awkward_registered else first.behavior,
575+
return maybe_record(
576+
ak.zip(
577+
dict(zip(names, arrays)),
578+
depth_limit=first.layout.purelist_depth,
579+
with_name=_class_to_name(cls.ProjectionClass3D),
580+
behavior=None if vector._awkward_registered else first.behavior,
581+
)
555582
)
556583

557584
elif (
@@ -611,11 +638,13 @@ def _wrap_result(
611638
names.append(name)
612639
arrays.append(self[name])
613640

614-
return ak.zip(
615-
dict(zip(names, arrays)),
616-
depth_limit=first.layout.purelist_depth,
617-
with_name=_class_to_name(cls.ProjectionClass4D),
618-
behavior=None if vector._awkward_registered else first.behavior,
641+
return maybe_record(
642+
ak.zip(
643+
dict(zip(names, arrays)),
644+
depth_limit=first.layout.purelist_depth,
645+
with_name=_class_to_name(cls.ProjectionClass4D),
646+
behavior=None if vector._awkward_registered else first.behavior,
647+
)
619648
)
620649

621650
else:

tests/test_issues.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright (c) 2019-2021, Jonas Eschle, Jim Pivarski, Eduardo Rodrigues, and Henry Schreiner.
2+
#
3+
# Distributed under the 3-clause BSD license, see accompanying file LICENSE
4+
# or https://github.com/scikit-hep/vector for details.
5+
6+
import pytest
7+
8+
import vector
9+
10+
11+
def test_issue_99():
12+
ak = pytest.importorskip("awkward")
13+
vector.register_awkward()
14+
vec = ak.Array([{"x": 1.0, "y": 2.0, "z": 3.0}], with_name="Vector3D")
15+
assert vec.to_xyz().tolist() == [{"x": 1.0, "y": 2.0, "z": 3.0}]
16+
assert vec[0].to_xyz().tolist() == {"x": 1.0, "y": 2.0, "z": 3.0}
17+
assert vec[0].to_rhophiz().tolist() == {
18+
"rho": 2.23606797749979,
19+
"phi": 1.1071487177940904,
20+
"z": 3.0,
21+
}

0 commit comments

Comments
 (0)