Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
fffa1b8
Add vision genai inference path for multi-file VLM evaluation
jiafatom Jun 1, 2026
120d0e7
Fix task_type_components_map to apply all component overrides
jiafatom Jun 1, 2026
5bf9c89
Address review comments on vision genai inference
jiafatom Jun 1, 2026
bf8ca82
Fix EP mapping: skip CPUExecutionProvider for genai
jiafatom Jun 1, 2026
398d655
Fix lint: remove unused import, unused loop var, use .values()
jiafatom Jun 1, 2026
ab380c3
Fix genai provider: use device field instead of ORT EP names
jiafatom Jun 1, 2026
e25cd2e
Cap max_length to 128 for vision VQA generation
jiafatom Jun 1, 2026
b7f46c9
Increase max_length cap to 4096 for vision genai inference
jiafatom Jun 1, 2026
df8756b
Address all review comments and fix lint errors
jiafatom Jun 1, 2026
23fa91b
Add system_prompt support for vision VQA evaluation
jiafatom Jun 1, 2026
a144c5e
Add options_col support and extract leading number from predictions
jiafatom Jun 1, 2026
d51af37
Address review comments: extract helper, remove debug code, fix lint
jiafatom Jun 1, 2026
7d0738e
Add opt-in number extraction for multiple-choice VQA tasks
jiafatom Jun 1, 2026
f3524f6
Re-trigger CI: flaky test_mnb_to_qdq failure unrelated to PR changes
jiafatom Jun 2, 2026
fa7a239
Address Copilot review: fix vision detection for empty dict, add unit…
jiafatom Jun 2, 2026
57fb8b3
Address review: add JSON error handling, guard PIL import, fix file h…
jiafatom Jun 2, 2026
9e6156e
Fix formatting: use parenthesized context managers
jiafatom Jun 2, 2026
d1b55a7
Fix formatting: collapse ImportError raise to single line
jiafatom Jun 2, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions olive/data/component/pre_process_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,8 @@ def vision_vqa_pre_process(
image_col: str = "image",
question_col: str = "question",
answer_col: str = "answer",
options_col: str = "",
system_prompt: str = "",
max_samples: Optional[int] = None,
limit: Optional[float] = None,
seed: int = 42,
Expand All @@ -408,6 +410,10 @@ def vision_vqa_pre_process(
image_col: Name of the image column. Defaults to "image".
question_col: Name of the question column. Defaults to "question".
answer_col: Name of the answer column. Defaults to "answer".
options_col: Name of the options column for multiple-choice questions. If specified,
options are formatted as numbered choices and appended to the question. Defaults to "".
system_prompt: System prompt to guide model responses (e.g., "Reply with only the
option number"). Passed through to the evaluator. Defaults to "".
max_samples: Maximum number of samples (deprecated, use limit). Defaults to None.
limit: Sampling limit following Olive convention:
If >= 1: use first N samples.
Expand Down Expand Up @@ -438,11 +444,13 @@ class VisionVQADataset:
Note: Use batch_size=1 in dataloader config as images have variable sizes.
"""

def __init__(self, hf_dataset, image_column, question_column, answer_column):
def __init__(self, hf_dataset, image_column, question_column, answer_column, options_column="", sys_prompt=""):
self.dataset = hf_dataset
self.image_column = image_column
self.question_column = question_column
self.answer_column = answer_column
self.options_column = options_column
self.system_prompt = sys_prompt

def __len__(self):
return len(self.dataset)
Expand All @@ -452,11 +460,27 @@ def __getitem__(self, idx):
image = item[self.image_column]
question = item[self.question_column]
answer = item[self.answer_column]

# Format options into the question if options_col is specified
has_options = False
if self.options_column and self.options_column in item:
options = item[self.options_column]
if isinstance(options, (list, tuple)):
options_text = "\n".join(f"{i}. {opt}" for i, opt in enumerate(options))
question = f"{question}\n{options_text}"
has_options = True

# Handle list/tuple answers (some datasets have multiple valid answers)
# Join with | separator so metrics can match against any valid answer
if isinstance(answer, (list, tuple)):
answer = "|".join(str(a) for a in answer) if answer else ""
return {"image": image, "question": question}, str(answer)
input_dict = {
"image": image,
"question": question,
"system_prompt": self.system_prompt,
"extract_number": has_options,
}
return input_dict, str(answer)

@staticmethod
def collate_fn(batch):
Expand All @@ -472,4 +496,4 @@ def collate_fn(batch):
answers = [item[1] for item in batch]
return (inputs, answers)

return VisionVQADataset(dataset, image_col, question_col, answer_col)
return VisionVQADataset(dataset, image_col, question_col, answer_col, options_col, system_prompt)
11 changes: 5 additions & 6 deletions olive/data/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,12 +213,11 @@ def to_data_container(self) -> "DataContainer":
return dc_cls(config=self)

def _update_default_component_type_with_task_type(self, dc_cls, default_components_type):
for component_name, config in self.components.items():
for config in self.components.values():
if config and config.params:
task_type = config.params.get("task")
if task_type:
task_specific_override = dc_cls.task_type_components_map.get(
task_type.replace("-with-past", ""), {}
).get(component_name)
if task_specific_override:
default_components_type[component_name] = task_specific_override
task_overrides = dc_cls.task_type_components_map.get(task_type.replace("-with-past", ""), {})
# Apply all component overrides for this task type
for override_component, override_type in task_overrides.items():
default_components_type[override_component] = override_type
173 changes: 163 additions & 10 deletions olive/evaluator/olive_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,20 @@ def _inference(
dump_tuning_result(session.session, tuning_result_file)
return OliveModelOutput(preds=preds, logits=logits), targets

@staticmethod
def _load_genai_config(model: ONNXModelHandler) -> Optional[dict]:
"""Load genai_config.json from the model directory, or return None if not found."""
genai_config_path = Path(model.model_path).parent / "genai_config.json"
if not genai_config_path.exists():
return None
import json

try:
with genai_config_path.open(encoding="utf-8") as f:
return json.load(f)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON in genai config file: {genai_config_path}") from e

def _evaluate_onnx_accuracy(
self,
model: ONNXModelHandler,
Expand All @@ -593,18 +607,21 @@ def _evaluate_onnx_accuracy(
) -> MetricResult:
if _is_vision_metric(metric):
_validate_vision_task_metric(metric)
inference_output, targets = self._inference_vision(
model, metric, dataloader, post_func, device, execution_providers
)
# Auto-detect genai vision model by checking for genai_config.json with vision field
genai_cfg = self._load_genai_config(model)
use_genai_vision = genai_cfg is not None and "vision" in genai_cfg.get("model", {})

Comment thread
jiafatom marked this conversation as resolved.
if use_genai_vision:
inference_output, targets = self._inference_vision_genai(model, dataloader, device)
else:
inference_output, targets = self._inference_vision(
model, metric, dataloader, post_func, device, execution_providers
)
Comment thread
jiafatom marked this conversation as resolved.
Comment thread
jiafatom marked this conversation as resolved.
Comment thread
jiafatom marked this conversation as resolved.
elif _is_text_based_metric(metric):
# Auto-detect genai model by checking for genai_config.json
genai_config_path = Path(model.model_path).parent / "genai_config.json"
if genai_config_path.exists():
import json

with genai_config_path.open() as f:
genai_config = json.load(f)
model_type = genai_config.get("model", {}).get("type", "")
genai_cfg = self._load_genai_config(model)
if genai_cfg:
model_type = genai_cfg.get("model", {}).get("type", "")

if model_type == "whisper":
inference_output, targets = self._inference_text_genai(
Expand Down Expand Up @@ -795,6 +812,142 @@ def _inference_vision(

return OliveModelOutput(preds=all_preds, logits=None), all_targets

def _inference_vision_genai(
self,
model: ONNXModelHandler,
dataloader: "DataLoader",
device: Device = Device.CPU,
) -> tuple[OliveModelOutput, Any]:
"""Vision-based inference for VQA/OCR metrics using onnxruntime-genai.

Auto-detected when the model directory contains genai_config.json with a vision field.
Uses og.Model with multimodal processor for vision-language models (e.g., Qwen3-VL).
The dataloader must yield (input_dict, labels) where input_dict contains
'image' (PIL Image) and 'question' (str), and labels are reference answer strings.

Note: GPU/CPU selection is driven by the `device` parameter. onnxruntime-genai uses
short provider names internally (e.g., "cuda") which differ from ORT-style EP names.
"""
try:
import onnxruntime_genai as og
except ImportError as e:
raise ImportError(
"onnxruntime-genai is required for genai-based vision evaluation. "
"Install it with: pip install onnxruntime-genai"
) from e

import json
import re
import tempfile

try:
from PIL import Image
except ImportError as e:
raise ImportError("Pillow is required for vision evaluation. Install it with: pip install Pillow") from e

model_dir = str(Path(model.model_path).parent)
Comment on lines +839 to +848

# max_length in genai is total sequence length (input + output).
# Default to 1028 which accommodates image/prompt tokens (~200-500) plus answer tokens.
# Note: genai_config.json's search.max_length is typically the full context window
# (e.g., 262144) which is too large — the model will stop at EOS well before this cap.
max_length = 1028
Comment thread
jiafatom marked this conversation as resolved.

# Build og.Model with appropriate execution provider
# Note: onnxruntime-genai uses CPU by default when no provider is appended.
# Only non-CPU providers need to be explicitly added using short names (e.g., "cuda").
# This follows the same pattern as _inference_text_genai and _inference_text_genai_streaming.
config = og.Config(model_dir)
config.clear_providers()
if device == Device.GPU:
config.append_provider("cuda")
og_model = og.Model(config)
Comment thread
jiafatom marked this conversation as resolved.
Comment thread
jiafatom marked this conversation as resolved.
processor = og_model.create_multimodal_processor()
tokenizer = og.Tokenizer(og_model)

all_preds = []
all_targets = []

# Use a temporary directory for image files to avoid per-file create/delete overhead
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_img_path = Path(tmp_dir) / "input.png"

for batch in dataloader:
input_data, labels = OliveEvaluator.unpack_batch_for_accuracy(batch)

# input_data is a dict with 'image' (PIL) and 'question' (str)
# or a list of such dicts for batch_size > 1
items = [input_data] if isinstance(input_data, dict) else input_data

for item in items:
pil_image = item.get("image")
question = item.get("question", "")
sys_prompt = item.get("system_prompt", "")
extract_number = item.get("extract_number", False)

if pil_image is None:
# Append empty pred to maintain alignment with targets
all_preds.append("")
continue

# Ensure PIL Image
if not isinstance(pil_image, Image.Image):
with Image.open(pil_image) as img:
pil_image = img.convert("RGB")

Comment on lines +893 to +897
# Build chat messages for the vision-language model
messages = []
if sys_prompt:
messages.append({"role": "system", "content": sys_prompt})
messages.append(
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": question},
],
}
)
messages_json = json.dumps(messages)

# Save image to temp file for og.Images (reuse same path to minimize I/O)
pil_image.save(str(tmp_img_path), format="PNG")
images = og.Images.open(str(tmp_img_path))

prompt = tokenizer.apply_chat_template(messages_json, add_generation_prompt=True)
inputs = processor(prompt, images=images)

params = og.GeneratorParams(og_model)
params.set_search_options(max_length=max_length, do_sample=False)

generator = og.Generator(og_model, params)
generator.set_inputs(inputs)

tokens = []
while not generator.is_done():
generator.generate_next_token()
tokens.append(generator.get_next_tokens()[0])
del generator

pred = tokenizer.decode(tokens).strip()
# For multiple-choice tasks, extract leading number from responses
# like "1. D" or "0. krill" to match the expected answer format
if extract_number:
num_match = re.match(r"^(\d+)", pred)
if num_match:
pred = num_match.group(1)
all_preds.append(pred)

# Collect reference texts (aligned with preds including empty ones for None images)
if isinstance(labels, (list, tuple)):
all_targets.extend(labels)
else:
all_targets.append(labels)

del og_model

return OliveModelOutput(preds=all_preds, logits=None), all_targets

def _inference_text_genai(
self,
model: ONNXModelHandler,
Expand Down
Loading
Loading