Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ repos:
- numpy
- packaging
- sympy
- optree

- repo: https://github.com/codespell-project/codespell
rev: v2.4.1
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ optional-dependencies.test-extras = [
"dask_awkward",
"spark-parser",
'uncompyle6; python_version == "3.8"',
"optree>=0.16",
]
urls."Bug Tracker" = "https://github.com/scikit-hep/vector/issues"
urls.Changelog = "https://vector.readthedocs.io/en/latest/changelog.html"
Expand Down
13 changes: 13 additions & 0 deletions src/vector/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import importlib.metadata
import typing
from types import ModuleType

import packaging.version

Expand Down Expand Up @@ -96,6 +97,17 @@ def _import_awkward() -> None:
)


pytree: ModuleType | None = None
try:
import optree.pytree

from vector._pytree import _register

pytree = optree.pytree.reexport(namespace="vector")
_register(pytree)
except ImportError:
pass

__all__: tuple[str, ...] = (
"Array",
"Azimuthal",
Expand Down Expand Up @@ -146,6 +158,7 @@ def _import_awkward() -> None:
"awkward_transform",
"dim",
"obj",
"pytree",
"register_awkward",
"register_numba",
"zip",
Expand Down
96 changes: 96 additions & 0 deletions src/vector/_pytree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""PyTree operations for vector objects.
This module defines how vector objects are handled within optree.
See https://blog.scientific-python.org/pytrees/ for the rationale for these functions.
"""

from __future__ import annotations

from types import ModuleType
from typing import Any

from vector._methods import (
Vector2D,
Vector3D,
Vector4D,
)
from vector.backends.object import (
MomentumObject2D,
MomentumObject3D,
MomentumObject4D,
VectorObject2D,
VectorObject3D,
VectorObject4D,
)

Children = tuple[Any, ...]
MetaData = tuple[type, ...]


def flatten2D(v: Vector2D) -> tuple[Children, MetaData]:
children = v.azimuthal.elements
metadata = type(v), type(v.azimuthal)
return children, metadata


def unflatten2D(metadata: MetaData, children: Children) -> Vector2D:
backend, azimuthal = metadata
return backend(azimuthal=azimuthal(*children))


def flatten3D(v: Vector3D) -> tuple[Children, MetaData]:
children = v.azimuthal.elements, v.longitudinal.elements
metadata = type(v), type(v.azimuthal), type(v.longitudinal)
return children, metadata


def unflatten3D(metadata: MetaData, children: Children) -> Vector3D:
coords_azimuthal, coords_longitudinal = children
backend, azimuthal, longitudinal = metadata
return backend(
azimuthal=azimuthal(*coords_azimuthal),
longitudinal=longitudinal(*coords_longitudinal),
)


def flatten4D(v: Vector4D) -> tuple[Children, MetaData]:
children = (
v.azimuthal.elements,
v.longitudinal.elements,
v.temporal.elements,
)
metadata = type(v), type(v.azimuthal), type(v.longitudinal), type(v.temporal)
return children, metadata


def unflatten4D(metadata: MetaData, children: Children) -> Vector4D:
coords_azimuthal, coords_longitudinal, coords_temporal = children
backend, azimuthal, longitudinal, temporal = metadata
return backend(
azimuthal=azimuthal(*coords_azimuthal),
longitudinal=longitudinal(*coords_longitudinal),
temporal=temporal(*coords_temporal),
)


def _register(pytree: ModuleType) -> None:
"""Register vector objects with the given pytree module."""

pytree.register_node(
VectorObject2D, flatten_func=flatten2D, unflatten_func=unflatten2D
)
pytree.register_node(
MomentumObject2D, flatten_func=flatten2D, unflatten_func=unflatten2D
)
pytree.register_node(
VectorObject3D, flatten_func=flatten3D, unflatten_func=unflatten3D
)
pytree.register_node(
MomentumObject3D, flatten_func=flatten3D, unflatten_func=unflatten3D
)
pytree.register_node(
VectorObject4D, flatten_func=flatten4D, unflatten_func=unflatten4D
)
pytree.register_node(
MomentumObject4D, flatten_func=flatten4D, unflatten_func=unflatten4D
)
126 changes: 126 additions & 0 deletions tests/test_pytree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Copyright (c) 2025, Nick Smith
#
# Distributed under the 3-clause BSD license, see accompanying file LICENSE
# or https://github.com/scikit-hep/vector for details.

from __future__ import annotations

import pytest

import vector

if vector.pytree is None:
pytest.skip("optree is not installed", allow_module_level=True)


def test_pytree_roundtrip_VectorObject2D():
vec = vector.obj(x=1, y=2)
leaves, treedef = vector.pytree.flatten(vec)
assert leaves == [1, 2]
vec2 = vector.pytree.unflatten(treedef, leaves)
assert vec == vec2

vec = vector.obj(rho=1, phi=2)
leaves, treedef = vector.pytree.flatten(vec)
assert leaves == [1, 2]
vec2 = vector.pytree.unflatten(treedef, leaves)
assert vec == vec2
assert type(vec.azimuthal) is type(vec2.azimuthal)


def test_pytree_roundtrip_MomentumObject2D():
vec = vector.obj(px=1, py=2)
leaves, treedef = vector.pytree.flatten(vec)
assert leaves == [1, 2]
vec2 = vector.pytree.unflatten(treedef, leaves)
assert vec == vec2

vec = vector.obj(pt=1, phi=2)
leaves, treedef = vector.pytree.flatten(vec)
assert leaves == [1, 2]
vec2 = vector.pytree.unflatten(treedef, leaves)
assert vec == vec2
assert type(vec.azimuthal) is type(vec2.azimuthal)


def test_pytree_roundtrip_VectorObject3D():
vec = vector.obj(x=1, y=2, z=3)
leaves, treedef = vector.pytree.flatten(vec)
assert leaves == [1, 2, 3]
vec2 = vector.pytree.unflatten(treedef, leaves)
assert vec == vec2

vec = vector.obj(rho=1, phi=2, z=3)
leaves, treedef = vector.pytree.flatten(vec)
assert leaves == [1, 2, 3]
vec2 = vector.pytree.unflatten(treedef, leaves)
assert vec == vec2
assert type(vec.azimuthal) is type(vec2.azimuthal)
assert type(vec.longitudinal) is type(vec2.longitudinal)

vec = vector.obj(x=1, y=2, theta=3)
leaves, treedef = vector.pytree.flatten(vec)
assert leaves == [1, 2, 3]
vec2 = vector.pytree.unflatten(treedef, leaves)
assert vec == vec2
assert type(vec.azimuthal) is type(vec2.azimuthal)
assert type(vec.longitudinal) is type(vec2.longitudinal)


def test_pytree_roundtrip_MomentumObject3D():
vec = vector.obj(px=1, py=2, pz=3)
leaves, treedef = vector.pytree.flatten(vec)
assert leaves == [1, 2, 3]
vec2 = vector.pytree.unflatten(treedef, leaves)
assert vec == vec2

vec = vector.obj(pt=1, phi=2, pz=3)
leaves, treedef = vector.pytree.flatten(vec)
assert leaves == [1, 2, 3]
vec2 = vector.pytree.unflatten(treedef, leaves)
assert vec == vec2
assert type(vec.azimuthal) is type(vec2.azimuthal)
assert type(vec.longitudinal) is type(vec2.longitudinal)

vec = vector.obj(px=1, py=2, theta=3)
leaves, treedef = vector.pytree.flatten(vec)
assert leaves == [1, 2, 3]
vec2 = vector.pytree.unflatten(treedef, leaves)
assert vec == vec2
assert type(vec.azimuthal) is type(vec2.azimuthal)
assert type(vec.longitudinal) is type(vec2.longitudinal)


def test_pytree_roundtrip_VectorObject4D():
vec = vector.obj(x=1, y=2, z=3, t=4)
leaves, treedef = vector.pytree.flatten(vec)
assert leaves == [1, 2, 3, 4]
vec2 = vector.pytree.unflatten(treedef, leaves)
assert vec == vec2

vec = vector.obj(rho=1, phi=2, z=3, t=4)
leaves, treedef = vector.pytree.flatten(vec)
assert leaves == [1, 2, 3, 4]
vec2 = vector.pytree.unflatten(treedef, leaves)
assert vec == vec2
assert type(vec.azimuthal) is type(vec2.azimuthal)
assert type(vec.longitudinal) is type(vec2.longitudinal)
assert type(vec.temporal) is type(vec2.temporal)

vec = vector.obj(x=1, y=2, theta=3, t=4)
leaves, treedef = vector.pytree.flatten(vec)
assert leaves == [1, 2, 3, 4]
vec2 = vector.pytree.unflatten(treedef, leaves)
assert vec == vec2
assert type(vec.azimuthal) is type(vec2.azimuthal)
assert type(vec.longitudinal) is type(vec2.longitudinal)
assert type(vec.temporal) is type(vec2.temporal)

vec = vector.obj(pt=1, phi=2, eta=3, mass=4)
leaves, treedef = vector.pytree.flatten(vec)
assert leaves == [1, 2, 3, 4]
vec2 = vector.pytree.unflatten(treedef, leaves)
assert vec == vec2
assert type(vec.azimuthal) is type(vec2.azimuthal)
assert type(vec.longitudinal) is type(vec2.longitudinal)
assert type(vec.temporal) is type(vec2.temporal)
Loading