Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions bin/pytorch_inference/CSupportedOperations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ const CSupportedOperations::TStringViewSet CSupportedOperations::FORBIDDEN_OPERA
// elastic/multilingual-e5-small-optimized, intfloat/multilingual-e5-small,
// .multilingual-e5-small (prepacked), elastic/splade-v3,
// elastic/test-elser-v2, .rerank-v1 (Elastic rerank model),
// deepset/tinyroberta-squad2, typeform/squeezebert-mnli,
// facebook/bart-large-mnli, valhalla/distilbart-mnli-12-6,
// distilbert-base-uncased-finetuned-sst-2-english,
// sentence-transformers/all-distilroberta-v1.
// Eland-deployed variants of the above models (with pooling/normalization layers).
Expand All @@ -51,6 +53,7 @@ const CSupportedOperations::TStringViewSet CSupportedOperations::ALLOWED_OPERATI
"aten::IntImplicit"sv,
"aten::ScalarImplicit"sv,
"aten::__and__"sv,
"aten::_convolution"sv,
"aten::abs"sv,
"aten::add"sv,
"aten::add_"sv,
Expand All @@ -62,7 +65,9 @@ const CSupportedOperations::TStringViewSet CSupportedOperations::ALLOWED_OPERATI
"aten::chunk"sv,
"aten::clamp"sv,
"aten::clamp_min"sv,
"aten::clone"sv,
"aten::contiguous"sv,
"aten::copy_"sv,
"aten::cumsum"sv,
"aten::detach"sv,
"aten::div"sv,
Expand All @@ -72,7 +77,9 @@ const CSupportedOperations::TStringViewSet CSupportedOperations::ALLOWED_OPERATI
"aten::eq"sv,
"aten::expand"sv,
"aten::expand_as"sv,
"aten::fill_"sv,
"aten::floor_divide"sv,
"aten::full"sv,
"aten::full_like"sv,
"aten::gather"sv,
"aten::ge"sv,
Expand Down Expand Up @@ -100,6 +107,7 @@ const CSupportedOperations::TStringViewSet CSupportedOperations::ALLOWED_OPERATI
"aten::ne"sv,
"aten::neg"sv,
"aten::new_ones"sv,
"aten::new_zeros"sv,
"aten::norm"sv,
"aten::ones"sv,
"aten::pad"sv,
Expand All @@ -125,6 +133,7 @@ const CSupportedOperations::TStringViewSet CSupportedOperations::ALLOWED_OPERATI
"aten::tensor"sv,
"aten::to"sv,
"aten::transpose"sv,
"aten::triu"sv,
"aten::type_as"sv,
"aten::unsqueeze"sv,
"aten::view"sv,
Expand Down
169 changes: 169 additions & 0 deletions bin/pytorch_inference/unittest/testfiles/reference_model_ops.json
Original file line number Diff line number Diff line change
Expand Up @@ -1006,6 +1006,175 @@
"prim::ListConstruct",
"prim::NumToTensor"
]
},
"qa-tinyroberta-squad2": {
"model_id": "deepset/tinyroberta-squad2",
"quantized": false,
"ops": [
"aten::Int",
"aten::add",
"aten::add_",
"aten::cumsum",
"aten::detach",
"aten::dropout",
"aten::embedding",
"aten::expand",
"aten::gelu",
"aten::layer_norm",
"aten::linear",
"aten::masked_fill",
"aten::mul",
"aten::ne",
"aten::reshape",
"aten::scaled_dot_product_attention",
"aten::select",
"aten::size",
"aten::slice",
"aten::sub",
"aten::tanh",
"aten::to",
"aten::transpose",
"aten::type_as",
"aten::unsqueeze",
"aten::view",
"prim::Constant",
"prim::GetAttr",
"prim::ListConstruct",
"prim::NumToTensor",
"prim::TupleConstruct"
]
},
"qa-bart-large-mnli": {
"model_id": "facebook/bart-large-mnli",
"quantized": false,
"ops": [
"aten::Int",
"aten::ScalarImplicit",
"aten::add",
"aten::arange",
"aten::clone",
"aten::contiguous",
"aten::copy_",
"aten::detach",
"aten::dropout",
"aten::embedding",
"aten::eq",
"aten::expand",
"aten::fill_",
"aten::full",
"aten::gelu",
"aten::gt",
"aten::index",
"aten::layer_norm",
"aten::linear",
"aten::masked_fill",
"aten::masked_fill_",
"aten::mul",
"aten::mul_",
"aten::new_zeros",
"aten::ones",
"aten::reshape",
"aten::scaled_dot_product_attention",
"aten::select",
"aten::size",
"aten::slice",
"aten::sub",
"aten::tanh",
"aten::to",
"aten::transpose",
"aten::triu",
"aten::unsqueeze",
"aten::view",
"prim::Constant",
"prim::GetAttr",
"prim::ListConstruct",
"prim::NumToTensor",
"prim::TupleConstruct",
"prim::TupleUnpack"
]
},
"qa-distilbart-mnli": {
"model_id": "valhalla/distilbart-mnli-12-6",
"quantized": false,
"ops": [
"aten::Int",
"aten::ScalarImplicit",
"aten::add",
"aten::arange",
"aten::clone",
"aten::contiguous",
"aten::copy_",
"aten::detach",
"aten::dropout",
"aten::embedding",
"aten::eq",
"aten::expand",
"aten::fill_",
"aten::full",
"aten::gelu",
"aten::gt",
"aten::index",
"aten::layer_norm",
"aten::linear",
"aten::masked_fill",
"aten::masked_fill_",
"aten::mul",
"aten::mul_",
"aten::new_zeros",
"aten::ones",
"aten::reshape",
"aten::scaled_dot_product_attention",
"aten::select",
"aten::size",
"aten::slice",
"aten::sub",
"aten::tanh",
"aten::to",
"aten::transpose",
"aten::triu",
"aten::unsqueeze",
"aten::view",
"prim::Constant",
"prim::GetAttr",
"prim::ListConstruct",
"prim::NumToTensor",
"prim::TupleConstruct",
"prim::TupleUnpack"
]
},
"qa-squeezebert-mnli": {
"model_id": "typeform/squeezebert-mnli",
"quantized": false,
"ops": [
"aten::Int",
"aten::_convolution",
"aten::add",
"aten::contiguous",
"aten::div",
"aten::dropout",
"aten::embedding",
"aten::gelu",
"aten::layer_norm",
"aten::linear",
"aten::matmul",
"aten::mul",
"aten::permute",
"aten::rsub",
"aten::select",
"aten::size",
"aten::slice",
"aten::softmax",
"aten::tanh",
"aten::to",
"aten::unsqueeze",
"aten::view",
"aten::zeros",
"prim::Constant",
"prim::GetAttr",
"prim::ListConstruct",
"prim::NumToTensor",
"prim::TupleConstruct"
]
}
}
}
12 changes: 9 additions & 3 deletions dev-tools/extract_model_ops/extract_model_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,18 @@


def extract_ops_for_model(model_name: str,
quantize: bool = False) -> Optional[set[str]]:
quantize: bool = False,
auto_class: str | None = None,
config_overrides: dict | None = None) -> Optional[set[str]]:
"""Trace a HuggingFace model and return its TorchScript op set.

Returns None if the model could not be loaded or traced.
"""
label = f"{model_name} (quantized)" if quantize else model_name
print(f" Loading {label}...", file=sys.stderr)
traced = load_and_trace_hf_model(model_name, quantize=quantize)
traced = load_and_trace_hf_model(model_name, quantize=quantize,
auto_class=auto_class,
config_overrides=config_overrides)
if traced is None:
return None
return collect_inlined_ops(traced)
Expand Down Expand Up @@ -93,7 +97,9 @@ def main():
failed = []
for arch, spec in reference_models.items():
ops = extract_ops_for_model(spec["model_id"],
quantize=spec["quantized"])
quantize=spec["quantized"],
auto_class=spec["auto_class"],
config_overrides=spec["config_overrides"])
if ops is None:
failed.append(arch)
print(f" {arch}: FAILED", file=sys.stderr)
Expand Down
8 changes: 7 additions & 1 deletion dev-tools/extract_model_ops/reference_models.json
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,11 @@
"_comment:quantized": "Quantized variants: Eland applies torch.quantization.quantize_dynamic on nn.Linear layers when importing models. These produce quantized::* ops not present in the standard traced graphs above.",
"elastic-elser-v2-quantized": {"model_id": "elastic/elser-v2", "quantized": true},
"elastic-eis-elser-v2-quantized": {"model_id": "elastic/eis-elser-v2", "quantized": true},
"elastic-test-elser-v2-quantized": {"model_id": "elastic/test-elser-v2", "quantized": true}
"elastic-test-elser-v2-quantized": {"model_id": "elastic/test-elser-v2", "quantized": true},

"_comment:qa-models": "Models from the Appex QA pytorch_tests suite. BART models require auto_class and config_overrides to trace correctly.",
"qa-tinyroberta-squad2": "deepset/tinyroberta-squad2",
"qa-squeezebert-mnli": "typeform/squeezebert-mnli",
"qa-bart-large-mnli": {"model_id": "facebook/bart-large-mnli", "auto_class": "AutoModelForSequenceClassification", "config_overrides": {"use_cache": false}},
"qa-distilbart-mnli": {"model_id": "valhalla/distilbart-mnli-12-6", "auto_class": "AutoModelForSequenceClassification", "config_overrides": {"use_cache": false}}
}
42 changes: 36 additions & 6 deletions dev-tools/extract_model_ops/torchscript_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,14 @@ def load_model_config(config_path: Path) -> dict[str, dict]:
"""Load a model config JSON file and normalise entries.

Each entry is either a plain model-name string or a dict with
``model_id`` (required) and optional ``quantized`` boolean. All
entries are normalised to ``{"model_id": str, "quantized": bool}``.
``model_id`` (required) and optional fields:

- ``quantized`` (bool, default False) — apply dynamic quantization.
- ``auto_class`` (str) — transformers Auto class name to use instead
of ``AutoModel`` (e.g. ``"AutoModelForSequenceClassification"``).
- ``config_overrides`` (dict) — extra kwargs passed to
``AutoConfig.from_pretrained`` (e.g. ``{"use_cache": false}``).

Keys starting with ``_comment`` are silently skipped.

Raises ``ValueError`` for malformed entries so that config problems
Expand All @@ -39,7 +45,8 @@ def load_model_config(config_path: Path) -> dict[str, dict]:
if key.startswith("_comment"):
continue
if isinstance(value, str):
models[key] = {"model_id": value, "quantized": False}
models[key] = {"model_id": value, "quantized": False,
"auto_class": None, "config_overrides": {}}
elif isinstance(value, dict):
if "model_id" not in value:
raise ValueError(
Expand All @@ -48,6 +55,8 @@ def load_model_config(config_path: Path) -> dict[str, dict]:
models[key] = {
"model_id": value["model_id"],
"quantized": value.get("quantized", False),
"auto_class": value.get("auto_class"),
"config_overrides": value.get("config_overrides", {}),
}
else:
raise ValueError(
Expand All @@ -74,22 +83,43 @@ def collect_inlined_ops(module) -> set[str]:
return collect_graph_ops(graph)


def load_and_trace_hf_model(model_name: str, quantize: bool = False):
def _resolve_auto_class(class_name: str | None):
"""Resolve a transformers Auto class by name, defaulting to AutoModel."""
if not class_name:
return AutoModel
import transformers
cls = getattr(transformers, class_name, None)
if cls is None:
raise ValueError(f"Unknown transformers class: {class_name}")
return cls


def load_and_trace_hf_model(model_name: str, quantize: bool = False,
auto_class: str | None = None,
config_overrides: dict | None = None):
"""Load a HuggingFace model, tokenize sample input, and trace to TorchScript.

When *quantize* is True the model is dynamically quantized (nn.Linear
layers converted to quantized::linear_dynamic) before tracing. This
mirrors what Eland does when importing models for Elasticsearch.

*auto_class* selects a transformers Auto class by name (e.g.
``"AutoModelForSequenceClassification"``). Defaults to ``AutoModel``.

*config_overrides* supplies extra kwargs to ``AutoConfig.from_pretrained``
(e.g. ``{"use_cache": False}`` for encoder-decoder models like BART).

Returns the traced module, or None if the model could not be loaded or traced.
"""
token = os.environ.get("HF_TOKEN")
model_cls = _resolve_auto_class(auto_class)
overrides = config_overrides or {}

try:
tokenizer = AutoTokenizer.from_pretrained(model_name, token=token)
config = AutoConfig.from_pretrained(
model_name, torchscript=True, token=token)
model = AutoModel.from_pretrained(
model_name, torchscript=True, token=token, **overrides)
model = model_cls.from_pretrained(
model_name, config=config, token=token)
model.eval()
except Exception as exc:
Expand Down
8 changes: 7 additions & 1 deletion dev-tools/extract_model_ops/validation_models.json
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,11 @@
"es-multilingual-e5-small": "intfloat/multilingual-e5-small",
"es-all-MiniLM-L6-v2": "sentence-transformers/all-MiniLM-L6-v2",
"es-cross-encoder-ms-marco": "cross-encoder/ms-marco-MiniLM-L-6-v2",
"es-dpr-question-encoder": "facebook/dpr-question_encoder-single-nq-base"
"es-dpr-question-encoder": "facebook/dpr-question_encoder-single-nq-base",

"_comment:qa-models": "Models from the Appex QA pytorch_tests suite. BART models require auto_class and config_overrides to trace correctly.",
"qa-tinyroberta-squad2": "deepset/tinyroberta-squad2",
"qa-squeezebert-mnli": "typeform/squeezebert-mnli",
"qa-bart-large-mnli": {"model_id": "facebook/bart-large-mnli", "auto_class": "AutoModelForSequenceClassification", "config_overrides": {"use_cache": false}},
"qa-distilbart-mnli": {"model_id": "valhalla/distilbart-mnli-12-6", "auto_class": "AutoModelForSequenceClassification", "config_overrides": {"use_cache": false}}
}
Loading