Skip to content

Commit 9c3dfed

Browse files
authored
Merge pull request #15 from warrior25/gtfs-service-day
GTFS service day handling
2 parents 1e8bfd3 + 3dcf1eb commit 9c3dfed

File tree

4 files changed

+89
-59
lines changed

4 files changed

+89
-59
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ Each station creates a sensor which contains data for departures from that stati
2727
| last_refresh | Timestamp (ISO 8601 format) indicating when real-time departures were last fetched. |
2828
| departures | A list of departure objects representing the next available departures. |
2929
| station_name | Name of the monitored stop. |
30+
| station_id | Unique identifier of the monitored stop. |
3031

3132
### Departures
3233

custom_components/nysse/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ async def async_setup_entry(
1515
hass.data[DOMAIN][entry.entry_id] = entry.data
1616

1717
# Forward the setup to the sensor platform.
18-
await hass.config_entries.async_forward_entry_setup(entry, "sensor")
18+
await hass.config_entries.async_forward_entry_setups(entry, ["sensor"])
1919
entry.async_on_unload(entry.add_update_listener(options_update_listener))
2020
return True
2121

custom_components/nysse/fetch_api.py

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,13 @@
77
import os
88
import pathlib
99
import sqlite3
10+
from typing import NamedTuple
1011
import zipfile
1112

1213
import aiohttp
14+
from dateutil import parser
15+
16+
import homeassistant.util.dt as dt_util
1317

1418
from .const import DOMAIN, GTFS_URL
1519

@@ -240,8 +244,7 @@ def _get_file_modified_time(file_path):
240244
def _parse_csv_file(file_path):
241245
with open(file_path, newline="", encoding="utf-8") as csvfile:
242246
reader = csv.DictReader(csvfile)
243-
data = [row.copy() for row in reader]
244-
return data
247+
return [row.copy() for row in reader]
245248

246249

247250
async def get_stops():
@@ -289,6 +292,16 @@ async def get_route_ids(stop_id):
289292
return route_ids
290293

291294

295+
class StopTime(NamedTuple):
296+
route_id: str
297+
trip_headsign: str
298+
departure_time: datetime
299+
aimed_departure_time: datetime | None
300+
delay: int | None
301+
delta_days: int
302+
realtime: bool
303+
304+
292305
async def get_stop_times(stop_id, route_ids, amount, from_time):
293306
"""Get the stop times for a given stop ID, route IDs, and amount.
294307
@@ -306,26 +319,52 @@ async def get_stop_times(stop_id, route_ids, amount, from_time):
306319
conn, cursor = _get_database()
307320
today = datetime.now().strftime("%Y%m%d")
308321
weekday = datetime.strptime(today, "%Y%m%d").strftime("%A").lower()
309-
stop_times = []
322+
stop_times: list[StopTime] = []
310323
delta_days = 0
324+
start_time = from_time.strftime("%H:%M:%S")
311325
while len(stop_times) < amount:
312326
cursor.execute(
313327
f"""
314-
SELECT stop_times.trip_id, route_id, trip_headsign, departure_time, {delta_days} as delta_days
328+
SELECT route_id, trip_headsign, departure_time
315329
FROM stop_times
316330
JOIN trips ON stop_times.trip_id = trips.trip_id
317331
JOIN calendar ON trips.service_id = calendar.service_id
318332
WHERE stop_id = ?
319-
AND trips.route_id IN ({','.join(['?']*len(route_ids))})
333+
AND trips.route_id IN ({",".join(["?"] * len(route_ids))})
320334
AND calendar.{weekday} = '1'
321335
AND calendar.start_date <= ?
322336
AND calendar.end_date >= ?
323337
AND departure_time > ?
324338
LIMIT ?
325339
""",
326-
[stop_id, *route_ids, today, today, from_time.strftime("%H:%M:%S"), amount],
340+
[stop_id, *route_ids, today, today, start_time, amount],
327341
)
328-
stop_times += cursor.fetchall()
342+
343+
for row in cursor.fetchall():
344+
route_id, trip_headsign, departure_time_str = row
345+
row_delta_days = delta_days
346+
hours, minutes, seconds = map(int, departure_time_str.split(":"))
347+
348+
if hours > 23:
349+
hours -= 24
350+
row_delta_days += 1
351+
352+
valid_time_str = f"{hours:02}:{minutes:02}:{seconds:02}"
353+
354+
departure_time = dt_util.as_local(parser.parse(valid_time_str))
355+
356+
stop_times.append(
357+
StopTime(
358+
route_id,
359+
trip_headsign,
360+
departure_time,
361+
None,
362+
None,
363+
row_delta_days,
364+
False,
365+
)
366+
)
367+
329368
if len(stop_times) >= amount:
330369
break
331370
# If there are no more stop times for today, move to the next day
@@ -338,5 +377,6 @@ async def get_stop_times(stop_id, route_ids, amount, from_time):
338377
next_day = datetime.strptime(today, "%Y%m%d") + timedelta(days=1)
339378
today = next_day.strftime("%Y%m%d")
340379
weekday = next_day.strftime("%A").lower()
380+
start_time = "00:00:00"
341381
conn.close()
342382
return stop_times[:amount]

custom_components/nysse/sensor.py

Lines changed: 40 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import logging
88

99
from dateutil import parser
10+
from dateutil.parser import ParserError
1011
import isodate
1112

1213
from homeassistant import config_entries, core
@@ -24,7 +25,7 @@
2425
STOP_URL,
2526
TRAM_LINES,
2627
)
27-
from .fetch_api import get_stop_times, get_stops
28+
from .fetch_api import StopTime, get_stop_times, get_stops
2829
from .network import get
2930

3031
_LOGGER = logging.getLogger(__name__)
@@ -81,19 +82,17 @@ def __init__(self, stop_code, maximum, timelimit, lines) -> None:
8182

8283
self._last_update_time = None
8384

84-
def _remove_unwanted_departures(self, departures):
85+
def _remove_unwanted_departures(self, departures: list[StopTime]):
8586
try:
8687
removed_departures_count = 0
8788

8889
# Remove unwanted departures based on departure time and line number
8990
for departure in departures[:]:
90-
departure_local = dt_util.as_local(
91-
parser.parse(departure["departure_time"])
92-
)
91+
departure_local = dt_util.as_local(departure.departure_time)
9392
if (
9493
departure_local
9594
< self._last_update_time + timedelta(minutes=self._timelimit)
96-
or departure["route_id"] not in self._lines
95+
or departure.route_id not in self._lines
9796
):
9897
departures.remove(departure)
9998
removed_departures_count += 1
@@ -129,7 +128,7 @@ async def _fetch_departures(self):
129128
self._stop_code,
130129
url,
131130
)
132-
return
131+
return None
133132
unformatted_departures = json.loads(data)
134133
return self._format_departures(unformatted_departures)
135134
except OSError as err:
@@ -139,25 +138,20 @@ async def _fetch_departures(self):
139138
def _format_departures(self, departures):
140139
try:
141140
body = departures["body"][self._stop_code]
142-
formatted_data = []
141+
formatted_data: list[StopTime] = []
143142
for departure in body:
144143
try:
145-
formatted_departure = {
146-
"route_id": departure["lineRef"],
147-
"trip_headsign": self._get_stop_name(
148-
departure["destinationShortName"]
149-
),
150-
"departure_time": departure["call"]["expectedDepartureTime"],
151-
"aimed_departure_time": departure["call"]["aimedDepartureTime"],
152-
"delay": departure["delay"],
153-
"realtime": True,
154-
}
155-
if (
156-
formatted_departure["departure_time"] is not None
157-
and formatted_departure["aimed_departure_time"] is not None
158-
):
159-
formatted_data.append(formatted_departure)
160-
except KeyError as err:
144+
formatted_departure = StopTime(
145+
departure["lineRef"],
146+
self._get_stop_name(departure["destinationShortName"]),
147+
parser.parse(departure["call"]["expectedDepartureTime"]),
148+
parser.parse(departure["call"]["aimedDepartureTime"]),
149+
self._delay_to_display_format(departure["delay"]),
150+
0,
151+
True,
152+
)
153+
formatted_data.append(formatted_departure)
154+
except (KeyError, ParserError) as err:
161155
_LOGGER.info(
162156
"%s: Failed to process realtime departure: %s",
163157
self._stop_code,
@@ -200,13 +194,9 @@ async def async_update(self) -> None:
200194
)
201195
for journey in self._journeys[:]:
202196
for departure in departures:
203-
departure_time = parser.parse(departure["aimed_departure_time"])
204-
journey_time = dt_util.as_local(
205-
parser.parse(journey["departure_time"])
206-
)
207197
if (
208-
journey_time == departure_time
209-
and journey["route_id"] == departure["route_id"]
198+
journey.departure_time == departure.aimed_departure_time
199+
and journey.route_id == departure.route_id
210200
):
211201
self._journeys.remove(journey)
212202
else:
@@ -223,24 +213,24 @@ async def async_update(self) -> None:
223213
except (OSError, ValueError) as err:
224214
_LOGGER.error("%s: Failed to update sensor: %s", self._stop_code, err)
225215

226-
def _data_to_display_format(self, data):
216+
def _data_to_display_format(self, data: list[StopTime]):
227217
try:
228218
formatted_data = []
229219
for item in data:
230220
departure = {
231-
"destination": item["trip_headsign"],
232-
"line": item["route_id"],
233-
"departure": parser.parse(item["departure_time"]).strftime("%H:%M"),
221+
"destination": item.trip_headsign,
222+
"line": item.route_id,
223+
"departure": item.departure_time.strftime("%H:%M"),
234224
"time_to_station": self._time_to_station(item),
235-
"icon": self._get_line_icon(item["route_id"]),
236-
"realtime": item["realtime"] if "realtime" in item else False,
225+
"icon": self._get_line_icon(item.route_id),
226+
"realtime": item.realtime,
237227
}
238-
if "aimed_departure_time" in item:
239-
departure["aimed_departure"] = parser.parse(
240-
item["aimed_departure_time"]
241-
).strftime("%H:%M")
242-
if "delay" in item:
243-
departure["delay"] = self._delay_to_display_format(item["delay"])
228+
if item.aimed_departure_time is not None:
229+
departure["aimed_departure"] = item.aimed_departure_time.strftime(
230+
"%H:%M"
231+
)
232+
if item.delay is not None:
233+
departure["delay"] = item.delay
244234
formatted_data.append(departure)
245235
return sorted(formatted_data, key=lambda x: x["time_to_station"])
246236
except (OSError, ValueError) as err:
@@ -252,11 +242,11 @@ def _get_line_icon(self, line_no):
252242
return "mdi:tram"
253243
return "mdi:bus"
254244

255-
def _time_to_station(self, item):
245+
def _time_to_station(self, item: StopTime):
256246
try:
257-
departure_local = dt_util.as_local(parser.parse(item["departure_time"]))
258-
if "delta_days" in item:
259-
departure_local += timedelta(days=item["delta_days"])
247+
departure_local = dt_util.as_local(item.departure_time)
248+
if item.delta_days > 0:
249+
departure_local += timedelta(days=item.delta_days)
260250
next_departure_time = (departure_local - self._last_update_time).seconds
261251
return int(next_departure_time / 60)
262252
except OSError as err:
@@ -323,12 +313,12 @@ def state(self) -> str:
323313
@property
324314
def extra_state_attributes(self):
325315
"""Sensor attributes."""
326-
attributes = {
316+
return {
327317
"last_refresh": self._last_update_time,
328318
"departures": self._all_data,
329319
"station_name": self._get_stop_name(self._stop_code),
320+
"station_id": self._stop_code,
330321
}
331-
return attributes
332322

333323

334324
class ServiceAlertSensor(SensorEntity):
@@ -363,7 +353,7 @@ async def _fetch_service_alerts(self):
363353
"Nysse API error: failed to fetch service alerts: no data received from %s",
364354
SERVICE_ALERTS_URL,
365355
)
366-
return
356+
return None
367357
json_data = json.loads(data)
368358

369359
self._last_update = self._timestamp_to_local(
@@ -428,8 +418,7 @@ def state(self) -> str:
428418
@property
429419
def extra_state_attributes(self):
430420
"""Sensor attributes."""
431-
attributes = {
421+
return {
432422
"last_refresh": self._last_update,
433423
"alerts": self._alerts,
434424
}
435-
return attributes

0 commit comments

Comments
 (0)