diff --git a/bin/pytorch_inference/CSupportedOperations.cc b/bin/pytorch_inference/CSupportedOperations.cc index 250b86f79..d4172c4ee 100644 --- a/bin/pytorch_inference/CSupportedOperations.cc +++ b/bin/pytorch_inference/CSupportedOperations.cc @@ -38,6 +38,8 @@ const CSupportedOperations::TStringViewSet CSupportedOperations::FORBIDDEN_OPERA // elastic/multilingual-e5-small-optimized, intfloat/multilingual-e5-small, // .multilingual-e5-small (prepacked), elastic/splade-v3, // elastic/test-elser-v2, .rerank-v1 (Elastic rerank model), +// deepset/tinyroberta-squad2, typeform/squeezebert-mnli, +// facebook/bart-large-mnli, valhalla/distilbart-mnli-12-6, // distilbert-base-uncased-finetuned-sst-2-english, // sentence-transformers/all-distilroberta-v1. // Eland-deployed variants of the above models (with pooling/normalization layers). @@ -51,6 +53,7 @@ const CSupportedOperations::TStringViewSet CSupportedOperations::ALLOWED_OPERATI "aten::IntImplicit"sv, "aten::ScalarImplicit"sv, "aten::__and__"sv, + "aten::_convolution"sv, "aten::abs"sv, "aten::add"sv, "aten::add_"sv, @@ -62,7 +65,9 @@ const CSupportedOperations::TStringViewSet CSupportedOperations::ALLOWED_OPERATI "aten::chunk"sv, "aten::clamp"sv, "aten::clamp_min"sv, + "aten::clone"sv, "aten::contiguous"sv, + "aten::copy_"sv, "aten::cumsum"sv, "aten::detach"sv, "aten::div"sv, @@ -72,7 +77,9 @@ const CSupportedOperations::TStringViewSet CSupportedOperations::ALLOWED_OPERATI "aten::eq"sv, "aten::expand"sv, "aten::expand_as"sv, + "aten::fill_"sv, "aten::floor_divide"sv, + "aten::full"sv, "aten::full_like"sv, "aten::gather"sv, "aten::ge"sv, @@ -100,6 +107,7 @@ const CSupportedOperations::TStringViewSet CSupportedOperations::ALLOWED_OPERATI "aten::ne"sv, "aten::neg"sv, "aten::new_ones"sv, + "aten::new_zeros"sv, "aten::norm"sv, "aten::ones"sv, "aten::pad"sv, @@ -125,6 +133,7 @@ const CSupportedOperations::TStringViewSet CSupportedOperations::ALLOWED_OPERATI "aten::tensor"sv, "aten::to"sv, "aten::transpose"sv, + "aten::triu"sv, "aten::type_as"sv, "aten::unsqueeze"sv, "aten::view"sv, diff --git a/bin/pytorch_inference/unittest/testfiles/reference_model_ops.json b/bin/pytorch_inference/unittest/testfiles/reference_model_ops.json index 30e985582..b2ba8be16 100644 --- a/bin/pytorch_inference/unittest/testfiles/reference_model_ops.json +++ b/bin/pytorch_inference/unittest/testfiles/reference_model_ops.json @@ -1006,6 +1006,175 @@ "prim::ListConstruct", "prim::NumToTensor" ] + }, + "qa-tinyroberta-squad2": { + "model_id": "deepset/tinyroberta-squad2", + "quantized": false, + "ops": [ + "aten::Int", + "aten::add", + "aten::add_", + "aten::cumsum", + "aten::detach", + "aten::dropout", + "aten::embedding", + "aten::expand", + "aten::gelu", + "aten::layer_norm", + "aten::linear", + "aten::masked_fill", + "aten::mul", + "aten::ne", + "aten::reshape", + "aten::scaled_dot_product_attention", + "aten::select", + "aten::size", + "aten::slice", + "aten::sub", + "aten::tanh", + "aten::to", + "aten::transpose", + "aten::type_as", + "aten::unsqueeze", + "aten::view", + "prim::Constant", + "prim::GetAttr", + "prim::ListConstruct", + "prim::NumToTensor", + "prim::TupleConstruct" + ] + }, + "qa-bart-large-mnli": { + "model_id": "facebook/bart-large-mnli", + "quantized": false, + "ops": [ + "aten::Int", + "aten::ScalarImplicit", + "aten::add", + "aten::arange", + "aten::clone", + "aten::contiguous", + "aten::copy_", + "aten::detach", + "aten::dropout", + "aten::embedding", + "aten::eq", + "aten::expand", + "aten::fill_", + "aten::full", + "aten::gelu", + "aten::gt", + "aten::index", + "aten::layer_norm", + "aten::linear", + "aten::masked_fill", + "aten::masked_fill_", + "aten::mul", + "aten::mul_", + "aten::new_zeros", + "aten::ones", + "aten::reshape", + "aten::scaled_dot_product_attention", + "aten::select", + "aten::size", + "aten::slice", + "aten::sub", + "aten::tanh", + "aten::to", + "aten::transpose", + "aten::triu", + "aten::unsqueeze", + "aten::view", + "prim::Constant", + "prim::GetAttr", + "prim::ListConstruct", + "prim::NumToTensor", + "prim::TupleConstruct", + "prim::TupleUnpack" + ] + }, + "qa-distilbart-mnli": { + "model_id": "valhalla/distilbart-mnli-12-6", + "quantized": false, + "ops": [ + "aten::Int", + "aten::ScalarImplicit", + "aten::add", + "aten::arange", + "aten::clone", + "aten::contiguous", + "aten::copy_", + "aten::detach", + "aten::dropout", + "aten::embedding", + "aten::eq", + "aten::expand", + "aten::fill_", + "aten::full", + "aten::gelu", + "aten::gt", + "aten::index", + "aten::layer_norm", + "aten::linear", + "aten::masked_fill", + "aten::masked_fill_", + "aten::mul", + "aten::mul_", + "aten::new_zeros", + "aten::ones", + "aten::reshape", + "aten::scaled_dot_product_attention", + "aten::select", + "aten::size", + "aten::slice", + "aten::sub", + "aten::tanh", + "aten::to", + "aten::transpose", + "aten::triu", + "aten::unsqueeze", + "aten::view", + "prim::Constant", + "prim::GetAttr", + "prim::ListConstruct", + "prim::NumToTensor", + "prim::TupleConstruct", + "prim::TupleUnpack" + ] + }, + "qa-squeezebert-mnli": { + "model_id": "typeform/squeezebert-mnli", + "quantized": false, + "ops": [ + "aten::Int", + "aten::_convolution", + "aten::add", + "aten::contiguous", + "aten::div", + "aten::dropout", + "aten::embedding", + "aten::gelu", + "aten::layer_norm", + "aten::linear", + "aten::matmul", + "aten::mul", + "aten::permute", + "aten::rsub", + "aten::select", + "aten::size", + "aten::slice", + "aten::softmax", + "aten::tanh", + "aten::to", + "aten::unsqueeze", + "aten::view", + "aten::zeros", + "prim::Constant", + "prim::GetAttr", + "prim::ListConstruct", + "prim::NumToTensor", + "prim::TupleConstruct" + ] } } } diff --git a/dev-tools/extract_model_ops/extract_model_ops.py b/dev-tools/extract_model_ops/extract_model_ops.py index 451369a6d..2a070d1cc 100644 --- a/dev-tools/extract_model_ops/extract_model_ops.py +++ b/dev-tools/extract_model_ops/extract_model_ops.py @@ -47,14 +47,18 @@ def extract_ops_for_model(model_name: str, - quantize: bool = False) -> Optional[set[str]]: + quantize: bool = False, + auto_class: str | None = None, + config_overrides: dict | None = None) -> Optional[set[str]]: """Trace a HuggingFace model and return its TorchScript op set. Returns None if the model could not be loaded or traced. """ label = f"{model_name} (quantized)" if quantize else model_name print(f" Loading {label}...", file=sys.stderr) - traced = load_and_trace_hf_model(model_name, quantize=quantize) + traced = load_and_trace_hf_model(model_name, quantize=quantize, + auto_class=auto_class, + config_overrides=config_overrides) if traced is None: return None return collect_inlined_ops(traced) @@ -93,7 +97,9 @@ def main(): failed = [] for arch, spec in reference_models.items(): ops = extract_ops_for_model(spec["model_id"], - quantize=spec["quantized"]) + quantize=spec["quantized"], + auto_class=spec["auto_class"], + config_overrides=spec["config_overrides"]) if ops is None: failed.append(arch) print(f" {arch}: FAILED", file=sys.stderr) diff --git a/dev-tools/extract_model_ops/reference_models.json b/dev-tools/extract_model_ops/reference_models.json index e2f270d35..8a4d86293 100644 --- a/dev-tools/extract_model_ops/reference_models.json +++ b/dev-tools/extract_model_ops/reference_models.json @@ -28,5 +28,11 @@ "_comment:quantized": "Quantized variants: Eland applies torch.quantization.quantize_dynamic on nn.Linear layers when importing models. These produce quantized::* ops not present in the standard traced graphs above.", "elastic-elser-v2-quantized": {"model_id": "elastic/elser-v2", "quantized": true}, "elastic-eis-elser-v2-quantized": {"model_id": "elastic/eis-elser-v2", "quantized": true}, - "elastic-test-elser-v2-quantized": {"model_id": "elastic/test-elser-v2", "quantized": true} + "elastic-test-elser-v2-quantized": {"model_id": "elastic/test-elser-v2", "quantized": true}, + + "_comment:qa-models": "Models from the Appex QA pytorch_tests suite. BART models require auto_class and config_overrides to trace correctly.", + "qa-tinyroberta-squad2": "deepset/tinyroberta-squad2", + "qa-squeezebert-mnli": "typeform/squeezebert-mnli", + "qa-bart-large-mnli": {"model_id": "facebook/bart-large-mnli", "auto_class": "AutoModelForSequenceClassification", "config_overrides": {"use_cache": false}}, + "qa-distilbart-mnli": {"model_id": "valhalla/distilbart-mnli-12-6", "auto_class": "AutoModelForSequenceClassification", "config_overrides": {"use_cache": false}} } diff --git a/dev-tools/extract_model_ops/torchscript_utils.py b/dev-tools/extract_model_ops/torchscript_utils.py index 33042f261..af2b30f68 100644 --- a/dev-tools/extract_model_ops/torchscript_utils.py +++ b/dev-tools/extract_model_ops/torchscript_utils.py @@ -24,8 +24,14 @@ def load_model_config(config_path: Path) -> dict[str, dict]: """Load a model config JSON file and normalise entries. Each entry is either a plain model-name string or a dict with - ``model_id`` (required) and optional ``quantized`` boolean. All - entries are normalised to ``{"model_id": str, "quantized": bool}``. + ``model_id`` (required) and optional fields: + + - ``quantized`` (bool, default False) — apply dynamic quantization. + - ``auto_class`` (str) — transformers Auto class name to use instead + of ``AutoModel`` (e.g. ``"AutoModelForSequenceClassification"``). + - ``config_overrides`` (dict) — extra kwargs passed to + ``AutoConfig.from_pretrained`` (e.g. ``{"use_cache": false}``). + Keys starting with ``_comment`` are silently skipped. Raises ``ValueError`` for malformed entries so that config problems @@ -39,7 +45,8 @@ def load_model_config(config_path: Path) -> dict[str, dict]: if key.startswith("_comment"): continue if isinstance(value, str): - models[key] = {"model_id": value, "quantized": False} + models[key] = {"model_id": value, "quantized": False, + "auto_class": None, "config_overrides": {}} elif isinstance(value, dict): if "model_id" not in value: raise ValueError( @@ -48,6 +55,8 @@ def load_model_config(config_path: Path) -> dict[str, dict]: models[key] = { "model_id": value["model_id"], "quantized": value.get("quantized", False), + "auto_class": value.get("auto_class"), + "config_overrides": value.get("config_overrides", {}), } else: raise ValueError( @@ -74,22 +83,43 @@ def collect_inlined_ops(module) -> set[str]: return collect_graph_ops(graph) -def load_and_trace_hf_model(model_name: str, quantize: bool = False): +def _resolve_auto_class(class_name: str | None): + """Resolve a transformers Auto class by name, defaulting to AutoModel.""" + if not class_name: + return AutoModel + import transformers + cls = getattr(transformers, class_name, None) + if cls is None: + raise ValueError(f"Unknown transformers class: {class_name}") + return cls + + +def load_and_trace_hf_model(model_name: str, quantize: bool = False, + auto_class: str | None = None, + config_overrides: dict | None = None): """Load a HuggingFace model, tokenize sample input, and trace to TorchScript. When *quantize* is True the model is dynamically quantized (nn.Linear layers converted to quantized::linear_dynamic) before tracing. This mirrors what Eland does when importing models for Elasticsearch. + *auto_class* selects a transformers Auto class by name (e.g. + ``"AutoModelForSequenceClassification"``). Defaults to ``AutoModel``. + + *config_overrides* supplies extra kwargs to ``AutoConfig.from_pretrained`` + (e.g. ``{"use_cache": False}`` for encoder-decoder models like BART). + Returns the traced module, or None if the model could not be loaded or traced. """ token = os.environ.get("HF_TOKEN") + model_cls = _resolve_auto_class(auto_class) + overrides = config_overrides or {} try: tokenizer = AutoTokenizer.from_pretrained(model_name, token=token) config = AutoConfig.from_pretrained( - model_name, torchscript=True, token=token) - model = AutoModel.from_pretrained( + model_name, torchscript=True, token=token, **overrides) + model = model_cls.from_pretrained( model_name, config=config, token=token) model.eval() except Exception as exc: diff --git a/dev-tools/extract_model_ops/validation_models.json b/dev-tools/extract_model_ops/validation_models.json index 0c853cdc5..fa4efee4c 100644 --- a/dev-tools/extract_model_ops/validation_models.json +++ b/dev-tools/extract_model_ops/validation_models.json @@ -29,5 +29,11 @@ "es-multilingual-e5-small": "intfloat/multilingual-e5-small", "es-all-MiniLM-L6-v2": "sentence-transformers/all-MiniLM-L6-v2", "es-cross-encoder-ms-marco": "cross-encoder/ms-marco-MiniLM-L-6-v2", - "es-dpr-question-encoder": "facebook/dpr-question_encoder-single-nq-base" + "es-dpr-question-encoder": "facebook/dpr-question_encoder-single-nq-base", + + "_comment:qa-models": "Models from the Appex QA pytorch_tests suite. BART models require auto_class and config_overrides to trace correctly.", + "qa-tinyroberta-squad2": "deepset/tinyroberta-squad2", + "qa-squeezebert-mnli": "typeform/squeezebert-mnli", + "qa-bart-large-mnli": {"model_id": "facebook/bart-large-mnli", "auto_class": "AutoModelForSequenceClassification", "config_overrides": {"use_cache": false}}, + "qa-distilbart-mnli": {"model_id": "valhalla/distilbart-mnli-12-6", "auto_class": "AutoModelForSequenceClassification", "config_overrides": {"use_cache": false}} }