From c07999af4de9aa24221431fc928536777e691f8e Mon Sep 17 00:00:00 2001 From: jacantwell Date: Wed, 26 Mar 2025 14:56:19 +0000 Subject: [PATCH 1/3] Adding default factory to message type --- uncertainty_engine_types/message.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/uncertainty_engine_types/message.py b/uncertainty_engine_types/message.py index 3550961..8cab135 100644 --- a/uncertainty_engine_types/message.py +++ b/uncertainty_engine_types/message.py @@ -1,13 +1,13 @@ from datetime import datetime from typing import Any, Literal -from pydantic import BaseModel, field_validator +from pydantic import BaseModel, field_validator, Field class Message(BaseModel): role: Literal["instruction", "user", "engine"] content: str - timestamp: datetime + timestamp: datetime = Field(default_factory=datetime.now) @field_validator("content", mode="before") @classmethod From 17890f43d7ede3f9f1eafdad43ffd9581bf1cb37 Mon Sep 17 00:00:00 2001 From: jacantwell Date: Mon, 31 Mar 2025 13:51:18 +0100 Subject: [PATCH 2/3] Updating tests --- tests/test_embeddings.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/tests/test_embeddings.py b/tests/test_embeddings.py index e2a07c2..42becdc 100644 --- a/tests/test_embeddings.py +++ b/tests/test_embeddings.py @@ -36,7 +36,7 @@ def test_text_embeddings_config_raise_missing(text_embeddings_config_data: dict) @pytest.mark.parametrize( "embeddings_provider_field, field", - [("openai", "model"), ("openai", "ollama_url"), ("ollama", "openai_api_key")], + [("openai", "ollama_url"), ("ollama", "openai_api_key")], indirect=["embeddings_provider_field"], ) def test_text_embeddings_config_optional(text_embeddings_config_data: dict, field: str): @@ -57,6 +57,28 @@ def test_text_embeddings_config_optional(text_embeddings_config_data: dict, fiel assert text_embeddings_config.model_dump()[field] is None +@pytest.mark.parametrize( + "text_embeddings_config_data", + ["openai", "ollama"], + indirect=True, +) +def test_text_embeddings_config_default_model(text_embeddings_config_data: dict): + """Test that the default model is set correctly when not specified.""" + + # Remove the model field + del text_embeddings_config_data["model"] + + # Instantiate a TextEmbeddingsConfig object + text_embeddings_config = TextEmbeddingsConfig(**text_embeddings_config_data) + + # Check that the model field is set to the default value + provider = text_embeddings_config_data["provider"] + if provider == "openai": + assert text_embeddings_config.model == "text-embedding-ada-002" + elif provider == "ollama": + assert text_embeddings_config.model == "nomic-embed-text" + + @pytest.mark.parametrize( "embeddings_provider_field, field", [("openai", "openai_api_key"), ("ollama", "ollama_url")], From 65e3e7160233961d2f7f6e8d0f0de522bf46bf6e Mon Sep 17 00:00:00 2001 From: jacantwell Date: Mon, 31 Mar 2025 13:54:54 +0100 Subject: [PATCH 3/3] Updating type --- uncertainty_engine_types/embeddings.py | 11 +++++++++++ uncertainty_engine_types/message.py | 4 ++-- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/uncertainty_engine_types/embeddings.py b/uncertainty_engine_types/embeddings.py index 0c57c23..894e06d 100644 --- a/uncertainty_engine_types/embeddings.py +++ b/uncertainty_engine_types/embeddings.py @@ -32,3 +32,14 @@ def check_provider(cls, values): ): raise ValueError("openai_api_key must be provided for 'openai' provider.") return values + + @model_validator(mode="before") + @classmethod + def check_model(cls, values): + model = values.get("model", None) + provider = values.get("provider") + if provider == TextEmbeddingsProvider.OLLAMA.value and not model: + values["model"] = "nomic-embed-text" + if provider == TextEmbeddingsProvider.OPENAI.value and not model: + values["model"] = "text-embedding-ada-002" + return values diff --git a/uncertainty_engine_types/message.py b/uncertainty_engine_types/message.py index 8cab135..3550961 100644 --- a/uncertainty_engine_types/message.py +++ b/uncertainty_engine_types/message.py @@ -1,13 +1,13 @@ from datetime import datetime from typing import Any, Literal -from pydantic import BaseModel, field_validator, Field +from pydantic import BaseModel, field_validator class Message(BaseModel): role: Literal["instruction", "user", "engine"] content: str - timestamp: datetime = Field(default_factory=datetime.now) + timestamp: datetime @field_validator("content", mode="before") @classmethod