diff --git a/olive/passes/onnx/discrepancy_check.py b/olive/passes/onnx/discrepancy_check.py index c6ce019fe..711d91643 100644 --- a/olive/passes/onnx/discrepancy_check.py +++ b/olive/passes/onnx/discrepancy_check.py @@ -38,6 +38,20 @@ def _infer_shape(dynamic_shape, known_values=None): return tuple(inferred_shape) +def _longest_common_token_sequence(seq_a: list[int], seq_b: list[int]) -> int: + """Compute the length of the longest common token sequence starting from the beginning. + + Counts how many tokens match consecutively from the start of both sequences + before the first divergence. + """ + length = 0 + for a, b in zip(seq_a, seq_b): + if a != b: + break + length += 1 + return length + + class OnnxDiscrepancyCheck(Pass): """Validates ONNX model outputs against a reference PyTorch model. @@ -47,6 +61,8 @@ class OnnxDiscrepancyCheck(Pass): - Maximum absolute error (MaxAE) - Number of elements where the absolute difference exceeds 0.1 - Number of elements where the absolute difference exceeds 0.01 + - Longest common token sequence from the beginning between transformers + generate and ONNX Runtime GenAI generate (when enabled) The pass fails if any configured threshold is exceeded. """ @@ -82,6 +98,34 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassCon "If exceeded, the pass fails." ), ), + "genai_model_path": PassConfigParam( + type_=Optional[str], + default_value=None, + description=( + "Path to the ONNX Runtime GenAI model directory. When provided, the pass " + "runs token generation using both transformers and ONNX Runtime GenAI, then " + "computes the longest common token sequence from the beginning of their outputs." + ), + ), + "generate_prompt": PassConfigParam( + type_=str, + default_value="The capital of France is", + description="Text prompt used for generation comparison between transformers and GenAI.", + ), + "generate_max_new_tokens": PassConfigParam( + type_=int, + default_value=32, + description="Maximum number of new tokens to generate for the token sequence comparison.", + ), + "min_longest_common_tokens": PassConfigParam( + type_=Optional[int], + default_value=None, + description=( + "Minimum acceptable length of the longest common token sequence from the " + "beginning between transformers and GenAI outputs. If the actual value is " + "below this threshold, the pass fails." + ), + ), } def _run_for_config( @@ -197,5 +241,65 @@ def _run_for_config( if failures: raise RuntimeError("ONNX model discrepancy check failed:\n" + "\n".join(f" - {f}" for f in failures)) + # Generation token sequence comparison (transformers vs ONNX Runtime GenAI) + if config.genai_model_path: + longest_common = self.compare_generation(config, ref_model) + if config.min_longest_common_tokens is not None and longest_common < config.min_longest_common_tokens: + raise RuntimeError( + f"ONNX model discrepancy check failed:\n" + f" - Longest common token sequence length {longest_common} is below " + f"threshold {config.min_longest_common_tokens}" + ) + # Return the model unchanged return model + + def compare_generation(self, config: type[BasePassConfig], ref_model) -> int: + """Run generation on both transformers and GenAI, return longest common token sequence length.""" + try: + import onnxruntime_genai as og + except ImportError as exc: + raise ImportError("Please install `onnxruntime-genai` to enable generation comparison.") from exc + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(config.reference_model_path) + + # Transformers generation + input_ids = tokenizer(config.generate_prompt, return_tensors="pt").input_ids + import torch + + with torch.no_grad(): + transformers_output = ref_model.generate( + input_ids, + max_new_tokens=config.generate_max_new_tokens, + do_sample=False, + ) + transformers_tokens = transformers_output[0].tolist() + + # ONNX Runtime GenAI generation + genai_model = og.Model(config.genai_model_path) + genai_tokenizer = og.Tokenizer(genai_model) + genai_input_ids = genai_tokenizer.encode(config.generate_prompt) + + params = og.GeneratorParams(genai_model) + params.set_search_options(max_length=len(genai_input_ids) + config.generate_max_new_tokens, do_sample=False) + + generator = og.Generator(genai_model, params) + generator.append_tokens([genai_input_ids]) + genai_tokens = list(genai_input_ids) + while not generator.is_done(): + generator.generate_next_token() + genai_tokens.append(generator.get_next_tokens()[0]) + del generator + + longest_common = _longest_common_token_sequence(transformers_tokens, genai_tokens) + + logger.info( + "OnnxDiscrepancyCheck generation comparison: " + "transformers_len=%d, genai_len=%d, longest_common_token_sequence=%d", + len(transformers_tokens), + len(genai_tokens), + longest_common, + ) + + return longest_common diff --git a/test/passes/onnx/test_discrepancy_check.py b/test/passes/onnx/test_discrepancy_check.py new file mode 100644 index 000000000..8fe9c677f --- /dev/null +++ b/test/passes/onnx/test_discrepancy_check.py @@ -0,0 +1,161 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import sys +from unittest.mock import MagicMock, patch + +from olive.passes.onnx.discrepancy_check import _longest_common_token_sequence + + +class TestLongestCommonTokenSequence: + """Unit tests for _longest_common_token_sequence helper.""" + + def test_identical_sequences(self): + assert _longest_common_token_sequence([1, 2, 3, 4], [1, 2, 3, 4]) == 4 + + def test_empty_sequences(self): + assert _longest_common_token_sequence([], []) == 0 + assert _longest_common_token_sequence([1, 2], []) == 0 + assert _longest_common_token_sequence([], [1, 2]) == 0 + + def test_no_common_prefix(self): + assert _longest_common_token_sequence([1, 2, 3], [4, 5, 6]) == 0 + + def test_partial_common_prefix(self): + assert _longest_common_token_sequence([1, 2, 3, 4, 5], [1, 2, 3, 9, 10]) == 3 + + def test_one_is_prefix_of_other(self): + assert _longest_common_token_sequence([1, 2, 3], [1, 2, 3, 4, 5]) == 3 + assert _longest_common_token_sequence([1, 2, 3, 4, 5], [1, 2, 3]) == 3 + + def test_single_element_match(self): + assert _longest_common_token_sequence([7], [7]) == 1 + + def test_single_element_no_match(self): + assert _longest_common_token_sequence([7], [8]) == 0 + + def test_diverges_after_first(self): + assert _longest_common_token_sequence([1, 99, 99], [1, 2, 3]) == 1 + + def test_common_tokens_later_not_counted(self): + # Tokens match later but not from the beginning + assert _longest_common_token_sequence([10, 1, 2, 3], [20, 1, 2, 3]) == 0 + + +class TestCompareGeneration: + """Unit tests for OnnxDiscrepancyCheck.compare_generation.""" + + def test_compare_generation_returns_common_prefix_length(self): + """Test that compare_generation correctly computes the common prefix length.""" + import torch + + from olive.passes.onnx.discrepancy_check import OnnxDiscrepancyCheck + + # Mock config + config = MagicMock() + config.reference_model_path = "mock_model" + config.genai_model_path = "mock_genai_model" + config.generate_prompt = "Hello world" + config.generate_max_new_tokens = 10 + + # Mock transformers tokenizer and model + mock_tokenizer = MagicMock() + mock_tokenizer.return_value = MagicMock(input_ids=torch.tensor([[1, 2, 3]])) + + # Transformers generates [1, 2, 3, 10, 11, 12, 13] + mock_ref_model = MagicMock() + mock_ref_model.generate.return_value = torch.tensor([[1, 2, 3, 10, 11, 12, 13]]) + + # GenAI generates [1, 2, 3, 10, 11, 99, 99] (diverges at index 5) + mock_og = MagicMock() + mock_genai_model = MagicMock() + mock_og.Model.return_value = mock_genai_model + mock_genai_tokenizer = MagicMock() + mock_og.Tokenizer.return_value = mock_genai_tokenizer + mock_genai_tokenizer.encode.return_value = [1, 2, 3] + + mock_params = MagicMock() + mock_og.GeneratorParams.return_value = mock_params + + # Simulate generator producing tokens: 10, 11, 99, 99 then done + mock_generator = MagicMock() + genai_new_tokens = [10, 11, 99, 99] + call_count = [0] + + def is_done_side_effect(): + return call_count[0] >= len(genai_new_tokens) + + def get_next_tokens_side_effect(): + token = genai_new_tokens[call_count[0]] + call_count[0] += 1 + return [token] + + mock_generator.is_done = is_done_side_effect + mock_generator.get_next_tokens = get_next_tokens_side_effect + mock_og.Generator.return_value = mock_generator + + with ( + patch.dict(sys.modules, {"onnxruntime_genai": mock_og}), + patch("transformers.AutoTokenizer.from_pretrained", return_value=mock_tokenizer), + ): + pass_instance = OnnxDiscrepancyCheck.__new__(OnnxDiscrepancyCheck) + result = pass_instance.compare_generation(config, mock_ref_model) + + mock_generator.append_tokens.assert_called_once_with([[1, 2, 3]]) + # Common prefix: [1, 2, 3, 10, 11] = 5 tokens before divergence + assert result == 5 + + def test_compare_generation_fully_matching(self): + """Test when both outputs are identical.""" + import torch + + from olive.passes.onnx.discrepancy_check import OnnxDiscrepancyCheck + + config = MagicMock() + config.reference_model_path = "mock_model" + config.genai_model_path = "mock_genai_model" + config.generate_prompt = "Test" + config.generate_max_new_tokens = 5 + + mock_tokenizer = MagicMock() + mock_tokenizer.return_value = MagicMock(input_ids=torch.tensor([[10, 20]])) + + mock_ref_model = MagicMock() + mock_ref_model.generate.return_value = torch.tensor([[10, 20, 30, 40, 50]]) + + mock_og = MagicMock() + mock_og.Model.return_value = MagicMock() + mock_genai_tokenizer = MagicMock() + mock_og.Tokenizer.return_value = mock_genai_tokenizer + mock_genai_tokenizer.encode.return_value = [10, 20] + + mock_params = MagicMock() + mock_og.GeneratorParams.return_value = mock_params + + mock_generator = MagicMock() + genai_new_tokens = [30, 40, 50] + call_count = [0] + + def is_done_side_effect(): + return call_count[0] >= len(genai_new_tokens) + + def get_next_tokens_side_effect(): + token = genai_new_tokens[call_count[0]] + call_count[0] += 1 + return [token] + + mock_generator.is_done = is_done_side_effect + mock_generator.get_next_tokens = get_next_tokens_side_effect + mock_og.Generator.return_value = mock_generator + + with ( + patch.dict(sys.modules, {"onnxruntime_genai": mock_og}), + patch("transformers.AutoTokenizer.from_pretrained", return_value=mock_tokenizer), + ): + pass_instance = OnnxDiscrepancyCheck.__new__(OnnxDiscrepancyCheck) + result = pass_instance.compare_generation(config, mock_ref_model) + + mock_generator.append_tokens.assert_called_once_with([[10, 20]]) + # All 5 tokens match + assert result == 5