Skip to content
This repository was archived by the owner on Oct 24, 2024. It is now read-only.

Commit 6ae66b0

Browse files
authored
Fix mapping parentage bug (#54)
* think I've fixed the bug * used feature from python 3.9 * test but doesn't yet work properly * only check subtree, not down to root * make sure choice whether to check from root is propagated * bump python version in CI * 3.10 instead of 3.1
1 parent df6925f commit 6ae66b0

File tree

6 files changed

+28
-11
lines changed

6 files changed

+28
-11
lines changed

.github/workflows/main.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ jobs:
2121
runs-on: ubuntu-latest
2222
strategy:
2323
matrix:
24-
python-version: [3.7, 3.8, 3.9]
24+
python-version: ["3.9", "3.10"]
2525
steps:
2626
- uses: actions/[email protected]
2727
- uses: conda-incubator/setup-miniconda@v2
@@ -59,7 +59,7 @@ jobs:
5959
runs-on: ubuntu-latest
6060
strategy:
6161
matrix:
62-
python-version: [3.8, 3.9]
62+
python-version: ["3.9", "3.10"]
6363
steps:
6464
- uses: actions/[email protected]
6565
- uses: conda-incubator/setup-miniconda@v2

datatree/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.1.dev75+g977ffe2.d20210902"
1+
__version__ = "0.1.dev94+g6c6f23c.d20211217"

datatree/mapping.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,8 @@ def _map_over_subtree(*args, **kwargs):
215215
# Find out how many return values we received
216216
num_return_values = _check_all_return_values(out_data_objects)
217217

218+
ancestors_of_new_root = first_tree.pathstr.removesuffix(first_tree.name)
219+
218220
# Reconstruct 1+ subtrees from the dict of results, by filling in all nodes of all result trees
219221
result_trees = []
220222
for i in range(num_return_values):
@@ -228,7 +230,11 @@ def _map_over_subtree(*args, **kwargs):
228230
output_node_data = out_data_objects[p]
229231
else:
230232
output_node_data = None
231-
out_tree_contents[p] = output_node_data
233+
234+
# Discard parentage so that new trees don't include parents of input nodes
235+
# TODO use a proper relative_path method on DataTree(/TreeNode) to do this
236+
relative_path = p.removeprefix(ancestors_of_new_root)
237+
out_tree_contents[relative_path] = output_node_data
232238

233239
new_tree = DataTree.from_dict(
234240
name=first_tree.name, data_objects=out_tree_contents

datatree/testing.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def assert_isomorphic(a: DataTree, b: DataTree, from_root: bool = False):
4141
a = a.root
4242
b = b.root
4343

44-
assert a.isomorphic(b, from_root=False), diff_tree_repr(a, b, "isomorphic")
44+
assert a.isomorphic(b, from_root=from_root), diff_tree_repr(a, b, "isomorphic")
4545
else:
4646
raise TypeError(f"{type(a)} not of type DataTree")
4747

@@ -78,7 +78,7 @@ def assert_equal(a: DataTree, b: DataTree, from_root: bool = True):
7878
a = a.root
7979
b = b.root
8080

81-
assert a.equals(b), diff_tree_repr(a, b, "equals")
81+
assert a.equals(b, from_root=from_root), diff_tree_repr(a, b, "equals")
8282
else:
8383
raise TypeError(f"{type(a)} not of type DataTree")
8484

@@ -115,6 +115,6 @@ def assert_identical(a: DataTree, b: DataTree, from_root: bool = True):
115115
a = a.root
116116
b = b.root
117117

118-
assert a.identical(b), diff_tree_repr(a, b, "identical")
118+
assert a.identical(b, from_root=from_root), diff_tree_repr(a, b, "identical")
119119
else:
120120
raise TypeError(f"{type(a)} not of type DataTree")

datatree/tests/test_mapping.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,19 @@ def multiply_then_add(ds, times, add=0.0):
251251
result_tree = dt.map_over_subtree(multiply_then_add, 10.0, add=2.0)
252252
assert_equal(result_tree, expected)
253253

254+
def test_discard_ancestry(self):
255+
# Check for datatree GH issue #48
256+
dt = create_test_datatree()
257+
subtree = dt["set1"]
258+
259+
@map_over_subtree
260+
def times_ten(ds):
261+
return 10.0 * ds
262+
263+
expected = create_test_datatree(modify=lambda ds: 10.0 * ds)["set1"]
264+
result_tree = times_ten(subtree)
265+
assert_equal(result_tree, expected, from_root=False)
266+
254267

255268
@pytest.mark.xfail
256269
class TestMapOverSubTreeInplace:

setup.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,12 @@
2626
"Topic :: Scientific/Engineering",
2727
"License :: OSI Approved :: Apache License",
2828
"Operating System :: OS Independent",
29-
"Programming Language :: Python :: 3",
30-
"Programming Language :: Python :: 3.7",
31-
"Programming Language :: Python :: 3.8",
3229
"Programming Language :: Python :: 3.9",
30+
"Programming Language :: Python :: 3.10",
3331
],
3432
packages=find_packages(exclude=["docs", "tests", "tests.*", "docs.*"]),
3533
install_requires=install_requires,
36-
python_requires=">=3.7",
34+
python_requires=">=3.9",
3735
setup_requires="setuptools_scm",
3836
use_scm_version={
3937
"write_to": "datatree/_version.py",

0 commit comments

Comments
 (0)