@@ -17,7 +17,11 @@ def get_opposite_label(label):
1717
1818class Split :
1919 @staticmethod
20- def split (topology : Topology , split_edges : Dict [Edge , float ]) -> Tuple [Topology , Topology ]:
20+ def split (
21+ topology : Topology ,
22+ split_edges : Dict [Edge , float ],
23+ add_missing_elements_to_topology : None | Label = None ,
24+ ) -> Tuple [Topology , Topology ]:
2125 Split ._validate_split_edges (split_edges )
2226
2327 topology_a = Topology ()
@@ -26,6 +30,11 @@ def split(topology: Topology, split_edges: Dict[Edge, float]) -> Tuple[Topology,
2630 new_end_nodes = Split ._split_edges (topology , split_edges )
2731 node_labels , edge_labels , signal_labels = Split ._label_elements (topology , new_end_nodes )
2832
33+ if add_missing_elements_to_topology is not None :
34+ Split ._assign_missing_elements_to_label (
35+ topology , node_labels , edge_labels , signal_labels , add_missing_elements_to_topology
36+ )
37+
2938 # Add nodes, edges and signals to destination topologies.
3039 def _add_elements_to_topology (
3140 _element_label_matching , _topology_a_method , _topology_b_method
@@ -188,6 +197,24 @@ def _dfs(_start_node: Node, _label):
188197
189198 return node_labels , edge_labels , signal_labels
190199
200+ @staticmethod
201+ def _assign_missing_elements_to_label (
202+ topology : Topology ,
203+ node_labels : Dict [Node , Label ],
204+ edge_labels : Dict [Edge , Label ],
205+ signal_labels : Dict [Signal , Label ],
206+ label : Label ,
207+ ):
208+ for node in topology .nodes .values ():
209+ if node not in node_labels :
210+ node_labels [node ] = label
211+ for edge in topology .edges .values ():
212+ if edge not in edge_labels :
213+ edge_labels [edge ] = label
214+ for signal in topology .signals .values ():
215+ if signal not in signal_labels :
216+ signal_labels [signal ] = label
217+
191218 @staticmethod
192219 def _validate_for_data_loss (
193220 topology : Topology ,
0 commit comments