Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 59 additions & 16 deletions olive/cli/capture_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,30 @@ def register_subcommand(parser: ArgumentParser):
help="The device used to run the model to capture the ONNX graph.",
)

# PyTorch Exporter options
pte_group = sub_parser.add_argument_group("PyTorch Exporter options")
pte_group.add_argument(
# Mutually exclusive exporter flags
exporter_group = sub_parser.add_mutually_exclusive_group()
exporter_group.add_argument(
"--use_dynamo_exporter",
action="store_true",
help="Whether to use dynamo_export API to export ONNX model.",
)
exporter_group.add_argument(
"--use_model_builder",
action="store_true",
help="Whether to use Model Builder to capture ONNX model.",
)
exporter_group.add_argument(
"--use_mobius_builder",
action="store_true",
help=(
"Whether to use MobiusBuilder (mobius-ai) to capture ONNX model. "
"Supports multi-component multimodal models (VLMs). "
"Requires 'pip install mobius-ai'."
),
)

# PyTorch Exporter options
pte_group = sub_parser.add_argument_group("PyTorch Exporter options")
pte_group.add_argument(
"--fixed_param_dict",
type=parse_dim_dict,
Expand Down Expand Up @@ -108,11 +125,6 @@ def register_subcommand(parser: ArgumentParser):

# Model Builder options
mb_group = sub_parser.add_argument_group("Model Builder options")
mb_group.add_argument(
"--use_model_builder",
action="store_true",
help="Whether to use Model Builder to capture ONNX model.",
)
mb_group.add_argument(
"--precision",
type=str,
Expand Down Expand Up @@ -182,12 +194,16 @@ def run(self):
def _get_run_config(self, tempdir: str) -> dict:
config = deepcopy(TEMPLATE)

# Check if diffusers model detection is needed
is_diffusers = is_valid_diffusers_model(self.args.model_name_or_path) if self.args.model_name_or_path else False
if is_diffusers:
input_model_config = get_diffusers_input_model(self.args, self.args.model_name_or_path)
else:
if self.args.use_mobius_builder:
input_model_config = get_input_model_config(self.args)
else:
is_diffusers = (
is_valid_diffusers_model(self.args.model_name_or_path) if self.args.model_name_or_path else False
)
if is_diffusers:
input_model_config = get_diffusers_input_model(self.args, self.args.model_name_or_path)
else:
input_model_config = get_input_model_config(self.args)
assert input_model_config["type"].lower() in {
"hfmodel",
"pytorchmodel",
Expand All @@ -197,8 +213,14 @@ def _get_run_config(self, tempdir: str) -> dict:
is_diffusers_model = input_model_config["type"].lower() == "diffusersmodel"

# whether model is in fp16 or bf16 (currently not supported by CPU EP)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cpu supports f16 right?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not all of the cpus support fp16, maybe we need to revisit this later

is_fp16_or_bf16 = (not self.args.use_model_builder and self.args.torch_dtype == "float16") or (
self.args.use_model_builder and self.args.precision in ("fp16", "bf16")
is_fp16_or_bf16 = (
(
not self.args.use_model_builder
and not self.args.use_mobius_builder
and self.args.torch_dtype == "float16"
)
or (self.args.use_model_builder and self.args.precision in ("fp16", "bf16"))
or (self.args.use_mobius_builder and self.args.precision in ("fp16", "bf16"))
)
to_replace = [
("input_model", input_model_config),
Expand All @@ -211,8 +233,26 @@ def _get_run_config(self, tempdir: str) -> dict:
),
]

if is_diffusers_model:
if self.args.use_mobius_builder:
if self.args.precision not in ("fp32", "fp16", "bf16"):
raise ValueError(
f"MobiusBuilder supports precisions fp32/fp16/bf16; got '{self.args.precision}'. "
"For INT4, capture in fp32/fp16/bf16 first and run a quantization pass afterwards."
)
del config["passes"]["c"]
del config["passes"]["m"]
to_replace.extend(
[
(("passes", "b", "precision"), self.args.precision),
(
("passes", "b", "runtime"),
"ort-genai" if self.args.use_ort_genai else "none",
),
]
)
elif is_diffusers_model:
del config["passes"]["m"]
del config["passes"]["b"]
to_replace.extend(
[
(
Expand All @@ -225,6 +265,7 @@ def _get_run_config(self, tempdir: str) -> dict:
)
elif self.args.use_model_builder:
del config["passes"]["c"]
del config["passes"]["b"]
to_replace.extend(
[
(("passes", "m", "precision"), self.args.precision),
Expand All @@ -245,6 +286,7 @@ def _get_run_config(self, tempdir: str) -> dict:
if self.args.int4_accuracy_level is not None:
to_replace.append((("passes", "m", "int4_accuracy_level"), self.args.int4_accuracy_level))
else:
del config["passes"]["b"]
to_replace.extend(
[
(
Expand Down Expand Up @@ -300,6 +342,7 @@ def _get_run_config(self, tempdir: str) -> dict:
"type": "OnnxConversion",
},
"m": {"type": "ModelBuilder", "metadata_only": False},
"b": {"type": "MobiusBuilder"},
"f": {"type": "DynamicToFixedShape"},
},
"host": "local_system",
Expand Down
85 changes: 85 additions & 0 deletions test/cli/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,91 @@ def test_capture_onnx_command_fix_shape(_, mock_run, use_model_builder, tmp_path
assert mock_run.call_count == 1


@patch("olive.workflows.run")
@patch("huggingface_hub.repo_exists", return_value=True)
@pytest.mark.parametrize(
("precision", "use_ort_genai"),
[
("fp16", True),
("fp32", False),
("bf16", True),
],
)
def test_capture_onnx_command_use_mobius_builder(_, mock_run, precision, use_ort_genai, tmp_path):
# setup
output_dir = tmp_path / "output_dir"
model_id = "dummy-model-id"
command_args = [
"capture-onnx-graph",
"-m",
model_id,
"-o",
str(output_dir),
"--use_mobius_builder",
"--precision",
precision,
]
if use_ort_genai:
command_args.append("--use_ort_genai")

# execute
cli_main(command_args)

config = mock_run.call_args[0][0]
assert config["input_model"]["model_path"] == model_id
# MobiusBuilder ("b") is the only conversion pass; "c" (OnnxConversion) and "m" (ModelBuilder) are removed.
assert "b" in config["passes"]
assert "c" not in config["passes"]
assert "m" not in config["passes"]
assert config["passes"]["b"]["type"] == "MobiusBuilder"
assert config["passes"]["b"]["precision"] == precision
assert config["passes"]["b"]["runtime"] == ("ort-genai" if use_ort_genai else "none")
assert mock_run.call_count == 1


@patch("olive.workflows.run")
@patch("huggingface_hub.repo_exists", return_value=True)
def test_capture_onnx_command_use_mobius_builder_rejects_int4(_, __, tmp_path):
# setup
output_dir = tmp_path / "output_dir"
command_args = [
"capture-onnx-graph",
"-m",
"dummy-model-id",
"-o",
str(output_dir),
"--use_mobius_builder",
"--precision",
"int4",
]

# execute / verify
with pytest.raises(ValueError, match="MobiusBuilder supports precisions fp32/fp16/bf16"):
cli_main(command_args)


@patch("olive.workflows.run")
@patch("huggingface_hub.repo_exists", return_value=True)
@pytest.mark.parametrize("conflicting_flag", ["--use_model_builder", "--use_dynamo_exporter"])
def test_capture_onnx_command_use_mobius_builder_rejects_conflicts(_, __, conflicting_flag, tmp_path):
# setup
output_dir = tmp_path / "output_dir"
command_args = [
"capture-onnx-graph",
"-m",
"dummy-model-id",
"-o",
str(output_dir),
"--use_mobius_builder",
conflicting_flag,
]

# execute / verify
with pytest.raises(SystemExit) as exc_info:
cli_main(command_args)
assert exc_info.value.code == 2


@patch("olive.cli.shared_cache.AzureContainerClientFactory")
def test_shared_cache_command(mock_AzureContainerClientFactory):
# setup
Expand Down
Loading