diff --git a/runners/flashback.py b/runners/flashback.py index 7b338cd8..4c4a1ed2 100644 --- a/runners/flashback.py +++ b/runners/flashback.py @@ -1,3 +1,3 @@ from lamp_py.flashback.pipeline import pipeline -pipeline() +pipeline(local_override_path="/tmp/flashback/") diff --git a/src/lamp_py/flashback/events.py b/src/lamp_py/flashback/events.py index 75ac4595..c49591bd 100644 --- a/src/lamp_py/flashback/events.py +++ b/src/lamp_py/flashback/events.py @@ -4,151 +4,171 @@ import dataframely as dy import polars as pl -from lamp_py.ingestion.convert_gtfs_rt import VehiclePositions +from lamp_py.ingestion.convert_gtfs_rt import VehiclePositions, VehiclePositionsApiFormat from lamp_py.runtime_utils.process_logger import ProcessLogger -class StopEventsTable(dy.Schema): - """Flat events data, with additional information for determining stop departures.""" +class VehicleEvents(VehiclePositions): + """Vehicle Position raw events to be de-duplicated into actual events""" - id = dy.String(primary_key=True) # trip-route-vehicle + event_id = dy.String(primary_key=True) # start_date-trip-route-vehicle + status_start_timestamp = dy.Int64(nullable=True) + status_end_timestamp = dy.Int64(nullable=True) + + +class VehicleStopEvents(dy.Schema): + """Vehicle Position raw events to be de-duplicated into actual events""" + + event_id = dy.String(primary_key=True) # start_date-trip-route-vehicle timestamp = dy.Int64() - start_date = dy.String(nullable=True) - trip_id = dy.String() - direction_id = dy.Int8(min=0, max=1, nullable=True) - route_id = dy.String() - start_time = dy.String(nullable=True) - revenue = dy.Bool(nullable=True) - stop_id = dy.String(nullable=False) + start_date = dy.String(nullable=False) + trip_id = VehicleEvents.trip_id + direction_id = VehicleEvents.direction_id + route_id = VehicleEvents.route_id + start_time = VehicleEvents.start_time + revenue = VehicleEvents.revenue + stop_id = VehicleEvents.stop_id current_stop_sequence = dy.Int16(primary_key=True) - arrived = dy.Int64(nullable=True) - departed = dy.Int64(nullable=True) - latest_stopped_timestamp = dy.Int64(nullable=True) + # remove current status + # renamed status start and stop to arrival and departure for stop events schema + arrived = VehicleEvents.status_start_timestamp + departed = VehicleEvents.status_end_timestamp class StopEventsJSON(dy.Schema): """Pre-serialized stop events for trips.""" - id = dy.String(primary_key=True) + event_id = dy.String(primary_key=True) timestamp = dy.Int64() - start_date = StopEventsTable.start_date - trip_id = StopEventsTable.trip_id - direction_id = StopEventsTable.direction_id - route_id = StopEventsTable.route_id - start_time = StopEventsTable.start_time - revenue = StopEventsTable.revenue + start_date = VehicleStopEvents.start_date + trip_id = VehicleStopEvents.trip_id + direction_id = VehicleStopEvents.direction_id + route_id = VehicleStopEvents.route_id + start_time = VehicleStopEvents.start_time + revenue = VehicleStopEvents.revenue stop_events = dy.List( dy.Struct( inner={ - "stop_id": StopEventsTable.stop_id, - "current_stop_sequence": dy.Int16(), - "arrived": StopEventsTable.arrived, - "departed": StopEventsTable.departed, + "stop_id": VehicleStopEvents.stop_id, + "current_stop_sequence": VehicleStopEvents.current_stop_sequence, + "arrived": VehicleStopEvents.arrived, + "departed": VehicleStopEvents.departed, } ) ) -def unnest_vehicle_positions(vp: dy.DataFrame[VehiclePositions]) -> dy.DataFrame[StopEventsTable]: +def unnest_vehicle_positions(vp: dy.DataFrame[VehiclePositionsApiFormat]) -> dy.DataFrame[VehiclePositions]: """Unnest VehiclePositions data into flat table.""" process_logger = ProcessLogger("unnest_vehicle_positions", input_rows=vp.height) process_logger.log_start() - events = ( + + # it is what it is. note: the struct "vehicle" appears twice. + # the first is a catch all, the 2nd is vehicle_id and vehicle_label. + vehicle_positions = ( vp.select("entity") .explode("entity") .unnest("entity") .unnest("vehicle") .unnest("trip") - .filter( - pl.col("current_stop_sequence").is_not_null(), - pl.col("trip_id").is_not_null(), - pl.col("timestamp").is_not_null(), - pl.col("route_id").is_not_null(), - ) - .select( - pl.concat_str(pl.col("trip_id"), pl.col("route_id"), pl.col("id"), separator="-").alias("id"), - "timestamp", - "start_date", - "trip_id", - "direction_id", - "route_id", - "start_time", - "revenue", - "stop_id", - "current_stop_sequence", - pl.when(pl.col("current_status").eq("STOPPED_AT")).then(pl.col("timestamp")).alias("arrived"), - pl.lit(None).alias("departed"), # for schema adherence - pl.when(pl.col("current_status").eq("STOPPED_AT")) - .then(pl.col("timestamp")) - .alias("latest_stopped_timestamp"), - ) + .rename({"id": "entity_id"}) + .unnest("vehicle") + .rename({"id": "vehicle_id", "label": "vehicle_label"}) + .rename({"entity_id": "id"}) + .unnest("position") ) - valid = process_logger.log_dataframely_filter_results(*StopEventsTable.filter(events, cast=True)) + valid = process_logger.log_dataframely_filter_results(*VehiclePositions.filter(vehicle_positions, cast=True)) process_logger.log_complete() return valid -def update_records( - existing_records: dy.DataFrame[StopEventsTable], - new_records: dy.DataFrame[StopEventsTable], - max_record_age: timedelta, -) -> dy.DataFrame[StopEventsTable]: - """Return a DataFrame of recent stops using VehiclePositions.""" +def vehicle_position_to_archive_events(vp: dy.DataFrame[VehiclePositions]) -> dy.DataFrame[VehicleEvents]: + """ + Convert VehiclePositions data into VehicleEvents format. + + Filters vehicle position records to include only those with valid stop sequences, + trip IDs, timestamps, and route IDs. Generates a composite ID from start_date, + trip_id, route_id, and vehicle id, then selects relevant columns for event archival. + + Start_date is required to have a unique identifier across days, as all other identifiers are reusable. + """ + process_logger = ProcessLogger("vehicle_position_to_archive_events", input_rows=vp.height) + process_logger.log_start() + events = vp.filter( + pl.col("current_stop_sequence").is_not_null(), + pl.col("trip_id").is_not_null(), + pl.col("timestamp").is_not_null(), + pl.col("route_id").is_not_null(), + pl.col("start_date").is_not_null(), + ).with_columns( + pl.concat_str(pl.col("start_date"), pl.col("trip_id"), pl.col("route_id"), pl.col("id"), separator="-").alias( + "event_id" + ), + pl.lit(None).cast(pl.Int64).alias("status_start_timestamp"), + pl.lit(None).cast(pl.Int64).alias("status_end_timestamp"), + ) + + valid = process_logger.log_dataframely_filter_results(*VehicleEvents.filter(events, cast=True)) + + process_logger.log_complete() + + return valid + + +def aggregate_duration_with_new_records( + existing_records: dy.DataFrame[VehicleEvents], + new_records: dy.DataFrame[VehicleEvents], +) -> dy.DataFrame[VehicleEvents]: + """ + Recalculate derived duration fields for stop events based on status changes. + + Merges existing and new stop event records, groups them by vehicle ID and stop + sequence, and calculates the timestamp when each status began and ended. Returns + only records that pass VehicleEvents validation. + + Args: + existing_records: DataFrame of previously processed stop events with status information. + new_records: DataFrame of newly received stop events with status information. + max_record_age: Maximum age threshold for records (currently logged but not actively used in filtering). + + Returns: + DataFrame of stop events with validated derived duration fields (status_start_timestamp + and status_end_timestamp where applicable). + + Note: + Records are sorted by timestamp and grouped by vehicle ID, stop sequence, and current status. + Status end timestamp is only set when the first and last timestamp within a group differ. + """ process_logger = ProcessLogger( - "update_records", existing_records=existing_records.height, max_record_age=str(max_record_age) + "aggregate_duration_with_new_records", + existing_records=existing_records.height, ) process_logger.log_start() + # grab only the records that are still getting updates + existing_merge_records = existing_records.filter(pl.col("event_id").is_in(new_records["event_id"].unique())) + all_events = pl.concat([existing_merge_records, new_records], how="diagonal") + + # for all records at a current stop sequence and status, calculate the start and end times of that status combined = ( - existing_records.filter( # remove old records - datetime.now(tz=ZoneInfo("America/New_York")) - - pl.from_epoch("timestamp").dt.replace_time_zone( - "America/New_York", ambiguous="latest", non_existent="null" - ) - < max_record_age + all_events.sort(by="timestamp") + .group_by("event_id", "current_stop_sequence", "current_status") + .agg( + [ + pl.first("timestamp").alias("status_start_timestamp"), + pl.when(pl.first("timestamp").ne(pl.last("timestamp"))).then( + pl.last("timestamp").alias("status_end_timestamp") + ), + pl.all().exclude("status_start_timestamp", "status_end_timestamp").last(), + # keep the rest of the columns of the most recent one. + ] ) - .join(new_records, on=["id", "current_stop_sequence"], how="full", coalesce=True) - .select( - "id", - "current_stop_sequence", - *[ - pl.coalesce(col, f"{col}_right").alias(col) - for col in [ - "start_date", - "trip_id", - "direction_id", - "route_id", - "start_time", - "revenue", - "stop_id", - "arrived", - ] - ], - pl.coalesce( - pl.when( # if the trip has moved past this stop sequence, set departed to latest_stopped_timestamp - pl.col("current_stop_sequence").max().over("id").gt(pl.col("current_stop_sequence")) - ).then(pl.col("latest_stopped_timestamp")), - "departed", - ).alias("departed"), - pl.coalesce( - pl.when( # if departure is updated, then also update timestamp - pl.col("current_stop_sequence").max().over("id").gt(pl.col("current_stop_sequence")), - pl.col("departed").is_null(), - ).then(pl.col("timestamp_right").max().over("id")), - "timestamp", - "timestamp_right", - ).alias("timestamp"), - pl.coalesce("latest_stopped_timestamp_right", "latest_stopped_timestamp").alias( - "latest_stopped_timestamp" - ), # use value from new record - ) - .filter(pl.col("arrived").is_not_null() | pl.col("departed").is_not_null()) # keep only stops with events ) - valid = process_logger.log_dataframely_filter_results(*StopEventsTable.filter(combined, cast=True)) + valid = process_logger.log_dataframely_filter_results(*VehicleEvents.filter(combined, cast=True)) process_logger.add_metadata(new_records=new_records.height, updated_records=combined.height) @@ -157,10 +177,43 @@ def update_records( return valid -def structure_stop_events(df: dy.DataFrame[StopEventsTable]) -> dy.DataFrame[StopEventsJSON]: +def filter_stop_events( + compressed_events: dy.DataFrame[VehicleEvents], + max_record_age: timedelta, +) -> dy.DataFrame[VehicleStopEvents]: + """ + take compressed events and take only stopped_at events, + and rename the status start and end periods to stop event schema format + """ + + filtered = ( + compressed_events.filter( + (pl.col("current_status") == "STOPPED_AT") + & (pl.col("status_start_timestamp").is_not_null() | pl.col("status_end_timestamp").is_not_null()) + & ( + datetime.now(tz=ZoneInfo("America/New_York")) + - pl.from_epoch("timestamp").dt.replace_time_zone( + "America/New_York", ambiguous="latest", non_existent="null" + ) + < max_record_age + ) # remove records that are older than max_record_age - flashback usecase only requires max_record_age history + ) + .drop("current_status") + .sort("event_id", "current_stop_sequence") + .rename({"status_start_timestamp": "arrived", "status_end_timestamp": "departed"}) + ) + + valid = ProcessLogger("filter_stop_events").log_dataframely_filter_results( + *VehicleStopEvents.filter(filtered, cast=True) + ) + + return valid + + +def structure_stop_events(df: dy.DataFrame[VehicleStopEvents]) -> dy.DataFrame[StopEventsJSON]: """Structure flat table into StopEvents records.""" process_logger = ProcessLogger("structure_stop_events", input_rows=df.height) - stop_events = df.group_by("id").agg( + stop_events = df.group_by("event_id").agg( pl.max("timestamp").alias("timestamp"), pl.selectors.by_name("start_date", "trip_id", "direction_id", "route_id", "start_time", "revenue").first(), pl.struct("stop_id", "current_stop_sequence", "arrived", "departed").alias("stop_events"), diff --git a/src/lamp_py/flashback/io.py b/src/lamp_py/flashback/io.py index 7b6bd7d2..796dff26 100644 --- a/src/lamp_py/flashback/io.py +++ b/src/lamp_py/flashback/io.py @@ -4,26 +4,55 @@ import polars as pl from aiohttp import ClientError, ClientSession -from lamp_py.flashback.events import StopEventsJSON, StopEventsTable -from lamp_py.ingestion.convert_gtfs_rt import VehiclePositions +from lamp_py.flashback.events import ( + StopEventsJSON, + VehicleEvents, + VehicleStopEvents, + unnest_vehicle_positions, +) +from lamp_py.ingestion.convert_gtfs_rt import VehiclePositions, VehiclePositionsApiFormat from lamp_py.runtime_utils.process_logger import ProcessLogger -from lamp_py.runtime_utils.remote_files import S3Location -from lamp_py.runtime_utils.remote_files import stop_events as stop_events_location +from lamp_py.runtime_utils.remote_files import S3Location, vehicle_position_all_events, stop_events -def get_remote_events(location: S3Location = stop_events_location) -> dy.DataFrame[StopEventsTable]: +def get_remote_all_events(location: S3Location = vehicle_position_all_events) -> dy.DataFrame[VehicleEvents]: + """Fetch existing events from S3.""" + process_logger = ProcessLogger("get_remote_all_events") + process_logger.log_start() + try: + remote_events = process_logger.log_dataframely_filter_results( + *VehicleEvents.filter(pl.scan_parquet(location.s3_uri), cast=True) + ) + + existing_events = VehicleEvents.cast( + pl.concat( + [VehicleEvents.create_empty(), remote_events], + how="diagonal", + ) + ) + + except OSError as e: + process_logger.log_warning(e) + existing_events = VehicleEvents.create_empty() + + process_logger.log_complete() + + return existing_events + + +def get_remote_stop_events(location: S3Location = stop_events) -> dy.DataFrame[VehicleStopEvents]: """Fetch existing stop events from S3.""" - process_logger = ProcessLogger("get_remote_events") + process_logger = ProcessLogger("get_remote_stop_events") process_logger.log_start() try: remote_events = process_logger.log_dataframely_filter_results( *StopEventsJSON.filter(pl.scan_parquet(location.s3_uri), cast=True) ) - existing_events = StopEventsTable.cast( + existing_events = VehicleStopEvents.cast( pl.concat( [ - StopEventsTable.create_empty(), + VehicleStopEvents.create_empty(), remote_events.explode("stop_events").unnest("stop_events"), ], how="diagonal", @@ -32,7 +61,7 @@ def get_remote_events(location: S3Location = stop_events_location) -> dy.DataFra except OSError as e: process_logger.log_warning(e) - existing_events = StopEventsTable.create_empty() + existing_events = VehicleStopEvents.create_empty() process_logger.log_complete() @@ -56,23 +85,19 @@ async def get_vehicle_positions( data = await response.read() break except ClientError as e: - process_logger.log_failure(e) + process_logger.log_warning(e) if attempt == max_retries: + process_logger.log_failure(e) raise ClientError(f"Maximum retries ({max_retries}) exceeded") from e sleep(sleep_interval) - vehicle_positions = pl.read_ndjson(data, schema=VehiclePositions.to_polars_schema()) + except Exception as e: + process_logger.log_failure(e) + raise - valid = process_logger.log_dataframely_filter_results(*VehiclePositions.filter(vehicle_positions)) + vehicle_positions = pl.read_ndjson(data, schema=VehiclePositionsApiFormat.to_polars_schema()) + valid = process_logger.log_dataframely_filter_results(*VehiclePositionsApiFormat.filter(vehicle_positions)) process_logger.log_complete() - return valid - - -def write_stop_events(stop_events: dy.DataFrame[StopEventsJSON], location: S3Location = stop_events_location) -> None: - """Write stop events to specified location.""" - process_logger = ProcessLogger("write_stop_events", s3_uri=location.s3_uri) - process_logger.log_start() - stop_events.write_parquet(location.s3_uri, compression_level=9, retries=5, use_pyarrow=True) - process_logger.log_complete() + return unnest_vehicle_positions(valid) diff --git a/src/lamp_py/flashback/pipeline.py b/src/lamp_py/flashback/pipeline.py index 82aca673..75d2ff7a 100644 --- a/src/lamp_py/flashback/pipeline.py +++ b/src/lamp_py/flashback/pipeline.py @@ -6,34 +6,77 @@ import dataframely as dy from lamp_py.aws.ecs import handle_ecs_sigterm -from lamp_py.flashback.events import StopEventsTable, structure_stop_events, unnest_vehicle_positions, update_records -from lamp_py.flashback.io import get_remote_events, get_vehicle_positions, write_stop_events +from lamp_py.flashback.events import ( + VehicleEvents, + filter_stop_events, + structure_stop_events, + aggregate_duration_with_new_records, + vehicle_position_to_archive_events, +) +from lamp_py.flashback.io import get_remote_all_events, get_vehicle_positions from lamp_py.runtime_utils.env_validation import validate_environment from lamp_py.runtime_utils.process_logger import ProcessLogger +from lamp_py.runtime_utils.remote_files import stop_events, vehicle_position_all_events, stop_events_json async def flashback( - remote_events: dy.DataFrame[StopEventsTable], max_record_age: timedelta = timedelta(hours=2) + remote_events: dy.DataFrame[VehicleEvents], + max_record_age: timedelta = timedelta(hours=2), + local_override_path: str | None = None, ) -> None: """Fetch, process, and store stop events.""" - existing_events = remote_events + all_events = remote_events + + while True: process_logger = ProcessLogger("flashback") process_logger.log_start() + + # raw, flat vehicle position new_records = await get_vehicle_positions() - stop_events = update_records(existing_events, unnest_vehicle_positions(new_records), max_record_age) + # add event_id, event duration columns + new_events = vehicle_position_to_archive_events(new_records) + + # combine and update events + compressed_events = aggregate_duration_with_new_records(all_events, new_events) + + # update all_events with the newly compressed events + all_events = all_events.update( # type: ignore[assignment] + compressed_events, on=["event_id", "current_stop_sequence", "current_status"], how="full" + ) + + # take only meaningful stop events for flashback + compressed_stop_events = filter_stop_events(compressed_events, max_record_age) + + process_logger.add_metadata( + new_records=new_records.height, + compressed_events=compressed_events.height, + compressed_stop_events=compressed_stop_events.height, + ) + + if local_override_path: + stop_events_uri = f"{local_override_path}/stop_events.parquet" + stop_events_json_uri = f"{local_override_path}/stop_events.ndjson" + all_events_uri = f"{local_override_path}/vehicle_position_all_events.parquet" + else: + stop_events_uri = stop_events.s3_uri + stop_events_json_uri = stop_events_json.s3_uri + all_events_uri = vehicle_position_all_events.s3_uri - existing_events = stop_events + await asyncio.to_thread(lambda: structure_stop_events(compressed_stop_events).write_parquet(stop_events_uri)) + await asyncio.to_thread( + lambda: structure_stop_events(compressed_stop_events).write_ndjson(stop_events_json_uri) + ) - await asyncio.to_thread(lambda: write_stop_events(structure_stop_events(stop_events))) + await asyncio.to_thread(lambda: all_events.write_parquet(all_events_uri)) process_logger.log_complete() await asyncio.sleep(3) # wait before fetching new data -def pipeline() -> None: +def pipeline(local_override_path: str | None = None) -> None: """Entry point for flashback stop events pipeline.""" process_logger = ProcessLogger("main") process_logger.log_start() @@ -41,7 +84,7 @@ def pipeline() -> None: signal(SIGTERM, handle_ecs_sigterm) # configure the environment - environ["SERVICE_NAME"] = "ingestion" + environ["SERVICE_NAME"] = "flashback_event_service" validate_environment( required_variables=[ @@ -49,4 +92,4 @@ def pipeline() -> None: ], ) - asyncio.run(flashback(get_remote_events())) + asyncio.run(flashback(get_remote_all_events(), local_override_path=local_override_path)) diff --git a/src/lamp_py/ingestion/convert_gtfs_rt.py b/src/lamp_py/ingestion/convert_gtfs_rt.py index 4e5f1fe9..34ce5d35 100644 --- a/src/lamp_py/ingestion/convert_gtfs_rt.py +++ b/src/lamp_py/ingestion/convert_gtfs_rt.py @@ -61,8 +61,8 @@ from lamp_py.utils.filter_bank import FilterBankRtTripUpdates -class VehiclePositions(dy.Schema): - """Structured VehiclePositions message.""" +class VehiclePositionsApiFormat(dy.Schema): + """Api Format of VehiclePositions message.""" entity = dy.List( inner=dy.Struct( @@ -111,6 +111,32 @@ class VehiclePositions(dy.Schema): ) +class VehiclePositions(dy.Schema): + """Flat Format of VehiclePositions message.""" + + id = dy.String(primary_key=True) + trip_id = dy.String(nullable=True) + route_id = dy.String(nullable=True) + direction_id = dy.Int8(min=0, max=1, nullable=True) + start_time = dy.String(nullable=True) + start_date = dy.String(nullable=True) + revenue = dy.Bool(nullable=True) + last_trip = dy.Bool(nullable=True) + schedule_relationship = dy.String(nullable=True) + vehicle_id = dy.String(nullable=True) + vehicle_label = dy.String(nullable=True) # rename this? + bearing = dy.UInt16(nullable=True) + latitude = dy.Float64(nullable=True) + longitude = dy.Float64(nullable=True) + speed = dy.Float64(nullable=True) + current_stop_sequence = dy.Int16(nullable=True) + stop_id = dy.String(nullable=True) + timestamp = dy.Int64(nullable=True) + occupancy_status = dy.String(nullable=True) + occupancy_percentage = dy.UInt32(nullable=True) + current_status = dy.String(nullable=True) + + @dataclass class TableData: """ diff --git a/src/lamp_py/runtime_utils/process_logger.py b/src/lamp_py/runtime_utils/process_logger.py index 8138c29f..728bacd0 100644 --- a/src/lamp_py/runtime_utils/process_logger.py +++ b/src/lamp_py/runtime_utils/process_logger.py @@ -149,10 +149,12 @@ def log_warning(self, exception: Exception) -> None: duration = time.monotonic() - self.start_time self.default_data["status"] = "warned" self.default_data["duration"] = f"{duration:.2f}" + self.default_data["error_type"] = type(exception).__name__ for tb in traceback.format_exception_only(exception): for line in tb.strip("\n").split("\n"): logging.warning(f"uuid={self.default_data["uuid"]}, {line.strip('\n')}") + logging.warning(self._get_log_string()) def log_dataframely_filter_results( self, valid: dy.DataFrame, invalid: dy.FailureInfo, log_level: Optional[int] = logging.WARNING diff --git a/src/lamp_py/runtime_utils/remote_files.py b/src/lamp_py/runtime_utils/remote_files.py index 84e1b3a7..26de413f 100644 --- a/src/lamp_py/runtime_utils/remote_files.py +++ b/src/lamp_py/runtime_utils/remote_files.py @@ -275,8 +275,20 @@ def parquet_path(self, year: Union[str, int], file: str) -> S3Location: prefix=os.path.join(LAMP, "gtfs_archive"), ) +vehicle_position_all_events = S3Location( + bucket=S3_ARCHIVE, + prefix=f"{LAMP}/stop_events/vehicle_position_all_events_v0.parquet", + version="0.1.0", +) + stop_events = S3Location( bucket=S3_ARCHIVE, - prefix=f"{LAMP}/stop_events/stop_events_v0.parquet", + prefix=f"{LAMP}/stop_events/stop_events_v1.parquet", + version="0.1.0", +) + +stop_events_json = S3Location( + bucket=S3_ARCHIVE, + prefix=f"{LAMP}/stop_events/stop_events_v1.json", version="0.1.0", ) diff --git a/tests/flashback/test_events.py b/tests/flashback/test_events.py index f6f52c23..2472df61 100644 --- a/tests/flashback/test_events.py +++ b/tests/flashback/test_events.py @@ -1,12 +1,18 @@ import time -from datetime import datetime, timedelta +from datetime import datetime import dataframely as dy import polars as pl import pytest -from lamp_py.flashback.events import StopEventsTable, structure_stop_events, unnest_vehicle_positions, update_records -from lamp_py.ingestion.convert_gtfs_rt import VehiclePositions +from lamp_py.flashback.events import ( + VehicleEvents, + VehicleStopEvents, + aggregate_duration_with_new_records, + structure_stop_events, + unnest_vehicle_positions, +) +from lamp_py.ingestion.convert_gtfs_rt import VehiclePositionsApiFormat @pytest.mark.parametrize( @@ -56,44 +62,14 @@ "id": "1234", "vehicle": { "trip": { - "start_date": "20231010", - "start_time": "08:00:00", - "direction_id": 1, - "revenue": True, - "last_trip": False, - "schedule_relationship": "SCHEDULED", - }, - "position": { - "latitude": 42.352271, - "longitude": -71.055242, - "bearing": 90.0, + "trip_id": "5678", + "route_id": "red", }, + "position": {}, "vehicle": { "id": "vehicle_1234", "label": "Bus 1234", }, - "current_stop_sequence": 5, - "stop_id": "place-dwnxg", - "timestamp": 1700000000, - "occupancy_status": "MANY_SEATS_AVAILABLE", - "occupancy_percentage": 30, - "current_status": "IN_TRANSIT_TO", - }, - }, - ], - 0, - ), - ( - [ - { - "id": "1234", - "vehicle": { - "trip": { - "trip_id": "5678", - "route_id": "red", - }, - "position": {}, - "vehicle": {}, "stop_id": "123", "current_stop_sequence": 5, "timestamp": 1700000000, @@ -109,222 +85,104 @@ ], ids=[ "complete-data", - "null-primary-keys", "null-non-primary-keys", "empty-entity", ], ) def test_unnest_vehicle_positions(entity: list[dict], valid_records: int) -> None: """It gracefully handles missing and complete data alike.""" - vp = VehiclePositions.validate( - pl.DataFrame([pl.Series(name="entity", values=[entity], dtype=VehiclePositions.entity.dtype)]) + vp = VehiclePositionsApiFormat.validate( + pl.DataFrame([pl.Series(name="entity", values=[entity], dtype=VehiclePositionsApiFormat.entity.dtype)]) ) df = unnest_vehicle_positions(vp) assert df.height == valid_records -@pytest.mark.parametrize( - [ - "existing_record_overrides", - "new_record_overrides", - "expected_events", - ], - [ - ( - { - "id": "foo", - "timestamp": [2_000_000_000 + 1, 2_000_000_000 + 2], # sometime in the future - "current_stop_sequence": [2, 3], - "latest_stopped_timestamp": [2_000_000_000, 2_000_000_000 + 1], - "arrived": [2_000_000_000, 2_000_000_000 + 1], - "departed": [2_000_000_000, None], - }, - { - "id": "foo", - "timestamp": 2_000_000_000 + 2, - "current_stop_sequence": 3, - "latest_stopped_timestamp": 2_000_000_000 + 1, - "arrived": 2_000_000_000 + 1, - "departed": None, - }, - { - ("foo", 2, 2_000_000_000 + 1, 2_000_000_000, 2_000_000_000), - ("foo", 3, 2_000_000_000 + 2, 2_000_000_000 + 1, None), - }, - ), - ( - { - "id": "foo", - "timestamp": [2_000_000_000 + 1, 2_000_000_000 + 2], # sometime in the future - "current_stop_sequence": [2, 3], - "latest_stopped_timestamp": [2_000_000_000, 2_000_000_000 + 1], - "arrived": [2_000_000_000, 2_000_000_000 + 1], - "departed": [2_000_000_000, None], - }, - { - "id": "foo", - "timestamp": 2_000_000_000 + 3, - "current_stop_sequence": 3, - "latest_stopped_timestamp": 2_000_000_000 + 2, - "arrived": 2_000_000_000 + 2, - "departed": None, - }, - { - ("foo", 2, 2_000_000_000 + 1, 2_000_000_000, 2_000_000_000), - ("foo", 3, 2_000_000_000 + 2, 2_000_000_000 + 1, None), - }, - ), - ( - { - "id": "foo", - "timestamp": [2_000_000_000 + 1, 2_000_000_000 + 2], # sometime in the future - "current_stop_sequence": [2, 3], - "latest_stopped_timestamp": [2_000_000_000, 2_000_000_000 + 1], - "arrived": [2_000_000_000, 2_000_000_000 + 1], - "departed": [2_000_000_000, None], - }, - { - "id": "foo", - "timestamp": 2_000_000_000 + 3, - "current_stop_sequence": 4, - "latest_stopped_timestamp": 2_000_000_000 + 2, - "arrived": 2_000_000_000 + 2, - "departed": None, - }, - { - ("foo", 2, 2_000_000_000 + 1, 2_000_000_000, 2_000_000_000), - ("foo", 3, 2_000_000_000 + 3, 2_000_000_000 + 1, 2_000_000_000 + 1), - ("foo", 4, 2_000_000_000 + 3, 2_000_000_000 + 2, None), - }, - ), - ( - { - "id": "foo", - "timestamp": [1_000_000_000 + 1, 1_000_000_000 + 2], # SOMETIME IN THE PAST - "current_stop_sequence": [2, 3], - "latest_stopped_timestamp": [2_000_000_000, 2_000_000_000 + 1], - "arrived": [2_000_000_000, 2_000_000_000 + 1], - "departed": [2_000_000_000, None], - }, - { - "id": "foo", - "timestamp": 2_000_000_000 + 3, - "current_stop_sequence": 4, - "latest_stopped_timestamp": 2_000_000_000 + 2, - "arrived": 2_000_000_000 + 2, - "departed": None, - }, - { - ("foo", 4, 2_000_000_000 + 3, 2_000_000_000 + 2, None), - }, - ), - ( - { - "id": "foo", - "timestamp": [2_000_000_000 + 1, 2_000_000_000 + 2], # sometime in the future - "current_stop_sequence": [2, 3], - "latest_stopped_timestamp": [2_000_000_000, 2_000_000_000 + 1], - "arrived": [2_000_000_000, 2_000_000_000 + 1], - "departed": [2_000_000_000, None], - }, - { - "id": "foo", - "timestamp": 2_000_000_000 + 3, - "current_stop_sequence": 50, - "latest_stopped_timestamp": 2_000_000_000 + 2, - "arrived": 2_000_000_000 + 2, - "departed": None, - }, - { - ("foo", 2, 2_000_000_000 + 1, 2_000_000_000, 2_000_000_000), - ("foo", 3, 2_000_000_000 + 3, 2_000_000_000 + 1, 2_000_000_000 + 1), - ("foo", 50, 2_000_000_000 + 3, 2_000_000_000 + 2, None), - }, - ), - ( - { - "id": "foo", - "timestamp": [2_000_000_000 + 1, 2_000_000_000 + 2], # sometime in the future - "current_stop_sequence": [2, 3], - "latest_stopped_timestamp": [2_000_000_000, 2_000_000_000 + 1], - "arrived": [None, None], - "departed": [None, None], - }, - { - "id": "foo", - "timestamp": 2_000_000_000 + 3, - "current_stop_sequence": 3, - "latest_stopped_timestamp": 2_000_000_000 + 2, - "arrived": 2_000_000_000 + 2, - "departed": None, - }, - { - ("foo", 2, 2_000_000_000 + 3, None, 2_000_000_000), - ("foo", 3, 2_000_000_000 + 2, 2_000_000_000 + 2, None), - }, - ), - ], - ids=[ - "old-records-only", - "same-stop-sequence-newer-timestamp", - "new-stop-sequence", - "outdated-records", - "non-sequential-stop-sequences", - "null-arrived-departed", # is this bad behavior? - ], -) -def test_update_records( - dy_gen: dy.random.Generator, - existing_record_overrides: dict, - new_record_overrides: dict, - expected_events: set[tuple[str, int, int, int | None, int | None]], -) -> None: - """It quickly and correctly updates records.""" - existing_records = StopEventsTable.sample(generator=dy_gen, overrides=existing_record_overrides) - new_records = StopEventsTable.sample(generator=dy_gen, overrides=new_record_overrides) - updated = update_records(existing_records, new_records, timedelta(hours=2)) - updated_set = set( - tuple(i.values()) - for i in updated.select("id", "current_stop_sequence", "timestamp", "arrived", "departed").to_struct().to_list() - ) - - assert updated_set == expected_events +def test_performance_update_records(dy_gen: dy.random.Generator, num_rows: int = 100000) -> None: + """It can handle 1,000,000 existing and new records in under a second.""" + statuses = ["IN_TRANSIT_TO", "STOPPED_AT", "INCOMING_TO"] -def test_performance_update_records(dy_gen: dy.random.Generator, num_rows: int = 1_000_000) -> None: - """It can handle 1,000,000 existing and new records in under a second.""" - existing_records = StopEventsTable.sample( + existing_records = VehicleEvents.sample( num_rows=num_rows, generator=dy_gen, overrides={ "timestamp": dy_gen.sample_int( num_rows, min=int(datetime(1970, 1, 1).timestamp()), max=int(datetime(2039, 1, 1).timestamp()) - ) + ), + "current_stop_sequence": dy_gen.sample_int(num_rows, min=1, max=50), + "current_status": dy_gen.sample_choice(num_rows, choices=statuses), }, ) - new_records = StopEventsTable.sample( - num_rows=num_rows // 10, + new_records_count = 1_000 + new_records = VehicleEvents.sample( + new_records_count, generator=dy_gen, overrides={ + "id": dy_gen.sample_choice(new_records_count, choices=existing_records.select("id").to_series().to_list()), "timestamp": dy_gen.sample_int( - num_rows // 10, min=int(datetime(1970, 1, 1).timestamp()), max=int(datetime(2039, 1, 1).timestamp()) - ) + new_records_count, min=int(datetime(1970, 1, 1).timestamp()), max=int(datetime(2039, 1, 1).timestamp()) + ), + "current_stop_sequence": dy_gen.sample_int(new_records_count, min=1, max=50), + "current_status": dy_gen.sample_choice(new_records_count, choices=statuses), }, ) start = time.time() - _ = update_records(existing_records, new_records, timedelta(hours=2)) + _ = aggregate_duration_with_new_records(existing_records, new_records) duration = time.time() - start assert duration < 1.0 def test_structure_stop_events(dy_gen: dy.random.Generator) -> None: """It correctly chooses the most recent timestamp and the first trip in the id.""" - events_df = StopEventsTable.sample( - num_rows=2, generator=dy_gen, overrides={"id": "foo", "timestamp": [1, 2], "route_id": ["red", "blue"]} + events_df = VehicleStopEvents.sample( + num_rows=2, + generator=dy_gen, + overrides={ + "event_id": "foo", + "timestamp": [1, 2], + "route_id": ["red", "blue"], + "start_date": ["20231010", "20231011"], + }, ) events_json = structure_stop_events(events_df) assert events_json.row(0)[1] == 2 assert events_df.select("start_date", "trip_id", "direction_id", "route_id", "start_time", "revenue").row( 0 ) == events_json.select("start_date", "trip_id", "direction_id", "route_id", "start_time", "revenue").row(0) + + +def test_aggregate_duration_with_new_records(dy_gen: dy.random.Generator) -> None: + """ + Test that aggregate_duration_with_new_records correctly updates event durations. + + Creates initial vehicle events with different statuses and timestamps, then + adds new records to test that the aggregation properly calculates duration + based on the timestamp differences between events. + """ + + events = VehicleEvents.sample( + num_rows=2, + generator=dy_gen, + overrides={ + "id": ["id1", "id2"], + "current_stop_sequence": [1, 1], + "current_status": ["IN_TRANSIT_TO", "STOPPED_AT"], + "timestamp": [100, 200], + }, + ) + + new_records = VehicleEvents.sample( + num_rows=2, + generator=dy_gen, + overrides={ + "id": ["id1", "id2"], + "current_stop_sequence": [1, 1], + "current_status": ["STOPPED_AT", "STOPPED_AT"], + "timestamp": [250, 250], + }, + ) + aggregated = aggregate_duration_with_new_records(events, new_records) + + print(aggregated) diff --git a/tests/flashback/test_io.py b/tests/flashback/test_io.py index bd9550ad..5c95a676 100644 --- a/tests/flashback/test_io.py +++ b/tests/flashback/test_io.py @@ -9,19 +9,20 @@ import polars as pl import pytest from aiohttp import ClientError -from polars.testing import assert_frame_equal from lamp_py.flashback.events import StopEventsJSON -from lamp_py.flashback.io import get_remote_events, get_vehicle_positions, write_stop_events -from lamp_py.ingestion.convert_gtfs_rt import VehiclePositions +from lamp_py.flashback.io import get_remote_stop_events, get_vehicle_positions +from lamp_py.ingestion.convert_gtfs_rt import VehiclePositionsApiFormat from tests.test_resources import LocalS3Location @pytest.fixture(name="mock_vp_response") -def fixture_mock_vp_response(tmp_path: Path) -> Callable[[dy.DataFrame[VehiclePositions]], tuple[AsyncMock, bytes]]: +def fixture_mock_vp_response( + tmp_path: Path, +) -> Callable[[dy.DataFrame[VehiclePositionsApiFormat]], tuple[AsyncMock, bytes]]: """Create mocked vehicle positions HTTP responses from dataframe.""" - def _create(vp: dy.DataFrame[VehiclePositions]) -> tuple[AsyncMock, bytes]: + def _create(vp: dy.DataFrame[VehiclePositionsApiFormat]) -> tuple[AsyncMock, bytes]: json_file = tmp_path.joinpath("test.json") vp.write_ndjson(json_file) @@ -55,9 +56,9 @@ def test_get_remote_events( if raise_network_error: # Simulate networking problems by patching scan_parquet to raise OSError with patch("polars.scan_parquet", side_effect=OSError("Network error")): - df = get_remote_events(test_location) + df = get_remote_stop_events(test_location) else: - df = get_remote_events(test_location) + df = get_remote_stop_events(test_location) assert (file_exists and not raise_network_error) == (df.height > 0) assert (not file_exists or raise_network_error) == (WARNING in [record[1] for record in caplog.record_tuples]) @@ -66,10 +67,10 @@ def test_get_remote_events( @pytest.mark.parametrize( ["overrides", "expected_valid_records", "raise_warning", "raises_error"], [ - ({"id": pl.lit("1")}, 0, True, nullcontext()), - ({"id": pl.col("id")}, 3, False, nullcontext()), - ({"id": pl.Series(values=["1", "1", "2"])}, 1, True, nullcontext()), - pytest.param({"id": pl.col("id").implode()}, 0, False, pytest.raises(pl.exceptions.PolarsError)), + ({"event_id": pl.lit("1")}, 0, True, nullcontext()), + ({"event_id": pl.col("event_id")}, 3, False, nullcontext()), + ({"event_id": pl.Series(values=["1", "1", "2"])}, 1, True, nullcontext()), + pytest.param({"event_id": pl.col("event_id").implode()}, 0, False, pytest.raises(pl.exceptions.PolarsError)), ], ids=["all-invalid", "all-valid", "1-valid", "wrong-schema"], ) @@ -92,7 +93,7 @@ def test_invalid_remote_events_schema( **overrides ).write_parquet(test_location.s3_uri) with raises_error: - df = get_remote_events(test_location) + df = get_remote_stop_events(test_location) assert df.height >= expected_valid_records assert raise_warning == (WARNING in [record[1] for record in caplog.record_tuples]) @@ -115,13 +116,13 @@ async def test_get_vehicle_positions( mock_sleep: AsyncMock, mock_get: AsyncMock, dy_gen: dy.random.Generator, - mock_vp_response: Callable[[dy.DataFrame[VehiclePositions]], tuple[AsyncMock, bytes]], + mock_vp_response: Callable[[dy.DataFrame[VehiclePositionsApiFormat]], tuple[AsyncMock, bytes]], num_failures: int, max_retries: int, caplog: pytest.LogCaptureFixture, ) -> None: """It gracefully handles (successive) non-200 responses.""" - vp = VehiclePositions.sample(generator=dy_gen) + vp = VehiclePositionsApiFormat.sample(generator=dy_gen) success_response, _ = mock_vp_response(vp) # Create mock error response @@ -135,88 +136,12 @@ async def test_get_vehicle_positions( with pytest.raises(ClientError): await get_vehicle_positions(max_retries=max_retries) else: - df = await get_vehicle_positions(max_retries=max_retries) + await get_vehicle_positions(max_retries=max_retries) - assert df.height == 1 assert mock_sleep.call_count == num_failures # Check that failures were logged (status=failed appears in log message) assert "ClientError" in caplog.text failure_logs = [record for record in caplog.record_tuples if "status=failed" in record[2]] - assert len(failure_logs) == num_failures - - -@pytest.mark.parametrize( - ["overrides", "expected_rows", "raises_error", "has_invalid_records"], - [ - ( - {"entity": pl.col("entity").list.eval(pl.element().struct.with_fields(id=pl.lit("a")))}, - 0, - nullcontext(), - True, - ), - ({"entity": pl.col("entity")}, 1, nullcontext(), False), # no change - ( - {"entity": pl.col("entity").list.eval(pl.element().struct.rename_fields(["id", "trip"]))}, - 0, - nullcontext(), - True, - ), # remove vehicle field - row exists but with empty entity list (valid data) - ( - {"entity": pl.col("entity").list.eval(pl.element().struct.with_fields(vehicle=pl.lit(1)))}, - 0, - pytest.raises(pl.exceptions.SchemaError), - True, - ), - ], - ids=["duplicate-primary-keys", "valid-data", "null-vehicle", "invalid-schema"], -) -@pytest.mark.asyncio -@patch("aiohttp.ClientSession.get") -async def test_invalid_vehicle_positions_schema( - mock_get: AsyncMock, - dy_gen: dy.random.Generator, - mock_vp_response: Callable[[dy.DataFrame[VehiclePositions]], tuple[AsyncMock, bytes]], - overrides: dict[str, pl.Expr], - expected_rows: int, - raises_error: pytest.RaisesExc, - has_invalid_records: bool, - caplog: pytest.LogCaptureFixture, -) -> None: - """It filters out events that don't comply with the schema.""" - vp = VehiclePositions.sample(generator=dy_gen).with_columns(**overrides) - mock_response, _ = mock_vp_response(vp) # type: ignore[arg-type] - mock_get.return_value.__aenter__.return_value = mock_response - - with raises_error: - df = await get_vehicle_positions() - - assert df.height == expected_rows - assert has_invalid_records == (WARNING in [record[1] for record in caplog.record_tuples]) - - -@pytest.mark.parametrize( - "should_fail", - [False, True], - ids=["success", "write-failure"], -) -def test_write_stop_events( - dy_gen: dy.random.Generator, - tmp_path: Path, - should_fail: bool, -) -> None: - """It writes stop events to S3 and handles write failures.""" - test_location = LocalS3Location(tmp_path.as_posix(), "test.parquet") - stop_events = StopEventsJSON.sample(2, generator=dy_gen) - - if should_fail: - # Simulate persistent write failure - with patch("polars.DataFrame.write_parquet", side_effect=OSError("S3 write error")): - with pytest.raises(OSError): - write_stop_events(stop_events, test_location) - else: - write_stop_events(stop_events, test_location) + warn_logs = [record for record in caplog.record_tuples if "status=warned" in record[2]] - # Verify file was written successfully - assert Path(test_location.s3_uri).exists() - written_df = pl.read_parquet(test_location.s3_uri) - assert_frame_equal(written_df, stop_events) + assert len(failure_logs) + len(warn_logs) == num_failures