Skip to content

Commit fe2d35e

Browse files
jburnimtensorflower-gardener
authored andcommitted
Add structure/type-checks NumPy version of nest.map_structure_with_path_up_to.
These checks were recently removed from dm-tree. PiperOrigin-RevId: 379382661
1 parent a954b38 commit fe2d35e

File tree

1 file changed

+30
-3
lines changed
  • tensorflow_probability/python/internal/backend/numpy

1 file changed

+30
-3
lines changed

tensorflow_probability/python/internal/backend/numpy/nest.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,7 @@
5050
from tree import flatten_with_path_up_to
5151
from tree import is_nested
5252
from tree import map_structure as dm_tree_map_structure
53-
from tree import map_structure_up_to
54-
from tree import map_structure_with_path
55-
from tree import map_structure_with_path_up_to
53+
from tree import map_structure_with_path_up_to as dm_tree_map_structure_with_path_up_to
5654
from tree import unflatten_as
5755
# pylint: enable=unused-import
5856

@@ -213,6 +211,35 @@ def flatten_with_tuple_paths_up_to(shallow_structure,
213211
check_types)
214212

215213

214+
def map_structure_up_to(shallow_structure, func, *structures, **kwargs):
215+
return map_structure_with_path_up_to(
216+
shallow_structure,
217+
lambda _, *args: func(*args), # Discards path.
218+
*structures,
219+
**kwargs)
220+
221+
222+
def map_structure_with_path(func, *structures, **kwargs):
223+
return map_structure_with_path_up_to(structures[0], func, *structures,
224+
**kwargs)
225+
226+
227+
def map_structure_with_path_up_to(shallow_structure, func, *structures,
228+
**kwargs):
229+
"""Wraps nest.map_structure_with_path_up_to, with structure/type checking."""
230+
if not structures:
231+
raise ValueError('Cannot map over no sequences')
232+
233+
check_types = kwargs.get('check_types', True)
234+
235+
for input_tree in structures:
236+
_assert_shallow_structure(
237+
shallow_structure, input_tree, check_types=check_types)
238+
239+
return dm_tree_map_structure_with_path_up_to(
240+
shallow_structure, func, *structures, **kwargs)
241+
242+
216243
def map_structure_with_tuple_paths(func, *structures, **kwargs):
217244
return map_structure_with_path(func, *structures, **kwargs)
218245

0 commit comments

Comments
 (0)