Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ repos:
- numpy
- packaging
- sympy
- optree

- repo: https://github.com/codespell-project/codespell
rev: v2.4.1
Expand Down
9 changes: 9 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ optional-dependencies.test-extras = [
"dask_awkward",
"spark-parser",
'uncompyle6; python_version == "3.8"',
"optree>=0.16",
]
urls."Bug Tracker" = "https://github.com/scikit-hep/vector/issues"
urls.Changelog = "https://vector.readthedocs.io/en/latest/changelog.html"
Expand Down Expand Up @@ -146,6 +147,9 @@ isort.required-imports = [
"src/vector/_methods.py" = [
"PLC0415",
]
"src/vector/_pytree.py" = [
"PLC0415",
]
"src/vector/backends/_numba_object.py" = [
"PGH003",
]
Expand Down Expand Up @@ -274,3 +278,8 @@ module = [
ignore_missing_imports = true
disallow_untyped_defs = false
disallow_untyped_calls = false

[[tool.mypy.overrides]]
module = "vector._pytree"
# optree register_node typing requires vectors to be Collection, unnecessarily
disable_error_code = ["call-arg", "arg-type"]
3 changes: 2 additions & 1 deletion src/vector/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
Vector4D,
dim,
)
from vector._pytree import register_pytree
from vector._version import version as __version__
from vector.backends.awkward_constructors import Array, zip
from vector.backends.awkward_constructors import Array as awk
Expand Down Expand Up @@ -95,7 +96,6 @@ def _import_awkward() -> None:
VectorSympy4D,
)


__all__: tuple[str, ...] = (
"Array",
"Azimuthal",
Expand Down Expand Up @@ -148,6 +148,7 @@ def _import_awkward() -> None:
"obj",
"register_awkward",
"register_numba",
"register_pytree",
"zip",
)

Expand Down
176 changes: 176 additions & 0 deletions src/vector/_pytree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
"""PyTree operations for vector objects.

This module defines how vector objects are handled within optree.
See https://blog.scientific-python.org/pytrees/ for the rationale for these functions.
"""

from __future__ import annotations

from functools import partial
from typing import TYPE_CHECKING, Any

import numpy

if TYPE_CHECKING:
from optree.pytree import ReexportedPyTreeModule

from vector._methods import (
Vector2D,
Vector3D,
Vector4D,
)
from vector.backends.numpy import (
MomentumNumpy2D,
MomentumNumpy3D,
MomentumNumpy4D,
VectorNumpy,
VectorNumpy2D,
VectorNumpy3D,
VectorNumpy4D,
)
from vector.backends.object import (
MomentumObject2D,
MomentumObject3D,
MomentumObject4D,
VectorObject2D,
VectorObject3D,
VectorObject4D,
)

Children = tuple[Any, ...]
MetaData = tuple[type, ...]


def _flatten2D(v: Vector2D) -> tuple[Children, MetaData]:
children = v.azimuthal.elements
metadata = type(v), type(v.azimuthal)
return children, metadata


def _unflatten2D(metadata: MetaData, children: Children) -> Vector2D:
backend, azimuthal = metadata
return backend(azimuthal=azimuthal(*children))


def _flatten3D(v: Vector3D) -> tuple[Children, MetaData]:
children = v.azimuthal.elements, v.longitudinal.elements
metadata = type(v), type(v.azimuthal), type(v.longitudinal)
return children, metadata


def _unflatten3D(metadata: MetaData, children: Children) -> Vector3D:
coords_azimuthal, coords_longitudinal = children
backend, azimuthal, longitudinal = metadata
return backend(
azimuthal=azimuthal(*coords_azimuthal),
longitudinal=longitudinal(*coords_longitudinal),
)


def _flatten4D(v: Vector4D) -> tuple[Children, MetaData]:
children = (
v.azimuthal.elements,
v.longitudinal.elements,
v.temporal.elements,
)
metadata = type(v), type(v.azimuthal), type(v.longitudinal), type(v.temporal)
return children, metadata


def _unflatten4D(metadata: MetaData, children: Children) -> Vector4D:
coords_azimuthal, coords_longitudinal, coords_temporal = children
backend, azimuthal, longitudinal, temporal = metadata
return backend(
azimuthal=azimuthal(*coords_azimuthal),
longitudinal=longitudinal(*coords_longitudinal),
temporal=temporal(*coords_temporal),
)


def _flattenAoSdata(v: VectorNumpy) -> tuple[Children, tuple[type, numpy.dtype]]:
assert v.dtype.fields is not None
field_dtypes = [dt for dt, *_ in v.dtype.fields.values()]
target_dtype = field_dtypes[0]
if not all(fd == target_dtype for fd in field_dtypes):
raise ValueError("All fields must have the same dtype to flatten")
array = numpy.array(v).view(target_dtype)
Comment on lines +86 to +90
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We manually flatten the numpy struct dtype here and include the necessary information to revert it. For safety we require all fields to have the same dtype. This is necessary to fully flatten the array to a 1d vector

children = (array,)
metadata = (type(v), v.dtype)
return children, metadata


def _unflattenAoSdata(
metadata: tuple[type, numpy.dtype], children: Children
) -> VectorNumpy:
(array,) = children
(vtype, dtype) = metadata
return array.view(dtype).view(vtype)


def register_pytree() -> ReexportedPyTreeModule:
"""Register vector objects with the optree module."""
try:
import optree.pytree
from optree import GetAttrEntry
from optree.integrations.numpy import tree_ravel
except ImportError as e:
raise ImportError("Please install optree to use vector.pytree") from e

pytree = optree.pytree.reexport(namespace="vector", module="vector.pytree")

pytree.register_node(
VectorObject2D,
flatten_func=_flatten2D,
unflatten_func=_unflatten2D,
path_entry_type=GetAttrEntry,
)
pytree.register_node(
MomentumObject2D,
flatten_func=_flatten2D,
unflatten_func=_unflatten2D,
path_entry_type=GetAttrEntry,
)
pytree.register_node(
VectorObject3D,
flatten_func=_flatten3D,
unflatten_func=_unflatten3D,
path_entry_type=GetAttrEntry,
)
pytree.register_node(
MomentumObject3D,
flatten_func=_flatten3D,
unflatten_func=_unflatten3D,
path_entry_type=GetAttrEntry,
)
pytree.register_node(
VectorObject4D,
flatten_func=_flatten4D,
unflatten_func=_unflatten4D,
path_entry_type=GetAttrEntry,
)
pytree.register_node(
MomentumObject4D,
flatten_func=_flatten4D,
unflatten_func=_unflatten4D,
path_entry_type=GetAttrEntry,
)

for cls in (
VectorNumpy2D,
MomentumNumpy2D,
VectorNumpy3D,
MomentumNumpy3D,
VectorNumpy4D,
MomentumNumpy4D,
):
pytree.register_node(
cls,
flatten_func=_flattenAoSdata,
unflatten_func=_unflattenAoSdata,
path_entry_type=GetAttrEntry,
)

# A convenience function
pytree.ravel = partial(tree_ravel, namespace="vector") # type: ignore[attr-defined]

return pytree
Loading
Loading