Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,12 @@ repos:
- numpy
- packaging
- sympy
- optree

- repo: https://github.com/codespell-project/codespell
rev: v2.4.1
hooks:
- id: codespell
args: ["-LHEP"]
exclude: ^docs/usage/intro.ipynb$

- repo: https://github.com/rbubley/mirrors-prettier
rev: "v3.6.2"
Expand Down
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ Vectors may be included in any of these data types:

Each of these "backends" provides the same suite of properties and methods, through a common "compute" library.

### Integrations

Optionally, the vector package provides integration with other libraries. Currently, this includes:

- [PyTree integrations](https://vector.readthedocs.io/en/latest/src/pytree.html) using the [optree](https://github.com/metaopt/optree) package.

### Geometric versus momentum

Finally, vectors come in two flavors:
Expand Down Expand Up @@ -134,6 +140,10 @@ journal = {Journal of Open Source Software}
- [Interface for 3D momentum](https://vector.readthedocs.io/en/latest/src/momentum3d.html)
- [Interface for 4D momentum](https://vector.readthedocs.io/en/latest/src/momentum4d.html)

### Integrations

- [PyTree integration API](https://vector.readthedocs.io/en/latest/src/pytree_api.html)

### More ways to learn

- [Papers and talks](https://vector.readthedocs.io/en/latest/src/talks.html)
Expand Down
13 changes: 13 additions & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ Vectors may be included in any of these data types:

Each of these "backends" provides the same suite of properties and methods, through a common "compute" library.

### Integrations

Optionally, the vector package provides integration with other libraries. Currently, this includes:

- [PyTree integrations](https://vector.readthedocs.io/en/latest/src/pytree.html) using the [optree](https://github.com/metaopt/optree) package.

### Geometric versus momentum

Finally, vectors come in two flavors:
Expand Down Expand Up @@ -133,6 +139,7 @@ src/numpy.ipynb
src/awkward.ipynb
src/numba.ipynb
src/sympy.ipynb
src/pytree.ipynb
```

```{toctree}
Expand All @@ -156,6 +163,12 @@ src/momentum3d.md
src/momentum4d.md
```

```{toctree}
:maxdepth: 1
:caption: Integrations
src/pytree_api.md
```

```{toctree}
:maxdepth: 1
:caption: More ways to learn
Expand Down
292 changes: 292 additions & 0 deletions docs/src/pytree.ipynb

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions docs/src/pytree_api.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# PyTree API

```{eval-rst}
.. autofunction:: vector.register_pytree
```
15 changes: 15 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ test-extras = [
"jax",
"dask_awkward",
"spark-parser",
"optree>=0.16",
]
docs = [
"awkward>=2",
Expand All @@ -89,6 +90,8 @@ docs = [
"sphinx-book-theme>=0.0.42",
"sphinx-copybutton",
"sphinx-math-dollar",
"hepunits",
"matplotlib",
]
test-all = [
{ include-group = "test"},
Expand Down Expand Up @@ -148,6 +151,9 @@ isort.required-imports = [
"src/vector/_methods.py" = [
"PLC0415",
]
"src/vector/_pytree.py" = [
"PLC0415",
]
"src/vector/backends/_numba_object.py" = [
"PGH003",
]
Expand Down Expand Up @@ -275,3 +281,12 @@ module = [
ignore_missing_imports = true
disallow_untyped_defs = false
disallow_untyped_calls = false

[[tool.mypy.overrides]]
module = "vector._pytree"
# optree register_node typing requires vectors to be Collection, unnecessarily
disable_error_code = ["call-arg", "arg-type"]

[tool.codespell]
ignore-words-list = "HEP"
ignore-regex = "[A-Za-z0-9+/]{100,}"
3 changes: 2 additions & 1 deletion src/vector/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
Vector4D,
dim,
)
from vector._pytree import register_pytree
from vector._version import version as __version__
from vector.backends.awkward_constructors import Array, zip
from vector.backends.awkward_constructors import Array as awk
Expand Down Expand Up @@ -95,7 +96,6 @@ def _import_awkward() -> None:
VectorSympy4D,
)


__all__: tuple[str, ...] = (
"Array",
"Azimuthal",
Expand Down Expand Up @@ -148,6 +148,7 @@ def _import_awkward() -> None:
"obj",
"register_awkward",
"register_numba",
"register_pytree",
"zip",
)

Expand Down
222 changes: 222 additions & 0 deletions src/vector/_pytree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
from __future__ import annotations

from functools import partial
from typing import TYPE_CHECKING, Any

import numpy

if TYPE_CHECKING:
from optree.pytree import ReexportedPyTreeModule

from vector._methods import (
Vector2D,
Vector3D,
Vector4D,
)
from vector.backends.numpy import (
MomentumNumpy2D,
MomentumNumpy3D,
MomentumNumpy4D,
VectorNumpy,
VectorNumpy2D,
VectorNumpy3D,
VectorNumpy4D,
)
from vector.backends.object import (
MomentumObject2D,
MomentumObject3D,
MomentumObject4D,
VectorObject2D,
VectorObject3D,
VectorObject4D,
)

Children = tuple[Any, ...]
MetaData = tuple[type, ...]


def _flatten2D(v: Vector2D) -> tuple[Children, MetaData]:
children = v.azimuthal.elements
metadata = type(v), type(v.azimuthal)
return children, metadata


def _unflatten2D(metadata: MetaData, children: Children) -> Vector2D:
backend, azimuthal = metadata
return backend(azimuthal=azimuthal(*children))


def _flatten3D(v: Vector3D) -> tuple[Children, MetaData]:
children = v.azimuthal.elements, v.longitudinal.elements
metadata = type(v), type(v.azimuthal), type(v.longitudinal)
return children, metadata


def _unflatten3D(metadata: MetaData, children: Children) -> Vector3D:
coords_azimuthal, coords_longitudinal = children
backend, azimuthal, longitudinal = metadata
return backend(
azimuthal=azimuthal(*coords_azimuthal),
longitudinal=longitudinal(*coords_longitudinal),
)


def _flatten4D(v: Vector4D) -> tuple[Children, MetaData]:
children = (
v.azimuthal.elements,
v.longitudinal.elements,
v.temporal.elements,
)
metadata = type(v), type(v.azimuthal), type(v.longitudinal), type(v.temporal)
return children, metadata


def _unflatten4D(metadata: MetaData, children: Children) -> Vector4D:
coords_azimuthal, coords_longitudinal, coords_temporal = children
backend, azimuthal, longitudinal, temporal = metadata
return backend(
azimuthal=azimuthal(*coords_azimuthal),
longitudinal=longitudinal(*coords_longitudinal),
temporal=temporal(*coords_temporal),
)


def _flattenAoSdata(v: VectorNumpy) -> tuple[Children, tuple[type, numpy.dtype]]:
assert v.dtype.fields is not None
field_dtypes = [dt for dt, *_ in v.dtype.fields.values()]
target_dtype = field_dtypes[0]
if not all(fd == target_dtype for fd in field_dtypes):
raise ValueError("All fields must have the same dtype to flatten")
array = numpy.array(v).view(target_dtype)
Comment on lines +86 to +90
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We manually flatten the numpy struct dtype here and include the necessary information to revert it. For safety we require all fields to have the same dtype. This is necessary to fully flatten the array to a 1d vector

children = (array,)
metadata = (type(v), v.dtype)
return children, metadata


def _unflattenAoSdata(
metadata: tuple[type, numpy.dtype], children: Children
) -> VectorNumpy:
(array,) = children
(vtype, dtype) = metadata
return array.view(dtype).view(vtype)


_MODULE: list[ReexportedPyTreeModule] = []


def register_pytree() -> ReexportedPyTreeModule:
"""Register Optree PyTree operations for vector objects.

This module defines how vector objects are handled with the optree package.
See https://blog.scientific-python.org/pytrees/ for the rationale for these functions.

After calling this function,

>>> import vector
>>> vector.register_pytree()
<module 'vector.pytree'>

the following classes can be flattened and unflattened with the `optree` package:

- VectorObject*D
- MomentumObject*D
- VectorNumpy*D
- MomentumNumpy*D

For example:

>>> import optree
>>> vec = vector.obj(x=1, y=2)
>>> leaves, treedef = optree.tree_flatten(vec, namespace="vector")
>>> vec2 = optree.tree_unflatten(treedef, leaves)
>>> assert vec == vec2

As a convenience, we return a re-exported module that can be used without the ``namespace``
argument. For example:

>>> pytree = vector.register_pytree()
>>> vec = vector.obj(x=1, y=2)
>>> leaves, treedef = pytree.flatten(vec)
>>> vec2 = pytree.unflatten(treedef, leaves)
>>> assert vec == vec2

A ravel function is also added to the returned PyTree module,
which can be used to flatten VectorNumpy arrays into a 1D array and reconstruct them.

>>> import numpy as np
>>> vec = vector.array({"x": np.ones(10), "y": np.ones(10)})
>>> flat, unravel = pytree.ravel(vec)
>>> assert flat.shape == (20,)
>>> vec2 = unravel(flat)
>>> assert (vec == vec2).all()

Note that this function requires the `optree` package to be installed.
"""
if _MODULE:
return _MODULE[0]
try:
import optree.pytree
from optree import GetAttrEntry
from optree.integrations.numpy import tree_ravel
except ImportError as e:
raise ImportError("Please install optree to use vector.pytree") from e

pytree = optree.pytree.reexport(namespace="vector", module="vector.pytree")

pytree.register_node(
VectorObject2D,
flatten_func=_flatten2D,
unflatten_func=_unflatten2D,
path_entry_type=GetAttrEntry,
)
pytree.register_node(
MomentumObject2D,
flatten_func=_flatten2D,
unflatten_func=_unflatten2D,
path_entry_type=GetAttrEntry,
)
pytree.register_node(
VectorObject3D,
flatten_func=_flatten3D,
unflatten_func=_unflatten3D,
path_entry_type=GetAttrEntry,
)
pytree.register_node(
MomentumObject3D,
flatten_func=_flatten3D,
unflatten_func=_unflatten3D,
path_entry_type=GetAttrEntry,
)
pytree.register_node(
VectorObject4D,
flatten_func=_flatten4D,
unflatten_func=_unflatten4D,
path_entry_type=GetAttrEntry,
)
pytree.register_node(
MomentumObject4D,
flatten_func=_flatten4D,
unflatten_func=_unflatten4D,
path_entry_type=GetAttrEntry,
)

for cls in (
VectorNumpy2D,
MomentumNumpy2D,
VectorNumpy3D,
MomentumNumpy3D,
VectorNumpy4D,
MomentumNumpy4D,
):
pytree.register_node(
cls,
flatten_func=_flattenAoSdata,
unflatten_func=_unflattenAoSdata,
path_entry_type=GetAttrEntry,
)

# A convenience function
pytree.ravel = partial(tree_ravel, namespace="vector") # type: ignore[attr-defined]

_MODULE.append(pytree)
return pytree
Loading