diff --git a/pyproject.toml b/pyproject.toml index 004fa3f0..a603847b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,41 +67,33 @@ name = "pruna_internal" url = "https://prunaai.pythonanywhere.com/simple/" explicit = true -[[tool.uv.index]] -name = "intel-pytorch-extension" -url = "https://pytorch-extension.intel.com/release-whl/stable/cpu/cn/" -explicit = true - [tool.uv] index-strategy = "first-index" +exclude-newer = "1 week" # protection against compromised dependencies +# trusted dev wheels that are missing an upload date +exclude-newer-package = { gptqmodel = false, "stable-fast-pruna" = false } conflicts = [ [{ extra = "awq" }, { extra = "vbench" }], [{ extra = "vllm" }, { extra = "vbench" }], - [{ extra = "intel" }, { extra = "awq" }], [{ extra = "gptq" }, { extra = "awq" }], - # intel is incompatible with all stable-fast variants and vllm - [{ extra = "intel" }, { extra = "stable-fast" }, { extra = "stable-fast-extraindex" }], - [{ extra = "intel" }, { extra = "full" }, { extra = "stable-fast-extraindex" }], - [{ extra = "intel" }, { extra = "vllm" }], [{ extra = "kvpress" }, { extra = "vbench" }], ] [tool.uv.sources] gptqmodel = { index = "pruna_internal", marker = "sys_platform != 'darwin' or platform_machine != 'arm64'" } -intel-extension-for-pytorch = { index = "intel-pytorch-extension" } stable-fast-pruna = { index = "pruna_internal", extra = "stable-fast-extraindex" } [project] name = "pruna" -version = "0.3.2" +version = "0.3.3" description = "Smash your AI models" authors = [ {name = "Pruna AI", email = "hello@pruna.ai"} ] license = {file = "LICENSE"} readme = "README.md" -requires-python = ">=3.10,<3.13" +requires-python = ">=3.10,<3.14" keywords = ["AI", "machine learning", "model optimization", "pruning"] classifiers = [ "Development Status :: 4 - Beta", @@ -246,12 +238,6 @@ lmharness = [ "lm-eval>=0.4.0" ] -# Intel extension is tightly coupled with the torch version -intel = [ - "intel-extension-for-pytorch>=2.7.0", - "torch>=2.7.0,<2.9.0", - "torchvision>=0.22.0,<0.24.0", -] kvpress = [ "kvpress>=0.5.2", ] diff --git a/src/pruna/evaluation/benchmarks.py b/src/pruna/evaluation/benchmarks.py index f00749e2..36e50c5e 100644 --- a/src/pruna/evaluation/benchmarks.py +++ b/src/pruna/evaluation/benchmarks.py @@ -66,19 +66,19 @@ class BenchmarkRegistry: paper (see reference URL). All entries verified from paper evaluation sections (ar5iv/HTML or PDF) as of verification pass: - - Parti Prompts (2206.10789 §5.2, §5.4): human side-by-side only on P222. - - DrawBench (2205.11487 §4.3): human raters only; COCO uses FID + CLIP. + - Parti Prompts (2206.10789 ?5.2, ?5.4): human side-by-side only on P222. + - DrawBench (2205.11487 ?4.3): human raters only; COCO uses FID + CLIP. - GenAI Bench (2406.13743): VQAScore only (web/PWC; ar5iv failed). - VBench (2311.17982): 16 dimension-specific methods; no single Pruna metric. - - COCO (2205.11487 §4.1): FID and CLIP score for fidelity and alignment. - - ImageNet (1409.0575 §4): top-1/top-5 classification accuracy. - - WikiText (1609.07843 §5): perplexity on validation/test. - - GenEval (2310.11513 §3.2): Mask2Former + CLIP color pipeline, binary score. + - COCO (2205.11487 ?4.1): FID and CLIP score for fidelity and alignment. + - ImageNet (1409.0575 ?4): top-1/top-5 classification accuracy. + - WikiText (1609.07843 ?5): perplexity on validation/test. + - GenEval (2310.11513 ?3.2): Mask2Former + CLIP color pipeline, binary score. - HPS (2306.09341): HPS v2 scoring model (CLIP fine-tuned on HPD v2). - - ImgEdit (2505.20275 §4.2): GPT-4o 1âÿÿ5 ratings and ImgEdit-Judge. - - Long Text Bench (2507.22058 §4): Text Accuracy (OCR, Qwen2.5-VL-7B). - - GEditBench (2504.17761 §4.2): VIEScore (SQ, PQ, O via GPT-4.1/Qwen2.5-VL). - - OneIG (2506.07977 §4.1): per-dimension metrics (semantic alignment, ED, etc.). + - ImgEdit (2505.20275 ?4.2): GPT-4o 1ÿÿÿ5 ratings and ImgEdit-Judge. + - Long Text Bench (2507.22058 ?4): Text Accuracy (OCR, Qwen2.5-VL-7B). + - GEditBench (2504.17761 ?4.2): VIEScore (SQ, PQ, O via GPT-4.1/Qwen2.5-VL). + - OneIG (2506.07977 ?4.1): per-dimension metrics (semantic alignment, ED, etc.). - DPG (2403.05135): DSG-style graph score, mPLUG-large adjudicator. """ @@ -195,7 +195,7 @@ def list(cls, task_type: str | None = None) -> list[str]: "MS-COCO for text-to-image evaluation (Imagen, 2205.11487). Paper reports " "FID for fidelity and CLIP score for image-text alignment." ), - metrics=["fid", "clip_score"], # §4.1: FID + CLIP score + metrics=["fid", "clip_score"], # ?4.1: FID + CLIP score task_type="text_to_image", reference="https://arxiv.org/abs/2205.11487", ), diff --git a/src/pruna/evaluation/metrics/__init__.py b/src/pruna/evaluation/metrics/__init__.py index f7386198..4a08d622 100644 --- a/src/pruna/evaluation/metrics/__init__.py +++ b/src/pruna/evaluation/metrics/__init__.py @@ -22,12 +22,13 @@ from pruna.evaluation.metrics.metric_evalharness import LMEvalMetric from pruna.evaluation.metrics.metric_memory import DiskMemoryMetric, InferenceMemoryMetric, TrainingMemoryMetric from pruna.evaluation.metrics.metric_model_architecture import TotalMACsMetric, TotalParamsMetric -from pruna.evaluation.metrics.metric_pairwise_clip import PairwiseClipScore from pruna.evaluation.metrics.metric_oneig_alignment import OneIGAlignmentMetric +from pruna.evaluation.metrics.metric_oneig_reasoning import OneIGReasoningMetric +from pruna.evaluation.metrics.metric_pairwise_clip import PairwiseClipScore from pruna.evaluation.metrics.metric_qa_accuracy import QAAccuracyMetric -from pruna.evaluation.metrics.metric_text_score import OneIGTextScoreMetric, TextScoreMetric from pruna.evaluation.metrics.metric_rapiddata import RapidataMetric as RapidataMetric from pruna.evaluation.metrics.metric_sharpness import SharpnessMetric +from pruna.evaluation.metrics.metric_text_score import OneIGTextScoreMetric, TextScoreMetric from pruna.evaluation.metrics.metric_torch import TorchMetricWrapper from pruna.evaluation.metrics.vlm_base import ( BaseVLM, @@ -57,6 +58,7 @@ "AestheticLAION", "LMEvalMetric", "OneIGAlignmentMetric", + "OneIGReasoningMetric", "OneIGTextScoreMetric", "QAAccuracyMetric", "RapidataMetric", diff --git a/src/pruna/evaluation/metrics/metric_oneig_alignment.py b/src/pruna/evaluation/metrics/metric_oneig_alignment.py index 0f372f4f..a8827dd7 100644 --- a/src/pruna/evaluation/metrics/metric_oneig_alignment.py +++ b/src/pruna/evaluation/metrics/metric_oneig_alignment.py @@ -151,8 +151,6 @@ class OneIGAlignmentMetric(QAAccuracyMetric): (default ``2 x 2``), score **one question per VLM call** across all cells, apply dependency masking per cell, then average cell scores. - Scoring semantics - ----------------- OneIG Q_D probes are phrased so **Yes = aligned**. Each call requests :meth:`~pruna.evaluation.metrics.vlm_base.BaseVLM.score` with expected answer ``"Yes"`` (probability of Yes). Low scores act as semantic **No** for dependency @@ -178,11 +176,9 @@ class OneIGAlignmentMetric(QAAccuracyMetric): api_key : str | None, optional API key for litellm. call_type : str, optional - Call type for the metric. - aggregation : str, optional - Unused; kept for registry compatibility with :class:`QAAccuracyMetric`. + Call type for the metric (``"single"`` or ``"pairwise"``). **kwargs : Any - Additional keyword arguments for :class:`QAAccuracyMetric`. + Forwarded to :class:`QAAccuracyMetric` (e.g. ``aggregation``). Examples -------- @@ -199,7 +195,6 @@ class OneIGAlignmentMetric(QAAccuracyMetric): def __init__( self, - *args: Any, grid_size: tuple[int, int] = (2, 2), vlm: Any | None = None, vlm_type: Literal["litellm", "transformers"] = "transformers", @@ -212,7 +207,6 @@ def __init__( **kwargs: Any, ) -> None: super().__init__( - *args, vlm=vlm, vlm_type=vlm_type, model_name=model_name, @@ -220,10 +214,11 @@ def __init__( structured_output=structured_output, device=device, api_key=api_key, - call_type=call_type if call_type is not None else "y_gt", + call_type=call_type, **kwargs, ) self.grid_size = (int(grid_size[0]), int(grid_size[1])) + self.metric_units = type(self).metric_units def _score_sample(self, image: Any, aux: dict[str, Any]) -> float: if not isinstance(image, Image.Image): diff --git a/src/pruna/evaluation/metrics/metric_oneig_reasoning.py b/src/pruna/evaluation/metrics/metric_oneig_reasoning.py new file mode 100644 index 00000000..23889f0c --- /dev/null +++ b/src/pruna/evaluation/metrics/metric_oneig_reasoning.py @@ -0,0 +1,420 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""OneIG reasoning score via LLM2CLIP text-image similarity. + +Llama-derived checkpoints may require ``HF_TOKEN`` and ``huggingface-cli login``. + +Hugging Face download tuning (optional): + +- ``PRUNA_ONEIG_HF_VERBOSE=1`` or ``HF_DEBUG=1`` — hub **debug** logging and tqdm + progress bars (helps when stderr is piped; pair with ``python -u`` or + ``PYTHONUNBUFFERED=1`` for line-buffered output). +- ``PRUNA_ONEIG_HF_FAST_DOWNLOAD=1`` — enable **hf_transfer** multi-part downloads + (requires ``pruna[evaluation]``, which lists ``hf_transfer``). Alternatively, set + ``HF_HUB_ENABLE_HF_TRANSFER=1`` **before** starting Python so the hub picks it up at + import time. + +``transformers`` is pinned to ``<5`` in ``pyproject.toml``. The LLM2CLIP loading path +(``CLIPImageProcessor``, ``AutoModel``, ``LlamaEncoderModel``) is exercised on **4.x** +releases in CI and manual smoke runs. ``transformers`` 5.x has had reports of +``from_pretrained`` not fully initializing some non-persistent buffers (for example +``position_ids``) for certain architectures; the pin avoids that class of failures +until those issues are clearly resolved upstream. +""" + +from __future__ import annotations + +import os +from contextlib import contextmanager +from typing import Any, Iterator + +import torch +from PIL import Image + +from pruna.evaluation.metrics.metric_stateful import StatefulMetric +from pruna.evaluation.metrics.registry import MetricRegistry +from pruna.evaluation.metrics.result import MetricResult +from pruna.evaluation.metrics.utils import ( + SINGLE, + get_call_type_for_single_metric, + metric_data_processor, +) +from pruna.evaluation.metrics.vlm_utils import ( + _process_images, + _tensor_to_pil, + resolve_oneig_reasoning_device, + split_mxn_grid, +) +from pruna.logging.logger import pruna_logger + + +def _is_cuda_device(device: str) -> bool: + return device == "cuda" or device.startswith("cuda:") + + +@contextmanager +def _oneig_hf_download_env() -> Iterator[None]: + """Apply optional OneIG HF hub env tweaks, then restore prior values.""" + keys = ("HF_HUB_ENABLE_HF_TRANSFER",) + saved = {k: os.environ.get(k) for k in keys} + try: + _prepare_huggingface_hub_for_oneig_downloads() + yield + finally: + for key, val in saved.items(): + if val is None: + os.environ.pop(key, None) + else: + os.environ[key] = val + + +def _env_truthy(raw: str | None) -> bool: + if raw is None: + return False + return raw.strip().upper() in {"1", "ON", "YES", "TRUE"} + + +def _prepare_huggingface_hub_for_oneig_downloads() -> None: + """ + Apply Hugging Face Hub verbosity and optional fast downloads before checkpoints load. + + ``HF_HUB_ENABLE_HF_TRANSFER`` is read when ``huggingface_hub`` loads; if it was + false, we flip the in-module flag after importing ``hf_transfer`` when + ``PRUNA_ONEIG_HF_FAST_DOWNLOAD=1``. + """ + if _env_truthy(os.environ.get("PRUNA_ONEIG_HF_VERBOSE")) or _env_truthy(os.environ.get("HF_DEBUG")): + from huggingface_hub.utils import enable_progress_bars + from huggingface_hub.utils.logging import set_verbosity_debug + + set_verbosity_debug() + enable_progress_bars() + + if not _env_truthy(os.environ.get("PRUNA_ONEIG_HF_FAST_DOWNLOAD")): + return + + import hf_transfer # noqa: F401 # type: ignore[import-not-found] + import huggingface_hub.constants as hf_constants + + hf_constants.HF_HUB_ENABLE_HF_TRANSFER = True + pruna_logger.info("oneig_reasoning: enabled hf_transfer downloads (PRUNA_ONEIG_HF_FAST_DOWNLOAD=1).") + + +def _to_pil_list(images: list) -> list: + """Convert images to list of PIL.Image (RGB).""" + import numpy as np + from PIL import Image + + out: list = [] + for img in images: + if isinstance(img, Image.Image): + out.append(img.convert("RGB")) + elif isinstance(img, torch.Tensor): + if img.ndim == 4: + img = img[0] + if img.max() > 1: + img = img / 255.0 + np_img = (img.cpu().numpy() * 255).astype("uint8") + if np_img.shape[0] == 3: + np_img = np_img.transpose(1, 2, 0) + out.append(Image.fromarray(np_img)) + elif hasattr(img, "__array__"): + out.append(Image.fromarray(np.asarray(img)).convert("RGB")) + else: + out.append(img) + return out + + +class _LLM2CLIPScorer: + """ + Thin wrapper around LLM2CLIP text-image similarity. + + Accepts PIL images and a single answer string; returns per-image scores. + Best-effort alignment with OneIG-Benchmark scripts (CUDA + bfloat16). + """ + + def __init__( + self, + processor_model: str = "openai/clip-vit-large-patch14-336", + model_name: str = "microsoft/LLM2CLIP-Openai-L-14-336", + llm_model_name: str = "microsoft/LLM2CLIP-Llama-3-8B-Instruct-CC-Finetuned", + device: str = "cuda", + attn_implementation: str | None = None, + ) -> None: + self.processor_model = processor_model + self.model_name = model_name + self.llm_model_name = llm_model_name + self.device = device + self.attn_implementation = attn_implementation + self._processor = None + self._clip_model = None + self._l2v = None + + def _load_models(self) -> None: + if self._clip_model is not None: + return + with _oneig_hf_download_env(): + self._load_models_inner() + + def _load_models_inner(self) -> None: + from transformers import AutoConfig, AutoModel, AutoTokenizer, CLIPImageProcessor + + from pruna.evaluation.metrics.vendor.oneig_llm2vec import LLM2Vec + from pruna.evaluation.metrics.vendor.oneig_llm2vec.modeling_llama_encoder import LlamaEncoderModel + + pruna_logger.info( + "oneig_reasoning: downloading or loading LLM2CLIP checkpoints " + "(%s, %s). First run can take many minutes and several gigabytes; " + "Hugging Face download progress may look idle when logs are piped.", + self.model_name, + self.llm_model_name, + ) + dtype = torch.bfloat16 if _is_cuda_device(str(self.device)) else torch.float32 + self._processor = CLIPImageProcessor.from_pretrained(self.processor_model) + self._clip_model = AutoModel.from_pretrained( + self.model_name, + dtype=dtype, + trust_remote_code=True, + ).to(self.device) + self._clip_model.train(mode=False) + + config = AutoConfig.from_pretrained(self.llm_model_name, trust_remote_code=True) + if self.attn_implementation is not None: + # User override (e.g. sdpa, flash_attention_2, eager). Upstream OneIG leaves HF default. + config.attn_implementation = self.attn_implementation + if hasattr(config, "_attn_implementation"): + config._attn_implementation = self.attn_implementation + elif _is_cuda_device(str(self.device)): + config.attn_implementation = "sdpa" + if hasattr(config, "_attn_implementation"): + config._attn_implementation = "sdpa" + llm_model = LlamaEncoderModel.from_pretrained( + self.llm_model_name, + dtype=dtype, + config=config, + trust_remote_code=True, + ) + llm_model.config._name_or_path = "meta-llama/Meta-Llama-3-8B-Instruct" + tokenizer = AutoTokenizer.from_pretrained(self.llm_model_name) + self._l2v = LLM2Vec(llm_model, tokenizer, pooling_mode="mean", max_length=512, doc_max_length=512) + + def score(self, images: list, text_prompt: str) -> list[float] | None: + """ + Compute similarity scores between images and text. + + Parameters + ---------- + images : list + List of PIL.Image.Image. + text_prompt : str + Reference text (e.g. ground-truth answer). + + Returns + ------- + list[float] | None + Per-image scores, or None on failure. + """ + self._load_models() + pil_images = _to_pil_list(images) + if not pil_images: + return None + input_pixels = self._processor(images=pil_images, return_tensors="pt").pixel_values.to(self.device) + captions = [text_prompt] + with torch.no_grad(): + text_features = self._l2v.encode(captions, convert_to_tensor=True, device=self.device).to(self.device) + text_features = self._clip_model.get_text_features(text_features) + if _is_cuda_device(str(self.device)): + with torch.amp.autocast(device_type="cuda"): + image_features = self._clip_model.get_image_features(input_pixels) + else: + image_features = self._clip_model.get_image_features(input_pixels.float()) + + image_features = image_features.float() + text_features = text_features.float() + eps = 1e-8 + image_features /= image_features.norm(dim=-1, keepdim=True) + eps + text_features /= text_features.norm(dim=-1, keepdim=True) + eps + + text_probs = (image_features @ text_features.T).cpu().tolist() + return [p[0] for p in text_probs] + + +@MetricRegistry.register("oneig_reasoning") +class OneIGReasoningMetric(StatefulMetric): + """ + OneIG reasoning score: LLM2CLIP similarity between GT answer text and generated image. + + Uses ``reasoning_gt_answer`` from aux (populated by OneIG Knowledge_Reasoning loader; + language is chosen at dataset load via ``reasoning_language``). Splits a ``2 x 2`` grid + by default (OneIG ``split_mxn_grid``) and averages cell scores. Llama-derived checkpoints may require + ``HF_TOKEN`` and ``huggingface-cli login``. + + Parameters + ---------- + processor_model : str, optional + CLIP processor model ID. + model_name : str, optional + LLM2CLIP model ID. + llm_model_name : str, optional + LLM2Vec model ID. + grid_size : tuple[int, int], optional + ``(columns, rows)`` for :func:`~pruna.evaluation.metrics.vlm_utils.split_mxn_grid``. + attn_implementation : str | None, optional + Transformers attention backend for the LLM encoder (``None`` uses HF default; + Pruna defaults to ``sdpa`` on CUDA when unset). + device : str | torch.device | None, optional + Device for inference. + scorer : _LLM2CLIPScorer | None, optional + Optional scorer instance for testing (injected mock). + call_type : str, optional + Call type for the metric. + **kwargs : Any + Additional keyword arguments for :class:`StatefulMetric`. + + Notes + ----- + Prompt benchmarks yield ``(prompts, aux_list)``. With default ``call_type`` + ``y_gt``, ``aux_list`` is the list (or tensor coerced to a list) of per-sample + dicts parallel to generated images. Each dict must include a non-empty + ``reasoning_gt_answer`` for Knowledge/Reasoning samples. Missing GT, scorer + failures, or :meth:`compute` with no scored samples raise ``ValueError`` or + ``RuntimeError`` instead of returning a placeholder score. + """ + + metric_name: str = "oneig_reasoning" + default_call_type: str = "y_gt" + higher_is_better: bool = True + runs_on: list[str] = ["cuda", "cpu"] + + def __init__( + self, + processor_model: str = "openai/clip-vit-large-patch14-336", + model_name: str = "microsoft/LLM2CLIP-Openai-L-14-336", + llm_model_name: str = "microsoft/LLM2CLIP-Llama-3-8B-Instruct-CC-Finetuned", + grid_size: tuple[int, int] = (2, 2), + attn_implementation: str | None = None, + device: str | torch.device | None = None, + scorer: _LLM2CLIPScorer | None = None, + call_type: str | None = None, + **kwargs: Any, + ) -> None: + resolved = resolve_oneig_reasoning_device(device) + super().__init__(device=resolved, **kwargs) + self.device = resolved + self.call_type = get_call_type_for_single_metric( + call_type if call_type is not None else SINGLE, self.default_call_type + ) + self.processor_model = processor_model + self.model_name = model_name + self.llm_model_name = llm_model_name + self.grid_size = (int(grid_size[0]), int(grid_size[1])) + self.attn_implementation = attn_implementation + self._injected_scorer = scorer + self._llm2clip_scorer: _LLM2CLIPScorer | None = None + self.add_state("scores", default=[]) + + def _get_scorer(self) -> _LLM2CLIPScorer: + if self._injected_scorer is not None: + return self._injected_scorer + if self._llm2clip_scorer is None: + self._llm2clip_scorer = _LLM2CLIPScorer( + processor_model=self.processor_model, + model_name=self.model_name, + llm_model_name=self.llm_model_name, + device=str(self.device), + attn_implementation=self.attn_implementation, + ) + return self._llm2clip_scorer + + def reset(self) -> None: + """Clear scores and release cached LLM2CLIP weights.""" + super().reset() + self._llm2clip_scorer = None + + def _get_gt_text(self, aux: dict) -> str: + val = aux.get("reasoning_gt_answer") + if val is None or (isinstance(val, str) and not val.strip()): + raise ValueError( + "oneig_reasoning requires 'reasoning_gt_answer' in aux for Knowledge_Reasoning rows. " + f"Got keys: {list(aux.keys())}." + ) + return str(val).strip() + + def update(self, x: list[Any] | torch.Tensor, gt: torch.Tensor, outputs: torch.Tensor) -> None: + """ + Score each image against its GT answer text via LLM2CLIP similarity. + + Parameters + ---------- + x : list[Any] | torch.Tensor + Unused batch metadata. + gt : torch.Tensor + Ground-truth slot with per-sample aux dicts containing ``reasoning_gt_answer``. + outputs : torch.Tensor + Model outputs (generated images). + + Raises + ------ + ValueError + If a per-sample aux entry is not a dict or lacks a non-empty + ``reasoning_gt_answer``. + RuntimeError + If the LLM2CLIP scorer returns no scores for a sample. + """ + inputs = metric_data_processor(x, gt, outputs, self.call_type) + images = _process_images(inputs[0]) + aux_slot = inputs[1] if len(inputs) > 1 else [] + if isinstance(aux_slot, torch.Tensor): + raise ValueError("oneig_reasoning expects gt as list[dict] with 'reasoning_gt_answer'.") + + scorer = self._get_scorer() + + for i, image in enumerate(images): + aux_row = aux_slot[i] if isinstance(aux_slot, (list, tuple)) and i < len(aux_slot) else {} + if not isinstance(aux_row, dict): + raise ValueError(f"oneig_reasoning requires aux[{i}] to be a dict. Got: {type(aux_row)}.") + text = self._get_gt_text(aux_row) + if isinstance(image, Image.Image): + pil = image.convert("RGB") + elif isinstance(image, torch.Tensor): + pil = _tensor_to_pil(image) + else: + pil = Image.fromarray(image).convert("RGB") + cells = split_mxn_grid(pil, self.grid_size) + result = scorer.score(cells, text) + if result is None or len(result) == 0: + raise RuntimeError(f"oneig_reasoning: LLM2CLIP scorer returned no scores for sample {i}.") + self.scores.append(float(sum(result) / len(result))) + + def compute(self) -> MetricResult: + """ + Compute the mean reasoning score across all samples. + + Returns + ------- + MetricResult + Mean LLM2CLIP similarity. + + Raises + ------ + RuntimeError + If :meth:`update` was not called or scored no samples. + """ + if not self.scores: + raise RuntimeError( + "oneig_reasoning: no samples were scored; call update() with valid " + "batches and non-empty reasoning_gt_answer before compute()." + ) + mean_score = sum(self.scores) / len(self.scores) + return MetricResult(self.metric_name, self.__dict__, float(mean_score)) diff --git a/src/pruna/evaluation/metrics/metric_qa_accuracy.py b/src/pruna/evaluation/metrics/metric_qa_accuracy.py index f954c0eb..ba5ed118 100644 --- a/src/pruna/evaluation/metrics/metric_qa_accuracy.py +++ b/src/pruna/evaluation/metrics/metric_qa_accuracy.py @@ -55,8 +55,6 @@ class QAAccuracyMetric(StatefulVLMMeanScoresMetric): Parameters ---------- - *args : Any - Additional positional arguments. vlm : BaseVLM | None, optional Custom VLM instance. If provided, ``vlm_type`` and ``model_name`` are ignored. vlm_type : {"litellm", "transformers"}, optional @@ -76,8 +74,10 @@ class QAAccuracyMetric(StatefulVLMMeanScoresMetric): API key for litellm. call_type : str, optional Call type for the metric. + aggregation : {"mean", "all_or_nothing"}, optional + Per-image score aggregation (keyword-only). Default is ``"mean"``. **kwargs : Any - Supports ``aggregation``: ``"mean"`` or ``"all_or_nothing"``. + Additional keyword arguments forwarded to the parent class. Raises ------ @@ -111,7 +111,6 @@ class QAAccuracyMetric(StatefulVLMMeanScoresMetric): def __init__( self, - *args, vlm: BaseVLM | None = None, vlm_type: Literal["litellm", "transformers"] = "litellm", model_name: str | None = None, @@ -119,7 +118,7 @@ def __init__( structured_output: bool = True, device: str | torch.device | None = None, api_key: str | None = None, - call_type: str = SINGLE, + call_type: str | None = None, *, aggregation: str = "mean", **kwargs: Any, @@ -139,7 +138,7 @@ def __init__( structured_output=structured_output, device=device, api_key=api_key, - call_type=call_type, + call_type=call_type if call_type is not None else SINGLE, ) def _extract_questions(self, gt: Any, n: int) -> list[list[str]]: diff --git a/src/pruna/evaluation/metrics/metric_torch.py b/src/pruna/evaluation/metrics/metric_torch.py index 4d329d86..b2c16f00 100644 --- a/src/pruna/evaluation/metrics/metric_torch.py +++ b/src/pruna/evaluation/metrics/metric_torch.py @@ -50,6 +50,26 @@ ) from pruna.logging.logger import pruna_logger +_PRUNA_TASK_ROUTING_KWARGS: tuple[str, ...] = ( + "vlm_type", + "model_name", + "structured_output", + "vlm_kwargs", + "api_key", +) + + +def _strip_task_routing_kwargs(kwargs: dict[str, Any]) -> None: + """ + Drop kwargs :class:`~pruna.evaluation.task.Task` passes when building mixed metric lists. + + Torchmetrics classes often end with ``**kwargs`` and would otherwise accept bogus keys + until a lower layer raises. Stripping here keeps :class:`TorchMetricWrapper` the single + choke point between Pruna routing and torchmetrics constructors. + """ + for key in _PRUNA_TASK_ROUTING_KWARGS: + kwargs.pop(key, None) + def default_update(metric: Metric, *args, **kwargs) -> None: """ @@ -124,9 +144,7 @@ def arniqa_update(metric: ARNIQA, preds: Any) -> None: def ssim_update( - metric: StructuralSimilarityIndexMeasure | MultiScaleStructuralSimilarityIndexMeasure, - preds: Any, - target: Any + metric: StructuralSimilarityIndexMeasure | MultiScaleStructuralSimilarityIndexMeasure, preds: Any, target: Any ) -> None: """ Update handler for SSIM or MS-SSIM metric. @@ -152,29 +170,22 @@ class TorchMetrics(Enum): """ Enumeration of torchmetrics metrics for evaluation. - This enum provides a tuple per member (metric_factory, update_fn, call_type): - metric_factory builds the metric (typically a torchmetrics class, or - functools.partial when some constructor arguments are fixed); update_fn is - an optional custom update handler; call_type describes how inputs are paired - for the metric. + Each member value is a ``(metric_factory, update_fn, call_type)`` tuple. Parameters ---------- value : tuple - Tuple holding metric_factory, update_fn, and call_type as described above. + ``(metric_factory, update_fn, call_type)`` for this enum member. names : str - The name of the enum member. + Enum member name. module : str - The module where the enum is defined. + Defining module name. qualname : str - The qualified name of the enum. + Qualified name of the enum class. type : type - The type of the enum. + Enum metaclass type. start : int - The start index for auto-numbering enum values. - boundary : enum.FlagBoundary or None - Boundary handling mode used by the Enum functional API for Flag and - IntFlag enums. + Auto-numbering start index for functional API enums. """ fid = (FrechetInceptionDistance, fid_update, "gt_y") @@ -246,6 +257,7 @@ def __new__(cls, metric_name: str, call_type: str = "", **kwargs) -> StatefulMet if metric_name == "clip_score" and call_type.startswith(PAIRWISE): from pruna.evaluation.metrics.metric_pairwise_clip import PairwiseClipScore + _strip_task_routing_kwargs(kwargs) return PairwiseClipScore(**kwargs) return super().__new__(cls) @@ -259,6 +271,7 @@ def __init__(self, metric_name: str, call_type: str = "", **kwargs) -> None: If the metric name is not supported. """ self.metric_name = metric_name + _strip_task_routing_kwargs(kwargs) super().__init__(kwargs.pop("device", None)) try: self.metric = TorchMetrics[metric_name](**kwargs) diff --git a/src/pruna/evaluation/metrics/registry.py b/src/pruna/evaluation/metrics/registry.py index 5efd721a..650f8a76 100644 --- a/src/pruna/evaluation/metrics/registry.py +++ b/src/pruna/evaluation/metrics/registry.py @@ -14,6 +14,7 @@ from __future__ import annotations +import importlib from functools import partial from inspect import isclass from typing import Any, Callable, Dict, Iterable, List @@ -29,9 +30,17 @@ class MetricRegistry: Registry for metrics. The registry is a dictionary that maps metric names to metric classes. + + Notes + ----- + ``_lazy_metrics`` lists names that :meth:`has_metric` treats as registered before the + implementing module is loaded. The ``oneig_reasoning`` metric imports the LLM2CLIP-related + stack (vendored helpers, heavy optional dependencies); it is imported only when + :meth:`get_metric` is called with that name so other code paths avoid that cost. """ _registry: Dict[str, Callable[..., Any]] = {} + _lazy_metrics: frozenset[str] = frozenset({"oneig_reasoning"}) @classmethod def register(cls, name: str) -> Callable[[Callable[..., Any]], Callable[..., Any]]: @@ -104,7 +113,7 @@ def has_metric(cls, name: str) -> bool: bool True if the metric is registered, False otherwise. """ - return name in cls._registry + return name in cls._registry or name in cls._lazy_metrics @classmethod def get_metric(cls, name: str, **kwargs) -> BaseMetric | StatefulMetric: @@ -122,6 +131,9 @@ def get_metric(cls, name: str, **kwargs) -> BaseMetric | StatefulMetric: ------- The metric instance. """ + if name in cls._lazy_metrics and name not in cls._registry: + importlib.import_module("pruna.evaluation.metrics.metric_oneig_reasoning") + if name not in cls._registry: raise ValueError(f"Metric '{name}' is not registered.") diff --git a/tests/conftest.py b/tests/conftest.py index 80d54825..6dff757b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,14 @@ +import os from typing import Any import pytest +if os.environ.get("PRUNA_CI_CPU_ONLY") == "1": + import torch + + if hasattr(torch.backends, "mps"): + torch.backends.mps.is_available = lambda: False # type: ignore[method-assign] + # import all fixtures to make them avaliable for pytest from .fixtures import * # noqa: F403, F401 diff --git a/tests/evaluation/test_text_metrics.py b/tests/evaluation/test_text_metrics.py index a5931bae..d566390d 100644 --- a/tests/evaluation/test_text_metrics.py +++ b/tests/evaluation/test_text_metrics.py @@ -120,20 +120,3 @@ def test_oneig_alignment_all_padding_questions_yields_zero_without_vlm() -> None assert metric.compute().result == 0.0 mock_vlm.score.assert_not_called() - -def test_to_oneig_record_strips_null_questions_and_dependencies() -> None: - """Null-valued Q_D entries are filtered out at record construction time.""" - row = {"category": "Anime_Stylization", "id": "001", "class": "None", "prompt_en": "a cat"} - questions_by_key = { - "anime_001": { - "questions": {"1": "Is there a cat?", "21": None}, - "dependencies": {"1": [0], "21": None}, - } - } - record = _to_oneig_record(row, questions_by_key, {}, {}) - assert "21" not in record["questions"] - assert "21" not in record["dependencies"] - assert record["questions"] == {"1": "Is there a cat?"} - assert record["dependencies"] == {"1": [0]} - - diff --git a/tests/evaluation/test_vlm_base_infrastructure.py b/tests/evaluation/test_vlm_base_infrastructure.py index a4eaa139..b6ac9b1c 100644 --- a/tests/evaluation/test_vlm_base_infrastructure.py +++ b/tests/evaluation/test_vlm_base_infrastructure.py @@ -1,50 +1,12 @@ -"""Tests for VLM metrics (VQA, ImageEditScore, QAAccuracy, TextScore, VieScore) and vlm_utils helpers.""" +"""Tests for VLM base classes and vlm_utils (infrastructure PR only).""" from unittest.mock import MagicMock, patch import pytest import torch -from pruna.evaluation.metrics.metric_img_edit_score import ImageEditScoreMetric -from pruna.evaluation.metrics.metric_oneig_alignment import OneIGAlignmentMetric -from pruna.evaluation.metrics.metric_qa_accuracy import QAAccuracyMetric -from pruna.evaluation.metrics.metric_text_score import OneIGTextScoreMetric, TextScoreMetric -from pruna.evaluation.metrics.metric_vie_score import VieScoreMetric -from pruna.evaluation.metrics.metric_vqa import VQAMetric -from pruna.evaluation.metrics.result import MetricResult -from pruna.evaluation.metrics.vlm_base import BaseVLM, get_vlm -from pruna.evaluation.metrics.vlm_utils import ( - FloatOutput, - VLM_AUX_IMAGE_BYTES_KEY_ORDER, - get_score_from_response, - yes_no_first_token_id_groups, -) - -from ._vlm_batch_snapshot_helpers import ( - BenchmarkVlmBatchOutcome, - pred_tensor_from_auxiliaries, - safe_json_for_snapshot, - vlm_benchmark_batch_to_json_record, -) - -SMOL_VLM = "HuggingFaceTB/SmolVLM-256M-Instruct" - -_ALL_VLM = ( - VQAMetric, - ImageEditScoreMetric, - QAAccuracyMetric, - OneIGAlignmentMetric, - TextScoreMetric, - OneIGTextScoreMetric, - VieScoreMetric, -) - -_SLOW_SMOL_SUBSET = ( - VQAMetric, - OneIGAlignmentMetric, - ImageEditScoreMetric, - VieScoreMetric, -) +from pruna.evaluation.metrics.vlm_base import BaseVLM, LitellmVLM, get_vlm +from pruna.evaluation.metrics.vlm_utils import FloatOutput, get_score_from_response, yes_no_first_token_id_groups @pytest.mark.parametrize( @@ -64,115 +26,6 @@ def test_get_score_from_response(raw: object, expected: float) -> None: assert get_score_from_response(raw) == pytest.approx(expected) -def _dummy_image(batch: int = 1, size: int = 224) -> torch.Tensor: - return torch.rand(batch, 3, size, size) - - -def _update_metric(metric: object, prompts: list, images: torch.Tensor) -> None: - if isinstance(metric, OneIGAlignmentMetric): - metric.update( - prompts, - [ - { - "questions": {"1": "Is there a cat?", "2": "Is it sleeping?"}, - "dependencies": {"1": [0], "2": [1]}, - } - ], - images, - ) - elif isinstance(metric, QAAccuracyMetric): - metric.update( - prompts, - [{"questions": {"1": "Is there a cat?"}}], - images, - ) - elif isinstance(metric, (TextScoreMetric, OneIGTextScoreMetric)): - metric.update(prompts, ["cat"], images) - else: - metric.update(prompts, images, images) - - -@pytest.mark.cpu -@pytest.mark.slow -@pytest.mark.parametrize("metric_cls", _SLOW_SMOL_SUBSET) -def test_vlm_metrics_transformers_smolvlm(metric_cls: type) -> None: - """Smoke-test a subset with local SmolVLM (full matrix covered by litellm mock).""" - metric = metric_cls( - vlm_type="transformers", - model_name=SMOL_VLM, - device="cpu", - structured_output=True, - ) - images = _dummy_image(batch=1) - prompts = ["a cat"] - _update_metric(metric, prompts, images) - result = metric.compute() - assert result.name == metric.metric_name - assert isinstance(result.result, float) - if metric.higher_is_better: - assert 0.0 <= result.result <= 1.0 - else: - assert result.result >= 0.0 - - -@pytest.mark.cpu -@pytest.mark.parametrize("metric_cls", _ALL_VLM) -def test_vlm_metrics_litellm_mocked(metric_cls: type) -> None: - """Each VLM metric runs end-to-end with mocked litellm.""" - pytest.importorskip("litellm") - mock_response = MagicMock() - mock_response.choices = [MagicMock()] - if metric_cls in (VQAMetric, QAAccuracyMetric, OneIGAlignmentMetric): - mock_response.choices[0].message.content = '{"answer": "Yes"}' - else: - mock_response.choices[0].message.content = '{"score": 8}' - - with patch("litellm.completion") as mock_completion: - mock_completion.return_value = mock_response - - metric = metric_cls( - vlm_type="litellm", - model_name="gpt-4o", - device="cpu", - structured_output=True, - ) - images = _dummy_image(batch=1) - prompts = ["a cat"] - _update_metric(metric, prompts, images) - result = metric.compute() - - assert result.name == metric.metric_name - assert isinstance(result.result, float) - assert mock_completion.called - - -@pytest.mark.cpu -def test_vlm_metrics_empty_compute_returns_zero() -> None: - """No updates → compute is 0.0 (same for all stateful VLM metrics).""" - metric = VQAMetric( - vlm_type="transformers", - model_name=SMOL_VLM, - device="cpu", - structured_output=True, - ) - assert metric.compute().result == 0.0 - - -@pytest.mark.cpu -def test_vlm_metrics_custom_vlm() -> None: - """Custom VLM passed to VQAMetric is used instead of the default litellm backend.""" - mock_vlm = MagicMock(spec=BaseVLM) - mock_vlm.generate.return_value = ["Yes"] - mock_vlm.score.return_value = [1.0] - - metric = VQAMetric(vlm=mock_vlm, vlm_type="litellm", device="cpu", structured_output=True) - images = _dummy_image(batch=1) - prompts = ["a cat"] - metric.update(prompts, images, images) - assert metric.compute().result == 1.0 - mock_vlm.score.assert_called() - - @pytest.mark.cpu def test_get_vlm_returns_custom() -> None: """get_vlm returns the provided VLM instance unchanged.""" @@ -200,286 +53,15 @@ def test_get_vlm_requires_model_name_without_vlm() -> None: get_vlm(vlm=None, vlm_type="litellm") -@pytest.mark.cpu -@pytest.mark.parametrize( - "metric_cls, expected_name, expected_result", - [ - (TextScoreMetric, "text_score", 1.0), - (OneIGTextScoreMetric, "oneig_text_score", 1.0), - ], -) -def test_text_metrics_list_str_gt(metric_cls: type, expected_name: str, expected_result: float) -> None: - """Text metrics accept plain string ground-truth and return the expected score.""" - mock_vlm = MagicMock(spec=BaseVLM) - mock_vlm.generate.return_value = ["hello world"] - - metric = metric_cls(vlm=mock_vlm, vlm_type="litellm", device="cpu") - images = _dummy_image(batch=1) - metric.update(["a prompt"], ["hello world"], images) - result = metric.compute() - - assert result.result == expected_result - assert result.name == expected_name - mock_vlm.generate.assert_called_once() - - -@pytest.mark.cpu -def test_text_score_result_in_zero_one_range() -> None: - """TextScoreMetric must return a normalized score in [0, 1], not raw edit distance.""" - mock_vlm = MagicMock(spec=BaseVLM) - # VLM OCR returns something very different from ground truth (high edit distance) - mock_vlm.generate.return_value = ["completely wrong text abcdefghijklmnop"] - - metric = TextScoreMetric(vlm=mock_vlm, device="cpu") - images = _dummy_image(batch=1) - metric.update(["prompt"], ["hello"], images) - result = metric.compute() - - assert 0.0 <= result.result <= 1.0, f"TextScoreMetric must return [0,1], got {result.result}" - assert result.result < 0.5, f"Very different strings should score below 0.5, got {result.result}" - - -@pytest.mark.cpu -def test_text_score_perfect_match_is_one() -> None: - """TextScoreMetric: identical OCR and GT -> score 1.0.""" - mock_vlm = MagicMock(spec=BaseVLM) - mock_vlm.generate.return_value = ["hello world"] - - metric = TextScoreMetric(vlm=mock_vlm, device="cpu") - images = _dummy_image(batch=1) - metric.update(["prompt"], ["hello world"], images) - result = metric.compute() - - assert result.result == 1.0, f"Perfect match should give 1.0, got {result.result}" - assert result.higher_is_better is True - - -@pytest.mark.cpu -def test_text_score_registry_aliases() -> None: - """Registry aliases ocr_levenshtein and ocr_text_score resolve to the correct metric classes.""" - from pruna.evaluation.metrics.registry import MetricRegistry - - lev = MetricRegistry.get_metric("ocr_levenshtein", device="cpu", model_name="openai/gpt-4o") - comp = MetricRegistry.get_metric("ocr_text_score", device="cpu", model_name="openai/gpt-4o") - assert type(lev).__name__ == "TextScoreMetric" - assert type(comp).__name__ == "OneIGTextScoreMetric" - assert lev.metric_name == "text_score" - assert comp.metric_name == "oneig_text_score" - - -@pytest.mark.cpu -def test_oneig_text_score_utils_golden_composite() -> None: - """oneig_mean_text_score returns expected component values for a known input.""" - from pruna.evaluation.metrics.metric_text_score_utils import oneig_mean_text_score - - ed, cr, wac, composite = oneig_mean_text_score( - edit_distances=[10.0], - completion_ratios=[0.0], - match_counts=[2], - gt_totals=[4], - language_mode="EN", - ) - assert ed == 10.0 - assert cr == 0.0 - assert wac == 0.5 - assert composite == pytest.approx(0.95) - - _, _, _, zh = oneig_mean_text_score( - edit_distances=[30.0], - completion_ratios=[0.0], - match_counts=[0], - gt_totals=[1], - language_mode="ZH", - ) - assert zh == pytest.approx(0.4) - - -@pytest.mark.cpu -def test_qa_accuracy_all_or_nothing_partial_fail() -> None: - """all_or_nothing: if any question scores 0, the image score is 0.0 (not a partial mean).""" - mock_vlm = MagicMock(spec=BaseVLM) - # First question Yes (1.0), second question No (0.0) → mean=0.5, all_or_nothing=0.0 - mock_vlm.score.return_value = [1.0, 0.0] - - metric = QAAccuracyMetric(vlm=mock_vlm, device="cpu", aggregation="all_or_nothing") - metric.update( - ["a prompt"], - [{"questions": {"1": "Is there a cat?", "2": "Is it blue?"}}], - _dummy_image(batch=1), - ) - result = metric.compute() - assert result.result == 0.0, f"Expected 0.0 for all_or_nothing with one No, got {result.result}" - - -@pytest.mark.cpu -def test_qa_accuracy_all_or_nothing_all_yes() -> None: - """all_or_nothing: all Yes → score 1.0.""" - mock_vlm = MagicMock(spec=BaseVLM) - mock_vlm.score.return_value = [1.0, 1.0] - - metric = QAAccuracyMetric(vlm=mock_vlm, device="cpu", aggregation="all_or_nothing") - metric.update( - ["a prompt"], - [{"questions": {"1": "Is there a cat?", "2": "Is it blue?"}}], - _dummy_image(batch=1), - ) - result = metric.compute() - assert result.result == 1.0, f"Expected 1.0 for all_or_nothing with all Yes, got {result.result}" - - -@pytest.mark.cpu -def test_qa_accuracy_invalid_aggregation_raises() -> None: - """qa_accuracy rejects aggregation values other than mean / all_or_nothing.""" - mock_vlm = MagicMock(spec=BaseVLM) - with pytest.raises(ValueError, match="aggregation"): - QAAccuracyMetric(vlm=mock_vlm, device="cpu", aggregation="median") - - -@pytest.mark.cpu -def test_vie_score_tie_uses_source_from_gt_and_two_image_sc() -> None: - """With ``source_image_bytes`` in gt, VieScore calls two-image SC then PQ on the edited image.""" - from io import BytesIO - - from PIL import Image - - buf = BytesIO() - Image.new("RGB", (8, 8), color=(0, 0, 200)).save(buf, format="PNG") - src_bytes = buf.getvalue() - - mock_vlm = MagicMock() - mock_vlm.generate_with_image_lists.return_value = ['{"score": [8.0, 8.0], "reasoning": "ok"}'] - mock_vlm.generate.return_value = ['{"score": [9.0, 9.0], "reasoning": "ok"}'] - - metric = VieScoreMetric(vlm=mock_vlm, device="cpu", structured_output=True) - pred = _dummy_image(batch=1) - metric.update( - ["make the sky purple"], - [{"source_image_bytes": src_bytes}], - pred, - ) - result = metric.compute() - - assert mock_vlm.generate_with_image_lists.called - assert mock_vlm.generate.called - assert 0.0 <= result.result <= 1.0 - - -@pytest.mark.cpu -def test_vie_score_uses_get_score_from_response() -> None: - """VieScoreMetric ``t2i`` path parses JSON ``score`` lists via ``viescore_min_scores_0_10``.""" - mock_vlm = MagicMock(spec=BaseVLM) - # LitellmVLM returns model_dump_json() for structured outputs → JSON string (two SC + two PQ sub-scores) - mock_vlm.generate.return_value = ['{"score": [8.0, 8.0], "reasoning": ""}'] - - metric = VieScoreMetric(vlm=mock_vlm, device="cpu", structured_output=True) - metric.update(["a cat on a sofa"], _dummy_image(batch=1), _dummy_image(batch=1)) - result = metric.compute() - - # min(SC)=8, min(PQ)=8 → sqrt(8 * 8) / 10 = 0.8 - assert abs(result.result - 0.8) < 0.01, f"Expected ~0.8, got {result.result}" - - -@pytest.mark.cpu -def test_img_edit_score_negative_response_clamped() -> None: - """img_edit_score must be non-negative even when the VLM generates a negative JSON score. - - Regression for: Outlines constrained decoding can emit {"score": -10} despite the - FloatOutput JSON schema specifying minimum=0, because Outlines does not enforce numeric - bounds during token sampling. The fix is max(0.0, ...) in get_score_from_response. - """ - mock_vlm = MagicMock(spec=BaseVLM) - # Simulate Outlines generating a negative value (the bug scenario) - mock_vlm.generate.return_value = ['{"score": -10.0}'] - - metric = ImageEditScoreMetric(vlm=mock_vlm, device="cpu", structured_output=True) - metric.update(["replace the boot with a mug"], torch.zeros(1), _dummy_image(batch=1)) - result = metric.compute() - - assert result.result >= 0.0, f"img_edit_score must be >= 0, got {result.result}" - - -@pytest.mark.cpu -def test_qa_accuracy_all_or_nothing_ambiguous_score() -> None: - """all_or_nothing: score exactly 0.5 (ambiguous) is treated as No → result 0.0.""" - mock_vlm = MagicMock(spec=BaseVLM) - mock_vlm.score.return_value = [0.5] - - metric = QAAccuracyMetric(vlm=mock_vlm, device="cpu", aggregation="all_or_nothing") - metric.update( - ["a prompt"], - [{"questions": {"1": "Is there a cat?"}}], - _dummy_image(batch=1), - ) - result = metric.compute() - assert result.result == 0.0, f"Score 0.5 should be treated as No (ambiguous), got {result.result}" - - -@pytest.mark.cpu -@pytest.mark.slow -def test_yes_no_token_ids_smolvlm_nonempty() -> None: - """SmolVLM tokenizer must yield non-empty disjoint yes/no prefix ids for VQAScore scoring.""" - pytest.importorskip("transformers") - from transformers import AutoTokenizer - - tok = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolVLM-256M-Instruct") - yes_ids, no_ids = yes_no_first_token_id_groups(tok) - assert len(yes_ids) > 0, "SmolVLM tokenizer has no 'Yes'-prefix token ids" - assert len(no_ids) > 0, "SmolVLM tokenizer has no 'No'-prefix token ids" - assert not (set(yes_ids) & set(no_ids)), "yes_ids and no_ids must be disjoint" - - -@pytest.mark.cpu -def test_img_edit_score_uses_prompt_from_x() -> None: - """img_edit_score must score the edited image against the instruction from x, not gt.""" - mock_vlm = MagicMock(spec=BaseVLM) - mock_vlm.generate.return_value = ['{"score": 9}'] - - metric = ImageEditScoreMetric(vlm=mock_vlm, device="cpu") - pred = _dummy_image(batch=1) - metric.update( - ["replace the cat with a dog"], # x = instruction - pred, # gt = unused for y_x - pred, # outputs = edited image - ) - result = metric.compute() - - call_args = mock_vlm.generate.call_args - prompt_sent = call_args[0][1][0] # second positional arg = prompts list, first item - assert "replace the cat with a dog" in prompt_sent, f"Instruction not in VLM prompt. Got: {prompt_sent}" - assert abs(result.result - 0.9) < 0.01, f"Expected ~0.9, got {result.result}" - - -@pytest.mark.cpu -def test_vie_score_geditbench_gap_documented() -> None: - """VieScoreMetric infers text--image editing from ``source_image_bytes`` in aux (no ``task_type``). - - This test fails if a ``task_type`` parameter is added to ``__init__`` without updating - GEditBench integration tests and benchmark copy accordingly. - """ - import inspect - - sig = inspect.signature(VieScoreMetric.__init__) - assert "task_type" not in sig.parameters, ( - "VieScoreMetric now has task_type — update GEditBench docs and e2e tests, then remove this sentinel." - ) - - @pytest.mark.cpu def test_litellm_logprob_aggregation_sums_all_yes_tokens() -> None: """LitellmVLM logprob scoring must sum all yes-prefix token probs, not return the first.""" pytest.importorskip("litellm") import math - from unittest.mock import MagicMock, patch import numpy as np from PIL import Image - from pruna.evaluation.metrics.vlm_base import LitellmVLM - - # Simulate top_logprobs for first output token: - # "Yes" → logprob=-2.303 (p≈0.10), " yes" → logprob=-2.996 (p≈0.05) → total p_yes≈0.15 - # "No" → logprob=-1.609 (p≈0.20), " no" → logprob=-2.303 (p≈0.10) → total p_no≈0.30 - # normalized: p_yes/(p_yes+p_no) ≈ 0.15/0.45 ≈ 0.333 def make_top_logprob(token, logprob): t = MagicMock() t.token = token @@ -510,175 +92,17 @@ def make_top_logprob(token, logprob): img = Image.fromarray(np.zeros((32, 32, 3), dtype="uint8")) score = vlm._score_with_logprobs(img, "Is there a cat?", "Yes") - # Should be ~0.333 (p_yes=0.15 / (p_yes+p_no)=0.45), not just 0.10 (first match) assert 0.28 < score < 0.40, f"Expected ~0.333 (sum-normalized), got {score}" @pytest.mark.cpu @pytest.mark.slow -def test_vqa_probability_score_normalized() -> None: - """P(Yes) from TransformersVLM.score use_probability=True is in [0, 1].""" +def test_yes_no_token_ids_smolvlm_nonempty() -> None: + """SmolVLM tokenizer yields non-empty yes/no prefix id groups.""" pytest.importorskip("transformers") - import numpy as np - from PIL import Image - - from pruna.evaluation.metrics.vlm_base import TransformersVLM - - vlm = TransformersVLM( - model_name="HuggingFaceTB/SmolVLM-256M-Instruct", - device="cpu", - use_outlines=False, - ) - img = Image.fromarray(np.zeros((32, 32, 3), dtype="uint8")) - scores = vlm.score([img], ["Is there a cat?"], ["Yes"], use_probability=True) - assert len(scores) == 1 - assert 0.0 <= scores[0] <= 1.0, f"P(Yes) must be in [0, 1], got {scores[0]}" - - -# --------------------------------------------------------------------------- -# vlm_benchmark_batch_to_json_record serialization tests -# --------------------------------------------------------------------------- - - -def test_vlm_benchmark_batch_to_json_record_serializes_batch() -> None: - """Record includes prompts, pred shape, and metric fields.""" - mr = MetricResult(name="qa_accuracy", params={}, result=0.25, higher_is_better=True) - outcome = BenchmarkVlmBatchOutcome( - result=mr, - prompts=["prompt"], - auxiliaries=[{"path": "/tmp/x.png"}], - pred=torch.zeros(1, 3, 8, 8), - ) - rec = vlm_benchmark_batch_to_json_record( - outcome, - benchmark_key="GenEval", - benchmark_name="GenEval", - metric_name="qa_accuracy", - vlm_type="transformers", - model_name="m", - device="cpu", - ) - assert rec["inputs"]["prompts"] == ["prompt"] - assert rec["pred"]["shape"] == [1, 3, 8, 8] - assert rec["metric_result"]["result"] == 0.25 - - -def test_safe_json_handles_bytes_without_expanding() -> None: - """Bytes values in aux (e.g. source_image_bytes) are summarized, not expanded to str repr.""" - result = safe_json_for_snapshot({"source_image_bytes": b"\xff\xd8\xff" * 1000, "name": "test"}) - assert result["source_image_bytes"] == {"bytes_len": 3000} - assert result["name"] == "test" - - -def test_vlm_benchmark_batch_to_json_record_preserves_null_question_slots() -> None: - """Padded ``None`` question labels stay JSON null, not the string ``"None"``.""" - mr = MetricResult(name="oneig_alignment", params={}, result=1.0, higher_is_better=True) - outcome = BenchmarkVlmBatchOutcome( - result=mr, - prompts=["p"], - auxiliaries=[{"questions": {"1": "Are there boys?", "21": None}, "subset": "Anime_Stylization"}], - pred=torch.zeros(1, 3, 8, 8), - ) - rec = vlm_benchmark_batch_to_json_record( - outcome, - benchmark_key="OneIGAnimeStylization", - benchmark_name="OneIG Anime Stylization", - metric_name="oneig_alignment", - vlm_type="transformers", - model_name="m", - device="cpu", - ) - qs = rec["inputs"]["auxiliary_0"]["questions"] - assert qs["1"] == "Are there boys?" - assert qs["21"] is None - - -# --------------------------------------------------------------------------- -# pred_tensor_from_auxiliaries (test helper, wraps pil_rgb_from_aux_image_bytes) tests -# --------------------------------------------------------------------------- - - -def _make_jpeg_bytes(h: int = 32, w: int = 32) -> bytes: - """Return a tiny JPEG-encoded RGB image as bytes (test helper).""" - import io - - import numpy as np - from PIL import Image - - arr = (np.random.rand(h, w, 3) * 255).astype("uint8") - buf = io.BytesIO() - Image.fromarray(arr).save(buf, format="JPEG") - return buf.getvalue() - - -@pytest.mark.cpu -def test_pred_from_auxiliaries_uses_source_image_bytes() -> None: - """pred_tensor_from_auxiliaries decodes source_image_bytes into a float tensor in [0, 1].""" - src_bytes = _make_jpeg_bytes() - aux = [{"source_image_bytes": src_bytes, "category": "background_change"}] - pred = pred_tensor_from_auxiliaries(aux, size=64) - - assert pred.shape == (1, 3, 64, 64), f"Expected (1,3,64,64), got {pred.shape}" - assert pred.min() >= 0.0 and pred.max() <= 1.0, "Pixel values must be in [0, 1]" - - -@pytest.mark.cpu -def test_pred_from_auxiliaries_falls_back_to_noise_without_source_image() -> None: - """pred_tensor_from_auxiliaries returns random noise when no source_image_bytes is present.""" - aux = [{"category": "single_object"}] - pred = pred_tensor_from_auxiliaries(aux, size=32) - assert pred.shape == (1, 3, 32, 32) - assert pred.min() >= 0.0 and pred.max() <= 1.0 - - -@pytest.mark.cpu -def test_pred_from_auxiliaries_mixed_batch() -> None: - """Batch with one source image and one missing falls back per-item.""" - src_bytes = _make_jpeg_bytes() - aux = [ - {"source_image_bytes": src_bytes, "category": "color_alter"}, - {"category": "style_change"}, # no source image - ] - pred = pred_tensor_from_auxiliaries(aux, size=32) - assert pred.shape == (2, 3, 32, 32) - assert pred.min() >= 0.0 and pred.max() <= 1.0 - - -@pytest.mark.cpu -def test_pred_from_auxiliaries_generic_bytes_scan() -> None: - """pred_tensor_from_auxiliaries discovers image bytes under an unknown field name (generic scan).""" - src_bytes = _make_jpeg_bytes() - aux = [{"my_custom_image_bytes": src_bytes, "category": "motion_change"}] - pred = pred_tensor_from_auxiliaries(aux, size=32) - assert pred.shape == (1, 3, 32, 32) - assert pred.min() >= 0.0 and pred.max() <= 1.0 - - -@pytest.mark.cpu -def test_pred_from_auxiliaries_known_names_take_priority() -> None: - """Known field names are resolved before the generic bytes scan.""" - src_bytes_known = _make_jpeg_bytes(16, 16) - src_bytes_unknown = _make_jpeg_bytes(32, 32) - first_known = VLM_AUX_IMAGE_BYTES_KEY_ORDER[0] - aux = [{"other_bytes": src_bytes_unknown, first_known: src_bytes_known}] - pred = pred_tensor_from_auxiliaries(aux, size=16) - # Should use the known key (16x16 image → 16x16 crop); generic scan would pick 32x32 - assert pred.shape == (1, 3, 16, 16) - - -@pytest.mark.cpu -def test_pred_from_auxiliaries_require_source_image_raises_when_missing() -> None: - """require_source_image=True raises ValueError instead of silently returning noise.""" - aux = [{"category": "replace"}] # no image bytes - with pytest.raises(ValueError, match="require_source_image=True"): - pred_tensor_from_auxiliaries(aux, size=32, require_source_image=True) - + from transformers import AutoTokenizer -@pytest.mark.cpu -def test_pred_from_auxiliaries_require_source_image_succeeds_when_present() -> None: - """require_source_image=True succeeds and decodes bytes when source_image_bytes is present.""" - src_bytes = _make_jpeg_bytes() - aux = [{"source_image_bytes": src_bytes, "category": "replace"}] - pred = pred_tensor_from_auxiliaries(aux, size=32, require_source_image=True) - assert pred.shape == (1, 3, 32, 32) - assert pred.min() >= 0.0 and pred.max() <= 1.0 + tok = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolVLM-256M-Instruct") + yes_ids, no_ids = yes_no_first_token_id_groups(tok) + assert yes_ids + assert no_ids