diff --git a/src/lamp_py/ingestion/config_rt_alerts.py b/src/lamp_py/ingestion/config_rt_alerts.py index 8e824ec7..708e5ddc 100644 --- a/src/lamp_py/ingestion/config_rt_alerts.py +++ b/src/lamp_py/ingestion/config_rt_alerts.py @@ -76,13 +76,81 @@ class AlertsRecord(FeedMessage): ) +class ProposedAlertsRecord(FeedMessage): + """Each Alert message from GTFS-RT.""" + + entity = dy.List( + inner=dy.Struct( + inner={ + "id": dy.String(min_length=1), + "alert": dy.Struct( + inner={ + "active_period": dy.List( + dy.Struct( + inner={ + "start": dy.UInt64(nullable=True), + "end": dy.UInt64(nullable=True), + } + ), + nullable=True, + ), + "informed_entity": dy.List( + dy.Struct( + inner={ + "agency_id": dy.String(nullable=True), + "route_id": dy.String(nullable=True), + "route_type": dy.Int32(nullable=True), + "direction_id": dy.UInt8(nullable=True), + "trip": with_nullable(trip_descriptor, nullable=True), + "stop_id": dy.String(nullable=True), + "facility_id": dy.String(nullable=True), + "activities": dy.List(dy.String(), nullable=True), # MBTA Enhanced field + } + ), + nullable=True, + ), + "cause": dy.String(nullable=False), + "cause_detail": translated_string, + "effect": dy.String(nullable=True), + "effect_detail": translated_string, + "url": translated_string, + "header_text": with_nullable(translated_string, nullable=False), + "description_text": with_nullable(translated_string, nullable=False), + "severity_level": dy.String(nullable=True), + "severity": dy.UInt16(nullable=True), # MBTA Enhanced field + "created_timestamp": dy.UInt64(nullable=True), # MBTA Enhanced field + "last_modified_timestamp": dy.UInt64(nullable=True), # MBTA Enhanced field + "last_push_notification_timestamp": dy.UInt64(nullable=True), # MBTA Enhanced field + "closed_timestamp": dy.Int64(nullable=True), # MBTA Enhanced field + "alert_lifecycle": dy.String(nullable=True), # MBTA Enhanced field + "duration_certainty": dy.String(nullable=True), # MBTA Enhanced field + "reminder_times": dy.List(dy.UInt64(), nullable=True), # MBTA Enhanced field + "short_header_text": translated_string, # not in message Alert struct spec + "service_effect_text": translated_string, # MBTA Enhanced field + "timeframe_text": translated_string, # MBTA Enhanced field + "recurrence_text": translated_string, # MBTA Enhanced field + }, + nullable=False, + ), + } + ), + min_length=1, + ) + + class AlertsTable(FeedEntityTable): """Flattened Alerts data.""" alert_cause = dy.String(alias="alert.cause") alert_cause_detail = dy.String(nullable=True, alias="alert.cause_detail") + alert_cause_detail_translation = with_alias( + translated_string.inner["translation"], "alert.cause_detail.translation" + ) alert_effect = dy.String(nullable=True, alias="alert.effect") alert_effect_detail = dy.String(nullable=True, alias="alert.effect_detail") + alert_effect_detail_translation = with_alias( + translated_string.inner["translation"], "alert.effect_detail.translation" + ) alert_severity_level = dy.String(nullable=True, alias="alert.severity_level") alert_severity = dy.UInt16(nullable=True, alias="alert.severity") alert_created_timestamp = dy.UInt64(nullable=True, alias="alert.created_timestamp") @@ -158,3 +226,10 @@ def table_sort_order(self) -> List[Tuple[str, str]]: ("alert.effect", "ascending"), ("feed_timestamp", "ascending"), ] + + def transform_for_write(self, table: pyarrow.Table) -> pyarrow.Table: + """Flatten table, then add columns defined in table_schema if they aren't already in the flattened table.""" + table = super().transform_for_write(table) + expected_table = AlertsTable.create_empty().to_arrow() + unioned_table = pyarrow.concat_tables([table, expected_table], promote_options="permissive") + return unioned_table diff --git a/src/lamp_py/ingestion/convert_gtfs_rt.py b/src/lamp_py/ingestion/convert_gtfs_rt.py index d1d6cb18..52ba87a4 100644 --- a/src/lamp_py/ingestion/convert_gtfs_rt.py +++ b/src/lamp_py/ingestion/convert_gtfs_rt.py @@ -39,7 +39,7 @@ ) from lamp_py.runtime_utils.process_logger import ProcessLogger -from lamp_py.ingestion.config_rt_alerts import RtAlertsDetail +from lamp_py.ingestion.config_rt_alerts import RtAlertsDetail, ProposedAlertsRecord from lamp_py.ingestion.config_busloc_trip import RtBusTripDetail from lamp_py.ingestion.config_busloc_vehicle import RtBusVehicleDetail from lamp_py.ingestion.config_rt_trip import RtTripDetail @@ -334,7 +334,18 @@ def gz_to_pyarrow(self, filename: str) -> Tuple[Optional[datetime], str, Optiona feed_timestamp = json_data["header"]["timestamp"] timestamp = datetime.fromtimestamp(feed_timestamp, timezone.utc) - table = pyarrow.Table.from_pylist(json_data["entity"], schema=self.detail.import_schema) + try: + table = pyarrow.Table.from_pylist(json_data["entity"], schema=self.detail.import_schema) + except pyarrow.ArrowTypeError as e: + if self.config_type == ConfigType.RT_ALERTS: + table = pyarrow.Table.from_pylist( + json_data["entity"], + schema=pyarrow.schema( + [v.pyarrow_field(k) for k, v in ProposedAlertsRecord.entity.inner.inner.items()] # type: ignore[attr-defined] + ), + ) + else: + raise e table = table.append_column( "year", diff --git a/src/lamp_py/ingestion/gtfs_rt_structs.py b/src/lamp_py/ingestion/gtfs_rt_structs.py index 8f90a634..7ae77e00 100644 --- a/src/lamp_py/ingestion/gtfs_rt_structs.py +++ b/src/lamp_py/ingestion/gtfs_rt_structs.py @@ -50,7 +50,7 @@ "translation": dy.List( dy.Struct( inner={ - "text": dy.String(nullable=True), + "text": dy.String(), "language": dy.String(nullable=True), } ), diff --git a/tests/ingestion/test_gtfs_rt_converter.py b/tests/ingestion/test_gtfs_rt_converter.py index 0fb349f3..4d679ff6 100644 --- a/tests/ingestion/test_gtfs_rt_converter.py +++ b/tests/ingestion/test_gtfs_rt_converter.py @@ -19,6 +19,8 @@ from pyarrow import fs from lamp_py.ingestion.gtfs_rt_detail import FeedMessage +from lamp_py.ingestion.config_busloc_vehicle import BusLocVehicleRecord +from lamp_py.ingestion.config_rt_alerts import AlertsRecord, ProposedAlertsRecord from lamp_py.ingestion.convert_gtfs_rt import GtfsRtConverter from lamp_py.ingestion.converter import ConfigType from lamp_py.ingestion.utils import flatten_table_schema @@ -325,8 +327,6 @@ def test_file_conversion( == 0 ), "Some ids in the original message are missing from the converted table." - assert not expected_converter.detail.table_schema.validate(table).is_empty() - def test_bus_trip_updates_file_conversion() -> None: """ @@ -367,10 +367,14 @@ def test_bus_trip_updates_file_conversion() -> None: @pytest.mark.parametrize( - "config_type", [ - ConfigType.BUS_VEHICLE_POSITIONS, - ConfigType.RT_ALERTS, + "config_type", + "input_schemas", + ], + [ + (ConfigType.BUS_VEHICLE_POSITIONS, [BusLocVehicleRecord, BusLocVehicleRecord]), + (ConfigType.RT_ALERTS, [AlertsRecord, AlertsRecord]), + (ConfigType.RT_ALERTS, [AlertsRecord, ProposedAlertsRecord]), ], ) @pytest.mark.parametrize( @@ -379,6 +383,7 @@ def test_bus_trip_updates_file_conversion() -> None: ) def test_convert( config_type: ConfigType, + input_schemas: list[type[FeedMessage]], timestamp: list[datetime], dy_gen: dy.random.Generator, monkeypatch: pytest.MonkeyPatch, @@ -388,10 +393,9 @@ def test_convert( monkeypatch.setattr("lamp_py.ingestion.convert_gtfs_rt.move_s3_objects", lambda files, __: files) monkeypatch.setattr("lamp_py.ingestion.convert_gtfs_rt.upload_file", create_mock_upload_file(tmp_path)) dfs = [] - for ts in timestamp: + for i, ts in enumerate(timestamp): converter = GtfsRtConverter(config_type, Queue()) - assert converter.detail.record_schema is not None - df = gtfs_rt_factory(converter.detail.record_schema, dy_gen, ts) + df = gtfs_rt_factory(input_schemas[i], dy_gen, ts) incoming_file = tmp_path / f"{ts.isoformat()}.json.gz" @@ -414,11 +418,17 @@ def test_convert( ] ) - expected_records = unnest_all_structs( - pl.union(dfs) - .select("entity", pl.col("header").struct.field("timestamp").alias("feed_timestamp")) - .explode("entity") - .unnest("entity") + expected_records = pl.concat( + [ + unnest_all_structs( + df.select("entity", pl.col("header").struct.field("timestamp").alias("feed_timestamp")) + .explode("entity") + .unnest("entity") + ) + for df in dfs + ] + + [converter.detail.table_schema.create_empty()], + how="diagonal_relaxed", ) assert_frame_equal(converted_records, expected_records, check_row_order=False, check_column_order=False) diff --git a/tests/performance_manager/test_alerts.py b/tests/performance_manager/test_alerts.py index 10bb3326..4af68e50 100644 --- a/tests/performance_manager/test_alerts.py +++ b/tests/performance_manager/test_alerts.py @@ -3,10 +3,13 @@ import datetime import os +from pathlib import Path from typing import List, Tuple, Dict, Optional, Union import pandas +import polars as pl +from lamp_py.ingestion.config_rt_alerts import AlertsTable from lamp_py.performance_manager.alerts import ( extract_alerts, transform_translations, @@ -16,7 +19,7 @@ ) from lamp_py.performance_manager.gtfs_utils import BOSTON_TZ_ZONEINFO -from ..test_resources import springboard_dir +from tests.test_resources import springboard_dir def generate_sample_translations(columns: List[str]) -> pandas.DataFrame: @@ -363,7 +366,7 @@ def test_explode_informed_entity() -> None: assert set(values) == set(options), f"{column} has different values" -def test_etl() -> None: +def test_etl(tmp_path: Path) -> None: """ Test that the entire ETL pipeline can be used without throwing and that it will be impacted by existing alerts that are passed into the extract_alerts @@ -379,10 +382,17 @@ def test_etl() -> None: "6ef6922c20064cb9a8f09a3b3b1d2783-0.parquet", ) + temp_path = tmp_path.joinpath("test_alerts.parquet").as_posix() + ( + pl.read_parquet(test_file) + .match_to_schema(AlertsTable.to_polars_schema(), missing_struct_fields="insert", missing_columns="insert") + .write_parquet(temp_path) + ) + key_columns = ["id", "last_modified_timestamp"] existing = pandas.DataFrame(columns=key_columns) - alerts = extract_alerts(alert_files=[test_file], existing_id_timestamp_pairs=existing) + alerts = extract_alerts(alert_files=[temp_path], existing_id_timestamp_pairs=existing) alerts = transform_translations(alerts) alerts = transform_timestamps(alerts) alerts = explode_active_periods(alerts) @@ -390,7 +400,7 @@ def test_etl() -> None: # process it a second time with some of the id / lm timestamp pairs to filter against. existing = alerts[key_columns].drop_duplicates().head(5) - alerts_2 = extract_alerts(alert_files=[test_file], existing_id_timestamp_pairs=existing) + alerts_2 = extract_alerts(alert_files=[temp_path], existing_id_timestamp_pairs=existing) alerts_2 = transform_translations(alerts_2) alerts_2 = transform_timestamps(alerts_2) alerts_2 = explode_active_periods(alerts_2)