Skip to content

Commit 82e1e0e

Browse files
authored
Add remove_tree and overload add_tree (#891)
* first draft for remove_tree and add_tree with tree object * use overloading for add_tree * format * format * trigger ci * fix linting * format * fix typing * clean up * beautify comment * also handle color and _enforced_id and test this * update changelog * add example script for merging NMLs * format * fix linting * move example file to correct folder
1 parent 476359f commit 82e1e0e

File tree

6 files changed

+196
-34
lines changed

6 files changed

+196
-34
lines changed
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Merge Trees of two NML files
2+
3+
This example opens two [NML files](/webknossos/data_formats.html#nml-files) and copies the trees from the NML A to B and saves the output to a new NML (that includes the trees of both input NMLs).
4+
5+
```python
6+
--8<--
7+
webknossos/examples/merge_trees_of_nml_files.py
8+
--8<--
9+
```

webknossos/Changelog.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ For upgrade instructions, please check the respective _Breaking Changes_ section
1515
### Breaking Changes
1616

1717
### Added
18+
- `Group.add_tree` now also accepts a tree object as a first parameter (instead of only a string). [#891](https://github.com/scalableminds/webknossos-libs/pull/891)
19+
- `Group.remove_tree_by_id` was added. [#891](https://github.com/scalableminds/webknossos-libs/pull/891)
1820

1921
### Changed
2022
- Upgrades `black`, `mypy`, `pylint`, `pytest`. [#873](https://github.com/scalableminds/webknossos-libs/pull/873)
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import webknossos as wk
2+
3+
4+
def main() -> None:
5+
skeleton_a = wk.Skeleton.load("./a.nml")
6+
skeleton_b = wk.Skeleton.load("./b.nml")
7+
8+
for tree in skeleton_a.flattened_trees():
9+
skeleton_b.add_tree(tree)
10+
11+
skeleton_b.save("./c.nml")
12+
13+
14+
if __name__ == "__main__":
15+
main()

webknossos/tests/test_skeleton.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,3 +275,72 @@ def test_load_nml(tmp_path: Path) -> None:
275275
skeleton_a = wk.Skeleton.load(input_path)
276276
skeleton_a.save(output_path)
277277
assert skeleton_a == wk.Skeleton.load(output_path)
278+
279+
280+
def test_remove_tree(tmp_path: Path) -> None:
281+
input_path = TESTDATA_DIR / "nmls" / "test_a.nml"
282+
output_path = tmp_path / "test_a.nml"
283+
skeleton_a = wk.Skeleton.load(input_path)
284+
285+
# Check that tree exists
286+
tree = skeleton_a.get_tree_by_id(1)
287+
assert tree is not None
288+
assert tree in list(skeleton_a.children)
289+
290+
# Check that tree doesn't exist anymore
291+
skeleton_a.remove_tree_by_id(1)
292+
with pytest.raises(ValueError):
293+
tree = skeleton_a.get_tree_by_id(1)
294+
295+
assert tree not in list(skeleton_a.children)
296+
297+
# Check that serialized skeleton doesn't contain
298+
# deleted tree
299+
skeleton_a.save(output_path)
300+
assert skeleton_a == wk.Skeleton.load(output_path)
301+
302+
# Load original file and check that tree is still
303+
# there (should not have been removed on disk
304+
# automatically).
305+
skeleton_a = wk.Skeleton.load(input_path)
306+
assert tree in list(skeleton_a.children)
307+
308+
309+
def test_add_tree_with_obj(tmp_path: Path) -> None:
310+
input_path = TESTDATA_DIR / "nmls" / "test_a.nml"
311+
output_path = tmp_path / "test_a.nml"
312+
skeleton_a = wk.Skeleton.load(input_path)
313+
314+
# Check that tree exists
315+
tree = skeleton_a.get_tree_by_id(1)
316+
317+
skeleton_b = wk.Skeleton(
318+
voxel_size=skeleton_a.voxel_size, dataset_name=skeleton_a.dataset_name
319+
)
320+
skeleton_b.add_tree(tree)
321+
322+
assert tree is not skeleton_b.get_tree_by_id(1)
323+
assert tree == skeleton_b.get_tree_by_id(1)
324+
325+
skeleton_b.save(output_path)
326+
327+
328+
def test_add_tree_with_obj_and_properties(tmp_path: Path) -> None:
329+
input_path = TESTDATA_DIR / "nmls" / "test_a.nml"
330+
output_path = tmp_path / "test_a.nml"
331+
skeleton_a = wk.Skeleton.load(input_path)
332+
333+
# Check that tree exists
334+
tree = skeleton_a.get_tree_by_id(1)
335+
336+
skeleton_b = wk.Skeleton(
337+
voxel_size=skeleton_a.voxel_size, dataset_name=skeleton_a.dataset_name
338+
)
339+
new_tree = skeleton_b.add_tree(tree, color=(1, 2, 3), _enforced_id=1337)
340+
341+
assert new_tree is skeleton_b.get_tree_by_id(1337)
342+
assert new_tree is not tree
343+
assert new_tree != tree
344+
assert new_tree.color == (1, 2, 3, 1)
345+
346+
skeleton_b.save(output_path)

webknossos/webknossos/skeleton/group.py

Lines changed: 52 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import copy
12
from typing import TYPE_CHECKING, Iterator, Optional, Set, Tuple, Union, cast
23

34
import attr
@@ -63,25 +64,53 @@ def id(self) -> int:
6364

6465
def add_tree(
6566
self,
66-
name: str,
67+
name_or_tree: Union[str, Tree],
6768
color: Optional[Union[Vector4, Vector3]] = None,
6869
_enforced_id: Optional[int] = None,
6970
) -> Tree:
70-
"""Adds a tree to the current group with the provided name (and color if specified)."""
71+
"""Adds a tree to the current group. If the first parameter is a string,
72+
a new tree will be added with the provided name and color if specified.
73+
Otherwise, the first parameter is assumed to be a tree object (e.g., from
74+
another skeleton). A copy of that tree will then be added. If the id
75+
of the tree already exists, a new id will be generated."""
7176

7277
if color is not None and len(color) == 3:
7378
color = cast(Optional[Vector4], color + (1.0,))
7479
color = cast(Optional[Vector4], color)
75-
new_tree = Tree(
76-
name=name,
77-
color=color,
78-
group=self,
79-
skeleton=self._skeleton,
80-
enforced_id=_enforced_id,
81-
)
82-
self._child_trees.add(new_tree)
8380

84-
return new_tree
81+
if type(name_or_tree) is str:
82+
name = name_or_tree
83+
new_tree = Tree(
84+
name=name,
85+
color=color,
86+
group=self,
87+
skeleton=self._skeleton,
88+
enforced_id=_enforced_id,
89+
)
90+
self._child_trees.add(new_tree)
91+
92+
return new_tree
93+
else:
94+
tree = cast(Tree, name_or_tree)
95+
new_tree = copy.deepcopy(tree)
96+
97+
if color is not None:
98+
new_tree.color = color
99+
100+
if _enforced_id is not None:
101+
assert not self.has_tree_id(
102+
_enforced_id
103+
), "A tree with the specified _enforced_id already exists in this group."
104+
new_tree._id = _enforced_id
105+
106+
if self.has_tree_id(tree.id):
107+
new_tree._id = self._skeleton._element_id_generator.__next__()
108+
109+
new_tree.group = self
110+
new_tree.skeleton = self._skeleton
111+
112+
self._child_trees.add(new_tree)
113+
return new_tree
85114

86115
def add_graph(
87116
self,
@@ -91,7 +120,10 @@ def add_graph(
91120
) -> Tree:
92121
"""Deprecated, please use `add_tree`."""
93122
warn_deprecated("add_graph()", "add_tree()")
94-
return self.add_tree(name=name, color=color, _enforced_id=_enforced_id)
123+
return self.add_tree(name_or_tree=name, color=color, _enforced_id=_enforced_id)
124+
125+
def remove_tree_by_id(self, tree_id: int) -> None:
126+
self._child_trees.remove(self.get_tree_by_id(tree_id))
95127

96128
@property
97129
def children(self) -> Iterator[GroupOrTree]:
@@ -180,6 +212,14 @@ def get_tree_by_id(self, tree_id: int) -> Tree:
180212
return tree
181213
raise ValueError(f"No tree with id {tree_id} was found")
182214

215+
def has_tree_id(self, tree_id: int) -> bool:
216+
"""Returns true if this group (or a subgroup) contains a tree with the given id."""
217+
try:
218+
self.get_tree_by_id(tree_id)
219+
return True
220+
except ValueError:
221+
return False
222+
183223
def get_graph_by_id(self, graph_id: int) -> Tree:
184224
"""Deprecated, please use `get_tree_by_id`."""
185225
warn_deprecated("get_graph_by_id()", "get_tree_by_id()")

webknossos/webknossos/skeleton/tree.py

Lines changed: 49 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -114,45 +114,72 @@ def __init__(
114114
group: "Group",
115115
skeleton: "Skeleton",
116116
color: Optional[Vector4] = None,
117-
enforced_id: Optional[int] = None,
117+
enforced_id: Optional[int] = None, # pylint: disable=unused-argument
118118
) -> None:
119119
"""
120120
To create a tree, it is recommended to use `Skeleton.add_tree` or
121121
`Group.add_tree`. That way, the newly created tree is automatically
122122
attached as a child to the object the method was called on.
123123
"""
124124

125-
# To be able to reference nodes by id after adding them for the first time, we use custom dict-like classes
126-
# for the networkx-graph structures, that have nodes as keys:
127-
# * `self._node`: _NodeDict
128-
# holding the attributes of nodes, keeping references from ids to nodes
129-
# * `self._adj`: _AdjDict on the first two levels
130-
# holding edge attributes on the last level, using self._node to convert ids to nodes
131-
#
132-
# It's important to set the attributes before the parent's init so that they shadow the class-attributes.
133-
#
134-
# For further details, see the *Subclasses* section here: https://networkx.org/documentation/stable/reference/classes/graph.html
135-
136-
self.node_dict_factory = _NodeDict
137-
# The lambda works because self._node is set before self._adj in networkx.Graph.__init__
138-
# and because the lambda is evaluated lazily.
139-
self.adjlist_outer_dict_factory = lambda: _AdjDict(node_dict=self._node)
140-
self.adjlist_inner_dict_factory = lambda: _AdjDict(node_dict=self._node)
141-
142125
super().__init__()
143-
126+
# Note that id is set up in __new__
144127
self.name = name
145128
self.group = group
146129
self.color = color
147130

148-
# read-only member, exposed via properties
131+
# only used internally
132+
self._skeleton = skeleton
133+
134+
def __new__(
135+
cls,
136+
name: str, # pylint: disable=unused-argument
137+
group: "Group", # pylint: disable=unused-argument
138+
skeleton: "Skeleton", # pylint: disable=unused-argument
139+
color: Optional[Vector4] = None, # pylint: disable=unused-argument
140+
enforced_id: Optional[int] = None,
141+
) -> "Tree":
142+
self = super().__new__(cls)
143+
144+
# self._id is a read-only member, exposed via properties.
145+
# It is set in __new__ instead of __init__ so that pickling/unpickling
146+
# works without problems. As long as the deserialization of a tree instance
147+
# is not finished, the object is only half-initialized. Since self._id
148+
# is needed by __hash__, an error would be raised otherwise.
149+
# Also see:
150+
# https://stackoverflow.com/questions/46283738/attributeerror-when-using-python-deepcopy
149151
if enforced_id is not None:
150152
self._id = enforced_id
151153
else:
152154
self._id = skeleton._element_id_generator.__next__()
153155

154-
# only used internally
155-
self._skeleton = skeleton
156+
return self
157+
158+
def __getnewargs__(self) -> Tuple:
159+
# pickle.dump will pickle instances of Tree so that the following
160+
# tuple is passed as arguments to __new__.
161+
return (self.name, self.group, self._skeleton, self.color, self._id)
162+
163+
# node_dict_factory, adjlist_outer_dict_factory and adjlist_inner_dict_factory are used by networkx
164+
# from which we subclass.
165+
# To be able to reference nodes by id after adding them for the first time, we use custom dict-like classes
166+
# for the networkx-graph structures, that have nodes as keys:
167+
# * `self._node`: _NodeDict
168+
# holding the attributes of nodes, keeping references from ids to nodes
169+
# * `self._adj`: _AdjDict on the first two levels
170+
# holding edge attributes on the last level, using self._node to convert ids to nodes
171+
# It's important to set the attributes before the parent's init so that they shadow the class-attributes.
172+
# For further details, see the *Subclasses* section here: https://networkx.org/documentation/stable/reference/classes/graph.html
173+
174+
node_dict_factory = _NodeDict
175+
176+
def adjlist_outer_dict_factory(self) -> _AdjDict:
177+
# self._node will already be available when this method is called, because networkx.Graph.__init__
178+
# sets up the nodes first and then the edges (i.e., adjacency list).
179+
return _AdjDict(node_dict=self._node)
180+
181+
def adjlist_inner_dict_factory(self) -> _AdjDict:
182+
return _AdjDict(node_dict=self._node)
156183

157184
def __to_tuple_for_comparison(self) -> Tuple:
158185
return (

0 commit comments

Comments
 (0)