-
Notifications
You must be signed in to change notification settings - Fork 298
Add ONNX Runtime GenAI generation comparison in OnnxDiscrepancyCheck #2487
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
xadupre
wants to merge
10
commits into
main
Choose a base branch
from
xadupre/genai
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+265
−0
Open
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
333131e
add onnxruntime-genai in OnnxDiscrepancyCheck to measure the first to…
xadupre 091df17
Potential fix for pull request finding 'CodeQL / Unused import'
xadupre 020a4ef
Merge branch 'main' of https://github.com/microsoft/Olive into xadupr…
xadupre 8054f90
lint
xadupre c0ea083
Fix compare_generation test decorator argument mismatch
Copilot 4e82f66
Fix discrepancy_check tests to mock local onnxruntime_genai import
Copilot 4b44c0e
Set discrepancy check generation default to 32 tokens
Copilot 38aa991
Switch discrepancy check to latest onnxruntime-genai generator API
Copilot 929000a
Merge branch 'main' into xadupre/genai
xadupre 8801866
Merge branch 'main' into xadupre/genai
xadupre File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,161 @@ | ||
| # ------------------------------------------------------------------------- | ||
| # Copyright (c) Microsoft Corporation. All rights reserved. | ||
| # Licensed under the MIT License. | ||
| # -------------------------------------------------------------------------- | ||
| import sys | ||
| from unittest.mock import MagicMock, patch | ||
|
|
||
| from olive.passes.onnx.discrepancy_check import _longest_common_token_sequence | ||
|
|
||
|
|
||
| class TestLongestCommonTokenSequence: | ||
| """Unit tests for _longest_common_token_sequence helper.""" | ||
|
|
||
| def test_identical_sequences(self): | ||
| assert _longest_common_token_sequence([1, 2, 3, 4], [1, 2, 3, 4]) == 4 | ||
|
|
||
| def test_empty_sequences(self): | ||
| assert _longest_common_token_sequence([], []) == 0 | ||
| assert _longest_common_token_sequence([1, 2], []) == 0 | ||
| assert _longest_common_token_sequence([], [1, 2]) == 0 | ||
|
|
||
| def test_no_common_prefix(self): | ||
| assert _longest_common_token_sequence([1, 2, 3], [4, 5, 6]) == 0 | ||
|
|
||
| def test_partial_common_prefix(self): | ||
| assert _longest_common_token_sequence([1, 2, 3, 4, 5], [1, 2, 3, 9, 10]) == 3 | ||
|
|
||
| def test_one_is_prefix_of_other(self): | ||
| assert _longest_common_token_sequence([1, 2, 3], [1, 2, 3, 4, 5]) == 3 | ||
| assert _longest_common_token_sequence([1, 2, 3, 4, 5], [1, 2, 3]) == 3 | ||
|
|
||
| def test_single_element_match(self): | ||
| assert _longest_common_token_sequence([7], [7]) == 1 | ||
|
|
||
| def test_single_element_no_match(self): | ||
| assert _longest_common_token_sequence([7], [8]) == 0 | ||
|
|
||
| def test_diverges_after_first(self): | ||
| assert _longest_common_token_sequence([1, 99, 99], [1, 2, 3]) == 1 | ||
|
|
||
| def test_common_tokens_later_not_counted(self): | ||
| # Tokens match later but not from the beginning | ||
| assert _longest_common_token_sequence([10, 1, 2, 3], [20, 1, 2, 3]) == 0 | ||
|
|
||
|
|
||
| class TestCompareGeneration: | ||
| """Unit tests for OnnxDiscrepancyCheck.compare_generation.""" | ||
|
|
||
| def test_compare_generation_returns_common_prefix_length(self): | ||
| """Test that compare_generation correctly computes the common prefix length.""" | ||
| import torch | ||
|
|
||
| from olive.passes.onnx.discrepancy_check import OnnxDiscrepancyCheck | ||
|
|
||
| # Mock config | ||
| config = MagicMock() | ||
| config.reference_model_path = "mock_model" | ||
| config.genai_model_path = "mock_genai_model" | ||
| config.generate_prompt = "Hello world" | ||
| config.generate_max_new_tokens = 10 | ||
|
|
||
| # Mock transformers tokenizer and model | ||
| mock_tokenizer = MagicMock() | ||
| mock_tokenizer.return_value = MagicMock(input_ids=torch.tensor([[1, 2, 3]])) | ||
|
|
||
| # Transformers generates [1, 2, 3, 10, 11, 12, 13] | ||
| mock_ref_model = MagicMock() | ||
| mock_ref_model.generate.return_value = torch.tensor([[1, 2, 3, 10, 11, 12, 13]]) | ||
|
|
||
| # GenAI generates [1, 2, 3, 10, 11, 99, 99] (diverges at index 5) | ||
| mock_og = MagicMock() | ||
| mock_genai_model = MagicMock() | ||
| mock_og.Model.return_value = mock_genai_model | ||
| mock_genai_tokenizer = MagicMock() | ||
| mock_og.Tokenizer.return_value = mock_genai_tokenizer | ||
| mock_genai_tokenizer.encode.return_value = [1, 2, 3] | ||
|
|
||
| mock_params = MagicMock() | ||
| mock_og.GeneratorParams.return_value = mock_params | ||
|
|
||
| # Simulate generator producing tokens: 10, 11, 99, 99 then done | ||
| mock_generator = MagicMock() | ||
| genai_new_tokens = [10, 11, 99, 99] | ||
| call_count = [0] | ||
|
|
||
| def is_done_side_effect(): | ||
| return call_count[0] >= len(genai_new_tokens) | ||
|
|
||
| def get_next_tokens_side_effect(): | ||
| token = genai_new_tokens[call_count[0]] | ||
| call_count[0] += 1 | ||
| return [token] | ||
|
|
||
| mock_generator.is_done = is_done_side_effect | ||
| mock_generator.get_next_tokens = get_next_tokens_side_effect | ||
| mock_og.Generator.return_value = mock_generator | ||
|
|
||
| with ( | ||
| patch.dict(sys.modules, {"onnxruntime_genai": mock_og}), | ||
| patch("transformers.AutoTokenizer.from_pretrained", return_value=mock_tokenizer), | ||
| ): | ||
| pass_instance = OnnxDiscrepancyCheck.__new__(OnnxDiscrepancyCheck) | ||
| result = pass_instance.compare_generation(config, mock_ref_model) | ||
|
|
||
| mock_generator.append_tokens.assert_called_once_with([[1, 2, 3]]) | ||
| # Common prefix: [1, 2, 3, 10, 11] = 5 tokens before divergence | ||
| assert result == 5 | ||
|
|
||
| def test_compare_generation_fully_matching(self): | ||
| """Test when both outputs are identical.""" | ||
| import torch | ||
|
|
||
| from olive.passes.onnx.discrepancy_check import OnnxDiscrepancyCheck | ||
|
|
||
| config = MagicMock() | ||
| config.reference_model_path = "mock_model" | ||
| config.genai_model_path = "mock_genai_model" | ||
| config.generate_prompt = "Test" | ||
| config.generate_max_new_tokens = 5 | ||
|
|
||
| mock_tokenizer = MagicMock() | ||
| mock_tokenizer.return_value = MagicMock(input_ids=torch.tensor([[10, 20]])) | ||
|
|
||
| mock_ref_model = MagicMock() | ||
| mock_ref_model.generate.return_value = torch.tensor([[10, 20, 30, 40, 50]]) | ||
|
|
||
| mock_og = MagicMock() | ||
| mock_og.Model.return_value = MagicMock() | ||
| mock_genai_tokenizer = MagicMock() | ||
| mock_og.Tokenizer.return_value = mock_genai_tokenizer | ||
| mock_genai_tokenizer.encode.return_value = [10, 20] | ||
|
|
||
| mock_params = MagicMock() | ||
| mock_og.GeneratorParams.return_value = mock_params | ||
|
|
||
| mock_generator = MagicMock() | ||
| genai_new_tokens = [30, 40, 50] | ||
| call_count = [0] | ||
|
|
||
| def is_done_side_effect(): | ||
| return call_count[0] >= len(genai_new_tokens) | ||
|
|
||
| def get_next_tokens_side_effect(): | ||
| token = genai_new_tokens[call_count[0]] | ||
| call_count[0] += 1 | ||
| return [token] | ||
|
|
||
| mock_generator.is_done = is_done_side_effect | ||
| mock_generator.get_next_tokens = get_next_tokens_side_effect | ||
| mock_og.Generator.return_value = mock_generator | ||
|
|
||
| with ( | ||
| patch.dict(sys.modules, {"onnxruntime_genai": mock_og}), | ||
| patch("transformers.AutoTokenizer.from_pretrained", return_value=mock_tokenizer), | ||
| ): | ||
| pass_instance = OnnxDiscrepancyCheck.__new__(OnnxDiscrepancyCheck) | ||
| result = pass_instance.compare_generation(config, mock_ref_model) | ||
|
|
||
| mock_generator.append_tokens.assert_called_once_with([[10, 20]]) | ||
| # All 5 tokens match | ||
| assert result == 5 |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.