From e0fa0d64ef94c7a558f0d0013bcf3a7d2f526498 Mon Sep 17 00:00:00 2001 From: jfinlon on Thor Date: Fri, 5 Jun 2026 04:38:44 -0400 Subject: [PATCH] fix(text-embeddings): compare norm type case-insensitively LTX-2.3 distilled checkpoints can serialize text_encoder_norm_type as PER_TOKEN_RMS while the expected value is per_token_rms. Compare string config values case-insensitively while keeping non-string expectations strict. Add pytest coverage for both casings and strict non-string behavior. --- .../test_text_embeddings_config_validation.py | 31 +++++++++++++++++++ text_embeddings_connectors.py | 9 +++++- 2 files changed, 39 insertions(+), 1 deletion(-) create mode 100644 tests/test_text_embeddings_config_validation.py diff --git a/tests/test_text_embeddings_config_validation.py b/tests/test_text_embeddings_config_validation.py new file mode 100644 index 0000000..dd3627e --- /dev/null +++ b/tests/test_text_embeddings_config_validation.py @@ -0,0 +1,31 @@ +import ast +from pathlib import Path + + +def load_config_value_matches(): + source_path = Path(__file__).resolve().parents[1] / "text_embeddings_connectors.py" + source = ast.parse(source_path.read_text()) + helper = next( + node + for node in source.body + if isinstance(node, ast.FunctionDef) and node.name == "_config_value_matches" + ) + module = ast.Module(body=[helper], type_ignores=[]) + ast.fix_missing_locations(module) + namespace = {} + exec(compile(module, str(source_path), "exec"), namespace) + return namespace["_config_value_matches"] + + +def test_text_encoder_norm_type_matches_both_casings(): + matches = load_config_value_matches() + + assert matches("per_token_rms", "per_token_rms") + assert matches("PER_TOKEN_RMS", "per_token_rms") + + +def test_non_string_expectations_remain_strict(): + matches = load_config_value_matches() + + assert matches(False, False) + assert not matches("False", False) diff --git a/text_embeddings_connectors.py b/text_embeddings_connectors.py index a5a7688..a788b64 100644 --- a/text_embeddings_connectors.py +++ b/text_embeddings_connectors.py @@ -47,6 +47,13 @@ def _filter_sd(sd: dict, prefix: str) -> dict: return {k[len(prefix) :]: v for k, v in sd.items() if k.startswith(prefix)} +def _config_value_matches(actual, expected_val) -> bool: + """Compare string config values case-insensitively while keeping other types strict.""" + if isinstance(actual, str) and isinstance(expected_val, str): + return actual.casefold() == expected_val.casefold() + return actual == expected_val + + def _load_aggregate_embed(sd: dict, modality: str, dtype) -> nn.Linear: """Load an aggregate_embed Linear from the state dict. @@ -359,7 +366,7 @@ def load_text_embeddings_pipeline( } for key, expected_val in _expected.items(): actual = transformer_config.get(key) - assert actual == expected_val, ( + assert _config_value_matches(actual, expected_val), ( f"Unexpected config for dual-aggregate model: " f"{key}={actual!r}, expected {expected_val!r}" )