Skip to content

Commit 4a6d8c2

Browse files
Implement exclude list in compare
1 parent 6362918 commit 4a6d8c2

File tree

2 files changed

+59
-14
lines changed

2 files changed

+59
-14
lines changed

test/compare_test.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,54 @@ def test_identical_topologies_but_ids():
9494
assert result.edge_matching.element_matching[edge_a5] == edge_b5
9595

9696

97+
def test_exclude_element_list():
98+
topology_a = Topology()
99+
node_a1 = Node(geo_node=EuclideanGeoNode(0, 0))
100+
node_a2 = Node(geo_node=EuclideanGeoNode(0, 10))
101+
node_a3 = Node(geo_node=EuclideanGeoNode(10, 0))
102+
node_a4 = Node(geo_node=EuclideanGeoNode(20, 0))
103+
node_a5 = Node(geo_node=EuclideanGeoNode(30, 0))
104+
node_a6 = Node(geo_node=EuclideanGeoNode(30, 10))
105+
edge_a1 = Edge(node_a1, node_a3)
106+
edge_a2 = Edge(node_a2, node_a3)
107+
edge_a3 = Edge(node_a4, node_a3)
108+
edge_a4 = Edge(node_a4, node_a5)
109+
edge_a5 = Edge(node_a4, node_a6)
110+
topology_a.add_nodes([node_a1, node_a2, node_a3, node_a4, node_a5, node_a6])
111+
topology_a.add_edges([edge_a1, edge_a2, edge_a3, edge_a4, edge_a5])
112+
topology_a.update_edge_lengths()
113+
114+
topology_b = Topology()
115+
node_b1 = Node(geo_node=EuclideanGeoNode(0, 0))
116+
node_b2 = Node(geo_node=EuclideanGeoNode(0, 10))
117+
node_b3 = Node(geo_node=EuclideanGeoNode(12, 0)) # x differs from Node A3
118+
node_b4 = Node(geo_node=EuclideanGeoNode(20, 3)) # y differs from Node A4
119+
node_b5 = Node(geo_node=EuclideanGeoNode(30, 0))
120+
node_b6 = Node(geo_node=EuclideanGeoNode(30, 10))
121+
edge_b1 = Edge(node_b1, node_b3)
122+
edge_b2 = Edge(node_b2, node_b3)
123+
edge_b3 = Edge(node_b4, node_b3)
124+
edge_b4 = Edge(node_b4, node_b5)
125+
edge_b5 = Edge(node_b4, node_b6)
126+
topology_b.add_nodes([node_b1, node_b2, node_b3, node_b4, node_b5, node_b6])
127+
topology_b.add_edges([edge_b1, edge_b2, edge_b3, edge_b4, edge_b5])
128+
topology_b.update_edge_lengths()
129+
130+
compare_modes = [CompareMode.ISOMORPHIC]
131+
132+
for compare_mode in compare_modes:
133+
result = Compare.compare(
134+
topology_a, topology_b, compare_mode, given_node_matching={node_a1: node_b1}, exclude_element_list=[node_b4]
135+
)
136+
assert result.node_distance == 2.0
137+
assert node_a1 in result.node_matching.element_matching
138+
assert result.node_matching.element_matching[node_a1] == node_b1
139+
assert node_a3 in result.node_matching.element_matching
140+
assert result.node_matching.element_matching[node_a3] == node_b3
141+
assert edge_a5 in result.edge_matching.element_matching
142+
assert result.edge_matching.element_matching[edge_a5] == edge_b5
143+
144+
97145
def test_edge_diff_and_signal_distance():
98146
topology_a = Topology()
99147
node_a1 = Node(geo_node=EuclideanGeoNode(0, 0))

yaramo/operations/compare.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,15 @@
11
from enum import Enum, auto
2-
from typing import Dict, List, Set, Tuple
2+
from typing import Dict, List, Tuple
33

44
import networkx as nx
55

66
from ..model import (
7-
DbrefGeoNode,
87
Edge,
98
GeoNode,
109
Node,
1110
Signal,
1211
SignalDirection,
1312
Topology,
14-
Wgs84GeoNode,
1513
)
1614

1715

@@ -44,11 +42,13 @@ def compare(
4442
topology_b: Topology,
4543
compare_mode: CompareMode,
4644
given_node_matching: Dict[Node, Node] | None = None,
47-
exclude_ends_in_calculation: bool = False,
45+
exclude_element_list=None,
4846
skip_signals: bool = False,
4947
) -> CompareResult:
5048
if given_node_matching is None:
5149
given_node_matching = {}
50+
if exclude_element_list is None:
51+
exclude_element_list = []
5252

5353
result: CompareResult = CompareResult()
5454

@@ -66,14 +66,14 @@ def compare(
6666
)
6767

6868
result.node_distance = Compare._calc_distance_for_matching(
69-
result.node_matching, exclude_ends_in_calculation
69+
result.node_matching, exclude_element_list
7070
)
7171
result.edge_length_difference = Compare._calc_distance_for_matching(
72-
result.edge_matching, exclude_ends_in_calculation, element_type="edge"
72+
result.edge_matching, exclude_element_list, element_type="edge"
7373
)
7474
if not skip_signals:
7575
result.signal_distance = Compare._calc_distance_for_matching(
76-
result.signal_matching, exclude_ends_in_calculation, element_type="signal"
76+
result.signal_matching, exclude_element_list, element_type="signal"
7777
)
7878
return result
7979

@@ -125,6 +125,7 @@ def __add_to_open_nodes(__node_a: Node, __node_b: Node):
125125
def __add_edges_to_matching(__edge_a: Edge, __edge_b: Edge):
126126
if __edge_a in result.edge_matching.element_matching:
127127
if result.edge_matching.element_matching[__edge_a] != __edge_b:
128+
print(f"bad edge {__edge_a.uuid} {__edge_b.uuid}")
128129
raise ValueError(
129130
"Graph topology is isomorphic, but railway network graph differs (edge graph broken)"
130131
)
@@ -225,7 +226,7 @@ def _are_topologies_isomorphic(topology_a: Topology, topology_b: Topology):
225226
@staticmethod
226227
def _calc_distance_for_matching(
227228
matching: CompareMatching,
228-
exclude_ends_in_calculation,
229+
exclude_element_list,
229230
element_type: str = "node",
230231
):
231232
if not matching.element_matching:
@@ -234,20 +235,16 @@ def _calc_distance_for_matching(
234235
distance_sum: float = 0.0
235236
for element_a in matching.element_matching:
236237
element_b = matching.element_matching[element_a]
238+
if element_a in exclude_element_list or element_b in exclude_element_list:
239+
continue
237240

238241
if element_type == "node":
239-
if exclude_ends_in_calculation and not element_a.is_point():
240-
continue
241242
geo_node_a: GeoNode = element_a.geo_node
242243
geo_node_b: GeoNode = element_b.geo_node
243244
distance = abs(geo_node_a.get_distance_to_other_geo_node(geo_node_b))
244245
print(f"From {element_a.uuid} to {element_b.uuid}: {distance}")
245246
distance_sum += distance
246247
elif element_type == "edge":
247-
if exclude_ends_in_calculation and (
248-
not element_a.node_a.is_point() or not element_a.node_b.is_point()
249-
):
250-
continue
251248
print(
252249
f"Edge {element_a.uuid} compared to {element_b.uuid}: {abs(element_a.length - element_b.length)}"
253250
)

0 commit comments

Comments
 (0)