Skip to content

Commit 3521c4f

Browse files
Split: Implement node label assignments
1 parent c1201e2 commit 3521c4f

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

yaramo/operations/split.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def split(
2222
topology: Topology,
2323
split_edges: Dict[Edge, float],
2424
add_missing_elements_to_topology: None | Label = None,
25-
node_label_assignments: Dict[Node, Label] = None,
25+
node_label_assignments: Node | Dict[Node, Label] = None,
2626
) -> Tuple[Topology, Topology, Dict[Edge, Tuple[Node, Node]]]:
2727
Split._validate_split_edges(split_edges)
2828

@@ -32,7 +32,11 @@ def split(
3232
OperationsHelper.copy_topology_metadata(topology, topology_b)
3333

3434
new_end_nodes = Split._split_edges(topology, split_edges)
35-
node_labels, edge_labels, signal_labels = Split._label_elements(topology, new_end_nodes)
35+
36+
if node_label_assignments is None:
37+
node_label_assignments = {}
38+
39+
node_labels, edge_labels, signal_labels = Split._label_elements(topology, new_end_nodes, node_label_assignments)
3640

3741
if add_missing_elements_to_topology is not None:
3842
Split._assign_missing_elements_to_label(
@@ -157,7 +161,9 @@ def _get_new_geo_node_same_type(_old_geo_node: GeoNode, x: float, y: float) -> G
157161

158162
@staticmethod
159163
def _label_elements(
160-
topology: Topology, new_end_nodes: Dict[Edge, Tuple[Node, Node]]
164+
topology: Topology,
165+
new_end_nodes: Dict[Edge, Tuple[Node, Node]],
166+
node_label_assignments: Node | Dict[Node, Label],
161167
) -> Tuple[Dict[Node, Label], Dict[Edge, Label], Dict[Signal, Label]]:
162168
node_labels: Dict[Node, Label] = {}
163169
edge_labels: Dict[Edge, Label] = {}
@@ -187,6 +193,10 @@ def _dfs(_start_node: Node, _label):
187193
if not new_end_nodes:
188194
raise ValueError("No new end nodes found. Split not possible")
189195

196+
for node, label in node_label_assignments.items():
197+
_dfs(node, label)
198+
199+
190200
for end_node_pair in new_end_nodes.values():
191201
node_a = end_node_pair[0]
192202
node_b = end_node_pair[1]

0 commit comments

Comments
 (0)