Skip to content

Commit 3a6f248

Browse files
SiegeLordExtensorflower-gardener
authored andcommitted
FunMC: Use tree_util.tree_map instead of tree_util.tree_multimap to silence warning.
PiperOrigin-RevId: 450797057
1 parent 89d248c commit 3a6f248

File tree

1 file changed

+3
-3
lines changed
  • spinoffs/fun_mc/fun_mc/dynamic/backend_jax

1 file changed

+3
-3
lines changed

spinoffs/fun_mc/fun_mc/dynamic/backend_jax/util.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747

4848
def map_tree(fn, tree, *args):
4949
"""Maps `fn` over the leaves of a nested structure."""
50-
return tree_util.tree_multimap(fn, tree, *args)
50+
return tree_util.tree_map(fn, tree, *args)
5151

5252

5353
def flatten_tree(tree):
@@ -66,7 +66,7 @@ def map_tree_up_to(shallow, fn, tree, *rest):
6666
def wrapper(_, *rest):
6767
return fn(*rest)
6868

69-
return tree_util.tree_multimap(wrapper, shallow, tree, *rest)
69+
return tree_util.tree_map(wrapper, shallow, tree, *rest)
7070

7171

7272
def get_shallow_tree(is_leaf, tree):
@@ -76,7 +76,7 @@ def get_shallow_tree(is_leaf, tree):
7676

7777
def assert_same_shallow_tree(shallow, tree):
7878
"""Asserts that `tree` has the same shallow structure as `shallow`."""
79-
# Do a dummy multimap for the side-effect of verifying that the structures are
79+
# Do a dummy map for the side-effect of verifying that the structures are
8080
# the same. This doesn't catch all the errors we actually care about, sadly.
8181
map_tree_up_to(shallow, lambda *args: (), tree)
8282

0 commit comments

Comments
 (0)