From 776570a41acf487a1c1b10dc15a2093ca693d8c8 Mon Sep 17 00:00:00 2001 From: Ed Savage Date: Thu, 26 Mar 2026 10:03:53 +1300 Subject: [PATCH] [ML] Add aten::split and aten::stack for question-answering models The deepset/tinyroberta-squad2 model uses aten::split (and aten::stack per ES node logs) in its answer span extraction logic. These ops only appear when traced with AutoModelForQuestionAnswering rather than AutoModel. Update the extraction configs to use the correct auto_class. Also verified that LaBSE, BAAI/bge-reranker-base, and castorini/bpr-nq-ctx-encoder (from the supported models docs) are all covered by the existing allowlist. Made-with: Cursor --- bin/pytorch_inference/CSupportedOperations.cc | 2 ++ .../unittest/testfiles/reference_model_ops.json | 11 ++++++----- dev-tools/extract_model_ops/reference_models.json | 2 +- dev-tools/extract_model_ops/validation_models.json | 2 +- 4 files changed, 10 insertions(+), 7 deletions(-) diff --git a/bin/pytorch_inference/CSupportedOperations.cc b/bin/pytorch_inference/CSupportedOperations.cc index d4172c4ee..56dbbaa84 100644 --- a/bin/pytorch_inference/CSupportedOperations.cc +++ b/bin/pytorch_inference/CSupportedOperations.cc @@ -124,8 +124,10 @@ const CSupportedOperations::TStringViewSet CSupportedOperations::ALLOWED_OPERATI "aten::size"sv, "aten::slice"sv, "aten::softmax"sv, + "aten::split"sv, "aten::sqrt"sv, "aten::squeeze"sv, + "aten::stack"sv, "aten::str"sv, "aten::sub"sv, "aten::sum"sv, diff --git a/bin/pytorch_inference/unittest/testfiles/reference_model_ops.json b/bin/pytorch_inference/unittest/testfiles/reference_model_ops.json index b2ba8be16..bdc975c53 100644 --- a/bin/pytorch_inference/unittest/testfiles/reference_model_ops.json +++ b/bin/pytorch_inference/unittest/testfiles/reference_model_ops.json @@ -1011,9 +1011,9 @@ "model_id": "deepset/tinyroberta-squad2", "quantized": false, "ops": [ - "aten::Int", "aten::add", "aten::add_", + "aten::contiguous", "aten::cumsum", "aten::detach", "aten::dropout", @@ -1027,11 +1027,11 @@ "aten::ne", "aten::reshape", "aten::scaled_dot_product_attention", - "aten::select", "aten::size", "aten::slice", + "aten::split", + "aten::squeeze", "aten::sub", - "aten::tanh", "aten::to", "aten::transpose", "aten::type_as", @@ -1040,9 +1040,10 @@ "prim::Constant", "prim::GetAttr", "prim::ListConstruct", - "prim::NumToTensor", + "prim::ListUnpack", "prim::TupleConstruct" - ] + ], + "auto_class": "AutoModelForQuestionAnswering" }, "qa-bart-large-mnli": { "model_id": "facebook/bart-large-mnli", diff --git a/dev-tools/extract_model_ops/reference_models.json b/dev-tools/extract_model_ops/reference_models.json index 8a4d86293..5170a0e2e 100644 --- a/dev-tools/extract_model_ops/reference_models.json +++ b/dev-tools/extract_model_ops/reference_models.json @@ -31,7 +31,7 @@ "elastic-test-elser-v2-quantized": {"model_id": "elastic/test-elser-v2", "quantized": true}, "_comment:qa-models": "Models from the Appex QA pytorch_tests suite. BART models require auto_class and config_overrides to trace correctly.", - "qa-tinyroberta-squad2": "deepset/tinyroberta-squad2", + "qa-tinyroberta-squad2": {"model_id": "deepset/tinyroberta-squad2", "auto_class": "AutoModelForQuestionAnswering"}, "qa-squeezebert-mnli": "typeform/squeezebert-mnli", "qa-bart-large-mnli": {"model_id": "facebook/bart-large-mnli", "auto_class": "AutoModelForSequenceClassification", "config_overrides": {"use_cache": false}}, "qa-distilbart-mnli": {"model_id": "valhalla/distilbart-mnli-12-6", "auto_class": "AutoModelForSequenceClassification", "config_overrides": {"use_cache": false}} diff --git a/dev-tools/extract_model_ops/validation_models.json b/dev-tools/extract_model_ops/validation_models.json index fa4efee4c..1b36747fd 100644 --- a/dev-tools/extract_model_ops/validation_models.json +++ b/dev-tools/extract_model_ops/validation_models.json @@ -32,7 +32,7 @@ "es-dpr-question-encoder": "facebook/dpr-question_encoder-single-nq-base", "_comment:qa-models": "Models from the Appex QA pytorch_tests suite. BART models require auto_class and config_overrides to trace correctly.", - "qa-tinyroberta-squad2": "deepset/tinyroberta-squad2", + "qa-tinyroberta-squad2": {"model_id": "deepset/tinyroberta-squad2", "auto_class": "AutoModelForQuestionAnswering"}, "qa-squeezebert-mnli": "typeform/squeezebert-mnli", "qa-bart-large-mnli": {"model_id": "facebook/bart-large-mnli", "auto_class": "AutoModelForSequenceClassification", "config_overrides": {"use_cache": false}}, "qa-distilbart-mnli": {"model_id": "valhalla/distilbart-mnli-12-6", "auto_class": "AutoModelForSequenceClassification", "config_overrides": {"use_cache": false}}