Skip to content
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ repos:
- numpy
- packaging
- sympy
- optree

- repo: https://github.com/codespell-project/codespell
rev: v2.4.1
Expand Down
Binary file added docs/_images/projectile-demo.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 6 additions & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,12 @@ src/momentum3d.md
src/momentum4d.md
```

```{toctree}
:maxdepth: 1
:caption: Integrations
src/pytree.md
```

```{toctree}
:maxdepth: 1
:caption: More ways to learn
Expand Down
139 changes: 139 additions & 0 deletions docs/src/pytree.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# PyTree integrations

PyTrees are a [powerful mechanism](https://blog.scientific-python.org/pytrees/) for working with
nested data structures, while allowing algorithms like finite-differences, minimization, and integration routines
to run on flattened 1D arrays of the the same data. To use PyTrees with vector objects, you need to install
the [optree](https://github.com/metaopt/optree) package, and then register vector with PyTrees, for example:

```python
import vector

pytree = vector.register_pytree()

state = {
"position": vector.obj(x=1, y=2, z=3, t=0),
"momentum": vector.obj(x=0, y=10, z=0, t=14),
}
flat_state, treedef = pytree.flatten(state)
print(flat_state)
print(treedef)
```

`flat_state` is now a 1D array of length 8, and `treedef` contains the information needed to reconstruct the original structure:

```
[0, 10, 0, 14, 1, 2, 3, 0]
PyTreeSpec({'momentum': CustomTreeNode(VectorObject4D[(<class 'vector.backends.object.VectorObject4D'>, <class 'vector.backends.object.AzimuthalObjectXY'>, <class 'vector.backends.object.LongitudinalObjectZ'>, <class 'vector.backends.object.TemporalObjectT'>)], [(*, *), (*,), (*,)]), 'position': CustomTreeNode(VectorObject4D[(<class 'vector.backends.object.VectorObject4D'>, <class 'vector.backends.object.AzimuthalObjectXY'>, <class 'vector.backends.object.LongitudinalObjectZ'>, <class 'vector.backends.object.TemporalObjectT'>)], [(*, *), (*,), (*,)])}, namespace='vector')
```

The original structure can be reconstructed with:

```python
reconstructed_state = pytree.unflatten(treedef, flat_state)
```

## Registration

```{eval-rst}
.. autofunction:: vector.register_pytree
```

## Example: projectile motion with air resistance

In the following example, we use `scipy.integrate.solve_ivp` to solve the equations of motion for a projectile under the influence of gravity and air resistance. We use PyTrees to flatten the state dictionary into a 1D array for the integrator, and then unflatten the result back into the original structure.

```python
import hepunits as u
import matplotlib.pyplot as plt
import numpy as np
from scipy.integrate import solve_ivp

import vector

pytree = vector.register_pytree()


def wrapped_solve(fun, t_span, y0, t_eval):
flat_y0, treedef = pytree.flatten(y0)

def flat_fun(t, flat_y):
state = pytree.unflatten(treedef, flat_y)
dstate_dt = fun(t, state)
flat_dstate_dt, _ = pytree.flatten(dstate_dt)
return flat_dstate_dt

flat_solution = solve_ivp(
fun=flat_fun,
t_span=t_span,
y0=flat_y0,
t_eval=t_eval,
)
return pytree.unflatten(treedef, flat_solution.y)


gravitational_acceleration = vector.obj(x=0.0, y=-9.81) * (u.m / u.s**2)
air_density = 1.25 * u.kg / u.meter3
drag_coefficient = 0.47 # for a sphere
object_radius = 10 * u.cm
object_area = np.pi * object_radius**2
object_mass = 1.0 * u.kg


def force(position: vector.Vector2D, momentum: vector.Vector2D):
gravitational_force = object_mass * gravitational_acceleration
speed = momentum.rho / object_mass
drag_magnitude = 0.5 * air_density * speed**2 * drag_coefficient * object_area
drag_force = -drag_magnitude * momentum.unit()
return gravitational_force + drag_force


def tangent(t: float, state):
position, momentum = state
dposition_dt = momentum / object_mass
dmomentum_dt = force(position, momentum)
return (dposition_dt, dmomentum_dt)


time = np.linspace(0, 4, 100) * u.s
initial_position = vector.obj(
x=0 * u.m,
y=0 * u.m,
)
initial_momentum = (
vector.obj(
x=1 * u.m / u.s,
y=20 * u.m / u.s,
)
* object_mass
)

position, momentum = wrapped_solve(
fun=tangent,
t_span=(time[0], time[-1]),
y0=(initial_position, initial_momentum),
t_eval=time,
)

fig, ax = plt.subplots()
ax.plot(position.x / u.m, position.y / u.m, label="with air resistance")

position_no_drag = (
initial_momentum / object_mass * time + 0.5 * gravitational_acceleration * time**2
)

ax.plot(
position_no_drag.x / u.m,
position_no_drag.y / u.m,
ls="--",
label="no air resistance",
)

ax.set_xlabel("x (m)")
ax.set_ylabel("y (m)")
ax.set_title("Projectile motion")
ax.legend()

fig.savefig("projectile-demo.png")
```

![](../_images/projectile-demo.png)
9 changes: 9 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ test-extras = [
"jax",
"dask_awkward",
"spark-parser",
"optree>=0.16",
]
docs = [
"awkward>=2",
Expand Down Expand Up @@ -148,6 +149,9 @@ isort.required-imports = [
"src/vector/_methods.py" = [
"PLC0415",
]
"src/vector/_pytree.py" = [
"PLC0415",
]
"src/vector/backends/_numba_object.py" = [
"PGH003",
]
Expand Down Expand Up @@ -275,3 +279,8 @@ module = [
ignore_missing_imports = true
disallow_untyped_defs = false
disallow_untyped_calls = false

[[tool.mypy.overrides]]
module = "vector._pytree"
# optree register_node typing requires vectors to be Collection, unnecessarily
disable_error_code = ["call-arg", "arg-type"]
3 changes: 2 additions & 1 deletion src/vector/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
Vector4D,
dim,
)
from vector._pytree import register_pytree
from vector._version import version as __version__
from vector.backends.awkward_constructors import Array, zip
from vector.backends.awkward_constructors import Array as awk
Expand Down Expand Up @@ -95,7 +96,6 @@ def _import_awkward() -> None:
VectorSympy4D,
)


__all__: tuple[str, ...] = (
"Array",
"Azimuthal",
Expand Down Expand Up @@ -148,6 +148,7 @@ def _import_awkward() -> None:
"obj",
"register_awkward",
"register_numba",
"register_pytree",
"zip",
)

Expand Down
Loading