diff --git a/olive/data/component/pre_process_data.py b/olive/data/component/pre_process_data.py index 2881ea891..31cf59f66 100644 --- a/olive/data/component/pre_process_data.py +++ b/olive/data/component/pre_process_data.py @@ -379,3 +379,92 @@ def collate_fn(batch): return (audios, texts) return SpeechTranscriptionDataset(dataset, audio_col, text_col) + + +@Registry.register_pre_process() +def vision_vqa_pre_process( + dataset, + image_col: str = "image", + question_col: str = "question", + answer_col: str = "answer", + max_samples: Optional[int] = None, + limit: Optional[float] = None, + seed: int = 42, + **kwargs, +): + """Pre-process data for vision VQA evaluation. + + Loads image, question, and ground truth answer from a HuggingFace dataset. + Returns a dataset of ({"image": image, "question": question}, answer) pairs. + + Note: This returns raw PIL images and question strings. For the PyTorch evaluator, + the model's own processor/tokenizer should be applied in the post_func or within + the model's forward method. For the ONNX evaluator, provide a custom pre-process + component that applies the appropriate processor/tokenizer to produce numeric + tensors matching the model's io_config. + + Args: + dataset: HuggingFace dataset with image, question, and answer columns. + image_col: Name of the image column. Defaults to "image". + question_col: Name of the question column. Defaults to "question". + answer_col: Name of the answer column. Defaults to "answer". + max_samples: Maximum number of samples (deprecated, use limit). Defaults to None. + limit: Sampling limit following Olive convention: + If >= 1: use first N samples. + If 0 < limit < 1: randomly sample that percentage. + If 0 or None: use all samples. + seed: Random seed for percentage-based sampling. Defaults to 42. + **kwargs: Additional arguments. + + """ + # Apply sampling: prefer limit over max_samples + effective_limit = limit if limit is not None else (max_samples if max_samples else 0) + if effective_limit and effective_limit != 0: + from random import Random + + total = len(dataset) + if 0 < effective_limit < 1: + n = max(1, int(total * effective_limit)) + rng = Random(seed) + indices = sorted(rng.sample(range(total), min(n, total))) + dataset = dataset.select(indices) + elif effective_limit >= 1: + n = min(int(effective_limit), total) + dataset = dataset.select(range(n)) + + class VisionVQADataset: + """Dataset that returns (input_dict, answer_text) pairs for VQA evaluation. + + Note: Use batch_size=1 in dataloader config as images have variable sizes. + """ + + def __init__(self, hf_dataset, image_column, question_column, answer_column): + self.dataset = hf_dataset + self.image_column = image_column + self.question_column = question_column + self.answer_column = answer_column + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, idx): + item = self.dataset[idx] + image = item[self.image_column] + question = item[self.question_column] + answer = item[self.answer_column] + # Handle list answers (some datasets have multiple valid answers) + if isinstance(answer, list): + answer = answer[0] if answer else "" + return {"image": image, "question": question}, str(answer) + + @staticmethod + def collate_fn(batch): + """Collate VQA batches. Use with batch_size=1 for variable-size images.""" + if len(batch) == 1: + input_dict, answer = batch[0] + return (input_dict, [answer]) + inputs = [item[0] for item in batch] + answers = [item[1] for item in batch] + return (inputs, answers) + + return VisionVQADataset(dataset, image_col, question_col, answer_col) diff --git a/olive/data/container/huggingface_container.py b/olive/data/container/huggingface_container.py index 9a6bb81e3..ab504d7a6 100644 --- a/olive/data/container/huggingface_container.py +++ b/olive/data/container/huggingface_container.py @@ -41,4 +41,13 @@ class HuggingfaceContainer(DataContainer): "speech-transcription": { DataComponentType.PRE_PROCESS_DATA.value: "speech_transcription_pre_process", }, + "vision-vqa": { + DataComponentType.PRE_PROCESS_DATA.value: "vision_vqa_pre_process", + }, + "vision-chart-qa": { + DataComponentType.PRE_PROCESS_DATA.value: "vision_vqa_pre_process", + }, + "vision-ocr": { + DataComponentType.PRE_PROCESS_DATA.value: "vision_vqa_pre_process", + }, } diff --git a/olive/evaluator/accuracy.py b/olive/evaluator/accuracy.py index db5e29775..2a9fb4894 100644 --- a/olive/evaluator/accuracy.py +++ b/olive/evaluator/accuracy.py @@ -217,3 +217,175 @@ def measure(self, model_output, target): if total_inference == 0: return float("inf") return round(total_audio / total_inference, 2) + + +class ExactMatch(AccuracyBase): + """Exact match metric for vision VQA evaluation. + + Compares predicted answer strings to ground truth answers using + case-insensitive, whitespace-normalized string equality. + Returns the fraction of samples with an exact match. + + Suitable for benchmarks: AI2D, ScienceQA, TextVQA, MathVista, MMMU, InterGPS. + """ + + name: Optional[str] = "exact_match" + + @classmethod + def _default_config(cls) -> dict[str, ConfigParam]: + return {} + + @staticmethod + def _normalize(text: str) -> str: + """Normalize text for comparison: lowercase and collapse whitespace.""" + return " ".join(text.strip().lower().split()) + + def measure(self, model_output, target): + preds = model_output.preds + refs = target + if isinstance(preds, str): + preds = [preds] + elif not isinstance(preds, list): + preds = list(preds) + if isinstance(refs, str): + refs = [refs] + elif not isinstance(refs, list): + refs = list(refs) + + if len(preds) != len(refs): + raise ValueError( + f"Number of predictions ({len(preds)}) does not match " + f"number of references ({len(refs)}) for exact_match metric." + ) + + correct = sum(1 for p, r in zip(preds, refs) if self._normalize(str(p)) == self._normalize(str(r))) + return correct / len(refs) if refs else 0.0 + + +class RelaxedAccuracy(AccuracyBase): + """Relaxed accuracy metric for chart/math VQA evaluation. + + For numeric answers, allows a ±5% tolerance (standard for ChartQA). + For non-numeric answers, falls back to exact string match. + Returns the fraction of samples that match within tolerance. + + Suitable for benchmarks: ChartQA. + """ + + name: Optional[str] = "relaxed_accuracy" + + @classmethod + def _default_config(cls) -> dict[str, ConfigParam]: + return { + "tolerance": ConfigParam(type_=float, required=False, default_value=0.05), + } + + @staticmethod + def _try_parse_number(text: str): + """Try to parse text as a number. Returns (True, value) or (False, None).""" + text = text.strip().replace(",", "").replace("%", "") + try: + return True, float(text) + except ValueError: + return False, None + + @staticmethod + def _normalize(text: str) -> str: + return " ".join(text.strip().lower().split()) + + def measure(self, model_output, target): + preds = model_output.preds + refs = target + if isinstance(preds, str): + preds = [preds] + elif not isinstance(preds, list): + preds = list(preds) + if isinstance(refs, str): + refs = [refs] + elif not isinstance(refs, list): + refs = list(refs) + + if len(preds) != len(refs): + raise ValueError( + f"Number of predictions ({len(preds)}) does not match " + f"number of references ({len(refs)}) for relaxed_accuracy metric." + ) + + tolerance = self.config_dict.get("tolerance", 0.05) + correct = 0 + for pred, ref in zip(preds, refs): + pred_str = str(pred) + ref_str = str(ref) + pred_is_num, pred_val = self._try_parse_number(pred_str) + ref_is_num, ref_val = self._try_parse_number(ref_str) + + if pred_is_num and ref_is_num: + # Numeric comparison with tolerance + if ref_val == 0: + if pred_val == 0: + correct += 1 + elif abs(pred_val - ref_val) / abs(ref_val) <= tolerance: + correct += 1 + else: + # String comparison (exact match, case-insensitive) + if self._normalize(pred_str) == self._normalize(ref_str): + correct += 1 + + return correct / len(refs) if refs else 0.0 + + +class WordSortRatio(AccuracyBase): + """Word sort ratio metric for OCR evaluation. + + Computes the ratio of matching words between prediction and reference + after sorting words alphabetically. This measures word-level overlap + regardless of word order. + Returns the average ratio across all samples. + + Suitable for benchmarks: OCR. + """ + + name: Optional[str] = "word_sort_ratio" + + @classmethod + def _default_config(cls) -> dict[str, ConfigParam]: + return {} + + @staticmethod + def _compute_word_sort_ratio(pred: str, ref: str) -> float: + """Compute word sort ratio between two strings.""" + pred_words = sorted(pred.strip().lower().split()) + ref_words = sorted(ref.strip().lower().split()) + + if not ref_words: + return 1.0 if not pred_words else 0.0 + + # Count matching words using multiset intersection + from collections import Counter + + pred_counter = Counter(pred_words) + ref_counter = Counter(ref_words) + intersection = sum((pred_counter & ref_counter).values()) + total = max(len(pred_words), len(ref_words)) + return intersection / total if total > 0 else 0.0 + + def measure(self, model_output, target): + preds = model_output.preds + refs = target + if isinstance(preds, str): + preds = [preds] + elif not isinstance(preds, list): + preds = list(preds) + if isinstance(refs, str): + refs = [refs] + elif not isinstance(refs, list): + refs = list(refs) + + if len(preds) != len(refs): + raise ValueError( + f"Number of predictions ({len(preds)}) does not match " + f"number of references ({len(refs)}) for word_sort_ratio metric." + ) + + total_ratio = sum(self._compute_word_sort_ratio(str(p), str(r)) for p, r in zip(preds, refs)) + return total_ratio / len(refs) if refs else 0.0 diff --git a/olive/evaluator/metric.py b/olive/evaluator/metric.py index 24ae3f88d..5d19b132e 100644 --- a/olive/evaluator/metric.py +++ b/olive/evaluator/metric.py @@ -40,6 +40,10 @@ class AccuracySubType(StrEnumBase): PERPLEXITY = "perplexity" WER = "wer" RTFX = "rtfx" + # Vision metrics (aligned with standard public vision benchmarks) + EXACT_MATCH = "exact_match" # AI2D, ScienceQA, TextVQA, MathVista, MMMU, InterGPS + RELAXED_ACCURACY = "relaxed_accuracy" # ChartQA (±5% numeric tolerance) + WORD_SORT_RATIO = "word_sort_ratio" # OCR (word-level overlap) class LatencySubType(StrEnumBase): diff --git a/olive/evaluator/olive_evaluator.py b/olive/evaluator/olive_evaluator.py index eb88b1597..74e865370 100644 --- a/olive/evaluator/olive_evaluator.py +++ b/olive/evaluator/olive_evaluator.py @@ -61,6 +61,22 @@ class OliveModelOutput(NamedTuple): # Text-based accuracy sub-types that work with string predictions/targets _TEXT_BASED_ACCURACY_SUBTYPES = {AccuracySubType.WER, AccuracySubType.RTFX} +_VISION_ACCURACY_SUBTYPES = { + AccuracySubType.EXACT_MATCH, + AccuracySubType.RELAXED_ACCURACY, + AccuracySubType.WORD_SORT_RATIO, +} + +# Task-to-metric validation: maps data task types to their allowed vision metrics. +# Metrics are aligned with standard public vision benchmarks: +# - vision-vqa (exact_match): AI2D, ScienceQA, TextVQA, MathVista, MMMU, InterGPS +# - vision-chart-qa (relaxed_accuracy): ChartQA (±5% numeric tolerance) +# - vision-ocr (word_sort_ratio): OCR (word-level overlap) +_VISION_TASK_METRIC_MAP = { + "vision-vqa": {AccuracySubType.EXACT_MATCH}, + "vision-chart-qa": {AccuracySubType.RELAXED_ACCURACY}, + "vision-ocr": {AccuracySubType.WORD_SORT_RATIO}, +} def _is_text_based_metric(metric: "Metric") -> bool: @@ -80,6 +96,57 @@ def _is_text_based_metric(metric: "Metric") -> bool: return all(text_based) +def _is_vision_metric(metric: "Metric") -> bool: + """Check if metric uses vision accuracy sub-types (exact_match, relaxed_accuracy, word_sort_ratio). + + Raises ValueError if vision sub-types are mixed with non-vision sub-types, + as they require different inference paths. + """ + if metric.type != MetricType.ACCURACY: + return False + vision_based = [sub.name in _VISION_ACCURACY_SUBTYPES for sub in metric.sub_types] + if any(vision_based) and not all(vision_based): + raise ValueError( + "Cannot mix vision accuracy sub-types (exact_match, relaxed_accuracy, word_sort_ratio) " + "with other sub-types in the same metric. Please define them as separate metrics." + ) + return all(vision_based) + + +def _validate_vision_task_metric(metric: "Metric") -> None: + """Validate that the vision metric sub-types are compatible with the data task type. + + Raises ValueError if the metric is not compatible with the task. + """ + if not _is_vision_metric(metric): + return + + task_type = None + if metric.data_config: + # Extract task from pre_process_data_config params, which is how HuggingfaceContainer + # maps task types (e.g., "vision-vqa", "vision-chart-qa", "vision-ocr") to components. + pre_process_config = metric.data_config.pre_process_data_config + if pre_process_config and pre_process_config.params: + task_type = pre_process_config.params.get("task") + + if task_type is None: + # No task type specified, allow any vision metric + return + + allowed_metrics = _VISION_TASK_METRIC_MAP.get(task_type) + if allowed_metrics is None: + raise ValueError( + f"Unknown vision task type '{task_type}'. Supported task types: {list(_VISION_TASK_METRIC_MAP.keys())}." + ) + + for sub in metric.sub_types: + if sub.name not in allowed_metrics: + raise ValueError( + f"Metric sub-type '{sub.name}' is not compatible with task type '{task_type}'. " + f"Allowed metrics for '{task_type}': {[m.value for m in allowed_metrics]}." + ) + + class OliveEvaluator(ABC): def __init__(self, **kwargs): super().__init__() @@ -518,7 +585,12 @@ def _evaluate_onnx_accuracy( device: Device = Device.CPU, execution_providers: Optional[Union[str, list[str]]] = None, ) -> MetricResult: - if _is_text_based_metric(metric): + if _is_vision_metric(metric): + _validate_vision_task_metric(metric) + inference_output, targets = self._inference_vision( + model, metric, dataloader, post_func, device, execution_providers + ) + elif _is_text_based_metric(metric): # Auto-detect genai model by checking for genai_config.json genai_config_path = Path(model.model_path).parent / "genai_config.json" if genai_config_path.exists(): @@ -646,6 +718,74 @@ def _inference_text( } return OliveModelOutput(preds=all_preds, logits=timing_metadata), all_targets + def _inference_vision( + self, + model: ONNXModelHandler, + metric: Metric, + dataloader: "DataLoader", + post_func=None, + device: Device = Device.CPU, + execution_providers: Optional[Union[str, list[str]]] = None, + ) -> tuple[OliveModelOutput, Any]: + """Vision-based inference for VQA/OCR metrics (exact_match, relaxed_accuracy, word_sort_ratio). + + The post_func must return predicted answer strings per batch. + Labels from the dataloader must be reference answer strings. + """ + session, inference_settings = OnnxEvaluator.get_session_wrapper( + model, metric, dataloader, device, execution_providers + ) + io_config = model.io_config + run_kwargs = metric.get_run_kwargs() + + all_preds = [] + all_targets = [] + output_names = io_config["output_names"] + is_single_tensor_output = len(output_names) == 1 + + for batch in dataloader: + input_data, labels = OliveEvaluator.unpack_batch_for_accuracy(batch) + input_feed = format_data(input_data, io_config) + result = model.run_session(session, input_feed, **run_kwargs) + if is_single_tensor_output: + result = torch.from_numpy(result[0]) if hasattr(result[0], "__array__") else torch.tensor(result[0]) + else: + result = { + name: torch.from_numpy(result[i]) if hasattr(result[i], "__array__") else torch.tensor(result[i]) + for i, name in enumerate(output_names) + } + # post_func must decode model output to answer strings + outputs = post_func(result) if post_func else result + if isinstance(outputs, str): + all_preds.append(outputs) + elif isinstance(outputs, (list, tuple)): + if not outputs: + continue + if not isinstance(outputs[0], str): + raise ValueError( + f"post_func must return str or list[str] for vision metrics, " + f"but got list of {type(outputs[0]).__name__}. " + f"Ensure your post_func decodes model output to answer text." + ) + all_preds.extend(outputs) + else: + raise ValueError( + f"post_func must return str or list[str] for vision metrics, " + f"but got {type(outputs).__name__}. " + f"Ensure your post_func decodes model output to answer text." + ) + # labels should be reference answer strings + if isinstance(labels, (list, tuple)): + all_targets.extend(labels) + else: + all_targets.append(labels) + + tuning_result_file = inference_settings.get("tuning_result_file") + if tuning_result_file: + dump_tuning_result(session.session, tuning_result_file) + + return OliveModelOutput(preds=all_preds, logits=None), all_targets + def _inference_text_genai( self, model: ONNXModelHandler, @@ -1216,7 +1356,12 @@ def _evaluate_accuracy( device: Device = Device.CPU, execution_providers: Optional[Union[str, list[str]]] = None, ) -> MetricResult: - if _is_text_based_metric(metric): + if _is_vision_metric(metric): + _validate_vision_task_metric(metric) + inference_output, targets = self._inference_vision( + model, metric, dataloader, post_func, device, execution_providers + ) + elif _is_text_based_metric(metric): inference_output, targets = self._inference_text( model, metric, dataloader, post_func, device, execution_providers ) @@ -1302,6 +1447,60 @@ def _inference_text( } return OliveModelOutput(preds=all_preds, logits=timing_metadata), all_targets + @torch.no_grad() + def _inference_vision( + self, + model: "PyTorchModelHandler", + metric: Metric, + dataloader: "DataLoader", + post_func=None, + device: Device = Device.CPU, + execution_providers: Optional[Union[str, list[str]]] = None, + ) -> tuple[OliveModelOutput, Any]: + """Vision-based inference for VQA/OCR metrics (exact_match, relaxed_accuracy, word_sort_ratio).""" + session = model.prepare_session() + all_preds = [] + all_targets = [] + torch_device = _OliveEvaluator.device_string_to_torch_device(device) + run_kwargs = metric.get_run_kwargs() + session.to(torch_device) + + for batch in dataloader: + input_data_i, labels = OliveEvaluator.unpack_batch_for_accuracy(batch) + input_data = tensor_data_to_device(input_data_i, torch_device) + result = model.run_session(session, input_data, **run_kwargs) + outputs = post_func(result) if post_func else result + + if isinstance(outputs, str): + all_preds.append(outputs) + elif isinstance(outputs, (list, tuple)): + if not outputs: + continue + if not isinstance(outputs[0], str): + raise ValueError( + f"post_func must return str or list[str] for vision metrics, " + f"but got list of {type(outputs[0]).__name__}. " + f"Ensure your post_func decodes model output to answer text." + ) + all_preds.extend(outputs) + else: + raise ValueError( + f"post_func must return str or list[str] for vision metrics, " + f"but got {type(outputs).__name__}. " + f"Ensure your post_func decodes model output to answer text." + ) + if isinstance(labels, (list, tuple)): + all_targets.extend(labels) + else: + all_targets.append(labels) + + if torch_device: + session.to("cpu") + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + return OliveModelOutput(preds=all_preds, logits=None), all_targets + @torch.no_grad() def _evaluate_raw_latency( self, diff --git a/olive/olive_config.json b/olive/olive_config.json index 8e6ae9274..c7da5c7cb 100644 --- a/olive/olive_config.json +++ b/olive/olive_config.json @@ -715,6 +715,7 @@ "speech": [ "jiwer", "librosa", "soundfile" ], "tf": [ "tensorflow==1.15.0" ], "torch-tensorrt": [ "torch-tensorrt" ], - "tune-session-params": [ "psutil" ] + "tune-session-params": [ "psutil" ], + "vision": [ "pillow" ] } } diff --git a/test/evaluator/test_accuracy.py b/test/evaluator/test_accuracy.py index f2c2754dd..401438793 100644 --- a/test/evaluator/test_accuracy.py +++ b/test/evaluator/test_accuracy.py @@ -213,3 +213,204 @@ def test_rtfx_missing_metadata(self): model_output = OliveModelOutput(preds=["text"], logits=None) with pytest.raises(ValueError, match="RTFx metric requires timing metadata"): rtfx.measure(model_output, ["text"]) + + +class TestExactMatch: + def test_perfect_match(self): + from olive.evaluator.accuracy import ExactMatch + + metric = ExactMatch({}) + model_output = OliveModelOutput(preds=["A", "B", "C"], logits=None) + targets = ["A", "B", "C"] + result = metric.measure(model_output, targets) + assert result == 1.0 + + def test_case_insensitive(self): + from olive.evaluator.accuracy import ExactMatch + + metric = ExactMatch({}) + model_output = OliveModelOutput(preds=["Hello World", "TEST"], logits=None) + targets = ["hello world", "test"] + result = metric.measure(model_output, targets) + assert result == 1.0 + + def test_whitespace_normalization(self): + from olive.evaluator.accuracy import ExactMatch + + metric = ExactMatch({}) + model_output = OliveModelOutput(preds=[" hello world "], logits=None) + targets = ["hello world"] + result = metric.measure(model_output, targets) + assert result == 1.0 + + def test_partial_match(self): + from olive.evaluator.accuracy import ExactMatch + + metric = ExactMatch({}) + model_output = OliveModelOutput(preds=["A", "wrong", "C"], logits=None) + targets = ["A", "B", "C"] + result = metric.measure(model_output, targets) + assert abs(result - 2.0 / 3.0) < 1e-6 + + def test_no_match(self): + from olive.evaluator.accuracy import ExactMatch + + metric = ExactMatch({}) + model_output = OliveModelOutput(preds=["X", "Y"], logits=None) + targets = ["A", "B"] + result = metric.measure(model_output, targets) + assert result == 0.0 + + def test_single_string_input(self): + from olive.evaluator.accuracy import ExactMatch + + metric = ExactMatch({}) + model_output = OliveModelOutput(preds="answer", logits=None) + targets = "answer" + result = metric.measure(model_output, targets) + assert result == 1.0 + + def test_length_mismatch_raises(self): + from olive.evaluator.accuracy import ExactMatch + + metric = ExactMatch({}) + model_output = OliveModelOutput(preds=["A", "B"], logits=None) + targets = ["A"] + with pytest.raises(ValueError, match="does not match"): + metric.measure(model_output, targets) + + +class TestRelaxedAccuracy: + def test_exact_numeric_match(self): + from olive.evaluator.accuracy import RelaxedAccuracy + + metric = RelaxedAccuracy({}) + model_output = OliveModelOutput(preds=["42.0"], logits=None) + targets = ["42.0"] + result = metric.measure(model_output, targets) + assert result == 1.0 + + def test_within_tolerance(self): + from olive.evaluator.accuracy import RelaxedAccuracy + + metric = RelaxedAccuracy({}) + # 41 is within 5% of 42 (42 * 0.05 = 2.1, |42-41| = 1 < 2.1) + model_output = OliveModelOutput(preds=["41"], logits=None) + targets = ["42"] + result = metric.measure(model_output, targets) + assert result == 1.0 + + def test_outside_tolerance(self): + from olive.evaluator.accuracy import RelaxedAccuracy + + metric = RelaxedAccuracy({}) + # 35 is outside 5% of 42 (42 * 0.05 = 2.1, |42-35| = 7 > 2.1) + model_output = OliveModelOutput(preds=["35"], logits=None) + targets = ["42"] + result = metric.measure(model_output, targets) + assert result == 0.0 + + def test_string_exact_match(self): + from olive.evaluator.accuracy import RelaxedAccuracy + + metric = RelaxedAccuracy({}) + model_output = OliveModelOutput(preds=["yes"], logits=None) + targets = ["yes"] + result = metric.measure(model_output, targets) + assert result == 1.0 + + def test_string_no_match(self): + from olive.evaluator.accuracy import RelaxedAccuracy + + metric = RelaxedAccuracy({}) + model_output = OliveModelOutput(preds=["yes"], logits=None) + targets = ["no"] + result = metric.measure(model_output, targets) + assert result == 0.0 + + def test_percentage_values(self): + from olive.evaluator.accuracy import RelaxedAccuracy + + metric = RelaxedAccuracy({}) + # 50% parsed as 50, within 5% of 50 + model_output = OliveModelOutput(preds=["51%"], logits=None) + targets = ["50%"] + result = metric.measure(model_output, targets) + assert result == 1.0 + + def test_zero_target(self): + from olive.evaluator.accuracy import RelaxedAccuracy + + metric = RelaxedAccuracy({}) + model_output = OliveModelOutput(preds=["0"], logits=None) + targets = ["0"] + result = metric.measure(model_output, targets) + assert result == 1.0 + + def test_custom_tolerance(self): + from olive.evaluator.accuracy import RelaxedAccuracy + + metric = RelaxedAccuracy({"tolerance": 0.1}) + # 38 is within 10% of 42 (42 * 0.1 = 4.2, |42-38| = 4 < 4.2) + model_output = OliveModelOutput(preds=["38"], logits=None) + targets = ["42"] + result = metric.measure(model_output, targets) + assert result == 1.0 + + +class TestWordSortRatio: + def test_perfect_match(self): + from olive.evaluator.accuracy import WordSortRatio + + metric = WordSortRatio({}) + model_output = OliveModelOutput(preds=["hello world"], logits=None) + targets = ["hello world"] + result = metric.measure(model_output, targets) + assert result == 1.0 + + def test_reordered_words(self): + from olive.evaluator.accuracy import WordSortRatio + + metric = WordSortRatio({}) + # Same words, different order → ratio = 1.0 (sorted comparison) + model_output = OliveModelOutput(preds=["world hello"], logits=None) + targets = ["hello world"] + result = metric.measure(model_output, targets) + assert result == 1.0 + + def test_partial_overlap(self): + from olive.evaluator.accuracy import WordSortRatio + + metric = WordSortRatio({}) + # "hello" matches, "earth" doesn't match "world" + model_output = OliveModelOutput(preds=["hello earth"], logits=None) + targets = ["hello world"] + result = metric.measure(model_output, targets) + assert 0.0 < result < 1.0 + + def test_no_overlap(self): + from olive.evaluator.accuracy import WordSortRatio + + metric = WordSortRatio({}) + model_output = OliveModelOutput(preds=["foo bar"], logits=None) + targets = ["hello world"] + result = metric.measure(model_output, targets) + assert result == 0.0 + + def test_case_insensitive(self): + from olive.evaluator.accuracy import WordSortRatio + + metric = WordSortRatio({}) + model_output = OliveModelOutput(preds=["HELLO WORLD"], logits=None) + targets = ["hello world"] + result = metric.measure(model_output, targets) + assert result == 1.0 + + def test_empty_reference(self): + from olive.evaluator.accuracy import WordSortRatio + + metric = WordSortRatio({}) + model_output = OliveModelOutput(preds=[""], logits=None) + targets = [""] + result = metric.measure(model_output, targets) + assert result == 1.0