From 2f8d06b2ec3aa0b42115b7c90ba991094253fe2c Mon Sep 17 00:00:00 2001 From: Deepak kudi Date: Fri, 22 May 2026 23:37:42 +0530 Subject: [PATCH] fix(scorer): count sentence boundaries accurately --- assayer/scorer.py | 51 ++++++++++++++++++++++++++++++++++++++------ tests/test_scorer.py | 14 ++++++++++++ 2 files changed, 59 insertions(+), 6 deletions(-) diff --git a/assayer/scorer.py b/assayer/scorer.py index db77d74..b34cd1e 100644 --- a/assayer/scorer.py +++ b/assayer/scorer.py @@ -1,9 +1,13 @@ from __future__ import annotations +import re + from assayer.models import ModelResult _model = None +_NON_BOUNDARY_ABBREVIATIONS = {"dr", "mr", "mrs", "ms", "prof", "sr", "jr", "st"} + def _get_model(): global _model @@ -34,22 +38,57 @@ def compute_similarity(results: list[ModelResult]) -> dict[tuple[str, str], floa for i in range(len(valid)): for j in range(i + 1, len(valid)): score = float(np.dot(normalized[i], normalized[j])) + score = max(-1.0, min(1.0, score)) similarity[(valid[i].model, valid[j].model)] = score return similarity def readability_stats(text: str) -> dict[str, float]: - sentences = [ - s - for s in text.replace("!", ".").replace("?", ".").split(".") - if s.strip() - ] words = text.split() word_count = len(words) - sentence_count = len(sentences) or 1 + sentence_count = _count_sentences(text) return { "word_count": float(word_count), "sentence_count": float(sentence_count), "avg_sentence_length": word_count / sentence_count, } + + +def _count_sentences(text: str) -> int: + count = 0 + start = 0 + + for match in re.finditer(r"[.!?]+", text): + punct_start, punct_end = match.span() + if punct_end < len(text) and not text[punct_end].isspace(): + continue + if _is_non_boundary_period(text, punct_start, punct_end): + continue + + if text[start:punct_end].strip(): + count += 1 + start = punct_end + + if text[start:].strip(): + count += 1 + + return count or 1 + + +def _is_non_boundary_period(text: str, punct_start: int, punct_end: int) -> bool: + if text[punct_start] != ".": + return False + if ( + punct_start > 0 + and punct_start + 1 < len(text) + and text[punct_start - 1].isdigit() + and text[punct_start + 1].isdigit() + ): + return True + + token_match = re.search(r"([A-Za-z]+)\.$", text[:punct_end]) + if not token_match: + return False + + return token_match.group(1).lower() in _NON_BOUNDARY_ABBREVIATIONS diff --git a/tests/test_scorer.py b/tests/test_scorer.py index 007e7fa..790813a 100644 --- a/tests/test_scorer.py +++ b/tests/test_scorer.py @@ -73,6 +73,20 @@ def test_readability_stats_basic(): assert stats["avg_sentence_length"] == pytest.approx(8 / 3) +@pytest.mark.parametrize( + ("text", "sentence_count"), + [ + ("Dr. Smith scored 3.5. Well done.", 2), + ("Visit example.com for details.", 1), + ("The price is $3.99. Cheap!", 2), + ("Mr. Jones paid at 4.30 p.m. Done?", 2), + ], +) +def test_readability_stats_ignores_non_boundary_periods(text, sentence_count): + stats = readability_stats(text) + assert stats["sentence_count"] == sentence_count + + def test_readability_stats_empty(): stats = readability_stats("") assert stats["word_count"] == 0