From 91b3614a8f29d66278ece1fb907a245eb7ccf83e Mon Sep 17 00:00:00 2001 From: AneeshD04 Date: Fri, 12 Jun 2026 14:04:00 -0700 Subject: [PATCH 1/4] Add DemandAssessor model and demand annotation pipeline - DemandAssessor: neural IRT model predicting P(response=1 | subject, item_features) via MLP over concatenated subject embeddings and item feature vectors - Full demand annotation pipeline (DemandAnnotator, GeminiClient, RubricsCatalog, AnnotationCache) implementing the 18-dimension ADeLe rubric scoring system - Unit tests for DemandAssessor (24 tests, synthetic data, no pretrained model needed) - Unit and live end-to-end tests for annotation pipeline - Fix deferred import in LLMJudge to avoid crash when transformers not installed Closes #41 --- .gitignore | 3 + src/torch_measure/annotation/__init__.py | 62 +++ src/torch_measure/annotation/_annotator.py | 122 ++++++ src/torch_measure/annotation/_cache.py | 57 +++ src/torch_measure/annotation/_client.py | 97 +++++ src/torch_measure/annotation/_parsers.py | 54 +++ src/torch_measure/annotation/_prompts.py | 25 ++ src/torch_measure/annotation/_rubrics.py | 62 +++ src/torch_measure/annotation/_types.py | 82 ++++ src/torch_measure/annotation/_ug.py | 83 ++++ src/torch_measure/annotation/py.typed | 0 src/torch_measure/annotation/rubrics/AS.txt | 39 ++ src/torch_measure/annotation/rubrics/AT.txt | 30 ++ src/torch_measure/annotation/rubrics/CEc.txt | 32 ++ src/torch_measure/annotation/rubrics/CEe.txt | 32 ++ src/torch_measure/annotation/rubrics/CL.txt | 36 ++ src/torch_measure/annotation/rubrics/KNa.txt | 32 ++ src/torch_measure/annotation/rubrics/KNc.txt | 32 ++ src/torch_measure/annotation/rubrics/KNf.txt | 32 ++ src/torch_measure/annotation/rubrics/KNn.txt | 32 ++ src/torch_measure/annotation/rubrics/KNs.txt | 32 ++ src/torch_measure/annotation/rubrics/MCr.txt | 31 ++ src/torch_measure/annotation/rubrics/MCt.txt | 38 ++ src/torch_measure/annotation/rubrics/MCu.txt | 38 ++ src/torch_measure/annotation/rubrics/MS.txt | 38 ++ src/torch_measure/annotation/rubrics/QLl.txt | 38 ++ src/torch_measure/annotation/rubrics/QLq.txt | 38 ++ src/torch_measure/annotation/rubrics/SNs.txt | 39 ++ .../annotation/rubrics/UG_choice_num.txt | 40 ++ src/torch_measure/annotation/rubrics/VO.txt | 39 ++ src/torch_measure/models/__init__.py | 2 + src/torch_measure/models/demand_assessor.py | 332 +++++++++++++++ src/torch_measure/models/llm_judge.py | 3 +- tests/test_annotation/__init__.py | 0 tests/test_annotation/test_cache.py | 271 ++++++++++++ tests/test_annotation/test_live.py | 257 ++++++++++++ tests/test_annotation/test_parsers.py | 225 ++++++++++ tests/test_annotation/test_pipeline.py | 390 ++++++++++++++++++ tests/test_annotation/test_prompts.py | 149 +++++++ tests/test_annotation/test_rubrics.py | 235 +++++++++++ tests/test_models/test_demand_assessor.py | 389 +++++++++++++++++ 41 files changed, 3567 insertions(+), 1 deletion(-) create mode 100644 src/torch_measure/annotation/__init__.py create mode 100644 src/torch_measure/annotation/_annotator.py create mode 100644 src/torch_measure/annotation/_cache.py create mode 100644 src/torch_measure/annotation/_client.py create mode 100644 src/torch_measure/annotation/_parsers.py create mode 100644 src/torch_measure/annotation/_prompts.py create mode 100644 src/torch_measure/annotation/_rubrics.py create mode 100644 src/torch_measure/annotation/_types.py create mode 100644 src/torch_measure/annotation/_ug.py create mode 100644 src/torch_measure/annotation/py.typed create mode 100644 src/torch_measure/annotation/rubrics/AS.txt create mode 100644 src/torch_measure/annotation/rubrics/AT.txt create mode 100644 src/torch_measure/annotation/rubrics/CEc.txt create mode 100644 src/torch_measure/annotation/rubrics/CEe.txt create mode 100644 src/torch_measure/annotation/rubrics/CL.txt create mode 100644 src/torch_measure/annotation/rubrics/KNa.txt create mode 100644 src/torch_measure/annotation/rubrics/KNc.txt create mode 100644 src/torch_measure/annotation/rubrics/KNf.txt create mode 100644 src/torch_measure/annotation/rubrics/KNn.txt create mode 100644 src/torch_measure/annotation/rubrics/KNs.txt create mode 100644 src/torch_measure/annotation/rubrics/MCr.txt create mode 100644 src/torch_measure/annotation/rubrics/MCt.txt create mode 100644 src/torch_measure/annotation/rubrics/MCu.txt create mode 100644 src/torch_measure/annotation/rubrics/MS.txt create mode 100644 src/torch_measure/annotation/rubrics/QLl.txt create mode 100644 src/torch_measure/annotation/rubrics/QLq.txt create mode 100644 src/torch_measure/annotation/rubrics/SNs.txt create mode 100644 src/torch_measure/annotation/rubrics/UG_choice_num.txt create mode 100644 src/torch_measure/annotation/rubrics/VO.txt create mode 100644 src/torch_measure/models/demand_assessor.py create mode 100644 tests/test_annotation/__init__.py create mode 100644 tests/test_annotation/test_cache.py create mode 100644 tests/test_annotation/test_live.py create mode 100644 tests/test_annotation/test_parsers.py create mode 100644 tests/test_annotation/test_pipeline.py create mode 100644 tests/test_annotation/test_prompts.py create mode 100644 tests/test_annotation/test_rubrics.py create mode 100644 tests/test_models/test_demand_assessor.py diff --git a/.gitignore b/.gitignore index 49217854..fed9b249 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +# Test annotation cache (API responses — do not commit) +tests/test_annotation/paper_comparison_cache.jsonl + # Python __pycache__/ *.py[cod] diff --git a/src/torch_measure/annotation/__init__.py b/src/torch_measure/annotation/__init__.py new file mode 100644 index 00000000..ea47547a --- /dev/null +++ b/src/torch_measure/annotation/__init__.py @@ -0,0 +1,62 @@ +"""ADeLe demand annotation pipeline (Gemini re-implementation). + +Reproduces the annotation methodology from: + Zhou et al. (2026) "General scales unlock AI evaluation with explanatory + and predictive power." Nature. + +Public API +---------- +DemandAnnotator — main entry point: annotates one item or a full dataset +GeminiClient — Gemini API wrapper (caller supplies pinned model string) +RubricsCatalog — loads the 19 bundled rubric files +AnnotationCache — append-only JSONL result cache + +Data types +---------- +AnnotationJob — input: item_id, content, reference_answer +DemandAnnotation — one (item, rubric) result with CoT response +UGAnnotation — UG classification result +ItemAnnotation — all 19 annotations for one item (.to_feature_vector()) +DemandVector — full-dataset tensor (n_items × 19) for DemandAssessor +CacheEntry — one persisted cache record + +Constants +--------- +DIMENSION_ORDER — canonical ordering of all 19 dimensions +DEMAND_DIMENSIONS — the first 18 (excludes UG) +""" +from ._annotator import DemandAnnotator +from ._cache import AnnotationCache +from ._client import GeminiClient +from ._rubrics import RubricsCatalog +from ._types import ( + DEMAND_DIMENSIONS, + DIMENSION_ORDER, + N_DIMENSIONS, + AnnotationJob, + CacheEntry, + DemandAnnotation, + DemandVector, + ItemAnnotation, + Rubric, + UGAnnotation, +) +from ._ug import UGAnnotator + +__all__ = [ + "DemandAnnotator", + "GeminiClient", + "RubricsCatalog", + "AnnotationCache", + "UGAnnotator", + "AnnotationJob", + "DemandAnnotation", + "UGAnnotation", + "ItemAnnotation", + "DemandVector", + "CacheEntry", + "Rubric", + "DIMENSION_ORDER", + "DEMAND_DIMENSIONS", + "N_DIMENSIONS", +] diff --git a/src/torch_measure/annotation/_annotator.py b/src/torch_measure/annotation/_annotator.py new file mode 100644 index 00000000..b97b45bb --- /dev/null +++ b/src/torch_measure/annotation/_annotator.py @@ -0,0 +1,122 @@ +from __future__ import annotations + +import hashlib +from typing import Optional + +from ._cache import AnnotationCache, make_cache_key +from ._client import GeminiClient +from ._parsers import extract_demand_level +from ._prompts import get_full_instruction +from ._rubrics import RubricsCatalog +from ._types import ( + AnnotationJob, + CacheEntry, + DemandAnnotation, + DemandVector, + ItemAnnotation, + Rubric, +) +from ._ug import UGAnnotator + + +class DemandAnnotator: + """Runs the full 19-call ADeLe annotation pipeline for one benchmark item. + + One API call per demand rubric (18 sequential calls) plus one UG call. + Results are cached to avoid redundant API calls across runs. + """ + + def __init__( + self, + client: GeminiClient, + rubrics: RubricsCatalog, + cache: Optional[AnnotationCache] = None, + ) -> None: + self._client = client + self._rubrics = rubrics + self._cache = cache + self._ug = UGAnnotator(client, rubrics, cache) + + def annotate(self, job: AnnotationJob) -> ItemAnnotation: + """Annotate one item across all 18 demand rubrics plus UG.""" + demands: dict[str, DemandAnnotation] = {} + for rubric in self._rubrics.all_demand_rubrics(): + demands[rubric.acronym] = self._annotate_one(job, rubric) + ug = self._ug.annotate(job) + return ItemAnnotation(item_id=job.item_id, demands=demands, ug=ug) + + def annotate_dataset(self, jobs: list[AnnotationJob]) -> DemandVector: + """Annotate all items and return a (n_items × 19) tensor. + + Row ordering in the returned ``DemandVector.tensor`` mirrors the order + of ``jobs``. To pass the result to ``DemandAssessor.fit()``, supply + ``jobs`` in the same order as ``data.to_fit_tensors()["item_ids"]``:: + + item_ids = data.to_fit_tensors()["item_ids"] # canonical order + jobs = [AnnotationJob(iid, content[iid], ref[iid]) for iid in item_ids] + dv = annotator.annotate_dataset(jobs) + model.fit(data, item_features=dv.tensor) + """ + import torch + + item_ids: list[str] = [] + rows: list[list[float]] = [] + for job in jobs: + item_ann = self.annotate(job) + item_ids.append(job.item_id) + rows.append(item_ann.to_feature_vector()) + + tensor = torch.tensor(rows, dtype=torch.float32) + return DemandVector(item_ids=item_ids, tensor=tensor) + + def _annotate_one(self, job: AnnotationJob, rubric: Rubric) -> DemandAnnotation: + key = make_cache_key( + content=job.content, + acronym=rubric.acronym, + model_id=self._client.model, + rubric_hash=rubric.rubric_hash, + ) + + if self._cache is not None: + entry = self._cache.get(key) + if entry is not None: + return DemandAnnotation( + item_id=job.item_id, + demand=rubric.acronym, + level=entry.level, + finish_reason=entry.finish_reason, + model_response=entry.model_response, + ) + + prompt = get_full_instruction( + dimension=rubric.dimension_name, + rubric_content=rubric.content, + item_text=job.content, + ) + model_response, finish_reason = self._client.generate(prompt) + level = extract_demand_level(model_response) + + annotation = DemandAnnotation( + item_id=job.item_id, + demand=rubric.acronym, + level=level, + finish_reason=finish_reason, + model_response=model_response, + ) + + if self._cache is not None: + content_hash = hashlib.sha256(job.content.encode()).hexdigest()[:16] + self._cache.put(CacheEntry( + key=key, + item_id=job.item_id, + demand=rubric.acronym, + level=level, + finish_reason=finish_reason, + model_response=model_response, + rubric_hash=rubric.rubric_hash, + model_id=self._client.model, + content_hash=content_hash, + timestamp=AnnotationCache.now_iso(), + )) + + return annotation diff --git a/src/torch_measure/annotation/_cache.py b/src/torch_measure/annotation/_cache.py new file mode 100644 index 00000000..ba208902 --- /dev/null +++ b/src/torch_measure/annotation/_cache.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +import dataclasses +import hashlib +import json +import math +from datetime import datetime, timezone +from pathlib import Path +from typing import Optional + +from ._types import CacheEntry + + +def make_cache_key(content: str, acronym: str, model_id: str, rubric_hash: str) -> str: + """sha256(content)[:16] : acronym : model_id : rubric_hash""" + content_hash = hashlib.sha256(content.encode()).hexdigest()[:16] + return f"{content_hash}:{acronym}:{model_id}:{rubric_hash}" + + +class AnnotationCache: + """Append-only JSONL cache keyed by sha256(content)[:16]:acronym:model_id:rubric_hash.""" + + def __init__(self, path: Path) -> None: + self._path = path + self._index: dict[str, CacheEntry] = {} + if path.exists(): + self._load() + + def _load(self) -> None: + with open(self._path, encoding="utf-8") as fh: + for line in fh: + line = line.strip() + if not line: + continue + data = json.loads(line) + # NaN is serialised as null (RFC-compliant); restore here. + level = data.get("level") + if level is None: + data["level"] = math.nan + entry = CacheEntry(**data) + self._index[entry.key] = entry + + def get(self, key: str) -> Optional[CacheEntry]: + return self._index.get(key) + + def put(self, entry: CacheEntry) -> None: + self._index[entry.key] = entry + self._path.parent.mkdir(parents=True, exist_ok=True) + record = dataclasses.asdict(entry) + if isinstance(record["level"], float) and math.isnan(record["level"]): + record["level"] = None + with open(self._path, "a", encoding="utf-8") as fh: + fh.write(json.dumps(record) + "\n") + + @staticmethod + def now_iso() -> str: + return datetime.now(timezone.utc).isoformat() diff --git a/src/torch_measure/annotation/_client.py b/src/torch_measure/annotation/_client.py new file mode 100644 index 00000000..dc86a530 --- /dev/null +++ b/src/torch_measure/annotation/_client.py @@ -0,0 +1,97 @@ +"""Gemini API client — the only file that imports google.genai.""" +from __future__ import annotations + +import time + +_FINISH_REASON_MAP: dict[str, str] = { + "STOP": "stop", + "MAX_TOKENS": "length", +} + + +def _is_retryable(exc: BaseException) -> bool: + # Network-level transient errors — server closed keep-alive connection + try: + import httpx + if isinstance(exc, (httpx.RemoteProtocolError, httpx.ConnectError, httpx.ReadError)): + return True + except ImportError: + pass + # Gemini API errors + try: + from google.genai import errors as genai_errors + if isinstance(exc, genai_errors.ServerError): + return True + if isinstance(exc, genai_errors.ClientError): + code = getattr(exc, "status_code", None) or getattr(exc, "code", None) + return code == 429 + except AttributeError: + pass + return False + + +class GeminiClient: + """Thin wrapper around google.genai with retry and finish-reason normalisation. + + Parameters + ---------- + api_key: + Gemini API key. + model: + Pinned model string, e.g. "gemini-2.0-flash-001". No default — caller + must supply the exact version to guarantee reproducibility. + """ + + _TEMPERATURE = 0.0 + _MAX_OUTPUT_TOKENS = 4096 # 2.5 Flash generates longer CoT than 2.0 Flash + + def __init__(self, api_key: str, model: str, rpm: int = 0) -> None: + import google.genai as genai + from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential + + self._client = genai.Client(api_key=api_key) + self.model = model + self._min_interval = (60.0 / rpm) if rpm > 0 else 0.0 + self._last_call_time: float = 0.0 + + self._generate_with_retry = retry( + retry=retry_if_exception(_is_retryable), + wait=wait_exponential(min=2, max=256), + stop=stop_after_attempt(10), + reraise=True, + )(self._call_api) + + def generate(self, prompt: str) -> tuple[str, str]: + """Call the API and return (response_text, finish_reason). + + Retries up to 10 times on transient errors with exponential backoff + (min 2 s, max 256 s), matching the paper's tenacity settings. + """ + return self._generate_with_retry(prompt) + + def _call_api(self, prompt: str) -> tuple[str, str]: + from google.genai import types as genai_types + + if self._min_interval > 0: + elapsed = time.monotonic() - self._last_call_time + if elapsed < self._min_interval: + time.sleep(self._min_interval - elapsed) + self._last_call_time = time.monotonic() + + response = self._client.models.generate_content( + model=self.model, + contents=prompt, + config=genai_types.GenerateContentConfig( + temperature=self._TEMPERATURE, + max_output_tokens=self._MAX_OUTPUT_TOKENS, + ), + ) + text: str = response.text or "" + candidate = response.candidates[0] if response.candidates else None + raw_reason = ( + candidate.finish_reason.name + if candidate and candidate.finish_reason + else "FINISH_REASON_UNSPECIFIED" + ) + finish_reason = _FINISH_REASON_MAP.get(raw_reason, "other") + return text, finish_reason diff --git a/src/torch_measure/annotation/_parsers.py b/src/torch_measure/annotation/_parsers.py new file mode 100644 index 00000000..e7ad7dd2 --- /dev/null +++ b/src/torch_measure/annotation/_parsers.py @@ -0,0 +1,54 @@ +"""Pure parsing functions, verbatim from adgomant/delean-batch-manager/parse.py.""" +from __future__ import annotations + +import math +import re + + +def extract_demand_level(response: str) -> float: + """Verbatim from extract_demand_level_from_response() in the paper repo. + + Splits on blank lines, takes the last paragraph as the conclusion, + extracts the last integer, validates 0-5, rejects leading section numbers. + Returns math.nan on any failure. + """ + segments = response.split("\n\n") + conclusion = segments[-1] + + digits = re.findall(r"\d+", conclusion) + if not digits: + return math.nan + + score = int(digits[-1]) + if not 0 <= score <= 5: + return math.nan + + # Reject if the only integer found is a leading section-header number + # (e.g. "4. Conclusion:" where 4 is not the actual score). + if len(digits) == 1 and re.search(rf"^{score}\.", conclusion, re.MULTILINE): + return math.nan + + return float(score) + + +def extract_ug_score(response: str) -> tuple[str, float]: + """Parse a UG classification response into (raw_output, ug_score). + + Model is instructed to output a single line: an integer N or the word "open". + Formula: ug_score = (1 - 1/N) * 100 for MCQ, or 100.0 for open-ended. + Returns math.nan as ug_score on any parse failure. + """ + raw = response.strip().split("\n")[0].strip() + + if raw.lower() == "open": + return raw, 100.0 + + try: + n = int(raw) + except ValueError: + return raw, math.nan + + if n < 1: + return raw, math.nan + + return raw, round((1.0 - 1.0 / n) * 100.0, 6) diff --git a/src/torch_measure/annotation/_prompts.py b/src/torch_measure/annotation/_prompts.py new file mode 100644 index 00000000..985f4445 --- /dev/null +++ b/src/torch_measure/annotation/_prompts.py @@ -0,0 +1,25 @@ +"""Pure prompt-construction functions, verbatim from the ADeLe paper pipeline.""" +from __future__ import annotations + + +def get_full_instruction(dimension: str, rubric_content: str, item_text: str) -> str: + """Verbatim prompt template from adgomant/delean-batch-manager/src/.../files.py.""" + return ( + f"The following rubric describes six distinct levels of *{dimension}*" + f" required by different tasks:\n" + f"{rubric_content}\n" + f"\nTASK INSTANCE: {item_text}\n" + f"\nINSTRUCTION: Score the level of *{dimension}* demanded by the given" + f" TASK INSTANCE using a discrete value from 0 to 5. Use CHAIN-OF-THOUGHTS" + f" REASONING to reason step by step before assigning the score. After the" + f" CHAIN-OF-THOUGHTS REASONING STEPS, conclude your assessment with the" + f' statement: "Thus, the level of *{dimension}* demanded by the given TASK' + f' INSTANCE is: SCORE", where SCORE is an integer score you have determined.\n' + f"\nCHAIN-OF-THOUGHTS REASONING STEPS to score the level of *{dimension}*" + f" demanded by the given TASK INSTANCE above:\n" + ) + + +def get_ug_instruction(item_text: str, reference_answer: str, ug_rubric_content: str) -> str: + """Best-faith reconstruction of the UG prompt (prepend format undocumented in paper repos).""" + return f"{item_text}\n\nReference answer: {reference_answer}\n\n{ug_rubric_content}" diff --git a/src/torch_measure/annotation/_rubrics.py b/src/torch_measure/annotation/_rubrics.py new file mode 100644 index 00000000..13f38fb0 --- /dev/null +++ b/src/torch_measure/annotation/_rubrics.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +import hashlib +from pathlib import Path + +from ._types import DEMAND_DIMENSIONS, Rubric + + +class RubricsCatalog: + """Loads rubric .txt files from the bundled rubrics/ directory.""" + + def __init__(self, rubrics_dir: Path | None = None) -> None: + if rubrics_dir is None: + rubrics_dir = Path(__file__).parent / "rubrics" + self._rubrics: dict[str, Rubric] = {} + self._ug_content: str = "" + self._ug_hash: str = "" + self._load(rubrics_dir) + missing = [a for a in DEMAND_DIMENSIONS if a not in self._rubrics] + if missing: + raise RuntimeError(f"Missing rubric files: {missing}") + + def _load(self, rubrics_dir: Path) -> None: + for path in rubrics_dir.glob("*.txt"): + acronym = path.stem + text = path.read_text(encoding="utf-8") + lines = text.splitlines(keepends=True) + + if lines and lines[0].startswith("#"): + dimension_name = lines[0].lstrip("#").strip() + content = "".join(lines[1:]).strip("\n") + else: + dimension_name = acronym + content = text + + rubric_hash = hashlib.sha256(content.encode()).hexdigest()[:16] + + if acronym == "UG_choice_num": + self._ug_content = content + self._ug_hash = rubric_hash + else: + self._rubrics[acronym] = Rubric( + acronym=acronym, + dimension_name=dimension_name, + content=content, + rubric_hash=rubric_hash, + ) + + def get(self, acronym: str) -> Rubric: + return self._rubrics[acronym] + + @property + def ug_content(self) -> str: + return self._ug_content + + @property + def ug_hash(self) -> str: + return self._ug_hash + + def all_demand_rubrics(self) -> list[Rubric]: + """Return all 18 demand rubrics in canonical DIMENSION_ORDER.""" + return [self._rubrics[a] for a in DEMAND_DIMENSIONS] diff --git a/src/torch_measure/annotation/_types.py b/src/torch_measure/annotation/_types.py new file mode 100644 index 00000000..b238b066 --- /dev/null +++ b/src/torch_measure/annotation/_types.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +import math +from dataclasses import dataclass +from typing import Any + +DIMENSION_ORDER: tuple[str, ...] = ( + "AS", "CEc", "CEe", "CL", "MCr", "MCt", "MCu", "MS", "QLl", "QLq", "SNs", + "KNa", "KNc", "KNf", "KNn", "KNs", "AT", "VO", "UG", +) + +DEMAND_DIMENSIONS: tuple[str, ...] = DIMENSION_ORDER[:18] +N_DIMENSIONS: int = 19 + + +@dataclass +class Rubric: + acronym: str + dimension_name: str + content: str # verbatim file text after the # Title line + rubric_hash: str # sha256(content)[:16] + + +@dataclass +class AnnotationJob: + item_id: str + content: str + reference_answer: str + + +@dataclass +class DemandAnnotation: + item_id: str + demand: str # rubric acronym + level: float # 0-5 or math.nan + finish_reason: str + model_response: str + + +@dataclass +class UGAnnotation: + item_id: str + raw_output: str + ug_score: float # 0-100 or math.nan + finish_reason: str + model_response: str + + +@dataclass +class ItemAnnotation: + item_id: str + demands: dict[str, DemandAnnotation] # acronym -> DemandAnnotation + ug: UGAnnotation + + def to_feature_vector(self) -> list[float]: + result: list[float] = [] + for dim in DEMAND_DIMENSIONS: + ann = self.demands.get(dim) + result.append(ann.level if ann is not None else math.nan) + result.append(self.ug.ug_score) + return result + + +@dataclass +class DemandVector: + item_ids: list[str] + tensor: Any # torch.Tensor at runtime; torch not imported here + + +@dataclass +class CacheEntry: + key: str + item_id: str + demand: str # rubric acronym or "UG" + level: float # demand level 0-5 or UG score 0-100; math.nan on parse failure + finish_reason: str + model_response: str + rubric_hash: str + model_id: str + content_hash: str + timestamp: str + raw_output: str = "" # UG only: the raw model token ("3", "open", etc.) diff --git a/src/torch_measure/annotation/_ug.py b/src/torch_measure/annotation/_ug.py new file mode 100644 index 00000000..f495c3f5 --- /dev/null +++ b/src/torch_measure/annotation/_ug.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +import hashlib +from typing import Optional + +from ._cache import AnnotationCache, make_cache_key +from ._client import GeminiClient +from ._parsers import extract_ug_score +from ._prompts import get_ug_instruction +from ._rubrics import RubricsCatalog +from ._types import AnnotationJob, CacheEntry, UGAnnotation + + +class UGAnnotator: + """Classifies benchmark items as MCQ or open-ended and computes the UG score. + + UG (Unguessability) is separate from the 18 demand rubrics: + - MCQ with N choices → ug_score = (1 - 1/N) * 100 + - open-ended → ug_score = 100.0 + """ + + def __init__( + self, + client: GeminiClient, + rubrics: RubricsCatalog, + cache: Optional[AnnotationCache] = None, + ) -> None: + self._client = client + self._rubrics = rubrics + self._cache = cache + + def annotate(self, job: AnnotationJob) -> UGAnnotation: + key = make_cache_key( + content=job.content, + acronym="UG", + model_id=self._client.model, + rubric_hash=self._rubrics.ug_hash, + ) + + if self._cache is not None: + entry = self._cache.get(key) + if entry is not None: + return UGAnnotation( + item_id=job.item_id, + raw_output=entry.raw_output, + ug_score=entry.level, + finish_reason=entry.finish_reason, + model_response=entry.model_response, + ) + + prompt = get_ug_instruction( + item_text=job.content, + reference_answer=job.reference_answer, + ug_rubric_content=self._rubrics.ug_content, + ) + model_response, finish_reason = self._client.generate(prompt) + raw_output, ug_score = extract_ug_score(model_response) + + annotation = UGAnnotation( + item_id=job.item_id, + raw_output=raw_output, + ug_score=ug_score, + finish_reason=finish_reason, + model_response=model_response, + ) + + if self._cache is not None: + content_hash = hashlib.sha256(job.content.encode()).hexdigest()[:16] + self._cache.put(CacheEntry( + key=key, + item_id=job.item_id, + demand="UG", + level=ug_score, + finish_reason=finish_reason, + model_response=model_response, + rubric_hash=self._rubrics.ug_hash, + model_id=self._client.model, + content_hash=content_hash, + timestamp=AnnotationCache.now_iso(), + raw_output=raw_output, + )) + + return annotation diff --git a/src/torch_measure/annotation/py.typed b/src/torch_measure/annotation/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/src/torch_measure/annotation/rubrics/AS.txt b/src/torch_measure/annotation/rubrics/AS.txt new file mode 100644 index 00000000..23c0eae4 --- /dev/null +++ b/src/torch_measure/annotation/rubrics/AS.txt @@ -0,0 +1,39 @@ +# Attention and Search +This criterion assesses the level of attention and scan required to focus on or locate specific elements within a given stream of information or environment in the whole process of solving a task. During this process, there is the need to actively scan for or retrieve elements that meet predetermined criteria. The level represents the extent to which the task requires locating and focusing on specific target information, ranging from situations where the target is immediately obvious to those requiring sustained tracking of multiple targets among numerous distractors—any elements that are irrelevant to solve the task, such as visual objects, sounds, pieces of text, noise, or other stimuli, but compete for attention with the target information—in complex, dynamic environments. The challenge is not on determining what to look for but focusing the attention to find it within a larger context. This differs from tasks where there's a need to identify which pieces of information are relevant from a set already under consideration. While both processes may overlap in complex tasks like reading comprehension or image understanding, "attention and scan" specifically focuses on the deployment of attention during scan processes when solving the task, rather than the selection or evaluation of information. + +Level 0: None. No attention or scan is required. The target information is immediately obvious or is the only information present. +Examples: +* "Given a single word input, determine if it starts with a capital letter." +* "Look at the only object in the centre of the white page and tell what colour it is." +* "Is Madrid the capital of Spain?" + +Level 1: Very low. Minimal attention or scanning is required. The target information is easily distinguishable with little to almost no distraction. +Examples: +* "Find the only blue car in a car park full of red cars." +* "Find the letter 'X' among a row of 'O's" +* "Spot the tall tree in a row of short bushes." + +Level 2: Low. Some attention or basic scanning is required. The target information is visible among a few distractors or in a small scan area. +Examples: +* "Find all the vowels in the following sentence: 'The quick brown fox jumps over the lazy dog.'" +* "Find who's wearing glasses in this photo of students at commencement, with 2 rows of 5 students each, all facing forward, taken by a professional photographer." +* "Who authored the Queensberry rules, which were published in 1867 for the sport of boxing? Choices: A. John Douglas (in his late twenties)\nB. John Graham Chambers (in his mid-twenties)\nC. Marquess of Queensberry (in his early thirties)\nD. James Figg (in his forties)." + +Level 3: Intermediate. Moderate attention and scan are required. The target information is mixed with several distractors or spread over a fairly large scan area. +Examples: +* "Find everyone wearing glasses in this casual BBQ photo where 15 people are gathered around a table. Some are sitting, some standing, some looking at the camera while others are in conversation." +* "In a 5-page technical document about basic geometry, locate all explicit references to the Pythagorean theorem (a² + b² = c²), where the equation appears 5 times mixed among references to 15 other geometric formulas, with occasional inconsistent equation numbering but standard mathematical notation. +* "While reading a podcast interview, keep track of how many times the guest explicitly discusses content about their new book." +* "As we all know, the Queensberry Rules are a set of rules for boxing that govern both amateur and professional matches. Who authored the Queensberry rules, which were published in 1867 for the sport of boxing? Choices: A. John Douglas (in his late twenties)\nB. John Graham Chambers (in his mid-twenties)\nC. Marquess of Queensberry (in his early thirties)\nD. James Figg (in his forties)\nE. James Zou (in his fifties)\nF. Lucy Grande (in her late twenties)\nG. Xiaoxiao Li (in her early forties)\nH. Enrique Garcia (in his late thirties)." + +Level 4: High. Sustained tracking of one or various targets is required. The target information is in an environment mixed with numerous distractors and changing conditions. requires some continuous monitoring amid competing signals. +Examples: +* "Listening to a symphony, identify all instances where the clarinet plays in a minor key, even when it's not playing the main melody. +* "Track three orange spheres among twenty red spheres as they move randomly across a black screen (40 cm × 30 cm) at varying speeds (1-3 cm/s), with spheres frequently intersecting paths and maintaining a minimum separation distance of 2 cm. Each sphere is 1 cm in diameter." +* "In a real-time video feed of a busy airport, finding the locations of ten blue suitcases." + +Level 5: Very high. Requires sustained attention and scan for simultaneous tracking of multiple targets across different domains or contexts, with continuous adaptation to fast-changing conditions. The target information is extremely difficult to distinguish from distractors or is hidden in a vast or constantly changing environment. +Examples: +* "While seated courtside at a professional basketball game, track two specific players throughout the entire game as they move at speeds up to 8m/s, frequently cluster with other players during rebounds, and weave through screens and defensive formations." +* "Monitor four simultaneous video feeds of a crowded airport terminal from different angles, detecting subtle security-relevant changes (e.g. brief interactions < 2 seconds, crowd flow changes, small object exchanges) across feeds." +* "While monitoring multiple simultaneous customer service chat conversations in different languages, identify instances where customers are expressing the same underlying technical issue, even though they're describing it using different metaphors, technical terms, or cultural references specific to their region." diff --git a/src/torch_measure/annotation/rubrics/AT.txt b/src/torch_measure/annotation/rubrics/AT.txt new file mode 100644 index 00000000..3e5b4556 --- /dev/null +++ b/src/torch_measure/annotation/rubrics/AT.txt @@ -0,0 +1,30 @@ +# Atypicality +Level 0: None. The task is a staple one. Exactly the same instance of the task appears many times on the Internet, textbooks or common psychological or achievement tests, and the solution is generally well-known and memorized. Examples: +* "What is 2 + 2?" +* "Name the capital of France." +* "What gets wetter and wetter the more it dries?" + +Level 1: Very Low. The task is very common and the specific task instance is likely to frequently appear on the Internet, textbooks or common psychological or achievement tests, so the chance that the solution is well-known and memorized is high. Examples: +* "What is the derivative of sin(x)?" +* "Define opportunity cost." +* "Name the seven continents." + +Level 2: Low. The task is moderately common and the specific task instance varies somewhat from other common examples or is unlikely to have seen it before in exactly the same form, but possibly in variations. Examples: +* "What is 21251 + 2835?" +* "Given the molecular SMILES: COC[C@H]1OC(=O)c2coc3c2[C@@]1(C)C1=C(C3=O)[C@@H]2CCC(=O)[C@@]2(C)C[C@H]1OC(C)=O, your task is to provide the detailed description of the molecule using your experienced chemical Molecular knowledge." +* "Solve the following Math Olympiad question: Determine the greatest real number $ C $, such that for every positive integer $ n\ge 2 $, there exists $ x_1, x_2,..., x_n \in [-1,1]$, so that $$\prod_{1\le i