-
Notifications
You must be signed in to change notification settings - Fork 32
feat: add PyTree support for vector objects #637
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
nsmith-
wants to merge
19
commits into
scikit-hep:main
Choose a base branch
from
nsmith-:pytrees
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 4 commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
bc4730a
feat: add PyTree support for vector objects
nsmith- b20f8da
lint
nsmith- 5b6e9b2
Fix linter errors
nsmith- 00b25ef
Add numpy vectors and switch to register_pytree()
nsmith- eb8c4fb
Add some documentation
nsmith- 2ff291c
style: pre-commit fixes
pre-commit-ci[bot] 57e41ff
Only register once
nsmith- 9af994b
style: pre-commit fixes
pre-commit-ci[bot] b10ce54
Try 3.8-friendly syntax
nsmith- 37cdee0
Merge branch 'main' into pytrees
Saransh-cpp ae28060
Revert "Try 3.8-friendly syntax"
nsmith- 88e98ff
Address some comments
nsmith- 8f02315
style: pre-commit fixes
pre-commit-ci[bot] 2c324ad
Add to readme
nsmith- 07c536a
exclude spellcheck
nsmith- d78a526
better codespell
nsmith- 0f5fde0
style: pre-commit fixes
pre-commit-ci[bot] 290be54
Merge branch 'main' into pytrees
nsmith- 5c76ace
md and ipynb alias each other
nsmith- File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
"""PyTree operations for vector objects. | ||
|
||
This module defines how vector objects are handled within optree. | ||
See https://blog.scientific-python.org/pytrees/ for the rationale for these functions. | ||
""" | ||
|
||
from __future__ import annotations | ||
|
||
from functools import partial | ||
from typing import TYPE_CHECKING, Any | ||
|
||
import numpy | ||
|
||
if TYPE_CHECKING: | ||
from optree.pytree import ReexportedPyTreeModule | ||
|
||
from vector._methods import ( | ||
Vector2D, | ||
Vector3D, | ||
Vector4D, | ||
) | ||
from vector.backends.numpy import ( | ||
MomentumNumpy2D, | ||
MomentumNumpy3D, | ||
MomentumNumpy4D, | ||
VectorNumpy, | ||
VectorNumpy2D, | ||
VectorNumpy3D, | ||
VectorNumpy4D, | ||
) | ||
from vector.backends.object import ( | ||
MomentumObject2D, | ||
MomentumObject3D, | ||
MomentumObject4D, | ||
VectorObject2D, | ||
VectorObject3D, | ||
VectorObject4D, | ||
) | ||
|
||
Children = tuple[Any, ...] | ||
MetaData = tuple[type, ...] | ||
nsmith- marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
def _flatten2D(v: Vector2D) -> tuple[Children, MetaData]: | ||
children = v.azimuthal.elements | ||
metadata = type(v), type(v.azimuthal) | ||
return children, metadata | ||
|
||
|
||
def _unflatten2D(metadata: MetaData, children: Children) -> Vector2D: | ||
backend, azimuthal = metadata | ||
return backend(azimuthal=azimuthal(*children)) | ||
|
||
|
||
def _flatten3D(v: Vector3D) -> tuple[Children, MetaData]: | ||
children = v.azimuthal.elements, v.longitudinal.elements | ||
metadata = type(v), type(v.azimuthal), type(v.longitudinal) | ||
return children, metadata | ||
|
||
|
||
def _unflatten3D(metadata: MetaData, children: Children) -> Vector3D: | ||
coords_azimuthal, coords_longitudinal = children | ||
backend, azimuthal, longitudinal = metadata | ||
return backend( | ||
azimuthal=azimuthal(*coords_azimuthal), | ||
longitudinal=longitudinal(*coords_longitudinal), | ||
) | ||
|
||
|
||
def _flatten4D(v: Vector4D) -> tuple[Children, MetaData]: | ||
children = ( | ||
v.azimuthal.elements, | ||
v.longitudinal.elements, | ||
v.temporal.elements, | ||
) | ||
metadata = type(v), type(v.azimuthal), type(v.longitudinal), type(v.temporal) | ||
return children, metadata | ||
|
||
|
||
def _unflatten4D(metadata: MetaData, children: Children) -> Vector4D: | ||
coords_azimuthal, coords_longitudinal, coords_temporal = children | ||
backend, azimuthal, longitudinal, temporal = metadata | ||
return backend( | ||
azimuthal=azimuthal(*coords_azimuthal), | ||
longitudinal=longitudinal(*coords_longitudinal), | ||
temporal=temporal(*coords_temporal), | ||
) | ||
|
||
|
||
def _flattenAoSdata(v: VectorNumpy) -> tuple[Children, tuple[type, numpy.dtype]]: | ||
assert v.dtype.fields is not None | ||
field_dtypes = [dt for dt, *_ in v.dtype.fields.values()] | ||
target_dtype = field_dtypes[0] | ||
if not all(fd == target_dtype for fd in field_dtypes): | ||
raise ValueError("All fields must have the same dtype to flatten") | ||
array = numpy.array(v).view(target_dtype) | ||
Comment on lines
+86
to
+90
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We manually flatten the numpy struct dtype here and include the necessary information to revert it. For safety we require all fields to have the same dtype. This is necessary to fully flatten the array to a 1d vector |
||
children = (array,) | ||
metadata = (type(v), v.dtype) | ||
return children, metadata | ||
|
||
|
||
def _unflattenAoSdata( | ||
metadata: tuple[type, numpy.dtype], children: Children | ||
) -> VectorNumpy: | ||
(array,) = children | ||
(vtype, dtype) = metadata | ||
return array.view(dtype).view(vtype) | ||
|
||
|
||
def register_pytree() -> ReexportedPyTreeModule: | ||
"""Register vector objects with the optree module.""" | ||
try: | ||
import optree.pytree | ||
from optree import GetAttrEntry | ||
from optree.integrations.numpy import tree_ravel | ||
except ImportError as e: | ||
raise ImportError("Please install optree to use vector.pytree") from e | ||
|
||
pytree = optree.pytree.reexport(namespace="vector", module="vector.pytree") | ||
|
||
pytree.register_node( | ||
VectorObject2D, | ||
flatten_func=_flatten2D, | ||
unflatten_func=_unflatten2D, | ||
path_entry_type=GetAttrEntry, | ||
) | ||
pytree.register_node( | ||
MomentumObject2D, | ||
flatten_func=_flatten2D, | ||
unflatten_func=_unflatten2D, | ||
path_entry_type=GetAttrEntry, | ||
) | ||
pytree.register_node( | ||
VectorObject3D, | ||
flatten_func=_flatten3D, | ||
unflatten_func=_unflatten3D, | ||
path_entry_type=GetAttrEntry, | ||
) | ||
pytree.register_node( | ||
MomentumObject3D, | ||
flatten_func=_flatten3D, | ||
unflatten_func=_unflatten3D, | ||
path_entry_type=GetAttrEntry, | ||
) | ||
pytree.register_node( | ||
VectorObject4D, | ||
flatten_func=_flatten4D, | ||
unflatten_func=_unflatten4D, | ||
path_entry_type=GetAttrEntry, | ||
) | ||
pytree.register_node( | ||
MomentumObject4D, | ||
flatten_func=_flatten4D, | ||
unflatten_func=_unflatten4D, | ||
path_entry_type=GetAttrEntry, | ||
) | ||
|
||
for cls in ( | ||
VectorNumpy2D, | ||
MomentumNumpy2D, | ||
VectorNumpy3D, | ||
MomentumNumpy3D, | ||
VectorNumpy4D, | ||
MomentumNumpy4D, | ||
): | ||
pytree.register_node( | ||
cls, | ||
flatten_func=_flattenAoSdata, | ||
unflatten_func=_unflattenAoSdata, | ||
path_entry_type=GetAttrEntry, | ||
) | ||
|
||
# A convenience function | ||
pytree.ravel = partial(tree_ravel, namespace="vector") # type: ignore[attr-defined] | ||
|
||
return pytree |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.