Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 45 additions & 6 deletions assayer/scorer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from __future__ import annotations

import re

from assayer.models import ModelResult

_model = None

_NON_BOUNDARY_ABBREVIATIONS = {"dr", "mr", "mrs", "ms", "prof", "sr", "jr", "st"}


def _get_model():
global _model
Expand Down Expand Up @@ -34,22 +38,57 @@ def compute_similarity(results: list[ModelResult]) -> dict[tuple[str, str], floa
for i in range(len(valid)):
for j in range(i + 1, len(valid)):
score = float(np.dot(normalized[i], normalized[j]))
score = max(-1.0, min(1.0, score))
similarity[(valid[i].model, valid[j].model)] = score

return similarity


def readability_stats(text: str) -> dict[str, float]:
sentences = [
s
for s in text.replace("!", ".").replace("?", ".").split(".")
if s.strip()
]
words = text.split()
word_count = len(words)
sentence_count = len(sentences) or 1
sentence_count = _count_sentences(text)
return {
"word_count": float(word_count),
"sentence_count": float(sentence_count),
"avg_sentence_length": word_count / sentence_count,
}


def _count_sentences(text: str) -> int:
count = 0
start = 0

for match in re.finditer(r"[.!?]+", text):
punct_start, punct_end = match.span()
if punct_end < len(text) and not text[punct_end].isspace():
continue
if _is_non_boundary_period(text, punct_start, punct_end):
continue

if text[start:punct_end].strip():
count += 1
start = punct_end

if text[start:].strip():
count += 1

return count or 1


def _is_non_boundary_period(text: str, punct_start: int, punct_end: int) -> bool:
if text[punct_start] != ".":
return False
if (
punct_start > 0
and punct_start + 1 < len(text)
and text[punct_start - 1].isdigit()
and text[punct_start + 1].isdigit()
):
return True

token_match = re.search(r"([A-Za-z]+)\.$", text[:punct_end])
if not token_match:
return False

return token_match.group(1).lower() in _NON_BOUNDARY_ABBREVIATIONS
14 changes: 14 additions & 0 deletions tests/test_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,20 @@ def test_readability_stats_basic():
assert stats["avg_sentence_length"] == pytest.approx(8 / 3)


@pytest.mark.parametrize(
("text", "sentence_count"),
[
("Dr. Smith scored 3.5. Well done.", 2),
("Visit example.com for details.", 1),
("The price is $3.99. Cheap!", 2),
("Mr. Jones paid at 4.30 p.m. Done?", 2),
],
)
def test_readability_stats_ignores_non_boundary_periods(text, sentence_count):
stats = readability_stats(text)
assert stats["sentence_count"] == sentence_count


def test_readability_stats_empty():
stats = readability_stats("")
assert stats["word_count"] == 0
Expand Down
Loading