Skip to content

[ML] Add SqueezeBERT and TinyRoBERTa to graph validation allowlist#3011

Merged
valeriy42 merged 4 commits intoelastic:mainfrom
edsavage:fix/add-squeezebert-convolution-op
Mar 25, 2026
Merged

[ML] Add SqueezeBERT and TinyRoBERTa to graph validation allowlist#3011
valeriy42 merged 4 commits intoelastic:mainfrom
edsavage:fix/add-squeezebert-convolution-op

Conversation

@edsavage
Copy link
Contributor

@edsavage edsavage commented Mar 24, 2026

Summary

  • Adds aten::_convolution to the allowlist — SqueezeBERT (typeform/squeezebert-mnli) uses 1D grouped convolutions instead of standard attention, which requires this op.
  • Adds 6 ops needed by BART models (facebook/bart-large-mnli, valhalla/distilbart-mnli-12-6): aten::clone, aten::copy_, aten::fill_, aten::full, aten::new_zeros, and aten::triu. BART models can be TorchScript-traced by setting use_cache=False in the config and using AutoModelForSequenceClassification (the default AutoModel returns EncoderDecoderCache which is not TorchScript-compatible).
  • Adds deepset/tinyroberta-squad2, typeform/squeezebert-mnli, facebook/bart-large-mnli, and valhalla/distilbart-mnli-12-6 to the validation config and reference golden file.

Addresses failures in https://buildkite.com/elastic/appex-qa-stateful-custom-ml-cpp-build-testing/builds/734 triggered from a PyTorch build pipeline. The PyTorch pipeline tests are a superset of the regular ml-cpp PR pipeline tests, which is how the 4 extra models were overlooked.

Test plan

  • CI passes (allowlist drift test validates new models against allowlist)
  • typeform/squeezebert-mnli no longer rejected by graph validator
  • facebook/bart-large-mnli and valhalla/distilbart-mnli-12-6 no longer rejected by graph validator
  • deepset/tinyroberta-squad2 continues to pass (all ops were already in allowlist)
  • Trigger buildkite run_pytorch_tests after merge and verify the 4 models pass in the Appex QA pipeline

Add aten::_convolution to the allowlist — SqueezeBERT uses 1D
grouped convolutions instead of standard attention, which requires
this op. Also add deepset/tinyroberta-squad2 and
typeform/squeezebert-mnli to the validation and reference model
configs.

BART models (facebook/bart-large-mnli, valhalla/distilbart-mnli-12-6)
cannot be TorchScript-traced due to their EncoderDecoderCache return
type and are excluded from validation.

Made-with: Cursor
@prodsecmachine
Copy link

prodsecmachine commented Mar 24, 2026

Snyk checks have passed. No issues have been found so far.

Status Scan Engine Critical High Medium Low Total (0)
Open Source Security 0 0 0 0 0 issues
Licenses 0 0 0 0 0 issues

💻 Catch issues earlier using the plugins for VS Code, JetBrains IDEs, Visual Studio, and Eclipse.

BART models (facebook/bart-large-mnli, valhalla/distilbart-mnli-12-6)
require 6 additional ops: aten::clone, aten::copy_, aten::fill_,
aten::full, aten::new_zeros, and aten::triu.

These models can be TorchScript-traced by setting use_cache=False in
the config and using AutoModelForSequenceClassification (the default
AutoModel returns EncoderDecoderCache which is not TorchScript-
compatible).

Made-with: Cursor
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Updates the PyTorch TorchScript graph validation setup to cover additional HuggingFace models used in the Appex QA pytorch_tests suite, and expands the C++ allowlist to accept the newly observed ops in those models’ traced graphs.

Changes:

  • Add four QA models (TinyRoBERTa SQuAD2, SqueezeBERT MNLI, BART-large MNLI, DistilBART MNLI) to the validation/reference model configs.
  • Expand the C++ ALLOWED_OPERATIONS set with ops required by SqueezeBERT (aten::_convolution) and BART (aten::clone, aten::copy_, aten::fill_, aten::full, aten::new_zeros, aten::triu).
  • Update the golden reference_model_ops.json with per-model op sets for the new QA models.

Reviewed changes

Copilot reviewed 2 out of 4 changed files in this pull request and generated 2 comments.

File Description
dev-tools/extract_model_ops/validation_models.json Adds QA models to the Python allowlist validation config.
dev-tools/extract_model_ops/reference_models.json Adds QA models to the reference extraction config used for allowlist generation.
bin/pytorch_inference/unittest/testfiles/reference_model_ops.json Adds golden op lists for the new QA models (drift test input).
bin/pytorch_inference/CSupportedOperations.cc Adds newly required ops to the C++ allowlist and updates reference model comments.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +34 to +38
"_comment:qa-models": "Models from the Appex QA pytorch_tests suite (test_trained_model_boot_check). BART models require use_cache=False and AutoModelForSequenceClassification to trace.",
"qa-tinyroberta-squad2": "deepset/tinyroberta-squad2",
"qa-squeezebert-mnli": "typeform/squeezebert-mnli",
"qa-bart-large-mnli": {"model_id": "facebook/bart-large-mnli", "quantized": false},
"qa-distilbart-mnli": {"model_id": "valhalla/distilbart-mnli-12-6", "quantized": false}
Copy link

Copilot AI Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The model comment notes BART tracing requires use_cache=False and AutoModelForSequenceClassification, but the validation tool (torchscript_utils.load_and_trace_hf_model) always uses AutoConfig(..., torchscript=True) + AutoModel.from_pretrained(...) and has no way to apply per-model config overrides or select a different AutoModel class. As a result, these BART entries are likely to either fail to trace in validate_allowlist.py or validate a different (non-sequence-classification) graph than the one used in production.

Consider extending the config schema to support per-model tracing options (e.g., model class selection and config overrides like use_cache) and updating load_and_trace_hf_model accordingly; alternatively, keep BART out of validation_models.json and validate via a local traced .pt artifact that matches the intended tracing path.

Suggested change
"_comment:qa-models": "Models from the Appex QA pytorch_tests suite (test_trained_model_boot_check). BART models require use_cache=False and AutoModelForSequenceClassification to trace.",
"qa-tinyroberta-squad2": "deepset/tinyroberta-squad2",
"qa-squeezebert-mnli": "typeform/squeezebert-mnli",
"qa-bart-large-mnli": {"model_id": "facebook/bart-large-mnli", "quantized": false},
"qa-distilbart-mnli": {"model_id": "valhalla/distilbart-mnli-12-6", "quantized": false}
"_comment:qa-models": "Models from the Appex QA pytorch_tests suite (test_trained_model_boot_check). BART-based QA models are validated via separately traced artifacts and are intentionally omitted here.",
"qa-tinyroberta-squad2": "deepset/tinyroberta-squad2",
"qa-squeezebert-mnli": "typeform/squeezebert-mnli"

Copilot uses AI. Check for mistakes.
Comment on lines +33 to +37
"_comment:qa-models": "Models from the Appex QA pytorch_tests suite. BART models require use_cache=False and AutoModelForSequenceClassification to trace.",
"qa-tinyroberta-squad2": "deepset/tinyroberta-squad2",
"qa-squeezebert-mnli": "typeform/squeezebert-mnli",
"qa-bart-large-mnli": {"model_id": "facebook/bart-large-mnli", "quantized": false},
"qa-distilbart-mnli": {"model_id": "valhalla/distilbart-mnli-12-6", "quantized": false}
Copy link

Copilot AI Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reference_models.json is the input for extract_model_ops.py, but the extractor’s tracing helper (torchscript_utils.load_and_trace_hf_model) always uses AutoModel and does not apply per-model overrides like use_cache=False or allow selecting AutoModelForSequenceClassification. With BART models listed here, regenerating the allowlist via extract_model_ops.py --cpp/--golden is likely to either fail for these entries or generate ops for a different graph than intended.

Recommend updating the extractor tooling/config schema to support the required tracing mode for these models (model class selection + config overrides), or removing these BART entries from reference_models.json and sourcing their ops from an explicitly traced .pt file that matches the production tracing path.

Suggested change
"_comment:qa-models": "Models from the Appex QA pytorch_tests suite. BART models require use_cache=False and AutoModelForSequenceClassification to trace.",
"qa-tinyroberta-squad2": "deepset/tinyroberta-squad2",
"qa-squeezebert-mnli": "typeform/squeezebert-mnli",
"qa-bart-large-mnli": {"model_id": "facebook/bart-large-mnli", "quantized": false},
"qa-distilbart-mnli": {"model_id": "valhalla/distilbart-mnli-12-6", "quantized": false}
"_comment:qa-models": "Models from the Appex QA pytorch_tests suite.",
"qa-tinyroberta-squad2": "deepset/tinyroberta-squad2",
"qa-squeezebert-mnli": "typeform/squeezebert-mnli"

Copilot uses AI. Check for mistakes.
Address Copilot review: instead of excluding BART models from the
extraction configs, extend the tooling to support them natively.

Add auto_class and config_overrides fields to the model config
schema. load_and_trace_hf_model now accepts an optional Auto class
name (e.g. "AutoModelForSequenceClassification") and config kwargs
(e.g. {"use_cache": false}), which BART models require to avoid
returning the non-TorchScript-compatible EncoderDecoderCache.

Restore BART entries in validation_models.json and
reference_models.json using the new schema.

Made-with: Cursor
@edsavage edsavage force-pushed the fix/add-squeezebert-convolution-op branch from 390f666 to 9ec6450 Compare March 25, 2026 00:37
@edsavage edsavage requested a review from Copilot March 25, 2026 00:39
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 4 out of 6 changed files in this pull request and generated 2 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

#
"""Shared utilities for extracting and inspecting TorchScript operations."""

import importlib
Copy link

Copilot AI Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

importlib is imported but not used anywhere in this module. Please remove it to avoid unused-import warnings and keep the dependency surface minimal.

Suggested change
import importlib

Copilot uses AI. Check for mistakes.
Comment on lines 48 to 60
if isinstance(value, str):
models[key] = {"model_id": value, "quantized": False}
elif isinstance(value, dict):
if "model_id" not in value:
raise ValueError(
f"Config entry {key!r} is a dict but missing required "
f"'model_id' key: {value!r}")
models[key] = {
"model_id": value["model_id"],
"quantized": value.get("quantized", False),
"auto_class": value.get("auto_class"),
"config_overrides": value.get("config_overrides", {}),
}
Copy link

Copilot AI Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

load_model_config() claims to “normalise entries”, but for string entries it only returns {model_id, quantized} while dict entries also include auto_class and config_overrides. This makes the returned schema depend on input type and forces callers to defensively use .get(). Consider always returning defaults (auto_class=None, config_overrides={}) for the string case too so all specs have a consistent shape.

Copilot uses AI. Check for mistakes.
…chema

Remove unused importlib import. Ensure string config entries return
the same schema shape as dict entries (with auto_class=None and
config_overrides={}) so callers don't need defensive .get() calls.

Made-with: Cursor
@edsavage edsavage requested a review from wwang500 March 25, 2026 03:22
Copy link

@wwang500 wwang500 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@valeriy42 valeriy42 merged commit 4387502 into elastic:main Mar 25, 2026
19 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants