-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
154 lines (123 loc) · 5.33 KB
/
utils.py
File metadata and controls
154 lines (123 loc) · 5.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
from timeit import default_timer
from typing import Tuple
import numpy as np
from flatland.envs.observations import TreeObsForRailEnv
class Timer(object):
"""
Utility to measure times.
TODO:
- add "lap" method to make it easier to measure average time (+std) when measuring the same thing multiple times.
"""
def __init__(self):
self.total_time = 0.0
self.start_time = 0.0
self.end_time = 0.0
def start(self):
self.start_time = default_timer()
def end(self):
self.total_time += default_timer() - self.start_time
def get(self):
return self.total_time
def get_current(self):
return default_timer() - self.start_time
def reset(self):
self.__init__()
def __repr__(self):
return self.get()
def max_lt(seq, val):
"""
Return greatest item in seq for which item < val applies.
None is returned if seq was empty or all items in seq were >= val.
"""
max = 0
idx = len(seq) - 1
while idx >= 0:
if seq[idx] < val and seq[idx] >= 0 and seq[idx] > max:
max = seq[idx]
idx -= 1
return max
def min_gt(seq, val):
"""
Return smallest item in seq for which item > val applies.
None is returned if seq was empty or all items in seq were >= val.
"""
min = np.inf
idx = len(seq) - 1
while idx >= 0:
if seq[idx] >= val and seq[idx] < min:
min = seq[idx]
idx -= 1
return min
def norm_obs_clip(obs, clip_min=-1, clip_max=1, fixed_radius=0, normalize_to_range=False):
"""
This function returns the difference between min and max value of an observation
:param obs: Observation that should be normalized
:param clip_min: min value where observation will be clipped
:param clip_max: max value where observation will be clipped
:return: returnes normalized and clipped observatoin
"""
if fixed_radius > 0:
max_obs = fixed_radius
else:
max_obs = max(1, max_lt(obs, 1000)) + 1
min_obs = 0 # min(max_obs, min_gt(obs, 0))
if normalize_to_range:
min_obs = min_gt(obs, 0)
if min_obs > max_obs:
min_obs = max_obs
if max_obs == min_obs:
return np.clip(np.array(obs) / max_obs, clip_min, clip_max)
norm = np.abs(max_obs - min_obs)
return np.clip((np.array(obs) - min_obs) / norm, clip_min, clip_max)
def _split_subtree_into_feature_groups(node, current_tree_depth: int, max_tree_depth: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
if node == -np.inf:
remaining_depth = max_tree_depth - current_tree_depth
# reference: https://stackoverflow.com/questions/515214/total-number-of-nodes-in-a-tree-data-structure
num_remaining_nodes = int((4 ** (remaining_depth + 1) - 1) / (4 - 1))
return [-np.inf] * num_remaining_nodes * 6, [-np.inf] * num_remaining_nodes, [-np.inf] * num_remaining_nodes * 4
data, distance, agent_data = _split_node_into_feature_groups(node)
if not node.childs:
return data, distance, agent_data
for direction in TreeObsForRailEnv.tree_explored_actions_char:
sub_data, sub_distance, sub_agent_data = _split_subtree_into_feature_groups(node.childs[direction], current_tree_depth + 1, max_tree_depth)
data = np.concatenate((data, sub_data))
distance = np.concatenate((distance, sub_distance))
agent_data = np.concatenate((agent_data, sub_agent_data))
return data, distance, agent_data
def _split_node_into_feature_groups(node) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
data = np.zeros(6)
distance = np.zeros(1)
agent_data = np.zeros(4)
data[0] = node.dist_own_target_encountered
data[1] = node.dist_other_target_encountered
data[2] = node.dist_other_agent_encountered
data[3] = node.dist_potential_conflict
data[4] = node.dist_unusable_switch
data[5] = node.dist_to_next_branch
distance[0] = node.dist_min_to_target
agent_data[0] = node.num_agents_same_direction
agent_data[1] = node.num_agents_opposite_direction
agent_data[2] = node.num_agents_malfunctioning
agent_data[3] = node.speed_min_fractional
return data, distance, agent_data
def split_tree_into_feature_groups(tree, max_tree_depth: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
This function splits the tree into three difference arrays of values
"""
data, distance, agent_data = _split_node_into_feature_groups(tree)
for direction in TreeObsForRailEnv.tree_explored_actions_char:
sub_data, sub_distance, sub_agent_data = _split_subtree_into_feature_groups(tree.childs[direction], 1, max_tree_depth)
data = np.concatenate((data, sub_data))
distance = np.concatenate((distance, sub_distance))
agent_data = np.concatenate((agent_data, sub_agent_data))
return data, distance, agent_data
def normalize_observation(observation, tree_depth: int, observation_radius=0):
"""
This function normalizes the observation used by the RL algorithm
"""
data, distance, agent_data = split_tree_into_feature_groups(observation, tree_depth)
data = norm_obs_clip(data, fixed_radius=observation_radius)
distance = norm_obs_clip(distance, normalize_to_range=True)
agent_data = np.clip(agent_data, -1, 1)
normalized_obs = np.concatenate((np.concatenate((data, distance)), agent_data))
return normalized_obs