Skip to content
Open
104 changes: 104 additions & 0 deletions olive/passes/onnx/discrepancy_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,20 @@ def _infer_shape(dynamic_shape, known_values=None):
return tuple(inferred_shape)


def _longest_common_token_sequence(seq_a: list[int], seq_b: list[int]) -> int:
"""Compute the length of the longest common token sequence starting from the beginning.

Counts how many tokens match consecutively from the start of both sequences
before the first divergence.
"""
length = 0
for a, b in zip(seq_a, seq_b):
if a != b:
break
length += 1
return length


class OnnxDiscrepancyCheck(Pass):
"""Validates ONNX model outputs against a reference PyTorch model.

Expand All @@ -47,6 +61,8 @@ class OnnxDiscrepancyCheck(Pass):
- Maximum absolute error (MaxAE)
- Number of elements where the absolute difference exceeds 0.1
- Number of elements where the absolute difference exceeds 0.01
- Longest common token sequence from the beginning between transformers
generate and ONNX Runtime GenAI generate (when enabled)

The pass fails if any configured threshold is exceeded.
"""
Expand Down Expand Up @@ -82,6 +98,34 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassCon
"If exceeded, the pass fails."
),
),
"genai_model_path": PassConfigParam(
type_=Optional[str],
default_value=None,
description=(
"Path to the ONNX Runtime GenAI model directory. When provided, the pass "
"runs token generation using both transformers and ONNX Runtime GenAI, then "
"computes the longest common token sequence from the beginning of their outputs."
),
),
"generate_prompt": PassConfigParam(
type_=str,
default_value="The capital of France is",
description="Text prompt used for generation comparison between transformers and GenAI.",
),
"generate_max_new_tokens": PassConfigParam(
type_=int,
default_value=32,
description="Maximum number of new tokens to generate for the token sequence comparison.",
),
"min_longest_common_tokens": PassConfigParam(
type_=Optional[int],
default_value=None,
description=(
"Minimum acceptable length of the longest common token sequence from the "
"beginning between transformers and GenAI outputs. If the actual value is "
"below this threshold, the pass fails."
),
),
}

def _run_for_config(
Expand Down Expand Up @@ -197,5 +241,65 @@ def _run_for_config(
if failures:
raise RuntimeError("ONNX model discrepancy check failed:\n" + "\n".join(f" - {f}" for f in failures))

# Generation token sequence comparison (transformers vs ONNX Runtime GenAI)
if config.genai_model_path:
longest_common = self.compare_generation(config, ref_model)
if config.min_longest_common_tokens is not None and longest_common < config.min_longest_common_tokens:
raise RuntimeError(
f"ONNX model discrepancy check failed:\n"
f" - Longest common token sequence length {longest_common} is below "
f"threshold {config.min_longest_common_tokens}"
)

# Return the model unchanged
return model

def compare_generation(self, config: type[BasePassConfig], ref_model) -> int:
"""Run generation on both transformers and GenAI, return longest common token sequence length."""
try:
import onnxruntime_genai as og
except ImportError as exc:
raise ImportError("Please install `onnxruntime-genai` to enable generation comparison.") from exc
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(config.reference_model_path)

# Transformers generation
input_ids = tokenizer(config.generate_prompt, return_tensors="pt").input_ids
import torch

with torch.no_grad():
transformers_output = ref_model.generate(
input_ids,
max_new_tokens=config.generate_max_new_tokens,
do_sample=False,
)
transformers_tokens = transformers_output[0].tolist()

# ONNX Runtime GenAI generation
genai_model = og.Model(config.genai_model_path)
genai_tokenizer = og.Tokenizer(genai_model)
genai_input_ids = genai_tokenizer.encode(config.generate_prompt)

params = og.GeneratorParams(genai_model)
params.set_search_options(max_length=len(genai_input_ids) + config.generate_max_new_tokens, do_sample=False)

generator = og.Generator(genai_model, params)
generator.append_tokens([genai_input_ids])
genai_tokens = list(genai_input_ids)
while not generator.is_done():
generator.generate_next_token()
genai_tokens.append(generator.get_next_tokens()[0])
del generator
Comment thread
xadupre marked this conversation as resolved.

longest_common = _longest_common_token_sequence(transformers_tokens, genai_tokens)

logger.info(
"OnnxDiscrepancyCheck generation comparison: "
"transformers_len=%d, genai_len=%d, longest_common_token_sequence=%d",
len(transformers_tokens),
len(genai_tokens),
longest_common,
)

return longest_common
161 changes: 161 additions & 0 deletions test/passes/onnx/test_discrepancy_check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import sys
from unittest.mock import MagicMock, patch

from olive.passes.onnx.discrepancy_check import _longest_common_token_sequence


class TestLongestCommonTokenSequence:
"""Unit tests for _longest_common_token_sequence helper."""

def test_identical_sequences(self):
assert _longest_common_token_sequence([1, 2, 3, 4], [1, 2, 3, 4]) == 4

def test_empty_sequences(self):
assert _longest_common_token_sequence([], []) == 0
assert _longest_common_token_sequence([1, 2], []) == 0
assert _longest_common_token_sequence([], [1, 2]) == 0

def test_no_common_prefix(self):
assert _longest_common_token_sequence([1, 2, 3], [4, 5, 6]) == 0

def test_partial_common_prefix(self):
assert _longest_common_token_sequence([1, 2, 3, 4, 5], [1, 2, 3, 9, 10]) == 3

def test_one_is_prefix_of_other(self):
assert _longest_common_token_sequence([1, 2, 3], [1, 2, 3, 4, 5]) == 3
assert _longest_common_token_sequence([1, 2, 3, 4, 5], [1, 2, 3]) == 3

def test_single_element_match(self):
assert _longest_common_token_sequence([7], [7]) == 1

def test_single_element_no_match(self):
assert _longest_common_token_sequence([7], [8]) == 0

def test_diverges_after_first(self):
assert _longest_common_token_sequence([1, 99, 99], [1, 2, 3]) == 1

def test_common_tokens_later_not_counted(self):
# Tokens match later but not from the beginning
assert _longest_common_token_sequence([10, 1, 2, 3], [20, 1, 2, 3]) == 0


class TestCompareGeneration:
"""Unit tests for OnnxDiscrepancyCheck.compare_generation."""

def test_compare_generation_returns_common_prefix_length(self):
"""Test that compare_generation correctly computes the common prefix length."""
import torch

from olive.passes.onnx.discrepancy_check import OnnxDiscrepancyCheck

# Mock config
config = MagicMock()
config.reference_model_path = "mock_model"
config.genai_model_path = "mock_genai_model"
config.generate_prompt = "Hello world"
config.generate_max_new_tokens = 10

# Mock transformers tokenizer and model
mock_tokenizer = MagicMock()
mock_tokenizer.return_value = MagicMock(input_ids=torch.tensor([[1, 2, 3]]))

# Transformers generates [1, 2, 3, 10, 11, 12, 13]
mock_ref_model = MagicMock()
mock_ref_model.generate.return_value = torch.tensor([[1, 2, 3, 10, 11, 12, 13]])

# GenAI generates [1, 2, 3, 10, 11, 99, 99] (diverges at index 5)
mock_og = MagicMock()
mock_genai_model = MagicMock()
mock_og.Model.return_value = mock_genai_model
mock_genai_tokenizer = MagicMock()
mock_og.Tokenizer.return_value = mock_genai_tokenizer
mock_genai_tokenizer.encode.return_value = [1, 2, 3]

mock_params = MagicMock()
mock_og.GeneratorParams.return_value = mock_params

# Simulate generator producing tokens: 10, 11, 99, 99 then done
mock_generator = MagicMock()
genai_new_tokens = [10, 11, 99, 99]
call_count = [0]

def is_done_side_effect():
return call_count[0] >= len(genai_new_tokens)

def get_next_tokens_side_effect():
token = genai_new_tokens[call_count[0]]
call_count[0] += 1
return [token]

mock_generator.is_done = is_done_side_effect
mock_generator.get_next_tokens = get_next_tokens_side_effect
mock_og.Generator.return_value = mock_generator

with (
patch.dict(sys.modules, {"onnxruntime_genai": mock_og}),
patch("transformers.AutoTokenizer.from_pretrained", return_value=mock_tokenizer),
):
pass_instance = OnnxDiscrepancyCheck.__new__(OnnxDiscrepancyCheck)
result = pass_instance.compare_generation(config, mock_ref_model)

mock_generator.append_tokens.assert_called_once_with([[1, 2, 3]])
# Common prefix: [1, 2, 3, 10, 11] = 5 tokens before divergence
assert result == 5

def test_compare_generation_fully_matching(self):
"""Test when both outputs are identical."""
import torch

from olive.passes.onnx.discrepancy_check import OnnxDiscrepancyCheck

config = MagicMock()
config.reference_model_path = "mock_model"
config.genai_model_path = "mock_genai_model"
config.generate_prompt = "Test"
config.generate_max_new_tokens = 5

mock_tokenizer = MagicMock()
mock_tokenizer.return_value = MagicMock(input_ids=torch.tensor([[10, 20]]))

mock_ref_model = MagicMock()
mock_ref_model.generate.return_value = torch.tensor([[10, 20, 30, 40, 50]])

mock_og = MagicMock()
mock_og.Model.return_value = MagicMock()
mock_genai_tokenizer = MagicMock()
mock_og.Tokenizer.return_value = mock_genai_tokenizer
mock_genai_tokenizer.encode.return_value = [10, 20]

mock_params = MagicMock()
mock_og.GeneratorParams.return_value = mock_params

mock_generator = MagicMock()
genai_new_tokens = [30, 40, 50]
call_count = [0]

def is_done_side_effect():
return call_count[0] >= len(genai_new_tokens)

def get_next_tokens_side_effect():
token = genai_new_tokens[call_count[0]]
call_count[0] += 1
return [token]

mock_generator.is_done = is_done_side_effect
mock_generator.get_next_tokens = get_next_tokens_side_effect
mock_og.Generator.return_value = mock_generator

with (
patch.dict(sys.modules, {"onnxruntime_genai": mock_og}),
patch("transformers.AutoTokenizer.from_pretrained", return_value=mock_tokenizer),
):
pass_instance = OnnxDiscrepancyCheck.__new__(OnnxDiscrepancyCheck)
result = pass_instance.compare_generation(config, mock_ref_model)

mock_generator.append_tokens.assert_called_once_with([[10, 20]])
# All 5 tokens match
assert result == 5
Loading