From 6258d232e3c083f6bfc89536d2f3ac4ccf7b9c06 Mon Sep 17 00:00:00 2001 From: David Fan Date: Wed, 27 May 2026 18:13:44 +0000 Subject: [PATCH 1/4] Add vision evaluation metrics (exact_match, relaxed_accuracy, word_sort_ratio) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add vision evaluation metrics to the Olive evaluator framework, enabling VQA, ChartQA, and OCR model evaluation. - exact_match: case-insensitive string equality for VQA tasks - relaxed_accuracy: ±5% numeric tolerance for ChartQA - word_sort_ratio: word-level overlap ratio for OCR Changes: - olive/evaluator/metric.py: Add EXACT_MATCH, RELAXED_ACCURACY, WORD_SORT_RATIO to AccuracySubType - olive/evaluator/accuracy.py: Add ExactMatch, RelaxedAccuracy, WordSortRatio classes - olive/evaluator/olive_evaluator.py: Add _inference_vision() path and task-metric validation - olive/data/component/pre_process_data.py: Add vision_vqa_pre_process data component - olive/data/container/huggingface_container.py: Add vision-vqa, vision-chart-qa, vision-ocr tasks - olive/olive_config.json: Add vision extra dependencies (Pillow) - test/evaluator/test_accuracy.py: Add 20 unit tests for vision metrics Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- olive/data/component/pre_process_data.py | 83 ++++++++ olive/data/container/huggingface_container.py | 9 + olive/evaluator/accuracy.py | 166 +++++++++++++++ olive/evaluator/metric.py | 4 + olive/evaluator/olive_evaluator.py | 194 ++++++++++++++++- olive/olive_config.json | 3 +- test/evaluator/test_accuracy.py | 201 ++++++++++++++++++ 7 files changed, 657 insertions(+), 3 deletions(-) diff --git a/olive/data/component/pre_process_data.py b/olive/data/component/pre_process_data.py index 2881ea891..8da36c247 100644 --- a/olive/data/component/pre_process_data.py +++ b/olive/data/component/pre_process_data.py @@ -379,3 +379,86 @@ def collate_fn(batch): return (audios, texts) return SpeechTranscriptionDataset(dataset, audio_col, text_col) + + +@Registry.register_pre_process() +def vision_vqa_pre_process( + dataset, + image_col: str = "image", + question_col: str = "question", + answer_col: str = "answer", + max_samples: Optional[int] = None, + limit: Optional[float] = None, + seed: int = 42, + **kwargs, +): + """Pre-process data for vision VQA evaluation. + + Loads image, question, and ground truth answer from a HuggingFace dataset. + Returns a dataset of ({"image": image, "question": question}, answer) pairs. + + Args: + dataset: HuggingFace dataset with image, question, and answer columns. + image_col: Name of the image column. Defaults to "image". + question_col: Name of the question column. Defaults to "question". + answer_col: Name of the answer column. Defaults to "answer". + max_samples: Maximum number of samples (deprecated, use limit). Defaults to None. + limit: Sampling limit following Olive convention: + If >= 1: use first N samples. + If 0 < limit < 1: randomly sample that percentage. + If 0 or None: use all samples. + seed: Random seed for percentage-based sampling. Defaults to 42. + **kwargs: Additional arguments. + + """ + # Apply sampling: prefer limit over max_samples + effective_limit = limit if limit is not None else (max_samples if max_samples else 0) + if effective_limit and effective_limit != 0: + from random import Random + + total = len(dataset) + if 0 < effective_limit < 1: + n = max(1, int(total * effective_limit)) + rng = Random(seed) + indices = sorted(rng.sample(range(total), min(n, total))) + dataset = dataset.select(indices) + elif effective_limit >= 1: + n = min(int(effective_limit), total) + dataset = dataset.select(range(n)) + + class VisionVQADataset: + """Dataset that returns (input_dict, answer_text) pairs for VQA evaluation. + + Note: Use batch_size=1 in dataloader config as images have variable sizes. + """ + + def __init__(self, hf_dataset, image_column, question_column, answer_column): + self.dataset = hf_dataset + self.image_column = image_column + self.question_column = question_column + self.answer_column = answer_column + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, idx): + item = self.dataset[idx] + image = item[self.image_column] + question = item[self.question_column] + answer = item[self.answer_column] + # Handle list answers (some datasets have multiple valid answers) + if isinstance(answer, list): + answer = answer[0] if answer else "" + return {"image": image, "question": question}, str(answer) + + @staticmethod + def collate_fn(batch): + """Collate VQA batches. Use with batch_size=1 for variable-size images.""" + if len(batch) == 1: + input_dict, answer = batch[0] + return (input_dict, [answer]) + inputs = [item[0] for item in batch] + answers = [item[1] for item in batch] + return (inputs, answers) + + return VisionVQADataset(dataset, image_col, question_col, answer_col) diff --git a/olive/data/container/huggingface_container.py b/olive/data/container/huggingface_container.py index 9a6bb81e3..ab504d7a6 100644 --- a/olive/data/container/huggingface_container.py +++ b/olive/data/container/huggingface_container.py @@ -41,4 +41,13 @@ class HuggingfaceContainer(DataContainer): "speech-transcription": { DataComponentType.PRE_PROCESS_DATA.value: "speech_transcription_pre_process", }, + "vision-vqa": { + DataComponentType.PRE_PROCESS_DATA.value: "vision_vqa_pre_process", + }, + "vision-chart-qa": { + DataComponentType.PRE_PROCESS_DATA.value: "vision_vqa_pre_process", + }, + "vision-ocr": { + DataComponentType.PRE_PROCESS_DATA.value: "vision_vqa_pre_process", + }, } diff --git a/olive/evaluator/accuracy.py b/olive/evaluator/accuracy.py index db5e29775..82b7b6418 100644 --- a/olive/evaluator/accuracy.py +++ b/olive/evaluator/accuracy.py @@ -217,3 +217,169 @@ def measure(self, model_output, target): if total_inference == 0: return float("inf") return round(total_audio / total_inference, 2) + + +class ExactMatch(AccuracyBase): + """Exact match metric for vision VQA evaluation. + + Compares predicted answer strings to ground truth answers using + case-insensitive, whitespace-normalized string equality. + Returns the fraction of samples with an exact match. + """ + + name: Optional[str] = "exact_match" + + @classmethod + def _default_config(cls) -> dict[str, ConfigParam]: + return {} + + @staticmethod + def _normalize(text: str) -> str: + """Normalize text for comparison: lowercase and collapse whitespace.""" + return " ".join(text.strip().lower().split()) + + def measure(self, model_output, target): + preds = model_output.preds + refs = target + if isinstance(preds, str): + preds = [preds] + elif not isinstance(preds, list): + preds = list(preds) + if isinstance(refs, str): + refs = [refs] + elif not isinstance(refs, list): + refs = list(refs) + + if len(preds) != len(refs): + raise ValueError( + f"Number of predictions ({len(preds)}) does not match " + f"number of references ({len(refs)}) for exact_match metric." + ) + + correct = sum(1 for p, r in zip(preds, refs) if self._normalize(str(p)) == self._normalize(str(r))) + return correct / len(refs) if refs else 0.0 + + +class RelaxedAccuracy(AccuracyBase): + """Relaxed accuracy metric for chart/math VQA evaluation. + + For numeric answers, allows a ±5% tolerance (standard for ChartQA). + For non-numeric answers, falls back to exact string match. + Returns the fraction of samples that match within tolerance. + """ + + name: Optional[str] = "relaxed_accuracy" + + @classmethod + def _default_config(cls) -> dict[str, ConfigParam]: + return { + "tolerance": ConfigParam(type_=float, required=False, default_value=0.05), + } + + @staticmethod + def _try_parse_number(text: str): + """Try to parse text as a number. Returns (True, value) or (False, None).""" + text = text.strip().replace(",", "").replace("%", "") + try: + return True, float(text) + except ValueError: + return False, None + + @staticmethod + def _normalize(text: str) -> str: + return " ".join(text.strip().lower().split()) + + def measure(self, model_output, target): + preds = model_output.preds + refs = target + if isinstance(preds, str): + preds = [preds] + elif not isinstance(preds, list): + preds = list(preds) + if isinstance(refs, str): + refs = [refs] + elif not isinstance(refs, list): + refs = list(refs) + + if len(preds) != len(refs): + raise ValueError( + f"Number of predictions ({len(preds)}) does not match " + f"number of references ({len(refs)}) for relaxed_accuracy metric." + ) + + tolerance = self.config_dict.get("tolerance", 0.05) + correct = 0 + for pred, ref in zip(preds, refs): + pred_str = str(pred) + ref_str = str(ref) + pred_is_num, pred_val = self._try_parse_number(pred_str) + ref_is_num, ref_val = self._try_parse_number(ref_str) + + if pred_is_num and ref_is_num: + # Numeric comparison with tolerance + if ref_val == 0: + if pred_val == 0: + correct += 1 + elif abs(pred_val - ref_val) / abs(ref_val) <= tolerance: + correct += 1 + else: + # String comparison (exact match, case-insensitive) + if self._normalize(pred_str) == self._normalize(ref_str): + correct += 1 + + return correct / len(refs) if refs else 0.0 + + +class WordSortRatio(AccuracyBase): + """Word sort ratio metric for OCR evaluation. + + Computes the ratio of matching words between prediction and reference + after sorting words alphabetically. This measures word-level overlap + regardless of word order. + Returns the average ratio across all samples. + """ + + name: Optional[str] = "word_sort_ratio" + + @classmethod + def _default_config(cls) -> dict[str, ConfigParam]: + return {} + + @staticmethod + def _compute_word_sort_ratio(pred: str, ref: str) -> float: + """Compute word sort ratio between two strings.""" + pred_words = sorted(pred.strip().lower().split()) + ref_words = sorted(ref.strip().lower().split()) + + if not ref_words: + return 1.0 if not pred_words else 0.0 + + # Count matching words using multiset intersection + from collections import Counter + + pred_counter = Counter(pred_words) + ref_counter = Counter(ref_words) + intersection = sum((pred_counter & ref_counter).values()) + total = max(len(pred_words), len(ref_words)) + return intersection / total if total > 0 else 0.0 + + def measure(self, model_output, target): + preds = model_output.preds + refs = target + if isinstance(preds, str): + preds = [preds] + elif not isinstance(preds, list): + preds = list(preds) + if isinstance(refs, str): + refs = [refs] + elif not isinstance(refs, list): + refs = list(refs) + + if len(preds) != len(refs): + raise ValueError( + f"Number of predictions ({len(preds)}) does not match " + f"number of references ({len(refs)}) for word_sort_ratio metric." + ) + + total_ratio = sum(self._compute_word_sort_ratio(str(p), str(r)) for p, r in zip(preds, refs)) + return total_ratio / len(refs) if refs else 0.0 diff --git a/olive/evaluator/metric.py b/olive/evaluator/metric.py index 24ae3f88d..fe26c8937 100644 --- a/olive/evaluator/metric.py +++ b/olive/evaluator/metric.py @@ -40,6 +40,10 @@ class AccuracySubType(StrEnumBase): PERPLEXITY = "perplexity" WER = "wer" RTFX = "rtfx" + # Vision metrics + EXACT_MATCH = "exact_match" + RELAXED_ACCURACY = "relaxed_accuracy" + WORD_SORT_RATIO = "word_sort_ratio" class LatencySubType(StrEnumBase): diff --git a/olive/evaluator/olive_evaluator.py b/olive/evaluator/olive_evaluator.py index eb88b1597..10df8351b 100644 --- a/olive/evaluator/olive_evaluator.py +++ b/olive/evaluator/olive_evaluator.py @@ -61,6 +61,14 @@ class OliveModelOutput(NamedTuple): # Text-based accuracy sub-types that work with string predictions/targets _TEXT_BASED_ACCURACY_SUBTYPES = {AccuracySubType.WER, AccuracySubType.RTFX} +_VISION_ACCURACY_SUBTYPES = {AccuracySubType.EXACT_MATCH, AccuracySubType.RELAXED_ACCURACY, AccuracySubType.WORD_SORT_RATIO} + +# Task-to-metric validation: maps data task types to their allowed vision metrics +_VISION_TASK_METRIC_MAP = { + "vision-vqa": {AccuracySubType.EXACT_MATCH}, + "vision-chart-qa": {AccuracySubType.RELAXED_ACCURACY}, + "vision-ocr": {AccuracySubType.WORD_SORT_RATIO}, +} def _is_text_based_metric(metric: "Metric") -> bool: @@ -80,6 +88,56 @@ def _is_text_based_metric(metric: "Metric") -> bool: return all(text_based) +def _is_vision_metric(metric: "Metric") -> bool: + """Check if metric uses vision accuracy sub-types (exact_match, relaxed_accuracy, word_sort_ratio). + + Raises ValueError if vision sub-types are mixed with non-vision sub-types, + as they require different inference paths. + """ + if metric.type != MetricType.ACCURACY: + return False + vision_based = [sub.name in _VISION_ACCURACY_SUBTYPES for sub in metric.sub_types] + if any(vision_based) and not all(vision_based): + raise ValueError( + "Cannot mix vision accuracy sub-types (exact_match, relaxed_accuracy, word_sort_ratio) " + "with other sub-types in the same metric. Please define them as separate metrics." + ) + return all(vision_based) + + +def _validate_vision_task_metric(metric: "Metric") -> None: + """Validate that the vision metric sub-types are compatible with the data task type. + + Raises ValueError if the metric is not compatible with the task. + """ + if not _is_vision_metric(metric): + return + + task_type = None + if metric.data_config and hasattr(metric.data_config, "task_type"): + task_type = metric.data_config.task_type + elif metric.data_config and hasattr(metric.data_config, "params_config"): + task_type = getattr(metric.data_config.params_config, "task_type", None) + + if task_type is None: + # No task type specified, allow any vision metric + return + + allowed_metrics = _VISION_TASK_METRIC_MAP.get(task_type) + if allowed_metrics is None: + raise ValueError( + f"Unknown vision task type '{task_type}'. " + f"Supported task types: {list(_VISION_TASK_METRIC_MAP.keys())}." + ) + + for sub in metric.sub_types: + if sub.name not in allowed_metrics: + raise ValueError( + f"Metric sub-type '{sub.name}' is not compatible with task type '{task_type}'. " + f"Allowed metrics for '{task_type}': {[m.value for m in allowed_metrics]}." + ) + + class OliveEvaluator(ABC): def __init__(self, **kwargs): super().__init__() @@ -518,7 +576,12 @@ def _evaluate_onnx_accuracy( device: Device = Device.CPU, execution_providers: Optional[Union[str, list[str]]] = None, ) -> MetricResult: - if _is_text_based_metric(metric): + if _is_vision_metric(metric): + _validate_vision_task_metric(metric) + inference_output, targets = self._inference_vision( + model, metric, dataloader, post_func, device, execution_providers + ) + elif _is_text_based_metric(metric): # Auto-detect genai model by checking for genai_config.json genai_config_path = Path(model.model_path).parent / "genai_config.json" if genai_config_path.exists(): @@ -646,6 +709,74 @@ def _inference_text( } return OliveModelOutput(preds=all_preds, logits=timing_metadata), all_targets + def _inference_vision( + self, + model: ONNXModelHandler, + metric: Metric, + dataloader: "DataLoader", + post_func=None, + device: Device = Device.CPU, + execution_providers: Optional[Union[str, list[str]]] = None, + ) -> tuple[OliveModelOutput, Any]: + """Vision-based inference for VQA/OCR metrics (exact_match, relaxed_accuracy, word_sort_ratio). + + The post_func must return predicted answer strings per batch. + Labels from the dataloader must be reference answer strings. + """ + session, inference_settings = OnnxEvaluator.get_session_wrapper( + model, metric, dataloader, device, execution_providers + ) + io_config = model.io_config + run_kwargs = metric.get_run_kwargs() + + all_preds = [] + all_targets = [] + output_names = io_config["output_names"] + is_single_tensor_output = len(output_names) == 1 + + for batch in dataloader: + input_data, labels = OliveEvaluator.unpack_batch_for_accuracy(batch) + input_feed = format_data(input_data, io_config) + result = model.run_session(session, input_feed, **run_kwargs) + if is_single_tensor_output: + result = torch.from_numpy(result[0]) if hasattr(result[0], "__array__") else torch.tensor(result[0]) + else: + result = { + name: torch.from_numpy(result[i]) if hasattr(result[i], "__array__") else torch.tensor(result[i]) + for i, name in enumerate(output_names) + } + # post_func must decode model output to answer strings + outputs = post_func(result) if post_func else result + if isinstance(outputs, str): + all_preds.append(outputs) + elif isinstance(outputs, (list, tuple)): + if not outputs: + continue + if not isinstance(outputs[0], str): + raise ValueError( + f"post_func must return str or list[str] for vision metrics, " + f"but got list of {type(outputs[0]).__name__}. " + f"Ensure your post_func decodes model output to answer text." + ) + all_preds.extend(outputs) + else: + raise ValueError( + f"post_func must return str or list[str] for vision metrics, " + f"but got {type(outputs).__name__}. " + f"Ensure your post_func decodes model output to answer text." + ) + # labels should be reference answer strings + if isinstance(labels, (list, tuple)): + all_targets.extend(labels) + else: + all_targets.append(labels) + + tuning_result_file = inference_settings.get("tuning_result_file") + if tuning_result_file: + dump_tuning_result(session.session, tuning_result_file) + + return OliveModelOutput(preds=all_preds, logits=None), all_targets + def _inference_text_genai( self, model: ONNXModelHandler, @@ -1216,7 +1347,12 @@ def _evaluate_accuracy( device: Device = Device.CPU, execution_providers: Optional[Union[str, list[str]]] = None, ) -> MetricResult: - if _is_text_based_metric(metric): + if _is_vision_metric(metric): + _validate_vision_task_metric(metric) + inference_output, targets = self._inference_vision( + model, metric, dataloader, post_func, device, execution_providers + ) + elif _is_text_based_metric(metric): inference_output, targets = self._inference_text( model, metric, dataloader, post_func, device, execution_providers ) @@ -1302,6 +1438,60 @@ def _inference_text( } return OliveModelOutput(preds=all_preds, logits=timing_metadata), all_targets + @torch.no_grad() + def _inference_vision( + self, + model: "PyTorchModelHandler", + metric: Metric, + dataloader: "DataLoader", + post_func=None, + device: Device = Device.CPU, + execution_providers: Optional[Union[str, list[str]]] = None, + ) -> tuple[OliveModelOutput, Any]: + """Vision-based inference for VQA/OCR metrics (exact_match, relaxed_accuracy, word_sort_ratio).""" + session = model.prepare_session() + all_preds = [] + all_targets = [] + torch_device = _OliveEvaluator.device_string_to_torch_device(device) + run_kwargs = metric.get_run_kwargs() + session.to(torch_device) + + for batch in dataloader: + input_data_i, labels = OliveEvaluator.unpack_batch_for_accuracy(batch) + input_data = tensor_data_to_device(input_data_i, torch_device) + result = model.run_session(session, input_data, **run_kwargs) + outputs = post_func(result) if post_func else result + + if isinstance(outputs, str): + all_preds.append(outputs) + elif isinstance(outputs, (list, tuple)): + if not outputs: + continue + if not isinstance(outputs[0], str): + raise ValueError( + f"post_func must return str or list[str] for vision metrics, " + f"but got list of {type(outputs[0]).__name__}. " + f"Ensure your post_func decodes model output to answer text." + ) + all_preds.extend(outputs) + else: + raise ValueError( + f"post_func must return str or list[str] for vision metrics, " + f"but got {type(outputs).__name__}. " + f"Ensure your post_func decodes model output to answer text." + ) + if isinstance(labels, (list, tuple)): + all_targets.extend(labels) + else: + all_targets.append(labels) + + if torch_device: + session.to("cpu") + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + return OliveModelOutput(preds=all_preds, logits=None), all_targets + @torch.no_grad() def _evaluate_raw_latency( self, diff --git a/olive/olive_config.json b/olive/olive_config.json index 8e6ae9274..84efb5256 100644 --- a/olive/olive_config.json +++ b/olive/olive_config.json @@ -715,6 +715,7 @@ "speech": [ "jiwer", "librosa", "soundfile" ], "tf": [ "tensorflow==1.15.0" ], "torch-tensorrt": [ "torch-tensorrt" ], - "tune-session-params": [ "psutil" ] + "tune-session-params": [ "psutil" ], + "vision": [ "Pillow" ] } } diff --git a/test/evaluator/test_accuracy.py b/test/evaluator/test_accuracy.py index f2c2754dd..401438793 100644 --- a/test/evaluator/test_accuracy.py +++ b/test/evaluator/test_accuracy.py @@ -213,3 +213,204 @@ def test_rtfx_missing_metadata(self): model_output = OliveModelOutput(preds=["text"], logits=None) with pytest.raises(ValueError, match="RTFx metric requires timing metadata"): rtfx.measure(model_output, ["text"]) + + +class TestExactMatch: + def test_perfect_match(self): + from olive.evaluator.accuracy import ExactMatch + + metric = ExactMatch({}) + model_output = OliveModelOutput(preds=["A", "B", "C"], logits=None) + targets = ["A", "B", "C"] + result = metric.measure(model_output, targets) + assert result == 1.0 + + def test_case_insensitive(self): + from olive.evaluator.accuracy import ExactMatch + + metric = ExactMatch({}) + model_output = OliveModelOutput(preds=["Hello World", "TEST"], logits=None) + targets = ["hello world", "test"] + result = metric.measure(model_output, targets) + assert result == 1.0 + + def test_whitespace_normalization(self): + from olive.evaluator.accuracy import ExactMatch + + metric = ExactMatch({}) + model_output = OliveModelOutput(preds=[" hello world "], logits=None) + targets = ["hello world"] + result = metric.measure(model_output, targets) + assert result == 1.0 + + def test_partial_match(self): + from olive.evaluator.accuracy import ExactMatch + + metric = ExactMatch({}) + model_output = OliveModelOutput(preds=["A", "wrong", "C"], logits=None) + targets = ["A", "B", "C"] + result = metric.measure(model_output, targets) + assert abs(result - 2.0 / 3.0) < 1e-6 + + def test_no_match(self): + from olive.evaluator.accuracy import ExactMatch + + metric = ExactMatch({}) + model_output = OliveModelOutput(preds=["X", "Y"], logits=None) + targets = ["A", "B"] + result = metric.measure(model_output, targets) + assert result == 0.0 + + def test_single_string_input(self): + from olive.evaluator.accuracy import ExactMatch + + metric = ExactMatch({}) + model_output = OliveModelOutput(preds="answer", logits=None) + targets = "answer" + result = metric.measure(model_output, targets) + assert result == 1.0 + + def test_length_mismatch_raises(self): + from olive.evaluator.accuracy import ExactMatch + + metric = ExactMatch({}) + model_output = OliveModelOutput(preds=["A", "B"], logits=None) + targets = ["A"] + with pytest.raises(ValueError, match="does not match"): + metric.measure(model_output, targets) + + +class TestRelaxedAccuracy: + def test_exact_numeric_match(self): + from olive.evaluator.accuracy import RelaxedAccuracy + + metric = RelaxedAccuracy({}) + model_output = OliveModelOutput(preds=["42.0"], logits=None) + targets = ["42.0"] + result = metric.measure(model_output, targets) + assert result == 1.0 + + def test_within_tolerance(self): + from olive.evaluator.accuracy import RelaxedAccuracy + + metric = RelaxedAccuracy({}) + # 41 is within 5% of 42 (42 * 0.05 = 2.1, |42-41| = 1 < 2.1) + model_output = OliveModelOutput(preds=["41"], logits=None) + targets = ["42"] + result = metric.measure(model_output, targets) + assert result == 1.0 + + def test_outside_tolerance(self): + from olive.evaluator.accuracy import RelaxedAccuracy + + metric = RelaxedAccuracy({}) + # 35 is outside 5% of 42 (42 * 0.05 = 2.1, |42-35| = 7 > 2.1) + model_output = OliveModelOutput(preds=["35"], logits=None) + targets = ["42"] + result = metric.measure(model_output, targets) + assert result == 0.0 + + def test_string_exact_match(self): + from olive.evaluator.accuracy import RelaxedAccuracy + + metric = RelaxedAccuracy({}) + model_output = OliveModelOutput(preds=["yes"], logits=None) + targets = ["yes"] + result = metric.measure(model_output, targets) + assert result == 1.0 + + def test_string_no_match(self): + from olive.evaluator.accuracy import RelaxedAccuracy + + metric = RelaxedAccuracy({}) + model_output = OliveModelOutput(preds=["yes"], logits=None) + targets = ["no"] + result = metric.measure(model_output, targets) + assert result == 0.0 + + def test_percentage_values(self): + from olive.evaluator.accuracy import RelaxedAccuracy + + metric = RelaxedAccuracy({}) + # 50% parsed as 50, within 5% of 50 + model_output = OliveModelOutput(preds=["51%"], logits=None) + targets = ["50%"] + result = metric.measure(model_output, targets) + assert result == 1.0 + + def test_zero_target(self): + from olive.evaluator.accuracy import RelaxedAccuracy + + metric = RelaxedAccuracy({}) + model_output = OliveModelOutput(preds=["0"], logits=None) + targets = ["0"] + result = metric.measure(model_output, targets) + assert result == 1.0 + + def test_custom_tolerance(self): + from olive.evaluator.accuracy import RelaxedAccuracy + + metric = RelaxedAccuracy({"tolerance": 0.1}) + # 38 is within 10% of 42 (42 * 0.1 = 4.2, |42-38| = 4 < 4.2) + model_output = OliveModelOutput(preds=["38"], logits=None) + targets = ["42"] + result = metric.measure(model_output, targets) + assert result == 1.0 + + +class TestWordSortRatio: + def test_perfect_match(self): + from olive.evaluator.accuracy import WordSortRatio + + metric = WordSortRatio({}) + model_output = OliveModelOutput(preds=["hello world"], logits=None) + targets = ["hello world"] + result = metric.measure(model_output, targets) + assert result == 1.0 + + def test_reordered_words(self): + from olive.evaluator.accuracy import WordSortRatio + + metric = WordSortRatio({}) + # Same words, different order → ratio = 1.0 (sorted comparison) + model_output = OliveModelOutput(preds=["world hello"], logits=None) + targets = ["hello world"] + result = metric.measure(model_output, targets) + assert result == 1.0 + + def test_partial_overlap(self): + from olive.evaluator.accuracy import WordSortRatio + + metric = WordSortRatio({}) + # "hello" matches, "earth" doesn't match "world" + model_output = OliveModelOutput(preds=["hello earth"], logits=None) + targets = ["hello world"] + result = metric.measure(model_output, targets) + assert 0.0 < result < 1.0 + + def test_no_overlap(self): + from olive.evaluator.accuracy import WordSortRatio + + metric = WordSortRatio({}) + model_output = OliveModelOutput(preds=["foo bar"], logits=None) + targets = ["hello world"] + result = metric.measure(model_output, targets) + assert result == 0.0 + + def test_case_insensitive(self): + from olive.evaluator.accuracy import WordSortRatio + + metric = WordSortRatio({}) + model_output = OliveModelOutput(preds=["HELLO WORLD"], logits=None) + targets = ["hello world"] + result = metric.measure(model_output, targets) + assert result == 1.0 + + def test_empty_reference(self): + from olive.evaluator.accuracy import WordSortRatio + + metric = WordSortRatio({}) + model_output = OliveModelOutput(preds=[""], logits=None) + targets = [""] + result = metric.measure(model_output, targets) + assert result == 1.0 From f8047857e18d506622ffbc6362cee27b82a8b099 Mon Sep 17 00:00:00 2001 From: David Fan Date: Wed, 27 May 2026 18:43:46 +0000 Subject: [PATCH 2/4] Add benchmark documentation to vision metric classes and enum Document which datasets each vision metric is suitable for: - exact_match: AI2D, ScienceQA, TextVQA, MathVista, MMMU, InterGPS - relaxed_accuracy: ChartQA - word_sort_ratio: OCR Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- olive/evaluator/accuracy.py | 6 ++++++ olive/evaluator/metric.py | 8 ++++---- olive/evaluator/olive_evaluator.py | 6 +++++- 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/olive/evaluator/accuracy.py b/olive/evaluator/accuracy.py index 82b7b6418..2a9fb4894 100644 --- a/olive/evaluator/accuracy.py +++ b/olive/evaluator/accuracy.py @@ -225,6 +225,8 @@ class ExactMatch(AccuracyBase): Compares predicted answer strings to ground truth answers using case-insensitive, whitespace-normalized string equality. Returns the fraction of samples with an exact match. + + Suitable for benchmarks: AI2D, ScienceQA, TextVQA, MathVista, MMMU, InterGPS. """ name: Optional[str] = "exact_match" @@ -266,6 +268,8 @@ class RelaxedAccuracy(AccuracyBase): For numeric answers, allows a ±5% tolerance (standard for ChartQA). For non-numeric answers, falls back to exact string match. Returns the fraction of samples that match within tolerance. + + Suitable for benchmarks: ChartQA. """ name: Optional[str] = "relaxed_accuracy" @@ -337,6 +341,8 @@ class WordSortRatio(AccuracyBase): after sorting words alphabetically. This measures word-level overlap regardless of word order. Returns the average ratio across all samples. + + Suitable for benchmarks: OCR. """ name: Optional[str] = "word_sort_ratio" diff --git a/olive/evaluator/metric.py b/olive/evaluator/metric.py index fe26c8937..860a33ea7 100644 --- a/olive/evaluator/metric.py +++ b/olive/evaluator/metric.py @@ -40,10 +40,10 @@ class AccuracySubType(StrEnumBase): PERPLEXITY = "perplexity" WER = "wer" RTFX = "rtfx" - # Vision metrics - EXACT_MATCH = "exact_match" - RELAXED_ACCURACY = "relaxed_accuracy" - WORD_SORT_RATIO = "word_sort_ratio" + # Vision metrics (aligned with LITE babelbench benchmarks) + EXACT_MATCH = "exact_match" # AI2D, ScienceQA, TextVQA, MathVista, MMMU, InterGPS + RELAXED_ACCURACY = "relaxed_accuracy" # ChartQA (±5% numeric tolerance) + WORD_SORT_RATIO = "word_sort_ratio" # OCR (word-level overlap) class LatencySubType(StrEnumBase): diff --git a/olive/evaluator/olive_evaluator.py b/olive/evaluator/olive_evaluator.py index 10df8351b..7f3178922 100644 --- a/olive/evaluator/olive_evaluator.py +++ b/olive/evaluator/olive_evaluator.py @@ -63,7 +63,11 @@ class OliveModelOutput(NamedTuple): _TEXT_BASED_ACCURACY_SUBTYPES = {AccuracySubType.WER, AccuracySubType.RTFX} _VISION_ACCURACY_SUBTYPES = {AccuracySubType.EXACT_MATCH, AccuracySubType.RELAXED_ACCURACY, AccuracySubType.WORD_SORT_RATIO} -# Task-to-metric validation: maps data task types to their allowed vision metrics +# Task-to-metric validation: maps data task types to their allowed vision metrics. +# Metrics are aligned with standard vision benchmarks used in LITE (babelbench): +# - vision-vqa (exact_match): AI2D, ScienceQA, TextVQA, MathVista, MMMU, InterGPS +# - vision-chart-qa (relaxed_accuracy): ChartQA (±5% numeric tolerance) +# - vision-ocr (word_sort_ratio): OCR (word-level overlap) _VISION_TASK_METRIC_MAP = { "vision-vqa": {AccuracySubType.EXACT_MATCH}, "vision-chart-qa": {AccuracySubType.RELAXED_ACCURACY}, From 28d81108b16f5be5e903e2284fcdd361a2915f74 Mon Sep 17 00:00:00 2001 From: David Fan Date: Wed, 27 May 2026 18:46:29 +0000 Subject: [PATCH 3/4] Remove internal project references from comments Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- olive/evaluator/metric.py | 2 +- olive/evaluator/olive_evaluator.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/olive/evaluator/metric.py b/olive/evaluator/metric.py index 860a33ea7..5d19b132e 100644 --- a/olive/evaluator/metric.py +++ b/olive/evaluator/metric.py @@ -40,7 +40,7 @@ class AccuracySubType(StrEnumBase): PERPLEXITY = "perplexity" WER = "wer" RTFX = "rtfx" - # Vision metrics (aligned with LITE babelbench benchmarks) + # Vision metrics (aligned with standard public vision benchmarks) EXACT_MATCH = "exact_match" # AI2D, ScienceQA, TextVQA, MathVista, MMMU, InterGPS RELAXED_ACCURACY = "relaxed_accuracy" # ChartQA (±5% numeric tolerance) WORD_SORT_RATIO = "word_sort_ratio" # OCR (word-level overlap) diff --git a/olive/evaluator/olive_evaluator.py b/olive/evaluator/olive_evaluator.py index 7f3178922..e0c298580 100644 --- a/olive/evaluator/olive_evaluator.py +++ b/olive/evaluator/olive_evaluator.py @@ -64,7 +64,7 @@ class OliveModelOutput(NamedTuple): _VISION_ACCURACY_SUBTYPES = {AccuracySubType.EXACT_MATCH, AccuracySubType.RELAXED_ACCURACY, AccuracySubType.WORD_SORT_RATIO} # Task-to-metric validation: maps data task types to their allowed vision metrics. -# Metrics are aligned with standard vision benchmarks used in LITE (babelbench): +# Metrics are aligned with standard public vision benchmarks: # - vision-vqa (exact_match): AI2D, ScienceQA, TextVQA, MathVista, MMMU, InterGPS # - vision-chart-qa (relaxed_accuracy): ChartQA (±5% numeric tolerance) # - vision-ocr (word_sort_ratio): OCR (word-level overlap) From 883d4c8bd761927fd4b4c9d227017d482e689442 Mon Sep 17 00:00:00 2001 From: David Fan Date: Wed, 27 May 2026 18:59:57 +0000 Subject: [PATCH 4/4] Address review comments: fix task type extraction, formatting, and docs - Fix _validate_vision_task_metric to extract task from pre_process_data_config.params['task'] instead of non-existent DataConfig attributes - Wrap _VISION_ACCURACY_SUBTYPES across multiple lines for lint compliance - Use lowercase 'pillow' in olive_config.json for consistency - Add docstring note about ONNX vs PyTorch path for vision_vqa_pre_process Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- olive/data/component/pre_process_data.py | 6 ++++++ olive/evaluator/olive_evaluator.py | 19 ++++++++++++------- olive/olive_config.json | 2 +- 3 files changed, 19 insertions(+), 8 deletions(-) diff --git a/olive/data/component/pre_process_data.py b/olive/data/component/pre_process_data.py index 8da36c247..31cf59f66 100644 --- a/olive/data/component/pre_process_data.py +++ b/olive/data/component/pre_process_data.py @@ -397,6 +397,12 @@ def vision_vqa_pre_process( Loads image, question, and ground truth answer from a HuggingFace dataset. Returns a dataset of ({"image": image, "question": question}, answer) pairs. + Note: This returns raw PIL images and question strings. For the PyTorch evaluator, + the model's own processor/tokenizer should be applied in the post_func or within + the model's forward method. For the ONNX evaluator, provide a custom pre-process + component that applies the appropriate processor/tokenizer to produce numeric + tensors matching the model's io_config. + Args: dataset: HuggingFace dataset with image, question, and answer columns. image_col: Name of the image column. Defaults to "image". diff --git a/olive/evaluator/olive_evaluator.py b/olive/evaluator/olive_evaluator.py index e0c298580..74e865370 100644 --- a/olive/evaluator/olive_evaluator.py +++ b/olive/evaluator/olive_evaluator.py @@ -61,7 +61,11 @@ class OliveModelOutput(NamedTuple): # Text-based accuracy sub-types that work with string predictions/targets _TEXT_BASED_ACCURACY_SUBTYPES = {AccuracySubType.WER, AccuracySubType.RTFX} -_VISION_ACCURACY_SUBTYPES = {AccuracySubType.EXACT_MATCH, AccuracySubType.RELAXED_ACCURACY, AccuracySubType.WORD_SORT_RATIO} +_VISION_ACCURACY_SUBTYPES = { + AccuracySubType.EXACT_MATCH, + AccuracySubType.RELAXED_ACCURACY, + AccuracySubType.WORD_SORT_RATIO, +} # Task-to-metric validation: maps data task types to their allowed vision metrics. # Metrics are aligned with standard public vision benchmarks: @@ -118,10 +122,12 @@ def _validate_vision_task_metric(metric: "Metric") -> None: return task_type = None - if metric.data_config and hasattr(metric.data_config, "task_type"): - task_type = metric.data_config.task_type - elif metric.data_config and hasattr(metric.data_config, "params_config"): - task_type = getattr(metric.data_config.params_config, "task_type", None) + if metric.data_config: + # Extract task from pre_process_data_config params, which is how HuggingfaceContainer + # maps task types (e.g., "vision-vqa", "vision-chart-qa", "vision-ocr") to components. + pre_process_config = metric.data_config.pre_process_data_config + if pre_process_config and pre_process_config.params: + task_type = pre_process_config.params.get("task") if task_type is None: # No task type specified, allow any vision metric @@ -130,8 +136,7 @@ def _validate_vision_task_metric(metric: "Metric") -> None: allowed_metrics = _VISION_TASK_METRIC_MAP.get(task_type) if allowed_metrics is None: raise ValueError( - f"Unknown vision task type '{task_type}'. " - f"Supported task types: {list(_VISION_TASK_METRIC_MAP.keys())}." + f"Unknown vision task type '{task_type}'. Supported task types: {list(_VISION_TASK_METRIC_MAP.keys())}." ) for sub in metric.sub_types: diff --git a/olive/olive_config.json b/olive/olive_config.json index 84efb5256..c7da5c7cb 100644 --- a/olive/olive_config.json +++ b/olive/olive_config.json @@ -716,6 +716,6 @@ "tf": [ "tensorflow==1.15.0" ], "torch-tensorrt": [ "torch-tensorrt" ], "tune-session-params": [ "psutil" ], - "vision": [ "Pillow" ] + "vision": [ "pillow" ] } }