diff --git a/src/votekit/utils.py b/src/votekit/utils.py index 6badce3b..9631360f 100644 --- a/src/votekit/utils.py +++ b/src/votekit/utils.py @@ -1,7 +1,7 @@ import math import random from itertools import permutations -from typing import Literal, Optional, Sequence, Union +from typing import Any, Literal, Optional, Sequence, Union, cast import numpy as np import pandas as pd @@ -369,6 +369,66 @@ def first_place_votes( ) +def _ballots_are_materialized(profile: RankProfile) -> bool: + """ + Checks if the ballots are materialized in a ``RankProfile``. + + Args: + profile (RankProfile): RankProfile of ballots. + + Returns: + bool: + True if ballots are materialized, False otherwise. + """ + return "ballots" in profile.__dict__ + + +def _mentions_from_df(profile: RankProfile) -> dict[str, float]: + assert profile.max_ranking_length is not None + + ranking_cols = [f"Ranking_{i}" for i in range(1, profile.max_ranking_length + 1)] + + tilde = frozenset({"~"}) + + rank_sets = cast(Any, profile.df[ranking_cols].stack()) + + mask = rank_sets.map(lambda s: isinstance(s, frozenset) and bool(s) and s != tilde) + + rank_sets = rank_sets[mask] + exploded = rank_sets.explode() + + if exploded.empty: + return {c: 0.0 for c in profile.candidates} + + weights = profile.df["Weight"].reindex(exploded.index.get_level_values(0)).to_numpy() + + totals = pd.Series(weights).groupby(exploded.to_numpy(), sort=False).sum() + + return {c: totals.get(c, 0.0) for c in profile.candidates} + + +def fast_mentions(profile: RankProfile) -> dict[str, float]: + """ + Decides which way to compute mentions based on whether ballots are materialized in the profile. + If they are, uses the traditional mentions calculation. + If not, uses a faster pandas-based approach. + + Args: + profile (RankProfile): RankProfile of ballots. + + Returns: + dict[str, float]: + Dictionary mapping candidates to mention totals (values). + """ + if not isinstance(profile, RankProfile): + raise TypeError("Profile must be of type RankProfile.") + + if _ballots_are_materialized(profile): + return mentions(profile) + + return _mentions_from_df(profile) + + def mentions( profile: RankProfile, ) -> dict[str, float]: diff --git a/tests/test_utils.py b/tests/test_utils.py index 8d52fd6a..1a943623 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,4 +1,5 @@ from itertools import permutations +from pathlib import Path from typing import Literal, cast import pytest @@ -12,6 +13,7 @@ borda_scores, elect_cands_from_set_ranking, expand_tied_ballot, + fast_mentions, first_place_votes, index_to_lexicographic_ballot, mentions, @@ -24,6 +26,8 @@ validate_score_vector, ) +CSV_DIR = Path(__file__).resolve().parents[0] / "data" / "csv" + profile_no_ties = RankProfile( ballots=( RankBallot(ranking=tuple(map(frozenset, [{"A"}, {"B"}])), weight=1), @@ -40,6 +44,14 @@ ) ) +profile_with_duplicates = RankProfile( + ballots=( + RankBallot(ranking=tuple(map(frozenset, [{"A"}, {"B"}, {"B"}])), weight=1), + RankBallot(ranking=tuple(map(frozenset, [{"A"}, {"B"}, {"C"}])), weight=1 / 2), + RankBallot(ranking=tuple(map(frozenset, [{"B"}, {"B"}, {"B"}])), weight=3), + ) +) + profile_with_missing = RankProfile( ballots=( RankBallot(ranking=tuple(map(frozenset, [{"A", "B"}, {"D"}])), weight=1), @@ -251,11 +263,57 @@ def test_mentions(): assert isinstance(test["A"], float) +def test_mentions_with_ties(): + correct = {"A": 9 / 2, "B": 9 / 2, "C": 7 / 2} + test = mentions(profile_with_ties) + assert correct == test + assert isinstance(test["A"], float) + + +def test_mentions_with_duplicates(): + correct = {"A": 3 / 2, "B": 23 / 2, "C": 1 / 2} + test = mentions(profile_with_duplicates) + assert correct == test + assert isinstance(test["A"], float) + + +def test_fast_mentions(): + correct = {"A": 9 / 2, "B": 9 / 2, "C": 7 / 2} + test = fast_mentions(profile_no_ties) + assert correct == test + assert isinstance(test["A"], float) + + +def test_fast_mentions_with_ties(): + correct = {"A": 9 / 2, "B": 9 / 2, "C": 7 / 2} + test = fast_mentions(profile_with_ties) + assert correct == test + assert isinstance(test["A"], float) + + +def test_fast_mentions_with_duplicates(): + correct = {"A": 3 / 2, "B": 23 / 2, "C": 1 / 2} + test = fast_mentions(profile_with_duplicates) + assert correct == test + assert isinstance(test["A"], float) + + +@pytest.mark.slow +def test_fast_and_slow_mentions_are_same(): + profile = cast(RankProfile, RankProfile.from_csv(CSV_DIR / "albany_profile.csv")) + assert mentions(profile) == fast_mentions(profile) + + def test_mentions_errors(): with pytest.raises(TypeError, match="Profile must be of type RankProfile"): mentions(cast(RankProfile, ScoreProfile(ballots=(ScoreBallot(scores={"A": 3}),)))) +def test_fast_mentions_errors(): + with pytest.raises(TypeError, match="Profile must be of type RankProfile"): + fast_mentions(cast(RankProfile, ScoreProfile(ballots=(ScoreBallot(scores={"A": 3}),)))) + + def test_borda_no_ties(): true_borda = {"A": 15 / 2, "B": 9, "C": 19 / 2}