Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions src/lamp_py/ingestion/config_rt_alerts.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,81 @@ class AlertsRecord(FeedMessage):
)


class ProposedAlertsRecord(FeedMessage):
"""Each Alert message from GTFS-RT."""

entity = dy.List(
inner=dy.Struct(
inner={
"id": dy.String(min_length=1),
"alert": dy.Struct(
inner={
"active_period": dy.List(
dy.Struct(
inner={
"start": dy.UInt64(nullable=True),
"end": dy.UInt64(nullable=True),
}
),
nullable=True,
),
"informed_entity": dy.List(
dy.Struct(
inner={
"agency_id": dy.String(nullable=True),
"route_id": dy.String(nullable=True),
"route_type": dy.Int32(nullable=True),
"direction_id": dy.UInt8(nullable=True),
"trip": with_nullable(trip_descriptor, nullable=True),
"stop_id": dy.String(nullable=True),
"facility_id": dy.String(nullable=True),
"activities": dy.List(dy.String(), nullable=True), # MBTA Enhanced field
}
),
nullable=True,
),
"cause": dy.String(nullable=False),
"cause_detail": translated_string,
Comment thread
runkelcorey marked this conversation as resolved.
"effect": dy.String(nullable=True),
"effect_detail": translated_string,
"url": translated_string,
"header_text": with_nullable(translated_string, nullable=False),
"description_text": with_nullable(translated_string, nullable=False),
"severity_level": dy.String(nullable=True),
"severity": dy.UInt16(nullable=True), # MBTA Enhanced field
"created_timestamp": dy.UInt64(nullable=True), # MBTA Enhanced field
"last_modified_timestamp": dy.UInt64(nullable=True), # MBTA Enhanced field
"last_push_notification_timestamp": dy.UInt64(nullable=True), # MBTA Enhanced field
"closed_timestamp": dy.Int64(nullable=True), # MBTA Enhanced field
"alert_lifecycle": dy.String(nullable=True), # MBTA Enhanced field
"duration_certainty": dy.String(nullable=True), # MBTA Enhanced field
"reminder_times": dy.List(dy.UInt64(), nullable=True), # MBTA Enhanced field
"short_header_text": translated_string, # not in message Alert struct spec
"service_effect_text": translated_string, # MBTA Enhanced field
"timeframe_text": translated_string, # MBTA Enhanced field
"recurrence_text": translated_string, # MBTA Enhanced field
},
nullable=False,
),
}
),
min_length=1,
)


class AlertsTable(FeedEntityTable):
"""Flattened Alerts data."""

alert_cause = dy.String(alias="alert.cause")
alert_cause_detail = dy.String(nullable=True, alias="alert.cause_detail")
alert_cause_detail_translation = with_alias(
translated_string.inner["translation"], "alert.cause_detail.translation"
)
alert_effect = dy.String(nullable=True, alias="alert.effect")
alert_effect_detail = dy.String(nullable=True, alias="alert.effect_detail")
alert_effect_detail_translation = with_alias(
Comment thread
runkelcorey marked this conversation as resolved.
translated_string.inner["translation"], "alert.effect_detail.translation"
)
alert_severity_level = dy.String(nullable=True, alias="alert.severity_level")
alert_severity = dy.UInt16(nullable=True, alias="alert.severity")
alert_created_timestamp = dy.UInt64(nullable=True, alias="alert.created_timestamp")
Expand Down Expand Up @@ -158,3 +226,10 @@ def table_sort_order(self) -> List[Tuple[str, str]]:
("alert.effect", "ascending"),
("feed_timestamp", "ascending"),
]

def transform_for_write(self, table: pyarrow.Table) -> pyarrow.Table:
"""Flatten table, then add columns defined in table_schema if they aren't already in the flattened table."""
table = super().transform_for_write(table)
expected_table = AlertsTable.create_empty().to_arrow()
unioned_table = pyarrow.concat_tables([table, expected_table], promote_options="permissive")
return unioned_table
15 changes: 13 additions & 2 deletions src/lamp_py/ingestion/convert_gtfs_rt.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
)
from lamp_py.runtime_utils.process_logger import ProcessLogger

from lamp_py.ingestion.config_rt_alerts import RtAlertsDetail
from lamp_py.ingestion.config_rt_alerts import RtAlertsDetail, ProposedAlertsRecord
from lamp_py.ingestion.config_busloc_trip import RtBusTripDetail
from lamp_py.ingestion.config_busloc_vehicle import RtBusVehicleDetail
from lamp_py.ingestion.config_rt_trip import RtTripDetail
Expand Down Expand Up @@ -334,7 +334,18 @@ def gz_to_pyarrow(self, filename: str) -> Tuple[Optional[datetime], str, Optiona
feed_timestamp = json_data["header"]["timestamp"]
timestamp = datetime.fromtimestamp(feed_timestamp, timezone.utc)

table = pyarrow.Table.from_pylist(json_data["entity"], schema=self.detail.import_schema)
try:
table = pyarrow.Table.from_pylist(json_data["entity"], schema=self.detail.import_schema)
except pyarrow.ArrowTypeError as e:
if self.config_type == ConfigType.RT_ALERTS:
table = pyarrow.Table.from_pylist(
json_data["entity"],
schema=pyarrow.schema(
[v.pyarrow_field(k) for k, v in ProposedAlertsRecord.entity.inner.inner.items()] # type: ignore[attr-defined]
),
)
else:
raise e
Comment thread
huangh marked this conversation as resolved.

table = table.append_column(
"year",
Expand Down
2 changes: 1 addition & 1 deletion src/lamp_py/ingestion/gtfs_rt_structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
"translation": dy.List(
dy.Struct(
inner={
"text": dy.String(nullable=True),
"text": dy.String(),
"language": dy.String(nullable=True),
}
),
Expand Down
36 changes: 23 additions & 13 deletions tests/ingestion/test_gtfs_rt_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from pyarrow import fs

from lamp_py.ingestion.gtfs_rt_detail import FeedMessage
from lamp_py.ingestion.config_busloc_vehicle import BusLocVehicleRecord
from lamp_py.ingestion.config_rt_alerts import AlertsRecord, ProposedAlertsRecord
from lamp_py.ingestion.convert_gtfs_rt import GtfsRtConverter
from lamp_py.ingestion.converter import ConfigType
from lamp_py.ingestion.utils import flatten_table_schema
Expand Down Expand Up @@ -325,8 +327,6 @@ def test_file_conversion(
== 0
), "Some ids in the original message are missing from the converted table."

assert not expected_converter.detail.table_schema.validate(table).is_empty()


def test_bus_trip_updates_file_conversion() -> None:
"""
Expand Down Expand Up @@ -367,10 +367,14 @@ def test_bus_trip_updates_file_conversion() -> None:


@pytest.mark.parametrize(
"config_type",
[
ConfigType.BUS_VEHICLE_POSITIONS,
ConfigType.RT_ALERTS,
"config_type",
"input_schemas",
],
[
(ConfigType.BUS_VEHICLE_POSITIONS, [BusLocVehicleRecord, BusLocVehicleRecord]),
(ConfigType.RT_ALERTS, [AlertsRecord, AlertsRecord]),
(ConfigType.RT_ALERTS, [AlertsRecord, ProposedAlertsRecord]),
],
)
@pytest.mark.parametrize(
Expand All @@ -379,6 +383,7 @@ def test_bus_trip_updates_file_conversion() -> None:
)
def test_convert(
config_type: ConfigType,
input_schemas: list[type[FeedMessage]],
timestamp: list[datetime],
dy_gen: dy.random.Generator,
monkeypatch: pytest.MonkeyPatch,
Expand All @@ -388,10 +393,9 @@ def test_convert(
monkeypatch.setattr("lamp_py.ingestion.convert_gtfs_rt.move_s3_objects", lambda files, __: files)
monkeypatch.setattr("lamp_py.ingestion.convert_gtfs_rt.upload_file", create_mock_upload_file(tmp_path))
dfs = []
for ts in timestamp:
for i, ts in enumerate(timestamp):
converter = GtfsRtConverter(config_type, Queue())
assert converter.detail.record_schema is not None
df = gtfs_rt_factory(converter.detail.record_schema, dy_gen, ts)
df = gtfs_rt_factory(input_schemas[i], dy_gen, ts)

incoming_file = tmp_path / f"{ts.isoformat()}.json.gz"

Expand All @@ -414,11 +418,17 @@ def test_convert(
]
)

expected_records = unnest_all_structs(
pl.union(dfs)
.select("entity", pl.col("header").struct.field("timestamp").alias("feed_timestamp"))
.explode("entity")
.unnest("entity")
expected_records = pl.concat(
[
unnest_all_structs(
df.select("entity", pl.col("header").struct.field("timestamp").alias("feed_timestamp"))
.explode("entity")
.unnest("entity")
)
for df in dfs
]
+ [converter.detail.table_schema.create_empty()],
how="diagonal_relaxed",
)

assert_frame_equal(converted_records, expected_records, check_row_order=False, check_column_order=False)
Expand Down
18 changes: 14 additions & 4 deletions tests/performance_manager/test_alerts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
import datetime
import os

from pathlib import Path
from typing import List, Tuple, Dict, Optional, Union

import pandas
import polars as pl

from lamp_py.ingestion.config_rt_alerts import AlertsTable
from lamp_py.performance_manager.alerts import (
extract_alerts,
transform_translations,
Expand All @@ -16,7 +19,7 @@
)
from lamp_py.performance_manager.gtfs_utils import BOSTON_TZ_ZONEINFO

from ..test_resources import springboard_dir
from tests.test_resources import springboard_dir


def generate_sample_translations(columns: List[str]) -> pandas.DataFrame:
Expand Down Expand Up @@ -363,7 +366,7 @@ def test_explode_informed_entity() -> None:
assert set(values) == set(options), f"{column} has different values"


def test_etl() -> None:
def test_etl(tmp_path: Path) -> None:
"""
Test that the entire ETL pipeline can be used without throwing and that it
will be impacted by existing alerts that are passed into the extract_alerts
Expand All @@ -379,18 +382,25 @@ def test_etl() -> None:
"6ef6922c20064cb9a8f09a3b3b1d2783-0.parquet",
)

temp_path = tmp_path.joinpath("test_alerts.parquet").as_posix()
(
pl.read_parquet(test_file)
.match_to_schema(AlertsTable.to_polars_schema(), missing_struct_fields="insert", missing_columns="insert")
.write_parquet(temp_path)
)

key_columns = ["id", "last_modified_timestamp"]
existing = pandas.DataFrame(columns=key_columns)

alerts = extract_alerts(alert_files=[test_file], existing_id_timestamp_pairs=existing)
alerts = extract_alerts(alert_files=[temp_path], existing_id_timestamp_pairs=existing)
alerts = transform_translations(alerts)
alerts = transform_timestamps(alerts)
alerts = explode_active_periods(alerts)
alerts = explode_informed_entity(alerts)

# process it a second time with some of the id / lm timestamp pairs to filter against.
existing = alerts[key_columns].drop_duplicates().head(5)
alerts_2 = extract_alerts(alert_files=[test_file], existing_id_timestamp_pairs=existing)
alerts_2 = extract_alerts(alert_files=[temp_path], existing_id_timestamp_pairs=existing)
alerts_2 = transform_translations(alerts_2)
alerts_2 = transform_timestamps(alerts_2)
alerts_2 = explode_active_periods(alerts_2)
Expand Down
Loading