diff --git a/tests/test_embeddings.py b/tests/test_embeddings.py index e2a07c2..42becdc 100644 --- a/tests/test_embeddings.py +++ b/tests/test_embeddings.py @@ -36,7 +36,7 @@ def test_text_embeddings_config_raise_missing(text_embeddings_config_data: dict) @pytest.mark.parametrize( "embeddings_provider_field, field", - [("openai", "model"), ("openai", "ollama_url"), ("ollama", "openai_api_key")], + [("openai", "ollama_url"), ("ollama", "openai_api_key")], indirect=["embeddings_provider_field"], ) def test_text_embeddings_config_optional(text_embeddings_config_data: dict, field: str): @@ -57,6 +57,28 @@ def test_text_embeddings_config_optional(text_embeddings_config_data: dict, fiel assert text_embeddings_config.model_dump()[field] is None +@pytest.mark.parametrize( + "text_embeddings_config_data", + ["openai", "ollama"], + indirect=True, +) +def test_text_embeddings_config_default_model(text_embeddings_config_data: dict): + """Test that the default model is set correctly when not specified.""" + + # Remove the model field + del text_embeddings_config_data["model"] + + # Instantiate a TextEmbeddingsConfig object + text_embeddings_config = TextEmbeddingsConfig(**text_embeddings_config_data) + + # Check that the model field is set to the default value + provider = text_embeddings_config_data["provider"] + if provider == "openai": + assert text_embeddings_config.model == "text-embedding-ada-002" + elif provider == "ollama": + assert text_embeddings_config.model == "nomic-embed-text" + + @pytest.mark.parametrize( "embeddings_provider_field, field", [("openai", "openai_api_key"), ("ollama", "ollama_url")], diff --git a/uncertainty_engine_types/embeddings.py b/uncertainty_engine_types/embeddings.py index 0c57c23..894e06d 100644 --- a/uncertainty_engine_types/embeddings.py +++ b/uncertainty_engine_types/embeddings.py @@ -32,3 +32,14 @@ def check_provider(cls, values): ): raise ValueError("openai_api_key must be provided for 'openai' provider.") return values + + @model_validator(mode="before") + @classmethod + def check_model(cls, values): + model = values.get("model", None) + provider = values.get("provider") + if provider == TextEmbeddingsProvider.OLLAMA.value and not model: + values["model"] = "nomic-embed-text" + if provider == TextEmbeddingsProvider.OPENAI.value and not model: + values["model"] = "text-embedding-ada-002" + return values