From ff1917ad9a36cb3502b262b67f5dc9dfe4cd65a1 Mon Sep 17 00:00:00 2001 From: Ed Savage Date: Wed, 25 Mar 2026 11:41:56 +1300 Subject: [PATCH 1/4] [ML] Add SqueezeBERT and TinyRoBERTa to graph validation allowlist MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add aten::_convolution to the allowlist — SqueezeBERT uses 1D grouped convolutions instead of standard attention, which requires this op. Also add deepset/tinyroberta-squad2 and typeform/squeezebert-mnli to the validation and reference model configs. BART models (facebook/bart-large-mnli, valhalla/distilbart-mnli-12-6) cannot be TorchScript-traced due to their EncoderDecoderCache return type and are excluded from validation. Made-with: Cursor --- bin/pytorch_inference/CSupportedOperations.cc | 2 + .../testfiles/reference_model_ops.json | 71 +++++++++++++++++++ .../extract_model_ops/reference_models.json | 6 +- .../extract_model_ops/validation_models.json | 6 +- 4 files changed, 83 insertions(+), 2 deletions(-) diff --git a/bin/pytorch_inference/CSupportedOperations.cc b/bin/pytorch_inference/CSupportedOperations.cc index 250b86f79..821b9d176 100644 --- a/bin/pytorch_inference/CSupportedOperations.cc +++ b/bin/pytorch_inference/CSupportedOperations.cc @@ -38,6 +38,7 @@ const CSupportedOperations::TStringViewSet CSupportedOperations::FORBIDDEN_OPERA // elastic/multilingual-e5-small-optimized, intfloat/multilingual-e5-small, // .multilingual-e5-small (prepacked), elastic/splade-v3, // elastic/test-elser-v2, .rerank-v1 (Elastic rerank model), +// deepset/tinyroberta-squad2, typeform/squeezebert-mnli, // distilbert-base-uncased-finetuned-sst-2-english, // sentence-transformers/all-distilroberta-v1. // Eland-deployed variants of the above models (with pooling/normalization layers). @@ -51,6 +52,7 @@ const CSupportedOperations::TStringViewSet CSupportedOperations::ALLOWED_OPERATI "aten::IntImplicit"sv, "aten::ScalarImplicit"sv, "aten::__and__"sv, + "aten::_convolution"sv, "aten::abs"sv, "aten::add"sv, "aten::add_"sv, diff --git a/bin/pytorch_inference/unittest/testfiles/reference_model_ops.json b/bin/pytorch_inference/unittest/testfiles/reference_model_ops.json index 30e985582..516723e87 100644 --- a/bin/pytorch_inference/unittest/testfiles/reference_model_ops.json +++ b/bin/pytorch_inference/unittest/testfiles/reference_model_ops.json @@ -1006,6 +1006,77 @@ "prim::ListConstruct", "prim::NumToTensor" ] + }, + "qa-tinyroberta-squad2": { + "model_id": "deepset/tinyroberta-squad2", + "quantized": false, + "ops": [ + "aten::Int", + "aten::add", + "aten::add_", + "aten::cumsum", + "aten::detach", + "aten::dropout", + "aten::embedding", + "aten::expand", + "aten::gelu", + "aten::layer_norm", + "aten::linear", + "aten::masked_fill", + "aten::mul", + "aten::ne", + "aten::reshape", + "aten::scaled_dot_product_attention", + "aten::select", + "aten::size", + "aten::slice", + "aten::sub", + "aten::tanh", + "aten::to", + "aten::transpose", + "aten::type_as", + "aten::unsqueeze", + "aten::view", + "prim::Constant", + "prim::GetAttr", + "prim::ListConstruct", + "prim::NumToTensor", + "prim::TupleConstruct" + ] + }, + "qa-squeezebert-mnli": { + "model_id": "typeform/squeezebert-mnli", + "quantized": false, + "ops": [ + "aten::Int", + "aten::_convolution", + "aten::add", + "aten::contiguous", + "aten::div", + "aten::dropout", + "aten::embedding", + "aten::gelu", + "aten::layer_norm", + "aten::linear", + "aten::matmul", + "aten::mul", + "aten::permute", + "aten::rsub", + "aten::select", + "aten::size", + "aten::slice", + "aten::softmax", + "aten::tanh", + "aten::to", + "aten::unsqueeze", + "aten::view", + "aten::zeros", + "prim::Constant", + "prim::GetAttr", + "prim::ListConstruct", + "prim::NumToTensor", + "prim::TupleConstruct" + ] } } } diff --git a/dev-tools/extract_model_ops/reference_models.json b/dev-tools/extract_model_ops/reference_models.json index e2f270d35..e5c21871d 100644 --- a/dev-tools/extract_model_ops/reference_models.json +++ b/dev-tools/extract_model_ops/reference_models.json @@ -28,5 +28,9 @@ "_comment:quantized": "Quantized variants: Eland applies torch.quantization.quantize_dynamic on nn.Linear layers when importing models. These produce quantized::* ops not present in the standard traced graphs above.", "elastic-elser-v2-quantized": {"model_id": "elastic/elser-v2", "quantized": true}, "elastic-eis-elser-v2-quantized": {"model_id": "elastic/eis-elser-v2", "quantized": true}, - "elastic-test-elser-v2-quantized": {"model_id": "elastic/test-elser-v2", "quantized": true} + "elastic-test-elser-v2-quantized": {"model_id": "elastic/test-elser-v2", "quantized": true}, + + "_comment:qa-models": "Models from the Appex QA pytorch_tests suite. BART models cannot be TorchScript-traced and are excluded.", + "qa-tinyroberta-squad2": "deepset/tinyroberta-squad2", + "qa-squeezebert-mnli": "typeform/squeezebert-mnli" } diff --git a/dev-tools/extract_model_ops/validation_models.json b/dev-tools/extract_model_ops/validation_models.json index 0c853cdc5..a022d6a36 100644 --- a/dev-tools/extract_model_ops/validation_models.json +++ b/dev-tools/extract_model_ops/validation_models.json @@ -29,5 +29,9 @@ "es-multilingual-e5-small": "intfloat/multilingual-e5-small", "es-all-MiniLM-L6-v2": "sentence-transformers/all-MiniLM-L6-v2", "es-cross-encoder-ms-marco": "cross-encoder/ms-marco-MiniLM-L-6-v2", - "es-dpr-question-encoder": "facebook/dpr-question_encoder-single-nq-base" + "es-dpr-question-encoder": "facebook/dpr-question_encoder-single-nq-base", + + "_comment:qa-models": "Models from the Appex QA pytorch_tests suite (test_trained_model_boot_check). BART models (facebook/bart-large-mnli, valhalla/distilbart-mnli-12-6) cannot be TorchScript-traced and are excluded.", + "qa-tinyroberta-squad2": "deepset/tinyroberta-squad2", + "qa-squeezebert-mnli": "typeform/squeezebert-mnli" } From 5f3c76a91869981613838b4ac973e20ace29b8f6 Mon Sep 17 00:00:00 2001 From: Ed Savage Date: Wed, 25 Mar 2026 11:56:29 +1300 Subject: [PATCH 2/4] [ML] Add BART model ops to graph validation allowlist BART models (facebook/bart-large-mnli, valhalla/distilbart-mnli-12-6) require 6 additional ops: aten::clone, aten::copy_, aten::fill_, aten::full, aten::new_zeros, and aten::triu. These models can be TorchScript-traced by setting use_cache=False in the config and using AutoModelForSequenceClassification (the default AutoModel returns EncoderDecoderCache which is not TorchScript- compatible). Made-with: Cursor --- bin/pytorch_inference/CSupportedOperations.cc | 7 ++ .../testfiles/reference_model_ops.json | 98 +++++++++++++++++++ .../extract_model_ops/reference_models.json | 6 +- .../extract_model_ops/validation_models.json | 6 +- 4 files changed, 113 insertions(+), 4 deletions(-) diff --git a/bin/pytorch_inference/CSupportedOperations.cc b/bin/pytorch_inference/CSupportedOperations.cc index 821b9d176..d4172c4ee 100644 --- a/bin/pytorch_inference/CSupportedOperations.cc +++ b/bin/pytorch_inference/CSupportedOperations.cc @@ -39,6 +39,7 @@ const CSupportedOperations::TStringViewSet CSupportedOperations::FORBIDDEN_OPERA // .multilingual-e5-small (prepacked), elastic/splade-v3, // elastic/test-elser-v2, .rerank-v1 (Elastic rerank model), // deepset/tinyroberta-squad2, typeform/squeezebert-mnli, +// facebook/bart-large-mnli, valhalla/distilbart-mnli-12-6, // distilbert-base-uncased-finetuned-sst-2-english, // sentence-transformers/all-distilroberta-v1. // Eland-deployed variants of the above models (with pooling/normalization layers). @@ -64,7 +65,9 @@ const CSupportedOperations::TStringViewSet CSupportedOperations::ALLOWED_OPERATI "aten::chunk"sv, "aten::clamp"sv, "aten::clamp_min"sv, + "aten::clone"sv, "aten::contiguous"sv, + "aten::copy_"sv, "aten::cumsum"sv, "aten::detach"sv, "aten::div"sv, @@ -74,7 +77,9 @@ const CSupportedOperations::TStringViewSet CSupportedOperations::ALLOWED_OPERATI "aten::eq"sv, "aten::expand"sv, "aten::expand_as"sv, + "aten::fill_"sv, "aten::floor_divide"sv, + "aten::full"sv, "aten::full_like"sv, "aten::gather"sv, "aten::ge"sv, @@ -102,6 +107,7 @@ const CSupportedOperations::TStringViewSet CSupportedOperations::ALLOWED_OPERATI "aten::ne"sv, "aten::neg"sv, "aten::new_ones"sv, + "aten::new_zeros"sv, "aten::norm"sv, "aten::ones"sv, "aten::pad"sv, @@ -127,6 +133,7 @@ const CSupportedOperations::TStringViewSet CSupportedOperations::ALLOWED_OPERATI "aten::tensor"sv, "aten::to"sv, "aten::transpose"sv, + "aten::triu"sv, "aten::type_as"sv, "aten::unsqueeze"sv, "aten::view"sv, diff --git a/bin/pytorch_inference/unittest/testfiles/reference_model_ops.json b/bin/pytorch_inference/unittest/testfiles/reference_model_ops.json index 516723e87..b2ba8be16 100644 --- a/bin/pytorch_inference/unittest/testfiles/reference_model_ops.json +++ b/bin/pytorch_inference/unittest/testfiles/reference_model_ops.json @@ -1044,6 +1044,104 @@ "prim::TupleConstruct" ] }, + "qa-bart-large-mnli": { + "model_id": "facebook/bart-large-mnli", + "quantized": false, + "ops": [ + "aten::Int", + "aten::ScalarImplicit", + "aten::add", + "aten::arange", + "aten::clone", + "aten::contiguous", + "aten::copy_", + "aten::detach", + "aten::dropout", + "aten::embedding", + "aten::eq", + "aten::expand", + "aten::fill_", + "aten::full", + "aten::gelu", + "aten::gt", + "aten::index", + "aten::layer_norm", + "aten::linear", + "aten::masked_fill", + "aten::masked_fill_", + "aten::mul", + "aten::mul_", + "aten::new_zeros", + "aten::ones", + "aten::reshape", + "aten::scaled_dot_product_attention", + "aten::select", + "aten::size", + "aten::slice", + "aten::sub", + "aten::tanh", + "aten::to", + "aten::transpose", + "aten::triu", + "aten::unsqueeze", + "aten::view", + "prim::Constant", + "prim::GetAttr", + "prim::ListConstruct", + "prim::NumToTensor", + "prim::TupleConstruct", + "prim::TupleUnpack" + ] + }, + "qa-distilbart-mnli": { + "model_id": "valhalla/distilbart-mnli-12-6", + "quantized": false, + "ops": [ + "aten::Int", + "aten::ScalarImplicit", + "aten::add", + "aten::arange", + "aten::clone", + "aten::contiguous", + "aten::copy_", + "aten::detach", + "aten::dropout", + "aten::embedding", + "aten::eq", + "aten::expand", + "aten::fill_", + "aten::full", + "aten::gelu", + "aten::gt", + "aten::index", + "aten::layer_norm", + "aten::linear", + "aten::masked_fill", + "aten::masked_fill_", + "aten::mul", + "aten::mul_", + "aten::new_zeros", + "aten::ones", + "aten::reshape", + "aten::scaled_dot_product_attention", + "aten::select", + "aten::size", + "aten::slice", + "aten::sub", + "aten::tanh", + "aten::to", + "aten::transpose", + "aten::triu", + "aten::unsqueeze", + "aten::view", + "prim::Constant", + "prim::GetAttr", + "prim::ListConstruct", + "prim::NumToTensor", + "prim::TupleConstruct", + "prim::TupleUnpack" + ] + }, "qa-squeezebert-mnli": { "model_id": "typeform/squeezebert-mnli", "quantized": false, diff --git a/dev-tools/extract_model_ops/reference_models.json b/dev-tools/extract_model_ops/reference_models.json index e5c21871d..3b0db4c57 100644 --- a/dev-tools/extract_model_ops/reference_models.json +++ b/dev-tools/extract_model_ops/reference_models.json @@ -30,7 +30,9 @@ "elastic-eis-elser-v2-quantized": {"model_id": "elastic/eis-elser-v2", "quantized": true}, "elastic-test-elser-v2-quantized": {"model_id": "elastic/test-elser-v2", "quantized": true}, - "_comment:qa-models": "Models from the Appex QA pytorch_tests suite. BART models cannot be TorchScript-traced and are excluded.", + "_comment:qa-models": "Models from the Appex QA pytorch_tests suite. BART models require use_cache=False and AutoModelForSequenceClassification to trace.", "qa-tinyroberta-squad2": "deepset/tinyroberta-squad2", - "qa-squeezebert-mnli": "typeform/squeezebert-mnli" + "qa-squeezebert-mnli": "typeform/squeezebert-mnli", + "qa-bart-large-mnli": {"model_id": "facebook/bart-large-mnli", "quantized": false}, + "qa-distilbart-mnli": {"model_id": "valhalla/distilbart-mnli-12-6", "quantized": false} } diff --git a/dev-tools/extract_model_ops/validation_models.json b/dev-tools/extract_model_ops/validation_models.json index a022d6a36..5c8fd239c 100644 --- a/dev-tools/extract_model_ops/validation_models.json +++ b/dev-tools/extract_model_ops/validation_models.json @@ -31,7 +31,9 @@ "es-cross-encoder-ms-marco": "cross-encoder/ms-marco-MiniLM-L-6-v2", "es-dpr-question-encoder": "facebook/dpr-question_encoder-single-nq-base", - "_comment:qa-models": "Models from the Appex QA pytorch_tests suite (test_trained_model_boot_check). BART models (facebook/bart-large-mnli, valhalla/distilbart-mnli-12-6) cannot be TorchScript-traced and are excluded.", + "_comment:qa-models": "Models from the Appex QA pytorch_tests suite (test_trained_model_boot_check). BART models require use_cache=False and AutoModelForSequenceClassification to trace.", "qa-tinyroberta-squad2": "deepset/tinyroberta-squad2", - "qa-squeezebert-mnli": "typeform/squeezebert-mnli" + "qa-squeezebert-mnli": "typeform/squeezebert-mnli", + "qa-bart-large-mnli": {"model_id": "facebook/bart-large-mnli", "quantized": false}, + "qa-distilbart-mnli": {"model_id": "valhalla/distilbart-mnli-12-6", "quantized": false} } From 9ec64506f6922c19b1286cc50f7e66fb8b34229d Mon Sep 17 00:00:00 2001 From: Ed Savage Date: Wed, 25 Mar 2026 13:31:21 +1300 Subject: [PATCH 3/4] [ML] Extend extraction tooling to support BART model tracing Address Copilot review: instead of excluding BART models from the extraction configs, extend the tooling to support them natively. Add auto_class and config_overrides fields to the model config schema. load_and_trace_hf_model now accepts an optional Auto class name (e.g. "AutoModelForSequenceClassification") and config kwargs (e.g. {"use_cache": false}), which BART models require to avoid returning the non-TorchScript-compatible EncoderDecoderCache. Restore BART entries in validation_models.json and reference_models.json using the new schema. Made-with: Cursor --- .../extract_model_ops/extract_model_ops.py | 12 ++++-- .../extract_model_ops/reference_models.json | 6 +-- .../extract_model_ops/torchscript_utils.py | 40 ++++++++++++++++--- .../extract_model_ops/validation_models.json | 6 +-- 4 files changed, 50 insertions(+), 14 deletions(-) diff --git a/dev-tools/extract_model_ops/extract_model_ops.py b/dev-tools/extract_model_ops/extract_model_ops.py index 451369a6d..a4e46376c 100644 --- a/dev-tools/extract_model_ops/extract_model_ops.py +++ b/dev-tools/extract_model_ops/extract_model_ops.py @@ -47,14 +47,18 @@ def extract_ops_for_model(model_name: str, - quantize: bool = False) -> Optional[set[str]]: + quantize: bool = False, + auto_class: str | None = None, + config_overrides: dict | None = None) -> Optional[set[str]]: """Trace a HuggingFace model and return its TorchScript op set. Returns None if the model could not be loaded or traced. """ label = f"{model_name} (quantized)" if quantize else model_name print(f" Loading {label}...", file=sys.stderr) - traced = load_and_trace_hf_model(model_name, quantize=quantize) + traced = load_and_trace_hf_model(model_name, quantize=quantize, + auto_class=auto_class, + config_overrides=config_overrides) if traced is None: return None return collect_inlined_ops(traced) @@ -93,7 +97,9 @@ def main(): failed = [] for arch, spec in reference_models.items(): ops = extract_ops_for_model(spec["model_id"], - quantize=spec["quantized"]) + quantize=spec["quantized"], + auto_class=spec.get("auto_class"), + config_overrides=spec.get("config_overrides")) if ops is None: failed.append(arch) print(f" {arch}: FAILED", file=sys.stderr) diff --git a/dev-tools/extract_model_ops/reference_models.json b/dev-tools/extract_model_ops/reference_models.json index 3b0db4c57..8a4d86293 100644 --- a/dev-tools/extract_model_ops/reference_models.json +++ b/dev-tools/extract_model_ops/reference_models.json @@ -30,9 +30,9 @@ "elastic-eis-elser-v2-quantized": {"model_id": "elastic/eis-elser-v2", "quantized": true}, "elastic-test-elser-v2-quantized": {"model_id": "elastic/test-elser-v2", "quantized": true}, - "_comment:qa-models": "Models from the Appex QA pytorch_tests suite. BART models require use_cache=False and AutoModelForSequenceClassification to trace.", + "_comment:qa-models": "Models from the Appex QA pytorch_tests suite. BART models require auto_class and config_overrides to trace correctly.", "qa-tinyroberta-squad2": "deepset/tinyroberta-squad2", "qa-squeezebert-mnli": "typeform/squeezebert-mnli", - "qa-bart-large-mnli": {"model_id": "facebook/bart-large-mnli", "quantized": false}, - "qa-distilbart-mnli": {"model_id": "valhalla/distilbart-mnli-12-6", "quantized": false} + "qa-bart-large-mnli": {"model_id": "facebook/bart-large-mnli", "auto_class": "AutoModelForSequenceClassification", "config_overrides": {"use_cache": false}}, + "qa-distilbart-mnli": {"model_id": "valhalla/distilbart-mnli-12-6", "auto_class": "AutoModelForSequenceClassification", "config_overrides": {"use_cache": false}} } diff --git a/dev-tools/extract_model_ops/torchscript_utils.py b/dev-tools/extract_model_ops/torchscript_utils.py index 33042f261..322cd129b 100644 --- a/dev-tools/extract_model_ops/torchscript_utils.py +++ b/dev-tools/extract_model_ops/torchscript_utils.py @@ -11,6 +11,7 @@ # """Shared utilities for extracting and inspecting TorchScript operations.""" +import importlib import json import os import sys @@ -24,8 +25,14 @@ def load_model_config(config_path: Path) -> dict[str, dict]: """Load a model config JSON file and normalise entries. Each entry is either a plain model-name string or a dict with - ``model_id`` (required) and optional ``quantized`` boolean. All - entries are normalised to ``{"model_id": str, "quantized": bool}``. + ``model_id`` (required) and optional fields: + + - ``quantized`` (bool, default False) — apply dynamic quantization. + - ``auto_class`` (str) — transformers Auto class name to use instead + of ``AutoModel`` (e.g. ``"AutoModelForSequenceClassification"``). + - ``config_overrides`` (dict) — extra kwargs passed to + ``AutoConfig.from_pretrained`` (e.g. ``{"use_cache": false}``). + Keys starting with ``_comment`` are silently skipped. Raises ``ValueError`` for malformed entries so that config problems @@ -48,6 +55,8 @@ def load_model_config(config_path: Path) -> dict[str, dict]: models[key] = { "model_id": value["model_id"], "quantized": value.get("quantized", False), + "auto_class": value.get("auto_class"), + "config_overrides": value.get("config_overrides", {}), } else: raise ValueError( @@ -74,22 +83,43 @@ def collect_inlined_ops(module) -> set[str]: return collect_graph_ops(graph) -def load_and_trace_hf_model(model_name: str, quantize: bool = False): +def _resolve_auto_class(class_name: str | None): + """Resolve a transformers Auto class by name, defaulting to AutoModel.""" + if not class_name: + return AutoModel + import transformers + cls = getattr(transformers, class_name, None) + if cls is None: + raise ValueError(f"Unknown transformers class: {class_name}") + return cls + + +def load_and_trace_hf_model(model_name: str, quantize: bool = False, + auto_class: str | None = None, + config_overrides: dict | None = None): """Load a HuggingFace model, tokenize sample input, and trace to TorchScript. When *quantize* is True the model is dynamically quantized (nn.Linear layers converted to quantized::linear_dynamic) before tracing. This mirrors what Eland does when importing models for Elasticsearch. + *auto_class* selects a transformers Auto class by name (e.g. + ``"AutoModelForSequenceClassification"``). Defaults to ``AutoModel``. + + *config_overrides* supplies extra kwargs to ``AutoConfig.from_pretrained`` + (e.g. ``{"use_cache": False}`` for encoder-decoder models like BART). + Returns the traced module, or None if the model could not be loaded or traced. """ token = os.environ.get("HF_TOKEN") + model_cls = _resolve_auto_class(auto_class) + overrides = config_overrides or {} try: tokenizer = AutoTokenizer.from_pretrained(model_name, token=token) config = AutoConfig.from_pretrained( - model_name, torchscript=True, token=token) - model = AutoModel.from_pretrained( + model_name, torchscript=True, token=token, **overrides) + model = model_cls.from_pretrained( model_name, config=config, token=token) model.eval() except Exception as exc: diff --git a/dev-tools/extract_model_ops/validation_models.json b/dev-tools/extract_model_ops/validation_models.json index 5c8fd239c..fa4efee4c 100644 --- a/dev-tools/extract_model_ops/validation_models.json +++ b/dev-tools/extract_model_ops/validation_models.json @@ -31,9 +31,9 @@ "es-cross-encoder-ms-marco": "cross-encoder/ms-marco-MiniLM-L-6-v2", "es-dpr-question-encoder": "facebook/dpr-question_encoder-single-nq-base", - "_comment:qa-models": "Models from the Appex QA pytorch_tests suite (test_trained_model_boot_check). BART models require use_cache=False and AutoModelForSequenceClassification to trace.", + "_comment:qa-models": "Models from the Appex QA pytorch_tests suite. BART models require auto_class and config_overrides to trace correctly.", "qa-tinyroberta-squad2": "deepset/tinyroberta-squad2", "qa-squeezebert-mnli": "typeform/squeezebert-mnli", - "qa-bart-large-mnli": {"model_id": "facebook/bart-large-mnli", "quantized": false}, - "qa-distilbart-mnli": {"model_id": "valhalla/distilbart-mnli-12-6", "quantized": false} + "qa-bart-large-mnli": {"model_id": "facebook/bart-large-mnli", "auto_class": "AutoModelForSequenceClassification", "config_overrides": {"use_cache": false}}, + "qa-distilbart-mnli": {"model_id": "valhalla/distilbart-mnli-12-6", "auto_class": "AutoModelForSequenceClassification", "config_overrides": {"use_cache": false}} } From d67de97a92e06fd8f6ba0d8ae782c55ab460655a Mon Sep 17 00:00:00 2001 From: Ed Savage Date: Wed, 25 Mar 2026 15:01:30 +1300 Subject: [PATCH 4/4] [ML] Address Copilot review: remove unused import, normalise config schema Remove unused importlib import. Ensure string config entries return the same schema shape as dict entries (with auto_class=None and config_overrides={}) so callers don't need defensive .get() calls. Made-with: Cursor --- dev-tools/extract_model_ops/extract_model_ops.py | 4 ++-- dev-tools/extract_model_ops/torchscript_utils.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/dev-tools/extract_model_ops/extract_model_ops.py b/dev-tools/extract_model_ops/extract_model_ops.py index a4e46376c..2a070d1cc 100644 --- a/dev-tools/extract_model_ops/extract_model_ops.py +++ b/dev-tools/extract_model_ops/extract_model_ops.py @@ -98,8 +98,8 @@ def main(): for arch, spec in reference_models.items(): ops = extract_ops_for_model(spec["model_id"], quantize=spec["quantized"], - auto_class=spec.get("auto_class"), - config_overrides=spec.get("config_overrides")) + auto_class=spec["auto_class"], + config_overrides=spec["config_overrides"]) if ops is None: failed.append(arch) print(f" {arch}: FAILED", file=sys.stderr) diff --git a/dev-tools/extract_model_ops/torchscript_utils.py b/dev-tools/extract_model_ops/torchscript_utils.py index 322cd129b..af2b30f68 100644 --- a/dev-tools/extract_model_ops/torchscript_utils.py +++ b/dev-tools/extract_model_ops/torchscript_utils.py @@ -11,7 +11,6 @@ # """Shared utilities for extracting and inspecting TorchScript operations.""" -import importlib import json import os import sys @@ -46,7 +45,8 @@ def load_model_config(config_path: Path) -> dict[str, dict]: if key.startswith("_comment"): continue if isinstance(value, str): - models[key] = {"model_id": value, "quantized": False} + models[key] = {"model_id": value, "quantized": False, + "auto_class": None, "config_overrides": {}} elif isinstance(value, dict): if "model_id" not in value: raise ValueError(