diff --git a/bin/pytorch_inference/CSupportedOperations.cc b/bin/pytorch_inference/CSupportedOperations.cc index d4172c4ee..56dbbaa84 100644 --- a/bin/pytorch_inference/CSupportedOperations.cc +++ b/bin/pytorch_inference/CSupportedOperations.cc @@ -124,8 +124,10 @@ const CSupportedOperations::TStringViewSet CSupportedOperations::ALLOWED_OPERATI "aten::size"sv, "aten::slice"sv, "aten::softmax"sv, + "aten::split"sv, "aten::sqrt"sv, "aten::squeeze"sv, + "aten::stack"sv, "aten::str"sv, "aten::sub"sv, "aten::sum"sv, diff --git a/bin/pytorch_inference/unittest/testfiles/reference_model_ops.json b/bin/pytorch_inference/unittest/testfiles/reference_model_ops.json index b2ba8be16..bdc975c53 100644 --- a/bin/pytorch_inference/unittest/testfiles/reference_model_ops.json +++ b/bin/pytorch_inference/unittest/testfiles/reference_model_ops.json @@ -1011,9 +1011,9 @@ "model_id": "deepset/tinyroberta-squad2", "quantized": false, "ops": [ - "aten::Int", "aten::add", "aten::add_", + "aten::contiguous", "aten::cumsum", "aten::detach", "aten::dropout", @@ -1027,11 +1027,11 @@ "aten::ne", "aten::reshape", "aten::scaled_dot_product_attention", - "aten::select", "aten::size", "aten::slice", + "aten::split", + "aten::squeeze", "aten::sub", - "aten::tanh", "aten::to", "aten::transpose", "aten::type_as", @@ -1040,9 +1040,10 @@ "prim::Constant", "prim::GetAttr", "prim::ListConstruct", - "prim::NumToTensor", + "prim::ListUnpack", "prim::TupleConstruct" - ] + ], + "auto_class": "AutoModelForQuestionAnswering" }, "qa-bart-large-mnli": { "model_id": "facebook/bart-large-mnli", diff --git a/dev-tools/extract_model_ops/reference_models.json b/dev-tools/extract_model_ops/reference_models.json index 8a4d86293..5170a0e2e 100644 --- a/dev-tools/extract_model_ops/reference_models.json +++ b/dev-tools/extract_model_ops/reference_models.json @@ -31,7 +31,7 @@ "elastic-test-elser-v2-quantized": {"model_id": "elastic/test-elser-v2", "quantized": true}, "_comment:qa-models": "Models from the Appex QA pytorch_tests suite. BART models require auto_class and config_overrides to trace correctly.", - "qa-tinyroberta-squad2": "deepset/tinyroberta-squad2", + "qa-tinyroberta-squad2": {"model_id": "deepset/tinyroberta-squad2", "auto_class": "AutoModelForQuestionAnswering"}, "qa-squeezebert-mnli": "typeform/squeezebert-mnli", "qa-bart-large-mnli": {"model_id": "facebook/bart-large-mnli", "auto_class": "AutoModelForSequenceClassification", "config_overrides": {"use_cache": false}}, "qa-distilbart-mnli": {"model_id": "valhalla/distilbart-mnli-12-6", "auto_class": "AutoModelForSequenceClassification", "config_overrides": {"use_cache": false}} diff --git a/dev-tools/extract_model_ops/validation_models.json b/dev-tools/extract_model_ops/validation_models.json index fa4efee4c..1b36747fd 100644 --- a/dev-tools/extract_model_ops/validation_models.json +++ b/dev-tools/extract_model_ops/validation_models.json @@ -32,7 +32,7 @@ "es-dpr-question-encoder": "facebook/dpr-question_encoder-single-nq-base", "_comment:qa-models": "Models from the Appex QA pytorch_tests suite. BART models require auto_class and config_overrides to trace correctly.", - "qa-tinyroberta-squad2": "deepset/tinyroberta-squad2", + "qa-tinyroberta-squad2": {"model_id": "deepset/tinyroberta-squad2", "auto_class": "AutoModelForQuestionAnswering"}, "qa-squeezebert-mnli": "typeform/squeezebert-mnli", "qa-bart-large-mnli": {"model_id": "facebook/bart-large-mnli", "auto_class": "AutoModelForSequenceClassification", "config_overrides": {"use_cache": false}}, "qa-distilbart-mnli": {"model_id": "valhalla/distilbart-mnli-12-6", "auto_class": "AutoModelForSequenceClassification", "config_overrides": {"use_cache": false}}