Skip to content

Commit 4ce00d4

Browse files
committed
Support playing programmatically
1 parent d583e88 commit 4ce00d4

6 files changed

Lines changed: 642 additions & 0 deletions

File tree

.cursor/rules/update-progress.mdc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
---
2+
alwaysApply: true
3+
---
4+
Before you do anything, check `PROGRESS.md` to check what is already done.
5+
After you do anything, update `PROGRESS.md` after to keep track of what is done.
6+
This serves as long-term context.

PROGRESS.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
## 2026-02-08
2+
3+
- Added Gym-like programmatic play interface in `src/env.py` with structured and numpy observations.
4+
- Added high-level programmatic actions in `src/mediator.py` (create/remove paths, pause/resume, step time).
5+
- Expanded programmatic-play tests in `test/test_env.py` for loops, invalid actions, limits, reward delivery, and observations.
6+
7+
Tests:
8+
- `python -m unittest -v test.test_env`

README.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,19 @@ This repo uses `pygame-ce` to implement Mini Metro, a fun 2D strategic game wher
1717
* The number of grey circles on top of the screen is the number of availabel metro lines left.
1818
* Click on the colored circle at the top to cancel an established line.
1919

20+
# Programmatic play
21+
Use the Gym-like environment in `src/env.py`:
22+
23+
```
24+
from env import MiniMetroEnv
25+
26+
env = MiniMetroEnv(dt_ms=16)
27+
obs = env.reset(seed=42)
28+
obs, reward, done, info = env.step(
29+
{"type": "create_path", "stations": [0, 1, 2], "loop": False}
30+
)
31+
obs, reward, done, info = env.step({"type": "remove_path", "path_index": 0})
32+
```
33+
2034
# Testing
2135
`python -m unittest -v`

src/env.py

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
import random
2+
from typing import Any, Dict, List, Tuple
3+
4+
import numpy as np
5+
6+
from mediator import Mediator
7+
8+
9+
class MiniMetroEnv:
10+
def __init__(self, dt_ms: int | None = None) -> None:
11+
self.dt_ms_default = dt_ms
12+
self.mediator = Mediator()
13+
self.last_score = self.mediator.score
14+
15+
def reset(self, seed: int | None = None) -> Dict[str, Any]:
16+
if seed is not None:
17+
random.seed(seed)
18+
np.random.seed(seed)
19+
self.mediator = Mediator()
20+
self.last_score = self.mediator.score
21+
return self.observe()
22+
23+
def step(
24+
self, action: Dict[str, Any] | None = None, dt_ms: int | None = None
25+
) -> Tuple[Dict[str, Any], int, bool, Dict[str, Any]]:
26+
if action is None:
27+
action = {"type": "noop"}
28+
action_ok = self.mediator.apply_action(action)
29+
30+
if dt_ms is None:
31+
dt_ms = self.dt_ms_default
32+
if dt_ms is not None:
33+
self.mediator.step_time(dt_ms)
34+
35+
obs = self.observe()
36+
reward = self.mediator.score - self.last_score
37+
self.last_score = self.mediator.score
38+
done = False
39+
info = {"action_ok": action_ok}
40+
return obs, reward, done, info
41+
42+
def observe(self) -> Dict[str, Any]:
43+
station_id_to_index = {
44+
station.id: idx for idx, station in enumerate(self.mediator.stations)
45+
}
46+
path_id_to_index = {
47+
path.id: idx for idx, path in enumerate(self.mediator.paths)
48+
}
49+
metro_id_to_index = {
50+
metro.id: idx for idx, metro in enumerate(self.mediator.metros)
51+
}
52+
passenger_id_to_index = {
53+
passenger.id: idx for idx, passenger in enumerate(self.mediator.passengers)
54+
}
55+
56+
passenger_locations: Dict[str, Tuple[str, str] | None] = {
57+
passenger.id: None for passenger in self.mediator.passengers
58+
}
59+
for station in self.mediator.stations:
60+
for passenger in station.passengers:
61+
passenger_locations[passenger.id] = ("station", station.id)
62+
for metro in self.mediator.metros:
63+
for passenger in metro.passengers:
64+
passenger_locations[passenger.id] = ("metro", metro.id)
65+
66+
structured = {
67+
"stations": [
68+
{
69+
"id": station.id,
70+
"position": (station.position.left, station.position.top),
71+
"shape_type": station.shape.type,
72+
"passenger_ids": [p.id for p in station.passengers],
73+
"passenger_count": len(station.passengers),
74+
}
75+
for station in self.mediator.stations
76+
],
77+
"paths": [
78+
{
79+
"id": path.id,
80+
"station_ids": [s.id for s in path.stations],
81+
"is_looped": path.is_looped,
82+
"color": path.color,
83+
}
84+
for path in self.mediator.paths
85+
],
86+
"metros": [
87+
{
88+
"id": metro.id,
89+
"path_id": metro.path_id,
90+
"position": (
91+
(metro.position.left, metro.position.top)
92+
if metro.position is not None
93+
else None
94+
),
95+
"current_station_id": (
96+
metro.current_station.id if metro.current_station else None
97+
),
98+
"passenger_ids": [p.id for p in metro.passengers],
99+
}
100+
for metro in self.mediator.metros
101+
],
102+
"passengers": [
103+
{
104+
"id": passenger.id,
105+
"destination_shape_type": passenger.destination_shape.type,
106+
"is_at_destination": passenger.is_at_destination,
107+
"location": passenger_locations[passenger.id],
108+
}
109+
for passenger in self.mediator.passengers
110+
],
111+
"score": self.mediator.score,
112+
"time_ms": self.mediator.time_ms,
113+
"steps": self.mediator.steps,
114+
"is_paused": self.mediator.is_paused,
115+
"index": {
116+
"station_id_to_index": station_id_to_index,
117+
"path_id_to_index": path_id_to_index,
118+
"metro_id_to_index": metro_id_to_index,
119+
"passenger_id_to_index": passenger_id_to_index,
120+
},
121+
}
122+
123+
arrays = self._encode_numpy(
124+
station_id_to_index,
125+
path_id_to_index,
126+
metro_id_to_index,
127+
passenger_id_to_index,
128+
)
129+
130+
return {"structured": structured, "arrays": arrays}
131+
132+
def _encode_numpy(
133+
self,
134+
station_id_to_index: Dict[str, int],
135+
path_id_to_index: Dict[str, int],
136+
metro_id_to_index: Dict[str, int],
137+
passenger_id_to_index: Dict[str, int],
138+
) -> Dict[str, Any]:
139+
station_positions = np.array(
140+
[
141+
[station.position.left, station.position.top]
142+
for station in self.mediator.stations
143+
],
144+
dtype=np.float32,
145+
)
146+
station_shape_types = np.array(
147+
[int(station.shape.type.value) for station in self.mediator.stations],
148+
dtype=np.int64,
149+
)
150+
station_passenger_counts = np.array(
151+
[len(station.passengers) for station in self.mediator.stations],
152+
dtype=np.int64,
153+
)
154+
path_station_indices = [
155+
np.array(
156+
[station_id_to_index[s.id] for s in path.stations], dtype=np.int64
157+
)
158+
for path in self.mediator.paths
159+
]
160+
path_is_looped = np.array(
161+
[int(path.is_looped) for path in self.mediator.paths], dtype=np.int64
162+
)
163+
164+
metro_positions_list = [
165+
[metro.position.left, metro.position.top]
166+
if metro.position is not None
167+
else [-1, -1]
168+
for metro in self.mediator.metros
169+
]
170+
if metro_positions_list:
171+
metro_positions = np.array(metro_positions_list, dtype=np.float32)
172+
else:
173+
metro_positions = np.zeros((0, 2), dtype=np.float32)
174+
metro_path_indices = np.array(
175+
[
176+
path_id_to_index.get(metro.path_id, -1)
177+
for metro in self.mediator.metros
178+
],
179+
dtype=np.int64,
180+
)
181+
182+
passenger_destination_types = np.array(
183+
[
184+
int(passenger.destination_shape.type.value)
185+
for passenger in self.mediator.passengers
186+
],
187+
dtype=np.int64,
188+
)
189+
passenger_station_indices = np.full(
190+
(len(self.mediator.passengers),), -1, dtype=np.int64
191+
)
192+
passenger_metro_indices = np.full(
193+
(len(self.mediator.passengers),), -1, dtype=np.int64
194+
)
195+
196+
for station in self.mediator.stations:
197+
for passenger in station.passengers:
198+
idx = passenger_id_to_index.get(passenger.id)
199+
if idx is not None:
200+
passenger_station_indices[idx] = station_id_to_index[station.id]
201+
for metro in self.mediator.metros:
202+
for passenger in metro.passengers:
203+
idx = passenger_id_to_index.get(passenger.id)
204+
if idx is not None:
205+
passenger_metro_indices[idx] = metro_id_to_index[metro.id]
206+
207+
return {
208+
"station_positions": station_positions,
209+
"station_shape_types": station_shape_types,
210+
"station_passenger_counts": station_passenger_counts,
211+
"path_station_indices": path_station_indices,
212+
"path_is_looped": path_is_looped,
213+
"metro_positions": metro_positions,
214+
"metro_path_indices": metro_path_indices,
215+
"passenger_destination_types": passenger_destination_types,
216+
"passenger_station_indices": passenger_station_indices,
217+
"passenger_metro_indices": passenger_metro_indices,
218+
}

src/mediator.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@ def __init__(self) -> None:
7474
self.is_paused = False
7575
self.score = 0
7676

77+
def step_time(self, dt_ms: int) -> None:
78+
self.increment_time(dt_ms)
79+
7780
def assign_paths_to_buttons(self) -> None:
7881
for path_button in self.path_buttons:
7982
path_button.remove_path()
@@ -162,6 +165,19 @@ def remove_path(self, path: Path) -> None:
162165
self.assign_paths_to_buttons()
163166
self.find_travel_plan_for_passengers()
164167

168+
def remove_path_by_id(self, path_id: str) -> bool:
169+
for path in self.paths:
170+
if path.id == path_id:
171+
self.remove_path(path)
172+
return True
173+
return False
174+
175+
def remove_path_by_index(self, path_index: int) -> bool:
176+
if 0 <= path_index < len(self.paths):
177+
self.remove_path(self.paths[path_index])
178+
return True
179+
return False
180+
165181
def start_path_on_station(self, station: Station) -> None:
166182
if len(self.paths) < self.num_paths:
167183
self.is_creating_path = True
@@ -178,6 +194,30 @@ def start_path_on_station(self, station: Station) -> None:
178194
self.path_being_created = path
179195
self.paths.append(path)
180196

197+
def create_path_from_station_indices(
198+
self, station_indices: List[int], loop: bool = False
199+
) -> Path | None:
200+
if len(station_indices) < 2 or len(self.paths) >= self.num_paths:
201+
return None
202+
if any(
203+
idx < 0 or idx >= len(self.stations) for idx in station_indices
204+
):
205+
return None
206+
207+
self.start_path_on_station(self.stations[station_indices[0]])
208+
if not self.path_being_created:
209+
return None
210+
211+
for idx in station_indices[1:-1]:
212+
self.add_station_to_path(self.stations[idx])
213+
214+
if loop:
215+
self.end_path_on_station(self.stations[station_indices[0]])
216+
else:
217+
self.end_path_on_station(self.stations[station_indices[-1]])
218+
219+
return self.paths[-1] if self.paths else None
220+
181221
def add_station_to_path(self, station: Station) -> None:
182222
assert self.path_being_created is not None
183223
if self.path_being_created.stations[-1] == station:
@@ -217,6 +257,31 @@ def finish_path_creation(self) -> None:
217257
self.path_being_created = None
218258
self.assign_paths_to_buttons()
219259

260+
def set_paused(self, paused: bool) -> None:
261+
self.is_paused = paused
262+
263+
def apply_action(self, action: Dict) -> bool:
264+
action_type = action.get("type")
265+
if action_type == "create_path":
266+
stations = action.get("stations", [])
267+
loop = bool(action.get("loop", False))
268+
return self.create_path_from_station_indices(stations, loop) is not None
269+
if action_type == "remove_path":
270+
if "path_id" in action:
271+
return self.remove_path_by_id(action["path_id"])
272+
if "path_index" in action:
273+
return self.remove_path_by_index(action["path_index"])
274+
return False
275+
if action_type == "pause":
276+
self.set_paused(True)
277+
return True
278+
if action_type == "resume":
279+
self.set_paused(False)
280+
return True
281+
if action_type == "noop" or action_type is None:
282+
return True
283+
return False
284+
220285
def end_path_on_station(self, station: Station) -> None:
221286
assert self.path_being_created is not None
222287
# current station de-dupe

0 commit comments

Comments
 (0)