diff --git a/docs/Privacy.md b/docs/Privacy.md index 95aee00b0b..b49ddbd6ce 100644 --- a/docs/Privacy.md +++ b/docs/Privacy.md @@ -13,4 +13,4 @@ In addition, Olive may collect additional telemetry data such as: - Performance data - Exception information -Collection of this additional telemetry can be disabled by adding the `--disable_telemetry` flag to any Olive CLI command, or by setting the `OLIVE_DISABLE_TELEMETRY` environment variable to `1` before running. Telemetry is also automatically disabled when a CI/CD environment is detected (e.g., GitHub Actions, Azure Pipelines, Jenkins). If telemetry is enabled, but cannot be sent to Microsoft, it will be stored locally and sent when a connection is available. You can override the default cache location by setting the `OLIVE_TELEMETRY_CACHE_DIR` environment variable to a valid directory path. +Collection of this additional telemetry can be disabled by adding the `--disable_telemetry` flag to any Olive CLI command, or by setting the `OLIVE_DISABLE_TELEMETRY` environment variable to `1` before running. In CI/CD environments (e.g., GitHub Actions, Azure Pipelines, Jenkins), Olive suppresses the general heartbeat/action/error events and only emits the `OliveRecipe` event. The `OliveRecipe` event may include recipe metadata such as pass types, explicitly configured target settings, the host system type (including the default `LocalSystem` host) and any explicitly configured host accelerator settings, whether a custom package config was provided, a redacted snapshot of custom package-config overrides, and a redacted snapshot of explicitly supplied config overrides. Outside CI/CD environments, if telemetry is enabled but cannot be sent to Microsoft, it will be stored locally and sent when a connection is available. You can override the default cache location by setting the `OLIVE_TELEMETRY_CACHE_DIR` environment variable to a valid directory path. diff --git a/docs/source/how-to/extending/custom-scripts.md b/docs/source/how-to/extending/custom-scripts.md index 8e8961a5b6..5e78149fe1 100644 --- a/docs/source/how-to/extending/custom-scripts.md +++ b/docs/source/how-to/extending/custom-scripts.md @@ -36,7 +36,7 @@ class MyDataLoader: @Registry.register_dataloader() def my_dataloader(dataset, batch_size): - return MyDataloader(dataset, batch_size) + return MyDataLoader(dataset, batch_size) @Registry.register_post_process() def my_post_process(output): diff --git a/mcp/uv.lock b/mcp/uv.lock index b7995e5533..0ef5b5a3dc 100644 --- a/mcp/uv.lock +++ b/mcp/uv.lock @@ -275,11 +275,11 @@ wheels = [ [[package]] name = "idna" -version = "3.11" +version = "3.15" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/6f/6d/0703ccc57f3a7233505399edb88de3cbd678da106337b9fcde432b65ed60/idna-3.11.tar.gz", hash = "sha256:795dafcc9c04ed0c1fb032c2aa73654d8e8c5023a7df64a53f39190ada629902", size = 194582, upload-time = "2025-10-12T14:55:20.501Z" } +sdist = { url = "https://files.pythonhosted.org/packages/82/77/7b3966d0b9d1d31a36ddf1746926a11dface89a83409bf1483f0237aa758/idna-3.15.tar.gz", hash = "sha256:ca962446ea538f7092a95e057da437618e886f4d349216d2b1e294abfdb65fdc", size = 199245, upload-time = "2026-05-12T22:45:57.011Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/0e/61/66938bbb5fc52dbdf84594873d5b51fb1f7c7794e9c0f5bd885f30bc507b/idna-3.11-py3-none-any.whl", hash = "sha256:771a87f49d9defaf64091e6e6fe9c18d4833f140bd19464795bc32d966ca37ea", size = 71008, upload-time = "2025-10-12T14:55:18.883Z" }, + { url = "https://files.pythonhosted.org/packages/d2/23/408243171aa9aaba178d3e2559159c24c1171a641aa83b67bdd3394ead8e/idna-3.15-py3-none-any.whl", hash = "sha256:048adeaf8c2d788c40fee287673ccaa74c24ffd8dcf09ffa555a2fbb59f10ac8", size = 72340, upload-time = "2026-05-12T22:45:55.733Z" }, ] [[package]] @@ -594,11 +594,11 @@ wheels = [ [[package]] name = "python-multipart" -version = "0.0.26" +version = "0.0.27" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/88/71/b145a380824a960ebd60e1014256dbb7d2253f2316ff2d73dfd8928ec2c3/python_multipart-0.0.26.tar.gz", hash = "sha256:08fadc45918cd615e26846437f50c5d6d23304da32c341f289a617127b081f17", size = 43501, upload-time = "2026-04-10T14:09:59.473Z" } +sdist = { url = "https://files.pythonhosted.org/packages/69/9b/f23807317a113dc36e74e75eb265a02dd1a4d9082abc3c1064acd22997c4/python_multipart-0.0.27.tar.gz", hash = "sha256:9870a6a8c5a20a5bf4f07c017bd1489006ff8836cff097b6933355ee2b49b602", size = 44043, upload-time = "2026-04-27T10:51:26.649Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/9a/22/f1925cdda983ab66fc8ec6ec8014b959262747e58bdca26a4e3d1da29d56/python_multipart-0.0.26-py3-none-any.whl", hash = "sha256:c0b169f8c4484c13b0dcf2ef0ec3a4adb255c4b7d18d8e420477d2b1dd03f185", size = 28847, upload-time = "2026-04-10T14:09:58.131Z" }, + { url = "https://files.pythonhosted.org/packages/99/78/4126abcbdbd3c559d43e0db7f7b9173fc6befe45d39a2856cc0b8ec2a5a6/python_multipart-0.0.27-py3-none-any.whl", hash = "sha256:6fccfad17a27334bd0193681b369f476eda3409f17381a2d65aa7df3f7275645", size = 29254, upload-time = "2026-04-27T10:51:24.997Z" }, ] [[package]] diff --git a/olive/cache.py b/olive/cache.py index 22b13eae5b..ceb64e1528 100644 --- a/olive/cache.py +++ b/olive/cache.py @@ -439,13 +439,19 @@ def save_model( else: from olive.passes.onnx.common import resave_model + component_output_name = ( + component_name + if Path(component_name).suffix == ".onnx" + else f"{component_name}.onnx" + ) + resave_model( ModelConfig.model_validate(component_model_json).create_model().model_path, - actual_output_dir / f"{component_name}.onnx", + actual_output_dir / component_output_name, saved_external_files=saved_external_files, ) component_model_json["config"][resource_name] = str(actual_output_dir) - component_model_json["config"]["onnx_file_name"] = f"{component_name}.onnx" + component_model_json["config"]["onnx_file_name"] = component_output_name copied_components.append(component_model_json) diff --git a/olive/cli/auto_opt.py b/olive/cli/auto_opt.py index 2e0f73444f..ef8f5fed8d 100644 --- a/olive/cli/auto_opt.py +++ b/olive/cli/auto_opt.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- +import logging from argparse import ArgumentParser from collections import OrderedDict from copy import deepcopy @@ -25,13 +26,19 @@ from olive.package_config import OlivePackageConfig from olive.telemetry import action +logger = logging.getLogger(__name__) + class AutoOptCommand(BaseOliveCLICommand): @staticmethod def register_subcommand(parser: ArgumentParser): sub_parser = parser.add_parser( "auto-opt", - help="Automatically optimize the performance of the input model.", + help=( + "Automatically optimize the performance of the input model.\n" + "**** DEPRECATION WARNING ****\n" + '"auto-opt" command is deprecated in favor of "optimize".' + ), ) # Model options @@ -174,6 +181,11 @@ def register_subcommand(parser: ArgumentParser): @action def run(self): + logger.warning( + "**** DEPRECATION WARNING ****\n" + '"auto-opt" command is deprecated in favor of "optimize". Please switch to using "optimize".\n' + "Deprecated commands will be removed entirely in future release." + ) return self._run_workflow() def _get_run_config(self, tempdir) -> dict: diff --git a/olive/cli/base.py b/olive/cli/base.py index 75fd2816c7..26bd042107 100644 --- a/olive/cli/base.py +++ b/olive/cli/base.py @@ -42,13 +42,24 @@ def _run_workflow(self): if self.args.dry_run: print("Dry run mode enabled. Configuration file is generated but no optimization is performed.") return None - workflow_output = olive_run(run_config) + workflow_output = olive_run(run_config, recipe_telemetry_metadata=self._get_recipe_telemetry_metadata()) if not workflow_output.has_output_model(): print("No output model produced. Please check the log for details.") else: print(f"Model is saved at {self.args.output_path}") return workflow_output + def _get_recipe_telemetry_metadata(self) -> dict[str, str]: + recipe_name = self.__class__.__name__ + if recipe_name.endswith("Command"): + recipe_name = recipe_name[: -len("Command")] + return { + "recipe_name": recipe_name, + "recipe_command": recipe_name, + "recipe_source": "generated_cli", + "recipe_format": "generated", + } + @staticmethod def _parse_extra_options(kv_items): from onnxruntime_genai import __version__ as OrtGenaiVersion diff --git a/olive/cli/model_package.py b/olive/cli/model_package.py index 4481df1b98..5513c2d719 100644 --- a/olive/cli/model_package.py +++ b/olive/cli/model_package.py @@ -2,12 +2,52 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- +"""``olive generate-model-package`` CLI command. + +Assemble one or more Olive output directories into a proposal-shaped ORT +model package. + +Each ``--source`` directory is one Olive output (an ``ONNXModel`` or a +``CompositeModel`` with ONNX components). Single-source packages are +allowed: a single variant under one component is a normal, valid package. + +Output layout (per the ORT model-package proposal):: + + / + ├── manifest.json + ├── configs/ + │ └── # tokenizer, genai_config, ... + └── / + ├── metadata.json + ├── shared_weights/ + │ └── / # opt-in cross-variant dedup + └── / + ├── variant.json + ├── model.onnx + └── ... + +Notes: +- ``shared_weights`` is opt-in per blob. A blob whose SHA-256 appears in only + one variant stays inline next to its ONNX file in the variant directory, + keeping the single-variant case loadable by stock ORT. +- Cross-variant dedup moves a duplicated blob to + ``/shared_weights//`` and records the mapping + in the per-file ``shared_files`` map of the variant's ``variant.json``. + Loading such a variant requires a model-package-aware consumer. +- ``genai_config.json`` is copied verbatim into ``/configs/``; + per-variant overlays are ORT-GenAI's responsibility, not Olive's. + +""" + +import hashlib import json import logging +import re import shutil from argparse import ArgumentParser +from dataclasses import dataclass, field from pathlib import Path -from typing import Optional +from typing import Any, Optional from olive.cli.base import ( BaseOliveCLICommand, @@ -18,18 +58,34 @@ logger = logging.getLogger(__name__) -# Model file suffixes that belong in the models/ directory, not configs/ +# Files inside an Olive output dir that always belong next to the ONNX model +# rather than under /configs/. _MODEL_SUFFIXES = {".onnx", ".bin", ".data", ".xml"} +# Schema version emitted in manifest.json. Keep in sync with the proposal. +_MANIFEST_SCHEMA_VERSION = 1 + +# Hash chunk size for SHA-256 over external-data blobs. +_HASH_CHUNK = 1024 * 1024 + +# Disallow path separators / traversal in component and variant names so a +# producer can't write files outside the package directory. +_NAME_RE = re.compile(r"^[A-Za-z0-9._-][A-Za-z0-9._\- ]*$") + + +# --------------------------------------------------------------------------- +# CLI command +# --------------------------------------------------------------------------- + class ModelPackageCommand(BaseOliveCLICommand): - """Merge multiple Olive output directories into a model package with manifest.""" + """Merge one or more Olive output directories into a model package.""" @staticmethod def register_subcommand(parser: ArgumentParser): sub_parser = parser.add_parser( "generate-model-package", - help="Merge multiple model outputs into a model package with manifest", + help="Merge one or more Olive output directories into a model package", ) sub_parser.add_argument( @@ -38,7 +94,10 @@ def register_subcommand(parser: ArgumentParser): type=str, action="append", required=True, - help="Source Olive output directory. Can be specified multiple times.", + help=( + "Source Olive output directory. Repeat to add multiple variants. " + "A single source is allowed (single-variant package)." + ), ) sub_parser.add_argument( @@ -46,21 +105,21 @@ def register_subcommand(parser: ArgumentParser): "--output_path", type=str, required=True, - help="Output directory for the merged model package.", + help="Output directory for the model package. Must be empty or non-existent.", ) sub_parser.add_argument( "--model_name", type=str, default=None, - help="Model name for the manifest. If not set, derived from the output directory name.", + help="Optional model name recorded under manifest.producer.", ) sub_parser.add_argument( "--model_version", type=str, default="1.0", - help="Model version string for the manifest. Default: 1.0", + help="Optional model version recorded under manifest.producer. Default: 1.0", ) add_logging_options(sub_parser) @@ -71,141 +130,136 @@ def register_subcommand(parser: ArgumentParser): def run(self): sources = self._parse_sources() output_dir = Path(self.args.output_path) - output_dir.mkdir(parents=True, exist_ok=True) - - model_name = self.args.model_name or output_dir.name - model_version = self.args.model_version - # Read model configs from each source targets = [] for target_name, source_path in sources: model_config = self._read_model_config(source_path) targets.append((target_name, source_path, model_config)) - is_composite = targets[0][2].get("type") == "CompositeModel" + types = {targets[i][2].get("type") for i in range(len(targets))} + if types - {"ONNXModel", "CompositeModel"}: + unsupported = sorted(types - {"ONNXModel", "CompositeModel"}) + raise ValueError( + f"Unsupported source model type(s) {unsupported!r}. " + "generate-model-package supports ONNXModel and CompositeModel only." + ) + if len(types) > 1: + raise ValueError( + f"Sources mix model types {sorted(types)!r}. All sources must share the same type " + "(all ONNXModel or all CompositeModel)." + ) + is_composite = next(iter(types)) == "CompositeModel" + if is_composite: - self._package_composite(targets, output_dir, model_name, model_version) + variants = self._build_composite_variants(targets) else: - self._package_single(targets, output_dir, model_name, model_version) + variants = self._build_single_variants(targets) + + config_files = self._collect_config_files(targets) + + task = self._extract_task(targets) + producer_info: dict[str, str] = {"tool": "olive-ai"} + try: + from olive import __version__ as _olive_version + + producer_info["tool_version"] = _olive_version + except Exception: + logger.debug("Could not read olive.__version__", exc_info=True) + producer_info["model_name"] = self.args.model_name or output_dir.name + producer_info["model_version"] = self.args.model_version + if task: + producer_info["task"] = task + + write_model_package( + output_dir=output_dir, + variants=variants, + config_files=config_files, + producer_info=producer_info, + ) logger.info("Model package generated at %s", output_dir) print(f"Model package generated at {output_dir}") # ------------------------------------------------------------------ - # Single-component packaging + # VariantSpec construction # ------------------------------------------------------------------ - def _package_single( - self, - targets: list[tuple[str, Path, dict]], - output_dir: Path, - model_name: str, - model_version: str, - ) -> None: - """Package non-composite models (single ONNX per target).""" - config_file_names = self._copy_config_files(targets, output_dir) + def _build_single_variants(self, targets: list[tuple[str, Path, dict]]) -> list["VariantSpec"]: task = self._extract_task(targets) component_name = _task_to_component_name(task) - - component_dir = output_dir / "models" / component_name - component_dir.mkdir(parents=True, exist_ok=True) - - model_variants = {} - for target_name, _source_path, model_config in targets: + variants: list[VariantSpec] = [] + for target_name, _src, model_config in targets: attrs = _get_model_attributes(model_config) - model_path = Path(model_config["config"]["model_path"]) - - target_dir = component_dir / target_name - _copy_model_files_single(model_path, target_dir) - - constraints = _build_constraints(attrs, model_path) - model_variants[target_name] = {"file": model_path.name, "constraints": constraints} - - _remove_config_files(component_dir, config_file_names) - - metadata = {"name": component_name, "model_variants": model_variants} - _write_json(component_dir / "metadata.json", metadata) - - manifest = { - "name": model_name, - "model_version": model_version, - "task": task, - "component_models": [component_name], - } - _write_json(output_dir / "manifest.json", manifest) - - # ------------------------------------------------------------------ - # Composite-model packaging - # ------------------------------------------------------------------ + onnx_path = _resolve_onnx_path(model_config) + ep, device, compatibility = _ep_device_compatibility(attrs, onnx_path) + variants.append( + VariantSpec( + component_name=component_name, + variant_name=target_name, + onnx_files=[onnx_path], + ep=ep, + device=device, + compatibility=compatibility, + inference_settings=model_config.get("config", {}).get("inference_settings") or {}, + ) + ) + return variants - def _package_composite( - self, - targets: list[tuple[str, Path, dict]], - output_dir: Path, - model_name: str, - model_version: str, - ) -> None: - """Package composite models with per-component directory layout.""" - config_file_names = self._copy_config_files(targets, output_dir) - - # Collect component info: component_data[comp_name][target_name] = (comp_config, target_attrs) + def _build_composite_variants(self, targets: list[tuple[str, Path, dict]]) -> list["VariantSpec"]: from collections import OrderedDict - component_data: dict[str, dict] = OrderedDict() + # Track per-component variants in source insertion order. + component_variants: dict[str, list[VariantSpec]] = OrderedDict() - for target_name, _source_path, model_config in targets: + for target_name, _src, model_config in targets: target_attrs = _get_model_attributes(model_config) + target_inference = model_config.get("config", {}).get("inference_settings") or {} components = model_config["config"].get("model_components", []) component_names = model_config["config"].get("component_names", []) - for comp_config, comp_name in zip(components, component_names): - if comp_name not in component_data: - component_data[comp_name] = OrderedDict() - component_data[comp_name][target_name] = (comp_config, target_attrs) - - models_dir = output_dir / "models" - comp_names_list = list(component_data.keys()) - - for comp_name in comp_names_list: - comp_dir = models_dir / comp_name - comp_dir.mkdir(parents=True, exist_ok=True) - - model_variants = {} - for target_name, (comp_config, target_attrs) in component_data[comp_name].items(): - comp_model_path = Path(comp_config["config"]["model_path"]) - target_dir = comp_dir / target_name - _copy_component_files(comp_model_path, target_dir) - - constraints = _build_constraints(target_attrs, comp_model_path) - model_variants[target_name] = {"file": comp_model_path.name, "constraints": constraints} + if not components: + raise ValueError(f"Composite source {target_name!r} declares no model_components.") - _remove_config_files(comp_dir, config_file_names) - - metadata = {"name": comp_name, "model_variants": model_variants} - _write_json(comp_dir / "metadata.json", metadata) + for comp_config, comp_name in zip(components, component_names): + # Component-level inference_settings overrides target-level if present. + comp_inference = comp_config.get("config", {}).get("inference_settings") or target_inference + # Component-level model_attributes overlay target-level. + comp_attrs = dict(target_attrs) + comp_attrs.update(_get_model_attributes(comp_config)) + + onnx_path = _resolve_onnx_path(comp_config) + ep, device, compatibility = _ep_device_compatibility(comp_attrs, onnx_path) + + spec = VariantSpec( + component_name=comp_name, + variant_name=target_name, + onnx_files=[onnx_path], + ep=ep, + device=device, + compatibility=compatibility, + inference_settings=comp_inference, + ) + component_variants.setdefault(comp_name, []).append(spec) - task = self._extract_task(targets) - manifest = { - "name": model_name, - "model_version": model_version, - "task": task, - "component_models": comp_names_list, - } - _write_json(output_dir / "manifest.json", manifest) + flat: list[VariantSpec] = [] + for comp_specs in component_variants.values(): + flat.extend(comp_specs) + return flat # ------------------------------------------------------------------ # Config file handling # ------------------------------------------------------------------ @staticmethod - def _copy_config_files( - targets: list[tuple[str, Path, dict]], - output_dir: Path, - ) -> set[str]: - """Copy non-model config files (genai_config, tokenizer, etc.) to configs/.""" + def _collect_config_files(targets: list[tuple[str, Path, dict]]) -> dict[str, Path]: + """Pick consumer-shared config files (genai_config, tokenizer, ...). + + Source-of-truth order: + 1. ``model_attributes.additional_files`` of any source that has it. + 2. Otherwise, the first source's non-model files. + """ config_entries: dict[str, Path] = {} - # Collect from the first target's additional_files or source directory for _target_name, _source_path, model_config in targets: attrs = _get_model_attributes(model_config) for fp in attrs.get("additional_files", []): @@ -215,7 +269,6 @@ def _copy_config_files( if config_entries: break - # Fall back to scanning the source directory for non-model files if not config_entries: for _target_name, source_path, _model_config in targets: for f in sorted(source_path.iterdir()): @@ -226,69 +279,47 @@ def _copy_config_files( if config_entries: break - if not config_entries: - return set() - - configs_dir = output_dir / "configs" - configs_dir.mkdir(parents=True, exist_ok=True) - - for name, src_path in config_entries.items(): - dest = configs_dir / name - if src_path.is_dir(): - if not dest.exists(): - shutil.copytree(str(src_path), str(dest)) - else: - shutil.copy2(str(src_path), str(dest)) - logger.info("Copied %s to %s", name, configs_dir) - - return set(config_entries.keys()) + return config_entries # ------------------------------------------------------------------ - # Source validation and reading + # Source validation / reading # ------------------------------------------------------------------ def _parse_sources(self) -> list[tuple[str, Path]]: - sources = [] + sources: list[tuple[str, Path]] = [] + seen_names: set[str] = set() for source in self.args.source: path = Path(source) if not path.is_dir(): raise ValueError(f"Source path does not exist or is not a directory: {path}") - if not (path / "model_config.json").exists(): raise ValueError( f"No model_config.json found in {path}. " "Source must be an Olive output directory with model_config.json." ) - - sources.append((path.name, path)) - - if len(sources) < 2: - raise ValueError("At least two --source directories are required to merge.") - + name = path.name + if name in seen_names: + raise ValueError( + f"Two sources share the directory name {name!r}. Variant names are derived from " + "the source directory name; please rename so each source is unique." + ) + seen_names.add(name) + sources.append((name, path)) + if not sources: + raise ValueError("At least one --source directory is required.") return sources @staticmethod def _read_model_config(source_path: Path) -> dict: - config_path = source_path / "model_config.json" - with open(config_path) as f: + with (source_path / "model_config.json").open() as f: return json.load(f) - @staticmethod - def _extract_accelerator_info(target_models: list[dict]) -> tuple[str, str]: - for model_config in target_models: - attrs = model_config.get("config", {}).get("model_attributes") or {} - ep = attrs.get("ep", "CPUExecutionProvider") - device = attrs.get("device", "cpu") - return ep, device.lower() - return "CPUExecutionProvider", "cpu" - # ------------------------------------------------------------------ # Task extraction # ------------------------------------------------------------------ @staticmethod def _extract_task(targets: list[tuple[str, Path, dict]]) -> str: - """Extract the HuggingFace pipeline task for the model.""" model_name_or_path = "" for _target_name, _source_path, model_config in targets: attrs = _get_model_attributes(model_config) @@ -310,40 +341,490 @@ def _extract_task(targets: list[tuple[str, Path, dict]]) -> str: return "" -# ------------------------------------------------------------------ -# Module-level helpers -# ------------------------------------------------------------------ +# --------------------------------------------------------------------------- +# Writer (CLI-private; kept here because only this command produces packages) +# --------------------------------------------------------------------------- + + +@dataclass +class VariantSpec: + """One variant of one component, ready to be packaged.""" + + component_name: str + variant_name: str + onnx_files: list[Path] + ep: str + device: Optional[str] = None + compatibility: list[str] = field(default_factory=list) + inference_settings: dict[str, Any] = field(default_factory=dict) + consumer_metadata: Optional[dict[str, Any]] = None + + +def write_model_package( + output_dir: Path, + variants: list[VariantSpec], + config_files: Optional[dict[str, Path]] = None, + producer_info: Optional[dict[str, Any]] = None, +) -> None: + """Materialize a model package on disk. + + :param output_dir: Target directory. Must be empty (or non-existent) so a + partial overwrite cannot mix the new layout with stale files from a + previous run. + :param variants: Ordered list of variants. Component insertion order is + the order each component first appears in this list. + :param config_files: Map from filename (basename) to source path; copied + into ``/configs/``. Same-named files contributed by + different sources should be byte-identical; the first wins on + conflict and a warning is logged. + :param producer_info: Olive-specific provenance recorded under + ``manifest.producer``. Schema-tolerated extra field (the proposal + defines only ``schema_version``, ``components``, and + ``merge_provenance``; producers may add namespaced extras). + """ + if not variants: + raise ValueError("write_model_package requires at least one variant.") + + output_dir = Path(output_dir) + _ensure_empty_output_dir(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + # Group by component while preserving insertion order. + components: dict[str, list[VariantSpec]] = {} + for v in variants: + _validate_name(v.component_name, "component") + _validate_name(v.variant_name, "variant") + components.setdefault(v.component_name, []).append(v) + + # Per component, fail fast on duplicate variant names. The caller is + # expected to disambiguate (e.g. with a rank suffix) before calling. + for comp_name, comp_variants in components.items(): + seen: set[str] = set() + for v in comp_variants: + if v.variant_name in seen: + raise ValueError( + f"Duplicate variant name '{v.variant_name}' under component " + f"'{comp_name}'. Variant names must be unique per component." + ) + seen.add(v.variant_name) + + for comp_name, comp_variants in components.items(): + _write_component(output_dir, comp_name, comp_variants) + + if config_files: + _copy_config_files(output_dir, config_files) + + _write_manifest(output_dir, list(components.keys()), producer_info) + + +def _write_component(output_dir: Path, component_name: str, comp_variants: list[VariantSpec]) -> None: + component_dir = output_dir / component_name + component_dir.mkdir(parents=True, exist_ok=True) + + # First pass: copy each variant's ONNX file(s) and discover external-data + # references. We hash blobs as we copy so multi-variant packages don't + # re-read the data later. + blob_index: dict[str, dict[str, Any]] = {} + variant_files: dict[str, list[tuple[str, list[tuple[str, str]]]]] = {} + + for v in comp_variants: + if not v.onnx_files: + raise ValueError(f"Variant '{v.variant_name}' under component '{component_name}' has no ONNX files.") + + variant_dir = component_dir / v.variant_name + variant_dir.mkdir(parents=True, exist_ok=True) + files_for_variant: list[tuple[str, list[tuple[str, str]]]] = [] + + for onnx_src in v.onnx_files: + onnx_src_path = Path(onnx_src) + if not onnx_src_path.is_file(): + raise FileNotFoundError(f"ONNX file not found: {onnx_src_path}") + + onnx_dst = variant_dir / onnx_src_path.name + shutil.copy2(str(onnx_src_path), str(onnx_dst)) + + ext_refs = _discover_external_data(onnx_src_path) + external_root = onnx_src_path.parent.resolve() + blob_records: list[tuple[str, str]] = [] + for graph_location in ext_refs: + blob_src = (onnx_src_path.parent / graph_location).resolve() + if not blob_src.is_relative_to(external_root): + logger.warning( + "External-data file referenced by %s resolves outside its source directory " + "(symlink escape?); skipping: %s", + onnx_src_path, + blob_src, + ) + continue + if not blob_src.is_file(): + logger.warning( + "External-data file referenced by %s but missing: %s", + onnx_src_path, + blob_src, + ) + continue + + blob_dst = variant_dir / graph_location + blob_dst.parent.mkdir(parents=True, exist_ok=True) + if not blob_dst.exists(): + shutil.copy2(str(blob_src), str(blob_dst)) + + sha = _sha256_file(blob_dst) + blob_records.append((graph_location, sha)) + + entry = blob_index.setdefault( + sha, {"first_path": blob_dst, "occurrences": 0, "basename": Path(graph_location).name} + ) + entry["occurrences"] += 1 + + files_for_variant.append((onnx_dst.name, blob_records)) + + variant_files[v.variant_name] = files_for_variant + + # Second pass: dedup any blob that appears in 2+ variants of this + # component into /shared_weights//. Single- + # occurrence blobs stay inline so single-variant packages remain + # loadable without the package API. + shared_weights_dir = component_dir / "shared_weights" + shared_blob_paths: dict[str, Path] = {} + for sha, entry in blob_index.items(): + if entry["occurrences"] < 2: + continue + sha_dir = shared_weights_dir / sha + sha_dir.mkdir(parents=True, exist_ok=True) + target = sha_dir / entry["basename"] + if not target.exists(): + shutil.copy2(str(entry["first_path"]), str(target)) + shared_blob_paths[sha] = target + + # Third pass: for each variant, remove deduped blobs from the variant + # directory and emit variant.json with the right shared_files map per + # files[i]. Then emit metadata.json for the component. + for v in comp_variants: + variant_dir = component_dir / v.variant_name + files_payload: list[dict[str, Any]] = [] + for onnx_filename, blob_records in variant_files[v.variant_name]: + shared_files: dict[str, str] = {} + for graph_location, sha in blob_records: + if sha in shared_blob_paths: + inline = variant_dir / graph_location + if inline.exists(): + inline.unlink() + # Clean up any now-empty parent directories created for + # nested graph_location paths, but stop at variant_dir. + parent = inline.parent + while parent != variant_dir and parent.is_dir() and not any(parent.iterdir()): + parent.rmdir() + parent = parent.parent + shared_files[graph_location] = sha + + file_entry: dict[str, Any] = {"filename": onnx_filename} + so = (v.inference_settings or {}).get("session_options") or {} + po = _provider_options_for_ep(v.inference_settings or {}, v.ep) + if so: + file_entry["session_options"] = so + if po: + file_entry["provider_options"] = po + if shared_files: + file_entry["shared_files"] = shared_files + files_payload.append(file_entry) + + variant_payload: dict[str, Any] = {"files": files_payload} + if v.consumer_metadata is not None: + variant_payload["consumer_metadata"] = v.consumer_metadata + _write_json(variant_dir / "variant.json", variant_payload) + + _write_metadata(component_dir, comp_variants) + + +def _write_metadata(component_dir: Path, comp_variants: list[VariantSpec]) -> None: + variants_payload: dict[str, Any] = {} + for v in comp_variants: + ep_entry: dict[str, Any] = {"ep": v.ep} + if v.device: + ep_entry["device"] = v.device + if v.compatibility: + ep_entry["compatibility"] = list(v.compatibility) + variants_payload[v.variant_name] = {"ep_compatibility": [ep_entry]} + _write_json(component_dir / "metadata.json", {"variants": variants_payload}) + + +def _write_manifest( + output_dir: Path, + components: list[str], + producer_info: Optional[dict[str, Any]], +) -> None: + manifest: dict[str, Any] = { + "schema_version": _MANIFEST_SCHEMA_VERSION, + "components": components, + } + if producer_info: + # Olive-specific provenance under a namespaced key so future schema + # evolution can't collide with it. + manifest["producer"] = producer_info + _write_json(output_dir / "manifest.json", manifest) + + +# --------------------------------------------------------------------------- +# configs/ handling +# --------------------------------------------------------------------------- + + +def _copy_config_files(output_dir: Path, config_files: dict[str, Path]) -> None: + configs_dir = output_dir / "configs" + configs_dir.mkdir(parents=True, exist_ok=True) + configs_root = configs_dir.resolve() + for name, src in config_files.items(): + if "/" in name or "\\" in name or name in ("", ".", ".."): + logger.warning("Skipping config file with unsafe name %r.", name) + continue + src_path = Path(src) + dest = configs_dir / name + # Belt-and-suspenders: even with the name check above, refuse a dest + # that doesn't land directly under configs/. + if dest.resolve().parent != configs_root: + logger.warning("Skipping config file %r: resolved path escapes configs/.", name) + continue + if dest.exists(): + if not _paths_equal(src_path, dest): + logger.warning( + "configs/%s already present and differs from %s; keeping the existing copy. " + "Per-variant config differences belong in variant.json's consumer_metadata, " + "which is consumer-defined and out of Olive's scope.", + name, + src_path, + ) + continue + if src_path.is_dir(): + shutil.copytree(str(src_path), str(dest)) + elif src_path.is_file(): + shutil.copy2(str(src_path), str(dest)) + else: + logger.warning("Config source %s does not exist; skipping.", src_path) + + +def _paths_equal(a: Path, b: Path) -> bool: + """Return True if a and b have identical content (file or directory).""" + if a.is_file() and b.is_file(): + if a.stat().st_size != b.stat().st_size: + return False + return _sha256_file(a) == _sha256_file(b) + if a.is_dir() and b.is_dir(): + a_entries = sorted(p.name for p in a.iterdir()) + b_entries = sorted(p.name for p in b.iterdir()) + if a_entries != b_entries: + return False + return all(_paths_equal(a / name, b / name) for name in a_entries) + return False + + +# --------------------------------------------------------------------------- +# ONNX external-data discovery +# --------------------------------------------------------------------------- + + +def _discover_external_data(onnx_path: Path) -> list[str]: + """Return the relative ``location`` strings of every external-data blob. + + Locations are validated as safe relative paths (no absolute paths, no + upward traversal). Unsafe references are dropped with a warning rather + than failing — better to package a slightly broken model than to refuse + progress on something the user can fix downstream. + """ + try: + import onnx + except ImportError: + logger.warning("onnx package not available; external-data discovery skipped.") + return [] + + try: + model = onnx.load(str(onnx_path), load_external_data=False) + except Exception: + logger.debug("Failed to parse %s; skipping external-data discovery.", onnx_path, exc_info=True) + return [] + + locations: list[str] = [] + seen: set[str] = set() + for init in model.graph.initializer: + if init.data_location != onnx.TensorProto.EXTERNAL: + continue + for entry in init.external_data: + if entry.key != "location": + continue + location = entry.value + if not _is_safe_relative_location(location): + logger.warning( + "Skipping unsafe external-data location %r in %s.", + location, + onnx_path, + ) + continue + if location not in seen: + locations.append(location) + seen.add(location) + return locations + + +def _is_safe_relative_location(location: str) -> bool: + if not location: + return False + p = Path(location) + if p.is_absolute(): + return False + parts = p.parts + if any(part in ("..", "") for part in parts): + return False + # Reject Windows-drive style paths that slip through is_absolute on POSIX. + return not (len(location) >= 2 and location[1] == ":") + + +# --------------------------------------------------------------------------- +# Helpers (module-level so tests can exercise them directly) +# --------------------------------------------------------------------------- + + +def _provider_options_for_ep(inference_settings: dict[str, Any], ep: str) -> dict[str, Any]: + """Return the provider_options dict that matches ``ep`` by name. + + Olive's inference_settings has ``execution_provider`` (list of EP names) + and ``provider_options`` (parallel list). Match by EP name; do not rely on + positional indexing. + """ + eps = inference_settings.get("execution_provider") or [] + pos = inference_settings.get("provider_options") or [] + for name, opts in zip(eps, pos): + if name == ep: + return opts or {} + return {} + + +def _sha256_file(path: Path) -> str: + h = hashlib.sha256() + with path.open("rb") as fh: + while True: + chunk = fh.read(_HASH_CHUNK) + if not chunk: + break + h.update(chunk) + return h.hexdigest() + + +def _validate_name(name: str, kind: str) -> None: + if not name or not _NAME_RE.match(name): + raise ValueError( + f"Invalid {kind} name {name!r}: must be non-empty and contain only " + "alphanumerics, dot, underscore, hyphen, and space." + ) + if name in (".", "..") or "/" in name or "\\" in name: + raise ValueError(f"Invalid {kind} name {name!r}: path separators and traversal are not allowed.") + + +def _ensure_empty_output_dir(output_dir: Path) -> None: + if output_dir.exists(): + if not output_dir.is_dir(): + raise ValueError(f"Output path {output_dir} exists and is not a directory.") + if any(output_dir.iterdir()): + raise ValueError( + f"Output directory {output_dir} is not empty. Refusing to mix stale files with a new " + "package; please point at an empty (or non-existent) directory." + ) + + +def _write_json(path: Path, data: dict[str, Any]) -> None: + with path.open("w", encoding="utf-8") as fh: + json.dump(data, fh, indent=2) + fh.write("\n") + logger.info("Wrote %s", path) + + +def parse_compatibility_strings(raw: Optional[str]) -> list[str]: + """Split Olive's ``ep_compatibility_info.`` ONNX metadata string. + + Producers store comma-delimited lists today (e.g. ``"sm_80,sm_86,sm_90"``); + the proposal expects a JSON list of opaque strings. Splitting here keeps + consumers from having to know Olive's convention. + """ + if not raw: + return [] + return [tok.strip() for tok in raw.split(",") if tok.strip()] + + +def disambiguate_variant_names(candidates: list[tuple[str, str]]) -> list[str]: + """Return per-candidate variant names with rank suffixes on collision. + + ``candidates`` is a list of ``(component_name, base_variant_name)`` + tuples; the function returns a parallel list of disambiguated variant + names (suffixing ``_rank{N}`` deterministically when two candidates land + on the same ``(component, base_variant)``). + """ + counts: dict[tuple[str, str], int] = {} + for key in candidates: + counts[key] = counts.get(key, 0) + 1 + + used: dict[tuple[str, str], int] = {} + result: list[str] = [] + for comp, base in candidates: + if counts[(comp, base)] == 1: + result.append(base) + continue + used[(comp, base)] = used.get((comp, base), 0) + 1 + result.append(f"{base}_rank{used[(comp, base)]}") + return result + + +# --------------------------------------------------------------------------- +# Olive model-config helpers +# --------------------------------------------------------------------------- def _get_model_attributes(model_config: dict) -> dict: return model_config.get("config", {}).get("model_attributes") or {} -def _write_json(path: Path, data: dict) -> None: - with open(path, "w") as f: - json.dump(data, f, indent=2) - logger.info("Generated %s", path) +def _resolve_onnx_path(model_config: dict) -> Path: + """Resolve the ONNX file path from an Olive model config. + The config's ``model_path`` may be either: + - the ONNX file itself (a ``LocalFile`` resource), + - a directory containing the ONNX file (a ``LocalFolder`` resource), + in which case ``onnx_file_name`` (or a single ``.onnx`` in the dir) + identifies the actual file. + """ + cfg = model_config.get("config", {}) or {} + raw = cfg.get("model_path") + if not raw: + raise ValueError("Model config has no model_path.") + p = Path(raw) + if p.is_file(): + return p + if p.is_dir(): + onnx_name = cfg.get("onnx_file_name") + if onnx_name: + candidate = p / onnx_name + if candidate.is_file(): + return candidate + onnx_files = list(p.glob("*.onnx")) + if len(onnx_files) == 1: + return onnx_files[0] + raise ValueError( + f"Cannot resolve a unique ONNX file under {p}; " + "set onnx_file_name in the model config or pass the file path directly." + ) + raise FileNotFoundError(f"model_path does not exist: {p}") -def _build_constraints(attrs: dict, model_path: Path) -> dict: - """Build variant constraints from model attributes and ONNX metadata.""" - constraints = {} - ep = attrs.get("ep") - if ep: - constraints["ep"] = ep - device = attrs.get("device") - if device: - constraints["device"] = device - ep_compat = _extract_ep_compatibility_from_onnx(model_path, ep or "") - constraints["ep_compatibility_info"] = ep_compat or "" - return constraints + +def _ep_device_compatibility(attrs: dict, onnx_path: Path) -> tuple[str, Optional[str], list[str]]: + """Extract (ep, device, compatibility[]) for one variant from Olive metadata.""" + ep = attrs.get("ep") or "CPUExecutionProvider" + device = attrs.get("device") or None + compatibility = parse_compatibility_strings(_extract_ep_compatibility_from_onnx(onnx_path, ep)) + return ep, device, compatibility def _extract_ep_compatibility_from_onnx(model_path: Path, ep: str = "") -> Optional[str]: - """Extract ep_compatibility_info from ONNX model custom metadata.""" + """Read ``ep_compatibility_info.`` from the ONNX model's metadata_props.""" if not model_path.is_file(): return None - try: import onnx @@ -365,74 +846,7 @@ def _extract_ep_compatibility_from_onnx(model_path: Path, ep: str = "") -> Optio return None -def _copy_model_files_single(model_path: Path, dest_dir: Path) -> None: - """Copy model files for a single ONNX model into dest_dir.""" - if dest_dir.exists(): - return - - src_dir = model_path.parent if model_path.is_file() else model_path - if src_dir.is_dir(): - shutil.copytree(str(src_dir), str(dest_dir)) - else: - dest_dir.mkdir(parents=True, exist_ok=True) - shutil.copy2(str(model_path), str(dest_dir)) - - -def _copy_component_files(model_path: Path, dest_dir: Path) -> None: - """Copy files for a single ONNX component to dest_dir. - - Copies the .onnx file and its associated context binary (.bin) files - and external data files. - """ - if dest_dir.exists(): - return - - dest_dir.mkdir(parents=True, exist_ok=True) - src_dir = model_path.parent - - # Copy the ONNX file itself - shutil.copy2(str(model_path), str(dest_dir / model_path.name)) - - # Find associated files - associated_files: set[str] = set() - try: - from olive.passes.onnx.common import get_context_bin_file_names - - associated_files.update(get_context_bin_file_names(str(model_path))) - except Exception: - logger.debug("Could not read context binary file names from %s", model_path, exc_info=True) - - try: - import onnx - - onnx_model = onnx.load(str(model_path), load_external_data=False) - for init in onnx_model.graph.initializer: - if init.data_location == onnx.TensorProto.EXTERNAL: - for entry in init.external_data: - if entry.key == "location": - associated_files.add(entry.value) - except Exception: - logger.debug("Could not read ONNX external data from %s", model_path, exc_info=True) - - for file_name in associated_files: - src = src_dir / file_name - if src.is_file(): - shutil.copy2(str(src), str(dest_dir / file_name)) - - -def _remove_config_files(component_dir: Path, config_file_names: set[str]) -> None: - """Remove config files from variant subdirectories (they belong in configs/).""" - for name in config_file_names: - for p in component_dir.rglob(name): - if p.is_dir(): - shutil.rmtree(str(p)) - else: - p.unlink() - logger.debug("Removed duplicate config entry %s from variant directory", p) - - def _task_to_component_name(task: str) -> str: - """Map a task string to a component name for single-component models.""" task_component_map = { "text_generation": "decoder", "text2text_generation": "encoder_decoder", diff --git a/olive/cli/run.py b/olive/cli/run.py index 6d2a831aef..3b4c166b3e 100644 --- a/olive/cli/run.py +++ b/olive/cli/run.py @@ -49,16 +49,22 @@ def register_subcommand(parser: ArgumentParser): @action def run(self): + from copy import deepcopy + from pathlib import Path + from olive.common.config_utils import load_config_file from olive.workflows import run as olive_run # allow the run_config to be a dict already (for api use) - run_config = self.args.run_config - if not isinstance(run_config, dict): - run_config = load_config_file(run_config) + run_config_input = self.args.run_config + run_config = ( + deepcopy(run_config_input) if isinstance(run_config_input, dict) else load_config_file(run_config_input) + ) + config_overrides = {} if input_model_config := get_input_model_config(self.args, required=False): print("Replacing input model config in run config") run_config["input_model"] = input_model_config + config_overrides["input_model"] = input_model_config for arg_key, rc_key in [("output_path", "output_dir"), ("log_level", "log_severity_level")]: if (arg_value := getattr(self.args, arg_key)) is not None: @@ -67,12 +73,26 @@ def run(self): run_config.get("engine", {}).pop(rc_key, None) # add value to run config directly run_config[rc_key] = arg_value + config_overrides[rc_key] = arg_value + + recipe_telemetry_metadata = { + "recipe_command": "WorkflowRun", + "recipe_source": "config_dict" if isinstance(run_config_input, dict) else "config_file", + "recipe_format": "dict" + if isinstance(run_config_input, dict) + else Path(run_config_input).suffix.lstrip(".").lower() or "unknown", + "execution_mode": "list_required_packages" if self.args.list_required_packages else "run", + "package_config_provided": bool(self.args.package_config), + } + if config_overrides: + recipe_telemetry_metadata["config_overrides"] = config_overrides workflow_output = olive_run( run_config, list_required_packages=self.args.list_required_packages, tempdir=self.args.tempdir, package_config=self.args.package_config, + recipe_telemetry_metadata=recipe_telemetry_metadata, ) if self.args.list_required_packages is True: diff --git a/olive/common/onnx_io.py b/olive/common/onnx_io.py index 42db080a35..5a47c9058a 100644 --- a/olive/common/onnx_io.py +++ b/olive/common/onnx_io.py @@ -89,20 +89,20 @@ def get_kv_info(io_config: dict) -> dict | None: if kv_format is None: return None - # find the number of layers - num_layers = 0 + # find the actual layer indices (may be non-contiguous after pruning) + layer_indices = set() for i_name in io_config["input_names"]: - num_layers += int(re.match(kv_format, i_name) is not None) + m = re.match(kv_format, i_name) + if m: + layer_indices.add(int(m.group(1))) + layer_indices = sorted(layer_indices) past_names = [] present_to_past = {} for k in ["key", "value"]: - past_names.extend([kv_options[kv_format][f"past_{k}"] % i for i in range(num_layers)]) + past_names.extend([kv_options[kv_format][f"past_{k}"] % i for i in layer_indices]) present_to_past.update( - { - kv_options[kv_format][f"present_{k}"] % i: kv_options[kv_format][f"past_{k}"] % i - for i in range(num_layers) - } + {kv_options[kv_format][f"present_{k}"] % i: kv_options[kv_format][f"past_{k}"] % i for i in layer_indices} ) past_shape = io_config["input_shapes"][io_config["input_names"].index(past_names[0])] diff --git a/olive/evaluator/lmeval_ort.py b/olive/evaluator/lmeval_ort.py index 50d1f1289d..aaa4b690f5 100644 --- a/olive/evaluator/lmeval_ort.py +++ b/olive/evaluator/lmeval_ort.py @@ -293,6 +293,21 @@ def __init__(self, model_path: str, ep: str | None = None, ep_options: dict | No if self.kv_info is None: raise ValueError("Invalid io_config: kv_info not found") + # detect position_ids rank, hybrid state inputs and outputs in a single pass + self.position_ids_rank = 2 + self.hybrid_states = {} + self.hybrid_state_outputs = {} + for prefix in ("input", "output"): + names = self.io_config[f"{prefix}_names"] + shapes = self.io_config[f"{prefix}_shapes"] + types = self.io_config[f"{prefix}_types"] + target = self.hybrid_states if prefix == "input" else self.hybrid_state_outputs + for idx, name in enumerate(names): + if name == "position_ids": + self.position_ids_rank = len(shapes[idx]) + elif "conv_state" in name or "recurrent_state" in name: + target[name] = {"shape": shapes[idx], "dtype": types[idx]} + self._session = None self._batch_size = None self._buffers = None @@ -331,17 +346,29 @@ def run(self, input_ids: torch.Tensor) -> torch.Tensor: inputs_to_bind[name] = (self._buffers["inputs"][name], self.io_dtypes[name], shape) if "position_ids" in self._buffers["inputs"]: # need to reallocate since the position_ids tensor may be sliced - inputs_to_bind["position_ids"] = ( - self._buffers["inputs"]["position_ids"][:batch_size, :seqlen].contiguous(), - self.io_dtypes["position_ids"], - (batch_size, seqlen), - ) + if self.position_ids_rank == 3: + inputs_to_bind["position_ids"] = ( + self._buffers["inputs"]["position_ids"][:, :batch_size, :seqlen].contiguous(), + self.io_dtypes["position_ids"], + (self._buffers["inputs"]["position_ids"].shape[0], batch_size, seqlen), + ) + else: + inputs_to_bind["position_ids"] = ( + self._buffers["inputs"]["position_ids"][:batch_size, :seqlen].contiguous(), + self.io_dtypes["position_ids"], + (batch_size, seqlen), + ) for name in self._buffers["kv_inputs"]: inputs_to_bind[name] = ( self._buffers["kv_inputs"][name], self.kv_info["dtype"], (batch_size, self.kv_info["num_kv_heads"], 0, self.kv_info["head_size"]), ) + # hybrid state inputs (conv_state, recurrent_state) + for name, buf in self._buffers["hybrid_inputs"].items(): + shape = list(buf.shape) + shape[0] = batch_size + inputs_to_bind[name] = (buf, self.hybrid_states[name]["dtype"], tuple(shape)) for name, (buffer, dtype, shape) in inputs_to_bind.items(): io_binding.bind_input( name, @@ -363,6 +390,11 @@ def run(self, input_ids: torch.Tensor) -> torch.Tensor: self.kv_info["dtype"], (batch_size, self.kv_info["num_kv_heads"], seqlen, self.kv_info["head_size"]), ) + # hybrid state outputs (conv_state, recurrent_state) + for name, buf in self._buffers["hybrid_outputs"].items(): + shape = list(buf.shape) + shape[0] = batch_size + outputs_to_bind[name] = (buf, self.hybrid_state_outputs[name]["dtype"], tuple(shape)) for name, (buffer, dtype, shape) in outputs_to_bind.items(): io_binding.bind_output( name, @@ -418,11 +450,16 @@ def initialize_buffers(self, batch_size: int, max_length: int): ) } if self.io_dtypes.get("position_ids") is not None: - inputs["position_ids"] = ( + pos_ids = ( torch.arange(max_length, dtype=getattr(torch, self.io_dtypes["position_ids"]), device=self.device) .unsqueeze(0) .expand(batch_size, -1) ) + if self.position_ids_rank == 3: + # mRoPE: expand to [mrope_sections, batch_size, seq_len] + mrope_sections = self.io_config["input_shapes"][self.io_config["input_names"].index("position_ids")][0] + pos_ids = pos_ids.unsqueeze(0).expand(mrope_sections, -1, -1) + inputs["position_ids"] = pos_ids if self.io_dtypes.get("past_seq_len") is not None: inputs["past_seq_len"] = ( torch.tensor(max_length - 1, dtype=getattr(torch, self.io_dtypes["past_seq_len"]), device=self.device) @@ -457,6 +494,20 @@ def initialize_buffers(self, batch_size: int, max_length: int): } self._buffers = {"inputs": inputs, "outputs": outputs, "kv_inputs": kv_inputs, "kv_outputs": kv_outputs} + + # hybrid state buffers (conv_state, recurrent_state) - zero-initialized + hybrid_inputs = {} + for name, info in self.hybrid_states.items(): + # Replace symbolic 'batch_size' with actual batch_size + shape = [batch_size if s == "batch_size" else s for s in info["shape"]] + hybrid_inputs[name] = torch.zeros(shape, dtype=getattr(torch, info["dtype"]), device=self.device) + hybrid_outputs = {} + for name, info in self.hybrid_state_outputs.items(): + shape = [batch_size if s == "batch_size" else s for s in info["shape"]] + hybrid_outputs[name] = torch.zeros(shape, dtype=getattr(torch, info["dtype"]), device=self.device) + self._buffers["hybrid_inputs"] = hybrid_inputs + self._buffers["hybrid_outputs"] = hybrid_outputs + self._batch_size = batch_size @@ -539,7 +590,7 @@ def _detect_full_logits(self) -> bool: def eot_token_id(self): return self._eot_token_id - def tok_encode(self, string: str, **kwargs) -> list[int]: + def tok_encode(self, string: str, add_special_tokens: bool | None = None, **kwargs) -> list[int]: """Tokenize a string using the model's tokenizer and return a list of token IDs.""" return self.tokenizer.encode(string).tolist() @@ -551,27 +602,25 @@ def model_call(self, input_ids: torch.Tensor, cont_len: int = 0) -> torch.Tensor self.params.set_search_options(batch_size=batch_size) generator = og.Generator(self.model, self.params) - if self._returns_full_logits: - generator.append_tokens(input_ids.tolist()) - return torch.from_numpy(generator.get_output("logits")).to(self.device) - - # Model only returns logits for the last appended position. if batch_size > 1 and cont_len > 1: raise ValueError( - "batch_size > 1 is not supported when the model returns single-position logits" + "batch_size > 1 is not supported when using incremental get_logits() retrieval" " and continuation length > 1. Right-padding misaligns continuation positions across" " batch elements. Use batch_size=1 instead." ) - # Bulk-append context tokens, then step through the last cont_len tokens - # one at a time to collect only the logits we actually need. + # Use incremental token appending with get_logits() to avoid copying + # the full logits tensor from GPU to CPU. get_output("logits") copies + # seq_len * vocab_size * 2 bytes (e.g. 472MB for 900 tokens with + # 262K vocab), while get_logits() copies only vocab_size * 4 bytes + # (~1MB) per position. n_logits = max(cont_len, 1) prefix_len = seq_len - n_logits generator.append_tokens(input_ids[:, : prefix_len + 1].tolist()) - all_logits = [torch.from_numpy(generator.get_output("logits")).to(self.device)] + all_logits = [torch.from_numpy(generator.get_logits()).to(self.device)] for i in range(prefix_len + 1, seq_len): generator.append_tokens(input_ids[:, i : i + 1].tolist()) - all_logits.append(torch.from_numpy(generator.get_output("logits")).to(self.device)) + all_logits.append(torch.from_numpy(generator.get_logits()).to(self.device)) # No need to pad to [batch, seq_len, vocab]. The slicing in _loglikelihood_tokens computes # ctx_len = inplen + (logits.shape[0] - padding_len_inp), which adjusts for the shorter diff --git a/olive/evaluator/olive_evaluator.py b/olive/evaluator/olive_evaluator.py index 05933b8b64..eb88b1597c 100644 --- a/olive/evaluator/olive_evaluator.py +++ b/olive/evaluator/olive_evaluator.py @@ -1612,10 +1612,16 @@ def evaluate( task_metrics = {} for mf, v in metric_items: - if mf != "alias": + if mf == "alias": + continue + if not isinstance(v, (int, float)): + continue + if "," in mf: m, _ = mf.split(",", 1) - if not m.endswith("_stderr"): - task_metrics[m] = SubMetricResult(value=v, priority=-1, higher_is_better=True) + else: + m = mf + if not m.endswith("_stderr"): + task_metrics[m] = SubMetricResult(value=v, priority=-1, higher_is_better=True) metrics[task_name] = MetricResult.model_validate(task_metrics) diff --git a/olive/passes/onnx/graph_surgeries.py b/olive/passes/onnx/graph_surgeries.py index db32c99b70..61f80a47ff 100644 --- a/olive/passes/onnx/graph_surgeries.py +++ b/olive/passes/onnx/graph_surgeries.py @@ -105,6 +105,14 @@ def get_initializer_by_name(model, name: str): return initializer return None + @staticmethod + def find_node(dag: OnnxDAG, op_type: str, name_substr: str) -> str | None: + """Find the first node matching an op_type and name substring.""" + for name in dag.get_node_names(): + if dag.get_node_op_type(name) == op_type and name_substr in name: + return name + return None + @staticmethod def create_new_name(name: str, old_op: str, new_op: str) -> str: return name.replace(old_op, new_op) if old_op in name else f"{name}_{new_op}" @@ -850,6 +858,225 @@ def get_rmsnorm_nodes(pow_node: str, dag: OnnxDAG) -> list[str] | None: return rmsnorm_nodes if len(rmsnorm_nodes) >= (len(pattern) - 1) else [] +class SimplifiedLayerNormToRMSNorm(ProtoSurgeon): + """Replace SimplifiedLayerNormalization or SkipSimplifiedLayerNormalization with an RMSNorm subgraph built from elementwise ops. + + RMS(x) = sqrt(mean(x^2, axis=-1, keepdims=1) + eps) + y = (x / RMS(x)) * gamma + + For SkipSimplifiedLayerNormalization, we first do: + s = input + skip + and use 's' as x for RMSNorm. If the original node exposes a second output + (residual sum), we rewire its consumers to 's' to preserve graph behavior. + + IMPORTANT: ReduceMean schema change across opsets: + - opset < 18: axes is an ATTRIBUTE + - opset >=18: axes is an INPUT tensor (int64), keepdims remains an attribute. + """ + + def __call__(self, model: onnx.ModelProto): + from onnx import numpy_helper + from onnx.helper import tensor_dtype_to_np_dtype + + dag = OnnxDAG(model) + + # Determine the default ONNX opset for the main domain ("", "ai.onnx"). + # We'll use this to decide how to build ReduceMean. + default_opset = None + for imp in model.opset_import: + if imp.domain in ("", "ai.onnx"): + default_opset = imp.version + break + if default_opset is None: + # Fall back defensively; most models have a default import. + default_opset = 13 + + use_axes_input_for_reduce_mean = default_opset >= 18 + + modified = 0 + + for node_name in dag.get_node_names(): + op_type = dag.get_node_op_type(node_name) + if op_type not in {"SimplifiedLayerNormalization", "SkipSimplifiedLayerNormalization"}: + continue + + graph_idx = dag.get_graph_idx(node_name) + inputs = dag.get_node_inputs(node_name, True) + outputs = dag.get_node_outputs(node_name, True) + + # --------------------------- + # Build the input to be normalized: ln_input + # --------------------------- + if op_type == "SkipSimplifiedLayerNormalization": + # Expect inputs: [input, skip, gamma] + if len(inputs) != 3: + continue + root1, root2, gamma = inputs + + # Add(input, skip) => skip_add_out + skip_add_name = self.create_new_name(node_name, op_type, "Add") + skip_add_out = f"{skip_add_name}_out" + skip_add_node = onnx.helper.make_node( + "Add", + inputs=[root1, root2], + outputs=[skip_add_out], + name=skip_add_name, + ) + dag.add_node(skip_add_node, graph_idx) + + ln_input = skip_add_out + else: + # SimplifiedLayerNormalization: inputs = [x, gamma] + if len(inputs) != 2: + continue + ln_input, gamma = inputs + + # The original primary output (normalized tensor) + ln_output = outputs[0] + + ln_elem_type = dag.get_io_elem_type(inputs[0]) or onnx.TensorProto.FLOAT + ln_np_dtype = tensor_dtype_to_np_dtype(ln_elem_type) + + # --------------------------- + # Step 1: Pow(x, 2) + # --------------------------- + pow_name = self.create_new_name(node_name, op_type, "Pow") + pow_out = f"{pow_name}_out" + pow_const = numpy_helper.from_array(np.array([2.0], dtype=ln_np_dtype), name=f"{pow_name}_const") + dag.add_initializer(pow_const, graph_idx) + pow_node = onnx.helper.make_node( + "Pow", + inputs=[ln_input, pow_const.name], + outputs=[pow_out], + name=pow_name, + ) + dag.add_node(pow_node, graph_idx) + + # --------------------------- + # Step 2: ReduceMean over last dim, keepdims=1 + # - opset < 18 : axes is an attribute + # - opset >= 18: axes is an input tensor (INT64) + # --------------------------- + mean_name = self.create_new_name(node_name, op_type, "ReduceMean") + mean_out = f"{mean_name}_out" + + if use_axes_input_for_reduce_mean: + axes_init = numpy_helper.from_array(np.array([-1], dtype=np.int64), name=f"{mean_name}_axes") + dag.add_initializer(axes_init, graph_idx) + + mean_node = onnx.helper.make_node( + "ReduceMean", + inputs=[pow_out, axes_init.name], + outputs=[mean_out], + name=mean_name, + keepdims=1, + ) + else: + # Older schema: axes is an attribute + mean_node = onnx.helper.make_node( + "ReduceMean", + inputs=[pow_out], + outputs=[mean_out], + name=mean_name, + axes=[-1], + keepdims=1, + ) + dag.add_node(mean_node, graph_idx) + + # --------------------------- + # Step 3: Add epsilon + # --------------------------- + eps_value = 1e-06 + add_eps_name = self.create_new_name(node_name, op_type, "AddEps") + add_eps_out = f"{add_eps_name}_out" + + eps_const = numpy_helper.from_array(np.array([eps_value], dtype=ln_np_dtype), name=f"{add_eps_name}_const") + dag.add_initializer(eps_const, graph_idx) + + add_eps_node = onnx.helper.make_node( + "Add", + inputs=[mean_out, eps_const.name], + outputs=[add_eps_out], + name=add_eps_name, + ) + dag.add_node(add_eps_node, graph_idx) + + # --------------------------- + # Step 4: Sqrt + # --------------------------- + sqrt_name = self.create_new_name(node_name, op_type, "Sqrt") + sqrt_out = f"{sqrt_name}_out" + sqrt_node = onnx.helper.make_node( + "Sqrt", + inputs=[add_eps_out], + outputs=[sqrt_out], + name=sqrt_name, + ) + dag.add_node(sqrt_node, graph_idx) + + # --------------------------- + # Step 5: Div (x / sqrt(...)) + # --------------------------- + div_name = self.create_new_name(node_name, op_type, "Div") + div_out = f"{div_name}_out" + div_node = onnx.helper.make_node( + "Div", + inputs=[ln_input, sqrt_out], + outputs=[div_out], + name=div_name, + ) + dag.add_node(div_node, graph_idx) + + # --------------------------- + # Step 6: Mul with gamma + # --------------------------- + mul_name = self.create_new_name(node_name, op_type, "Mul") + mul_out = f"{mul_name}_out" + mul_node = onnx.helper.make_node( + "Mul", + inputs=[div_out, gamma], + outputs=[mul_out], + name=mul_name, + ) + dag.add_node(mul_node, graph_idx) + + # --------------------------- + # Rewire consumers of the original main output + # --------------------------- + for consumer in dag.get_consumers(ln_output): + dag.replace_node_input(consumer, ln_output, mul_out) + + # --------------------------- + # For SkipSimplifiedLayerNormalization that had two outputs: + # - Output 1 is typically residual sum (input_skip_bias_sum) + # - Redirect its consumers to the skip-sum Add output + # --------------------------- + if op_type == "SkipSimplifiedLayerNormalization" and len(outputs) == 2: + second_output = outputs[1] + + second_vi = dag.get_value_info_proto(second_output) + if second_vi is not None: + new_vi = onnx.ValueInfoProto() + new_vi.CopyFrom(second_vi) + new_vi.name = skip_add_out + dag.add_value_info(new_vi, graph_idx) + + # Redirect all consumers of the second output + for consumer in dag.get_consumers(second_output): + dag.replace_node_input(consumer, second_output, skip_add_out) + + dag.remove_node(node_name) + modified += 1 + + if modified > 0: + logger.debug( + "Replaced %d Simplified/SkipSimplifiedLayerNormalization nodes with RMSNorm subgraphs", modified + ) + + dag.update() + return dag.model + + class SimplifiedLayerNormToL2Norm(ProtoSurgeon): """Replace Skip/SimplifiedLayerNormalization node with L2Norm subgraph. @@ -1953,7 +2180,7 @@ def __call__(self, model: onnx.ModelProto): ]: return dag.model - if embed_op_type == "Gather": + if embed_op_type == "Gather" and lm_head_op_type == "MatMul": return self.handle_unquantized(dag, embed_name, lm_head_name) return self.handle_quantized(dag, embed_name, lm_head_name) @@ -2152,6 +2379,260 @@ def equal_weights(self, dag: OnnxDAG, init0: str, init1: str, transpose: bool = return np.array_equal(arr0.ravel(), arr1.ravel()) +class QuantizeEmbeddingInt8(ProtoSurgeon): + """Quantize FP16 embedding to INT8 using GatherBlockQuantized. + + Replaces the Gather op for embed_tokens with a GatherBlockQuantized op + that uses per-block INT8 quantization (block_size=32). + """ + + def __call__(self, model: onnx.ModelProto): + from onnx import numpy_helper + + dag = OnnxDAG(model) + + # Find embedding Gather node + gather_name = self.find_node(dag, "Gather", "embed_tokens") + if gather_name is None: + logger.warning("No embed_tokens Gather node found, skipping QuantizeEmbeddingInt8") + return model + + embed_weight_name = dag.get_node_inputs(gather_name)[0] + if not dag.is_initializer(embed_weight_name): + logger.warning("Embedding weight initializer not found, skipping QuantizeEmbeddingInt8") + return model + + embed_init = dag.get_initializer_proto(embed_weight_name) + + # Check if already quantized + if embed_init.data_type not in (onnx.TensorProto.FLOAT16, onnx.TensorProto.FLOAT): + logger.info("Embedding is not FP16/FP32, skipping QuantizeEmbeddingInt8") + return model + + embed = dag.get_initializer_np_array(embed_weight_name).astype(np.float32) + vocab_size, hidden_size = embed.shape + block_size = 32 + + if hidden_size % block_size != 0: + logger.warning("hidden_size %d not divisible by block_size %d, skipping", hidden_size, block_size) + return model + + num_blocks = hidden_size // block_size + + # Preserve the model's float dtype for scales so downstream ops (LayerNorm, MatMul, ...) + # receive the dtype they expect. FP16 model -> FP16 scales; FP32 model -> FP32 scales. + scales_dtype = np.float16 if embed_init.data_type == onnx.TensorProto.FLOAT16 else np.float32 + + logger.info( + "Quantizing embedding %s (%dx%d) from %s to INT8 (block_size=%d)", + embed_weight_name, + vocab_size, + hidden_size, + "FP16" if scales_dtype == np.float16 else "FP32", + block_size, + ) + + # Per-block INT8 quantization (asymmetric with zero_point=128 for GatherBlockQuantized) + blocked = embed.reshape(vocab_size, num_blocks, block_size) + scales = (np.abs(blocked).max(axis=2) / 127.0).astype(scales_dtype) + scales_f32 = scales.astype(np.float32) + # Avoid division by zero + scales_f32 = np.where(scales_f32 == 0, 1.0, scales_f32) + q = np.clip(np.round(blocked / scales_f32[:, :, None]), -128, 127).astype(np.int8) + # GatherBlockQuantized expects unsigned uint8 with zero_point offset + q_uint8 = (q.astype(np.int16) + 128).astype(np.uint8) + q_flat = q_uint8.reshape(vocab_size, hidden_size) + # Zero point tensor: 128 for all blocks (symmetric around 128) + zero_points = np.full((vocab_size, num_blocks), 128, dtype=np.uint8) + + old_size_mb = embed.nbytes / (1024 * 1024) + new_size_mb = (q_flat.nbytes + scales.nbytes + zero_points.nbytes) / (1024 * 1024) + logger.info( + "Embedding: %.0f MB -> %.0f MB (saved %.0f MB)", old_size_mb, new_size_mb, old_size_mb - new_size_mb + ) + + graph_idx = dag.get_graph_idx(gather_name) + + # Create new initializers + qweight_name = embed_weight_name + "_Q8" + scales_name = embed_weight_name + "_scales" + zp_name = embed_weight_name + "_zp" + dag.add_initializer(numpy_helper.from_array(q_flat, name=qweight_name), graph_idx) + dag.add_initializer(numpy_helper.from_array(scales, name=scales_name), graph_idx) + dag.add_initializer(numpy_helper.from_array(zero_points, name=zp_name), graph_idx) + + # Ensure com.microsoft opset is declared + dag.set_opset_import("com.microsoft", 1) + + # Create GatherBlockQuantized node + gather_inputs = dag.get_node_inputs(gather_name) + gather_output = dag.get_node_outputs(gather_name)[0] + gbq_output = gather_output + "_gbq" + gbq_name = gather_name.replace("Gather", "GatherBlockQuantized") + gbq_node = onnx.helper.make_node( + "GatherBlockQuantized", + inputs=[qweight_name, gather_inputs[1], scales_name, zp_name], + outputs=[gbq_output], + name=gbq_name, + domain="com.microsoft", + bits=8, + block_size=block_size, + gather_axis=0, + quantize_axis=1, + ) + dag.add_node(gbq_node, graph_idx) + + # Rewire consumers from old Gather output to new GBQ output and remove old node + for consumer in dag.get_consumers(gather_output): + dag.replace_node_input(consumer, gather_output, gbq_output) + dag.remove_node(gather_name) + # Old FP16 embedding weight is auto-cleaned by update() since no consumers remain + + logger.info("Replaced Gather with GatherBlockQuantized (INT8)") + dag.update() + return dag.model + + +class ShareEmbeddingLmHead(ProtoSurgeon): + """Share INT8 embedding weight with lm_head by converting lm_head to INT8 MatMulNBits. + + Must be applied AFTER QuantizeEmbeddingInt8. Replaces the lm_head's INT4 + MatMulNBits with an INT8 MatMulNBits that references the same quantized + weight as the embedding's GatherBlockQuantized, eliminating duplicate storage. + """ + + def __call__(self, model: onnx.ModelProto): + from onnx import numpy_helper + + dag = OnnxDAG(model) + + # Find embedding GatherBlockQuantized + gbq_name = self.find_node(dag, "GatherBlockQuantized", "embed_tokens") + if gbq_name is None: + logger.warning("No embed_tokens GatherBlockQuantized node found, skipping ShareEmbeddingLmHead") + return model + + attrs = dag.get_node_attributes(gbq_name) + gbq_bits = attrs.get("bits", 8) + gbq_block_size = attrs.get("block_size", 32) + + if gbq_bits != 8: + logger.warning("Embedding is not INT8, cannot share with lm_head") + return model + + # Get embedding weight, scales, zero_points names + gbq_inputs = dag.get_node_inputs(gbq_name) + embed_weight_name = gbq_inputs[0] + embed_scales_name = gbq_inputs[2] + embed_zp_name = gbq_inputs[3] if len(gbq_inputs) > 3 else None + + # Get embedding weight shape to determine K and N + if not dag.is_initializer(embed_weight_name): + logger.warning("Could not find embedding weight initializer") + return model + + embed_weight = dag.get_initializer_np_array(embed_weight_name) + + vocab_size, hidden_size = embed_weight.shape # [V, H] for INT8 + num_blocks = hidden_size // gbq_block_size + + # Find lm_head MatMulNBits node + lm_head_name = self.find_node(dag, "MatMulNBits", "lm_head") + if lm_head_name is None: + logger.warning("No lm_head MatMulNBits found") + return model + + lm_head_inputs = dag.get_node_inputs(lm_head_name) + + # Check if already shared (idempotency): lm_head weight input references embedding weight + if embed_weight_name in lm_head_inputs[1] or lm_head_inputs[2] == embed_scales_name: + logger.info("lm_head already shares weights with embedding, skipping ShareEmbeddingLmHead") + return model + + # Get old lm_head attributes + old_attrs = dag.get_node_attributes(lm_head_name) + + logger.info( + "Sharing embedding with lm_head: lm_head INT%d (%dx%d, bs=%d) -> INT8 (shared with embedding)", + old_attrs.get("bits", 0), + old_attrs.get("N", 0), + old_attrs.get("K", 0), + old_attrs.get("block_size", 0), + ) + + graph_idx = dag.get_graph_idx(lm_head_name) + + # MatMulNBits needs [N, K_blocks, block_size] but GBQ weight is [V, H]. + # Add a Reshape node to convert, referencing the SAME embedding weight. + reshape_shape_name = "lm_head.MatMulNBits.reshape_shape" + reshape_shape = np.array([vocab_size, num_blocks, gbq_block_size], dtype=np.int64) + dag.add_initializer(numpy_helper.from_array(reshape_shape, name=reshape_shape_name), graph_idx) + + reshape_output = "lm_head.MatMulNBits.reshaped_weight" + reshape_node = onnx.helper.make_node( + "Reshape", + inputs=[embed_weight_name, reshape_shape_name], + outputs=[reshape_output], + name="lm_head/Reshape_shared_weight", + ) + dag.add_node(reshape_node, graph_idx) + + # Scales and zp: reuse embedding's directly + inputs = [lm_head_inputs[0], reshape_output, embed_scales_name] + if embed_zp_name: + inputs.append(embed_zp_name) + + # Ensure com.microsoft opset is declared + dag.set_opset_import("com.microsoft", 1) + + # Create new INT8 MatMulNBits node + lm_head_output = dag.get_node_outputs(lm_head_name)[0] + new_lm_head_output = lm_head_output + "_shared" + new_lm_head_name = lm_head_name + "_shared" + lm_head_proto = dag.get_node_proto(lm_head_name) + new_lm_head = onnx.helper.make_node( + "MatMulNBits", + inputs=inputs, + outputs=[new_lm_head_output], + name=new_lm_head_name, + domain="com.microsoft", + bits=8, + block_size=gbq_block_size, + K=hidden_size, + N=vocab_size, + ) + # Copy accuracy_level if present + for attr in lm_head_proto.attribute: + if attr.name == "accuracy_level": + new_lm_head.attribute.append(attr) + + dag.add_node(new_lm_head, graph_idx) + + # Copy value info from old output to new output (needed for graph output serialization) + old_vi = dag.get_value_info_proto(lm_head_output) + if old_vi is not None: + new_vi = onnx.helper.make_tensor_value_info(new_lm_head_output, old_vi.type.tensor_type.elem_type, []) + new_vi.CopyFrom(old_vi) + new_vi.name = new_lm_head_output + dag.add_value_info(new_vi, graph_idx) + + # Rewire consumers and remove old node + for consumer in dag.get_consumers(lm_head_output): + dag.replace_node_input(consumer, lm_head_output, new_lm_head_output) + if dag.is_output(lm_head_output): + dag.remove_output(lm_head_output) + dag.remove_node(lm_head_name) + dag.rename_node_output(new_lm_head_name, new_lm_head_output, lm_head_output) + dag.make_output(lm_head_output) + else: + dag.remove_node(lm_head_name) + # Old lm_head initializers are auto-cleaned by update() since no consumers remain + + logger.info("lm_head now uses INT8 MatMulNBits (shared quantization with embedding)") + dag.update() + return dag.model + + class ReciprocalMulToDiv(ProtoSurgeon): """Replace Reciprocal(x) * a with Div(a, x). diff --git a/olive/passes/onnx/kquant_quantization.py b/olive/passes/onnx/kquant_quantization.py index 5d75016ecc..4263406afd 100644 --- a/olive/passes/onnx/kquant_quantization.py +++ b/olive/passes/onnx/kquant_quantization.py @@ -256,7 +256,15 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassCon def _run_for_config( self, model: ONNXModelHandler, config: type[BasePassConfig], output_model_path: str ) -> ONNXModelHandler: - output_model_path = resolve_onnx_path(output_model_path, Path(model.model_path).name) + # For composite model components (e.g., Whisper encoder.onnx/decoder.onnx), + # output_model_path already includes .onnx extension. Strip it so ir.save doesn't + # create a double extension (.onnx.onnx). For other cases, resolve normally. + output_path_obj = Path(output_model_path) + if output_path_obj.suffix == ".onnx": + output_model_path = str(output_path_obj.with_suffix("")) + else: + output_model_path = resolve_onnx_path(output_model_path, Path(model.model_path).name) + ir_model = model.load_ir_model() ir.external_data.load_to_model(ir_model) ir_model.graph.opset_imports[MSFT_DOMAIN] = 1 diff --git a/olive/passes/onnx/model_builder.py b/olive/passes/onnx/model_builder.py index f704579ba2..1d49319de7 100644 --- a/olive/passes/onnx/model_builder.py +++ b/olive/passes/onnx/model_builder.py @@ -20,7 +20,7 @@ from olive.constants import Precision from olive.hardware.accelerator import AcceleratorSpec, Device from olive.hardware.constants import ExecutionProvider -from olive.model import HfModelHandler, ONNXModelHandler +from olive.model import CompositeModelHandler, HfModelHandler, ONNXModelHandler from olive.model.utils import resolve_onnx_path from olive.passes import Pass from olive.passes.olive_pass import PassConfigParam @@ -264,8 +264,9 @@ def _run_for_config( if config.extra_options: extra_args.update(config.extra_options) - # Ensure output_model_filepath matches the final filename in extra_args - output_model_filepath = Path(output_model_path) / extra_args["filename"] + # Ensure output_model_filepath matches the final filename in extra_args while preserving + # the resolved output directory selected above. + output_model_filepath = output_model_filepath.parent / extra_args["filename"] model_attributes = copy.deepcopy(model.model_attributes or {}) @@ -283,26 +284,6 @@ def _run_for_config( **extra_args, ) - # Apply post-processing annotations (split assignments and/or layer annotations) - # in a single load/save cycle to avoid redundant disk I/O. - split_assignments = model_attributes.get("split_assignments") if not metadata_only else None - layer_annotations = model_attributes.get("layer_annotations") if not metadata_only else None - - if split_assignments or layer_annotations: - model_proto = onnx.load(output_model_filepath, load_external_data=False) - - if split_assignments: - # NOTE: currently the model builder renames modules to it's own naming convention - # so the assignments for the renamed modules won't match - split_assignment_str = ";".join([f"{k}={v}" for k, v in split_assignments.items()]) - onnx.helper.set_model_props(model_proto, {"split_assignments": split_assignment_str}) - - if layer_annotations: - from olive.passes.onnx.layer_annotation import annotate_proto_model - - annotate_proto_model(model_proto, layer_annotations) - - onnx.save(model_proto, output_model_filepath) except Exception: # if model building fails, clean up the intermediate files in the cache_dir cache_dir = Path(HF_HUB_CACHE) @@ -328,6 +309,58 @@ def _run_for_config( # tokenizer and generation configs are skipped since they are already saved by the model builder model.save_metadata(output_model_filepath.parent) + generated_onnx_files = sorted(output_model_filepath.parent.glob("*.onnx")) if not metadata_only else [] + + # For multi-file models (e.g., Whisper), preserve component file names and process each file independently + # in subsequent passes by returning a CompositeModelHandler. + is_multi_file_model = not metadata_only and len(generated_onnx_files) > 1 + resolved_single_model_filepath = output_model_filepath + if ( + not metadata_only + and not is_multi_file_model + and not output_model_filepath.exists() + and len(generated_onnx_files) == 1 + ): + logger.info( + "ONNX model file %s does not exist, using %s instead", + output_model_filepath, + generated_onnx_files[0].name, + ) + resolved_single_model_filepath = generated_onnx_files[0] + + # Apply post-processing annotations (split assignments and/or layer annotations) + # in a single load/save cycle to avoid redundant disk I/O. + split_assignments = model_attributes.get("split_assignments") if not metadata_only else None + layer_annotations = model_attributes.get("layer_annotations") if not metadata_only else None + if is_multi_file_model: + primary_onnx_files = generated_onnx_files + elif resolved_single_model_filepath.exists(): + primary_onnx_files = [resolved_single_model_filepath] + else: + primary_onnx_files = [] + if split_assignments or layer_annotations: + if primary_onnx_files: + for primary_onnx_file in primary_onnx_files: + model_proto = onnx.load(primary_onnx_file, load_external_data=False) + + if split_assignments: + # NOTE: currently the model builder renames modules to it's own naming convention + # so the assignments for the renamed modules won't match + split_assignment_str = ";".join([f"{k}={v}" for k, v in split_assignments.items()]) + onnx.helper.set_model_props(model_proto, {"split_assignments": split_assignment_str}) + + if layer_annotations: + from olive.passes.onnx.layer_annotation import annotate_proto_model + + annotate_proto_model(model_proto, layer_annotations) + + onnx.save(model_proto, primary_onnx_file) + else: + logger.warning( + "Skipping split_assignments/layer_annotations because no ONNX file was generated in %s.", + output_model_filepath.parent, + ) + # add additional files generated by model builder to model_attributes additional_files = model_attributes.get("additional_files") or [] if metadata_only: @@ -338,20 +371,36 @@ def _run_for_config( str(output_model_filepath.parent / "genai_config.json"), ] else: + primary_model_paths = {str(fp) for fp in primary_onnx_files} model_attributes["additional_files"] = sorted( set(additional_files) # all files in the output directory except the model and model.data files | {str(fp) for fp in output_model_filepath.parent.iterdir()} - - {str(output_model_filepath), str(output_model_filepath) + ".data"} + - primary_model_paths + - {f"{path}.data" for path in primary_model_paths} ) if metadata_only: output_model = copy.copy(model) output_model.model_attributes = model_attributes + elif is_multi_file_model: + # Use the ONNX filenames as component names so child passes write back to encoder.onnx/decoder.onnx + # instead of defaulting to model.onnx. + component_names = [fp.name for fp in generated_onnx_files] + components = [ + ONNXModelHandler(output_model_filepath.parent, onnx_file_name=component_name) + for component_name in component_names + ] + output_model = CompositeModelHandler( + components, + component_names, + model_path=output_model_filepath.parent, + model_attributes=model_attributes, + ) else: output_model = ONNXModelHandler( output_model_filepath.parent, - onnx_file_name=output_model_filepath.name, + onnx_file_name=resolved_single_model_filepath.name, model_attributes=model_attributes, ) @@ -526,11 +575,6 @@ def patched_make_embedding(self, embedding): import onnx_ir as ir basename = "/model/embed_tokens" - if getattr(self, "int4_tied_embeddings", False) or getattr(self, "shared_embeddings", False): - logger.debug( - "int4_tied_embedding/shared_embeddings is set to True but will be ignored. Use TieWordEmbeddings graph surgery to tie" - " embeddings." - ) if hasattr(embedding, "qweight"): qweight = "model.embed_tokens.qweight" diff --git a/olive/systems/docker/docker_system.py b/olive/systems/docker/docker_system.py index 8371cccd44..2a479ec690 100644 --- a/olive/systems/docker/docker_system.py +++ b/olive/systems/docker/docker_system.py @@ -232,6 +232,8 @@ def _prepare_run_params(self) -> dict: def _prepare_environment(self, base_env) -> dict: """Prepare environment variables for container.""" + from olive.telemetry.telemetry import is_ci_environment + # Convert list to dict if needed if isinstance(base_env, list): environment = {env.split("=")[0]: env.split("=")[1] for env in base_env} @@ -241,6 +243,8 @@ def _prepare_environment(self, base_env) -> dict: # Add default environment variables environment.setdefault("PYTHONPYCACHEPREFIX", "/tmp") environment["OLIVE_LOG_LEVEL"] = logging.getLevelName(logger.getEffectiveLevel()) + if is_ci_environment(): + environment["CI"] = "1" # Add HuggingFace token if needed if self.hf_token: diff --git a/olive/systems/docker/workflow_runner.py b/olive/systems/docker/workflow_runner.py index 5842d0bd49..be0d59d671 100644 --- a/olive/systems/docker/workflow_runner.py +++ b/olive/systems/docker/workflow_runner.py @@ -20,7 +20,7 @@ def runner_entry(config): config = json.load(f) logger.info("Running workflow with config: %s", config) - olive_run(config) + olive_run(config, emit_recipe_telemetry=False) if __name__ == "__main__": diff --git a/olive/telemetry/constants.py b/olive/telemetry/constants.py index ca9e150b1b..25a60e813e 100644 --- a/olive/telemetry/constants.py +++ b/olive/telemetry/constants.py @@ -3,6 +3,6 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- -"""OneCollector connection string.""" +"""Telemetry constants.""" -CONNECTION_STRING = "SW5zdHJ1bWVudGF0aW9uS2V5PTlkNWRkYWVjNjFlMjQ1NjdiNzg4YTIwYWVhMzI0NjMxLTcyMzdkN2M2LWVlNjEtNGNmZC1iYjdiLTU5MDNhOTcyYzJlNC03MDQ3" +CONNECTION_STRING = "SW5zdHJ1bWVudGF0aW9uS2V5PTYyMTUwOTExZGMwMDRmYzliYjY3YmE5NjA2NDI3ZTU2LWVjNjFmOWFmLTVkN2EtNGQxOS1hZjMxLWI5Y2Q2OWU5ODdmMS02OTE1" diff --git a/olive/telemetry/library/options.py b/olive/telemetry/library/options.py index dd934cad2d..31fd1ba195 100644 --- a/olive/telemetry/library/options.py +++ b/olive/telemetry/library/options.py @@ -62,6 +62,7 @@ class OneCollectorExporterOptions: """Configuration options for OneCollector exporter.""" connection_string: Optional[str] = None + service_name: Optional[str] = None transport_options: OneCollectorTransportOptions = field(default_factory=OneCollectorTransportOptions) # Internal fields populated during validation diff --git a/olive/telemetry/library/telemetry_logger.py b/olive/telemetry/library/telemetry_logger.py index 7eb236e759..d3b98fd4bf 100644 --- a/olive/telemetry/library/telemetry_logger.py +++ b/olive/telemetry/library/telemetry_logger.py @@ -6,6 +6,7 @@ """High-level telemetry logger facade for easy usage.""" import logging +import threading import uuid from typing import Any, Callable, Optional @@ -28,6 +29,8 @@ class TelemetryLogger: _instance: Optional["TelemetryLogger"] = None _default_logger: Optional["TelemetryLogger"] = None + _instance_lock = threading.RLock() + _default_logger_lock = threading.RLock() _logger: Optional[logging.Logger] = None _logger_exporter: Optional[OneCollectorLogExporter] = None _logger_provider: Optional[LoggerProvider] = None @@ -40,8 +43,10 @@ def __new__(cls, options: Optional[OneCollectorExporterOptions] = None): """ if cls._instance is None: - cls._instance = super().__new__(cls) - cls._instance._initialize(options) + with cls._instance_lock: + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._initialize(options) return cls._instance @@ -57,10 +62,13 @@ def _initialize(self, options: Optional[OneCollectorExporterOptions]) -> None: self._logger_exporter = OneCollectorLogExporter(options=options) # Create logger provider + service_name = ( + options.service_name if options and options.service_name else __name__.split(".", maxsplit=1)[0] + ) self._logger_provider = LoggerProvider( resource=Resource.create( { - "service.name": __name__.split(".", maxsplit=1)[0], + "service.name": service_name, "service.version": VERSION, "service.instance.id": str(uuid.uuid4()), # Unique instance ID; can double as session ID } @@ -141,43 +149,54 @@ def shutdown(self) -> None: self._logger_provider.shutdown() @classmethod - def get_default_logger(cls, connection_string: Optional[str] = None) -> "TelemetryLogger": + def get_default_logger( + cls, connection_string: Optional[str] = None, service_name: Optional[str] = None + ) -> "TelemetryLogger": """Get or create the default telemetry logger. Args: connection_string: OneCollector connection string (only used on first call) + service_name: Logical application/service name for emitted telemetry (only used on first call) Returns: TelemetryLogger instance """ if cls._default_logger is None: - options = None - if connection_string: - options = OneCollectorExporterOptions(connection_string=connection_string) - cls._default_logger = cls(options=options) + with cls._default_logger_lock: + if cls._default_logger is None: + options = None + if connection_string: + options = OneCollectorExporterOptions( + connection_string=connection_string, service_name=service_name + ) + cls._default_logger = cls(options=options) return cls._default_logger @classmethod def shutdown_default_logger(cls) -> None: """Shutdown the default telemetry logger.""" - if cls._default_logger: - cls._default_logger.shutdown() - cls._default_logger = None + with cls._default_logger_lock: + if cls._default_logger: + cls._default_logger.shutdown() + cls._default_logger = None -def get_telemetry_logger(connection_string: Optional[str] = None) -> TelemetryLogger: +def get_telemetry_logger( + connection_string: Optional[str] = None, service_name: Optional[str] = None +) -> TelemetryLogger: """Get or create the default telemetry logger. Args: connection_string: OneCollector connection string (only used on first call) + service_name: Logical application/service name for emitted telemetry (only used on first call) Returns: TelemetryLogger instance """ - return TelemetryLogger.get_default_logger(connection_string=connection_string) + return TelemetryLogger.get_default_logger(connection_string=connection_string, service_name=service_name) def log_event(event_name: str, attributes: Optional[dict[str, Any]] = None) -> None: diff --git a/olive/telemetry/recipe_telemetry.py b/olive/telemetry/recipe_telemetry.py new file mode 100644 index 0000000000..baf735e9e3 --- /dev/null +++ b/olive/telemetry/recipe_telemetry.py @@ -0,0 +1,432 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import functools +import json +import re +from copy import deepcopy +from os import PathLike +from pathlib import Path, PurePosixPath, PureWindowsPath +from typing import TYPE_CHECKING, Any, Optional, Union + +from olive.common.config_utils import load_config_file +from olive.common.utils import hash_dict +from olive.package_config import OlivePackageConfig +from olive.systems.common import SystemType +from olive.telemetry.telemetry import is_ci_environment + +if TYPE_CHECKING: + from olive.workflows.run.config import RunConfig + +RECIPE_HASH_REDACTED_VALUE = "" +CONFIG_REFERENCE_REDACTED_VALUE = "" +CONFIG_CALLABLE_REDACTED_VALUE = "" +RECIPE_HASH_REDACTED_KEYS = { + "output_dir", + "cache_dir", + "tempdir", + "additional_files", + "dockerfile", + "build_context_path", + "python_environment_path", + "prepend_to_path", + "script_dir", + "model_script", + # package_config is tracked separately via package_config_provided and + # package_config_overrides, but excluded from recipe_hash because it is an + # environment/infrastructure path. + "package_config", + "work_dir", +} +CONFIG_SNAPSHOT_REDACTED_KEYS = RECIPE_HASH_REDACTED_KEYS | { + "model_path", + "_name_or_path", + "adapter_path", + "user_script", +} +HF_MODEL_IDENTIFIER_KEYS = {"model_path", "_name_or_path"} +CONFIG_REFERENCE_KEYS = {"host", "target", "evaluator"} +LOCAL_MODEL_FILE_SUFFIXES = {".bin", ".model", ".onnx", ".pb", ".pt", ".pth", ".safetensors", ".tflite"} +HF_CACHE_MODEL_PATTERN = re.compile(r"(?:^|[\\/])models--([^\\/]+)--([^\\/]+)(?:[\\/]|$)") +HF_REPO_ID_PATTERN = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]*(/[A-Za-z0-9][A-Za-z0-9._-]*)?$") +_NO_OVERRIDE = object() + + +def _build_recipe_result_metadata( + run_config_input: Union[str, Path, dict], + run_config_telemetry_input: Optional[Any], + run_config: Optional["RunConfig"], + recipe_telemetry_metadata: Optional[dict[str, Any]], + *, + list_required_packages: bool, + package_config_input: Optional[Union[str, Path, dict]], + package_config_provided: bool, +) -> dict[str, Any]: + metadata = dict(recipe_telemetry_metadata or {}) + default_source, default_format = _classify_run_config_source(run_config_input) + metadata.setdefault("recipe_source", default_source) + metadata.setdefault("recipe_format", default_format) + metadata.setdefault("execution_mode", "list_required_packages" if list_required_packages else "run") + metadata.setdefault("package_config_provided", package_config_provided) + config_overrides = metadata.pop("config_overrides", _NO_OVERRIDE) + if config_overrides is _NO_OVERRIDE: + config_overrides = _build_config_overrides(run_config_telemetry_input) + elif not isinstance(config_overrides, str): + config_overrides = _build_config_overrides(config_overrides) + if config_overrides is not None: + metadata["config_overrides"] = config_overrides + if package_config_provided: + package_config_overrides = _build_package_config_overrides(package_config_input) + if package_config_overrides is not None: + metadata.setdefault("package_config_overrides", package_config_overrides) + metadata["is_ci"] = is_ci_environment() + + if run_config is None: + metadata.setdefault("recipe_name", metadata.get("recipe_command") or "WorkflowRun") + return metadata + + run_config_json = run_config.to_json(make_absolute=False) + model_metadata = _extract_input_model_metadata(run_config_json["input_model"]) + target_metadata = _extract_target_metadata(run_config) + host_metadata = _extract_host_metadata(run_config) + pass_types = _get_used_pass_types(run_config) + + metadata.setdefault("recipe_name", metadata.get("recipe_command") or run_config.workflow_id) + metadata.setdefault("workflow_id", run_config.workflow_id) + metadata.setdefault("recipe_hash", _build_recipe_hash(run_config_json)) + metadata.setdefault("input_model_type", run_config.input_model.type) + metadata.setdefault("input_model_source", model_metadata["input_model_source"]) + metadata.setdefault("model_task", model_metadata["model_task"]) + _set_metadata_if_present(metadata, target_metadata) + _set_metadata_if_present(metadata, host_metadata) + metadata.setdefault("pass_types", ";".join(pass_types)) + metadata.setdefault("pass_count", len(pass_types)) + metadata.setdefault("data_config_count", len(run_config.data_configs)) + metadata.setdefault("search_enabled", bool(run_config.engine.search_strategy)) + return metadata + + +def _classify_run_config_source(run_config_input: Any) -> tuple[str, str]: + if isinstance(run_config_input, dict): + return "config_dict", "dict" + + if isinstance(run_config_input, (str, PathLike)): + suffix = Path(run_config_input).suffix.lstrip(".").lower() + return "config_file", suffix or "unknown" + + return "config_object", "object" + + +def _build_config_overrides(config_input: Any) -> Optional[str]: + try: + config_data = _load_config_input_for_telemetry(config_input) + if config_data is None: + return None + + snapshot = _sanitize_config_snapshot(config_data) + if snapshot in (None, {}, []): + return None + + return json.dumps(snapshot, ensure_ascii=False, sort_keys=True, separators=(",", ":")) + except Exception: + return None + + +def _build_package_config_overrides(config_input: Any) -> Optional[str]: + try: + config_data = _load_config_input_for_telemetry(config_input) + if not isinstance(config_data, dict): + return None + + default_config = _load_default_package_config_for_telemetry() + baseline = ( + _normalize_package_config_snapshot(default_config) if isinstance(default_config, dict) else _NO_OVERRIDE + ) + overrides = _extract_config_overrides(_normalize_package_config_snapshot(config_data), baseline) + if overrides is _NO_OVERRIDE: + return None + + snapshot = _sanitize_config_snapshot(overrides) + if not isinstance(snapshot, dict): + return None + + return json.dumps(snapshot, ensure_ascii=False, sort_keys=True, separators=(",", ":")) + except Exception: + return None + + +@functools.lru_cache +def _load_default_package_config_for_telemetry() -> Optional[dict[str, Any]]: + try: + default_config = load_config_file(OlivePackageConfig.get_default_config_path()) + except Exception: + return None + + return default_config if isinstance(default_config, dict) else None + + +def _normalize_package_config_snapshot(config_data: Any) -> Any: + if not isinstance(config_data, dict): + return config_data + + normalized = deepcopy(config_data) + passes = normalized.get("passes") + if isinstance(passes, dict): + normalized["passes"] = {str(pass_name).lower(): pass_config for pass_name, pass_config in passes.items()} + return normalized + + +def _extract_config_overrides(value: Any, baseline: Any = _NO_OVERRIDE) -> Any: + if baseline is _NO_OVERRIDE: + return deepcopy(value) + + if isinstance(value, dict) and isinstance(baseline, dict): + overrides = {} + for key, child_value in value.items(): + child_override = _extract_config_overrides(child_value, baseline.get(key, _NO_OVERRIDE)) + if child_override is not _NO_OVERRIDE: + overrides[key] = child_override + if overrides: + return overrides + return _NO_OVERRIDE if value == baseline else {} + + if isinstance(value, list): + if isinstance(baseline, list) and value == baseline: + return _NO_OVERRIDE + return deepcopy(value) + + if isinstance(value, tuple): + value_list = list(value) + baseline_list = list(baseline) if isinstance(baseline, tuple) else baseline + if isinstance(baseline_list, list) and value_list == baseline_list: + return _NO_OVERRIDE + return value_list + + return deepcopy(value) if value != baseline else _NO_OVERRIDE + + +def _load_config_input_for_telemetry(config_input: Any) -> Optional[Any]: + if config_input is None: + return None + if isinstance(config_input, dict): + return deepcopy(config_input) + if isinstance(config_input, (str, PathLike)): + return load_config_file(config_input) + + model_dump = getattr(config_input, "model_dump", None) + if callable(model_dump): + return model_dump(exclude_defaults=True, exclude_none=True, by_alias=True) + return None + + +def _sanitize_config_snapshot(value: Any, key: Optional[str] = None, model_type: Optional[str] = None) -> Any: + if key in HF_MODEL_IDENTIFIER_KEYS: + if str(model_type).lower() == "hfmodel": + hf_model_id = _extract_huggingface_model_id(value) + if hf_model_id: + return hf_model_id + return RECIPE_HASH_REDACTED_VALUE + if key in CONFIG_SNAPSHOT_REDACTED_KEYS or _is_path_like_key(key): + return RECIPE_HASH_REDACTED_VALUE + if key in CONFIG_REFERENCE_KEYS and isinstance(value, str): + return CONFIG_REFERENCE_REDACTED_VALUE + + if isinstance(value, dict): + child_model_type = _get_model_type(value) or model_type + if key == "systems": + return [_sanitize_config_snapshot(system, "system", child_model_type) for system in value.values()] + if key == "passes": + passes = [] + for pass_configs in value.values(): + if isinstance(pass_configs, list): + passes.extend(pass_configs) + else: + passes.append(pass_configs) + return [_sanitize_config_snapshot(pass_config, "pass", child_model_type) for pass_config in passes] + if key == "evaluators": + return [ + _sanitize_config_snapshot(evaluator, "evaluator_config", child_model_type) + for evaluator in value.values() + ] + return { + child_key: _sanitize_config_snapshot(child_value, child_key, child_model_type) + for child_key, child_value in value.items() + if child_value is not None + } + if isinstance(value, list): + return [_sanitize_config_snapshot(item, key, model_type) for item in value] + if isinstance(value, tuple): + return [_sanitize_config_snapshot(item, key, model_type) for item in value] + if isinstance(value, Path): + return RECIPE_HASH_REDACTED_VALUE + if callable(value): + return CONFIG_CALLABLE_REDACTED_VALUE + if isinstance(value, (str, int, float, bool)) or value is None: + return value + if hasattr(value, "value") and isinstance(value.value, (str, int, float, bool)): + return value.value + return f"<{type(value).__name__}>" + + +def _is_path_like_key(key: Optional[str]) -> bool: + if key is None: + return False + return key in {"path", "paths", "dir", "dirs", "file", "files"} or key.endswith( + ("_path", "_paths", "_dir", "_dirs", "_file", "_files") + ) + + +def _get_model_type(config: dict[str, Any]) -> Optional[str]: + model_type = config.get("type") + return str(model_type).lower() if model_type is not None else None + + +def _extract_huggingface_model_id(model_identifier: Any) -> Optional[str]: + if not isinstance(model_identifier, str): + return None + + identifier = model_identifier.strip() + if not identifier: + return None + + if identifier.startswith("https://huggingface.co/"): + parts = identifier.removeprefix("https://huggingface.co/").strip("/").split("/") + if len(parts) >= 2: + return f"{parts[0]}/{parts[1]}" + if parts and parts[0]: + return parts[0] + + if match := HF_CACHE_MODEL_PATTERN.search(identifier): + return f"{match.group(1)}/{match.group(2)}" + + if HF_REPO_ID_PATTERN.match(identifier) and not _has_local_model_file_suffix(identifier): + return identifier + + return None + + +def _extract_input_model_metadata(input_model_config: dict[str, Any]) -> dict[str, Optional[str]]: + model_config = input_model_config.get("config", {}) + model_attributes = model_config.get("model_attributes", {}) + model_task = model_attributes.get("hf_task") or model_config.get("task") + raw_identifier = model_attributes.get("_name_or_path") or model_config.get("model_path") + return { + "input_model_source": _classify_input_model_source(raw_identifier), + "model_task": str(model_task) if model_task is not None else None, + } + + +def _classify_input_model_source(model_identifier: Any) -> str: + if model_identifier is None: + return "unknown" + if isinstance(model_identifier, dict): + resource_type = model_identifier.get("type") + if resource_type == "azureml_registry_model": + return "azureml" + return "structured_resource" + + identifier = str(model_identifier) + if identifier.startswith("azureml://"): + return "azureml" + if identifier.startswith("https://huggingface.co/"): + return "huggingface_url" + if identifier.startswith(("http://", "https://")): + return "url" + + if _is_explicit_local_model_path(identifier): + suffix = PureWindowsPath(identifier).suffix or PurePosixPath(identifier).suffix + return "local_file" if suffix else "local_folder" + return "string_name" + + +def _is_explicit_local_model_path(identifier: str) -> bool: + if _has_local_model_file_suffix(identifier): + return True + return ( + identifier.startswith(("./", "../", ".\\", "..\\", "~/", "~\\", "/", "\\\\")) + or PureWindowsPath(identifier).is_absolute() + or PurePosixPath(identifier).is_absolute() + ) + + +def _has_local_model_file_suffix(identifier: str) -> bool: + suffix = PureWindowsPath(identifier).suffix or PurePosixPath(identifier).suffix + return suffix.lower() in LOCAL_MODEL_FILE_SUFFIXES + + +def _extract_target_metadata(run_config: "RunConfig") -> dict[str, Optional[str]]: + target_system = run_config.engine.target + return _extract_system_metadata(target_system, "target") + + +def _extract_host_metadata(run_config: "RunConfig") -> dict[str, Optional[str]]: + host_system = run_config.engine.host + if host_system is None: + return { + "host_system_type": SystemType.Local.value, + } + return _extract_system_metadata(host_system, "host") + + +def _extract_system_metadata(system_config: Optional[Any], field_prefix: str) -> dict[str, Optional[str]]: + system_type = system_config.type.value if system_config is not None else None + device = None + execution_provider = None + execution_providers = None + + accelerators = system_config.config.accelerators if system_config and system_config.config else None + if accelerators: + accelerator = accelerators[0] + device = str(accelerator.device) if accelerator.device is not None else None + ep_values = accelerator.get_ep_strs() or [] + if ep_values: + execution_provider = ep_values[0] + execution_providers = ";".join(ep_values) + + return { + f"{field_prefix}_system_type": system_type, + f"{field_prefix}_device": device, + f"{field_prefix}_execution_provider": execution_provider, + f"{field_prefix}_execution_providers": execution_providers, + } + + +def _set_metadata_if_present(metadata: dict[str, Any], values: dict[str, Optional[str]]) -> None: + for key, value in values.items(): + if value is not None: + metadata.setdefault(key, value) + + +def _get_used_pass_types(run_config: "RunConfig") -> list[str]: + return ( + [pass_config.type for _, pass_configs in run_config.passes.items() for pass_config in pass_configs] + if run_config.passes + else [] + ) + + +def _build_recipe_hash(run_config_json: dict[str, Any]) -> str: + sanitized = deepcopy(run_config_json) + _redact_recipe_hash_keys(sanitized) + return hash_dict(sanitized)[:16] + + +def _redact_recipe_hash_keys(value: Any, key: Optional[str] = None) -> Any: + if key in RECIPE_HASH_REDACTED_KEYS or _is_path_like_key(key): + return RECIPE_HASH_REDACTED_VALUE + if isinstance(value, dict): + for child_key in list(value): + value[child_key] = _redact_recipe_hash_keys(value[child_key], child_key) + elif isinstance(value, list): + for index, item in enumerate(value): + value[index] = _redact_recipe_hash_keys(item, key) + elif isinstance(value, tuple): + return [_redact_recipe_hash_keys(item, key) for item in value] + elif isinstance(value, Path): + return RECIPE_HASH_REDACTED_VALUE + elif callable(value): + return CONFIG_CALLABLE_REDACTED_VALUE + elif hasattr(value, "value") and isinstance(value.value, (str, int, float, bool)): + return value.value + return value diff --git a/olive/telemetry/telemetry.py b/olive/telemetry/telemetry.py index 0ddb690e2a..97d287ba7d 100644 --- a/olive/telemetry/telemetry.py +++ b/olive/telemetry/telemetry.py @@ -5,7 +5,6 @@ """Thin wrapper around the OneCollector telemetry logger with event helpers.""" import base64 -import errno import json import os import platform @@ -19,8 +18,6 @@ from olive.telemetry.library.event_source import event_source from olive.telemetry.library.telemetry_logger import TelemetryLogger, get_telemetry_logger from olive.telemetry.utils import ( - _decode_cache_line, - _encode_cache_line, _exclusive_file_lock, get_telemetry_base_dir, ) @@ -30,6 +27,7 @@ # Default event names used by the high-level telemetry helpers. HEARTBEAT_EVENT_NAME = "OliveHeartbeat" +RECIPE_EVENT_NAME = "OliveRecipe" # CI/CD environment variables whose presence indicates an automated pipeline. _CI_ENV_VARS = ( @@ -43,6 +41,7 @@ ) ACTION_EVENT_NAME = "OliveAction" ERROR_EVENT_NAME = "OliveError" +APP_NAME = "Olive" ALLOWED_KEYS = { HEARTBEAT_EVENT_NAME: { @@ -72,6 +71,38 @@ "app_instance_id", "initTs", }, + RECIPE_EVENT_NAME: { + "recipe_name", + "recipe_hash", + "recipe_source", + "recipe_format", + "recipe_command", + "execution_mode", + "workflow_id", + "config_overrides", + "success", + "input_model_type", + "input_model_source", + "model_task", + "target_system_type", + "target_device", + "target_execution_provider", + "target_execution_providers", + "host_system_type", + "host_device", + "host_execution_provider", + "host_execution_providers", + "pass_types", + "pass_count", + "data_config_count", + "search_enabled", + "package_config_provided", + "package_config_overrides", + "is_ci", + "app_version", + "app_instance_id", + "initTs", + }, } CRITICAL_EVENTS = {HEARTBEAT_EVENT_NAME} @@ -80,6 +111,11 @@ CACHE_FILE_NAME = "olive.json" +def is_ci_environment() -> bool: + """Detect CI/CD environments by checking well-known environment variables.""" + return any(os.environ.get(var) for var in _CI_ENV_VARS) + + class TelemetryCacheHandler: """Handles caching of failed telemetry events for offline resilience. @@ -103,13 +139,17 @@ def __init__(self, telemetry: "Telemetry") -> None: # Single shared cache file for all processes self._cache_file_name = CACHE_FILE_NAME self._shutdown = False - # Protects all shared state to prevent race conditions - self._lock = threading.Lock() - self._callback_condition = threading.Condition() + # Single condition protects all shared state: _shutdown, _is_flushing, + # _callbacks_item_count, _events_logged. Using one lock eliminates + # lock ordering issues that arise with separate locks. + self._condition = threading.Condition() self._callbacks_item_count = 0 self._events_logged = 0 # Prevents concurrent flush operations self._is_flushing = False + # Tracks whether any replayed event failed during the current flush + # so the flush file can be preserved instead of silently dropped. + self._flush_failed = False def shutdown(self) -> None: """Signal shutdown to prevent new operations. @@ -118,7 +158,7 @@ def shutdown(self) -> None: offline resilience. If network is working, success callbacks already flushed. If network is down, flushing would fail anyway. """ - with self._lock: + with self._condition: self._shutdown = True def __del__(self): @@ -150,18 +190,19 @@ def on_payload_transmitted(self, args: "PayloadTransmittedCallbackArgs") -> None payload = None should_flush = False - with self._lock: + with self._condition: if self._shutdown: return - # Skip callbacks from replayed events during flush - # If a flush is in progress it means we successfully sent an event, - # so it's unlikely that an event would suddenly fail and need to be cached - # and we don't need to flush again. + # Callbacks for replayed events: don't trigger a new flush or + # re-cache, but record whether the replay actually succeeded so + # _flush_cache_file can decide whether to delete or restore the + # flush file. Falling through to the finally block still + # increments _callbacks_item_count so wait_for_callbacks can + # complete. if self._is_flushing: - with self._callback_condition: - self._callbacks_item_count += args.item_count - self._callback_condition.notify_all() + if not args.succeeded: + self._flush_failed = True return if args.succeeded: @@ -182,26 +223,25 @@ def on_payload_transmitted(self, args: "PayloadTransmittedCallbackArgs") -> None # Fail silently - telemetry should never crash the application pass finally: - with self._callback_condition: + with self._condition: self._callbacks_item_count += args.item_count - self._callback_condition.notify_all() + # Wake threads waiting for flush/shutdown callback accounting. + self._condition.notify_all() def wait_for_callbacks(self, timeout_sec: float, during_flush: bool = False) -> bool: + """Wait until callbacks have caught up with logged telemetry items.""" deadline = time.time() + timeout_sec - while True: - with self._callback_condition: - callbacks_item_count = self._callbacks_item_count - expected_items = self._events_logged - if (during_flush or not self.is_flushing) and callbacks_item_count >= expected_items: + with self._condition: + while True: + if (during_flush or not self._is_flushing) and self._callbacks_item_count >= self._events_logged: return True - remaining = deadline - time.time() - if remaining <= 0: - return False - with self._callback_condition: - self._callback_condition.wait(timeout=remaining) + remaining = deadline - time.time() + if remaining <= 0: + return False + self._condition.wait(timeout=remaining) def record_event_logged(self, count: int = 1) -> None: - with self._callback_condition: + with self._condition: self._events_logged += count def _schedule_flush(self) -> None: @@ -218,7 +258,7 @@ def _schedule_flush(self) -> None: - Daemon thread is acceptable (flush is best-effort) """ # Check before spawning thread to avoid unnecessary thread creation - with self._lock: + with self._condition: if self._shutdown or self._is_flushing: return self._is_flushing = True @@ -230,9 +270,11 @@ def flush_task(): # Fail silently pass finally: - # Always clear flag, even on exception - with self._lock: + # Always clear flag, even on exception, and wake any waiters + # (e.g. shutdown) that are blocked on _is_flushing becoming False. + with self._condition: self._is_flushing = False + self._condition.notify_all() thread = threading.Thread(target=flush_task, daemon=True) thread.start() @@ -246,8 +288,9 @@ def cache_path(self) -> Optional[Path]: """ telemetry_cache_dir = None - if "OLIVE_TELEMETRY_CACHE_DIR" in os.environ: - telemetry_cache_dir = os.environ["OLIVE_TELEMETRY_CACHE_DIR"] + telemetry_cache_dir_override = (os.environ.get("OLIVE_TELEMETRY_CACHE_DIR") or "").strip() + if telemetry_cache_dir_override: + telemetry_cache_dir = Path(telemetry_cache_dir_override).expanduser() if not telemetry_cache_dir: telemetry_cache_dir = get_telemetry_base_dir() / "cache" return telemetry_cache_dir / self._cache_file_name @@ -258,13 +301,11 @@ def _write_payload_to_cache(self, payload: bytes) -> None: Design decisions: - Parse payload to extract individual events (allows filtering) - Filter to only critical events near size limit (preserves important data) - - Use file locking for multi-process safety (prevents corruption) - - Use exponential backoff for file contention (avoids spinning) + - Use exclusive file lock to serialize concurrent writers - Fail silently on errors (telemetry should never crash app) Assumptions: - JSON operations are fast enough for synchronous execution - - File contention is rare and transient (retry a few times) - Cache size limits prevent unbounded growth - Critical events (heartbeat) are more important than others """ @@ -280,36 +321,25 @@ def _write_payload_to_cache(self, payload: bytes) -> None: cache_path.parent.mkdir(parents=True, exist_ok=True) - max_retries = 3 - for attempt in range(max_retries + 1): - try: - cache_size = cache_path.stat().st_size if cache_path.exists() else 0 - - # Hard limit: stop caching entirely to prevent unbounded growth - if cache_size >= HARD_MAX_CACHE_SIZE_BYTES: - return - - # Soft limit: keep only critical events to preserve space - if cache_size >= MAX_CACHE_SIZE_BYTES: - entries = [entry for entry in entries if entry["event_name"] in CRITICAL_EVENTS] - if not entries: - return - - # Append base64-encoded newline-delimited entries - # Use exclusive file lock for multi-process safety - with _exclusive_file_lock(cache_path, mode="a") as cache_file: - for entry in entries: - plain = json.dumps(entry, ensure_ascii=False, separators=(",", ":")) - cache_file.write(_encode_cache_line(plain) + "\n") + cache_size = cache_path.stat().st_size if cache_path.exists() else 0 + + # Hard limit: stop caching entirely to prevent unbounded growth + if cache_size >= HARD_MAX_CACHE_SIZE_BYTES: + return + + # Soft limit: keep only critical events to preserve space + if cache_size >= MAX_CACHE_SIZE_BYTES: + entries = [entry for entry in entries if entry["event_name"] in CRITICAL_EVENTS] + if not entries: return - except OSError as exc: - # Retry only on transient access errors (file locked by another process) - if exc.errno not in {errno.EACCES, errno.EAGAIN, errno.EWOULDBLOCK, errno.EBUSY}: - return - if attempt >= max_retries: - return - # Exponential backoff: 50ms, 100ms, 200ms (aligned with C# implementation) - time.sleep(0.05 * (2**attempt)) + + # Append newline-delimited JSON entries. The exclusive file lock + # blocks until the previous writer releases, which serializes + # concurrent writers across processes without an explicit retry + # loop. + with _exclusive_file_lock(cache_path, mode="a") as cache_file: + for entry in entries: + cache_file.write(json.dumps(entry, ensure_ascii=False, separators=(",", ":")) + "\n") except Exception: # Fail silently - telemetry errors should not crash the application return @@ -322,62 +352,64 @@ def _flush_cache(self) -> None: self._flush_cache_file(cache_path) + def _restore_flush_file(self, flush_path: Optional[Path], cache_path: Path) -> None: + """Restore a claimed flush file back into the cache without overwriting new entries. + + Another process may create a fresh cache file while this process is flushing. + Appending the old flush contents preserves both sets of entries. + """ + if not flush_path or not flush_path.exists(): + return + + try: + cache_path.parent.mkdir(parents=True, exist_ok=True) + with ( + _exclusive_file_lock(cache_path, mode="a") as cache_file, + _exclusive_file_lock(flush_path, mode="r") as flush_file, + ): + for raw_line in flush_file: + line = raw_line.rstrip("\n") + if line: + cache_file.write(line + "\n") + flush_path.unlink(missing_ok=True) + except Exception: + # Best-effort cache restore must never interrupt telemetry flow. + # Leave the flush file in place so a later retry can attempt recovery again. + return + def _flush_cache_file(self, cache_path: Path) -> None: """Flush cached events back to telemetry service. - Approach: - 1. Atomically rename cache → .flush (claims ownership, prevents concurrent flushes) - 2. Read all events from .flush file - 3. Queue all events for sending via telemetry logger - 4. Force flush with 2-second timeout - 5. On success: delete .flush file - 6. On failure: restore .flush → cache for retry - - Multi-process coordination: - - `replace()` is atomic; only one process can successfully rename the cache file - - If another process already renamed it, we get FileNotFoundError and abort - - Stale .flush files from crashes are overwritten by the atomic rename - - Shutdown handling: - - If shutdown flag set during flush, restore cache before returning - - This preserves events even if callbacks don't fire during shutdown - - Callback behavior: - - Queued events trigger callbacks with success/failure - - Failed events are automatically re-cached via callbacks (unless shutting down) - - The _is_flushing flag prevents re-caching of replayed events during flush + Uses atomic rename to claim the cache file, preventing duplicate + sends when multiple processes flush concurrently. """ flush_path = None try: - # Check shutdown before starting (under lock to prevent race) - with self._lock: + with self._condition: if self._shutdown: return + # Reset failure tracking so this flush only observes failures + # for events it actually replays. _schedule_flush guards + # against concurrent flushes, so resetting here is safe. + self._flush_failed = False - if not cache_path.exists(): - return - - # Atomically rename to .flush file to claim ownership - # Overwrite any stale .flush file from crashed process (C# pattern) + # Atomically rename to claim ownership — only one process can succeed flush_path = cache_path.with_name(f"{cache_path.name}.flush") try: - # On Windows/POSIX, replace() overwrites existing files atomically cache_path.replace(flush_path) except FileNotFoundError: - # Cache already claimed by another flush or doesn't exist return - # Read all cached entries (base64-decoded) entries = _read_cache_entries(flush_path) - if not entries: - # Empty cache, just delete the flush file - flush_path.unlink(missing_ok=True) + if flush_path.stat().st_size == 0: + flush_path.unlink(missing_ok=True) + else: + self._restore_flush_file(flush_path, cache_path) return - # Replay all events through telemetry logger - # Note: _is_flushing flag (set by caller) prevents these callbacks from re-caching or triggering nested flushes - # (unlikely since we just successfully sent an event, indicating network is available) + # Replay cached events — _is_flushing flag prevents re-caching but + # callbacks still update _flush_failed so we can detect failures. for entry in entries: try: event_name = entry["event_name"] @@ -387,57 +419,55 @@ def _flush_cache_file(self, cache_path: Path) -> None: attributes = json.loads(event_data) if not isinstance(attributes, dict): continue - # Preserve original timestamp attributes["initTs"] = entry.get("initTs", entry["ts"]) self._telemetry.log(event_name, attributes, None) except Exception: - # Skip malformed entries continue - # Check if shutdown happened during flush - with self._lock: - if self._shutdown: - # Restore cache to avoid data loss during shutdown - if flush_path and flush_path.exists(): - try: - cache_path.parent.mkdir(parents=True, exist_ok=True) - flush_path.replace(cache_path) - except Exception: - # Silently ignore errors during cleanup - pass - return + callbacks_completed = self.wait_for_callbacks(timeout_sec=5.0, during_flush=True) + with self._condition: + replay_failed = self._flush_failed - # Wait for in-flight callbacks to complete before deciding success/failure - flush_success = self.wait_for_callbacks(timeout_sec=5.0, during_flush=True) - if flush_success: - # Success: delete the flush file (events were sent) - if flush_path: - flush_path.unlink(missing_ok=True) - elif flush_path and flush_path.exists(): - # Failure: restore cache for retry later - cache_path.parent.mkdir(parents=True, exist_ok=True) - flush_path.replace(cache_path) + # Only delete the flush file when every replayed event was acknowledged + # AND none of them failed. Otherwise preserve the cache so a later + # flush can retry, guaranteeing we never silently drop events. + if callbacks_completed and not replay_failed: + flush_path.unlink(missing_ok=True) + else: + self._restore_flush_file(flush_path, cache_path) except Exception: - # Best-effort restore on any exception to prevent data loss - if flush_path and flush_path.exists(): - try: - cache_path.parent.mkdir(parents=True, exist_ok=True) - flush_path.replace(cache_path) - except Exception: - # If restore fails, we lose the data (acceptable for telemetry) - pass - return + # Best-effort restore on failure + self._restore_flush_file(flush_path, cache_path) @property def is_flushing(self) -> bool: - with self._lock: + with self._condition: return self._is_flushing + def wait_until_flush_complete(self, timeout_sec: float) -> bool: + """Block until any in-progress flush has finished. + + Returns True if no flush was running (or it finished within the + timeout), False if the timeout elapsed while a flush was still in + progress. Uses condition-variable signalling rather than polling so + the caller wakes immediately when the flush thread clears the flag. + """ + deadline = time.time() + timeout_sec + with self._condition: + while self._is_flushing: + remaining = deadline - time.time() + if remaining <= 0: + return False + self._condition.wait(timeout=remaining) + return True + class Telemetry: """Wrapper that wires environment configuration into the library logger. - This is a singleton class - all instances share the same state. + This is a per-process singleton class - all instances in a process share the same state. + Separate processes get separate in-memory singleton instances and coordinate only through + the shared telemetry cache file lock. Use Telemetry() to get the singleton instance. """ @@ -445,13 +475,9 @@ class Telemetry: _lock = threading.Lock() def __new__(cls): - """Create or return the singleton instance. - - Thread-safe singleton implementation using double-checked locking. - """ + """Create or return the singleton instance.""" if cls._instance is None: with cls._lock: - # Double-check pattern to prevent race conditions if cls._instance is None: instance = super().__new__(cls) instance._initialized = False @@ -466,18 +492,18 @@ def __init__(self): self._logger = None self._cache_handler = None + self._recipe_only_ci_telemetry = False try: self._logger = self._create_logger() event_source.disable() - self._cache_handler = TelemetryCacheHandler(self) - self._setup_payload_callbacks() - if self._is_ci_environment(): - self.disable_telemetry() - self._initialized = True - return - self._log_heartbeat() + is_ci = is_ci_environment() + self._recipe_only_ci_telemetry = is_ci + if not is_ci: + self._cache_handler = TelemetryCacheHandler(self) + self._setup_payload_callbacks() + self._log_heartbeat() if os.environ.get("OLIVE_DISABLE_TELEMETRY") == "1": self.disable_telemetry() self._initialized = True @@ -485,21 +511,16 @@ def __init__(self): # Fail silently — telemetry must never crash the host application self._initialized = True - @staticmethod - def _is_ci_environment() -> bool: - """Detect CI/CD environments by checking well-known environment variables.""" - return any(os.environ.get(var) for var in _CI_ENV_VARS) - def _create_logger(self) -> Optional[TelemetryLogger]: try: - return get_telemetry_logger(base64.b64decode(CONNECTION_STRING).decode()) + return get_telemetry_logger(base64.b64decode(CONNECTION_STRING).decode(), service_name=APP_NAME) except Exception: return None def _setup_payload_callbacks(self) -> None: # Register callback for payload transmission events # No need to store unregister function - logger shutdown will clean up callbacks - if self._logger is None: + if self._logger is None or self._cache_handler is None: return self._logger.register_payload_transmitted_callback( self._cache_handler.on_payload_transmitted, @@ -545,6 +566,8 @@ def log( """ try: + if self._recipe_only_ci_telemetry and event_name != RECIPE_EVENT_NAME: + return attrs = _merge_metadata(attributes, metadata) if self._logger is None: return @@ -605,12 +628,11 @@ def shutdown(self, timeout_millis: float = 10_000, callback_timeout_millis: floa 3. Shutdown logger (cleans up callbacks automatically) """ try: - # Step 1: Wait for pending flush to complete (matches C# 1-second timeout) - start_time = time.time() - while time.time() - start_time < 1.0: - if not self._cache_handler or not self._cache_handler.is_flushing: - break - time.sleep(0.05) + # Step 1: Wait for any in-flight flush to complete (matches C# 1-second timeout). + # Uses condition-variable signalling instead of polling so the wait wakes up + # immediately when the flush thread clears _is_flushing. + if self._cache_handler: + self._cache_handler.wait_until_flush_complete(1.0) # Step 2: Wait for callbacks/flush to complete before shutting down cache handler if self._cache_handler: @@ -743,18 +765,10 @@ def _set_nested_value(data: dict[str, Any], key: str, value: Any) -> None: def _read_cache_entries(cache_path: Path) -> list[dict[str, Any]]: - """Read all entries from a cache file, decoding each line. + """Read all JSON-line entries from a cache file. - Design decisions: - - Use file locking for multi-process safety - - Continue reading past malformed entries (partial data recovery) - - Return empty list on complete read failure (fail gracefully) - - Each line is base64-decoded before JSON parsing. - - Assumptions: - - Cache file contains newline-delimited base64-encoded entries (one per line) - - Each line is independent (one malformed line doesn't affect others) - - Empty or whitespace-only lines are skipped + Each line is independent — malformed lines are skipped without + affecting other entries. Returns empty list on read failure. """ entries = [] try: @@ -764,13 +778,11 @@ def _read_cache_entries(cache_path: Path) -> list[dict[str, Any]]: if not line: continue try: - line = json.loads(_decode_cache_line(line)) - if isinstance(line, dict): - entries.append(line) + parsed = json.loads(line) + if isinstance(parsed, dict): + entries.append(parsed) except Exception: - # Malformed line, skip and continue continue except Exception: - # If file cannot be opened or read, return empty list return [] return entries diff --git a/olive/telemetry/telemetry_extensions.py b/olive/telemetry/telemetry_extensions.py index e5b13395d0..068aa9dd1b 100644 --- a/olive/telemetry/telemetry_extensions.py +++ b/olive/telemetry/telemetry_extensions.py @@ -6,11 +6,11 @@ import functools import inspect import time +import traceback from types import TracebackType from typing import Any, Callable, Optional, TypeVar -from olive.telemetry.telemetry import ACTION_EVENT_NAME, ERROR_EVENT_NAME, _get_logger -from olive.telemetry.utils import _format_exception_message +from olive.telemetry.telemetry import ACTION_EVENT_NAME, ERROR_EVENT_NAME, RECIPE_EVENT_NAME, _get_logger _TFunc = TypeVar("_TFunc", bound=Callable[..., Any]) @@ -45,6 +45,38 @@ def log_error( telemetry.log(ERROR_EVENT_NAME, attributes, metadata) +def log_recipe_result( + recipe_name: str, + success: bool, + metadata: Optional[dict[str, Any]] = None, +) -> None: + telemetry = _get_logger() + attributes = { + "recipe_name": recipe_name, + "success": success, + } + telemetry.log(RECIPE_EVENT_NAME, attributes, metadata) + + +def _format_exception_message(ex: BaseException, tb: Optional[TracebackType] = None) -> str: + """Format an exception and trim local paths for readability.""" + folder = "Olive" + file_line = 'File "' + formatted = traceback.format_exception(type(ex), ex, tb, limit=5) + lines = [] + for line in formatted: + line_trunc = line.strip() + if line_trunc.startswith(file_line) and folder in line_trunc: + idx = line_trunc.find(folder) + if idx != -1: + line_trunc = line_trunc[idx + len(folder) :] + elif line_trunc.startswith(file_line): + idx = line_trunc[len(file_line) :].find('"') + line_trunc = line_trunc[idx + len(file_line) :] + lines.append(line_trunc) + return "\n".join(lines) + + def _resolve_invoked_from(skip_frames: int = 0) -> str: """Resolve how Olive was invoked by examining the call stack. diff --git a/olive/telemetry/utils.py b/olive/telemetry/utils.py index 52a39acded..806f5f93da 100644 --- a/olive/telemetry/utils.py +++ b/olive/telemetry/utils.py @@ -2,17 +2,62 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- -import base64 import functools import os import platform import tempfile -import traceback from pathlib import Path -from types import TracebackType -from typing import Optional +from typing import ClassVar + +if os.name == "nt": + import ctypes + import ctypes.wintypes as wintypes + + _LOCKFILE_EXCLUSIVE_LOCK = 0x00000002 + + class _Overlapped(ctypes.Structure): + _fields_: ClassVar[list[tuple[str, object]]] = [ + ("Internal", ctypes.c_void_p), + ("InternalHigh", ctypes.c_void_p), + ("Offset", wintypes.DWORD), + ("OffsetHigh", wintypes.DWORD), + ("hEvent", wintypes.HANDLE), + ] + + _kernel32 = ctypes.WinDLL("kernel32", use_last_error=True) + _lock_file_ex = _kernel32.LockFileEx + _lock_file_ex.argtypes = [ + wintypes.HANDLE, + wintypes.DWORD, + wintypes.DWORD, + wintypes.DWORD, + wintypes.DWORD, + ctypes.POINTER(_Overlapped), + ] + _lock_file_ex.restype = wintypes.BOOL + _unlock_file_ex = _kernel32.UnlockFileEx + _unlock_file_ex.argtypes = [ + wintypes.HANDLE, + wintypes.DWORD, + wintypes.DWORD, + wintypes.DWORD, + ctypes.POINTER(_Overlapped), + ] + _unlock_file_ex.restype = wintypes.BOOL +else: + ctypes = None + wintypes = None + _lock_file_ex = None + _unlock_file_ex = None + _Overlapped = None ORT_SUPPORT_DIR = r"Microsoft/DeveloperTools/.onnxruntime" +_WINDOWS_FILE_LOCK_LENGTH = 0x7FFFFFFF + + +def _raise_windows_lock_error(message: str) -> None: + error_code = ctypes.get_last_error() if ctypes is not None else 0 + raise OSError(error_code, message) def _resolve_home_dir() -> Path: @@ -50,29 +95,10 @@ def get_telemetry_base_dir() -> Path: return Path(cache_dir).expanduser() / ORT_SUPPORT_DIR -def _format_exception_message(ex: BaseException, tb: Optional[TracebackType] = None) -> str: - """Format an exception and trim local paths for readability.""" - folder = "Olive" - file_line = 'File "' - formatted = traceback.format_exception(type(ex), ex, tb, limit=5) - lines = [] - for line in formatted: - line_trunc = line.strip() - if line_trunc.startswith(file_line) and folder in line_trunc: - idx = line_trunc.find(folder) - if idx != -1: - line_trunc = line_trunc[idx + len(folder) :] - elif line_trunc.startswith(file_line): - idx = line_trunc[len(file_line) :].find('"') - line_trunc = line_trunc[idx + len(file_line) :] - lines.append(line_trunc) - return "\n".join(lines) - - class _ExclusiveFileLock: """Cross-platform exclusive file lock context manager. - Uses fcntl on Unix/Linux/macOS, msvcrt on Windows. + Uses fcntl on Unix/Linux/macOS and LockFileEx on Windows. Prevents cache corruption when multiple processes access the same file. Design decisions: @@ -89,6 +115,7 @@ def __init__(self, file_path: Path, mode: str): self.file_path = file_path self.mode = mode self.file = None + self._windows_overlapped = None def __enter__(self): self.file = open(self.file_path, self.mode, encoding="utf-8") @@ -102,19 +129,44 @@ def __enter__(self): elif os.name == "nt": import msvcrt - # Lock 1 byte at position 0 - msvcrt.locking(self.file.fileno(), msvcrt.LK_LOCK, 1) + self._windows_overlapped = _Overlapped() + handle = msvcrt.get_osfhandle(self.file.fileno()) + if not _lock_file_ex( + handle, + _LOCKFILE_EXCLUSIVE_LOCK, + 0, + _WINDOWS_FILE_LOCK_LENGTH, + _WINDOWS_FILE_LOCK_LENGTH, + ctypes.byref(self._windows_overlapped), + ): + _raise_windows_lock_error("Failed to lock telemetry cache file") except Exception: self.file.close() self.file = None + self._windows_overlapped = None raise return self.file def __exit__(self, exc_type, exc_val, exc_tb): if self.file: - # Unlock happens automatically on close - self.file.close() + try: + if os.name == "nt" and self._windows_overlapped is not None: + import msvcrt + + handle = msvcrt.get_osfhandle(self.file.fileno()) + if not _unlock_file_ex( + handle, + 0, + _WINDOWS_FILE_LOCK_LENGTH, + _WINDOWS_FILE_LOCK_LENGTH, + ctypes.byref(self._windows_overlapped), + ): + _raise_windows_lock_error("Failed to unlock telemetry cache file") + finally: + self.file.close() + self.file = None + self._windows_overlapped = None def _exclusive_file_lock(file_path: Path, mode: str): @@ -125,21 +177,3 @@ def _exclusive_file_lock(file_path: Path, mode: str): :return: Context manager that returns an open file handle. """ return _ExclusiveFileLock(file_path, mode) - - -def _encode_cache_line(plaintext: str) -> str: - """Encode a single cache line using base64. - - :param plaintext: The plaintext string to encode. - :return: Base64-encoded string (safe for a single text line). - """ - return base64.b64encode(plaintext.encode("utf-8")).decode("ascii") - - -def _decode_cache_line(encoded: str) -> str: - """Decode a single base64-encoded cache line. - - :param encoded: The base64-encoded string. - :return: The decoded plaintext string. - """ - return base64.b64decode(encoded.encode("ascii")).decode("utf-8") diff --git a/olive/workflows/run/run.py b/olive/workflows/run/run.py index 89100e1c1c..b997fcdc8b 100644 --- a/olive/workflows/run/run.py +++ b/olive/workflows/run/run.py @@ -5,7 +5,7 @@ import logging from copy import deepcopy from pathlib import Path -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union from olive.common.utils import set_tempdir from olive.hardware.constants import ExecutionProvider @@ -13,6 +13,8 @@ from olive.package_config import OlivePackageConfig from olive.systems.accelerator_creator import create_accelerator from olive.systems.common import SystemType +from olive.telemetry.recipe_telemetry import _build_recipe_result_metadata, _load_config_input_for_telemetry +from olive.telemetry.telemetry_extensions import _format_exception_message, log_error, log_recipe_result from olive.workflows.run.config import RunConfig if TYPE_CHECKING: @@ -152,30 +154,71 @@ def run( list_required_packages: bool = False, package_config: Optional[Union[str, Path, dict]] = None, tempdir: Optional[Union[str, Path]] = None, + recipe_telemetry_metadata: Optional[dict[str, Any]] = None, + emit_recipe_telemetry: bool = True, ): # set tempdir set_tempdir(tempdir) + package_config_input = package_config + try: + package_config_telemetry_input = ( + _load_config_input_for_telemetry(package_config_input) if package_config_input is not None else None + ) + except Exception: + package_config_telemetry_input = None + + package_config_provided = package_config is not None if package_config is None: package_config = OlivePackageConfig.get_default_config_path() - package_config = OlivePackageConfig.parse_file_or_obj(package_config) - run_config: RunConfig = RunConfig.parse_file_or_obj(run_config) - - if list_required_packages: - # set the log level to INFO for packages - set_verbosity_info() - required_packages = get_required_packages(package_config, run_config) - generate_files_from_packages(required_packages, "olive_requirements.txt") - return None - - if run_config.engine.host and run_config.engine.host.type == SystemType.Docker: - docker_system = run_config.engine.host.create_system() - return docker_system.run_workflow(run_config) - - # set log level for olive - set_default_logger_severity(run_config.engine.log_severity_level) - return run_engine(package_config, run_config) + parsed_run_config = None + success = False + exception = None + try: + package_config = OlivePackageConfig.parse_file_or_obj(package_config) + parsed_run_config = RunConfig.parse_file_or_obj(run_config) + + if list_required_packages: + # set the log level to INFO for packages + set_verbosity_info() + required_packages = get_required_packages(package_config, parsed_run_config) + generate_files_from_packages(required_packages, "olive_requirements.txt") + success = True + return None + + if parsed_run_config.engine.host and parsed_run_config.engine.host.type == SystemType.Docker: + docker_system = parsed_run_config.engine.host.create_system() + workflow_output = docker_system.run_workflow(deepcopy(parsed_run_config)) + success = True + return workflow_output + + # set log level for olive + set_default_logger_severity(parsed_run_config.engine.log_severity_level) + workflow_output = run_engine(package_config, parsed_run_config) + success = True + return workflow_output + except Exception as exc: + exception = exc + raise + finally: + if exception is not None: + log_error( + exception_type=type(exception).__name__, + exception_message=_format_exception_message(exception, exception.__traceback__), + ) + if emit_recipe_telemetry: + metadata = _build_recipe_result_metadata( + run_config, + None, + parsed_run_config, + recipe_telemetry_metadata, + list_required_packages=list_required_packages, + package_config_input=package_config_telemetry_input, + package_config_provided=package_config_provided, + ) + recipe_name = metadata.pop("recipe_name") + log_recipe_result(recipe_name, success=success, metadata=metadata) def generate_files_from_packages(packages, file_name): diff --git a/test/cli/test_cli.py b/test/cli/test_cli.py index a7cb39e244..ae0e90b092 100644 --- a/test/cli/test_cli.py +++ b/test/cli/test_cli.py @@ -107,7 +107,17 @@ def test_workflow_run_command(mock_run, tempdir, list_required_packages, tmp_pat # assert mock_run.assert_called_once_with( - {"key": "value"}, package_config=None, tempdir=tempdir, list_required_packages=list_required_packages + {"key": "value"}, + package_config=None, + tempdir=tempdir, + list_required_packages=list_required_packages, + recipe_telemetry_metadata={ + "recipe_command": "WorkflowRun", + "recipe_source": "config_file", + "recipe_format": "json", + "execution_mode": "list_required_packages" if list_required_packages else "run", + "package_config_provided": False, + }, ) @@ -147,6 +157,22 @@ def test_workflow_run_command_with_overrides(mock_run, tmp_path): list_required_packages=False, package_config=None, tempdir=None, + recipe_telemetry_metadata={ + "recipe_command": "WorkflowRun", + "recipe_source": "config_file", + "recipe_format": "json", + "execution_mode": "run", + "package_config_provided": False, + "config_overrides": { + "input_model": { + "type": "HfModel", + "model_path": "hf-internal-testing/tiny-random-LlamaForCausalLM", + "load_kwargs": {"attn_implementation": "eager", "trust_remote_code": False}, + }, + "output_dir": str(Path("new_output_path").resolve()), + "log_severity_level": 2, + }, + }, ) diff --git a/test/cli/test_model_package.py b/test/cli/test_model_package.py index 458337e50a..3f25496346 100644 --- a/test/cli/test_model_package.py +++ b/test/cli/test_model_package.py @@ -3,24 +3,104 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- # pylint: disable=protected-access +"""Tests for ``olive generate-model-package``. + +Covers both the CLI argument-parsing / source-validation surface and the +underlying writer (``write_model_package`` and helpers); they live in the +same module (``olive.cli.model_package``). +""" + import json from argparse import ArgumentParser +from pathlib import Path +import onnx import pytest +from onnx import TensorProto, helper + +from olive.cli.model_package import ( + ModelPackageCommand, + VariantSpec, + disambiguate_variant_names, + parse_compatibility_strings, + write_model_package, +) + +# --------------------------------------------------------------------------- +# ONNX fixture helpers +# --------------------------------------------------------------------------- + -from olive.cli.model_package import ModelPackageCommand +def _make_onnx_inline(onnx_path: Path, metadata_props: dict[str, str] | None = None) -> Path: + """Write a minimal ONNX file with no external data.""" + onnx_path.parent.mkdir(parents=True, exist_ok=True) + init = helper.make_tensor("weight", TensorProto.FLOAT, [1], [1.0]) + output = helper.make_tensor_value_info("y", TensorProto.FLOAT, [None]) + node = helper.make_node("Identity", inputs=["weight"], outputs=["y"]) + graph = helper.make_graph([node], "test", inputs=[], outputs=[output], initializer=[init]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + if metadata_props: + for k, v in metadata_props.items(): + entry = model.metadata_props.add() + entry.key = k + entry.value = v + onnx.save(model, str(onnx_path)) + return onnx_path -def _create_source_dir(tmp_path, name, model_attributes): - """Create a fake Olive output directory with model_config.json and a dummy .onnx file.""" +def _make_onnx_with_external( + onnx_path: Path, + blob_relpath: str, + blob_bytes: bytes, + metadata_props: dict[str, str] | None = None, +) -> Path: + """Write a minimal ONNX file whose only initializer points at an external-data blob.""" + onnx_path.parent.mkdir(parents=True, exist_ok=True) + blob_path = onnx_path.parent / blob_relpath + blob_path.parent.mkdir(parents=True, exist_ok=True) + blob_path.write_bytes(blob_bytes) + + init = TensorProto() + init.name = "weight" + init.data_type = TensorProto.FLOAT + init.dims.extend([max(1, len(blob_bytes) // 4)]) + init.data_location = TensorProto.EXTERNAL + for k, v in (("location", blob_relpath), ("offset", "0"), ("length", str(len(blob_bytes)))): + entry = init.external_data.add() + entry.key = k + entry.value = v + + output = helper.make_tensor_value_info("y", TensorProto.FLOAT, [None]) + node = helper.make_node("Identity", inputs=["weight"], outputs=["y"]) + graph = helper.make_graph([node], "test", inputs=[], outputs=[output], initializer=[init]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + if metadata_props: + for k, v in metadata_props.items(): + entry = model.metadata_props.add() + entry.key = k + entry.value = v + onnx.save(model, str(onnx_path)) + return onnx_path + + +def _create_source_dir( + tmp_path: Path, + name: str, + model_attributes: dict, + *, + onnx_metadata: dict[str, str] | None = None, + inference_settings: dict | None = None, +) -> Path: + """Create a fake Olive output directory with model_config.json and a real ONNX file.""" source_dir = tmp_path / name source_dir.mkdir(parents=True) - model_config = { - "type": "ONNXModel", - "config": {"model_path": str(source_dir / "model.onnx"), "model_attributes": model_attributes}, - } + onnx_path = source_dir / "model.onnx" + _make_onnx_inline(onnx_path, metadata_props=onnx_metadata) + cfg: dict = {"model_path": str(onnx_path), "model_attributes": model_attributes} + if inference_settings is not None: + cfg["inference_settings"] = inference_settings + model_config = {"type": "ONNXModel", "config": cfg} (source_dir / "model_config.json").write_text(json.dumps(model_config)) - (source_dir / "model.onnx").write_text("dummy") return source_dir @@ -33,20 +113,21 @@ def _make_command(args_list): return parsed_args.func(parser, parsed_args, unknown) -class TestSourceValidation: - """Tests for _parse_sources validation logic.""" +# --------------------------------------------------------------------------- +# CLI: source validation +# --------------------------------------------------------------------------- - def test_rejects_single_source(self, tmp_path): - # setup + +class TestSourceValidation: + def test_accepts_single_source(self, tmp_path): src = _create_source_dir(tmp_path, "soc_60", {"ep": "QNNExecutionProvider"}) cmd = _make_command(["generate-model-package", "-s", str(src), "-o", str(tmp_path / "out")]) - # execute + assert - with pytest.raises(ValueError, match="At least two"): - cmd._parse_sources() + sources = cmd._parse_sources() + + assert sources == [("soc_60", src)] def test_rejects_missing_model_config(self, tmp_path): - # setup no_config = tmp_path / "no_config" no_config.mkdir() valid = _create_source_dir(tmp_path, "valid", {"ep": "QNNExecutionProvider"}) @@ -54,45 +135,50 @@ def test_rejects_missing_model_config(self, tmp_path): ["generate-model-package", "-s", str(no_config), "-s", str(valid), "-o", str(tmp_path / "out")] ) - # execute + assert with pytest.raises(ValueError, match=r"model_config\.json"): cmd._parse_sources() def test_rejects_nonexistent_path(self, tmp_path): - # setup valid = _create_source_dir(tmp_path, "valid", {"ep": "QNNExecutionProvider"}) cmd = _make_command( ["generate-model-package", "-s", "/nonexistent/path", "-s", str(valid), "-o", str(tmp_path / "out")] ) - # execute + assert with pytest.raises(ValueError, match="does not exist"): cmd._parse_sources() + def test_rejects_duplicate_source_basenames(self, tmp_path): + # Two source dirs share basename "soc_60" — variant names would collide. + src_a = _create_source_dir(tmp_path / "a", "soc_60", {"ep": "QNNExecutionProvider"}) + src_b = _create_source_dir(tmp_path / "b", "soc_60", {"ep": "QNNExecutionProvider"}) + cmd = _make_command(["generate-model-package", "-s", str(src_a), "-s", str(src_b), "-o", str(tmp_path / "out")]) + + with pytest.raises(ValueError, match="share the directory name"): + cmd._parse_sources() + def test_parses_two_valid_sources(self, tmp_path): - # setup src1 = _create_source_dir(tmp_path, "soc_60", {"ep": "QNNExecutionProvider"}) src2 = _create_source_dir(tmp_path, "soc_73", {"ep": "QNNExecutionProvider"}) cmd = _make_command(["generate-model-package", "-s", str(src1), "-s", str(src2), "-o", str(tmp_path / "out")]) - # execute sources = cmd._parse_sources() - # assert assert len(sources) == 2 assert sources[0] == ("soc_60", src1) assert sources[1] == ("soc_73", src2) -class TestGeneratePackageSingle: - """Tests for single-component model package generation.""" +# --------------------------------------------------------------------------- +# CLI: end-to-end (single component, multi-variant) +# --------------------------------------------------------------------------- + - def test_generates_manifest_and_metadata(self, tmp_path): - """Package output should have manifest.json and metadata.json.""" +class TestGeneratePackageMultiVariant: + def test_writes_proposal_layout(self, tmp_path): # setup src1 = _create_source_dir(tmp_path, "soc_60", {"ep": "QNNExecutionProvider", "device": "NPU"}) src2 = _create_source_dir(tmp_path, "soc_73", {"ep": "QNNExecutionProvider", "device": "NPU"}) - out_dir = tmp_path / "out" + out = tmp_path / "out" cmd = _make_command( [ "generate-model-package", @@ -101,7 +187,7 @@ def test_generates_manifest_and_metadata(self, tmp_path): "-s", str(src2), "-o", - str(out_dir), + str(out), "--model_name", "test_model", "--model_version", @@ -112,36 +198,627 @@ def test_generates_manifest_and_metadata(self, tmp_path): # execute cmd.run() - # assert: manifest - manifest_path = out_dir / "manifest.json" - assert manifest_path.exists() - manifest = json.loads(manifest_path.read_text()) - assert manifest["name"] == "test_model" - assert manifest["model_version"] == "2.0" - assert "component_models" in manifest - - # assert: metadata in component dir - component_name = manifest["component_models"][0] - metadata_path = out_dir / "models" / component_name / "metadata.json" - assert metadata_path.exists() - metadata = json.loads(metadata_path.read_text()) - assert "soc_60" in metadata["model_variants"] - assert "soc_73" in metadata["model_variants"] - - # assert: constraints - for variant in metadata["model_variants"].values(): - assert variant["constraints"]["ep"] == "QNNExecutionProvider" - assert variant["constraints"]["device"] == "NPU" - - -class TestAcceleratorInfo: - """Test accelerator info extraction.""" - - def test_defaults_accelerator_when_no_attributes(self): - """Falls back to CPUExecutionProvider/cpu when model_attributes is empty.""" - # setup + execute - ep, device = ModelPackageCommand._extract_accelerator_info([{"type": "ONNXModel", "config": {}}]) - - # assert - assert ep == "CPUExecutionProvider" - assert device == "cpu" + # assert: top-level layout (no models/ wrapper) + assert (out / "manifest.json").is_file() + assert not (out / "models").exists() + + manifest = json.loads((out / "manifest.json").read_text()) + assert manifest["schema_version"] == 1 + assert manifest["components"] == ["model"] + assert manifest["producer"]["model_name"] == "test_model" + assert manifest["producer"]["model_version"] == "2.0" + + # metadata uses ep_compatibility[] + metadata = json.loads((out / "model" / "metadata.json").read_text()) + assert set(metadata["variants"]) == {"soc_60", "soc_73"} + for variant_payload in metadata["variants"].values(): + ep_compat = variant_payload["ep_compatibility"] + assert ep_compat == [{"ep": "QNNExecutionProvider", "device": "NPU"}] + + # variant.json contains files[] with filename + for v in ("soc_60", "soc_73"): + variant_json = json.loads((out / "model" / v / "variant.json").read_text()) + assert variant_json["files"][0]["filename"] == "model.onnx" + assert (out / "model" / v / "model.onnx").is_file() + + +class TestGeneratePackageSingleSource: + def test_single_source_is_valid_package(self, tmp_path): + src = _create_source_dir(tmp_path, "cpu_x64", {"ep": "CPUExecutionProvider"}) + out = tmp_path / "out" + cmd = _make_command(["generate-model-package", "-s", str(src), "-o", str(out)]) + + cmd.run() + + manifest = json.loads((out / "manifest.json").read_text()) + assert manifest["components"] == ["model"] + metadata = json.loads((out / "model" / "metadata.json").read_text()) + assert "cpu_x64" in metadata["variants"] + assert metadata["variants"]["cpu_x64"]["ep_compatibility"] == [{"ep": "CPUExecutionProvider"}] + # No shared_weights because nothing to dedup. + assert not (out / "model" / "shared_weights").exists() + + +# --------------------------------------------------------------------------- +# Writer: layout + manifest + metadata + variant.json +# --------------------------------------------------------------------------- + + +class TestWriteModelPackageLayout: + def test_writes_proposal_shape_for_single_variant(self, tmp_path): + onnx_path = _make_onnx_inline(tmp_path / "src" / "model.onnx") + out = tmp_path / "package" + + write_model_package( + output_dir=out, + variants=[ + VariantSpec( + component_name="decoder", + variant_name="cpu", + onnx_files=[onnx_path], + ep="CPUExecutionProvider", + device="cpu", + ) + ], + producer_info={"tool": "olive-ai", "model_name": "demo"}, + ) + + assert (out / "manifest.json").is_file() + assert (out / "decoder" / "metadata.json").is_file() + assert (out / "decoder" / "cpu" / "variant.json").is_file() + assert (out / "decoder" / "cpu" / "model.onnx").is_file() + assert not (out / "models").exists() + + def test_manifest_uses_proposal_schema(self, tmp_path): + onnx_path = _make_onnx_inline(tmp_path / "src" / "model.onnx") + out = tmp_path / "package" + + write_model_package( + output_dir=out, + variants=[ + VariantSpec( + component_name="decoder", + variant_name="cpu", + onnx_files=[onnx_path], + ep="CPUExecutionProvider", + ) + ], + producer_info={"tool": "olive-ai", "tool_version": "1.2.3", "model_name": "demo"}, + ) + + manifest = json.loads((out / "manifest.json").read_text()) + assert manifest["schema_version"] == 1 + assert manifest["components"] == ["decoder"] + assert manifest["producer"] == { + "tool": "olive-ai", + "tool_version": "1.2.3", + "model_name": "demo", + } + # No legacy fields + assert "name" not in manifest + assert "component_models" not in manifest + assert "model_version" not in manifest + + def test_metadata_uses_ep_compatibility_array(self, tmp_path): + onnx_path = _make_onnx_inline(tmp_path / "src" / "model.onnx") + out = tmp_path / "package" + + write_model_package( + output_dir=out, + variants=[ + VariantSpec( + component_name="decoder", + variant_name="qnn-npu", + onnx_files=[onnx_path], + ep="QNNExecutionProvider", + device="NPU", + compatibility=["soc_60", "soc_69"], + ) + ], + ) + + metadata = json.loads((out / "decoder" / "metadata.json").read_text()) + ep_compat = metadata["variants"]["qnn-npu"]["ep_compatibility"] + assert ep_compat == [{"ep": "QNNExecutionProvider", "device": "NPU", "compatibility": ["soc_60", "soc_69"]}] + assert "model_variants" not in metadata + + def test_metadata_omits_optional_fields_when_unset(self, tmp_path): + onnx_path = _make_onnx_inline(tmp_path / "src" / "model.onnx") + out = tmp_path / "package" + + write_model_package( + output_dir=out, + variants=[ + VariantSpec( + component_name="decoder", + variant_name="cpu", + onnx_files=[onnx_path], + ep="CPUExecutionProvider", + ) + ], + ) + + metadata = json.loads((out / "decoder" / "metadata.json").read_text()) + ep_compat = metadata["variants"]["cpu"]["ep_compatibility"][0] + assert ep_compat == {"ep": "CPUExecutionProvider"} + + def test_variant_json_carries_session_and_provider_options(self, tmp_path): + onnx_path = _make_onnx_inline(tmp_path / "src" / "model.onnx") + out = tmp_path / "package" + + inference = { + "session_options": {"graph_optimization_level": 3}, + "execution_provider": ["CPUExecutionProvider"], + "provider_options": [{"intra_op_num_threads": 4}], + } + + write_model_package( + output_dir=out, + variants=[ + VariantSpec( + component_name="decoder", + variant_name="cpu", + onnx_files=[onnx_path], + ep="CPUExecutionProvider", + inference_settings=inference, + ) + ], + ) + + variant = json.loads((out / "decoder" / "cpu" / "variant.json").read_text()) + assert variant["files"] == [ + { + "filename": "model.onnx", + "session_options": {"graph_optimization_level": 3}, + "provider_options": {"intra_op_num_threads": 4}, + } + ] + + def test_provider_options_match_ep_by_name(self, tmp_path): + """When inference_settings has multiple EPs, pick the one whose name matches VariantSpec.ep.""" + onnx_path = _make_onnx_inline(tmp_path / "src" / "model.onnx") + out = tmp_path / "package" + + inference = { + "session_options": {}, + "execution_provider": ["CPUExecutionProvider", "QNNExecutionProvider"], + "provider_options": [{"cpu_only": "1"}, {"backend_path": "QnnHtp.so"}], + } + + write_model_package( + output_dir=out, + variants=[ + VariantSpec( + component_name="decoder", + variant_name="qnn", + onnx_files=[onnx_path], + ep="QNNExecutionProvider", + inference_settings=inference, + ) + ], + ) + + variant = json.loads((out / "decoder" / "qnn" / "variant.json").read_text()) + assert variant["files"][0].get("provider_options") == {"backend_path": "QnnHtp.so"} + assert "session_options" not in variant["files"][0] + + +# --------------------------------------------------------------------------- +# Writer: shared_weights / external-data dedup +# --------------------------------------------------------------------------- + + +class TestSharedWeightsDedup: + def test_dedups_identical_external_data_across_variants(self, tmp_path): + blob = b"\x00\x01\x02\x03" * 64 + a = _make_onnx_with_external(tmp_path / "a" / "model.onnx", "model.onnx.data", blob) + b = _make_onnx_with_external(tmp_path / "b" / "model.onnx", "model.onnx.data", blob) + out = tmp_path / "package" + + write_model_package( + output_dir=out, + variants=[ + VariantSpec( + component_name="decoder", + variant_name="v1", + onnx_files=[a], + ep="CPUExecutionProvider", + ), + VariantSpec( + component_name="decoder", + variant_name="v2", + onnx_files=[b], + ep="CPUExecutionProvider", + ), + ], + ) + + shared_root = out / "decoder" / "shared_weights" + assert shared_root.is_dir() + sha_dirs = list(shared_root.iterdir()) + assert len(sha_dirs) == 1 + sha = sha_dirs[0].name + assert (shared_root / sha / "model.onnx.data").is_file() + assert not (out / "decoder" / "v1" / "model.onnx.data").exists() + assert not (out / "decoder" / "v2" / "model.onnx.data").exists() + + for v in ("v1", "v2"): + variant = json.loads((out / "decoder" / v / "variant.json").read_text()) + entry = variant["files"][0] + assert entry["filename"] == "model.onnx" + assert entry["shared_files"] == {"model.onnx.data": sha} + + def test_keeps_external_data_inline_when_unique(self, tmp_path): + a = _make_onnx_with_external(tmp_path / "a" / "model.onnx", "model.onnx.data", b"a-bytes" * 32) + b = _make_onnx_with_external(tmp_path / "b" / "model.onnx", "model.onnx.data", b"b-bytes" * 32) + out = tmp_path / "package" + + write_model_package( + output_dir=out, + variants=[ + VariantSpec( + component_name="decoder", + variant_name="v1", + onnx_files=[a], + ep="CPUExecutionProvider", + ), + VariantSpec( + component_name="decoder", + variant_name="v2", + onnx_files=[b], + ep="CPUExecutionProvider", + ), + ], + ) + + assert not (out / "decoder" / "shared_weights").exists() + assert (out / "decoder" / "v1" / "model.onnx.data").is_file() + assert (out / "decoder" / "v2" / "model.onnx.data").is_file() + + for v in ("v1", "v2"): + variant = json.loads((out / "decoder" / v / "variant.json").read_text()) + assert "shared_files" not in variant["files"][0] + + def test_single_variant_keeps_blob_inline(self, tmp_path): + onnx_path = _make_onnx_with_external(tmp_path / "src" / "model.onnx", "model.onnx.data", b"x" * 128) + out = tmp_path / "package" + + write_model_package( + output_dir=out, + variants=[ + VariantSpec( + component_name="decoder", + variant_name="cpu", + onnx_files=[onnx_path], + ep="CPUExecutionProvider", + ) + ], + ) + + assert (out / "decoder" / "cpu" / "model.onnx.data").is_file() + assert not (out / "decoder" / "shared_weights").exists() + variant = json.loads((out / "decoder" / "cpu" / "variant.json").read_text()) + assert "shared_files" not in variant["files"][0] + + +# --------------------------------------------------------------------------- +# Writer: configs/ + safety +# --------------------------------------------------------------------------- + + +class TestConfigsAndSafety: + def test_copies_config_files_into_configs_dir(self, tmp_path): + onnx_path = _make_onnx_inline(tmp_path / "src" / "model.onnx") + cfg_a = tmp_path / "configs_src" / "tokenizer.json" + cfg_a.parent.mkdir(parents=True) + cfg_a.write_text("{}") + cfg_b = tmp_path / "configs_src" / "genai_config.json" + cfg_b.write_text("{}") + out = tmp_path / "package" + + write_model_package( + output_dir=out, + variants=[ + VariantSpec( + component_name="decoder", + variant_name="cpu", + onnx_files=[onnx_path], + ep="CPUExecutionProvider", + ) + ], + config_files={"tokenizer.json": cfg_a, "genai_config.json": cfg_b}, + ) + + assert (out / "configs" / "tokenizer.json").is_file() + assert (out / "configs" / "genai_config.json").is_file() + + def test_rejects_non_empty_output_dir(self, tmp_path): + onnx_path = _make_onnx_inline(tmp_path / "src" / "model.onnx") + out = tmp_path / "package" + out.mkdir() + (out / "stale.txt").write_text("stale") + + with pytest.raises(ValueError, match="not empty"): + write_model_package( + output_dir=out, + variants=[ + VariantSpec( + component_name="decoder", + variant_name="cpu", + onnx_files=[onnx_path], + ep="CPUExecutionProvider", + ) + ], + ) + + def test_rejects_invalid_component_name(self, tmp_path): + onnx_path = _make_onnx_inline(tmp_path / "src" / "model.onnx") + out = tmp_path / "package" + + with pytest.raises(ValueError, match="component name"): + write_model_package( + output_dir=out, + variants=[ + VariantSpec( + component_name="../escape", + variant_name="cpu", + onnx_files=[onnx_path], + ep="CPUExecutionProvider", + ) + ], + ) + + def test_rejects_invalid_variant_name(self, tmp_path): + onnx_path = _make_onnx_inline(tmp_path / "src" / "model.onnx") + out = tmp_path / "package" + + with pytest.raises(ValueError, match="variant name"): + write_model_package( + output_dir=out, + variants=[ + VariantSpec( + component_name="decoder", + variant_name="bad/name", + onnx_files=[onnx_path], + ep="CPUExecutionProvider", + ) + ], + ) + + def test_rejects_duplicate_variant_names_per_component(self, tmp_path): + onnx_path = _make_onnx_inline(tmp_path / "src" / "model.onnx") + out = tmp_path / "package" + + with pytest.raises(ValueError, match="Duplicate variant name"): + write_model_package( + output_dir=out, + variants=[ + VariantSpec( + component_name="decoder", + variant_name="cpu", + onnx_files=[onnx_path], + ep="CPUExecutionProvider", + ), + VariantSpec( + component_name="decoder", + variant_name="cpu", + onnx_files=[onnx_path], + ep="CPUExecutionProvider", + ), + ], + ) + + def test_rejects_empty_variants(self, tmp_path): + with pytest.raises(ValueError, match="at least one variant"): + write_model_package(output_dir=tmp_path / "package", variants=[]) + + def test_skips_config_file_with_unsafe_key(self, tmp_path): + # setup: a real source plus a config_files map with a path-escaping key. + onnx_path = _make_onnx_inline(tmp_path / "src" / "model.onnx") + bad = tmp_path / "configs_src" / "evil.txt" + bad.parent.mkdir(parents=True) + bad.write_text("oops") + out = tmp_path / "package" + + # execute + write_model_package( + output_dir=out, + variants=[ + VariantSpec( + component_name="decoder", + variant_name="cpu", + onnx_files=[onnx_path], + ep="CPUExecutionProvider", + ) + ], + config_files={"../escape.txt": bad, "subdir/nested.txt": bad, "ok.txt": bad}, + ) + + # assert: unsafe keys are dropped, safe key copied + assert not (out.parent / "escape.txt").exists() + assert not (out / "configs" / "subdir").exists() + assert not (out / "configs" / "..").is_dir() or not (out / ".." / "escape.txt").exists() + assert (out / "configs" / "ok.txt").exists() + # configs/ should contain only the one safe entry + assert sorted(p.name for p in (out / "configs").iterdir()) == ["ok.txt"] + + +# --------------------------------------------------------------------------- +# CLI: mixed source types +# --------------------------------------------------------------------------- + + +class TestMixedSourceTypes: + def test_rejects_mixed_onnx_and_composite(self, tmp_path): + # setup: one ONNXModel source, one CompositeModel source + onnx_src = _create_source_dir(tmp_path, "onnx_src", {"ep": "CPUExecutionProvider"}) + comp_src = tmp_path / "comp_src" + comp_src.mkdir() + comp_onnx = _make_onnx_inline(comp_src / "comp.onnx") + (comp_src / "model_config.json").write_text( + json.dumps( + { + "type": "CompositeModel", + "config": { + "model_components": [{"type": "ONNXModel", "config": {"model_path": str(comp_onnx)}}], + "component_names": ["decoder"], + }, + } + ) + ) + cmd = _make_command( + ["generate-model-package", "-s", str(onnx_src), "-s", str(comp_src), "-o", str(tmp_path / "out")] + ) + + # execute + assert + with pytest.raises(ValueError, match="mix model types"): + cmd.run() + + +# --------------------------------------------------------------------------- +# Helper functions +# --------------------------------------------------------------------------- + + +class TestParseCompatibilityStrings: + def test_splits_comma_delimited_string(self): + assert parse_compatibility_strings("sm_80,sm_86,sm_90") == ["sm_80", "sm_86", "sm_90"] + + def test_strips_whitespace_and_drops_empty(self): + assert parse_compatibility_strings(" sm_80 , , sm_86 ") == ["sm_80", "sm_86"] + + def test_returns_empty_for_none_or_empty(self): + assert parse_compatibility_strings(None) == [] + assert parse_compatibility_strings("") == [] + + +class TestDisambiguateVariantNames: + def test_passes_unique_names_through(self): + assert disambiguate_variant_names([("c", "a"), ("c", "b")]) == ["a", "b"] + + def test_appends_rank_suffix_on_collision(self): + out = disambiguate_variant_names([("c", "a"), ("c", "a"), ("c", "a")]) + assert out == ["a_rank1", "a_rank2", "a_rank3"] + + def test_isolates_collisions_per_component(self): + out = disambiguate_variant_names([("c1", "a"), ("c2", "a")]) + assert out == ["a", "a"] + + +# --------------------------------------------------------------------------- +# CLI: comma-delimited compatibility from ONNX metadata +# --------------------------------------------------------------------------- + + +class TestCompatibilityFromOnnxMetadata: + def test_splits_comma_delimited_metadata(self, tmp_path): + # setup: source with QNNExecutionProvider compat info in ONNX metadata_props + src = _create_source_dir( + tmp_path, + "soc_60", + {"ep": "QNNExecutionProvider", "device": "NPU"}, + onnx_metadata={"ep_compatibility_info.QNNExecutionProvider": "soc_60,soc_69,soc_73"}, + ) + out = tmp_path / "out" + cmd = _make_command(["generate-model-package", "-s", str(src), "-o", str(out)]) + + # execute + cmd.run() + + # assert: compatibility array reflects the comma-split list + metadata = json.loads((out / "model" / "metadata.json").read_text()) + ep_compat = metadata["variants"]["soc_60"]["ep_compatibility"][0] + assert ep_compat["ep"] == "QNNExecutionProvider" + assert ep_compat["compatibility"] == ["soc_60", "soc_69", "soc_73"] + + +# --------------------------------------------------------------------------- +# CLI: composite (per-component inference_settings precedence) +# --------------------------------------------------------------------------- + + +def _create_composite_source( + tmp_path: Path, + name: str, + components: list[dict], + component_names: list[str], + *, + target_inference: dict | None = None, + target_attrs: dict | None = None, +) -> Path: + """Create an Olive-style composite source dir.""" + source_dir = tmp_path / name + source_dir.mkdir(parents=True) + cfg = {"model_components": components, "component_names": component_names} + if target_inference is not None: + cfg["inference_settings"] = target_inference + if target_attrs is not None: + cfg["model_attributes"] = target_attrs + (source_dir / "model_config.json").write_text(json.dumps({"type": "CompositeModel", "config": cfg})) + return source_dir + + +class TestCompositeBuild: + def test_per_component_inference_settings_wins(self, tmp_path): + # setup: component-level inference_settings should override target-level + comp_a_onnx = _make_onnx_inline(tmp_path / "comp_a" / "model.onnx") + comp_b_onnx = _make_onnx_inline(tmp_path / "comp_b" / "model.onnx") + + target_inference = { + "session_options": {"graph_optimization_level": 1}, + "execution_provider": ["CPUExecutionProvider"], + "provider_options": [{}], + } + comp_b_inference = { + "session_options": {"graph_optimization_level": 99}, + "execution_provider": ["CPUExecutionProvider"], + "provider_options": [{}], + } + components = [ + {"type": "ONNXModel", "config": {"model_path": str(comp_a_onnx)}}, + { + "type": "ONNXModel", + "config": {"model_path": str(comp_b_onnx), "inference_settings": comp_b_inference}, + }, + ] + src = _create_composite_source( + tmp_path, + "soc_60", + components, + ["encoder", "decoder"], + target_inference=target_inference, + target_attrs={"ep": "CPUExecutionProvider"}, + ) + out = tmp_path / "out" + cmd = _make_command(["generate-model-package", "-s", str(src), "-o", str(out)]) + + # execute + cmd.run() + + # assert: encoder uses target-level, decoder uses component-level + encoder_v = json.loads((out / "encoder" / "soc_60" / "variant.json").read_text()) + assert encoder_v["files"][0]["session_options"] == {"graph_optimization_level": 1} + + decoder_v = json.loads((out / "decoder" / "soc_60" / "variant.json").read_text()) + assert decoder_v["files"][0]["session_options"] == {"graph_optimization_level": 99} + + +# --------------------------------------------------------------------------- +# CLI: unsupported model type +# --------------------------------------------------------------------------- + + +class TestUnsupportedModelType: + def test_rejects_pytorch_model(self, tmp_path): + # setup: a source whose model_config declares an unsupported type + source_dir = tmp_path / "pytorch_src" + source_dir.mkdir() + (source_dir / "model_config.json").write_text( + json.dumps({"type": "PyTorchModel", "config": {"model_path": "pt"}}) + ) + out = tmp_path / "out" + cmd = _make_command(["generate-model-package", "-s", str(source_dir), "-o", str(out)]) + + # execute + assert + with pytest.raises(ValueError, match="Unsupported source model type"): + cmd.run() diff --git a/test/passes/onnx/test_graph_surgeries.py b/test/passes/onnx/test_graph_surgeries.py index 337ee9139b..ad44db040f 100644 --- a/test/passes/onnx/test_graph_surgeries.py +++ b/test/passes/onnx/test_graph_surgeries.py @@ -588,6 +588,155 @@ def test_simplifiedlayernorm_to_l2norm_skip(tmp_path, all_ones, output_skip_sum) ) +def check_rmsnorm( + original_model_path: str, + modified_model_path: str, + hidden_size: int, + expected_num_nodes: int, + has_skip: bool = False, +): + # check output values match + input_session = InferenceSession(original_model_path) + output_session = InferenceSession(modified_model_path) + input_feed = {"x": np.random.randn(1, hidden_size).astype(np.float32)} + if has_skip: + input_feed["skip"] = np.random.randn(1, hidden_size).astype(np.float32) + input_result = input_session.run(None, input_feed) + output_result = output_session.run(None, input_feed) + for i_r, o_r in zip(input_result, output_result): + np.testing.assert_allclose(i_r, o_r, rtol=1e-3, atol=1e-3) + + # count nodes and verify expected op types are present + dag = OnnxDAG.from_model_path(modified_model_path) + assert len(dag.nodes) == expected_num_nodes + op_types = dag.get_node_op_types() + assert "Pow" in op_types + assert "ReduceMean" in op_types + assert "Sqrt" in op_types + assert "Div" in op_types + assert "Mul" in op_types + assert "SimplifiedLayerNormalization" not in op_types + assert "SkipSimplifiedLayerNormalization" not in op_types + + +@pytest.mark.parametrize("all_ones", [True, False]) +def test_simplifiedlayernorm_to_rmsnorm(tmp_path, all_ones): + # setup + hidden_size = 3 + inputs = [ + onnx.helper.make_tensor_value_info("x", TensorProto.FLOAT, [1, hidden_size]), + ] + outputs = [ + onnx.helper.make_tensor_value_info("y", TensorProto.FLOAT, [1, hidden_size]), + ] + weight = (np.ones(hidden_size) if all_ones else np.random.randn(hidden_size)).astype(np.float32) + initializers = [onnx.numpy_helper.from_array(weight, name="weight")] + nodes = [ + onnx.helper.make_node( + "SimplifiedLayerNormalization", + inputs=["x", "weight"], + outputs=["layernorm_output"], + name="layernorm/LayerNorm", + ), + onnx.helper.make_node("Identity", inputs=["layernorm_output"], outputs=["y"], name="Identity"), + ] + graph = helper.make_graph( + nodes=nodes, + name="TestGraph", + inputs=inputs, + outputs=outputs, + initializer=initializers, + ) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 20)]) + model.ir_version = 10 + onnx.save(model, str(tmp_path / "input_model.onnx")) + input_model = ONNXModelHandler(model_path=str(tmp_path / "input_model.onnx")) + + output_folder = str(tmp_path / "output") + p = create_pass_from_dict( + GraphSurgeries, + {"surgeries": [{"surgeon": "SimplifiedLayerNormToRMSNorm"}]}, + disable_search=True, + ) + + # execute + onnx_model = p.run(input_model, output_folder) + + # assert + # Pow, ReduceMean, Add(eps), Sqrt, Div, Mul, Identity = 7 nodes + check_rmsnorm(str(tmp_path / "input_model.onnx"), onnx_model.model_path, hidden_size, 7) + + +@pytest.mark.parametrize("all_ones", [True, False]) +@pytest.mark.parametrize("output_skip_sum", [True, False]) +def test_simplifiedlayernorm_to_rmsnorm_skip(tmp_path, all_ones, output_skip_sum): + # setup + hidden_size = 3 + inputs = [ + onnx.helper.make_tensor_value_info("x", TensorProto.FLOAT, [1, hidden_size]), + onnx.helper.make_tensor_value_info("skip", TensorProto.FLOAT, [1, hidden_size]), + ] + outputs = [ + onnx.helper.make_tensor_value_info("y", TensorProto.FLOAT, [1, hidden_size]), + ] + if output_skip_sum: + outputs.append( + onnx.helper.make_tensor_value_info("skip_sum", TensorProto.FLOAT, [1, hidden_size]), + ) + initializers = [ + onnx.numpy_helper.from_array( + (np.ones(hidden_size) if all_ones else np.random.randn(hidden_size)).astype(np.float32), name="weight" + ) + ] + nodes = [ + onnx.helper.make_node( + "SkipSimplifiedLayerNormalization", + inputs=["x", "skip", "weight"], + outputs=["layernorm_output"] if not output_skip_sum else ["layernorm_output", "", "", "layernorm_skip_sum"], + name="layernorm/LayerNorm", + domain=MSFT_DOMAIN, + ), + onnx.helper.make_node("Identity", inputs=["layernorm_output"], outputs=["y"], name="Identity"), + ] + if output_skip_sum: + nodes.append( + onnx.helper.make_node( + "Identity", inputs=["layernorm_skip_sum"], outputs=["skip_sum"], name="Identity_skip_sum" + ) + ) + graph = helper.make_graph( + nodes=nodes, + name="TestGraph", + inputs=inputs, + outputs=outputs, + initializer=initializers, + ) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 20)]) + model.ir_version = 10 + onnx.save(model, str(tmp_path / "input_model.onnx")) + input_model = ONNXModelHandler(model_path=str(tmp_path / "input_model.onnx")) + + output_folder = str(tmp_path / "output") + p = create_pass_from_dict( + GraphSurgeries, + {"surgeries": [{"surgeon": "SimplifiedLayerNormToRMSNorm"}]}, + disable_search=True, + ) + + # execute + output_model = p.run(input_model, output_folder) + + # assert + # Add(skip), Pow, ReduceMean, Add(eps), Sqrt, Div, Mul, Identity[, Identity_skip_sum] = 8 or 9 nodes + check_rmsnorm( + str(tmp_path / "input_model.onnx"), + output_model.model_path, + hidden_size, + 8 + int(output_skip_sum), + has_skip=True, + ) + + @pytest.mark.parametrize("use_large_cache", [True, False]) def test_remove_rope_multi_cache(tmp_path, use_large_cache): # setup diff --git a/test/passes/onnx/test_model_builder.py b/test/passes/onnx/test_model_builder.py index ba62005e4b..be5b728c65 100644 --- a/test/passes/onnx/test_model_builder.py +++ b/test/passes/onnx/test_model_builder.py @@ -2,18 +2,45 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- +import json +import sys +import types from pathlib import Path +from unittest.mock import Mock import onnx import pytest -from olive.model import ONNXModelHandler +from olive.model import CompositeModelHandler, HfModelHandler, ONNXModelHandler from olive.passes.olive_pass import create_pass_from_dict from olive.passes.onnx.model_builder import ModelBuilder from olive.passes.pytorch.rtn import Rtn from test.utils import make_local_tiny_llama +def _create_test_onnx_model(model_path: Path, node_name: str): + input_info = onnx.helper.make_tensor_value_info("input", onnx.TensorProto.FLOAT, [1, 1]) + output_info = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, [1, 1]) + node = onnx.helper.make_node("Identity", ["input"], ["output"], name=node_name) + graph = onnx.helper.make_graph([node], "test_graph", [input_info], [output_info]) + model = onnx.helper.make_model(graph) + onnx.save(model, model_path) + + +def _mock_genai_builder(monkeypatch, create_model_fn): + builder_module = types.ModuleType("onnxruntime_genai.models.builder") + builder_module.create_model = create_model_fn + models_module = types.ModuleType("onnxruntime_genai.models") + models_module.builder = builder_module + genai_module = types.ModuleType("onnxruntime_genai") + genai_module.__version__ = "0.8.0" + genai_module.models = models_module + monkeypatch.setitem(sys.modules, "onnxruntime_genai", genai_module) + monkeypatch.setitem(sys.modules, "onnxruntime_genai.models", models_module) + monkeypatch.setitem(sys.modules, "onnxruntime_genai.models.builder", builder_module) + monkeypatch.setattr(ModelBuilder, "maybe_patch_quant", staticmethod(lambda: None)) + + @pytest.mark.parametrize("metadata_only", [True, False]) def test_model_builder(tmp_path, metadata_only): input_model = make_local_tiny_llama(tmp_path / "input_model", "onnx" if metadata_only else "hf") @@ -100,3 +127,72 @@ def test_model_builder_layer_annotations(tmp_path, layer_annotations): assert len(node_names_with_metadata) > 0, ( "Expected nodes with metadata_props when layer_annotations are provided" ) + + +def test_model_builder_apply_annotations_on_single_file_fallback(tmp_path, monkeypatch): + def fake_create_model( + model_name, input_path, output_dir, precision, execution_provider, cache_dir, filename, **kwargs + ): + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + _create_test_onnx_model(output_dir / "actual.onnx", "test_node") + (output_dir / "actual.onnx.data").write_text("external_data") + (output_dir / "tokenizer.json").write_text("{}") + (output_dir / "genai_config.json").write_text(json.dumps({"search": {}})) + + _mock_genai_builder(monkeypatch, fake_create_model) + input_model = Mock(spec=HfModelHandler) + input_model.model_name_or_path = "dummy-model" + input_model.adapter_path = None + input_model.model_attributes = {"split_assignments": {"model.layers.0": 1}} + + p = create_pass_from_dict( + ModelBuilder, {"precision": "fp32", "extra_options": {"filename": "expected.onnx"}}, disable_search=True + ) + output_folder = tmp_path / "output_model" + output_model = p.run(input_model, output_folder) + + assert isinstance(output_model, ONNXModelHandler) + assert output_model.onnx_file_name == "actual.onnx" + model_proto = onnx.load(output_folder / "actual.onnx", load_external_data=False) + metadata_props = {prop.key: prop.value for prop in model_proto.metadata_props} + assert metadata_props["split_assignments"] == "model.layers.0=1" + assert str(output_folder / "actual.onnx") not in output_model.model_attributes["additional_files"] + assert str(output_folder / "actual.onnx.data") not in output_model.model_attributes["additional_files"] + assert str(output_folder / "tokenizer.json") in output_model.model_attributes["additional_files"] + + +def test_model_builder_multi_file_output_preserves_component_filenames(tmp_path, monkeypatch): + def fake_create_model( + model_name, input_path, output_dir, precision, execution_provider, cache_dir, filename, **kwargs + ): + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + _create_test_onnx_model(output_dir / "encoder.onnx", "encoder_node") + _create_test_onnx_model(output_dir / "decoder.onnx", "decoder_node") + (output_dir / "encoder.onnx.data").write_text("encoder_data") + (output_dir / "decoder.onnx.data").write_text("decoder_data") + (output_dir / "tokenizer.json").write_text("{}") + (output_dir / "genai_config.json").write_text(json.dumps({"search": {}})) + + _mock_genai_builder(monkeypatch, fake_create_model) + input_model = Mock(spec=HfModelHandler) + input_model.model_name_or_path = "dummy-model" + input_model.adapter_path = None + input_model.model_attributes = {} + + p = create_pass_from_dict(ModelBuilder, {"precision": "fp32"}, disable_search=True) + output_folder = tmp_path / "output_model" + output_model = p.run(input_model, output_folder) + + assert isinstance(output_model, CompositeModelHandler) + expected_component_names = sorted(["encoder.onnx", "decoder.onnx"]) + assert output_model.model_component_names == expected_component_names + component_onnx_files = [component.onnx_file_name for component in output_model.model_components] + assert component_onnx_files == output_model.model_component_names + additional_files = output_model.model_attributes["additional_files"] + assert str(output_folder / "encoder.onnx") not in additional_files + assert str(output_folder / "decoder.onnx") not in additional_files + assert str(output_folder / "encoder.onnx.data") not in additional_files + assert str(output_folder / "decoder.onnx.data") not in additional_files + assert str(output_folder / "tokenizer.json") in additional_files diff --git a/test/passes/onnx/test_quantize_embedding.py b/test/passes/onnx/test_quantize_embedding.py new file mode 100644 index 0000000000..db3f2e83f7 --- /dev/null +++ b/test/passes/onnx/test_quantize_embedding.py @@ -0,0 +1,195 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import numpy as np +from onnx import TensorProto, helper, numpy_helper + +from olive.passes.onnx.graph_surgeries import QuantizeEmbeddingInt8, ShareEmbeddingLmHead + + +def _make_model_with_fp16_embed(vocab_size=64, hidden_size=64, block_size=32): + """Create a minimal ONNX model with FP16 Gather embedding and INT4 MatMulNBits lm_head.""" + # Embedding: Gather with FP16 weight + embed_weight = np.random.randn(vocab_size, hidden_size).astype(np.float16) + embed_init = numpy_helper.from_array(embed_weight, name="model.embed_tokens.weight") + + input_ids = helper.make_tensor_value_info("input_ids", TensorProto.INT64, ["batch_size", "seq_len"]) + + gather_node = helper.make_node( + "Gather", + inputs=["model.embed_tokens.weight", "input_ids"], + outputs=["embed_output"], + name="/model/embed_tokens/Gather", + ) + + # lm_head: MatMulNBits with INT4 weight + num_blocks = hidden_size // block_size + lm_weight = np.random.randint(0, 255, (vocab_size, num_blocks, block_size // 2), dtype=np.uint8) + lm_scales = np.random.randn(vocab_size, num_blocks).astype(np.float16) * 0.01 + lm_zp = np.full((vocab_size, num_blocks), 8, dtype=np.uint8) + + lm_weight_init = numpy_helper.from_array(lm_weight, name="lm_head.MatMul_Q4.qweight") + lm_scales_init = numpy_helper.from_array(lm_scales, name="lm_head.MatMul_Q4.scales") + lm_zp_init = numpy_helper.from_array(lm_zp, name="lm_head.MatMul_Q4.zp") + + lm_head_node = helper.make_node( + "MatMulNBits", + inputs=["embed_output", "lm_head.MatMul_Q4.qweight", "lm_head.MatMul_Q4.scales", "lm_head.MatMul_Q4.zp"], + outputs=["logits"], + name="/lm_head/MatMulNBits", + domain="com.microsoft", + bits=4, + block_size=block_size, + K=hidden_size, + N=vocab_size, + ) + + graph = helper.make_graph( + [gather_node, lm_head_node], + "test", + [input_ids], + [helper.make_tensor_value_info("logits", TensorProto.FLOAT16, ["batch_size", "seq_len", vocab_size])], + initializer=[embed_init, lm_weight_init, lm_scales_init, lm_zp_init], + ) + return helper.make_model( + graph, + opset_imports=[helper.make_opsetid("", 21), helper.make_opsetid("com.microsoft", 1)], + ) + + +class TestQuantizeEmbeddingInt8: + def test_replaces_gather_with_gbq(self): + model = _make_model_with_fp16_embed() + surgery = QuantizeEmbeddingInt8() + result = surgery(model) + + # Verify Gather is replaced with GatherBlockQuantized + node_types = [n.op_type for n in result.graph.node] + assert "Gather" not in node_types or all( + "embed_tokens" not in n.name for n in result.graph.node if n.op_type == "Gather" + ) + gbq_nodes = [n for n in result.graph.node if n.op_type == "GatherBlockQuantized"] + assert len(gbq_nodes) == 1 + + gbq = gbq_nodes[0] + attrs = {a.name: a.i for a in gbq.attribute} + assert attrs["bits"] == 8 + assert attrs["block_size"] == 32 + + # Verify zero_point input exists (4 inputs: weight, input_ids, scales, zp) + assert len(gbq.input) == 4 + + def test_reduces_weight_size(self): + model = _make_model_with_fp16_embed(vocab_size=256, hidden_size=128) + surgery = QuantizeEmbeddingInt8() + + result = surgery(model) + + # FP16 weight should be removed + fp16_names = [init.name for init in result.graph.initializer if init.name == "model.embed_tokens.weight"] + assert len(fp16_names) == 0 + + # INT8 weight should exist + int8_names = [init.name for init in result.graph.initializer if "_Q8" in init.name] + assert len(int8_names) == 1 + + def test_skips_non_fp16(self): + model = _make_model_with_fp16_embed() + surgery = QuantizeEmbeddingInt8() + + # First pass: quantize to INT8 + result1 = surgery(model) + # Second pass: should skip (already quantized) + result2 = surgery(result1) + + # Should still have exactly 1 GBQ node + gbq_count = sum(1 for n in result2.graph.node if n.op_type == "GatherBlockQuantized") + assert gbq_count == 1 + + def test_skips_when_hidden_not_divisible(self): + # hidden_size=33, not divisible by block_size=32 + embed_weight = np.random.randn(64, 33).astype(np.float16) + embed_init = numpy_helper.from_array(embed_weight, name="model.embed_tokens.weight") + input_ids = helper.make_tensor_value_info("input_ids", TensorProto.INT64, [1]) + output = helper.make_tensor_value_info("out", TensorProto.FLOAT16, [1, 33]) + gather = helper.make_node( + "Gather", ["model.embed_tokens.weight", "input_ids"], ["out"], name="/model/embed_tokens/Gather" + ) + graph = helper.make_graph([gather], "test", [input_ids], [output], initializer=[embed_init]) + model = helper.make_model( + graph, opset_imports=[helper.make_opsetid("", 21), helper.make_opsetid("com.microsoft", 1)] + ) + + surgery = QuantizeEmbeddingInt8() + result = surgery(model) + + # Should still have Gather (not replaced) + assert any(n.op_type == "Gather" for n in result.graph.node) + + +class TestShareEmbeddingLmHead: + def test_shares_weight(self): + model = _make_model_with_fp16_embed() + + # First quantize embedding to INT8 + q_surgery = QuantizeEmbeddingInt8() + model = q_surgery(model) + + # Then share + s_surgery = ShareEmbeddingLmHead() + result = s_surgery(model) + + # lm_head should now be INT8 + lm_head = next(n for n in result.graph.node if n.op_type == "MatMulNBits" and "lm_head" in n.name) + attrs = {a.name: a.i for a in lm_head.attribute} + assert attrs["bits"] == 8 + + # Should have a Reshape node for weight sharing + reshape_nodes = [n for n in result.graph.node if "Reshape_shared" in n.name] + assert len(reshape_nodes) == 1 + + # Reshape should reference the embedding weight + reshape = reshape_nodes[0] + assert "embed_tokens" in reshape.input[0] + + # lm_head should use shared scales + assert "embed_tokens" in lm_head.input[2] # scales + + def test_idempotent(self): + model = _make_model_with_fp16_embed() + q_surgery = QuantizeEmbeddingInt8() + model = q_surgery(model) + + s_surgery = ShareEmbeddingLmHead() + result1 = s_surgery(model) + # Applying again should be a no-op + result2 = s_surgery(result1) + + # Should still have exactly 1 Reshape_shared node + reshape_count = sum(1 for n in result2.graph.node if "Reshape_shared" in n.name) + assert reshape_count == 1 + + def test_skips_without_gbq(self): + model = _make_model_with_fp16_embed() + # Don't quantize embedding first + s_surgery = ShareEmbeddingLmHead() + result = s_surgery(model) + + # Should be unchanged — still has Gather + assert any(n.op_type == "Gather" for n in result.graph.node) + + def test_removes_old_lm_head_weights(self): + model = _make_model_with_fp16_embed() + q_surgery = QuantizeEmbeddingInt8() + model = q_surgery(model) + + s_surgery = ShareEmbeddingLmHead() + result = s_surgery(model) + + new_init_names = {init.name for init in result.graph.initializer} + + # Old lm_head weights should be removed + assert "lm_head.MatMul_Q4.qweight" not in new_init_names + assert "lm_head.MatMul_Q4.scales" not in new_init_names + assert "lm_head.MatMul_Q4.zp" not in new_init_names diff --git a/test/systems/docker/test_docker_system.py b/test/systems/docker/test_docker_system.py index 5430b68587..ef20a43d18 100644 --- a/test/systems/docker/test_docker_system.py +++ b/test/systems/docker/test_docker_system.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- +import json from unittest.mock import MagicMock, patch import pytest @@ -140,6 +141,37 @@ def test_run_workflow(self, mock_find_resources, mock_tempdir, mock_from_env, tm # Verify cleanup mock_container.remove.assert_called_once() + @patch("olive.systems.docker.docker_system.docker.from_env") + def test_prepare_environment_forwards_ci_to_workflow_container(self, mock_from_env, monkeypatch): + mock_docker_client = MagicMock() + mock_from_env.return_value = mock_docker_client + mock_docker_client.images.get.return_value = MagicMock() + monkeypatch.setenv("TF_BUILD", "True") + docker_config = self.get_default_docker_config() + docker_system = DockerSystem( + image_name=docker_config.image_name, + build_context_path=docker_config.build_context_path, + dockerfile=docker_config.dockerfile, + work_dir=docker_config.work_dir, + ) + + environment = docker_system._prepare_environment({}) + + assert environment["CI"] == "1" + + def test_workflow_runner_disables_inner_recipe_telemetry(self, tmp_path, monkeypatch): + from olive.systems.docker import workflow_runner + + monkeypatch.delenv("HF_TOKEN", raising=False) + config = {"input_model": {"type": "ONNXModel", "model_path": "model.onnx"}} + config_path = tmp_path / "config.json" + config_path.write_text(json.dumps(config)) + + with patch.object(workflow_runner, "olive_run") as mock_olive_run: + workflow_runner.runner_entry(config_path) + + mock_olive_run.assert_called_once_with(config, emit_recipe_telemetry=False) + @patch("olive.systems.docker.docker_system.docker.from_env") @patch("olive.systems.docker.docker_system.tempfile.TemporaryDirectory") @patch("olive.systems.docker.docker_system.find_all_resources") diff --git a/test/test_telemetry.py b/test/test_telemetry.py new file mode 100644 index 0000000000..c9885163ff --- /dev/null +++ b/test/test_telemetry.py @@ -0,0 +1,226 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +# pylint: disable=protected-access +import json +import os +import subprocess +import sys +import threading +import time +from pathlib import Path +from types import SimpleNamespace +from unittest.mock import Mock, patch + +import pytest + +from olive.telemetry.telemetry import ( + ACTION_EVENT_NAME, + CACHE_FILE_NAME, + RECIPE_EVENT_NAME, + Telemetry, + TelemetryCacheHandler, +) +from olive.telemetry.utils import _exclusive_file_lock + + +def test_cache_path_uses_env_override(tmp_path, monkeypatch): + cache_dir = tmp_path / "telemetry-cache" + monkeypatch.setenv("OLIVE_TELEMETRY_CACHE_DIR", str(cache_dir)) + + handler = TelemetryCacheHandler(Mock()) + + assert handler.cache_path == cache_dir / CACHE_FILE_NAME + assert isinstance(handler.cache_path, Path) + + +def test_cache_path_ignores_empty_env_override(tmp_path, monkeypatch): + monkeypatch.setenv("OLIVE_TELEMETRY_CACHE_DIR", " ") + + with patch("olive.telemetry.telemetry.get_telemetry_base_dir", return_value=tmp_path): + handler = TelemetryCacheHandler(Mock()) + assert handler.cache_path == tmp_path / "cache" / CACHE_FILE_NAME + + +def test_telemetry_only_logs_recipe_events_in_ci(monkeypatch): + monkeypatch.setenv("CI", "1") + Telemetry._instance = None + + mock_logger = Mock() + mock_logger.register_payload_transmitted_callback.return_value = lambda: None + + try: + with patch("olive.telemetry.telemetry.get_telemetry_logger", return_value=mock_logger): + telemetry = Telemetry() + telemetry.log(ACTION_EVENT_NAME, {"action_name": "WorkflowRun", "duration_ms": 1, "success": False}) + telemetry.log(RECIPE_EVENT_NAME, {"recipe_name": "WorkflowRun", "success": False}) + + assert mock_logger.log.call_count == 1 + assert mock_logger.log.call_args.args[0] == RECIPE_EVENT_NAME + assert telemetry._cache_handler is None + mock_logger.register_payload_transmitted_callback.assert_not_called() + finally: + Telemetry._instance = None + + +def test_flush_cache_preserves_nonempty_unreadable_file(tmp_path): + handler = TelemetryCacheHandler(Mock()) + cache_path = tmp_path / CACHE_FILE_NAME + flush_path = cache_path.with_name(f"{cache_path.name}.flush") + cache_path.write_text("not-json\n", encoding="utf-8") + + handler._flush_cache_file(cache_path) + + assert cache_path.exists() + assert cache_path.read_text(encoding="utf-8") == "not-json\n" + assert not flush_path.exists() + + +def _write_cache_entry(cache_path, event_name="TestEvent", payload=None): + cache_path.parent.mkdir(parents=True, exist_ok=True) + entry = { + "event_name": event_name, + "event_data": json.dumps(payload if payload is not None else {"key": "value"}), + "ts": 12345, + "initTs": 12345, + } + cache_path.write_text(json.dumps(entry) + "\n", encoding="utf-8") + return entry + + +def _make_replay_handler(success): + telemetry = Mock() + handler = TelemetryCacheHandler(telemetry) + # Pretend we're already in a flush so callbacks are treated as replays. + handler._is_flushing = True + + def fake_log(_event_name, _attrs, _metadata): + handler.record_event_logged() + handler.on_payload_transmitted(SimpleNamespace(succeeded=success, item_count=1, payload_bytes=b"")) + + telemetry.log.side_effect = fake_log + return handler, telemetry + + +def test_flush_deletes_cache_when_replay_succeeds(tmp_path): + handler, _ = _make_replay_handler(success=True) + cache_path = tmp_path / CACHE_FILE_NAME + flush_path = cache_path.with_name(f"{cache_path.name}.flush") + _write_cache_entry(cache_path) + + handler._flush_cache_file(cache_path) + + assert not cache_path.exists() + assert not flush_path.exists() + + +def test_flush_restores_cache_when_replay_fails(tmp_path): + handler, _ = _make_replay_handler(success=False) + cache_path = tmp_path / CACHE_FILE_NAME + flush_path = cache_path.with_name(f"{cache_path.name}.flush") + _write_cache_entry(cache_path, event_name="ReplayedEvent") + + handler._flush_cache_file(cache_path) + + # Failed replay must preserve the cached event so a later flush can retry, + # rather than silently dropping it. + assert cache_path.exists() + assert "ReplayedEvent" in cache_path.read_text(encoding="utf-8") + assert not flush_path.exists() + + +def test_flush_restores_cache_when_callbacks_timeout(tmp_path, monkeypatch): + telemetry = Mock() + handler = TelemetryCacheHandler(telemetry) + handler._is_flushing = True + cache_path = tmp_path / CACHE_FILE_NAME + flush_path = cache_path.with_name(f"{cache_path.name}.flush") + _write_cache_entry(cache_path, event_name="OrphanedEvent") + + # Simulate replay that logs the event but never fires the callback + # (e.g. exporter dropped or stalled). wait_for_callbacks should time out. + def fake_log(_event_name, _attrs, _metadata): + handler.record_event_logged() + + telemetry.log.side_effect = fake_log + monkeypatch.setattr(handler, "wait_for_callbacks", lambda **_: False) + + handler._flush_cache_file(cache_path) + + assert cache_path.exists() + assert "OrphanedEvent" in cache_path.read_text(encoding="utf-8") + assert not flush_path.exists() + + +def test_wait_until_flush_complete_wakes_when_flush_clears(): + handler = TelemetryCacheHandler(Mock()) + handler._is_flushing = True + + def clear_flag(): + time.sleep(0.05) + with handler._condition: + handler._is_flushing = False + handler._condition.notify_all() + + threading.Thread(target=clear_flag, daemon=True).start() + + start = time.perf_counter() + completed = handler.wait_until_flush_complete(1.0) + elapsed = time.perf_counter() - start + + assert completed is True + # Should wake on notify, not poll the full timeout + assert elapsed < 0.5 + + +def test_wait_until_flush_complete_returns_false_on_timeout(): + handler = TelemetryCacheHandler(Mock()) + handler._is_flushing = True + + assert handler.wait_until_flush_complete(0.05) is False + + +@pytest.mark.skipif(os.name != "nt", reason="Windows locking behavior is specific to Windows.") +def test_exclusive_file_lock_blocks_second_append_on_windows(tmp_path): + file_path = tmp_path / "olive.json" + child_code = """ +import sys +import time +from pathlib import Path +from olive.telemetry.utils import _exclusive_file_lock + +path = Path(sys.argv[1]) +path.write_text("payload", encoding="utf-8") +with _exclusive_file_lock(path, "a") as locked_file: + locked_file.write("child") + locked_file.flush() + print("locked", flush=True) + time.sleep(2) +""" + + with subprocess.Popen( + [sys.executable, "-c", child_code, str(file_path)], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) as process: + assert process.stdout is not None + assert process.stdout.readline().strip() == "locked" + + start = time.perf_counter() + with _exclusive_file_lock(file_path, mode="a") as locked_file: + wait_time = time.perf_counter() - start + locked_file.write("parent") + + assert wait_time >= 1.0 + + try: + stdout, stderr = process.communicate(timeout=5) + except subprocess.TimeoutExpired: + process.kill() + stdout, stderr = process.communicate() + pytest.fail(f"child lock process timed out: stdout={stdout!r} stderr={stderr!r}") + + assert process.returncode == 0, stderr + assert file_path.read_text(encoding="utf-8") == "payloadchildparent" diff --git a/test/workflows/test_workflow_run.py b/test/workflows/test_workflow_run.py index 82cc4980bf..6af0118374 100644 --- a/test/workflows/test_workflow_run.py +++ b/test/workflows/test_workflow_run.py @@ -1,10 +1,16 @@ +import json import sys from copy import deepcopy from pathlib import Path -from unittest.mock import patch +from unittest.mock import Mock, patch import pytest +from olive.telemetry.recipe_telemetry import ( + _build_recipe_hash, + _classify_input_model_source, + _classify_run_config_source, +) from olive.workflows import run as olive_run from test.utils import ( get_pytorch_model, @@ -125,3 +131,313 @@ def test_run_packages(): # cleanup requirements_file_path.unlink() + + +@patch("olive.workflows.run.run.log_recipe_result") +@patch("olive.workflows.run.run.run_engine") +@patch("olive.telemetry.recipe_telemetry.is_ci_environment", return_value=False) +def test_run_logs_recipe_result_success(_, mock_run_engine, mock_log_recipe_result): + config = { + "input_model": { + "type": "HfModel", + "model_path": "Qwen/Qwen2.5-0.5B-Instruct", + "task": "text-generation", + "load_kwargs": {"attn_implementation": "eager"}, + }, + "systems": { + "local_system": { + "type": "LocalSystem", + "accelerators": [{"device": "gpu", "execution_providers": ["CUDAExecutionProvider"]}], + } + }, + "engine": {"target": "local_system"}, + "passes": {"dynamic_quant": {"type": "OnnxDynamicQuantization"}}, + } + expected_output = object() + mock_run_engine.return_value = expected_output + + output = olive_run( + config, + recipe_telemetry_metadata={ + "recipe_name": "Quantize", + "recipe_command": "Quantize", + "recipe_source": "generated_cli", + "recipe_format": "generated", + }, + ) + + assert output is expected_output + mock_log_recipe_result.assert_called_once() + assert mock_log_recipe_result.call_args.args[0] == "Quantize" + assert mock_log_recipe_result.call_args.kwargs["success"] is True + + metadata = mock_log_recipe_result.call_args.kwargs["metadata"] + assert metadata["recipe_command"] == "Quantize" + assert metadata["recipe_source"] == "generated_cli" + assert metadata["recipe_format"] == "generated" + assert metadata["workflow_id"] == "default_workflow" + assert metadata["input_model_type"] == "hfmodel" + assert metadata["input_model_source"] == "string_name" + assert metadata["model_task"] == "text-generation" + assert metadata["target_system_type"] == "LocalSystem" + assert metadata["target_device"] == "gpu" + assert metadata["target_execution_provider"] == "CUDAExecutionProvider" + assert metadata["target_execution_providers"] == "CUDAExecutionProvider" + assert metadata["host_system_type"] == "LocalSystem" + assert "host_device" not in metadata + assert "host_execution_provider" not in metadata + assert "host_execution_providers" not in metadata + assert metadata["pass_types"] == "onnxdynamicquantization" + assert metadata["pass_count"] == 1 + assert metadata["data_config_count"] == 0 + assert metadata["search_enabled"] is False + assert metadata["package_config_provided"] is False + assert metadata["is_ci"] is False + assert metadata["recipe_hash"] + assert "input_model_name_hash" not in metadata + assert "config_overrides" not in metadata + + +@patch("olive.workflows.run.run.log_recipe_result") +@patch("olive.workflows.run.run.run_engine") +def test_run_logs_config_overrides_when_recipe_metadata_provides_overrides(mock_run_engine, mock_log_recipe_result): + config = { + "input_model": { + "type": "HfModel", + "model_path": "Qwen/Qwen2.5-0.5B-Instruct", + "task": "text-generation", + } + } + mock_run_engine.return_value = object() + + olive_run( + config, + recipe_telemetry_metadata={ + "recipe_name": "WorkflowRun", + "config_overrides": { + "input_model": { + "type": "HfModel", + "model_path": "Qwen/Qwen2.5-0.5B-Instruct", + }, + "engine": {"target": "local_system"}, + "data_path": Path("data"), + }, + }, + ) + + metadata = mock_log_recipe_result.call_args.kwargs["metadata"] + config_overrides = json.loads(metadata["config_overrides"]) + assert config_overrides["input_model"]["model_path"] == "Qwen/Qwen2.5-0.5B-Instruct" + assert config_overrides["engine"]["target"] == "" + assert config_overrides["data_path"] == "" + + +@patch("olive.workflows.run.run.log_error") +@patch("olive.workflows.run.run.log_recipe_result") +@patch("olive.workflows.run.run.run_engine") +def test_run_logs_recipe_result_failure(mock_run_engine, mock_log_recipe_result, mock_log_error): + config = { + "input_model": { + "type": "HfModel", + "model_path": "Qwen/Qwen2.5-0.5B-Instruct", + "task": "text-generation", + "load_kwargs": {"attn_implementation": "eager"}, + }, + "passes": {"dynamic_quant": {"type": "OnnxDynamicQuantization"}}, + } + mock_run_engine.side_effect = ValueError("recipe failed") + + with pytest.raises(ValueError, match="recipe failed"): + olive_run( + config, + recipe_telemetry_metadata={ + "recipe_name": "Quantize", + "recipe_command": "Quantize", + "recipe_source": "generated_cli", + "recipe_format": "generated", + }, + ) + + mock_log_recipe_result.assert_called_once() + assert mock_log_recipe_result.call_args.args[0] == "Quantize" + assert mock_log_recipe_result.call_args.kwargs["success"] is False + assert "exception_type" not in mock_log_recipe_result.call_args.kwargs + mock_log_error.assert_called_once() + assert mock_log_error.call_args.kwargs["exception_type"] == "ValueError" + assert "recipe failed" in mock_log_error.call_args.kwargs["exception_message"] + + +@patch("olive.workflows.run.run.log_recipe_result") +@patch("olive.workflows.run.run.run_engine") +def test_run_skips_recipe_result_when_recipe_telemetry_is_not_emitted(mock_run_engine, mock_log_recipe_result): + expected_output = object() + mock_run_engine.return_value = expected_output + + output = olive_run( + { + "input_model": { + "type": "HfModel", + "model_path": "Qwen/Qwen2.5-0.5B-Instruct", + "task": "text-generation", + } + }, + emit_recipe_telemetry=False, + ) + + assert output is expected_output + mock_log_recipe_result.assert_not_called() + + +@patch("olive.workflows.run.run.log_recipe_result") +@patch("olive.systems.system_config.SystemConfig.create_system") +def test_run_logs_single_parent_recipe_result_for_docker_host(mock_create_system, mock_log_recipe_result): + expected_output = object() + docker_system = Mock() + + def run_workflow(container_run_config): + container_run_config.engine.host = container_run_config.engine.target + return expected_output + + docker_system.run_workflow.side_effect = run_workflow + mock_create_system.return_value = docker_system + config = { + "input_model": {"type": "ONNXModel", "model_path": "model.onnx"}, + "systems": { + "docker_system": { + "type": "Docker", + "config": { + "dockerfile": "Dockerfile", + "build_context_path": "build_context", + "image_name": "test-image:latest", + "work_dir": "/olive-ws", + }, + }, + "local_system": {"type": "LocalSystem"}, + }, + "engine": {"host": "docker_system", "target": "local_system"}, + } + + output = olive_run(config) + + assert output is expected_output + mock_log_recipe_result.assert_called_once() + metadata = mock_log_recipe_result.call_args.kwargs["metadata"] + assert metadata["host_system_type"] == "Docker" + + +@patch("olive.workflows.run.run.log_recipe_result") +@patch("olive.workflows.run.run.run_engine") +def test_run_logs_recipe_host_metadata_without_explicit_target(mock_run_engine, mock_log_recipe_result): + config = { + "input_model": { + "type": "HfModel", + "model_path": "Qwen/Qwen2.5-0.5B-Instruct", + "task": "text-generation", + "load_kwargs": {"attn_implementation": "eager"}, + }, + "systems": { + "host_system": { + "type": "LocalSystem", + "accelerators": [{"device": "cpu", "execution_providers": ["CPUExecutionProvider"]}], + } + }, + "engine": {"host": "host_system"}, + } + mock_run_engine.return_value = object() + + olive_run( + config, + recipe_telemetry_metadata={ + "recipe_name": "Quantize", + "recipe_command": "Quantize", + "recipe_source": "generated_cli", + "recipe_format": "generated", + }, + ) + + metadata = mock_log_recipe_result.call_args.kwargs["metadata"] + assert "target_system_type" not in metadata + assert "target_device" not in metadata + assert "target_execution_provider" not in metadata + assert "target_execution_providers" not in metadata + assert metadata["host_system_type"] == "LocalSystem" + assert metadata["host_device"] == "cpu" + assert metadata["host_execution_provider"] == "CPUExecutionProvider" + assert metadata["host_execution_providers"] == "CPUExecutionProvider" + + +@patch("olive.workflows.run.run.log_recipe_result") +@patch("olive.workflows.run.run.run_engine") +def test_run_logs_package_config_overrides_when_package_config_provided(mock_run_engine, mock_log_recipe_result): + config = { + "input_model": { + "type": "HfModel", + "model_path": "Qwen/Qwen2.5-0.5B-Instruct", + "task": "text-generation", + "load_kwargs": {"attn_implementation": "eager"}, + } + } + mock_run_engine.return_value = object() + + olive_run( + config, + package_config={ + "passes": { + "AddOliveMetadata": { + "module_path": "olive.passes.onnx.add_metadata.AddOliveMetadata", + "supported_providers": ["CPUExecutionProvider"], + } + }, + "extra_dependencies": {"custom_accelerator": ["custom-package"]}, + }, + recipe_telemetry_metadata={ + "recipe_name": "Quantize", + "recipe_command": "Quantize", + "recipe_source": "generated_cli", + "recipe_format": "generated", + }, + ) + + metadata = mock_log_recipe_result.call_args.kwargs["metadata"] + assert metadata["package_config_provided"] is True + package_config_overrides = json.loads(metadata["package_config_overrides"]) + assert package_config_overrides["passes"][0]["supported_providers"] == ["CPUExecutionProvider"] + assert "module_path" not in package_config_overrides["passes"][0] + assert package_config_overrides["extra_dependencies"]["custom_accelerator"] == ["custom-package"] + + +def test_classify_run_config_source_handles_non_pathlike_object(): + assert _classify_run_config_source(object()) == ("config_object", "object") + + +def test_classify_input_model_source_does_not_depend_on_local_filesystem(tmp_path, monkeypatch): + assert _classify_input_model_source("Qwen/Qwen2.5-0.5B-Instruct") == "string_name" + + monkeypatch.chdir(tmp_path) + (tmp_path / "bert-base-uncased").mkdir() + + assert _classify_input_model_source("bert-base-uncased") == "string_name" + assert _classify_input_model_source("./model.onnx") == "local_file" + assert _classify_input_model_source("model.onnx") == "local_file" + + +def test_recipe_hash_does_not_depend_on_local_model_path_presence(tmp_path, monkeypatch): + config = { + "input_model": {"type": "HfModel", "config": {"model_path": "bert-base-uncased"}}, + "engine": {"output_dir": "output"}, + } + recipe_hash = _build_recipe_hash(config) + + monkeypatch.chdir(tmp_path) + (tmp_path / "bert-base-uncased").mkdir() + + assert _build_recipe_hash(config) == recipe_hash + + +def test_recipe_hash_handles_path_values(): + config = { + "input_model": {"type": "HfModel", "config": {"model_path": Path("model")}}, + "custom_value": Path("custom"), + } + + assert _build_recipe_hash(config)