[ML] Add SqueezeBERT and TinyRoBERTa to graph validation allowlist#3011
Conversation
Add aten::_convolution to the allowlist — SqueezeBERT uses 1D grouped convolutions instead of standard attention, which requires this op. Also add deepset/tinyroberta-squad2 and typeform/squeezebert-mnli to the validation and reference model configs. BART models (facebook/bart-large-mnli, valhalla/distilbart-mnli-12-6) cannot be TorchScript-traced due to their EncoderDecoderCache return type and are excluded from validation. Made-with: Cursor
✅ Snyk checks have passed. No issues have been found so far.
💻 Catch issues earlier using the plugins for VS Code, JetBrains IDEs, Visual Studio, and Eclipse. |
BART models (facebook/bart-large-mnli, valhalla/distilbart-mnli-12-6) require 6 additional ops: aten::clone, aten::copy_, aten::fill_, aten::full, aten::new_zeros, and aten::triu. These models can be TorchScript-traced by setting use_cache=False in the config and using AutoModelForSequenceClassification (the default AutoModel returns EncoderDecoderCache which is not TorchScript- compatible). Made-with: Cursor
There was a problem hiding this comment.
Pull request overview
Updates the PyTorch TorchScript graph validation setup to cover additional HuggingFace models used in the Appex QA pytorch_tests suite, and expands the C++ allowlist to accept the newly observed ops in those models’ traced graphs.
Changes:
- Add four QA models (TinyRoBERTa SQuAD2, SqueezeBERT MNLI, BART-large MNLI, DistilBART MNLI) to the validation/reference model configs.
- Expand the C++
ALLOWED_OPERATIONSset with ops required by SqueezeBERT (aten::_convolution) and BART (aten::clone,aten::copy_,aten::fill_,aten::full,aten::new_zeros,aten::triu). - Update the golden
reference_model_ops.jsonwith per-model op sets for the new QA models.
Reviewed changes
Copilot reviewed 2 out of 4 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| dev-tools/extract_model_ops/validation_models.json | Adds QA models to the Python allowlist validation config. |
| dev-tools/extract_model_ops/reference_models.json | Adds QA models to the reference extraction config used for allowlist generation. |
| bin/pytorch_inference/unittest/testfiles/reference_model_ops.json | Adds golden op lists for the new QA models (drift test input). |
| bin/pytorch_inference/CSupportedOperations.cc | Adds newly required ops to the C++ allowlist and updates reference model comments. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| "_comment:qa-models": "Models from the Appex QA pytorch_tests suite (test_trained_model_boot_check). BART models require use_cache=False and AutoModelForSequenceClassification to trace.", | ||
| "qa-tinyroberta-squad2": "deepset/tinyroberta-squad2", | ||
| "qa-squeezebert-mnli": "typeform/squeezebert-mnli", | ||
| "qa-bart-large-mnli": {"model_id": "facebook/bart-large-mnli", "quantized": false}, | ||
| "qa-distilbart-mnli": {"model_id": "valhalla/distilbart-mnli-12-6", "quantized": false} |
There was a problem hiding this comment.
The model comment notes BART tracing requires use_cache=False and AutoModelForSequenceClassification, but the validation tool (torchscript_utils.load_and_trace_hf_model) always uses AutoConfig(..., torchscript=True) + AutoModel.from_pretrained(...) and has no way to apply per-model config overrides or select a different AutoModel class. As a result, these BART entries are likely to either fail to trace in validate_allowlist.py or validate a different (non-sequence-classification) graph than the one used in production.
Consider extending the config schema to support per-model tracing options (e.g., model class selection and config overrides like use_cache) and updating load_and_trace_hf_model accordingly; alternatively, keep BART out of validation_models.json and validate via a local traced .pt artifact that matches the intended tracing path.
| "_comment:qa-models": "Models from the Appex QA pytorch_tests suite (test_trained_model_boot_check). BART models require use_cache=False and AutoModelForSequenceClassification to trace.", | |
| "qa-tinyroberta-squad2": "deepset/tinyroberta-squad2", | |
| "qa-squeezebert-mnli": "typeform/squeezebert-mnli", | |
| "qa-bart-large-mnli": {"model_id": "facebook/bart-large-mnli", "quantized": false}, | |
| "qa-distilbart-mnli": {"model_id": "valhalla/distilbart-mnli-12-6", "quantized": false} | |
| "_comment:qa-models": "Models from the Appex QA pytorch_tests suite (test_trained_model_boot_check). BART-based QA models are validated via separately traced artifacts and are intentionally omitted here.", | |
| "qa-tinyroberta-squad2": "deepset/tinyroberta-squad2", | |
| "qa-squeezebert-mnli": "typeform/squeezebert-mnli" |
| "_comment:qa-models": "Models from the Appex QA pytorch_tests suite. BART models require use_cache=False and AutoModelForSequenceClassification to trace.", | ||
| "qa-tinyroberta-squad2": "deepset/tinyroberta-squad2", | ||
| "qa-squeezebert-mnli": "typeform/squeezebert-mnli", | ||
| "qa-bart-large-mnli": {"model_id": "facebook/bart-large-mnli", "quantized": false}, | ||
| "qa-distilbart-mnli": {"model_id": "valhalla/distilbart-mnli-12-6", "quantized": false} |
There was a problem hiding this comment.
reference_models.json is the input for extract_model_ops.py, but the extractor’s tracing helper (torchscript_utils.load_and_trace_hf_model) always uses AutoModel and does not apply per-model overrides like use_cache=False or allow selecting AutoModelForSequenceClassification. With BART models listed here, regenerating the allowlist via extract_model_ops.py --cpp/--golden is likely to either fail for these entries or generate ops for a different graph than intended.
Recommend updating the extractor tooling/config schema to support the required tracing mode for these models (model class selection + config overrides), or removing these BART entries from reference_models.json and sourcing their ops from an explicitly traced .pt file that matches the production tracing path.
| "_comment:qa-models": "Models from the Appex QA pytorch_tests suite. BART models require use_cache=False and AutoModelForSequenceClassification to trace.", | |
| "qa-tinyroberta-squad2": "deepset/tinyroberta-squad2", | |
| "qa-squeezebert-mnli": "typeform/squeezebert-mnli", | |
| "qa-bart-large-mnli": {"model_id": "facebook/bart-large-mnli", "quantized": false}, | |
| "qa-distilbart-mnli": {"model_id": "valhalla/distilbart-mnli-12-6", "quantized": false} | |
| "_comment:qa-models": "Models from the Appex QA pytorch_tests suite.", | |
| "qa-tinyroberta-squad2": "deepset/tinyroberta-squad2", | |
| "qa-squeezebert-mnli": "typeform/squeezebert-mnli" |
Address Copilot review: instead of excluding BART models from the
extraction configs, extend the tooling to support them natively.
Add auto_class and config_overrides fields to the model config
schema. load_and_trace_hf_model now accepts an optional Auto class
name (e.g. "AutoModelForSequenceClassification") and config kwargs
(e.g. {"use_cache": false}), which BART models require to avoid
returning the non-TorchScript-compatible EncoderDecoderCache.
Restore BART entries in validation_models.json and
reference_models.json using the new schema.
Made-with: Cursor
390f666 to
9ec6450
Compare
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 4 out of 6 changed files in this pull request and generated 2 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # | ||
| """Shared utilities for extracting and inspecting TorchScript operations.""" | ||
|
|
||
| import importlib |
There was a problem hiding this comment.
importlib is imported but not used anywhere in this module. Please remove it to avoid unused-import warnings and keep the dependency surface minimal.
| import importlib |
| if isinstance(value, str): | ||
| models[key] = {"model_id": value, "quantized": False} | ||
| elif isinstance(value, dict): | ||
| if "model_id" not in value: | ||
| raise ValueError( | ||
| f"Config entry {key!r} is a dict but missing required " | ||
| f"'model_id' key: {value!r}") | ||
| models[key] = { | ||
| "model_id": value["model_id"], | ||
| "quantized": value.get("quantized", False), | ||
| "auto_class": value.get("auto_class"), | ||
| "config_overrides": value.get("config_overrides", {}), | ||
| } |
There was a problem hiding this comment.
load_model_config() claims to “normalise entries”, but for string entries it only returns {model_id, quantized} while dict entries also include auto_class and config_overrides. This makes the returned schema depend on input type and forces callers to defensively use .get(). Consider always returning defaults (auto_class=None, config_overrides={}) for the string case too so all specs have a consistent shape.
…chema
Remove unused importlib import. Ensure string config entries return
the same schema shape as dict entries (with auto_class=None and
config_overrides={}) so callers don't need defensive .get() calls.
Made-with: Cursor
Summary
aten::_convolutionto the allowlist — SqueezeBERT (typeform/squeezebert-mnli) uses 1D grouped convolutions instead of standard attention, which requires this op.facebook/bart-large-mnli,valhalla/distilbart-mnli-12-6):aten::clone,aten::copy_,aten::fill_,aten::full,aten::new_zeros, andaten::triu. BART models can be TorchScript-traced by settinguse_cache=Falsein the config and usingAutoModelForSequenceClassification(the defaultAutoModelreturnsEncoderDecoderCachewhich is not TorchScript-compatible).deepset/tinyroberta-squad2,typeform/squeezebert-mnli,facebook/bart-large-mnli, andvalhalla/distilbart-mnli-12-6to the validation config and reference golden file.Addresses failures in https://buildkite.com/elastic/appex-qa-stateful-custom-ml-cpp-build-testing/builds/734 triggered from a PyTorch build pipeline. The PyTorch pipeline tests are a superset of the regular ml-cpp PR pipeline tests, which is how the 4 extra models were overlooked.
Test plan
typeform/squeezebert-mnlino longer rejected by graph validatorfacebook/bart-large-mnliandvalhalla/distilbart-mnli-12-6no longer rejected by graph validatordeepset/tinyroberta-squad2continues to pass (all ops were already in allowlist)buildkite run_pytorch_testsafter merge and verify the 4 models pass in the Appex QA pipeline