Skip to content

Commit 1f18c67

Browse files
committed
typing
1 parent c73dd57 commit 1f18c67

File tree

2 files changed

+45
-39
lines changed

2 files changed

+45
-39
lines changed

yaramo/geo_node.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
import math
2-
from abc import ABC, abstractmethod
1+
from __future__ import annotations
32

3+
from abc import ABC, abstractmethod
44
from haversine import Unit, haversine
5+
import math
56

67
from yaramo.base_element import BaseElement
78

@@ -21,27 +22,27 @@ def __init__(self, x, y, dbref_crs: str = "ER0", **kwargs):
2122
self.dbref_crs = dbref_crs
2223

2324
@abstractmethod
24-
def get_distance_to_other_geo_node(self, geo_node_b: "GeoNode"):
25+
def get_distance_to_other_geo_node(self, geo_node_b: GeoNode):
2526
"""Returns to distance to the given other GeoNode."""
2627
pass
2728

2829
@abstractmethod
29-
def to_wgs84(self) -> "Wgs84GeoNode":
30+
def to_wgs84(self) -> Wgs84GeoNode:
3031
pass
3132

3233
@abstractmethod
33-
def to_dbref(self) -> "DbrefGeoNode":
34+
def to_dbref(self) -> DbrefGeoNode:
3435
pass
3536

3637
@abstractmethod
37-
def to_euclidean(self) -> "EuclideanGeoNode":
38+
def to_euclidean(self) -> EuclideanGeoNode:
3839
pass
3940

4041
def to_serializable(self):
4142
return self.__dict__, {}
4243

4344
@staticmethod
44-
def get_new_geo_node_same_type(old_geo_node: "GeoNode", x: float, y: float) -> "GeoNode":
45+
def get_new_geo_node_same_type(old_geo_node: GeoNode, x: float, y: float) -> GeoNode:
4546
if isinstance(old_geo_node, Wgs84GeoNode):
4647
return Wgs84GeoNode(x, y)
4748
elif isinstance(old_geo_node, DbrefGeoNode):
@@ -52,39 +53,39 @@ def get_new_geo_node_same_type(old_geo_node: "GeoNode", x: float, y: float) -> "
5253

5354

5455
class Wgs84GeoNode(GeoNode):
55-
def get_distance_to_other_geo_node(self, geo_node_b: "GeoNode"):
56+
def get_distance_to_other_geo_node(self, geo_node_b: GeoNode):
5657
geo_node_b = geo_node_b.to_wgs84()
5758
return self.__haversine_distance(geo_node_b)
5859

59-
def __haversine_distance(self, geo_node_b: "GeoNode"):
60+
def __haversine_distance(self, geo_node_b: GeoNode):
6061
own = (self.x, self.y)
6162
other = (geo_node_b.x, geo_node_b.y)
6263
return haversine(own, other, unit=Unit.METERS)
6364

64-
def to_wgs84(self) -> "Wgs84GeoNode":
65+
def to_wgs84(self) -> Wgs84GeoNode:
6566
return self
6667

67-
def to_dbref(self) -> "DbrefGeoNode":
68+
def to_dbref(self) -> DbrefGeoNode:
6869
x, y = transform_wgs84_to_dbref(self.x, self.y, self.dbref_crs)
6970
return DbrefGeoNode(x, y, self.dbref_crs, uuid=self.uuid)
7071

71-
def to_euclidean(self) -> "EuclideanGeoNode":
72+
def to_euclidean(self) -> EuclideanGeoNode:
7273
return self.to_dbref().to_euclidean()
7374

7475

7576
class DbrefGeoNode(GeoNode):
76-
def get_distance_to_other_geo_node(self, geo_node_b: "GeoNode"):
77+
def get_distance_to_other_geo_node(self, geo_node_b: GeoNode):
7778
# Separate DB Ref distance method not implemented yet, therefore use WGS84 distance
7879
return self.to_wgs84().get_distance_to_other_geo_node(geo_node_b)
7980

80-
def to_wgs84(self) -> "Wgs84GeoNode":
81+
def to_wgs84(self) -> Wgs84GeoNode:
8182
x, y = transform_dbref_to_wgs84(self.x, self.y, self.dbref_crs)
8283
return Wgs84GeoNode(x, y, self.dbref_crs, uuid=self.uuid)
8384

84-
def to_dbref(self) -> "DbrefGeoNode":
85+
def to_dbref(self) -> DbrefGeoNode:
8586
return self
8687

87-
def to_euclidean(self) -> "EuclideanGeoNode":
88+
def to_euclidean(self) -> EuclideanGeoNode:
8889
# This transformation is just for testing purposes and not correct, see documentation in EuclideanGeoNode.
8990
_x_shift = 4533770.0
9091
_y_shift = 5625780.0
@@ -100,24 +101,24 @@ class EuclideanGeoNode(GeoNode):
100101
DBRef geo nodes, but the conversion is just a simple coordinate shift and with this, probably incorrect.
101102
"""
102103

103-
def get_distance_to_other_geo_node(self, geo_node_b: "EuclideanGeoNode"):
104+
def get_distance_to_other_geo_node(self, geo_node_b: EuclideanGeoNode):
104105
geo_node_b = geo_node_b.to_euclidean()
105106
return self.__eucldian_distance(geo_node_b)
106107

107-
def __eucldian_distance(self, geo_node_b: "GeoNode"):
108+
def __eucldian_distance(self, geo_node_b: GeoNode):
108109
min_x = min(self.x, geo_node_b.x)
109110
min_y = min(self.y, geo_node_b.y)
110111
max_x = max(self.x, geo_node_b.x)
111112
max_y = max(self.y, geo_node_b.y)
112113
return math.sqrt(math.pow(max_x - min_x, 2) + math.pow(max_y - min_y, 2))
113114

114-
def to_wgs84(self) -> "Wgs84GeoNode":
115+
def to_wgs84(self) -> Wgs84GeoNode:
115116
return self.to_dbref().to_wgs84()
116117

117-
def to_dbref(self) -> "DbrefGeoNode":
118+
def to_dbref(self) -> DbrefGeoNode:
118119
_x_shift = 4533770.0
119120
_y_shift = 5625780.0
120121
return DbrefGeoNode(self.x + _x_shift, self.y + _y_shift, self.dbref_crs, uuid=self.uuid)
121122

122-
def to_euclidean(self) -> "EuclideanGeoNode":
123+
def to_euclidean(self) -> EuclideanGeoNode:
123124
return self

yaramo/node.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
1+
from __future__ import annotations
2+
13
import sys
24
from enum import Enum
35
from itertools import permutations
46
from math import atan2, cos, sin
5-
from typing import List
7+
from typing import TYPE_CHECKING
68

79
from yaramo.base_element import BaseElement
810
from yaramo.geo_node import GeoNode
911

12+
if TYPE_CHECKING:
13+
from yaramo.edge import Edge
14+
1015

1116
class EdgeConnectionDirection(Enum):
1217
Spitze = 0
@@ -33,17 +38,17 @@ def __init__(self, turnout_side=None, **kwargs):
3338
"""
3439

3540
super().__init__(**kwargs)
36-
self.connected_edge_on_head = None
37-
self.connected_edge_on_right = None
38-
self.connected_edge_on_left = None
41+
self.connected_edge_on_head: Edge = None
42+
self.connected_edge_on_right: Edge = None
43+
self.connected_edge_on_left: Edge = None
3944
self.maximum_speed_on_left = None
4045
self.maximum_speed_on_right = None
41-
self.connected_edges: list["Edge"] = []
46+
self.connected_edges: list[Edge] = []
4247
self.geo_node: GeoNode = kwargs.get("geo_node", None)
4348
self.turnout_side: str = turnout_side
4449
self.drive_amount = 0
4550

46-
def maximum_speed(self, node_a: "Node", node_b: "Node"):
51+
def maximum_speed(self, node_a: Node, node_b: Node):
4752
"""Return the maximum allowed speed for traversing this node,
4853
coming from node_a and going to node_b
4954
"""
@@ -56,49 +61,49 @@ def maximum_speed(self, node_a: "Node", node_b: "Node"):
5661
return None
5762

5863
@property
59-
def connected_nodes(self):
64+
def connected_nodes(self) -> list[Node]:
6065
return [edge.get_opposite_node(self) for edge in self.connected_edges]
6166

6267
@property
63-
def connected_on_head(self):
68+
def connected_on_head(self) -> Node:
6469
if self.connected_edge_on_head is None:
6570
self.calc_anschluss_of_all_edges()
6671
if self.connected_edge_on_head is None:
6772
return None
6873
return self.connected_edge_on_head.get_opposite_node(self)
6974

7075
@property
71-
def connected_on_left(self):
76+
def connected_on_left(self) -> Node:
7277
if self.connected_edge_on_head is None:
7378
self.calc_anschluss_of_all_edges()
7479
if self.connected_edge_on_left is None:
7580
return None
7681
return self.connected_edge_on_left.get_opposite_node(self)
7782

7883
@property
79-
def connected_on_right(self):
84+
def connected_on_right(self) -> Node:
8085
if self.connected_edge_on_head is None:
8186
self.calc_anschluss_of_all_edges()
8287
if self.connected_edge_on_right is None:
8388
return None
8489
return self.connected_edge_on_right.get_opposite_node(self)
8590

86-
def set_connection_head_edge(self, edge: "Edge"):
91+
def set_connection_head_edge(self, edge: Edge):
8792
self.connected_edge_on_head = edge
8893
if edge not in self.connected_edges:
8994
self.connected_edges.append(edge)
9095

91-
def set_connection_left_edge(self, edge: "Edge"):
96+
def set_connection_left_edge(self, edge: Edge):
9297
self.connected_edge_on_left = edge
9398
if edge not in self.connected_edges:
9499
self.connected_edges.append(edge)
95100

96-
def set_connection_right_edge(self, edge: "Edge"):
101+
def set_connection_right_edge(self, edge: Edge):
97102
self.connected_edge_on_right = edge
98103
if edge not in self.connected_edges:
99104
self.connected_edges.append(edge)
100105

101-
def remove_edge(self, edge: "Edge"):
106+
def remove_edge(self, edge: Edge):
102107
self.connected_edges.remove(edge)
103108
if self.connected_edge_on_head == edge:
104109
self.connected_edge_on_head = None
@@ -107,7 +112,7 @@ def remove_edge(self, edge: "Edge"):
107112
if self.connected_edge_on_right == edge:
108113
self.connected_edge_on_right = None
109114

110-
def remove_edge_to_node(self, node: "Node"):
115+
def remove_edge_to_node(self, node: Node):
111116
"""Removes the edge to the given node and removes the node from the connected_nodes list."""
112117
edge = self.get_edge_to_node(node)
113118
self.remove_edge(edge)
@@ -116,7 +121,7 @@ def get_edge_to_node(self, node):
116121
"""Returns the edge to the given neighbor node."""
117122
return next(edge for edge in self.connected_edges if edge.get_opposite_node(self) == node)
118123

119-
def get_possible_followers(self, source: "Edge") -> List["Edge"]:
124+
def get_possible_followers(self, source: Edge) -> list[Edge]:
120125
"""Returns the `Edge`s that could follow (head, left, right) when comming from a source `Edge` connected to this `Node`."""
121126
if source is None:
122127
return self.connected_edges
@@ -131,7 +136,7 @@ def get_possible_followers(self, source: "Edge") -> List["Edge"]:
131136
return [self.connected_edge_on_left, self.connected_edge_on_right]
132137
return [self.connected_edge_on_head]
133138

134-
def get_anschluss_for_edge(self, edge: "Edge") -> EdgeConnectionDirection:
139+
def get_anschluss_for_edge(self, edge: Edge) -> EdgeConnectionDirection:
135140
"""Gets the Anschluss (Ende, Links, Rechts, Spitze) of other node.
136141
Idea: We assume, the current node is a point and we want to estimate the Anschluss of the other node.
137142
@@ -203,7 +208,7 @@ def get_rad_between_nodes(geo_node_a: GeoNode, geo_node_b: GeoNode) -> float:
203208
self.connected_edge_on_right = right
204209
break
205210

206-
def is_point(self):
211+
def is_point(self) -> bool:
207212
"""
208213
Returns true if this node is a point.
209214
A point is a `Node` with at least 3 connected tracks (yaramo only supports three edges at the moment)

0 commit comments

Comments
 (0)