Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 104 additions & 3 deletions src/event.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import json
from datetime import datetime
from typing import Tuple
from datetime import datetime, timezone, timedelta
from typing import Tuple, Optional
import pandas as pd
from ddtrace import tracer
import warnings

from config import CONFIG
from constants import BUS_STOPS, ROUTES_CR, ROUTES_RAPID
from logger import set_up_logging
from trip_state import TripsStateManager
from trip_state import TripsStateManager, TripState

import disk
import gtfs
Expand Down Expand Up @@ -50,6 +50,7 @@ def arr_or_dep_event(

@tracer.wrap()
def reduce_update_event(update: dict) -> Tuple:
id = update["id"]
current_status = update["attributes"]["current_status"]
event_type = EVENT_TYPE_MAP[current_status]
updated_at = datetime.fromisoformat(update["attributes"]["updated_at"])
Expand Down Expand Up @@ -81,6 +82,7 @@ def reduce_update_event(update: dict) -> Tuple:
stop_id = None

return (
id,
current_status,
event_type,
update["attributes"]["current_stop_sequence"],
Expand All @@ -100,6 +102,7 @@ def reduce_update_event(update: dict) -> Tuple:
def process_event(update, trips_state: TripsStateManager):
"""Process a single event from the MBTA's realtime API."""
(
id,
current_status,
event_type,
current_stop_sequence,
Expand Down Expand Up @@ -182,6 +185,7 @@ def process_event(update, trips_state: TripsStateManager):
route_id,
trip_id,
{
"id": id,
"stop_sequence": current_stop_sequence,
"stop_id": stop_id,
"updated_at": updated_at,
Expand All @@ -193,6 +197,103 @@ def process_event(update, trips_state: TripsStateManager):
)


@tracer.wrap()
def enrich_remove_event(update, trips_state: TripsStateManager) -> Optional[dict]:
"""
Enrich remove event to have the expected attributes for the process event function

Args:
update: A dictionary only containing the "id"
trips_state: The trips state

Returns:
dict: Returns Dict that attempts to be equivalent in structure to an update event.
"""
target_id = update["id"]
eastern = timezone(timedelta(hours=-4))
now = datetime.now(eastern)
route_id, trip_id, last_known_trip_state = trips_state.get_trip_by_id(target_id)

if last_known_trip_state:
if trip_id:
direction_id = get_direction_id(trip_id)
# TODO: Infer direction from stop_sequence, stop_id
stop_info = get_next_stop_in_sequence(last_known_trip_state, trip_id)
if stop_info:
template = {
"attributes": {
"bearing": None,
"current_status": "INCOMING_TO",
"current_stop_sequence": stop_info["stop_sequence"],
"direction_id": direction_id,
"label": target_id,
"latitude": None,
"longitude": None,
"speed": None,
"updated_at": now.isoformat(),
},
"id": target_id,
"links": {"self": f"/vehicles/{target_id}"},
"relationships": {
"route": {"data": {"id": route_id, "type": "route"}},
"stop": {"data": {"id": stop_info["stop_id"], "type": "stop"}},
"trip": {"data": {"id": trip_id, "type": "trip"}},
},
"type": "vehicle",
}
return template
else:
return None


@tracer.wrap()
def get_next_stop_in_sequence(last_known_trip_state: TripState, trip_id: str):
"""
Find the next stop in the sequence for a given trip from GTFS data.

Args:
last_known_trip_state: The last known state of the trip
trip_id: The trip identifier

Returns:
dict: Next stop information or None if at end of route
"""
current_stop_sequence = last_known_trip_state["stop_sequence"]

gtfs_archive = gtfs.get_current_gtfs_archive()
# Get stop times for this specific trip
trip_stop_times = gtfs_archive.stop_times[gtfs_archive.stop_times["trip_id"] == trip_id].sort_values(
"stop_sequence"
)

# Find the next stop in sequence
next_stops = trip_stop_times[trip_stop_times["stop_sequence"] > current_stop_sequence]

if len(next_stops) > 0:
next_stop = next_stops.iloc[0] # Get the first (next) stop

# Get stop name from stops dataframe
stop_name = get_stop_name(gtfs_archive.stops, next_stop["stop_id"])

return {
"stop_id": next_stop["stop_id"],
"stop_sequence": next_stop["stop_sequence"],
"stop_name": stop_name,
}

return None # End of route reached


@tracer.wrap()
def get_direction_id(trip_id: str) -> Optional[str]:
gtfs_archive = gtfs.get_current_gtfs_archive()
trips = gtfs_archive.trips[gtfs_archive.trips["trip_id" == trip_id]]
if trips["direction_id"].unique() == 1:
return str(trips["direction_id"].unique()[0])
else:
return None


@tracer.wrap()
def enrich_event(df: pd.DataFrame, gtfs_archive: gtfs.GtfsArchive):
"""
Expand Down
13 changes: 11 additions & 2 deletions src/gobble.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from constants import ROUTES_BUS, ROUTES_CR, ROUTES_RAPID
from config import CONFIG
from event import process_event
from event import process_event, enrich_remove_event
from logger import set_up_logging
from trip_state import TripsStateManager
import gtfs
Expand Down Expand Up @@ -93,7 +93,16 @@ def client_thread(routes: Set[str]):
def process_events(client: sseclient.SSEClient, trips_state: TripsStateManager):
for event in client.events():
try:
if event.event != "update":
logger.info(f"[{datetime.now().isoformat()}] Recieved {event.event} event")
if event.event == "update":
update = json.loads(event.data)
process_event(update, trips_state)
if event.event == "remove":
update = json.loads(event.data)
update = enrich_remove_event(update, trips_state)
if update:
process_event(update, trips_state)
else:
continue
update = json.loads(event.data)
process_event(update, trips_state)
Expand Down
11 changes: 11 additions & 0 deletions src/trip_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ class TripState(TypedDict):
Holds the current state of a single trip
"""

# What is the id?
id: str
# How far into the trip are we?
stop_sequence: int
# What stop are we at?
Expand Down Expand Up @@ -173,3 +175,12 @@ def get_trip_state(self, route_id: str, trip_id: str) -> Optional[TripState]:
if route_id not in self.route_states:
return None
return self.route_states[route_id].get_trip_state(trip_id)

def get_trip_by_id(self, target_id: str) -> tuple[Optional[str], Optional[str], Optional[TripState]]:
# Find a trip state by trip ID across all routes
# for key, val
for route_id, route_state in self.route_states.items():
for trip_id, trip_state in route_state["trip_states"].items():
if trip_state["id"] == target_id:
return (route_id, trip_id, self.get_trip_state(route_id, trip_id))
return (None, None, None)
Loading