diff --git a/olive/data/component/pre_process_data.py b/olive/data/component/pre_process_data.py index b61b39bf5..f818f1ac4 100644 --- a/olive/data/component/pre_process_data.py +++ b/olive/data/component/pre_process_data.py @@ -387,6 +387,8 @@ def vision_vqa_pre_process( image_col: str = "image", question_col: str = "question", answer_col: str = "answer", + options_col: str = "", + system_prompt: str = "", max_samples: Optional[int] = None, limit: Optional[float] = None, seed: int = 42, @@ -408,6 +410,10 @@ def vision_vqa_pre_process( image_col: Name of the image column. Defaults to "image". question_col: Name of the question column. Defaults to "question". answer_col: Name of the answer column. Defaults to "answer". + options_col: Name of the options column for multiple-choice questions. If specified, + options are formatted as numbered choices and appended to the question. Defaults to "". + system_prompt: System prompt to guide model responses (e.g., "Reply with only the + option number"). Passed through to the evaluator. Defaults to "". max_samples: Maximum number of samples (deprecated, use limit). Defaults to None. limit: Sampling limit following Olive convention: If >= 1: use first N samples. @@ -438,11 +444,13 @@ class VisionVQADataset: Note: Use batch_size=1 in dataloader config as images have variable sizes. """ - def __init__(self, hf_dataset, image_column, question_column, answer_column): + def __init__(self, hf_dataset, image_column, question_column, answer_column, options_column="", sys_prompt=""): self.dataset = hf_dataset self.image_column = image_column self.question_column = question_column self.answer_column = answer_column + self.options_column = options_column + self.system_prompt = sys_prompt def __len__(self): return len(self.dataset) @@ -452,11 +460,27 @@ def __getitem__(self, idx): image = item[self.image_column] question = item[self.question_column] answer = item[self.answer_column] + + # Format options into the question if options_col is specified + has_options = False + if self.options_column and self.options_column in item: + options = item[self.options_column] + if isinstance(options, (list, tuple)): + options_text = "\n".join(f"{i}. {opt}" for i, opt in enumerate(options)) + question = f"{question}\n{options_text}" + has_options = True + # Handle list/tuple answers (some datasets have multiple valid answers) # Join with | separator so metrics can match against any valid answer if isinstance(answer, (list, tuple)): answer = "|".join(str(a) for a in answer) if answer else "" - return {"image": image, "question": question}, str(answer) + input_dict = { + "image": image, + "question": question, + "system_prompt": self.system_prompt, + "extract_number": has_options, + } + return input_dict, str(answer) @staticmethod def collate_fn(batch): @@ -472,4 +496,4 @@ def collate_fn(batch): answers = [item[1] for item in batch] return (inputs, answers) - return VisionVQADataset(dataset, image_col, question_col, answer_col) + return VisionVQADataset(dataset, image_col, question_col, answer_col, options_col, system_prompt) diff --git a/olive/data/config.py b/olive/data/config.py index 89ed9ee17..866438b37 100644 --- a/olive/data/config.py +++ b/olive/data/config.py @@ -213,12 +213,11 @@ def to_data_container(self) -> "DataContainer": return dc_cls(config=self) def _update_default_component_type_with_task_type(self, dc_cls, default_components_type): - for component_name, config in self.components.items(): + for config in self.components.values(): if config and config.params: task_type = config.params.get("task") if task_type: - task_specific_override = dc_cls.task_type_components_map.get( - task_type.replace("-with-past", ""), {} - ).get(component_name) - if task_specific_override: - default_components_type[component_name] = task_specific_override + task_overrides = dc_cls.task_type_components_map.get(task_type.replace("-with-past", ""), {}) + # Apply all component overrides for this task type + for override_component, override_type in task_overrides.items(): + default_components_type[override_component] = override_type diff --git a/olive/evaluator/olive_evaluator.py b/olive/evaluator/olive_evaluator.py index bfa684acb..d1238c037 100644 --- a/olive/evaluator/olive_evaluator.py +++ b/olive/evaluator/olive_evaluator.py @@ -582,6 +582,20 @@ def _inference( dump_tuning_result(session.session, tuning_result_file) return OliveModelOutput(preds=preds, logits=logits), targets + @staticmethod + def _load_genai_config(model: ONNXModelHandler) -> Optional[dict]: + """Load genai_config.json from the model directory, or return None if not found.""" + genai_config_path = Path(model.model_path).parent / "genai_config.json" + if not genai_config_path.exists(): + return None + import json + + try: + with genai_config_path.open(encoding="utf-8") as f: + return json.load(f) + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON in genai config file: {genai_config_path}") from e + def _evaluate_onnx_accuracy( self, model: ONNXModelHandler, @@ -593,18 +607,21 @@ def _evaluate_onnx_accuracy( ) -> MetricResult: if _is_vision_metric(metric): _validate_vision_task_metric(metric) - inference_output, targets = self._inference_vision( - model, metric, dataloader, post_func, device, execution_providers - ) + # Auto-detect genai vision model by checking for genai_config.json with vision field + genai_cfg = self._load_genai_config(model) + use_genai_vision = genai_cfg is not None and "vision" in genai_cfg.get("model", {}) + + if use_genai_vision: + inference_output, targets = self._inference_vision_genai(model, dataloader, device) + else: + inference_output, targets = self._inference_vision( + model, metric, dataloader, post_func, device, execution_providers + ) elif _is_text_based_metric(metric): # Auto-detect genai model by checking for genai_config.json - genai_config_path = Path(model.model_path).parent / "genai_config.json" - if genai_config_path.exists(): - import json - - with genai_config_path.open() as f: - genai_config = json.load(f) - model_type = genai_config.get("model", {}).get("type", "") + genai_cfg = self._load_genai_config(model) + if genai_cfg: + model_type = genai_cfg.get("model", {}).get("type", "") if model_type == "whisper": inference_output, targets = self._inference_text_genai( @@ -795,6 +812,142 @@ def _inference_vision( return OliveModelOutput(preds=all_preds, logits=None), all_targets + def _inference_vision_genai( + self, + model: ONNXModelHandler, + dataloader: "DataLoader", + device: Device = Device.CPU, + ) -> tuple[OliveModelOutput, Any]: + """Vision-based inference for VQA/OCR metrics using onnxruntime-genai. + + Auto-detected when the model directory contains genai_config.json with a vision field. + Uses og.Model with multimodal processor for vision-language models (e.g., Qwen3-VL). + The dataloader must yield (input_dict, labels) where input_dict contains + 'image' (PIL Image) and 'question' (str), and labels are reference answer strings. + + Note: GPU/CPU selection is driven by the `device` parameter. onnxruntime-genai uses + short provider names internally (e.g., "cuda") which differ from ORT-style EP names. + """ + try: + import onnxruntime_genai as og + except ImportError as e: + raise ImportError( + "onnxruntime-genai is required for genai-based vision evaluation. " + "Install it with: pip install onnxruntime-genai" + ) from e + + import json + import re + import tempfile + + try: + from PIL import Image + except ImportError as e: + raise ImportError("Pillow is required for vision evaluation. Install it with: pip install Pillow") from e + + model_dir = str(Path(model.model_path).parent) + + # max_length in genai is total sequence length (input + output). + # Default to 1028 which accommodates image/prompt tokens (~200-500) plus answer tokens. + # Note: genai_config.json's search.max_length is typically the full context window + # (e.g., 262144) which is too large — the model will stop at EOS well before this cap. + max_length = 1028 + + # Build og.Model with appropriate execution provider + # Note: onnxruntime-genai uses CPU by default when no provider is appended. + # Only non-CPU providers need to be explicitly added using short names (e.g., "cuda"). + # This follows the same pattern as _inference_text_genai and _inference_text_genai_streaming. + config = og.Config(model_dir) + config.clear_providers() + if device == Device.GPU: + config.append_provider("cuda") + og_model = og.Model(config) + processor = og_model.create_multimodal_processor() + tokenizer = og.Tokenizer(og_model) + + all_preds = [] + all_targets = [] + + # Use a temporary directory for image files to avoid per-file create/delete overhead + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_img_path = Path(tmp_dir) / "input.png" + + for batch in dataloader: + input_data, labels = OliveEvaluator.unpack_batch_for_accuracy(batch) + + # input_data is a dict with 'image' (PIL) and 'question' (str) + # or a list of such dicts for batch_size > 1 + items = [input_data] if isinstance(input_data, dict) else input_data + + for item in items: + pil_image = item.get("image") + question = item.get("question", "") + sys_prompt = item.get("system_prompt", "") + extract_number = item.get("extract_number", False) + + if pil_image is None: + # Append empty pred to maintain alignment with targets + all_preds.append("") + continue + + # Ensure PIL Image + if not isinstance(pil_image, Image.Image): + with Image.open(pil_image) as img: + pil_image = img.convert("RGB") + + # Build chat messages for the vision-language model + messages = [] + if sys_prompt: + messages.append({"role": "system", "content": sys_prompt}) + messages.append( + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": question}, + ], + } + ) + messages_json = json.dumps(messages) + + # Save image to temp file for og.Images (reuse same path to minimize I/O) + pil_image.save(str(tmp_img_path), format="PNG") + images = og.Images.open(str(tmp_img_path)) + + prompt = tokenizer.apply_chat_template(messages_json, add_generation_prompt=True) + inputs = processor(prompt, images=images) + + params = og.GeneratorParams(og_model) + params.set_search_options(max_length=max_length, do_sample=False) + + generator = og.Generator(og_model, params) + generator.set_inputs(inputs) + + tokens = [] + while not generator.is_done(): + generator.generate_next_token() + tokens.append(generator.get_next_tokens()[0]) + del generator + + pred = tokenizer.decode(tokens).strip() + # For multiple-choice tasks, extract leading number from responses + # like "1. D" or "0. krill" to match the expected answer format + if extract_number: + num_match = re.match(r"^(\d+)", pred) + if num_match: + pred = num_match.group(1) + all_preds.append(pred) + + # Collect reference texts (aligned with preds including empty ones for None images) + if isinstance(labels, (list, tuple)): + all_targets.extend(labels) + else: + all_targets.append(labels) + + del og_model + + return OliveModelOutput(preds=all_preds, logits=None), all_targets + def _inference_text_genai( self, model: ONNXModelHandler, diff --git a/test/evaluator/test_olive_evaluator.py b/test/evaluator/test_olive_evaluator.py index 88c4c0d72..1812f64dc 100644 --- a/test/evaluator/test_olive_evaluator.py +++ b/test/evaluator/test_olive_evaluator.py @@ -620,3 +620,167 @@ def test_validate_vision_task_metric_no_task_skips(self): metric = self._make_vision_metric(["exact_match"]) # No task specified, should not raise _validate_vision_task_metric(metric) + + +class TestOnnxEvaluatorGenaiVisionDetection: + """Tests for genai vision model detection and dispatch via the public evaluate() method.""" + + def _make_model_with_genai_config(self, tmp_path, genai_config_content): + """Create a mock ONNXModelHandler with a genai_config.json in its directory.""" + import json + + from olive.model.handler.onnx import ONNXModelHandler + + model_dir = tmp_path / "model" + model_dir.mkdir() + model_file = model_dir / "text.onnx" + model_file.write_text("") # dummy file + + if genai_config_content is not None: + config_path = model_dir / "genai_config.json" + config_path.write_text(json.dumps(genai_config_content)) + + model = MagicMock(spec=ONNXModelHandler) + model.model_path = str(model_file) + model.framework = "onnx" + return model + + def _make_vision_accuracy_metric(self): + """Create a metric that triggers the vision accuracy evaluation path.""" + metric = MagicMock() + metric.name = "accuracy" + metric.type = MetricType.ACCURACY + metric.sub_types = [MagicMock()] + metric.sub_types[0].name = "exact_match" + metric.data_config = None + metric.user_config = MagicMock() + metric.user_config.user_script = None + metric.user_config.script_dir = None + metric.user_config.data_dir = None + metric.user_config.batch_size = 1 + metric.user_config.dataloader_func = None + metric.user_config.post_processing_func = None + metric.user_config.evaluate_func = None + metric.user_config.input_names = None + metric.user_config.input_shapes = None + metric.backend = "huggingface_metrics" + return metric + + def test_genai_vision_detected_when_vision_field_present(self, tmp_path): + """Dispatch to genai vision path when genai_config.json has a vision field.""" + from olive.evaluator.olive_evaluator import OliveModelOutput + + config = {"model": {"vision": {"inputs": "pixel_values"}}} + model = self._make_model_with_genai_config(tmp_path, config) + + with ( + patch.object(OnnxEvaluator, "_inference_vision_genai") as mock_genai, + patch.object(OnnxEvaluator, "_inference_vision") as mock_vision, + patch("olive.evaluator.olive_evaluator.OliveEvaluator.compute_accuracy") as mock_compute, + patch("olive.evaluator.olive_evaluator._is_vision_metric", return_value=True), + patch("olive.evaluator.olive_evaluator._validate_vision_task_metric"), + patch("olive.evaluator.olive_evaluator.OliveEvaluator.get_user_config") as mock_get_cfg, + patch( + "olive.evaluator.olive_evaluator.OliveEvaluator.generate_metric_user_config_with_model_io" + ) as mock_gen, + ): + mock_genai.return_value = (OliveModelOutput(preds=["answer"], logits=None), ["answer"]) + mock_compute.return_value = MagicMock() + metric = self._make_vision_accuracy_metric() + mock_gen.return_value = metric + mock_get_cfg.return_value = (MagicMock(), None, None) + + evaluator = OnnxEvaluator() + evaluator.evaluate(model, [metric], Device.CPU, None) + + mock_genai.assert_called_once() + mock_vision.assert_not_called() + + def test_genai_vision_detected_with_empty_vision_object(self, tmp_path): + """Dispatch to genai vision path even when vision value is an empty dict.""" + from olive.evaluator.olive_evaluator import OliveModelOutput + + config = {"model": {"vision": {}}} + model = self._make_model_with_genai_config(tmp_path, config) + + with ( + patch.object(OnnxEvaluator, "_inference_vision_genai") as mock_genai, + patch.object(OnnxEvaluator, "_inference_vision") as mock_vision, + patch("olive.evaluator.olive_evaluator.OliveEvaluator.compute_accuracy") as mock_compute, + patch("olive.evaluator.olive_evaluator._is_vision_metric", return_value=True), + patch("olive.evaluator.olive_evaluator._validate_vision_task_metric"), + patch("olive.evaluator.olive_evaluator.OliveEvaluator.get_user_config") as mock_get_cfg, + patch( + "olive.evaluator.olive_evaluator.OliveEvaluator.generate_metric_user_config_with_model_io" + ) as mock_gen, + ): + mock_genai.return_value = (OliveModelOutput(preds=["answer"], logits=None), ["answer"]) + mock_compute.return_value = MagicMock() + metric = self._make_vision_accuracy_metric() + mock_gen.return_value = metric + mock_get_cfg.return_value = (MagicMock(), None, None) + + evaluator = OnnxEvaluator() + evaluator.evaluate(model, [metric], Device.CPU, None) + + mock_genai.assert_called_once() + mock_vision.assert_not_called() + + def test_standard_vision_when_no_vision_field(self, tmp_path): + """Dispatch to standard vision path when genai_config has no vision field.""" + from olive.evaluator.olive_evaluator import OliveModelOutput + + config = {"model": {"type": "whisper"}} + model = self._make_model_with_genai_config(tmp_path, config) + + with ( + patch.object(OnnxEvaluator, "_inference_vision_genai") as mock_genai, + patch.object(OnnxEvaluator, "_inference_vision") as mock_vision, + patch("olive.evaluator.olive_evaluator.OliveEvaluator.compute_accuracy") as mock_compute, + patch("olive.evaluator.olive_evaluator._is_vision_metric", return_value=True), + patch("olive.evaluator.olive_evaluator._validate_vision_task_metric"), + patch("olive.evaluator.olive_evaluator.OliveEvaluator.get_user_config") as mock_get_cfg, + patch( + "olive.evaluator.olive_evaluator.OliveEvaluator.generate_metric_user_config_with_model_io" + ) as mock_gen, + ): + mock_vision.return_value = (OliveModelOutput(preds=["answer"], logits=None), ["answer"]) + mock_compute.return_value = MagicMock() + metric = self._make_vision_accuracy_metric() + mock_gen.return_value = metric + mock_get_cfg.return_value = (MagicMock(), None, None) + + evaluator = OnnxEvaluator() + evaluator.evaluate(model, [metric], Device.CPU, None) + + mock_vision.assert_called_once() + mock_genai.assert_not_called() + + def test_standard_vision_when_no_genai_config(self, tmp_path): + """Dispatch to standard vision path when genai_config.json is missing.""" + from olive.evaluator.olive_evaluator import OliveModelOutput + + model = self._make_model_with_genai_config(tmp_path, None) + + with ( + patch.object(OnnxEvaluator, "_inference_vision_genai") as mock_genai, + patch.object(OnnxEvaluator, "_inference_vision") as mock_vision, + patch("olive.evaluator.olive_evaluator.OliveEvaluator.compute_accuracy") as mock_compute, + patch("olive.evaluator.olive_evaluator._is_vision_metric", return_value=True), + patch("olive.evaluator.olive_evaluator._validate_vision_task_metric"), + patch("olive.evaluator.olive_evaluator.OliveEvaluator.get_user_config") as mock_get_cfg, + patch( + "olive.evaluator.olive_evaluator.OliveEvaluator.generate_metric_user_config_with_model_io" + ) as mock_gen, + ): + mock_vision.return_value = (OliveModelOutput(preds=["answer"], logits=None), ["answer"]) + mock_compute.return_value = MagicMock() + metric = self._make_vision_accuracy_metric() + mock_gen.return_value = metric + mock_get_cfg.return_value = (MagicMock(), None, None) + + evaluator = OnnxEvaluator() + evaluator.evaluate(model, [metric], Device.CPU, None) + + mock_vision.assert_called_once() + mock_genai.assert_not_called()