diff --git a/src/event.py b/src/event.py index fa63c68..e06f7b9 100644 --- a/src/event.py +++ b/src/event.py @@ -1,6 +1,6 @@ import json -from datetime import datetime -from typing import Tuple +from datetime import datetime, timezone, timedelta +from typing import Tuple, Optional import pandas as pd from ddtrace import tracer import warnings @@ -8,7 +8,7 @@ from config import CONFIG from constants import BUS_STOPS, ROUTES_CR, ROUTES_RAPID from logger import set_up_logging -from trip_state import TripsStateManager +from trip_state import TripsStateManager, TripState import disk import gtfs @@ -50,6 +50,7 @@ def arr_or_dep_event( @tracer.wrap() def reduce_update_event(update: dict) -> Tuple: + id = update["id"] current_status = update["attributes"]["current_status"] event_type = EVENT_TYPE_MAP[current_status] updated_at = datetime.fromisoformat(update["attributes"]["updated_at"]) @@ -81,6 +82,7 @@ def reduce_update_event(update: dict) -> Tuple: stop_id = None return ( + id, current_status, event_type, update["attributes"]["current_stop_sequence"], @@ -100,6 +102,7 @@ def reduce_update_event(update: dict) -> Tuple: def process_event(update, trips_state: TripsStateManager): """Process a single event from the MBTA's realtime API.""" ( + id, current_status, event_type, current_stop_sequence, @@ -182,6 +185,7 @@ def process_event(update, trips_state: TripsStateManager): route_id, trip_id, { + "id": id, "stop_sequence": current_stop_sequence, "stop_id": stop_id, "updated_at": updated_at, @@ -193,6 +197,103 @@ def process_event(update, trips_state: TripsStateManager): ) +@tracer.wrap() +def enrich_remove_event(update, trips_state: TripsStateManager) -> Optional[dict]: + """ + Enrich remove event to have the expected attributes for the process event function + + Args: + update: A dictionary only containing the "id" + trips_state: The trips state + + Returns: + dict: Returns Dict that attempts to be equivalent in structure to an update event. + """ + target_id = update["id"] + eastern = timezone(timedelta(hours=-4)) + now = datetime.now(eastern) + route_id, trip_id, last_known_trip_state = trips_state.get_trip_by_id(target_id) + + if last_known_trip_state: + if trip_id: + direction_id = get_direction_id(trip_id) + # TODO: Infer direction from stop_sequence, stop_id + stop_info = get_next_stop_in_sequence(last_known_trip_state, trip_id) + if stop_info: + template = { + "attributes": { + "bearing": None, + "current_status": "INCOMING_TO", + "current_stop_sequence": stop_info["stop_sequence"], + "direction_id": direction_id, + "label": target_id, + "latitude": None, + "longitude": None, + "speed": None, + "updated_at": now.isoformat(), + }, + "id": target_id, + "links": {"self": f"/vehicles/{target_id}"}, + "relationships": { + "route": {"data": {"id": route_id, "type": "route"}}, + "stop": {"data": {"id": stop_info["stop_id"], "type": "stop"}}, + "trip": {"data": {"id": trip_id, "type": "trip"}}, + }, + "type": "vehicle", + } + return template + else: + return None + + +@tracer.wrap() +def get_next_stop_in_sequence(last_known_trip_state: TripState, trip_id: str): + """ + Find the next stop in the sequence for a given trip from GTFS data. + + Args: + last_known_trip_state: The last known state of the trip + trip_id: The trip identifier + + Returns: + dict: Next stop information or None if at end of route + """ + current_stop_sequence = last_known_trip_state["stop_sequence"] + + gtfs_archive = gtfs.get_current_gtfs_archive() + # Get stop times for this specific trip + trip_stop_times = gtfs_archive.stop_times[gtfs_archive.stop_times["trip_id"] == trip_id].sort_values( + "stop_sequence" + ) + + # Find the next stop in sequence + next_stops = trip_stop_times[trip_stop_times["stop_sequence"] > current_stop_sequence] + + if len(next_stops) > 0: + next_stop = next_stops.iloc[0] # Get the first (next) stop + + # Get stop name from stops dataframe + stop_name = get_stop_name(gtfs_archive.stops, next_stop["stop_id"]) + + return { + "stop_id": next_stop["stop_id"], + "stop_sequence": next_stop["stop_sequence"], + "stop_name": stop_name, + } + + return None # End of route reached + + +@tracer.wrap() +def get_direction_id(trip_id: str) -> Optional[str]: + gtfs_archive = gtfs.get_current_gtfs_archive() + trips = gtfs_archive.trips[gtfs_archive.trips["trip_id" == trip_id]] + if trips["direction_id"].unique() == 1: + return str(trips["direction_id"].unique()[0]) + else: + return None + + @tracer.wrap() def enrich_event(df: pd.DataFrame, gtfs_archive: gtfs.GtfsArchive): """ diff --git a/src/gobble.py b/src/gobble.py index e09504e..336450c 100644 --- a/src/gobble.py +++ b/src/gobble.py @@ -10,7 +10,7 @@ from constants import ROUTES_BUS, ROUTES_CR, ROUTES_RAPID from config import CONFIG -from event import process_event +from event import process_event, enrich_remove_event from logger import set_up_logging from trip_state import TripsStateManager import gtfs @@ -93,7 +93,16 @@ def client_thread(routes: Set[str]): def process_events(client: sseclient.SSEClient, trips_state: TripsStateManager): for event in client.events(): try: - if event.event != "update": + logger.info(f"[{datetime.now().isoformat()}] Recieved {event.event} event") + if event.event == "update": + update = json.loads(event.data) + process_event(update, trips_state) + if event.event == "remove": + update = json.loads(event.data) + update = enrich_remove_event(update, trips_state) + if update: + process_event(update, trips_state) + else: continue update = json.loads(event.data) process_event(update, trips_state) diff --git a/src/trip_state.py b/src/trip_state.py index b731df6..ad856ae 100644 --- a/src/trip_state.py +++ b/src/trip_state.py @@ -16,6 +16,8 @@ class TripState(TypedDict): Holds the current state of a single trip """ + # What is the id? + id: str # How far into the trip are we? stop_sequence: int # What stop are we at? @@ -173,3 +175,12 @@ def get_trip_state(self, route_id: str, trip_id: str) -> Optional[TripState]: if route_id not in self.route_states: return None return self.route_states[route_id].get_trip_state(trip_id) + + def get_trip_by_id(self, target_id: str) -> tuple[Optional[str], Optional[str], Optional[TripState]]: + # Find a trip state by trip ID across all routes + # for key, val + for route_id, route_state in self.route_states.items(): + for trip_id, trip_state in route_state["trip_states"].items(): + if trip_state["id"] == target_id: + return (route_id, trip_id, self.get_trip_state(route_id, trip_id)) + return (None, None, None)