Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def join_tm_schedule_to_gtfs_schedule(
)
.with_columns(
pl.col("plan_stop_departure_dt")
.sub(pl.col("service_date").cast(pl.Datetime))
.sub(pl.col("service_date").cast(pl.Datetime(time_zone="UTC")))
.dt.total_seconds()
.alias("plan_stop_departure_sam"),
pl.struct("plan_stop_departure_dt", "tm_stop_sequence", "gtfs_stop_sequence", "plan_start_dt")
Expand Down
12 changes: 6 additions & 6 deletions src/lamp_py/bus_performance_manager/events_gtfs_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ class GTFSBusSchedule(BusBaseSchema):
plan_travel_time_seconds = dy.Int64(nullable=True)
plan_route_direction_headway_seconds = dy.Int64(nullable=True)
plan_direction_destination_headway_seconds = dy.Int64(nullable=True)
plan_start_dt = dy.Datetime(nullable=True, time_zone=None)
plan_stop_departure_dt = dy.Datetime(nullable=False, time_zone=None)
plan_start_dt = dy.Datetime(nullable=True, time_zone="UTC")
plan_stop_departure_dt = dy.Datetime(nullable=False, time_zone="UTC")
service_date = dy.Date(primary_key=True)


Expand Down Expand Up @@ -227,14 +227,14 @@ def stop_events_for_date(service_date: date) -> pl.DataFrame:
pl.datetime(service_date.year, service_date.month, service_date.day)
+ pl.duration(seconds=pl.col("plan_start_time"))
)
.alias("plan_start_dt")
.dt.replace_time_zone(None),
.dt.replace_time_zone("UTC")
.alias("plan_start_dt"),
(
pl.datetime(service_date.year, service_date.month, service_date.day)
+ pl.duration(seconds=pl.col("departure_seconds"))
)
.alias("plan_stop_departure_dt")
.dt.replace_time_zone(None),
.dt.replace_time_zone("UTC")
.alias("plan_stop_departure_dt"),
)
.drop(
"arrival_time",
Expand Down
9 changes: 7 additions & 2 deletions src/lamp_py/bus_performance_manager/events_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,9 @@ def calculate_derived_bus_performance_metrics(
.alias("gtfs_first_in_transit_seconds"),
(pl.col("stop_arrival_dt") - pl.col("service_date")).dt.total_seconds().alias("stop_arrival_seconds"),
(pl.col("stop_departure_dt") - pl.col("service_date")).dt.total_seconds().alias("stop_departure_seconds"),
(pl.col("plan_stop_departure_dt") - pl.col("service_date"))
.dt.total_seconds()
.alias("plan_stop_departure_seconds"),
)
# add metrics columns to events
.with_columns(
Expand All @@ -220,15 +223,16 @@ def calculate_derived_bus_performance_metrics(
.alias("travel_time_seconds"),
(pl.col("stop_departure_seconds") - pl.col("stop_arrival_seconds")).alias("stopped_duration_seconds"),
(
pl.coalesce(["stop_departure_seconds", "stop_arrival_seconds"])
- pl.coalesce(["stop_departure_seconds", "stop_arrival_seconds"])
pl.coalesce(["stop_departure_seconds", "stop_arrival_seconds", "plan_stop_departure_seconds"])
- pl.coalesce(["stop_departure_seconds", "stop_arrival_seconds", "plan_stop_departure_seconds"])
.shift()
.over(
["service_date", "stop_id", "direction_id", "route_id"],
order_by=pl.coalesce(
"stop_departure_dt",
"stop_arrival_dt",
"gtfs_last_in_transit_dt",
"plan_stop_departure_dt",
),
)
).alias("route_direction_headway_seconds"),
Expand All @@ -243,6 +247,7 @@ def calculate_derived_bus_performance_metrics(
"stop_departure_dt",
"stop_arrival_dt",
"gtfs_last_in_transit_dt",
"plan_stop_departure_dt",
),
)
)
Expand Down
6 changes: 4 additions & 2 deletions src/lamp_py/bus_performance_manager/events_tm_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class TransitMasterSchedule(BusBaseSchema):
tm_planned_sequence_end = dy.Int64(nullable=True)
tm_planned_sequence_start = dy.Int64(nullable=True)
service_date = dy.Date(nullable=True)
tm_stop_departure_dt = dy.Datetime(nullable=False, time_zone=None)
tm_stop_departure_dt = dy.Datetime(nullable=False, time_zone="UTC")
timepoint_order = dy.UInt32(nullable=True)
waiver_remark = dy.String(nullable=True, regex=r"^([[:alpha:]]{1,5}|Unrecognized Code)$")
STOP_CROSSING_ID = dy.Int64(nullable=False)
Expand Down Expand Up @@ -151,7 +151,9 @@ def generate_tm_schedule(service_date: date) -> dy.DataFrame[TransitMasterSchedu
(
pl.col("CALENDAR_ID").cast(pl.String).str.to_datetime("1%Y%m%d", time_unit="us")
+ pl.duration(seconds=pl.col("SCHEDULED_TIME"))
).alias("tm_stop_departure_dt"),
)
.dt.replace_time_zone("UTC")
.alias("tm_stop_departure_dt"),
pl.col(["PATTERN_GEO_NODE_SEQ"]).rank(method="dense").over(["PATTERN_ID"]).alias("timepoint_order"),
pl.col("ROUTE_ABBR").str.strip_chars_start(pl.lit("0")).alias("route_id"),
pl.col("PATTERN_GEO_NODE_SEQ").max().over(["TRIP_SERIAL_NUMBER"]).alias("tm_planned_sequence_end"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def apply_bus_analysis_conversions(polars_df: pl.DataFrame) -> pl.DataFrame:
pl.col("tm_actual_departure_dt").dt.convert_time_zone(time_zone="America/New_York").dt.replace_time_zone(None),
pl.col("gtfs_departure_dt").dt.convert_time_zone(time_zone="America/New_York").dt.replace_time_zone(None),
pl.col("gtfs_arrival_dt").dt.convert_time_zone(time_zone="America/New_York").dt.replace_time_zone(None),
# pl.col("plan_start_dt").dt.convert_time_zone(time_zone="America/New_York").dt.replace_time_zone(None),
# pl.col("plan_stop_departure_dt").dt.convert_time_zone(time_zone="America/New_York").dt.replace_time_zone(None),
pl.col("plan_start_dt").dt.convert_time_zone(time_zone="America/New_York").dt.replace_time_zone(None),
pl.col("plan_stop_departure_dt").dt.convert_time_zone(time_zone="America/New_York").dt.replace_time_zone(None),
)

# Convert seconds columns to be aligned with Eastern Time
Expand Down
82 changes: 76 additions & 6 deletions tests/bus_performance_manager/test_events_metrics.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# pylint: disable=too-many-positional-arguments,too-many-arguments
# pylint: disable=too-many-positional-arguments,too-many-arguments,too-many-ancestors
from contextlib import nullcontext
from datetime import datetime, date
from datetime import date, datetime

import dataframely as dy
import polars as pl
import pytest
from dataframely.random import Generator
from dataframely.exc import ValidationError
from dataframely.random import Generator

from lamp_py.bus_performance_manager.events_metrics import BusPerformanceMetrics


Expand Down Expand Up @@ -40,7 +42,7 @@ def test_dy_departure_after_arrival(
stopped_duration_seconds: int | None,
num_rows: pytest.RaisesExc,
) -> None:
"It returns false if the departure dt is earlier than the arrival dt."
"""It returns false if the departure dt is earlier than the arrival dt."""
df = BusPerformanceMetrics.sample(num_rows=1, generator=dy_gen).with_columns(
stop_arrival_dt=stop_arrival_dt,
stop_departure_dt=stop_departure_dt,
Expand Down Expand Up @@ -111,7 +113,7 @@ def test_dy_stop_sequence_implies_time_order(
stopped_duration_seconds: list[int],
num_rows: pytest.RaisesExc,
) -> None:
"It returns false if any departure or arrival time is earlier than the preceding record."
"""It returns false if any departure or arrival time is earlier than the preceding record."""
df = BusPerformanceMetrics.sample(
num_rows=2,
generator=dy_gen,
Expand Down Expand Up @@ -177,7 +179,7 @@ def test_dy_travel_time_plus_stopped_duration_equals_total_trip(
stopped_duration_seconds: list[int],
num_rows: pytest.RaisesExc,
) -> None:
"It returns false if the travel times and stopped durations don't add up to the total trip duration."
"""It returns false if the travel times and stopped durations don't add up to the total trip duration."""
df = BusPerformanceMetrics.sample(num_rows=3, generator=dy_gen).with_columns(
trip_id=pl.lit("1"),
vehicle_label=pl.lit("x"),
Expand All @@ -191,3 +193,71 @@ def test_dy_travel_time_plus_stopped_duration_equals_total_trip(

with num_rows:
assert BusPerformanceMetrics.validate(df, cast=True).height == num_rows.enter_result # type: ignore[attr-defined]


class TestBusPerformanceMetrics(BusPerformanceMetrics):
"""Production schema plus test-only rules."""

@dy.rule()
def headways_if_planned_or_actual_departure(cls) -> pl.Expr:
"""Headways aren't null if there is a planned or actual departure time."""
return pl.all_horizontal(
pl.coalesce("stop_departure_dt", "plan_stop_departure_dt").is_not_null(),
pl.coalesce("stop_departure_dt", "plan_stop_departure_dt").shift().is_not_null(),
) == pl.all_horizontal(pl.selectors.contains("headway").is_not_null())


@pytest.mark.parametrize(
[
"stop_departure_dt",
"plan_stop_departure_dt",
"route_direction_headway_seconds",
"direction_destination_headway_seconds",
"num_rows",
],
[
(
[None, None],
[datetime(2000, 1, 1), datetime(2000, 1, 1, 1)],
[None, None],
[None, None],
pytest.raises(ValidationError, match="headways_if_planned_or_actual_departure"),
),
(
[None, None],
[datetime(2000, 1, 1), datetime(2000, 1, 1, 1)],
[None, 60 * 60],
[None, 60 * 60],
nullcontext(2),
),
],
ids=[
"planned_departure_no_headways",
"departure_with_headways",
],
)
def test_dy_headways_if_planned_or_actual_departure(
dy_gen: Generator,
stop_departure_dt: list[datetime],
plan_stop_departure_dt: list[datetime],
route_direction_headway_seconds: list[int],
direction_destination_headway_seconds: list[int],
num_rows: pytest.RaisesExc,
) -> None:
"""It returns false if there is a planned or actual departure time, but the route_direction_headway_seconds or direction_destination_headway_seconds are null."""
df = TestBusPerformanceMetrics.sample(num_rows=2, generator=dy_gen).with_columns(
route_id=pl.lit("1"),
service_date=pl.lit(date(2000, 1, 1)),
stop_id=pl.lit("a"),
direction_id=pl.lit(0),
direction_destination=pl.lit("b"),
travel_time_seconds=pl.lit(None),
stopped_duration_seconds=pl.lit(None),
stop_departure_dt=pl.Series(values=stop_departure_dt),
plan_stop_departure_dt=pl.Series(values=plan_stop_departure_dt),
route_direction_headway_seconds=pl.Series(values=route_direction_headway_seconds),
direction_destination_headway_seconds=pl.Series(values=direction_destination_headway_seconds),
)

with num_rows:
assert TestBusPerformanceMetrics.validate(df, cast=True).height == num_rows.enter_result # type: ignore[attr-defined]
Loading