-
Notifications
You must be signed in to change notification settings - Fork 32
feat: add PyTree support for vector objects #637
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
nsmith-
wants to merge
19
commits into
scikit-hep:main
Choose a base branch
from
nsmith-:pytrees
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 3 commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
bc4730a
feat: add PyTree support for vector objects
nsmith- b20f8da
lint
nsmith- 5b6e9b2
Fix linter errors
nsmith- 00b25ef
Add numpy vectors and switch to register_pytree()
nsmith- eb8c4fb
Add some documentation
nsmith- 2ff291c
style: pre-commit fixes
pre-commit-ci[bot] 57e41ff
Only register once
nsmith- 9af994b
style: pre-commit fixes
pre-commit-ci[bot] b10ce54
Try 3.8-friendly syntax
nsmith- 37cdee0
Merge branch 'main' into pytrees
Saransh-cpp ae28060
Revert "Try 3.8-friendly syntax"
nsmith- 88e98ff
Address some comments
nsmith- 8f02315
style: pre-commit fixes
pre-commit-ci[bot] 2c324ad
Add to readme
nsmith- 07c536a
exclude spellcheck
nsmith- d78a526
better codespell
nsmith- 0f5fde0
style: pre-commit fixes
pre-commit-ci[bot] 290be54
Merge branch 'main' into pytrees
nsmith- 5c76ace
md and ipynb alias each other
nsmith- File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
"""PyTree operations for vector objects. | ||
This module defines how vector objects are handled within optree. | ||
See https://blog.scientific-python.org/pytrees/ for the rationale for these functions. | ||
""" | ||
|
||
from __future__ import annotations | ||
|
||
from types import ModuleType | ||
from typing import Any | ||
|
||
from vector._methods import ( | ||
Vector2D, | ||
Vector3D, | ||
Vector4D, | ||
) | ||
from vector.backends.object import ( | ||
MomentumObject2D, | ||
MomentumObject3D, | ||
MomentumObject4D, | ||
VectorObject2D, | ||
VectorObject3D, | ||
VectorObject4D, | ||
) | ||
nsmith- marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Children = tuple[Any, ...] | ||
MetaData = tuple[type, ...] | ||
nsmith- marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
def flatten2D(v: Vector2D) -> tuple[Children, MetaData]: | ||
children = v.azimuthal.elements | ||
metadata = type(v), type(v.azimuthal) | ||
return children, metadata | ||
|
||
|
||
def unflatten2D(metadata: MetaData, children: Children) -> Vector2D: | ||
backend, azimuthal = metadata | ||
return backend(azimuthal=azimuthal(*children)) | ||
|
||
|
||
def flatten3D(v: Vector3D) -> tuple[Children, MetaData]: | ||
children = v.azimuthal.elements, v.longitudinal.elements | ||
metadata = type(v), type(v.azimuthal), type(v.longitudinal) | ||
return children, metadata | ||
|
||
|
||
def unflatten3D(metadata: MetaData, children: Children) -> Vector3D: | ||
coords_azimuthal, coords_longitudinal = children | ||
backend, azimuthal, longitudinal = metadata | ||
return backend( | ||
azimuthal=azimuthal(*coords_azimuthal), | ||
longitudinal=longitudinal(*coords_longitudinal), | ||
) | ||
|
||
|
||
def flatten4D(v: Vector4D) -> tuple[Children, MetaData]: | ||
children = ( | ||
v.azimuthal.elements, | ||
v.longitudinal.elements, | ||
v.temporal.elements, | ||
) | ||
metadata = type(v), type(v.azimuthal), type(v.longitudinal), type(v.temporal) | ||
return children, metadata | ||
|
||
|
||
def unflatten4D(metadata: MetaData, children: Children) -> Vector4D: | ||
coords_azimuthal, coords_longitudinal, coords_temporal = children | ||
backend, azimuthal, longitudinal, temporal = metadata | ||
return backend( | ||
azimuthal=azimuthal(*coords_azimuthal), | ||
longitudinal=longitudinal(*coords_longitudinal), | ||
temporal=temporal(*coords_temporal), | ||
) | ||
|
||
|
||
def _register(pytree: ModuleType) -> None: | ||
"""Register vector objects with the given pytree module.""" | ||
|
||
pytree.register_node( | ||
VectorObject2D, flatten_func=flatten2D, unflatten_func=unflatten2D | ||
) | ||
pytree.register_node( | ||
MomentumObject2D, flatten_func=flatten2D, unflatten_func=unflatten2D | ||
) | ||
pytree.register_node( | ||
VectorObject3D, flatten_func=flatten3D, unflatten_func=unflatten3D | ||
) | ||
pytree.register_node( | ||
MomentumObject3D, flatten_func=flatten3D, unflatten_func=unflatten3D | ||
) | ||
pytree.register_node( | ||
VectorObject4D, flatten_func=flatten4D, unflatten_func=unflatten4D | ||
) | ||
pytree.register_node( | ||
MomentumObject4D, flatten_func=flatten4D, unflatten_func=unflatten4D | ||
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
# Copyright (c) 2025, Nick Smith | ||
# | ||
# Distributed under the 3-clause BSD license, see accompanying file LICENSE | ||
# or https://github.com/scikit-hep/vector for details. | ||
|
||
from __future__ import annotations | ||
|
||
import pytest | ||
|
||
import vector | ||
|
||
if vector.pytree is None: | ||
pytest.skip("optree is not installed", allow_module_level=True) | ||
|
||
|
||
def test_pytree_roundtrip_VectorObject2D(): | ||
vec = vector.obj(x=1, y=2) | ||
leaves, treedef = vector.pytree.flatten(vec) | ||
assert leaves == [1, 2] | ||
vec2 = vector.pytree.unflatten(treedef, leaves) | ||
assert vec == vec2 | ||
|
||
vec = vector.obj(rho=1, phi=2) | ||
leaves, treedef = vector.pytree.flatten(vec) | ||
assert leaves == [1, 2] | ||
vec2 = vector.pytree.unflatten(treedef, leaves) | ||
assert vec == vec2 | ||
assert type(vec.azimuthal) is type(vec2.azimuthal) | ||
|
||
|
||
def test_pytree_roundtrip_MomentumObject2D(): | ||
vec = vector.obj(px=1, py=2) | ||
leaves, treedef = vector.pytree.flatten(vec) | ||
assert leaves == [1, 2] | ||
vec2 = vector.pytree.unflatten(treedef, leaves) | ||
assert vec == vec2 | ||
|
||
vec = vector.obj(pt=1, phi=2) | ||
leaves, treedef = vector.pytree.flatten(vec) | ||
assert leaves == [1, 2] | ||
vec2 = vector.pytree.unflatten(treedef, leaves) | ||
assert vec == vec2 | ||
assert type(vec.azimuthal) is type(vec2.azimuthal) | ||
|
||
|
||
def test_pytree_roundtrip_VectorObject3D(): | ||
vec = vector.obj(x=1, y=2, z=3) | ||
leaves, treedef = vector.pytree.flatten(vec) | ||
assert leaves == [1, 2, 3] | ||
vec2 = vector.pytree.unflatten(treedef, leaves) | ||
assert vec == vec2 | ||
|
||
vec = vector.obj(rho=1, phi=2, z=3) | ||
leaves, treedef = vector.pytree.flatten(vec) | ||
assert leaves == [1, 2, 3] | ||
vec2 = vector.pytree.unflatten(treedef, leaves) | ||
assert vec == vec2 | ||
assert type(vec.azimuthal) is type(vec2.azimuthal) | ||
assert type(vec.longitudinal) is type(vec2.longitudinal) | ||
|
||
vec = vector.obj(x=1, y=2, theta=3) | ||
leaves, treedef = vector.pytree.flatten(vec) | ||
assert leaves == [1, 2, 3] | ||
vec2 = vector.pytree.unflatten(treedef, leaves) | ||
assert vec == vec2 | ||
assert type(vec.azimuthal) is type(vec2.azimuthal) | ||
assert type(vec.longitudinal) is type(vec2.longitudinal) | ||
|
||
|
||
def test_pytree_roundtrip_MomentumObject3D(): | ||
vec = vector.obj(px=1, py=2, pz=3) | ||
leaves, treedef = vector.pytree.flatten(vec) | ||
assert leaves == [1, 2, 3] | ||
vec2 = vector.pytree.unflatten(treedef, leaves) | ||
assert vec == vec2 | ||
|
||
vec = vector.obj(pt=1, phi=2, pz=3) | ||
leaves, treedef = vector.pytree.flatten(vec) | ||
assert leaves == [1, 2, 3] | ||
vec2 = vector.pytree.unflatten(treedef, leaves) | ||
assert vec == vec2 | ||
assert type(vec.azimuthal) is type(vec2.azimuthal) | ||
assert type(vec.longitudinal) is type(vec2.longitudinal) | ||
|
||
vec = vector.obj(px=1, py=2, theta=3) | ||
leaves, treedef = vector.pytree.flatten(vec) | ||
assert leaves == [1, 2, 3] | ||
vec2 = vector.pytree.unflatten(treedef, leaves) | ||
assert vec == vec2 | ||
assert type(vec.azimuthal) is type(vec2.azimuthal) | ||
assert type(vec.longitudinal) is type(vec2.longitudinal) | ||
|
||
|
||
def test_pytree_roundtrip_VectorObject4D(): | ||
vec = vector.obj(x=1, y=2, z=3, t=4) | ||
leaves, treedef = vector.pytree.flatten(vec) | ||
assert leaves == [1, 2, 3, 4] | ||
vec2 = vector.pytree.unflatten(treedef, leaves) | ||
assert vec == vec2 | ||
|
||
vec = vector.obj(rho=1, phi=2, z=3, t=4) | ||
leaves, treedef = vector.pytree.flatten(vec) | ||
assert leaves == [1, 2, 3, 4] | ||
vec2 = vector.pytree.unflatten(treedef, leaves) | ||
assert vec == vec2 | ||
assert type(vec.azimuthal) is type(vec2.azimuthal) | ||
assert type(vec.longitudinal) is type(vec2.longitudinal) | ||
assert type(vec.temporal) is type(vec2.temporal) | ||
|
||
vec = vector.obj(x=1, y=2, theta=3, t=4) | ||
leaves, treedef = vector.pytree.flatten(vec) | ||
assert leaves == [1, 2, 3, 4] | ||
vec2 = vector.pytree.unflatten(treedef, leaves) | ||
assert vec == vec2 | ||
assert type(vec.azimuthal) is type(vec2.azimuthal) | ||
assert type(vec.longitudinal) is type(vec2.longitudinal) | ||
assert type(vec.temporal) is type(vec2.temporal) | ||
|
||
vec = vector.obj(pt=1, phi=2, eta=3, mass=4) | ||
leaves, treedef = vector.pytree.flatten(vec) | ||
assert leaves == [1, 2, 3, 4] | ||
vec2 = vector.pytree.unflatten(treedef, leaves) | ||
assert vec == vec2 | ||
assert type(vec.azimuthal) is type(vec2.azimuthal) | ||
assert type(vec.longitudinal) is type(vec2.longitudinal) | ||
assert type(vec.temporal) is type(vec2.temporal) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.