|
| 1 | +import random |
| 2 | +from typing import Any, Dict, List, Tuple |
| 3 | + |
| 4 | +import numpy as np |
| 5 | + |
| 6 | +from mediator import Mediator |
| 7 | + |
| 8 | + |
| 9 | +class MiniMetroEnv: |
| 10 | + def __init__(self, dt_ms: int | None = None) -> None: |
| 11 | + self.dt_ms_default = dt_ms |
| 12 | + self.mediator = Mediator() |
| 13 | + self.last_score = self.mediator.score |
| 14 | + |
| 15 | + def reset(self, seed: int | None = None) -> Dict[str, Any]: |
| 16 | + if seed is not None: |
| 17 | + random.seed(seed) |
| 18 | + np.random.seed(seed) |
| 19 | + self.mediator = Mediator() |
| 20 | + self.last_score = self.mediator.score |
| 21 | + return self.observe() |
| 22 | + |
| 23 | + def step( |
| 24 | + self, action: Dict[str, Any] | None = None, dt_ms: int | None = None |
| 25 | + ) -> Tuple[Dict[str, Any], int, bool, Dict[str, Any]]: |
| 26 | + if action is None: |
| 27 | + action = {"type": "noop"} |
| 28 | + action_ok = self.mediator.apply_action(action) |
| 29 | + |
| 30 | + if dt_ms is None: |
| 31 | + dt_ms = self.dt_ms_default |
| 32 | + if dt_ms is not None: |
| 33 | + self.mediator.step_time(dt_ms) |
| 34 | + |
| 35 | + obs = self.observe() |
| 36 | + reward = self.mediator.score - self.last_score |
| 37 | + self.last_score = self.mediator.score |
| 38 | + done = False |
| 39 | + info = {"action_ok": action_ok} |
| 40 | + return obs, reward, done, info |
| 41 | + |
| 42 | + def observe(self) -> Dict[str, Any]: |
| 43 | + station_id_to_index = { |
| 44 | + station.id: idx for idx, station in enumerate(self.mediator.stations) |
| 45 | + } |
| 46 | + path_id_to_index = { |
| 47 | + path.id: idx for idx, path in enumerate(self.mediator.paths) |
| 48 | + } |
| 49 | + metro_id_to_index = { |
| 50 | + metro.id: idx for idx, metro in enumerate(self.mediator.metros) |
| 51 | + } |
| 52 | + passenger_id_to_index = { |
| 53 | + passenger.id: idx for idx, passenger in enumerate(self.mediator.passengers) |
| 54 | + } |
| 55 | + |
| 56 | + passenger_locations: Dict[str, Tuple[str, str] | None] = { |
| 57 | + passenger.id: None for passenger in self.mediator.passengers |
| 58 | + } |
| 59 | + for station in self.mediator.stations: |
| 60 | + for passenger in station.passengers: |
| 61 | + passenger_locations[passenger.id] = ("station", station.id) |
| 62 | + for metro in self.mediator.metros: |
| 63 | + for passenger in metro.passengers: |
| 64 | + passenger_locations[passenger.id] = ("metro", metro.id) |
| 65 | + |
| 66 | + structured = { |
| 67 | + "stations": [ |
| 68 | + { |
| 69 | + "id": station.id, |
| 70 | + "position": (station.position.left, station.position.top), |
| 71 | + "shape_type": station.shape.type, |
| 72 | + "passenger_ids": [p.id for p in station.passengers], |
| 73 | + "passenger_count": len(station.passengers), |
| 74 | + } |
| 75 | + for station in self.mediator.stations |
| 76 | + ], |
| 77 | + "paths": [ |
| 78 | + { |
| 79 | + "id": path.id, |
| 80 | + "station_ids": [s.id for s in path.stations], |
| 81 | + "is_looped": path.is_looped, |
| 82 | + "color": path.color, |
| 83 | + } |
| 84 | + for path in self.mediator.paths |
| 85 | + ], |
| 86 | + "metros": [ |
| 87 | + { |
| 88 | + "id": metro.id, |
| 89 | + "path_id": metro.path_id, |
| 90 | + "position": ( |
| 91 | + (metro.position.left, metro.position.top) |
| 92 | + if metro.position is not None |
| 93 | + else None |
| 94 | + ), |
| 95 | + "current_station_id": ( |
| 96 | + metro.current_station.id if metro.current_station else None |
| 97 | + ), |
| 98 | + "passenger_ids": [p.id for p in metro.passengers], |
| 99 | + } |
| 100 | + for metro in self.mediator.metros |
| 101 | + ], |
| 102 | + "passengers": [ |
| 103 | + { |
| 104 | + "id": passenger.id, |
| 105 | + "destination_shape_type": passenger.destination_shape.type, |
| 106 | + "is_at_destination": passenger.is_at_destination, |
| 107 | + "location": passenger_locations[passenger.id], |
| 108 | + } |
| 109 | + for passenger in self.mediator.passengers |
| 110 | + ], |
| 111 | + "score": self.mediator.score, |
| 112 | + "time_ms": self.mediator.time_ms, |
| 113 | + "steps": self.mediator.steps, |
| 114 | + "is_paused": self.mediator.is_paused, |
| 115 | + "index": { |
| 116 | + "station_id_to_index": station_id_to_index, |
| 117 | + "path_id_to_index": path_id_to_index, |
| 118 | + "metro_id_to_index": metro_id_to_index, |
| 119 | + "passenger_id_to_index": passenger_id_to_index, |
| 120 | + }, |
| 121 | + } |
| 122 | + |
| 123 | + arrays = self._encode_numpy( |
| 124 | + station_id_to_index, |
| 125 | + path_id_to_index, |
| 126 | + metro_id_to_index, |
| 127 | + passenger_id_to_index, |
| 128 | + ) |
| 129 | + |
| 130 | + return {"structured": structured, "arrays": arrays} |
| 131 | + |
| 132 | + def _encode_numpy( |
| 133 | + self, |
| 134 | + station_id_to_index: Dict[str, int], |
| 135 | + path_id_to_index: Dict[str, int], |
| 136 | + metro_id_to_index: Dict[str, int], |
| 137 | + passenger_id_to_index: Dict[str, int], |
| 138 | + ) -> Dict[str, Any]: |
| 139 | + station_positions = np.array( |
| 140 | + [ |
| 141 | + [station.position.left, station.position.top] |
| 142 | + for station in self.mediator.stations |
| 143 | + ], |
| 144 | + dtype=np.float32, |
| 145 | + ) |
| 146 | + station_shape_types = np.array( |
| 147 | + [int(station.shape.type.value) for station in self.mediator.stations], |
| 148 | + dtype=np.int64, |
| 149 | + ) |
| 150 | + station_passenger_counts = np.array( |
| 151 | + [len(station.passengers) for station in self.mediator.stations], |
| 152 | + dtype=np.int64, |
| 153 | + ) |
| 154 | + path_station_indices = [ |
| 155 | + np.array( |
| 156 | + [station_id_to_index[s.id] for s in path.stations], dtype=np.int64 |
| 157 | + ) |
| 158 | + for path in self.mediator.paths |
| 159 | + ] |
| 160 | + path_is_looped = np.array( |
| 161 | + [int(path.is_looped) for path in self.mediator.paths], dtype=np.int64 |
| 162 | + ) |
| 163 | + |
| 164 | + metro_positions_list = [ |
| 165 | + [metro.position.left, metro.position.top] |
| 166 | + if metro.position is not None |
| 167 | + else [-1, -1] |
| 168 | + for metro in self.mediator.metros |
| 169 | + ] |
| 170 | + if metro_positions_list: |
| 171 | + metro_positions = np.array(metro_positions_list, dtype=np.float32) |
| 172 | + else: |
| 173 | + metro_positions = np.zeros((0, 2), dtype=np.float32) |
| 174 | + metro_path_indices = np.array( |
| 175 | + [ |
| 176 | + path_id_to_index.get(metro.path_id, -1) |
| 177 | + for metro in self.mediator.metros |
| 178 | + ], |
| 179 | + dtype=np.int64, |
| 180 | + ) |
| 181 | + |
| 182 | + passenger_destination_types = np.array( |
| 183 | + [ |
| 184 | + int(passenger.destination_shape.type.value) |
| 185 | + for passenger in self.mediator.passengers |
| 186 | + ], |
| 187 | + dtype=np.int64, |
| 188 | + ) |
| 189 | + passenger_station_indices = np.full( |
| 190 | + (len(self.mediator.passengers),), -1, dtype=np.int64 |
| 191 | + ) |
| 192 | + passenger_metro_indices = np.full( |
| 193 | + (len(self.mediator.passengers),), -1, dtype=np.int64 |
| 194 | + ) |
| 195 | + |
| 196 | + for station in self.mediator.stations: |
| 197 | + for passenger in station.passengers: |
| 198 | + idx = passenger_id_to_index.get(passenger.id) |
| 199 | + if idx is not None: |
| 200 | + passenger_station_indices[idx] = station_id_to_index[station.id] |
| 201 | + for metro in self.mediator.metros: |
| 202 | + for passenger in metro.passengers: |
| 203 | + idx = passenger_id_to_index.get(passenger.id) |
| 204 | + if idx is not None: |
| 205 | + passenger_metro_indices[idx] = metro_id_to_index[metro.id] |
| 206 | + |
| 207 | + return { |
| 208 | + "station_positions": station_positions, |
| 209 | + "station_shape_types": station_shape_types, |
| 210 | + "station_passenger_counts": station_passenger_counts, |
| 211 | + "path_station_indices": path_station_indices, |
| 212 | + "path_is_looped": path_is_looped, |
| 213 | + "metro_positions": metro_positions, |
| 214 | + "metro_path_indices": metro_path_indices, |
| 215 | + "passenger_destination_types": passenger_destination_types, |
| 216 | + "passenger_station_indices": passenger_station_indices, |
| 217 | + "passenger_metro_indices": passenger_metro_indices, |
| 218 | + } |
0 commit comments