Skip to content

Conversation

nsmith-
Copy link
Member

@nsmith- nsmith- commented Sep 16, 2025

Description

See https://blog.scientific-python.org/pytrees/ for the rationale for these functions.

As a demonstration, here's a projectile motion with air resistance simulation using vector and scipy, showing in particular the use of vector.pytree.flatten to wrap the SciPy ODE solver:

import hepunits as u
import matplotlib.pyplot as plt
import numpy as np
from scipy.integrate import solve_ivp

import vector
pytree = vector.register_pytree()

def wrapped_solve(fun, t_span, y0, t_eval):
    flat_y0, treedef = pytree.flatten(y0)

    def flat_fun(t, flat_y):
        state = pytree.unflatten(treedef, flat_y)
        dstate_dt = fun(t, state)
        flat_dstate_dt, _ = pytree.flatten(dstate_dt)
        return flat_dstate_dt

    flat_solution = solve_ivp(
        fun=flat_fun,
        t_span=t_span,
        y0=flat_y0,
        t_eval=t_eval,
    )
    return pytree.unflatten(treedef, flat_solution.y)


gravitational_acceleration = vector.obj(x=0.0, y=-9.81) * (u.m / u.s**2)
air_density = 1.25 * u.kg / u.meter3
drag_coefficient = 0.47  # for a sphere
object_radius = 10 * u.cm
object_area = np.pi * object_radius**2
object_mass = 1.0 * u.kg


def force(position: vector.Vector2D, momentum: vector.Vector2D):
    gravitational_force = object_mass * gravitational_acceleration
    speed = momentum.rho / object_mass
    drag_magnitude = 0.5 * air_density * speed**2 * drag_coefficient * object_area
    drag_force = -drag_magnitude * momentum.unit()
    return gravitational_force + drag_force


def tangent(t: float, state):
    position, momentum = state
    dposition_dt = momentum / object_mass
    dmomentum_dt = force(position, momentum)
    return (dposition_dt, dmomentum_dt)


time = np.linspace(0, 4, 100) * u.s
initial_position = vector.obj(
    x=0 * u.m,
    y=0 * u.m,
)
initial_momentum = (
    vector.obj(
        x=1 * u.m / u.s,
        y=20 * u.m / u.s,
    )
    * object_mass
)

position, momentum = wrapped_solve(
    fun=tangent,
    t_span=(time[0], time[-1]),
    y0=(initial_position, initial_momentum),
    t_eval=time,
)

fig, ax = plt.subplots()
ax.plot(position.x / u.m, position.y / u.m, label="with air resistance")

position_no_drag = (
    initial_momentum / object_mass * time + 0.5 * gravitational_acceleration * time**2
)

ax.plot(
    position_no_drag.x / u.m,
    position_no_drag.y / u.m,
    ls="--",
    label="no air resistance",
)

ax.set_xlabel("x (m)")
ax.set_ylabel("y (m)")
ax.set_title("Projectile motion")
ax.legend()

fig.savefig("demo.png")

producing
demo

Checklist

  • Have you followed the guidelines in our Contributing document?
  • Have you checked to ensure there aren't any other open Pull Requests for the required change?
  • Does your submission pass pre-commit? ($ pre-commit run --all-files or $ nox -s lint)
  • Does your submission pass tests? ($ pytest or $ nox -s tests)
  • Does the documentation build with your changes? ($ cd docs; make clean; make html or $ nox -s docs)
  • Does your submission pass the doctests? ($ pytest --doctest-plus src/vector/ or $ nox -s doctests)

Before Merging

  • Summarize the commit messages into a brief review of the Pull request.

@nsmith- nsmith- requested a review from Saransh-cpp September 16, 2025 04:28
@nsmith-
Copy link
Member Author

nsmith- commented Sep 16, 2025

FYI @matthewfeickert @pfackeldey

@pfackeldey
Copy link
Collaborator

That's cool @nsmith-! This is a nice application for pytrees where you want to make use of the vector objects for physics, but are restricted by solvers that only work with floats.

We definitely should add this for NumPy as well - it should be rather straightforward as they are automatically considered leaf types. Awkward Arrays are much more complicated, I'm not sure if that should be supported for them, and if so how they should be decomposed.

@henryiii
Copy link
Member

Requires #622.

@nsmith-
Copy link
Member Author

nsmith- commented Sep 16, 2025

We definitely should add this for NumPy as well - it should be rather straightforward as they are automatically considered leaf types.

Is that the case though? From

stuff = (np.ones(4), np.ones(5))
pytree.flatten(stuff)

it looks like not

([array([1., 1., 1., 1.]), array([1., 1., 1., 1., 1.])], PyTreeSpec((*, *)))

@pfackeldey
Copy link
Collaborator

We definitely should add this for NumPy as well - it should be rather straightforward as they are automatically considered leaf types.

Is that the case though? From

stuff = (np.ones(4), np.ones(5))
pytree.flatten(stuff)

it looks like not

([array([1., 1., 1., 1.]), array([1., 1., 1., 1., 1.])], PyTreeSpec((*, *)))

This is considering arrays as leafs. We're storing things as SoA in vector right? So flattening a vector obj with numpy array as leafs should work like in this example.
For the minimization you may want to consider to use https://optree.readthedocs.io/en/latest/integrations.html#optree.integrations.jax.tree_ravel instead of flatten + unflatten.

@nsmith-
Copy link
Member Author

nsmith- commented Sep 16, 2025

Ok so it is making a distinction between flatten and ravel, where the latter also concatenates numpy arrays. It seems then that one would have to use something like https://github.com/metaopt/optree/blob/e0e772ea3bb4ab551cb6f102b8074022bb8542ec/optree/integrations/numpy.py#L133 to actually get things fully raveled?
We could add a helper function to the re-exported namespace, ravel for that purpose?

@nsmith-
Copy link
Member Author

nsmith- commented Sep 16, 2025

We're storing things as SoA in vector right?

Well, it's a bit messy, they can be either:

vec_soa = vector.obj(x=1, y=1) * np.ones(10)
print(pytree.flatten(vec_soa))
vec_aos = vector.array(dict(x=np.ones(10), y=np.ones(10)))
print(pytree.flatten(vec_aos))

gives

([array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]), array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])], PyTreeSpec(CustomTreeNode(VectorObject2D[(<class 'vector.backends.object.VectorObject2D'>, <class 'vector.backends.object.AzimuthalObjectXY'>)], [*, *]), namespace='vector'))
([VectorNumpy2D([(1., 1.), (1., 1.), (1., 1.), (1., 1.), (1., 1.), (1., 1.),
               (1., 1.), (1., 1.), (1., 1.), (1., 1.)],
              dtype=[('x', '<f8'), ('y', '<f8')])], PyTreeSpec(*))

Note vector.obj(x=np.ones(10), y=np.ones(10)) raises, which makes sense. But the object backend is "leaky"

@pfackeldey
Copy link
Collaborator

Ok so it is making a distinction between flatten and ravel, where the latter also concatenates numpy arrays. It seems then that one would have to use something like https://github.com/metaopt/optree/blob/e0e772ea3bb4ab551cb6f102b8074022bb8542ec/optree/integrations/numpy.py#L133 to actually get things fully raveled? We could add a helper function to the re-exported namespace, ravel for that purpose?

I'm not sure what the existing ravel in optree is missing that we want to add here? Shouldn't that be enough once the tree decomposition works?

@pfackeldey
Copy link
Collaborator

We're storing things as SoA in vector right?

Well, it's a bit messy, they can be either:

vec_soa = vector.obj(x=1, y=1) * np.ones(10)
print(pytree.flatten(vec_soa))
vec_aos = vector.array(dict(x=np.ones(10), y=np.ones(10)))
print(pytree.flatten(vec_aos))

gives

([array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]), array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])], PyTreeSpec(CustomTreeNode(VectorObject2D[(<class 'vector.backends.object.VectorObject2D'>, <class 'vector.backends.object.AzimuthalObjectXY'>)], [*, *]), namespace='vector'))
([VectorNumpy2D([(1., 1.), (1., 1.), (1., 1.), (1., 1.), (1., 1.), (1., 1.),
               (1., 1.), (1., 1.), (1., 1.), (1., 1.)],
              dtype=[('x', '<f8'), ('y', '<f8')])], PyTreeSpec(*))

Note vector.obj(x=np.ones(10), y=np.ones(10)) raises, which makes sense. But the object backend is "leaky"

Ok, I see. Looks like the vector.obj already works (because of your changes in this PR). Now, it needs to register the VectorNumpy*D (and MomentumNumpy*D) for optree (basically by decomposing the numpy rec array I guess)

@nsmith-
Copy link
Member Author

nsmith- commented Sep 16, 2025

Yes, so

from optree.integrations.numpy import ravel_pytree

ravel_pytree(vec_soa, namespace="vector")

works, but I'd rather have pytree.ravel as a shortcut

@nsmith-
Copy link
Member Author

nsmith- commented Sep 16, 2025

For AoS arrays, we may have to ask optree to implement something, as

from optree.integrations.numpy import ravel_pytree

data = np.array(
    [
        (1.0, 2.0),
        (3.0, 4.0),
        (5.0, 6.0),
    ],
    dtype=[("x", "<f8"), ("y", "<f8")],
)

flat, unravel_func = ravel_pytree(data)

returns

(array([(1., 1.), (1., 1.), (1., 1.)], dtype=[('x', '<f8'), ('y', '<f8')]),
 functools.partial(<function _tree_unravel at 0x12618f9c0>, PyTreeSpec(*), functools.partial(<function _unravel_leaves_single_dtype at 0x12618fb00>, (3,), ((3,),))))

which isn't friendly to the scipy methods:

from scipy.optimize import minimize

# works
minimize(lambda d: np.hypot(d[::2], d[1::2]).sum(), data.view(np.float64))

# does not work
minimize(lambda d: np.hypot(d['x'], d['y']).sum(), data)

Comment on lines +92 to +96
field_dtypes = [dt for dt, *_ in v.dtype.fields.values()]
target_dtype = field_dtypes[0]
if not all(fd == target_dtype for fd in field_dtypes):
raise ValueError("All fields must have the same dtype to flatten")
array = numpy.array(v).view(target_dtype)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We manually flatten the numpy struct dtype here and include the necessary information to revert it. For safety we require all fields to have the same dtype. This is necessary to fully flatten the array to a 1d vector

Copy link
Member

@Saransh-cpp Saransh-cpp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TIL: optree

This is quite an amazing addition. Thanks for working on this, @nsmith-! I am a bit on fence with the API design.

The register_awkward functions in Vector returns nothing and lets users pass Vector behavior names to the backend library (awkward) directly (vector.register_awkward(); ak.Array(..., with_name="Momentum3D")).

It would be nice to have an identical API for pytrees. From my very short experiments, an awkward-like API for optree looks feasible -

def _flatten2D(v: Vector2D) -> tuple[Children, MetaData]:
    children = v.azimuthal.elements
    metadata = type(v), type(v.azimuthal)
    return children, metadata

def _unflatten2D(metadata: MetaData, children: Children) -> Vector2D:
    backend, azimuthal = metadata
    return backend(azimuthal=azimuthal(*children))

optree.register_pytree_node(
    VectorObject2D,
    flatten_func=_flatten2D,
    unflatten_func=_unflatten2D,
    namespace="vector.pytree"
)

and then the user can -

leaves, treedef = optree.tree_flatten(vec, namespace="vector.pytree")

akin to how the awkward backend works (ak.Array(..., with_name="Momentum3D")).

We can then have a vector specific constructor (like vector.Array and vector.zip), which I think in this case is the return value of vector.register_pytree.

Could you see if this is feasible? Looking at how the awkward backend is structured will be very useful. Thanks again for working on this!

@nsmith-
Copy link
Member Author

nsmith- commented Sep 25, 2025

The pytree return is really just a convenience. If you ignore the return argument, then what you wrote already works:

import optree

leaves, treedef = optree.tree_flatten(vec, namespace="vector")

Copy link
Member

@Saransh-cpp Saransh-cpp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @nsmith-! Could you add that way of using optree for Vector classes in the example? I have a few other structural suggestions for the docs:

  • The example should be a notebook, so that it can be tested on our CI. The notebook is definitely a tutorial so it should be moved under Tutorials. Can the API documentation be moved somewhere else? Maybe a new page under Vector Constructors?
  • Vector backends in README should be updated (perhaps with a new heading below backends?)
  • The documentation tree in README should be updated
  • Same updates for docs/index.md

@nsmith-
Copy link
Member Author

nsmith- commented Oct 7, 2025

pre-commit is spellchecking the base64-encoded png image 🙄

codespell................................................................Failed
- hook id: codespell
- exit code: 65

docs/src/pytree.ipynb:232: fO ==> of, for, to, do, go

edit: potentially https://stackoverflow.com/questions/78867158/how-to-make-codespell-not-report-false-positives-in-base64-strings is a better solution

@nsmith- nsmith- requested a review from Saransh-cpp October 7, 2025 16:34
Copy link
Member

@Saransh-cpp Saransh-cpp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates, @nsmith-! This looks good to me now 🚀

@henryiii @pfackeldey please feel free to merge once you approve.

Copy link
Collaborator

@pfackeldey pfackeldey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

If anything I'd suggest to add some explanation in the docs that this feature does not work with the awkward-array backend, that may not be obvious on first sight when someone sees this feature. You can decide, I have no strong opinion on this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants