Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions olive/data/component/pre_process_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,3 +379,92 @@ def collate_fn(batch):
return (audios, texts)

return SpeechTranscriptionDataset(dataset, audio_col, text_col)


@Registry.register_pre_process()
def vision_vqa_pre_process(
dataset,
image_col: str = "image",
question_col: str = "question",
answer_col: str = "answer",
max_samples: Optional[int] = None,
limit: Optional[float] = None,
seed: int = 42,
**kwargs,
):
"""Pre-process data for vision VQA evaluation.

Loads image, question, and ground truth answer from a HuggingFace dataset.
Returns a dataset of ({"image": image, "question": question}, answer) pairs.

Note: This returns raw PIL images and question strings. For the PyTorch evaluator,
the model's own processor/tokenizer should be applied in the post_func or within
the model's forward method. For the ONNX evaluator, provide a custom pre-process
component that applies the appropriate processor/tokenizer to produce numeric
tensors matching the model's io_config.

Args:
dataset: HuggingFace dataset with image, question, and answer columns.
image_col: Name of the image column. Defaults to "image".
question_col: Name of the question column. Defaults to "question".
answer_col: Name of the answer column. Defaults to "answer".
max_samples: Maximum number of samples (deprecated, use limit). Defaults to None.
limit: Sampling limit following Olive convention:
If >= 1: use first N samples.
If 0 < limit < 1: randomly sample that percentage.
If 0 or None: use all samples.
seed: Random seed for percentage-based sampling. Defaults to 42.
**kwargs: Additional arguments.

"""
# Apply sampling: prefer limit over max_samples
effective_limit = limit if limit is not None else (max_samples if max_samples else 0)
if effective_limit and effective_limit != 0:
from random import Random

total = len(dataset)
if 0 < effective_limit < 1:
n = max(1, int(total * effective_limit))
rng = Random(seed)
indices = sorted(rng.sample(range(total), min(n, total)))
dataset = dataset.select(indices)
elif effective_limit >= 1:
n = min(int(effective_limit), total)
dataset = dataset.select(range(n))

class VisionVQADataset:
"""Dataset that returns (input_dict, answer_text) pairs for VQA evaluation.

Note: Use batch_size=1 in dataloader config as images have variable sizes.
"""

def __init__(self, hf_dataset, image_column, question_column, answer_column):
self.dataset = hf_dataset
self.image_column = image_column
self.question_column = question_column
self.answer_column = answer_column

def __len__(self):
return len(self.dataset)

def __getitem__(self, idx):
item = self.dataset[idx]
image = item[self.image_column]
question = item[self.question_column]
answer = item[self.answer_column]
# Handle list answers (some datasets have multiple valid answers)
if isinstance(answer, list):
Comment on lines +452 to +456
answer = answer[0] if answer else ""
return {"image": image, "question": question}, str(answer)

@staticmethod
def collate_fn(batch):
"""Collate VQA batches. Use with batch_size=1 for variable-size images."""
if len(batch) == 1:
input_dict, answer = batch[0]
return (input_dict, [answer])
inputs = [item[0] for item in batch]
answers = [item[1] for item in batch]
return (inputs, answers)

return VisionVQADataset(dataset, image_col, question_col, answer_col)
9 changes: 9 additions & 0 deletions olive/data/container/huggingface_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,13 @@ class HuggingfaceContainer(DataContainer):
"speech-transcription": {
DataComponentType.PRE_PROCESS_DATA.value: "speech_transcription_pre_process",
},
"vision-vqa": {
DataComponentType.PRE_PROCESS_DATA.value: "vision_vqa_pre_process",
},
"vision-chart-qa": {
DataComponentType.PRE_PROCESS_DATA.value: "vision_vqa_pre_process",
},
"vision-ocr": {
DataComponentType.PRE_PROCESS_DATA.value: "vision_vqa_pre_process",
},
}
172 changes: 172 additions & 0 deletions olive/evaluator/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,3 +217,175 @@ def measure(self, model_output, target):
if total_inference == 0:
return float("inf")
return round(total_audio / total_inference, 2)


class ExactMatch(AccuracyBase):
"""Exact match metric for vision VQA evaluation.

Compares predicted answer strings to ground truth answers using
case-insensitive, whitespace-normalized string equality.
Returns the fraction of samples with an exact match.

Suitable for benchmarks: AI2D, ScienceQA, TextVQA, MathVista, MMMU, InterGPS.
"""

name: Optional[str] = "exact_match"

@classmethod
def _default_config(cls) -> dict[str, ConfigParam]:
return {}

@staticmethod
def _normalize(text: str) -> str:
"""Normalize text for comparison: lowercase and collapse whitespace."""
return " ".join(text.strip().lower().split())

def measure(self, model_output, target):
preds = model_output.preds
refs = target
if isinstance(preds, str):
preds = [preds]
elif not isinstance(preds, list):
preds = list(preds)
if isinstance(refs, str):
refs = [refs]
elif not isinstance(refs, list):
refs = list(refs)

if len(preds) != len(refs):
raise ValueError(
f"Number of predictions ({len(preds)}) does not match "
f"number of references ({len(refs)}) for exact_match metric."
)

correct = sum(1 for p, r in zip(preds, refs) if self._normalize(str(p)) == self._normalize(str(r)))
return correct / len(refs) if refs else 0.0


class RelaxedAccuracy(AccuracyBase):
"""Relaxed accuracy metric for chart/math VQA evaluation.

For numeric answers, allows a ±5% tolerance (standard for ChartQA).
For non-numeric answers, falls back to exact string match.
Returns the fraction of samples that match within tolerance.

Suitable for benchmarks: ChartQA.
"""

name: Optional[str] = "relaxed_accuracy"

@classmethod
def _default_config(cls) -> dict[str, ConfigParam]:
return {
"tolerance": ConfigParam(type_=float, required=False, default_value=0.05),
}

@staticmethod
def _try_parse_number(text: str):
"""Try to parse text as a number. Returns (True, value) or (False, None)."""
text = text.strip().replace(",", "").replace("%", "")
try:
return True, float(text)
except ValueError:
return False, None

@staticmethod
def _normalize(text: str) -> str:
return " ".join(text.strip().lower().split())

def measure(self, model_output, target):
preds = model_output.preds
refs = target
if isinstance(preds, str):
preds = [preds]
elif not isinstance(preds, list):
preds = list(preds)
if isinstance(refs, str):
refs = [refs]
elif not isinstance(refs, list):
refs = list(refs)

if len(preds) != len(refs):
raise ValueError(
f"Number of predictions ({len(preds)}) does not match "
f"number of references ({len(refs)}) for relaxed_accuracy metric."
)

tolerance = self.config_dict.get("tolerance", 0.05)
correct = 0
for pred, ref in zip(preds, refs):
pred_str = str(pred)
ref_str = str(ref)
pred_is_num, pred_val = self._try_parse_number(pred_str)
ref_is_num, ref_val = self._try_parse_number(ref_str)

if pred_is_num and ref_is_num:
# Numeric comparison with tolerance
if ref_val == 0:
if pred_val == 0:
correct += 1
elif abs(pred_val - ref_val) / abs(ref_val) <= tolerance:
correct += 1
else:
# String comparison (exact match, case-insensitive)
if self._normalize(pred_str) == self._normalize(ref_str):
correct += 1

return correct / len(refs) if refs else 0.0


class WordSortRatio(AccuracyBase):
"""Word sort ratio metric for OCR evaluation.

Computes the ratio of matching words between prediction and reference
after sorting words alphabetically. This measures word-level overlap
regardless of word order.
Returns the average ratio across all samples.

Suitable for benchmarks: OCR.
"""

name: Optional[str] = "word_sort_ratio"

@classmethod
def _default_config(cls) -> dict[str, ConfigParam]:
return {}

@staticmethod
def _compute_word_sort_ratio(pred: str, ref: str) -> float:
"""Compute word sort ratio between two strings."""
pred_words = sorted(pred.strip().lower().split())
ref_words = sorted(ref.strip().lower().split())

if not ref_words:
return 1.0 if not pred_words else 0.0

# Count matching words using multiset intersection
from collections import Counter

pred_counter = Counter(pred_words)
ref_counter = Counter(ref_words)
intersection = sum((pred_counter & ref_counter).values())
total = max(len(pred_words), len(ref_words))
return intersection / total if total > 0 else 0.0

def measure(self, model_output, target):
preds = model_output.preds
refs = target
if isinstance(preds, str):
preds = [preds]
elif not isinstance(preds, list):
preds = list(preds)
if isinstance(refs, str):
refs = [refs]
elif not isinstance(refs, list):
refs = list(refs)

if len(preds) != len(refs):
raise ValueError(
f"Number of predictions ({len(preds)}) does not match "
f"number of references ({len(refs)}) for word_sort_ratio metric."
)

total_ratio = sum(self._compute_word_sort_ratio(str(p), str(r)) for p, r in zip(preds, refs))
return total_ratio / len(refs) if refs else 0.0
4 changes: 4 additions & 0 deletions olive/evaluator/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ class AccuracySubType(StrEnumBase):
PERPLEXITY = "perplexity"
WER = "wer"
RTFX = "rtfx"
# Vision metrics (aligned with standard public vision benchmarks)
EXACT_MATCH = "exact_match" # AI2D, ScienceQA, TextVQA, MathVista, MMMU, InterGPS
RELAXED_ACCURACY = "relaxed_accuracy" # ChartQA (±5% numeric tolerance)
WORD_SORT_RATIO = "word_sort_ratio" # OCR (word-level overlap)


class LatencySubType(StrEnumBase):
Expand Down
Loading
Loading