diff --git a/src/lamp_py/bus_performance_manager/combined_bus_schedule.py b/src/lamp_py/bus_performance_manager/combined_bus_schedule.py index 57dc5134..b059f138 100644 --- a/src/lamp_py/bus_performance_manager/combined_bus_schedule.py +++ b/src/lamp_py/bus_performance_manager/combined_bus_schedule.py @@ -145,7 +145,7 @@ def join_tm_schedule_to_gtfs_schedule( ) .with_columns( pl.col("plan_stop_departure_dt") - .sub(pl.col("service_date").cast(pl.Datetime)) + .sub(pl.col("service_date").cast(pl.Datetime(time_zone="UTC"))) .dt.total_seconds() .alias("plan_stop_departure_sam"), pl.struct("plan_stop_departure_dt", "tm_stop_sequence", "gtfs_stop_sequence", "plan_start_dt") diff --git a/src/lamp_py/bus_performance_manager/events_gtfs_schedule.py b/src/lamp_py/bus_performance_manager/events_gtfs_schedule.py index fc4fa38b..9b62cfe4 100644 --- a/src/lamp_py/bus_performance_manager/events_gtfs_schedule.py +++ b/src/lamp_py/bus_performance_manager/events_gtfs_schedule.py @@ -34,8 +34,8 @@ class GTFSBusSchedule(BusBaseSchema): plan_travel_time_seconds = dy.Int64(nullable=True) plan_route_direction_headway_seconds = dy.Int64(nullable=True) plan_direction_destination_headway_seconds = dy.Int64(nullable=True) - plan_start_dt = dy.Datetime(nullable=True, time_zone=None) - plan_stop_departure_dt = dy.Datetime(nullable=False, time_zone=None) + plan_start_dt = dy.Datetime(nullable=True, time_zone="UTC") + plan_stop_departure_dt = dy.Datetime(nullable=False, time_zone="UTC") service_date = dy.Date(primary_key=True) @@ -227,14 +227,14 @@ def stop_events_for_date(service_date: date) -> pl.DataFrame: pl.datetime(service_date.year, service_date.month, service_date.day) + pl.duration(seconds=pl.col("plan_start_time")) ) - .alias("plan_start_dt") - .dt.replace_time_zone(None), + .dt.replace_time_zone("UTC") + .alias("plan_start_dt"), ( pl.datetime(service_date.year, service_date.month, service_date.day) + pl.duration(seconds=pl.col("departure_seconds")) ) - .alias("plan_stop_departure_dt") - .dt.replace_time_zone(None), + .dt.replace_time_zone("UTC") + .alias("plan_stop_departure_dt"), ) .drop( "arrival_time", diff --git a/src/lamp_py/bus_performance_manager/events_metrics.py b/src/lamp_py/bus_performance_manager/events_metrics.py index e03105d6..8648557e 100644 --- a/src/lamp_py/bus_performance_manager/events_metrics.py +++ b/src/lamp_py/bus_performance_manager/events_metrics.py @@ -206,6 +206,9 @@ def calculate_derived_bus_performance_metrics( .alias("gtfs_first_in_transit_seconds"), (pl.col("stop_arrival_dt") - pl.col("service_date")).dt.total_seconds().alias("stop_arrival_seconds"), (pl.col("stop_departure_dt") - pl.col("service_date")).dt.total_seconds().alias("stop_departure_seconds"), + (pl.col("plan_stop_departure_dt") - pl.col("service_date")) + .dt.total_seconds() + .alias("plan_stop_departure_seconds"), ) # add metrics columns to events .with_columns( @@ -220,8 +223,8 @@ def calculate_derived_bus_performance_metrics( .alias("travel_time_seconds"), (pl.col("stop_departure_seconds") - pl.col("stop_arrival_seconds")).alias("stopped_duration_seconds"), ( - pl.coalesce(["stop_departure_seconds", "stop_arrival_seconds"]) - - pl.coalesce(["stop_departure_seconds", "stop_arrival_seconds"]) + pl.coalesce(["stop_departure_seconds", "stop_arrival_seconds", "plan_stop_departure_seconds"]) + - pl.coalesce(["stop_departure_seconds", "stop_arrival_seconds", "plan_stop_departure_seconds"]) .shift() .over( ["service_date", "stop_id", "direction_id", "route_id"], @@ -229,6 +232,7 @@ def calculate_derived_bus_performance_metrics( "stop_departure_dt", "stop_arrival_dt", "gtfs_last_in_transit_dt", + "plan_stop_departure_dt", ), ) ).alias("route_direction_headway_seconds"), @@ -243,6 +247,7 @@ def calculate_derived_bus_performance_metrics( "stop_departure_dt", "stop_arrival_dt", "gtfs_last_in_transit_dt", + "plan_stop_departure_dt", ), ) ) diff --git a/src/lamp_py/bus_performance_manager/events_tm_schedule.py b/src/lamp_py/bus_performance_manager/events_tm_schedule.py index bc9f612b..9ecaed49 100644 --- a/src/lamp_py/bus_performance_manager/events_tm_schedule.py +++ b/src/lamp_py/bus_performance_manager/events_tm_schedule.py @@ -29,7 +29,7 @@ class TransitMasterSchedule(BusBaseSchema): tm_planned_sequence_end = dy.Int64(nullable=True) tm_planned_sequence_start = dy.Int64(nullable=True) service_date = dy.Date(nullable=True) - tm_stop_departure_dt = dy.Datetime(nullable=False, time_zone=None) + tm_stop_departure_dt = dy.Datetime(nullable=False, time_zone="UTC") timepoint_order = dy.UInt32(nullable=True) waiver_remark = dy.String(nullable=True, regex=r"^([[:alpha:]]{1,5}|Unrecognized Code)$") STOP_CROSSING_ID = dy.Int64(nullable=False) @@ -151,7 +151,9 @@ def generate_tm_schedule(service_date: date) -> dy.DataFrame[TransitMasterSchedu ( pl.col("CALENDAR_ID").cast(pl.String).str.to_datetime("1%Y%m%d", time_unit="us") + pl.duration(seconds=pl.col("SCHEDULED_TIME")) - ).alias("tm_stop_departure_dt"), + ) + .dt.replace_time_zone("UTC") + .alias("tm_stop_departure_dt"), pl.col(["PATTERN_GEO_NODE_SEQ"]).rank(method="dense").over(["PATTERN_ID"]).alias("timepoint_order"), pl.col("ROUTE_ABBR").str.strip_chars_start(pl.lit("0")).alias("route_id"), pl.col("PATTERN_GEO_NODE_SEQ").max().over(["TRIP_SERIAL_NUMBER"]).alias("tm_planned_sequence_end"), diff --git a/src/lamp_py/tableau/conversions/convert_bus_performance_data.py b/src/lamp_py/tableau/conversions/convert_bus_performance_data.py index b5c20722..b9731654 100644 --- a/src/lamp_py/tableau/conversions/convert_bus_performance_data.py +++ b/src/lamp_py/tableau/conversions/convert_bus_performance_data.py @@ -50,8 +50,8 @@ def apply_bus_analysis_conversions(polars_df: pl.DataFrame) -> pl.DataFrame: pl.col("tm_actual_departure_dt").dt.convert_time_zone(time_zone="America/New_York").dt.replace_time_zone(None), pl.col("gtfs_departure_dt").dt.convert_time_zone(time_zone="America/New_York").dt.replace_time_zone(None), pl.col("gtfs_arrival_dt").dt.convert_time_zone(time_zone="America/New_York").dt.replace_time_zone(None), - # pl.col("plan_start_dt").dt.convert_time_zone(time_zone="America/New_York").dt.replace_time_zone(None), - # pl.col("plan_stop_departure_dt").dt.convert_time_zone(time_zone="America/New_York").dt.replace_time_zone(None), + pl.col("plan_start_dt").dt.convert_time_zone(time_zone="America/New_York").dt.replace_time_zone(None), + pl.col("plan_stop_departure_dt").dt.convert_time_zone(time_zone="America/New_York").dt.replace_time_zone(None), ) # Convert seconds columns to be aligned with Eastern Time diff --git a/tests/bus_performance_manager/test_events_metrics.py b/tests/bus_performance_manager/test_events_metrics.py index bad7255f..585c89bf 100644 --- a/tests/bus_performance_manager/test_events_metrics.py +++ b/tests/bus_performance_manager/test_events_metrics.py @@ -1,11 +1,13 @@ -# pylint: disable=too-many-positional-arguments,too-many-arguments +# pylint: disable=too-many-positional-arguments,too-many-arguments,too-many-ancestors from contextlib import nullcontext -from datetime import datetime, date +from datetime import date, datetime +import dataframely as dy import polars as pl import pytest -from dataframely.random import Generator from dataframely.exc import ValidationError +from dataframely.random import Generator + from lamp_py.bus_performance_manager.events_metrics import BusPerformanceMetrics @@ -40,7 +42,7 @@ def test_dy_departure_after_arrival( stopped_duration_seconds: int | None, num_rows: pytest.RaisesExc, ) -> None: - "It returns false if the departure dt is earlier than the arrival dt." + """It returns false if the departure dt is earlier than the arrival dt.""" df = BusPerformanceMetrics.sample(num_rows=1, generator=dy_gen).with_columns( stop_arrival_dt=stop_arrival_dt, stop_departure_dt=stop_departure_dt, @@ -111,7 +113,7 @@ def test_dy_stop_sequence_implies_time_order( stopped_duration_seconds: list[int], num_rows: pytest.RaisesExc, ) -> None: - "It returns false if any departure or arrival time is earlier than the preceding record." + """It returns false if any departure or arrival time is earlier than the preceding record.""" df = BusPerformanceMetrics.sample( num_rows=2, generator=dy_gen, @@ -177,7 +179,7 @@ def test_dy_travel_time_plus_stopped_duration_equals_total_trip( stopped_duration_seconds: list[int], num_rows: pytest.RaisesExc, ) -> None: - "It returns false if the travel times and stopped durations don't add up to the total trip duration." + """It returns false if the travel times and stopped durations don't add up to the total trip duration.""" df = BusPerformanceMetrics.sample(num_rows=3, generator=dy_gen).with_columns( trip_id=pl.lit("1"), vehicle_label=pl.lit("x"), @@ -191,3 +193,71 @@ def test_dy_travel_time_plus_stopped_duration_equals_total_trip( with num_rows: assert BusPerformanceMetrics.validate(df, cast=True).height == num_rows.enter_result # type: ignore[attr-defined] + + +class TestBusPerformanceMetrics(BusPerformanceMetrics): + """Production schema plus test-only rules.""" + + @dy.rule() + def headways_if_planned_or_actual_departure(cls) -> pl.Expr: + """Headways aren't null if there is a planned or actual departure time.""" + return pl.all_horizontal( + pl.coalesce("stop_departure_dt", "plan_stop_departure_dt").is_not_null(), + pl.coalesce("stop_departure_dt", "plan_stop_departure_dt").shift().is_not_null(), + ) == pl.all_horizontal(pl.selectors.contains("headway").is_not_null()) + + +@pytest.mark.parametrize( + [ + "stop_departure_dt", + "plan_stop_departure_dt", + "route_direction_headway_seconds", + "direction_destination_headway_seconds", + "num_rows", + ], + [ + ( + [None, None], + [datetime(2000, 1, 1), datetime(2000, 1, 1, 1)], + [None, None], + [None, None], + pytest.raises(ValidationError, match="headways_if_planned_or_actual_departure"), + ), + ( + [None, None], + [datetime(2000, 1, 1), datetime(2000, 1, 1, 1)], + [None, 60 * 60], + [None, 60 * 60], + nullcontext(2), + ), + ], + ids=[ + "planned_departure_no_headways", + "departure_with_headways", + ], +) +def test_dy_headways_if_planned_or_actual_departure( + dy_gen: Generator, + stop_departure_dt: list[datetime], + plan_stop_departure_dt: list[datetime], + route_direction_headway_seconds: list[int], + direction_destination_headway_seconds: list[int], + num_rows: pytest.RaisesExc, +) -> None: + """It returns false if there is a planned or actual departure time, but the route_direction_headway_seconds or direction_destination_headway_seconds are null.""" + df = TestBusPerformanceMetrics.sample(num_rows=2, generator=dy_gen).with_columns( + route_id=pl.lit("1"), + service_date=pl.lit(date(2000, 1, 1)), + stop_id=pl.lit("a"), + direction_id=pl.lit(0), + direction_destination=pl.lit("b"), + travel_time_seconds=pl.lit(None), + stopped_duration_seconds=pl.lit(None), + stop_departure_dt=pl.Series(values=stop_departure_dt), + plan_stop_departure_dt=pl.Series(values=plan_stop_departure_dt), + route_direction_headway_seconds=pl.Series(values=route_direction_headway_seconds), + direction_destination_headway_seconds=pl.Series(values=direction_destination_headway_seconds), + ) + + with num_rows: + assert TestBusPerformanceMetrics.validate(df, cast=True).height == num_rows.enter_result # type: ignore[attr-defined]