diff --git a/ais_bench/benchmark/configs/datasets/refcoco/refcoco_gen.py b/ais_bench/benchmark/configs/datasets/refcoco/refcoco_gen.py new file mode 100644 index 00000000..008119cb --- /dev/null +++ b/ais_bench/benchmark/configs/datasets/refcoco/refcoco_gen.py @@ -0,0 +1,56 @@ +from ais_bench.benchmark.openicl.icl_retriever import ZeroRetriever +from ais_bench.benchmark.openicl.icl_inferencer import GenInferencer +from ais_bench.benchmark.openicl.icl_prompt_template import MMPromptTemplate +from ais_bench.benchmark.datasets import RefCOCODataset +from ais_bench.benchmark.datasets.refcoco import refcoco_bbox_postprocess +from ais_bench.benchmark.openicl.icl_evaluator import BBoxIoUEvaluator + + +refcoco_reader_cfg = dict( + input_columns=['question', 'image'], + output_column='answer' +) + +refcoco_infer_cfg = dict( + prompt_template=dict( + type=MMPromptTemplate, + template=dict( + round=[ + dict(role='HUMAN', prompt_mm={ + 'text': { + 'type': 'text', + 'text': 'Locate every object that matches the description "{question}" in the image. Report bbox coordinates in JSON format.' + }, + 'image': {'type': 'image_url', 'image_url': {'url': 'file://{image}'}}, + }) + ] + ) + ), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=GenInferencer), +) + +refcoco_eval_cfg = dict( + evaluator=dict(type=BBoxIoUEvaluator, iou_threshold=0.5, coord_scale=1000.0), + pred_postprocessor=dict(type=refcoco_bbox_postprocess), +) + +_splits = [ + 'val', + 'test', + 'testA', + 'testB', +] + +refcoco_datasets = [ + dict( + abbr='RefCOCO_' + split, + type=RefCOCODataset, + path='ais_bench/datasets/RefCOCO/data', + split=split, + reader_cfg=refcoco_reader_cfg, + infer_cfg=refcoco_infer_cfg, + eval_cfg=refcoco_eval_cfg, + ) + for split in _splits +] \ No newline at end of file diff --git a/ais_bench/benchmark/configs/datasets/refcoco/refcoco_gen_base64.py b/ais_bench/benchmark/configs/datasets/refcoco/refcoco_gen_base64.py new file mode 100644 index 00000000..d807dbce --- /dev/null +++ b/ais_bench/benchmark/configs/datasets/refcoco/refcoco_gen_base64.py @@ -0,0 +1,57 @@ +from ais_bench.benchmark.openicl.icl_retriever import ZeroRetriever +from ais_bench.benchmark.openicl.icl_inferencer import GenInferencer +from ais_bench.benchmark.openicl.icl_prompt_template import MMPromptTemplate +from ais_bench.benchmark.datasets import RefCOCODataset +from ais_bench.benchmark.datasets.refcoco import IMAGE_BASE64_TYPE, refcoco_bbox_postprocess +from ais_bench.benchmark.openicl.icl_evaluator import BBoxIoUEvaluator + + +refcoco_reader_cfg = dict( + input_columns=['question', 'image'], + output_column='answer' +) + +refcoco_infer_cfg = dict( + prompt_template=dict( + type=MMPromptTemplate, + template=dict( + round=[ + dict(role='HUMAN', prompt_mm={ + 'text': { + 'type': 'text', + 'text': 'Locate every object that matches the description "{question}" in the image. Report bbox coordinates in JSON format.' + }, + 'image': {'type': 'image_url', 'image_url': {'url': 'data:image/jpeg;base64,{image}'}}, + }) + ] + ) + ), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=GenInferencer), +) + +refcoco_eval_cfg = dict( + evaluator=dict(type=BBoxIoUEvaluator, iou_threshold=0.5, coord_scale=1000.0), + pred_postprocessor=dict(type=refcoco_bbox_postprocess), +) + +_splits = [ + 'val', + 'test', + 'testA', + 'testB', +] + +refcoco_datasets = [ + dict( + abbr='RefCOCO_base64_' + split, + type=RefCOCODataset, + path='ais_bench/datasets/RefCOCO/data', + split=split, + image_type=IMAGE_BASE64_TYPE, + reader_cfg=refcoco_reader_cfg, + infer_cfg=refcoco_infer_cfg, + eval_cfg=refcoco_eval_cfg, + ) + for split in _splits +] \ No newline at end of file diff --git a/ais_bench/benchmark/configs/datasets/refcoco_plus/refcoco_plus_gen.py b/ais_bench/benchmark/configs/datasets/refcoco_plus/refcoco_plus_gen.py new file mode 100644 index 00000000..d505dd74 --- /dev/null +++ b/ais_bench/benchmark/configs/datasets/refcoco_plus/refcoco_plus_gen.py @@ -0,0 +1,55 @@ +from ais_bench.benchmark.openicl.icl_retriever import ZeroRetriever +from ais_bench.benchmark.openicl.icl_inferencer import GenInferencer +from ais_bench.benchmark.openicl.icl_prompt_template import MMPromptTemplate +from ais_bench.benchmark.datasets import RefCOCOPlusDataset +from ais_bench.benchmark.datasets.refcoco import refcoco_bbox_postprocess +from ais_bench.benchmark.openicl.icl_evaluator import BBoxIoUEvaluator + + +refcoco_plus_reader_cfg = dict( + input_columns=['question', 'image'], + output_column='answer' +) + +refcoco_plus_infer_cfg = dict( + prompt_template=dict( + type=MMPromptTemplate, + template=dict( + round=[ + dict(role='HUMAN', prompt_mm={ + 'text': { + 'type': 'text', + 'text': 'Locate every object that matches the description "{question}" in the image. Report bbox coordinates in JSON format.' + }, + 'image': {'type': 'image_url', 'image_url': {'url': 'file://{image}'}}, + }) + ] + ) + ), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=GenInferencer), +) + +refcoco_plus_eval_cfg = dict( + evaluator=dict(type=BBoxIoUEvaluator, iou_threshold=0.5, coord_scale=1000.0), + pred_postprocessor=dict(type=refcoco_bbox_postprocess), +) + +_splits = [ + 'val', + 'testA', + 'testB', +] + +refcoco_plus_datasets = [ + dict( + abbr='RefCOCOPlus_' + split, + type=RefCOCOPlusDataset, + path='ais_bench/datasets/RefCOCOplus/data', + split=split, + reader_cfg=refcoco_plus_reader_cfg, + infer_cfg=refcoco_plus_infer_cfg, + eval_cfg=refcoco_plus_eval_cfg, + ) + for split in _splits +] \ No newline at end of file diff --git a/ais_bench/benchmark/configs/datasets/refcoco_plus/refcoco_plus_gen_base64.py b/ais_bench/benchmark/configs/datasets/refcoco_plus/refcoco_plus_gen_base64.py new file mode 100644 index 00000000..5804cb61 --- /dev/null +++ b/ais_bench/benchmark/configs/datasets/refcoco_plus/refcoco_plus_gen_base64.py @@ -0,0 +1,56 @@ +from ais_bench.benchmark.openicl.icl_retriever import ZeroRetriever +from ais_bench.benchmark.openicl.icl_inferencer import GenInferencer +from ais_bench.benchmark.openicl.icl_prompt_template import MMPromptTemplate +from ais_bench.benchmark.datasets import RefCOCOPlusDataset +from ais_bench.benchmark.datasets.refcoco import IMAGE_BASE64_TYPE, refcoco_bbox_postprocess +from ais_bench.benchmark.openicl.icl_evaluator import BBoxIoUEvaluator + + +refcoco_plus_reader_cfg = dict( + input_columns=['question', 'image'], + output_column='answer' +) + +refcoco_plus_infer_cfg = dict( + prompt_template=dict( + type=MMPromptTemplate, + template=dict( + round=[ + dict(role='HUMAN', prompt_mm={ + 'text': { + 'type': 'text', + 'text': 'Locate every object that matches the description "{question}" in the image. Report bbox coordinates in JSON format.' + }, + 'image': {'type': 'image_url', 'image_url': {'url': 'data:image/jpeg;base64,{image}'}}, + }) + ] + ) + ), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=GenInferencer), +) + +refcoco_plus_eval_cfg = dict( + evaluator=dict(type=BBoxIoUEvaluator, iou_threshold=0.5, coord_scale=1000.0), + pred_postprocessor=dict(type=refcoco_bbox_postprocess), +) + +_splits = [ + 'val', + 'testA', + 'testB', +] + +refcoco_plus_datasets = [ + dict( + abbr='RefCOCOPlus_base64_' + split, + type=RefCOCOPlusDataset, + path='ais_bench/datasets/RefCOCOplus/data', + split=split, + image_type=IMAGE_BASE64_TYPE, + reader_cfg=refcoco_plus_reader_cfg, + infer_cfg=refcoco_plus_infer_cfg, + eval_cfg=refcoco_plus_eval_cfg, + ) + for split in _splits +] \ No newline at end of file diff --git a/ais_bench/benchmark/configs/datasets/refcocog/refcocog_gen.py b/ais_bench/benchmark/configs/datasets/refcocog/refcocog_gen.py new file mode 100644 index 00000000..50cbb852 --- /dev/null +++ b/ais_bench/benchmark/configs/datasets/refcocog/refcocog_gen.py @@ -0,0 +1,54 @@ +from ais_bench.benchmark.openicl.icl_retriever import ZeroRetriever +from ais_bench.benchmark.openicl.icl_inferencer import GenInferencer +from ais_bench.benchmark.openicl.icl_prompt_template import MMPromptTemplate +from ais_bench.benchmark.datasets import RefCOCOgDataset +from ais_bench.benchmark.datasets.refcoco import refcoco_bbox_postprocess +from ais_bench.benchmark.openicl.icl_evaluator import BBoxIoUEvaluator + + +refcocog_reader_cfg = dict( + input_columns=['question', 'image'], + output_column='answer' +) + +refcocog_infer_cfg = dict( + prompt_template=dict( + type=MMPromptTemplate, + template=dict( + round=[ + dict(role='HUMAN', prompt_mm={ + 'text': { + 'type': 'text', + 'text': 'Locate every object that matches the description "{question}" in the image. Report bbox coordinates in JSON format.' + }, + 'image': {'type': 'image_url', 'image_url': {'url': 'file://{image}'}}, + }) + ] + ) + ), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=GenInferencer), +) + +refcocog_eval_cfg = dict( + evaluator=dict(type=BBoxIoUEvaluator, iou_threshold=0.5, coord_scale=1000.0), + pred_postprocessor=dict(type=refcoco_bbox_postprocess), +) + +_splits = [ + 'val', + 'test', +] + +refcocog_datasets = [ + dict( + abbr='RefCOCOg_' + split, + type=RefCOCOgDataset, + path='ais_bench/datasets/RefCOCOg/data', + split=split, + reader_cfg=refcocog_reader_cfg, + infer_cfg=refcocog_infer_cfg, + eval_cfg=refcocog_eval_cfg, + ) + for split in _splits +] \ No newline at end of file diff --git a/ais_bench/benchmark/configs/datasets/refcocog/refcocog_gen_base64.py b/ais_bench/benchmark/configs/datasets/refcocog/refcocog_gen_base64.py new file mode 100644 index 00000000..cf6eb915 --- /dev/null +++ b/ais_bench/benchmark/configs/datasets/refcocog/refcocog_gen_base64.py @@ -0,0 +1,55 @@ +from ais_bench.benchmark.openicl.icl_retriever import ZeroRetriever +from ais_bench.benchmark.openicl.icl_inferencer import GenInferencer +from ais_bench.benchmark.openicl.icl_prompt_template import MMPromptTemplate +from ais_bench.benchmark.datasets import RefCOCOgDataset +from ais_bench.benchmark.datasets.refcoco import IMAGE_BASE64_TYPE, refcoco_bbox_postprocess +from ais_bench.benchmark.openicl.icl_evaluator import BBoxIoUEvaluator + + +refcocog_reader_cfg = dict( + input_columns=['question', 'image'], + output_column='answer' +) + +refcocog_infer_cfg = dict( + prompt_template=dict( + type=MMPromptTemplate, + template=dict( + round=[ + dict(role='HUMAN', prompt_mm={ + 'text': { + 'type': 'text', + 'text': 'Locate every object that matches the description "{question}" in the image. Report bbox coordinates in JSON format.' + }, + 'image': {'type': 'image_url', 'image_url': {'url': 'data:image/jpeg;base64,{image}'}}, + }) + ] + ) + ), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=GenInferencer), +) + +refcocog_eval_cfg = dict( + evaluator=dict(type=BBoxIoUEvaluator, iou_threshold=0.5, coord_scale=1000.0), + pred_postprocessor=dict(type=refcoco_bbox_postprocess), +) + +_splits = [ + 'val', + 'test', +] + +refcocog_datasets = [ + dict( + abbr='RefCOCOg_base64_' + split, + type=RefCOCOgDataset, + path='ais_bench/datasets/RefCOCOg/data', + split=split, + image_type=IMAGE_BASE64_TYPE, + reader_cfg=refcocog_reader_cfg, + infer_cfg=refcocog_infer_cfg, + eval_cfg=refcocog_eval_cfg, + ) + for split in _splits +] \ No newline at end of file diff --git a/ais_bench/benchmark/datasets/__init__.py b/ais_bench/benchmark/datasets/__init__.py index 1581a2af..3f0d8425 100644 --- a/ais_bench/benchmark/datasets/__init__.py +++ b/ais_bench/benchmark/datasets/__init__.py @@ -53,3 +53,4 @@ from ais_bench.benchmark.datasets.mmstar import * # noqa: F401, F403 from ais_bench.benchmark.datasets.dapo_math import * # noqa: F401, F403 from ais_bench.benchmark.datasets.mooncake_trace import * # noqa: F401, F403 +from ais_bench.benchmark.datasets.refcoco import * # noqa: F401, F403 diff --git a/ais_bench/benchmark/datasets/refcoco/__init__.py b/ais_bench/benchmark/datasets/refcoco/__init__.py new file mode 100644 index 00000000..1590279f --- /dev/null +++ b/ais_bench/benchmark/datasets/refcoco/__init__.py @@ -0,0 +1,9 @@ +from ais_bench.benchmark.datasets.refcoco.refcoco import ( # noqa: F401 + IMAGE_BASE64_TYPE, + IMAGE_PATH_TYPE, + TEMP_IMAGE_STORE_DIR, + RefCOCODataset, + refcoco_bbox_postprocess, +) +from ais_bench.benchmark.datasets.refcoco.refcoco_g import RefCOCOgDataset # noqa: F401 +from ais_bench.benchmark.datasets.refcoco.refcoco_plus import RefCOCOPlusDataset # noqa: F401 \ No newline at end of file diff --git a/ais_bench/benchmark/datasets/refcoco/refcoco.py b/ais_bench/benchmark/datasets/refcoco/refcoco.py new file mode 100644 index 00000000..319e4cfa --- /dev/null +++ b/ais_bench/benchmark/datasets/refcoco/refcoco.py @@ -0,0 +1,202 @@ +import glob +import io +import json +import os +import re + +from abc import ABC, abstractmethod +from typing import Any + +import pandas as pd +from PIL import Image + +from datasets import Dataset + +from ais_bench.benchmark.datasets.utils.datasets import get_content_str +from ais_bench.benchmark.registry import LOAD_DATASET, TEXT_POSTPROCESSORS +from ais_bench.benchmark.datasets.utils.datasets import get_data_path +from ais_bench.benchmark.utils.image_process import pil_to_base64 +from ais_bench.benchmark.utils.logging import AISLogger + +from ..base import BaseDataset + +logger = AISLogger() + +IMAGE_PATH_TYPE = 'image_path' +IMAGE_BASE64_TYPE = 'image_base64' + +TEMP_IMAGE_STORE_DIR = 'temp_save_images' + +def _parse_float_sequence_within(input_str: str) -> list[float]: + """Extract the first sequence of four floats inside square brackets.""" + pattern = r'\[\s*(-?\d+(?:\.\d+)?)\s*,\s*(-?\d+(?:\.\d+)?)\s*,\s*(-?\d+(?:\.\d+)?)\s*,\s*(-?\d+(?:\.\d+)?)\s*\]' + match = re.search(pattern, input_str) + if match: + return [float(match.group(i)) for i in range(1, 5)] + return [0.0, 0.0, 0.0, 0.0] + + +def _remove_leading_articles(text: str) -> str: + cleaned_text = re.sub(r'^(a|an|the)\s+', '', text.strip(), flags=re.IGNORECASE) + return cleaned_text or text.strip() + + +@TEXT_POSTPROCESSORS.register_module('refcoco_bbox_1000') +def refcoco_bbox_postprocess(text: str) -> list[float]: + stripped_text = text.strip() + bbox = _parse_float_sequence_within(stripped_text) + + logger.debug(f'refcoco_bbox_postprocess: bbox={bbox}') + return bbox + + +class ImageResolver(ABC): + """Strategy interface for converting a PIL image into a transport value.""" + + @abstractmethod + def setup(self, resolved_path: str, split: str) -> None: + ... + + @abstractmethod + def resolve(self, pil_img: Image.Image, file_name: str) -> str: + ... + + +class PathImageResolver(ImageResolver): + def setup(self, resolved_path: str, split: str) -> None: + image_cache_path = os.path.join( + resolved_path, + TEMP_IMAGE_STORE_DIR, + split, + ) + logger.info(f'Saving RefCOCO images to {image_cache_path}') + os.makedirs(image_cache_path, exist_ok=True) + self._cache_dir = image_cache_path + + def resolve(self, pil_img: Image.Image, file_name: str) -> str: + image_path = os.path.join(self._cache_dir, file_name) + os.makedirs(os.path.dirname(image_path), exist_ok=True) + if not os.path.exists(image_path): + pil_img.save(image_path, format='JPEG') + return image_path + + +class Base64ImageResolver(ImageResolver): + def setup(self, resolved_path: str, split: str) -> None: + logger.info(f'Encoding RefCOCO images as base64 for split {split}') + + def resolve(self, pil_img: Image.Image, file_name: str) -> str: + return pil_to_base64(pil_img, format='JPEG') + + +IMAGE_RESOLVERS = { + IMAGE_PATH_TYPE: PathImageResolver, + IMAGE_BASE64_TYPE: Base64ImageResolver, +} + + +@LOAD_DATASET.register_module() +class RefCOCODataset(BaseDataset): + @staticmethod + def _load_split_dataframe(resolved_path: str, split: str) -> pd.DataFrame: + shard_paths = sorted(glob.glob(os.path.join(resolved_path, f'{split}-*.parquet'))) + if not shard_paths: + raise FileNotFoundError( + f'No RefCOCO parquet shards found for split {split} in {resolved_path}' + ) + + logger.info(f'Loading RefCOCO split {split} from {len(shard_paths)} shard(s) in {resolved_path}') + return pd.concat([pd.read_parquet(shard_path) for shard_path in shard_paths], ignore_index=True) + + @staticmethod + def _decode_image_payload(image_payload: Any, row_index: int) -> Image.Image: + if not isinstance(image_payload, dict) or 'bytes' not in image_payload: + raise ValueError(f'RefCOCO row {row_index} has invalid image payload: {type(image_payload)}') + + return Image.open(io.BytesIO(image_payload['bytes'])).convert('RGB') + + @staticmethod + def _build_pixel_bbox(raw_bbox: Any) -> list[float]: + x_coord, y_coord, bbox_width, bbox_height = [float(value) for value in raw_bbox] + return [x_coord, y_coord, x_coord + bbox_width, y_coord + bbox_height] + + @staticmethod + def _build_rows( + sample: pd.Series, + image_value: str, + width: int, + height: int, + ) -> list[dict[str, str]]: + reference_answer = json.dumps({ + 'question_id': int(sample['question_id']), + 'bbox': RefCOCODataset._build_pixel_bbox(sample['bbox']), + 'image_width': width, + 'image_height': height, + }) + + rows: list[dict[str, str]] = [] + for answer_text in sample['answer']: + content = get_content_str([ + {'type': 'image_url', 'image_url': image_value}, + {'type': 'text', 'text': answer_text}, + ]) + rows.append({ + 'content': content, + 'answer': reference_answer, + }) + return rows + + @staticmethod + def load(path: str, split: str, **kwargs: Any) -> Dataset: # pyright: ignore[reportIncompatibleMethodOverride] + """Load a RefCOCO split and normalize it into benchmark rows. + + The source data is stored as parquet shards under ``path`` with shard + names matching ``-*.parquet``. Each source row contains an image + payload, a ground-truth bounding box in ``[x, y, w, h]`` format, and a + list of referring expressions. This loader can either persist each image + to a split-specific cache directory or encode it as base64, converts the + bbox to ``[x_min, y_min, x_max, y_max]``, and expands the answer list + into one benchmark row per referring expression. + + Each output row has a ``content`` field that encodes the image and + referring expression together using ``AIS_CONTENT_TAG`` delimiters + (via :func:`get_content_str`). During inference the + :meth:`PromptList.format_mm` method splits ``content`` on + ``AIS_CONTENT_TAG`` and uses the ``AIS_IMAGE_START`` / + ``AIS_TEXT_START`` prefixes to populate the ``prompt_mm`` template + with the image URL and question text respectively. + + Args: + path: Dataset root containing RefCOCO parquet shards. + split: Split prefix to load, for example ``val`` or ``testA``. + **kwargs: Extra keyword arguments passed by the dataset builder. + Supported key: ``image_type`` with values ``IMAGE_PATH_TYPE`` or + ``IMAGE_BASE64_TYPE``. + + Returns: + A HuggingFace ``Dataset`` with columns: + - content: encoded multimodal string consumed by + ``format_mm`` to fill the ``prompt_mm`` template. + - answer: JSON-serialized reference bbox payload used by + evaluation. + """ + resolved_path = get_data_path(path) + image_type = kwargs.get('image_type', IMAGE_PATH_TYPE) + if image_type not in IMAGE_RESOLVERS: + raise ValueError( + f'Unsupported image_type: {image_type}. Expected one of {sorted(IMAGE_RESOLVERS)}' + ) + data = RefCOCODataset._load_split_dataframe(resolved_path, split) + resolver = IMAGE_RESOLVERS[image_type]() + resolver.setup(resolved_path, split) + + rows: list[dict[str, str]] = [] + for row_index, (_, sample) in enumerate(data.iterrows()): + pil_img = RefCOCODataset._decode_image_payload(sample['image'], row_index) + image_value = resolver.resolve(pil_img, sample['file_name']) + + width, height = pil_img.width, pil_img.height + sample_rows = RefCOCODataset._build_rows(sample, image_value, width, height) + rows.extend(sample_rows) + + return Dataset.from_list(rows) diff --git a/ais_bench/benchmark/datasets/refcoco/refcoco_g.py b/ais_bench/benchmark/datasets/refcoco/refcoco_g.py new file mode 100644 index 00000000..6e32abb7 --- /dev/null +++ b/ais_bench/benchmark/datasets/refcoco/refcoco_g.py @@ -0,0 +1,15 @@ +from ais_bench.benchmark.registry import LOAD_DATASET + +from ais_bench.benchmark.datasets.refcoco.refcoco import RefCOCODataset + + +@LOAD_DATASET.register_module() +class RefCOCOgDataset(RefCOCODataset): + """ + RefCOCOg is a variant of RefCOCO with more complex referring expressions. + Because the dataset field is same as the RefCOCO dataset, we can reuse the loading and evaluation code. + The only difference is refcoco_g only has two splits: + - `val`: 7.57k rows + - `test`: 5.02k rows + """ + pass diff --git a/ais_bench/benchmark/datasets/refcoco/refcoco_plus.py b/ais_bench/benchmark/datasets/refcoco/refcoco_plus.py new file mode 100644 index 00000000..026c222a --- /dev/null +++ b/ais_bench/benchmark/datasets/refcoco/refcoco_plus.py @@ -0,0 +1,16 @@ +from ais_bench.benchmark.registry import LOAD_DATASET + +from ais_bench.benchmark.datasets.refcoco.refcoco import RefCOCODataset + + +@LOAD_DATASET.register_module() +class RefCOCOPlusDataset(RefCOCODataset): + """ + RefCOCOplus is a variant of RefCOCO with more complex referring expressions. + Because the dataset field is same as the RefCOCO dataset, we can reuse the loading and evaluation code. + The only difference is refcoco_plus only has three splits: + - `val`: 3.81k rows + - `testA`: 1.98k rows + - `testB`: 1.8k rows + """ + pass diff --git a/ais_bench/benchmark/openicl/icl_evaluator/__init__.py b/ais_bench/benchmark/openicl/icl_evaluator/__init__.py index 6a622d2a..8e72243b 100644 --- a/ais_bench/benchmark/openicl/icl_evaluator/__init__.py +++ b/ais_bench/benchmark/openicl/icl_evaluator/__init__.py @@ -1,4 +1,5 @@ from ais_bench.benchmark.openicl.icl_evaluator.icl_base_evaluator import BaseEvaluator # noqa +from ais_bench.benchmark.openicl.icl_evaluator.bbox_iou_evaluator import BBoxIoUEvaluator # noqa from ais_bench.benchmark.openicl.icl_evaluator.icl_jieba_rouge_evaluator import JiebaRougeEvaluator # noqa from ais_bench.benchmark.openicl.icl_evaluator.math_evaluator import MATHEvaluator # noqa from ais_bench.benchmark.openicl.icl_evaluator.icl_hf_evaluator import * # noqa diff --git a/ais_bench/benchmark/openicl/icl_evaluator/bbox_iou_evaluator.py b/ais_bench/benchmark/openicl/icl_evaluator/bbox_iou_evaluator.py new file mode 100644 index 00000000..f514b2e1 --- /dev/null +++ b/ais_bench/benchmark/openicl/icl_evaluator/bbox_iou_evaluator.py @@ -0,0 +1,106 @@ +import json + +from ais_bench.benchmark.openicl.icl_evaluator.icl_base_evaluator import BaseEvaluator +from ais_bench.benchmark.registry import ICL_EVALUATORS + + +def _compute_iou(box1: list, box2: list) -> float: + x_left = max(box1[0], box2[0]) + y_top = max(box1[1], box2[1]) + x_right = min(box1[2], box2[2]) + y_bottom = min(box1[3], box2[3]) + + inter = max(0.0, x_right - x_left) * max(0.0, y_bottom - y_top) + area1 = max(0.0, box1[2] - box1[0]) * max(0.0, box1[3] - box1[1]) + area2 = max(0.0, box2[2] - box2[0]) * max(0.0, box2[3] - box2[1]) + union = area1 + area2 - inter + return inter / union if union > 0 else 0.0 + + +@ICL_EVALUATORS.register_module() +class BBoxIoUEvaluator(BaseEvaluator): + + def __init__(self, + iou_threshold: float = 0.5, + coord_scale: float = 1000.0, + reference_bbox_key: str = 'bbox', + image_width_key: str = 'image_width', + image_height_key: str = 'image_height', + metric_prefix: str = 'Accuracy', + clip_to_image: bool = True) -> None: + super().__init__() + self.iou_threshold = iou_threshold + self.coord_scale = coord_scale + self.reference_bbox_key = reference_bbox_key + self.image_width_key = image_width_key + self.image_height_key = image_height_key + self.metric_prefix = metric_prefix + self.clip_to_image = clip_to_image + + def _scale_prediction(self, pred_box: list, image_width: float, image_height: float) -> list: + if len(pred_box) != 4: + raise ValueError('Predicted bbox must contain four coordinates') + + scaled_box = [ + float(pred_box[0]) / self.coord_scale * float(image_width), + float(pred_box[1]) / self.coord_scale * float(image_height), + float(pred_box[2]) / self.coord_scale * float(image_width), + float(pred_box[3]) / self.coord_scale * float(image_height), + ] + + if self.clip_to_image: + scaled_box = [ + min(max(scaled_box[0], 0.0), float(image_width)), + min(max(scaled_box[1], 0.0), float(image_height)), + min(max(scaled_box[2], 0.0), float(image_width)), + min(max(scaled_box[3], 0.0), float(image_height)), + ] + + if scaled_box[2] <= scaled_box[0] or scaled_box[3] <= scaled_box[1]: + raise ValueError('Predicted bbox is reversed or empty after scaling') + return scaled_box + + def score(self, predictions, references): # pyright: ignore[reportIncompatibleMethodOverride] + if len(predictions) != len(references): + return { + 'error': 'predictions and references have different ' + f'length. len(predictions): {len(predictions)}, ' + f'len(references): {len(references)}' + } + + details = [] + scores = [] + for pred, ref in zip(predictions, references): + detail = { + 'pred': pred, + 'answer': ref, + 'correct': False, + 'coord_mode': f'0-{int(self.coord_scale)}', + } + + try: + refer = json.loads(ref) if isinstance(ref, str) else ref + image_width = float(refer[self.image_width_key]) + image_height = float(refer[self.image_height_key]) + pred_box_pixel = self._scale_prediction(pred, image_width, image_height) + gt_box = [float(value) for value in refer[self.reference_bbox_key]] + + iou = _compute_iou(pred_box_pixel, gt_box) + correct = iou >= self.iou_threshold + detail['correct'] = correct + detail['iou'] = iou + detail['pred_bbox_pixel'] = pred_box_pixel + scores.append(1 if correct else 0) + except (TypeError, ValueError, KeyError, json.JSONDecodeError, IndexError) as error: + detail['iou'] = 0.0 + detail['pred_bbox_pixel'] = None + detail['invalid'] = True + detail['error'] = str(error) + scores.append(0) + + details.append(detail) + + return { + f'{self.metric_prefix}@{self.iou_threshold}': 100 * sum(scores) / len(scores) if scores else 0.0, + 'details': details, + } \ No newline at end of file