There are currently a few different places in core library code that implement generic data structure traversals that really ought to be unified to ensure correctness and consistency of behavior:
evaluate
internals.product_n.map_sequence
internals.unification.nested_type
- PyTree APIs in PyTorch and JAX via
handlers.torch/handlers.jax
To start out, it should be relatively straightforward to refactor evaluate and map_sequence to share their singledispatch-based traversal logic for non-Term data structures. Using the same traversal code in nested_type may be feasible as well, as should integrating with jax.tree.map for PyTree types.
There are currently a few different places in core library code that implement generic data structure traversals that really ought to be unified to ensure correctness and consistency of behavior:
evaluateinternals.product_n.map_sequenceinternals.unification.nested_typehandlers.torch/handlers.jaxTo start out, it should be relatively straightforward to refactor
evaluateandmap_sequenceto share theirsingledispatch-based traversal logic for non-Termdata structures. Using the same traversal code innested_typemay be feasible as well, as should integrating withjax.tree.mapfor PyTree types.