From 96d08e9c5da356aabb57e7e11e68ce20863d49e6 Mon Sep 17 00:00:00 2001 From: Kyle Romero Date: Tue, 19 May 2026 23:35:05 +0000 Subject: [PATCH 1/3] Add genie_overrides to QairtEncapsulation for GenAIConfig customization Introduce a genie_overrides PassConfigParam that deep-merges user-supplied fields into the GenAIConfig before LLMContainer.export() bakes them into the Genie DLC. This allows callers to override any GenAIConfig field (engine config, positional encoding, etc.) without modifying QairtGenAIBuilder or QairtPipelinePass. Nested dicts are merged recursively so only the specified keys are changed; all other values set by the upstream builder pass are preserved. --- olive/passes/qairt/encapsulation.py | 37 +++++ test/passes/qairt/test_encapsulation.py | 191 ++++++++++++++++++++++++ 2 files changed, 228 insertions(+) diff --git a/olive/passes/qairt/encapsulation.py b/olive/passes/qairt/encapsulation.py index a6fa9ebae..306535fc7 100644 --- a/olive/passes/qairt/encapsulation.py +++ b/olive/passes/qairt/encapsulation.py @@ -24,6 +24,21 @@ MAX_GENIE_CONTEXT_LENGTH = 4096 +def _deep_merge(base: dict, overrides: dict) -> dict: + """Recursively merge *overrides* into *base*, returning a new dict. + + Nested dicts are merged rather than replaced, so only the keys present in + *overrides* are changed; all other keys from *base* are preserved. + """ + result = dict(base) + for k, v in overrides.items(): + if k in result and isinstance(result[k], dict) and isinstance(v, dict): + result[k] = _deep_merge(result[k], v) + else: + result[k] = v + return result + + class QairtEncapsulation(Pass): """Encapsulates a QAIRT DLC model with an onnx protobuf.""" @@ -49,6 +64,21 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassCon required=False, description="Opset name and version to be added in the generated context model", ), + "genie_overrides": PassConfigParam( + type_=dict, + default_value=None, + required=False, + description=( + "Deep-merged into the GenAIConfig before the Genie DLC is produced. " + "Use Python field names (underscores). Nested dicts are merged recursively — " + "only the specified keys are overridden; all other GenAIBuilder defaults are " + "preserved. Any field on GenAIConfig is valid: kv_dim, rope_theta, n_heads, " + "n_layer, n_embd, allow_async_init, enable_graph_switching, " + "positional_encoding (nested dict), etc. Note: top-level rope_theta and " + "rope_scaling are not forwarded by the Genie factory — use " + "positional_encoding.rope_theta to override RoPE theta in the DLC." + ), + ), } def _run_for_config( @@ -76,6 +106,13 @@ def _run_for_config( container: qairt_genai.LLMContainer = qairt_genai.LLMContainer.load(model.model_path) + if config.genie_overrides: + gen_ai_cfg = container._gen_ai_config + current = gen_ai_cfg.model_dump(mode="json", by_alias=False, exclude_none=True) + merged = _deep_merge(current, config.genie_overrides) + container._gen_ai_config = gen_ai_cfg.model_validate(merged) + logger.info("Applied genie_overrides to GenAIConfig: %s", list(config.genie_overrides.keys())) + # Input/Output metadata container.inputs = [("input_ids", TensorProto.INT32, ["batch_size", "sequence_length"])] container.outputs = [("logits", TensorProto.FLOAT, ["batch_size", 1, "vocab_size"])] diff --git a/test/passes/qairt/test_encapsulation.py b/test/passes/qairt/test_encapsulation.py index 0fb45a8b7..2ef566bfa 100644 --- a/test/passes/qairt/test_encapsulation.py +++ b/test/passes/qairt/test_encapsulation.py @@ -894,3 +894,194 @@ def test_create_genai_config_provider_options_key_lowercase(tmp_path): assert len(provider_options) == 1 assert "qnn" in provider_options[0] assert "QNN" not in provider_options[0] + + +# --------------------------------------------------------------------------- +# _deep_merge unit tests +# --------------------------------------------------------------------------- + + +def test_deep_merge_flat(): + """Flat keys in overrides replace or add keys in base.""" + from olive.passes.qairt.encapsulation import _deep_merge + + result = _deep_merge({"a": 1, "b": 2}, {"b": 99, "c": 3}) + assert result == {"a": 1, "b": 99, "c": 3} + + +def test_deep_merge_nested_dicts_are_merged_not_replaced(): + """Nested dicts are recursively merged, preserving keys not in overrides.""" + from olive.passes.qairt.encapsulation import _deep_merge + + base = {"positional_encoding": {"type": "rope", "rope_dim": 64, "rope_theta": 10000.0}} + overrides = {"positional_encoding": {"rope_theta": 500000.0}} + result = _deep_merge(base, overrides) + assert result["positional_encoding"] == {"type": "rope", "rope_dim": 64, "rope_theta": 500000.0} + + +def test_deep_merge_nested_override_replaces_non_dict(): + """A dict override replaces a non-dict base value at the same key.""" + from olive.passes.qairt.encapsulation import _deep_merge + + result = _deep_merge({"a": 42}, {"a": {"nested": 1}}) + assert result == {"a": {"nested": 1}} + + +def test_deep_merge_base_unmodified(): + """_deep_merge does not mutate base.""" + from olive.passes.qairt.encapsulation import _deep_merge + + base = {"a": {"b": 1}} + overrides = {"a": {"b": 2}} + _deep_merge(base, overrides) + assert base["a"]["b"] == 1 + + +# --------------------------------------------------------------------------- +# genie_overrides integration tests +# --------------------------------------------------------------------------- + + +def test_encapsulation_default_config_includes_genie_overrides(mock_accelerator_spec): + """genie_overrides is present in _default_config with None default.""" + config = QairtEncapsulation._default_config(mock_accelerator_spec) # pylint: disable=protected-access + assert "genie_overrides" in config + assert config["genie_overrides"].default_value is None + assert config["genie_overrides"].required is False + + +def test_encapsulation_genie_overrides_applied(tmp_path, mock_qairt_model, mock_qairt_modules): + """When genie_overrides is set, _gen_ai_config is deep-merged before export.""" + output_path = tmp_path / "output" + output_path.mkdir(parents=True, exist_ok=True) + + model_path = Path(mock_qairt_model.model_path) + (model_path / "config.json").write_text(json.dumps({"model_type": "llama", "hidden_size": 4096})) + (model_path / "generation_config.json").write_text(json.dumps({"eos_token_id": 2})) + + mock_container = MagicMock() + mock_container.inputs = [("input_ids", 7, ["batch_size", "sequence_length"])] + mock_container.outputs = [("logits", 1, ["batch_size", 1, "vocab_size"])] + + # Represent the existing GenAIConfig state after LLMContainer.load() + initial_gen_ai_state = { + "context_length": 4096, + "n_vocab": 32000, + "bos_token": 1, + "eos_token": 2, + "tokenizer_path": str(tmp_path / "tokenizer.json"), + "kv_dim": None, + "positional_encoding": {"type": "rope", "rope_dim": 64}, + } + mock_container._gen_ai_config.model_dump.return_value = initial_gen_ai_state + + def mock_export(output_dir, export_format): + Path(output_dir).mkdir(parents=True, exist_ok=True) + (Path(output_dir) / "model.dlc").write_text("dummy dlc") + + mock_container.export.side_effect = mock_export + mock_qairt_modules["gen_ai_api"].LLMContainer.load.return_value = mock_container + + def mock_save_func(model_def, path): + import onnx + from onnx import TensorProto + + inp = onnx.helper.make_tensor_value_info("input_ids", TensorProto.INT32, ["batch_size", "seq"]) + out = onnx.helper.make_tensor_value_info("logits", TensorProto.FLOAT, ["batch_size", 1, "vocab"]) + node = onnx.helper.make_node("Identity", inputs=["input_ids"], outputs=["logits"]) + graph = onnx.helper.make_graph([node], "g", [inp], [out]) + model = onnx.helper.make_model(graph, opset_imports=[onnx.helper.make_opsetid("", 14)]) + onnx.save(model, path) + + overrides = {"kv_dim": 128, "positional_encoding": {"rope_theta": 500000.0}} + + with ( + patch("olive.passes.qairt.encapsulation.helper") as mock_helper, + patch("olive.passes.qairt.encapsulation.save", side_effect=mock_save_func), + patch("olive.passes.qairt.encapsulation.checker"), + ): + mock_helper.make_node.return_value = MagicMock() + mock_helper.make_attribute.return_value = MagicMock() + mock_helper.make_tensor_value_info.return_value = MagicMock() + mock_helper.make_graph.return_value = MagicMock() + mock_helper.make_opsetid.return_value = MagicMock() + mock_helper.make_model.return_value = MagicMock() + + encap_pass = create_pass_from_dict( + QairtEncapsulation, + {"backend": "CPU", "genie_overrides": overrides}, + disable_search=True, + ) + + encap_pass.run(mock_qairt_model, str(output_path)) + + # model_dump was called to capture current state + mock_container._gen_ai_config.model_dump.assert_called_once_with(mode="json", by_alias=False, exclude_none=True) + # model_validate was called with the deep-merged result + expected_merged = { + **initial_gen_ai_state, + "kv_dim": 128, + "positional_encoding": {"type": "rope", "rope_dim": 64, "rope_theta": 500000.0}, + } + mock_container._gen_ai_config.model_validate.assert_called_once_with(expected_merged) + # _gen_ai_config was reassigned to the validated result + assert ( + mock_container._gen_ai_config + is not mock_qairt_modules["gen_ai_api"].LLMContainer.load.return_value._gen_ai_config + ) + + +def test_encapsulation_no_genie_overrides_leaves_gen_ai_config_untouched( + tmp_path, mock_qairt_model, mock_qairt_modules +): + """When genie_overrides is None, _gen_ai_config.model_dump is never called.""" + output_path = tmp_path / "output" + output_path.mkdir(parents=True, exist_ok=True) + + model_path = Path(mock_qairt_model.model_path) + (model_path / "config.json").write_text(json.dumps({"model_type": "llama", "hidden_size": 4096})) + (model_path / "generation_config.json").write_text(json.dumps({"eos_token_id": 2})) + + mock_container = MagicMock() + mock_container.inputs = [("input_ids", 7, ["batch_size", "sequence_length"])] + mock_container.outputs = [("logits", 1, ["batch_size", 1, "vocab_size"])] + + def mock_export(output_dir, export_format): + Path(output_dir).mkdir(parents=True, exist_ok=True) + (Path(output_dir) / "model.dlc").write_text("dummy dlc") + + mock_container.export.side_effect = mock_export + mock_qairt_modules["gen_ai_api"].LLMContainer.load.return_value = mock_container + + def mock_save_func(model_def, path): + import onnx + from onnx import TensorProto + + inp = onnx.helper.make_tensor_value_info("input_ids", TensorProto.INT32, ["batch_size", "seq"]) + out = onnx.helper.make_tensor_value_info("logits", TensorProto.FLOAT, ["batch_size", 1, "vocab"]) + node = onnx.helper.make_node("Identity", inputs=["input_ids"], outputs=["logits"]) + graph = onnx.helper.make_graph([node], "g", [inp], [out]) + model = onnx.helper.make_model(graph, opset_imports=[onnx.helper.make_opsetid("", 14)]) + onnx.save(model, path) + + with ( + patch("olive.passes.qairt.encapsulation.helper") as mock_helper, + patch("olive.passes.qairt.encapsulation.save", side_effect=mock_save_func), + patch("olive.passes.qairt.encapsulation.checker"), + ): + mock_helper.make_node.return_value = MagicMock() + mock_helper.make_attribute.return_value = MagicMock() + mock_helper.make_tensor_value_info.return_value = MagicMock() + mock_helper.make_graph.return_value = MagicMock() + mock_helper.make_opsetid.return_value = MagicMock() + mock_helper.make_model.return_value = MagicMock() + + encap_pass = create_pass_from_dict( + QairtEncapsulation, + {"backend": "CPU"}, + disable_search=True, + ) + + encap_pass.run(mock_qairt_model, str(output_path)) + + mock_container._gen_ai_config.model_dump.assert_not_called() From 654709b7f02cdabba7678767c6d826e7c2e287c9 Mon Sep 17 00:00:00 2001 From: Kyle Romero Date: Tue, 26 May 2026 18:26:50 +0000 Subject: [PATCH 2/3] Add backend_extensions_override to QairtEncapsulation for backend extensions customization Adds a backend_extensions_override PassConfigParam that is deep-merged into the LLMContainer's existing _backend_extensions_config before the Genie DLC is produced. Uses the raw JSON key names (hyphens) as they appear in backend_extensions.json. Nested dicts are merged recursively so only specified keys are overridden; all other defaults set by the builder are preserved. If the container has no existing backend extensions config, the override becomes the entire config. Includes 4 tests covering default config presence, merge into existing config, merge from empty, and no-op when override is None. Also drops one fragile identity assertion from the existing genie_overrides test. --- olive/passes/qairt/encapsulation.py | 19 +++ test/passes/qairt/test_encapsulation.py | 178 +++++++++++++++++++++++- 2 files changed, 192 insertions(+), 5 deletions(-) diff --git a/olive/passes/qairt/encapsulation.py b/olive/passes/qairt/encapsulation.py index 306535fc7..8671db270 100644 --- a/olive/passes/qairt/encapsulation.py +++ b/olive/passes/qairt/encapsulation.py @@ -79,6 +79,19 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassCon "positional_encoding.rope_theta to override RoPE theta in the DLC." ), ), + "backend_extensions_override": PassConfigParam( + type_=dict, + default_value=None, + required=False, + description=( + "Deep-merged into the backend extensions config before the Genie DLC is " + "produced. Use the raw JSON key names (hyphens) as they appear in " + "backend_extensions.json. Nested dicts are merged recursively — only the " + "specified keys are overridden; all other backend extension defaults set " + "by the builder are preserved. If the container has no existing backend " + "extensions config, the override is used as the entire config." + ), + ), } def _run_for_config( @@ -113,6 +126,12 @@ def _run_for_config( container._gen_ai_config = gen_ai_cfg.model_validate(merged) logger.info("Applied genie_overrides to GenAIConfig: %s", list(config.genie_overrides.keys())) + if config.backend_extensions_override: + container._backend_extensions_config = _deep_merge( + container._backend_extensions_config or {}, config.backend_extensions_override + ) + logger.info("Applied backend_extensions_override: %s", list(config.backend_extensions_override.keys())) + # Input/Output metadata container.inputs = [("input_ids", TensorProto.INT32, ["batch_size", "sequence_length"])] container.outputs = [("logits", TensorProto.FLOAT, ["batch_size", 1, "vocab_size"])] diff --git a/test/passes/qairt/test_encapsulation.py b/test/passes/qairt/test_encapsulation.py index 2ef566bfa..15dfb62d6 100644 --- a/test/passes/qairt/test_encapsulation.py +++ b/test/passes/qairt/test_encapsulation.py @@ -1024,11 +1024,6 @@ def mock_save_func(model_def, path): "positional_encoding": {"type": "rope", "rope_dim": 64, "rope_theta": 500000.0}, } mock_container._gen_ai_config.model_validate.assert_called_once_with(expected_merged) - # _gen_ai_config was reassigned to the validated result - assert ( - mock_container._gen_ai_config - is not mock_qairt_modules["gen_ai_api"].LLMContainer.load.return_value._gen_ai_config - ) def test_encapsulation_no_genie_overrides_leaves_gen_ai_config_untouched( @@ -1085,3 +1080,176 @@ def mock_save_func(model_def, path): encap_pass.run(mock_qairt_model, str(output_path)) mock_container._gen_ai_config.model_dump.assert_not_called() + + +# --------------------------------------------------------------------------- +# backend_extensions_override integration tests +# --------------------------------------------------------------------------- + + +def test_encapsulation_default_config_includes_backend_extensions_override(mock_accelerator_spec): + """backend_extensions_override is present in _default_config with None default.""" + config = QairtEncapsulation._default_config(mock_accelerator_spec) # pylint: disable=protected-access + assert "backend_extensions_override" in config + assert config["backend_extensions_override"].default_value is None + assert config["backend_extensions_override"].required is False + + +def test_encapsulation_backend_extensions_override_merges_into_existing(tmp_path, mock_qairt_model, mock_qairt_modules): + """Override is deep-merged into the existing _backend_extensions_config.""" + output_path = tmp_path / "output" + output_path.mkdir(parents=True, exist_ok=True) + model_path = Path(mock_qairt_model.model_path) + (model_path / "config.json").write_text(json.dumps({"model_type": "llama", "hidden_size": 4096})) + (model_path / "generation_config.json").write_text(json.dumps({"eos_token_id": 2})) + + mock_container = MagicMock() + mock_container._backend_extensions_config = {"context": {"n-threads": 6, "kv-cache-size": 512}} + mock_container.inputs = [("input_ids", 7, ["batch_size", "sequence_length"])] + mock_container.outputs = [("logits", 1, ["batch_size", 1, "vocab_size"])] + + def mock_export(output_dir, export_format): + Path(output_dir).mkdir(parents=True, exist_ok=True) + (Path(output_dir) / "model.dlc").write_text("dummy dlc") + + mock_container.export.side_effect = mock_export + mock_qairt_modules["gen_ai_api"].LLMContainer.load.return_value = mock_container + + def mock_save_func(model_def, path): + import onnx + from onnx import TensorProto + + inp = onnx.helper.make_tensor_value_info("input_ids", TensorProto.INT32, ["batch_size", "seq"]) + out = onnx.helper.make_tensor_value_info("logits", TensorProto.FLOAT, ["batch_size", 1, "vocab"]) + node = onnx.helper.make_node("Identity", inputs=["input_ids"], outputs=["logits"]) + graph = onnx.helper.make_graph([node], "g", [inp], [out]) + model = onnx.helper.make_model(graph, opset_imports=[onnx.helper.make_opsetid("", 14)]) + onnx.save(model, path) + + with ( + patch("olive.passes.qairt.encapsulation.helper") as mock_helper, + patch("olive.passes.qairt.encapsulation.save", side_effect=mock_save_func), + patch("olive.passes.qairt.encapsulation.checker"), + ): + mock_helper.make_node.return_value = MagicMock() + mock_helper.make_attribute.return_value = MagicMock() + mock_helper.make_tensor_value_info.return_value = MagicMock() + mock_helper.make_graph.return_value = MagicMock() + mock_helper.make_opsetid.return_value = MagicMock() + mock_helper.make_model.return_value = MagicMock() + + encap_pass = create_pass_from_dict( + QairtEncapsulation, + {"backend": "CPU", "backend_extensions_override": {"context": {"n-threads": 4}}}, + disable_search=True, + ) + encap_pass.run(mock_qairt_model, str(output_path)) + + # n-threads overridden; kv-cache-size from original preserved + assert mock_container._backend_extensions_config == {"context": {"n-threads": 4, "kv-cache-size": 512}} + + +def test_encapsulation_backend_extensions_override_from_empty(tmp_path, mock_qairt_model, mock_qairt_modules): + """When the container has no existing backend extensions config, the override becomes the config.""" + mock_container = MagicMock() + mock_container._backend_extensions_config = None + mock_container.inputs = [("input_ids", 7, ["batch_size", "sequence_length"])] + mock_container.outputs = [("logits", 1, ["batch_size", 1, "vocab_size"])] + + output_path = tmp_path / "output" + output_path.mkdir(parents=True, exist_ok=True) + model_path = Path(mock_qairt_model.model_path) + (model_path / "config.json").write_text(json.dumps({"model_type": "llama", "hidden_size": 4096})) + (model_path / "generation_config.json").write_text(json.dumps({"eos_token_id": 2})) + + def mock_export(output_dir, export_format): + Path(output_dir).mkdir(parents=True, exist_ok=True) + (Path(output_dir) / "model.dlc").write_text("dummy dlc") + + mock_container.export.side_effect = mock_export + mock_qairt_modules["gen_ai_api"].LLMContainer.load.return_value = mock_container + + def mock_save_func(model_def, path): + import onnx + from onnx import TensorProto + + inp = onnx.helper.make_tensor_value_info("input_ids", TensorProto.INT32, ["batch_size", "seq"]) + out = onnx.helper.make_tensor_value_info("logits", TensorProto.FLOAT, ["batch_size", 1, "vocab"]) + node = onnx.helper.make_node("Identity", inputs=["input_ids"], outputs=["logits"]) + graph = onnx.helper.make_graph([node], "g", [inp], [out]) + model = onnx.helper.make_model(graph, opset_imports=[onnx.helper.make_opsetid("", 14)]) + onnx.save(model, path) + + with ( + patch("olive.passes.qairt.encapsulation.helper") as mock_helper, + patch("olive.passes.qairt.encapsulation.save", side_effect=mock_save_func), + patch("olive.passes.qairt.encapsulation.checker"), + ): + mock_helper.make_node.return_value = MagicMock() + mock_helper.make_attribute.return_value = MagicMock() + mock_helper.make_tensor_value_info.return_value = MagicMock() + mock_helper.make_graph.return_value = MagicMock() + mock_helper.make_opsetid.return_value = MagicMock() + mock_helper.make_model.return_value = MagicMock() + + encap_pass = create_pass_from_dict( + QairtEncapsulation, + {"backend": "CPU", "backend_extensions_override": {"context": {"n-threads": 4}}}, + disable_search=True, + ) + encap_pass.run(mock_qairt_model, str(output_path)) + + assert mock_container._backend_extensions_config == {"context": {"n-threads": 4}} + + +def test_encapsulation_no_backend_extensions_override_leaves_config_untouched( + tmp_path, mock_qairt_model, mock_qairt_modules +): + """When backend_extensions_override is None the container config is not touched.""" + original_ext_cfg = {"context": {"n-threads": 6}} + + mock_container = MagicMock() + mock_container._backend_extensions_config = original_ext_cfg + mock_container.inputs = [("input_ids", 7, ["batch_size", "sequence_length"])] + mock_container.outputs = [("logits", 1, ["batch_size", 1, "vocab_size"])] + + output_path = tmp_path / "output" + output_path.mkdir(parents=True, exist_ok=True) + model_path = Path(mock_qairt_model.model_path) + (model_path / "config.json").write_text(json.dumps({"model_type": "llama", "hidden_size": 4096})) + (model_path / "generation_config.json").write_text(json.dumps({"eos_token_id": 2})) + + def mock_export(output_dir, export_format): + Path(output_dir).mkdir(parents=True, exist_ok=True) + (Path(output_dir) / "model.dlc").write_text("dummy dlc") + + mock_container.export.side_effect = mock_export + mock_qairt_modules["gen_ai_api"].LLMContainer.load.return_value = mock_container + + def mock_save_func(model_def, path): + import onnx + from onnx import TensorProto + + inp = onnx.helper.make_tensor_value_info("input_ids", TensorProto.INT32, ["batch_size", "seq"]) + out = onnx.helper.make_tensor_value_info("logits", TensorProto.FLOAT, ["batch_size", 1, "vocab"]) + node = onnx.helper.make_node("Identity", inputs=["input_ids"], outputs=["logits"]) + graph = onnx.helper.make_graph([node], "g", [inp], [out]) + model = onnx.helper.make_model(graph, opset_imports=[onnx.helper.make_opsetid("", 14)]) + onnx.save(model, path) + + with ( + patch("olive.passes.qairt.encapsulation.helper") as mock_helper, + patch("olive.passes.qairt.encapsulation.save", side_effect=mock_save_func), + patch("olive.passes.qairt.encapsulation.checker"), + ): + mock_helper.make_node.return_value = MagicMock() + mock_helper.make_attribute.return_value = MagicMock() + mock_helper.make_tensor_value_info.return_value = MagicMock() + mock_helper.make_graph.return_value = MagicMock() + mock_helper.make_opsetid.return_value = MagicMock() + mock_helper.make_model.return_value = MagicMock() + + encap_pass = create_pass_from_dict(QairtEncapsulation, {"backend": "CPU"}, disable_search=True) + encap_pass.run(mock_qairt_model, str(output_path)) + + assert mock_container._backend_extensions_config is original_ext_cfg From 174d4ffb5bda828877452c6fb39f1c7b68d1fe06 Mon Sep 17 00:00:00 2001 From: Kyle Romero Date: Thu, 28 May 2026 19:58:59 +0000 Subject: [PATCH 3/3] Add htp_execution_overrides to QairtEncapsulation with qairt-dev version compatibility --- olive/passes/qairt/encapsulation.py | 79 +++++- test/passes/qairt/test_encapsulation.py | 311 +++++++++++++++++++++++- 2 files changed, 366 insertions(+), 24 deletions(-) diff --git a/olive/passes/qairt/encapsulation.py b/olive/passes/qairt/encapsulation.py index 8671db270..66190c14f 100644 --- a/olive/passes/qairt/encapsulation.py +++ b/olive/passes/qairt/encapsulation.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: MIT # -------------------------------------------------------------------------- +import inspect import logging import os from pathlib import Path @@ -17,6 +18,7 @@ from olive.model import ONNXModelHandler, QairtModelHandler from olive.passes import Pass from olive.passes.pass_config import BasePassConfig, PassConfigParam +from olive.passes.qairt.gen_ai_builder import QairtBackend from olive.passes.qairt.utils import QairtLogLevel logger = logging.getLogger(__name__) @@ -64,6 +66,11 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassCon required=False, description="Opset name and version to be added in the generated context model", ), + "backend": PassConfigParam( + type_=QairtBackend, + default_value=QairtBackend.HTP, + description="Target accelerator backend for the DLC being encapsulated. Accepted values are 'CPU' and 'HTP'.", + ), "genie_overrides": PassConfigParam( type_=dict, default_value=None, @@ -79,7 +86,7 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassCon "positional_encoding.rope_theta to override RoPE theta in the DLC." ), ), - "backend_extensions_override": PassConfigParam( + "backend_extensions_overrides": PassConfigParam( type_=dict, default_value=None, required=False, @@ -92,8 +99,33 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassCon "extensions config, the override is used as the entire config." ), ), + "htp_execution_overrides": PassConfigParam( + type_=dict, + default_value=None, + required=False, + description=( + "HTP deployment parameters passed to LLMContainer.export() as an " + "HTPExecutionConfig. Dict keys map directly to HTPExecutionConfig fields: " + "cpu_mask, n_threads, use_mmap, spill_fill_bufsize, mmap_budget, " + "pos_id_dim, kv_update_method, allow_async_init, enable_graph_switching. " + "Requires backend='HTP'. Ignored with a warning if the installed qairt " + "does not support htp_execution_config in LLMContainer.export() " + "(requires qairt-dev with AISW-184594)." + ), + ), } + @classmethod + def validate_config( + cls, + config: type[BasePassConfig], + accelerator_spec: AcceleratorSpec, + ) -> bool: + if config.htp_execution_overrides and config.backend != QairtBackend.HTP: + logger.error("htp_execution_overrides is unsupported on non-HTP backends") + return False + return True + def _run_for_config( self, model: Union[QairtModelHandler], @@ -120,17 +152,38 @@ def _run_for_config( container: qairt_genai.LLMContainer = qairt_genai.LLMContainer.load(model.model_path) if config.genie_overrides: - gen_ai_cfg = container._gen_ai_config - current = gen_ai_cfg.model_dump(mode="json", by_alias=False, exclude_none=True) - merged = _deep_merge(current, config.genie_overrides) - container._gen_ai_config = gen_ai_cfg.model_validate(merged) - logger.info("Applied genie_overrides to GenAIConfig: %s", list(config.genie_overrides.keys())) - - if config.backend_extensions_override: - container._backend_extensions_config = _deep_merge( - container._backend_extensions_config or {}, config.backend_extensions_override - ) - logger.info("Applied backend_extensions_override: %s", list(config.backend_extensions_override.keys())) + if hasattr(container, "_gen_ai_config"): + gen_ai_cfg = container._gen_ai_config + current = gen_ai_cfg.model_dump(mode="json", by_alias=False, exclude_none=True) + merged = _deep_merge(current, config.genie_overrides) + container._gen_ai_config = gen_ai_cfg.model_validate(merged) + logger.info("Applied genie_overrides to GenAIConfig: %s", list(config.genie_overrides.keys())) + else: + logger.warning("genie_overrides ignored: _gen_ai_config not found in installed qairt version") + + if config.backend_extensions_overrides: + if hasattr(container, "_backend_extensions_config"): + container._backend_extensions_config = _deep_merge( + container._backend_extensions_config or {}, config.backend_extensions_overrides + ) + logger.info( + "Applied backend_extensions_overrides: %s", list(config.backend_extensions_overrides.keys()) + ) + else: + logger.warning( + "backend_extensions_overrides ignored: _backend_extensions_config not found in installed qairt version" + ) + + export_kwargs = {"export_format": qairt.ExportFormat.LM_EXECUTOR} + if config.htp_execution_overrides: + htp_exec_cfg = qairt_genai.HTPExecutionConfig(**config.htp_execution_overrides) + if "htp_execution_config" in inspect.signature(container.export).parameters: + export_kwargs["htp_execution_config"] = htp_exec_cfg + else: + logger.warning( + "htp_execution_overrides ignored: installed qairt does not support " + "htp_execution_config in LLMContainer.export()" + ) # Input/Output metadata container.inputs = [("input_ids", TensorProto.INT32, ["batch_size", "sequence_length"])] @@ -149,7 +202,7 @@ def _run_for_config( for name, datatype, shape in container.outputs: outputs.append(helper.make_tensor_value_info(name, datatype, shape)) - container.export(output_model_path, export_format=qairt.ExportFormat.LM_EXECUTOR) + container.export(output_model_path, **export_kwargs) # Find the .dlc file in the output directory output_path_obj = Path(output_model_path) diff --git a/test/passes/qairt/test_encapsulation.py b/test/passes/qairt/test_encapsulation.py index 15dfb62d6..c45e34470 100644 --- a/test/passes/qairt/test_encapsulation.py +++ b/test/passes/qairt/test_encapsulation.py @@ -1087,15 +1087,32 @@ def mock_save_func(model_def, path): # --------------------------------------------------------------------------- -def test_encapsulation_default_config_includes_backend_extensions_override(mock_accelerator_spec): - """backend_extensions_override is present in _default_config with None default.""" +def test_encapsulation_default_config_includes_backend(mock_accelerator_spec): + """backend is present in _default_config with HTP as the default.""" + from olive.passes.qairt.gen_ai_builder import QairtBackend + + config = QairtEncapsulation._default_config(mock_accelerator_spec) # pylint: disable=protected-access + assert "backend" in config + assert config["backend"].default_value == QairtBackend.HTP + + +def test_encapsulation_default_config_includes_backend_extensions_overrides(mock_accelerator_spec): + """backend_extensions_overrides is present in _default_config with None default.""" config = QairtEncapsulation._default_config(mock_accelerator_spec) # pylint: disable=protected-access - assert "backend_extensions_override" in config - assert config["backend_extensions_override"].default_value is None - assert config["backend_extensions_override"].required is False + assert "backend_extensions_overrides" in config + assert config["backend_extensions_overrides"].default_value is None + assert config["backend_extensions_overrides"].required is False -def test_encapsulation_backend_extensions_override_merges_into_existing(tmp_path, mock_qairt_model, mock_qairt_modules): +def test_encapsulation_default_config_includes_htp_execution_overrides(mock_accelerator_spec): + """htp_execution_overrides is present in _default_config with None default.""" + config = QairtEncapsulation._default_config(mock_accelerator_spec) # pylint: disable=protected-access + assert "htp_execution_overrides" in config + assert config["htp_execution_overrides"].default_value is None + assert config["htp_execution_overrides"].required is False + + +def test_encapsulation_backend_extensions_overrides_merges_into_existing(tmp_path, mock_qairt_model, mock_qairt_modules): """Override is deep-merged into the existing _backend_extensions_config.""" output_path = tmp_path / "output" output_path.mkdir(parents=True, exist_ok=True) @@ -1140,7 +1157,7 @@ def mock_save_func(model_def, path): encap_pass = create_pass_from_dict( QairtEncapsulation, - {"backend": "CPU", "backend_extensions_override": {"context": {"n-threads": 4}}}, + {"backend": "CPU", "backend_extensions_overrides": {"context": {"n-threads": 4}}}, disable_search=True, ) encap_pass.run(mock_qairt_model, str(output_path)) @@ -1149,7 +1166,7 @@ def mock_save_func(model_def, path): assert mock_container._backend_extensions_config == {"context": {"n-threads": 4, "kv-cache-size": 512}} -def test_encapsulation_backend_extensions_override_from_empty(tmp_path, mock_qairt_model, mock_qairt_modules): +def test_encapsulation_backend_extensions_overrides_from_empty(tmp_path, mock_qairt_model, mock_qairt_modules): """When the container has no existing backend extensions config, the override becomes the config.""" mock_container = MagicMock() mock_container._backend_extensions_config = None @@ -1194,7 +1211,7 @@ def mock_save_func(model_def, path): encap_pass = create_pass_from_dict( QairtEncapsulation, - {"backend": "CPU", "backend_extensions_override": {"context": {"n-threads": 4}}}, + {"backend": "CPU", "backend_extensions_overrides": {"context": {"n-threads": 4}}}, disable_search=True, ) encap_pass.run(mock_qairt_model, str(output_path)) @@ -1202,10 +1219,10 @@ def mock_save_func(model_def, path): assert mock_container._backend_extensions_config == {"context": {"n-threads": 4}} -def test_encapsulation_no_backend_extensions_override_leaves_config_untouched( +def test_encapsulation_no_backend_extensions_overrides_leaves_config_untouched( tmp_path, mock_qairt_model, mock_qairt_modules ): - """When backend_extensions_override is None the container config is not touched.""" + """When backend_extensions_overrides is None the container config is not touched.""" original_ext_cfg = {"context": {"n-threads": 6}} mock_container = MagicMock() @@ -1253,3 +1270,275 @@ def mock_save_func(model_def, path): encap_pass.run(mock_qairt_model, str(output_path)) assert mock_container._backend_extensions_config is original_ext_cfg + + +# --------------------------------------------------------------------------- +# validate_config tests +# --------------------------------------------------------------------------- + + +def test_encapsulation_validate_config_rejects_htp_execution_overrides_on_non_htp(mock_accelerator_spec): + """validate_config returns False when htp_execution_overrides is set on a non-HTP backend.""" + encap_pass = create_pass_from_dict( + QairtEncapsulation, + {"backend": "CPU", "htp_execution_overrides": {"n_threads": 4}}, + disable_search=True, + ) + assert encap_pass.validate_config(encap_pass.config, mock_accelerator_spec) is False + + +def test_encapsulation_validate_config_accepts_htp_execution_overrides_on_htp(mock_accelerator_spec): + """validate_config returns True when htp_execution_overrides is set on HTP backend.""" + encap_pass = create_pass_from_dict( + QairtEncapsulation, + {"backend": "HTP", "htp_execution_overrides": {"n_threads": 4}}, + disable_search=True, + ) + assert encap_pass.validate_config(encap_pass.config, mock_accelerator_spec) is True + + +# --------------------------------------------------------------------------- +# htp_execution_overrides integration tests +# --------------------------------------------------------------------------- + + +def test_encapsulation_htp_execution_overrides_applied_when_supported(tmp_path, mock_qairt_model, mock_qairt_modules): + """htp_execution_overrides is passed to export() when the installed qairt supports it.""" + import inspect + + output_path = tmp_path / "output" + output_path.mkdir(parents=True, exist_ok=True) + model_path = Path(mock_qairt_model.model_path) + (model_path / "config.json").write_text(json.dumps({"model_type": "llama", "hidden_size": 4096})) + (model_path / "generation_config.json").write_text(json.dumps({"eos_token_id": 2})) + + mock_container = MagicMock() + mock_container.inputs = [("input_ids", 7, ["batch_size", "sequence_length"])] + mock_container.outputs = [("logits", 1, ["batch_size", 1, "vocab_size"])] + + def mock_export(output_dir, export_format, htp_execution_config=None): + Path(output_dir).mkdir(parents=True, exist_ok=True) + (Path(output_dir) / "model.dlc").write_text("dummy dlc") + + mock_container.export = MagicMock(side_effect=mock_export) + mock_qairt_modules["gen_ai_api"].LLMContainer.load.return_value = mock_container + + def mock_save_func(model_def, path): + import onnx + from onnx import TensorProto + + inp = onnx.helper.make_tensor_value_info("input_ids", TensorProto.INT32, ["batch_size", "seq"]) + out = onnx.helper.make_tensor_value_info("logits", TensorProto.FLOAT, ["batch_size", 1, "vocab"]) + node = onnx.helper.make_node("Identity", inputs=["input_ids"], outputs=["logits"]) + graph = onnx.helper.make_graph([node], "g", [inp], [out]) + model = onnx.helper.make_model(graph, opset_imports=[onnx.helper.make_opsetid("", 14)]) + onnx.save(model, path) + + fake_sig = inspect.signature(mock_export) + + with ( + patch("olive.passes.qairt.encapsulation.helper") as mock_helper, + patch("olive.passes.qairt.encapsulation.save", side_effect=mock_save_func), + patch("olive.passes.qairt.encapsulation.checker"), + patch("olive.passes.qairt.encapsulation.inspect.signature", return_value=fake_sig), + ): + mock_helper.make_node.return_value = MagicMock() + mock_helper.make_attribute.return_value = MagicMock() + mock_helper.make_tensor_value_info.return_value = MagicMock() + mock_helper.make_graph.return_value = MagicMock() + mock_helper.make_opsetid.return_value = MagicMock() + mock_helper.make_model.return_value = MagicMock() + + fake_htp_cfg = MagicMock() + mock_qairt_modules["gen_ai_api"].HTPExecutionConfig.return_value = fake_htp_cfg + + encap_pass = create_pass_from_dict( + QairtEncapsulation, + {"backend": "HTP", "htp_execution_overrides": {"n_threads": 4, "cpu_mask": "0x3"}}, + disable_search=True, + ) + encap_pass.run(mock_qairt_model, str(output_path)) + + mock_qairt_modules["gen_ai_api"].HTPExecutionConfig.assert_called_once_with(n_threads=4, cpu_mask="0x3") + _, call_kwargs = mock_container.export.call_args + assert call_kwargs.get("htp_execution_config") is fake_htp_cfg + + +def test_encapsulation_htp_execution_overrides_skipped_when_not_supported( + tmp_path, mock_qairt_model, mock_qairt_modules +): + """htp_execution_overrides is ignored with a warning when export() has no htp_execution_config param.""" + import inspect + + output_path = tmp_path / "output" + output_path.mkdir(parents=True, exist_ok=True) + model_path = Path(mock_qairt_model.model_path) + (model_path / "config.json").write_text(json.dumps({"model_type": "llama", "hidden_size": 4096})) + (model_path / "generation_config.json").write_text(json.dumps({"eos_token_id": 2})) + + mock_container = MagicMock() + mock_container.inputs = [("input_ids", 7, ["batch_size", "sequence_length"])] + mock_container.outputs = [("logits", 1, ["batch_size", 1, "vocab_size"])] + + def mock_export_old(output_dir, export_format): + Path(output_dir).mkdir(parents=True, exist_ok=True) + (Path(output_dir) / "model.dlc").write_text("dummy dlc") + + mock_container.export = MagicMock(side_effect=mock_export_old) + mock_qairt_modules["gen_ai_api"].LLMContainer.load.return_value = mock_container + + def mock_save_func(model_def, path): + import onnx + from onnx import TensorProto + + inp = onnx.helper.make_tensor_value_info("input_ids", TensorProto.INT32, ["batch_size", "seq"]) + out = onnx.helper.make_tensor_value_info("logits", TensorProto.FLOAT, ["batch_size", 1, "vocab"]) + node = onnx.helper.make_node("Identity", inputs=["input_ids"], outputs=["logits"]) + graph = onnx.helper.make_graph([node], "g", [inp], [out]) + model = onnx.helper.make_model(graph, opset_imports=[onnx.helper.make_opsetid("", 14)]) + onnx.save(model, path) + + fake_sig = inspect.signature(mock_export_old) + + with ( + patch("olive.passes.qairt.encapsulation.helper") as mock_helper, + patch("olive.passes.qairt.encapsulation.save", side_effect=mock_save_func), + patch("olive.passes.qairt.encapsulation.checker"), + patch("olive.passes.qairt.encapsulation.inspect.signature", return_value=fake_sig), + patch("olive.passes.qairt.encapsulation.logger") as mock_logger, + ): + mock_helper.make_node.return_value = MagicMock() + mock_helper.make_attribute.return_value = MagicMock() + mock_helper.make_tensor_value_info.return_value = MagicMock() + mock_helper.make_graph.return_value = MagicMock() + mock_helper.make_opsetid.return_value = MagicMock() + mock_helper.make_model.return_value = MagicMock() + + encap_pass = create_pass_from_dict( + QairtEncapsulation, + {"backend": "HTP", "htp_execution_overrides": {"n_threads": 4}}, + disable_search=True, + ) + encap_pass.run(mock_qairt_model, str(output_path)) + + warning_messages = [str(call) for call in mock_logger.warning.call_args_list] + assert any("htp_execution_overrides ignored" in msg for msg in warning_messages) + _, call_kwargs = mock_container.export.call_args + assert "htp_execution_config" not in call_kwargs + + +# --------------------------------------------------------------------------- +# hasattr guard tests +# --------------------------------------------------------------------------- + + +def test_encapsulation_genie_overrides_skipped_when_private_attr_missing( + tmp_path, mock_qairt_model, mock_qairt_modules +): + """genie_overrides is ignored with a warning when _gen_ai_config is absent from the container.""" + output_path = tmp_path / "output" + output_path.mkdir(parents=True, exist_ok=True) + model_path = Path(mock_qairt_model.model_path) + (model_path / "config.json").write_text(json.dumps({"model_type": "llama", "hidden_size": 4096})) + (model_path / "generation_config.json").write_text(json.dumps({"eos_token_id": 2})) + + mock_container = MagicMock(spec=["inputs", "outputs", "export"]) + mock_container.inputs = [("input_ids", 7, ["batch_size", "sequence_length"])] + mock_container.outputs = [("logits", 1, ["batch_size", 1, "vocab_size"])] + + def mock_export(output_dir, **kwargs): + Path(output_dir).mkdir(parents=True, exist_ok=True) + (Path(output_dir) / "model.dlc").write_text("dummy dlc") + + mock_container.export = MagicMock(side_effect=mock_export) + mock_qairt_modules["gen_ai_api"].LLMContainer.load.return_value = mock_container + + def mock_save_func(model_def, path): + import onnx + from onnx import TensorProto + + inp = onnx.helper.make_tensor_value_info("input_ids", TensorProto.INT32, ["batch_size", "seq"]) + out = onnx.helper.make_tensor_value_info("logits", TensorProto.FLOAT, ["batch_size", 1, "vocab"]) + node = onnx.helper.make_node("Identity", inputs=["input_ids"], outputs=["logits"]) + graph = onnx.helper.make_graph([node], "g", [inp], [out]) + model = onnx.helper.make_model(graph, opset_imports=[onnx.helper.make_opsetid("", 14)]) + onnx.save(model, path) + + with ( + patch("olive.passes.qairt.encapsulation.helper") as mock_helper, + patch("olive.passes.qairt.encapsulation.save", side_effect=mock_save_func), + patch("olive.passes.qairt.encapsulation.checker"), + patch("olive.passes.qairt.encapsulation.logger") as mock_logger, + ): + mock_helper.make_node.return_value = MagicMock() + mock_helper.make_attribute.return_value = MagicMock() + mock_helper.make_tensor_value_info.return_value = MagicMock() + mock_helper.make_graph.return_value = MagicMock() + mock_helper.make_opsetid.return_value = MagicMock() + mock_helper.make_model.return_value = MagicMock() + + encap_pass = create_pass_from_dict( + QairtEncapsulation, + {"backend": "CPU", "genie_overrides": {"kv_dim": 128}}, + disable_search=True, + ) + encap_pass.run(mock_qairt_model, str(output_path)) + + warning_messages = [str(call) for call in mock_logger.warning.call_args_list] + assert any("genie_overrides ignored" in msg for msg in warning_messages) + + +def test_encapsulation_backend_extensions_overrides_skipped_when_private_attr_missing( + tmp_path, mock_qairt_model, mock_qairt_modules +): + """backend_extensions_overrides is ignored with a warning when _backend_extensions_config is absent.""" + output_path = tmp_path / "output" + output_path.mkdir(parents=True, exist_ok=True) + model_path = Path(mock_qairt_model.model_path) + (model_path / "config.json").write_text(json.dumps({"model_type": "llama", "hidden_size": 4096})) + (model_path / "generation_config.json").write_text(json.dumps({"eos_token_id": 2})) + + mock_container = MagicMock(spec=["inputs", "outputs", "export"]) + mock_container.inputs = [("input_ids", 7, ["batch_size", "sequence_length"])] + mock_container.outputs = [("logits", 1, ["batch_size", 1, "vocab_size"])] + + def mock_export(output_dir, **kwargs): + Path(output_dir).mkdir(parents=True, exist_ok=True) + (Path(output_dir) / "model.dlc").write_text("dummy dlc") + + mock_container.export = MagicMock(side_effect=mock_export) + mock_qairt_modules["gen_ai_api"].LLMContainer.load.return_value = mock_container + + def mock_save_func(model_def, path): + import onnx + from onnx import TensorProto + + inp = onnx.helper.make_tensor_value_info("input_ids", TensorProto.INT32, ["batch_size", "seq"]) + out = onnx.helper.make_tensor_value_info("logits", TensorProto.FLOAT, ["batch_size", 1, "vocab"]) + node = onnx.helper.make_node("Identity", inputs=["input_ids"], outputs=["logits"]) + graph = onnx.helper.make_graph([node], "g", [inp], [out]) + model = onnx.helper.make_model(graph, opset_imports=[onnx.helper.make_opsetid("", 14)]) + onnx.save(model, path) + + with ( + patch("olive.passes.qairt.encapsulation.helper") as mock_helper, + patch("olive.passes.qairt.encapsulation.save", side_effect=mock_save_func), + patch("olive.passes.qairt.encapsulation.checker"), + patch("olive.passes.qairt.encapsulation.logger") as mock_logger, + ): + mock_helper.make_node.return_value = MagicMock() + mock_helper.make_attribute.return_value = MagicMock() + mock_helper.make_tensor_value_info.return_value = MagicMock() + mock_helper.make_graph.return_value = MagicMock() + mock_helper.make_opsetid.return_value = MagicMock() + mock_helper.make_model.return_value = MagicMock() + + encap_pass = create_pass_from_dict( + QairtEncapsulation, + {"backend": "CPU", "backend_extensions_overrides": {"context": {"n-threads": 4}}}, + disable_search=True, + ) + encap_pass.run(mock_qairt_model, str(output_path)) + + warning_messages = [str(call) for call in mock_logger.warning.call_args_list] + assert any("backend_extensions_overrides ignored" in msg for msg in warning_messages)