diff --git a/docs/EMBEDDINGS.md b/docs/EMBEDDINGS.md new file mode 100644 index 0000000..2a60ba1 --- /dev/null +++ b/docs/EMBEDDINGS.md @@ -0,0 +1,76 @@ +# EmbeddingBlock Documentation + +## Overview + +The `EmbeddingBlock` class provides a flexible interface for generating text embeddings using LiteLLM. It supports multiple embedding providers through a single, consistent API. + +## Quick Start + +```python +from quantmind.config import EmbeddingConfig +from quantmind.llm import create_embedding_block + +# Simple configuration +config = EmbeddingConfig( + model="text-embedding-ada-002" +) + +embedding_block = create_embedding_block(config) +embedding = embedding_block.generate_embedding("Sample text") +``` + +## Configuration + +### Required Parameters +- `model`: Embedding model name (e.g., "text-embedding-ada-002") + +### Optional Parameters +- `user`: Unique identifier for end-user +- `dimensions`: Number of dimensions (OpenAI text-embedding-3+) +- `encoding_format`: "float" or "base64" (default: "float") +- `timeout`: Request timeout in seconds (default: 600) +- `api_base`: Custom API endpoint +- `api_version`: Azure-specific API version +- `api_key`: API key for authentication +- `api_type`: Type of API to use + +## Examples + +### Basic Usage +```python +config = EmbeddingConfig(model="text-embedding-ada-002") +embedding_block = create_embedding_block(config) +embedding = embedding_block.generate_embedding("Text to embed") +``` + +### With Custom Dimensions +```python +config = EmbeddingConfig( + model="text-embedding-3-small", + dimensions=512 +) +``` + +### Azure OpenAI +```python +config = EmbeddingConfig( + model="text-embedding-ada-002", + api_key="azure-key", + api_base="https://your-resource.openai.azure.com/", + api_version="2023-05-15", + api_type="azure" +) +``` + +## Methods + +- `generate_embedding(text)`: Generate single embedding +- `generate_embeddings(texts)`: Generate multiple embeddings +- `batch_embed(texts, batch_size)`: Process large datasets +- `test_connection()`: Test API connection +- `get_info()`: Get configuration information +- `get_embedding_dimension()`: Get embedding dimension + +## See Also + +- `examples/llm/embedding_block_example.py` for complete examples diff --git a/examples/llm/embedding_block_example.py b/examples/llm/embedding_block_example.py new file mode 100644 index 0000000..ec477c7 --- /dev/null +++ b/examples/llm/embedding_block_example.py @@ -0,0 +1,205 @@ +"""Example usage of EmbeddingBlock for different embedding providers.""" + +import os +from typing import List + +from quantmind.config import EmbeddingConfig +from quantmind.llm import EmbeddingBlock, create_embedding_block +from quantmind.utils.logger import get_logger + +logger = get_logger(__name__) + + +def example_openai_embeddings(): + """Example using OpenAI embeddings.""" + print("\n=== OpenAI Embeddings Example ===") + + # Configuration for OpenAI embeddings + config = EmbeddingConfig( + model="text-embedding-ada-002", + api_key=os.getenv("OPENAI_API_KEY"), + timeout=30, + encoding_format="float", + ) + + # Create embedding block + embedding_block = create_embedding_block(config) + + # Test connection + if embedding_block.test_connection(): + print("āœ… OpenAI connection successful") + else: + print("āŒ OpenAI connection failed") + return + + # Generate single embedding + text = "This is a sample text for embedding generation." + embedding = embedding_block.generate_embedding(text) + + if embedding: + print(f"āœ… Generated embedding with {len(embedding)} dimensions") + print(f" First 5 values: {embedding[:5]}") + + # Generate multiple embeddings + texts = [ + "First sample text for embedding.", + "Second sample text with different content.", + "Third sample text for batch processing.", + ] + + embeddings = embedding_block.generate_embeddings(texts) + + if embeddings: + print(f"āœ… Generated {len(embeddings)} embeddings") + for i, emb in enumerate(embeddings): + print(f" Text {i + 1}: {len(emb)} dimensions") + + # Get embedding information + info = embedding_block.get_info() + print(f"šŸ“Š Model info: {info['model']}") + print(f"šŸ“Š Provider: {info['provider']}") + + +def example_azure_embeddings(): + """Example using Azure OpenAI embeddings.""" + print("\n=== Azure OpenAI Embeddings Example ===") + + # Configuration for Azure OpenAI embeddings + config = EmbeddingConfig( + model="text-embedding-ada-002", + api_key=os.getenv("AZURE_API_KEY"), + api_base=os.getenv("AZURE_API_BASE"), + api_version=os.getenv("AZURE_API_VERSION", "2023-05-15"), + api_type="azure", + timeout=30, + encoding_format="float", + ) + + # Create embedding block + embedding_block = create_embedding_block(config) + + # Test connection + if embedding_block.test_connection(): + print("āœ… Azure OpenAI connection successful") + else: + print("āŒ Azure OpenAI connection failed") + return + + # Generate single embedding + text = "This is a sample text for Azure OpenAI embedding generation." + embedding = embedding_block.generate_embedding(text) + + if embedding: + print(f"āœ… Generated embedding with {len(embedding)} dimensions") + print(f" First 5 values: {embedding[:5]}") + + # Generate multiple embeddings + texts = [ + "First sample text for Azure embedding.", + "Second sample text with different content.", + "Third sample text for batch processing.", + ] + + embeddings = embedding_block.generate_embeddings(texts) + + if embeddings: + print(f"āœ… Generated {len(embeddings)} embeddings") + for i, emb in enumerate(embeddings): + print(f" Text {i + 1}: {len(emb)} dimensions") + + # Get embedding information + info = embedding_block.get_info() + print(f"šŸ“Š Model info: {info['model']}") + print(f"šŸ“Š Provider: {info['provider']}") + + +def example_configuration_variants(): + """Example showing different configuration variants.""" + print("\n=== Configuration Variants Example ===") + + # Base configuration + base_config = EmbeddingConfig( + model="text-embedding-ada-002", + api_key=os.getenv("OPENAI_API_KEY"), + encoding_format="float", + ) + + # Create variants with different parameters + fast_config = base_config.create_variant(timeout=10, retry_attempts=1) + + conservative_config = base_config.create_variant( + timeout=120, retry_attempts=5, retry_delay=2.0 + ) + + print(f"Base config timeout: {base_config.timeout}s") + print(f"Fast config timeout: {fast_config.timeout}s") + print(f"Conservative config timeout: {conservative_config.timeout}s") + + # Test with temporary configuration + embedding_block = create_embedding_block(base_config) + + with embedding_block.temporary_config(timeout=5): + print("Using temporary configuration with 5s timeout") + # Any embedding operations here will use the temporary config + embedding = embedding_block.generate_embedding("Test with temp config") + if embedding: + print("āœ… Temporary configuration worked") + + +def example_error_handling(): + """Example showing error handling and fallbacks.""" + print("\n=== Error Handling Example ===") + + # Try with invalid API key + config = EmbeddingConfig( + model="text-embedding-ada-002", + api_key="invalid_key", + timeout=5, + ) + + embedding_block = create_embedding_block(config) + + # This should fail gracefully + embedding = embedding_block.generate_embedding("Test text") + if embedding is None: + print("āœ… Gracefully handled invalid API key") + + # Try with non-existent model + config = EmbeddingConfig( + model="non-existent-model", + timeout=5, + ) + + try: + embedding_block = create_embedding_block(config) + print("āŒ Should have failed with non-existent model") + except Exception as e: + print(f"āœ… Gracefully handled non-existent model: {e}") + + +def main(): + """Run all embedding examples.""" + print("šŸš€ EmbeddingBlock Examples") + print("=" * 50) + + # Run examples based on available API keys + if os.getenv("OPENAI_API_KEY"): + example_openai_embeddings() + else: + print("\nāš ļø Skipping OpenAI examples - OPENAI_API_KEY not set") + + if os.getenv("AZURE_API_KEY") and os.getenv("AZURE_API_BASE"): + example_azure_embeddings() + else: + print( + "\nāš ļø Skipping Azure example - AZURE_API_KEY or AZURE_API_BASE not set" + ) + + example_configuration_variants() + example_error_handling() + + print("\nāœ… All examples completed!") + + +if __name__ == "__main__": + main() diff --git a/examples/storage/storage_performance_demo.py b/examples/storage/storage_performance_demo.py index a991b64..8eccfd0 100644 --- a/examples/storage/storage_performance_demo.py +++ b/examples/storage/storage_performance_demo.py @@ -64,7 +64,7 @@ def simulate_old_behavior(storage: LocalStorage, num_lookups: int = 50): old_time = end_time - start_time print(f" Time for {num_lookups} lookups: {old_time:.4f} seconds") - print(f" Average per lookup: {(old_time/num_lookups)*1000:.2f} ms") + print(f" Average per lookup: {(old_time / num_lookups) * 1000:.2f} ms") return old_time @@ -86,7 +86,7 @@ def test_new_indexing_performance(storage: LocalStorage, num_lookups: int = 50): new_time = end_time - start_time print(f" Time for {num_lookups} lookups: {new_time:.4f} seconds") - print(f" Average per lookup: {(new_time/num_lookups)*1000:.2f} ms") + print(f" Average per lookup: {(new_time / num_lookups) * 1000:.2f} ms") return new_time @@ -110,7 +110,7 @@ def test_knowledge_lookup_performance( print( f" Time for {num_lookups} knowledge lookups: {lookup_time:.4f} seconds" ) - print(f" Average per lookup: {(lookup_time/num_lookups)*1000:.2f} ms") + print(f" Average per lookup: {(lookup_time / num_lookups) * 1000:.2f} ms") return lookup_time @@ -128,7 +128,7 @@ def test_batch_operations(storage: LocalStorage): count = len(all_knowledges) print(f" Retrieved {count} knowledge items in {batch_time:.4f} seconds") - print(f" Average per item: {(batch_time/count)*1000:.2f} ms") + print(f" Average per item: {(batch_time / count) * 1000:.2f} ms") return batch_time diff --git a/quantmind/config/__init__.py b/quantmind/config/__init__.py index d20d111..49901c9 100644 --- a/quantmind/config/__init__.py +++ b/quantmind/config/__init__.py @@ -7,6 +7,7 @@ SummaryFlowConfig, ) from .llm import LLMConfig +from .embedding import EmbeddingConfig from .parsers import LlamaParserConfig, PDFParserConfig from .settings import ( Setting, @@ -27,6 +28,7 @@ "Setting", # LLM Configuration "LLMConfig", + "EmbeddingConfig", # Tagger Configurations "LLMTaggerConfig", # Parser Configurations diff --git a/quantmind/config/embedding.py b/quantmind/config/embedding.py new file mode 100644 index 0000000..1774d5f --- /dev/null +++ b/quantmind/config/embedding.py @@ -0,0 +1,125 @@ +"""Embedding configuration for QuantMind.""" + +from typing import Any, Dict, Optional + +from pydantic import BaseModel, Field, field_validator + + +class EmbeddingConfig(BaseModel): + """Configuration for EmbeddingBlock.""" + + # Model configuration + model: str = Field( + default="text-embedding-ada-002", description="Embedding model name" + ) + + # Optional parameters + user: Optional[str] = Field( + default=None, + description="A unique identifier representing your end-user", + ) + dimensions: Optional[int] = Field( + default=None, + description="The number of dimensions the resulting output embeddings should have. Only supported in OpenAI/Azure text-embedding-3 and later models", + ) + encoding_format: str = Field( + default="float", + description="The format to return the embeddings in. Can be either 'float' or 'base64'", + ) + timeout: int = Field( + default=600, + description="The maximum time, in seconds, to wait for the API to respond", + ) + retry_attempts: int = Field( + default=3, + ge=0, + description="The number of retry attempts", + ) + retry_delay: float = Field( + default=1.0, + ge=0, + description="The delay between retries in seconds", + ) + api_base: Optional[str] = Field( + default=None, + description="The api endpoint you want to call the model with", + ) + api_version: Optional[str] = Field( + default=None, + description="(Azure-specific) the api version for the call", + ) + api_key: Optional[str] = Field( + default=None, + description="The API key to authenticate and authorize requests. If not provided, the default API key is used", + ) + api_type: Optional[str] = Field( + default=None, description="The type of API to use" + ) + + @field_validator("model") + def validate_model(cls, v): + """Validate model name format.""" + if not v or not isinstance(v, str): + raise ValueError("Model name must be a non-empty string") + return v.strip() + + @field_validator("api_key") + def validate_api_key(cls, v): + """Validate API key.""" + if v is not None and not isinstance(v, str): + raise ValueError("API key must be a string") + return v + + def get_litellm_params(self) -> Dict[str, Any]: + """Get parameters formatted for LiteLLM embedding.""" + params = { + "model": self.model, + } + + # Add optional parameters if provided + if self.user: + params["user"] = self.user + if self.dimensions: + params["dimensions"] = self.dimensions + if self.encoding_format: + params["encoding_format"] = self.encoding_format + if self.api_base: + params["api_base"] = self.api_base + if self.api_version: + params["api_version"] = self.api_version + if self.api_key: + params["api_key"] = self.api_key + if self.api_type: + params["api_type"] = self.api_type + + return params + + def get_provider_type(self) -> str: + """Extract provider type from model name.""" + model_lower = self.model.lower() + + # OpenAI models + if model_lower in [ + "text-embedding-ada-002", + "text-embedding-3-small", + "text-embedding-3-large", + ]: + return "openai" + + # Azure models + elif "azure" in model_lower: + return "azure" + + # Gemini models + elif "gemini" in model_lower: + return "gemini" + + # Default to openai for unknown models + else: + return "unknown" + + def create_variant(self, **overrides) -> "EmbeddingConfig": + """Create a variant of this config with parameter overrides.""" + current_dict = self.model_dump() + current_dict.update(overrides) + return EmbeddingConfig(**current_dict) diff --git a/quantmind/llm/__init__.py b/quantmind/llm/__init__.py index caab46f..55ad035 100644 --- a/quantmind/llm/__init__.py +++ b/quantmind/llm/__init__.py @@ -1,8 +1,11 @@ """LLM module for QuantMind - Basic LLM functionality.""" from .block import LLMBlock, create_llm_block +from .embedding import EmbeddingBlock, create_embedding_block __all__ = [ "LLMBlock", "create_llm_block", + "EmbeddingBlock", + "create_embedding_block", ] diff --git a/quantmind/llm/embedding.py b/quantmind/llm/embedding.py new file mode 100644 index 0000000..3b1c85f --- /dev/null +++ b/quantmind/llm/embedding.py @@ -0,0 +1,323 @@ +"""EmbeddingBlock - A reusable Embedding function block using LiteLLM.""" + +import os +import time +from contextlib import contextmanager +from typing import Any, Dict, List, Optional + +from quantmind.utils.logger import get_logger + +from ..config import EmbeddingConfig + +logger = get_logger(__name__) + +try: + import litellm + from litellm import embedding + + LITELLM_AVAILABLE = True +except ImportError: + LITELLM_AVAILABLE = False + + +class EmbeddingBlock: + """A reusable Embedding function block using LiteLLM. + + EmbeddingBlock provides a consistent interface for generating embeddings across + different providers (OpenAI, Gemini, etc.). + + Unlike workflows, EmbeddingBlock focuses on providing basic embedding capabilities + without business logic. + """ + + def __init__(self, config: EmbeddingConfig): + """Initialize the EmbeddingBlock with configuration. + + Args: + config: Embedding configuration + + Raises: + ImportError: If LiteLLM is not available. + """ + if not LITELLM_AVAILABLE: + raise ImportError( + "litellm is required for EmbeddingBlock but not installed." + ) + + self.config = config + self._setup_litellm() + + logger.info(f"Initialized EmbeddingBlock with model: {config.model}") + + def _setup_litellm(self): + """Setup LiteLLM configuration.""" + # Set global LiteLLM settings + litellm.set_verbose = False # Disable verbose logging by default + + # Configure retries + litellm.num_retries = self.config.retry_attempts + litellm.request_timeout = self.config.timeout + + # Set API key as environment variable if provided + if self.config.api_key: + provider_type = self.config.get_provider_type() + if provider_type == "openai": + os.environ["OPENAI_API_KEY"] = self.config.api_key + elif provider_type == "azure": + os.environ["AZURE_API_KEY"] = self.config.api_key + elif provider_type == "gemini": + os.environ["GEMINI_API_KEY"] = self.config.api_key + + logger.debug( + f"Configured LiteLLM for provider: {self.config.get_provider_type()}" + ) + + def generate_embedding(self, text: str, **kwargs) -> Optional[List[float]]: + """Generate embedding using the configured Embedding model. + + Args: + text (str): The input text to embed. + **kwargs: Additional parameters to override config + + Returns: + List[float]: The embedding vector as a list of floats, or None if failed. + """ + try: + # Get LiteLLM parameters + params = self.config.get_litellm_params() + params.update(kwargs) # Allow runtime overrides + + # Add input text + params["input"] = text + + # Call LiteLLM embedding + response = self._call_with_retry(params) + + if response and hasattr(response, "data"): + # Extract embedding from response + embedding_data = ( + response.data[0] + if isinstance(response.data, list) + else response.data + ) + return embedding_data.embedding + + return None + + except Exception as e: + logger.error(f"Failed to generate embedding: {e}") + return None + + def generate_embeddings( + self, texts: List[str], **kwargs + ) -> Optional[List[List[float]]]: + """Generate embeddings for multiple texts. + + Args: + texts (List[str]): List of input texts to embed. + **kwargs: Additional parameters to override config + + Returns: + List[List[float]]: List of embedding vectors, or None if failed. + """ + try: + # Get LiteLLM parameters + params = self.config.get_litellm_params() + params.update(kwargs) # Allow runtime overrides + + # Add input texts + params["input"] = texts + + # Call LiteLLM embedding + response = self._call_with_retry(params) + + if response and hasattr(response, "data"): + # Extract embeddings from response + return [item.embedding for item in response.data] + + return None + + except Exception as e: + logger.error(f"Failed to generate embeddings: {e}") + return None + + def _call_with_retry(self, params: Dict[str, Any]) -> Optional[Any]: + """Call LiteLLM embedding with retry logic. + + Args: + params (Dict[str, Any]): The parameters to pass to the embedding function. + + Returns: + Optional[Any]: The embedding result or None if failed. + """ + last_exception = None + for attempt in range(self.config.retry_attempts + 1): + try: + logger.debug( + f"Embedding call attempt {attempt + 1}/{self.config.retry_attempts + 1}" + ) + + # Create a copy of params to avoid mutation + call_params = params.copy() + + # Extract input from params + input_text = call_params.pop("input") + + # Remove model from params if it exists to avoid duplication + call_params.pop("model", None) + + response = embedding( + model=self.config.model, input=input_text, **call_params + ) + + if hasattr(response, "usage") and response.usage: + logger.debug(f"Token usage: {response.usage}") + return response + except Exception as e: + last_exception = e + logger.warning( + f"Embedding call attempt {attempt + 1} failed: {e}" + ) + + if attempt < self.config.retry_attempts: + time.sleep(self.config.retry_delay) + else: + logger.error( + f"All {self.config.retry_attempts + 1} attempts failed" + ) + + # Log final error + if last_exception: + logger.error(f"Final error: {last_exception}") + + return None + + def test_connection(self) -> bool: + """Test if the embedding connection is working. + + Returns: + True if connection is working, False otherwise + """ + try: + response = self.generate_embedding("test") + return response is not None and len(response) > 0 + except Exception as e: + logger.error(f"Connection test failed: {e}") + return False + + def get_info(self) -> Dict[str, Any]: + """Get information about the embedding block. + + Returns: + Dictionary with embedding block information + """ + info = { + "model": self.config.model, + "provider": self.config.get_provider_type(), + "timeout": self.config.timeout, + "retry_attempts": self.config.retry_attempts, + } + return info + + def get_embedding_dimension(self) -> Optional[int]: + """Get the dimension of embeddings generated by this model. + + Returns: + Embedding dimension or None if not available + """ + # First check if dimensions is specified in config + if self.config.dimensions: + return self.config.dimensions + + try: + # Try to get dimension by generating a test embedding + test_embedding = self.generate_embedding("test") + return len(test_embedding) if test_embedding else None + except Exception as e: + logger.error(f"Failed to get embedding dimension: {e}") + return None + + def update_config(self, **kwargs) -> None: + """Update the embedding configuration. + + Args: + **kwargs: Configuration parameters to update + """ + for key, value in kwargs.items(): + if hasattr(self.config, key): + setattr(self.config, key, value) + + logger.info(f"Updated embedding configuration: {kwargs}") + + @contextmanager + def temporary_config(self, **kwargs): + """Temporarily modify configuration for a context. + + Args: + **kwargs: Temporary configuration parameters + + Yields: + Self with temporary configuration + """ + original_config = {} + for key, value in kwargs.items(): + if hasattr(self.config, key): + original_config[key] = getattr(self.config, key) + setattr(self.config, key, value) + + try: + yield self + finally: + # Restore original configuration + for key, value in original_config.items(): + setattr(self.config, key, value) + + def batch_embed( + self, texts: List[str], batch_size: int = 32, **kwargs + ) -> Optional[List[List[float]]]: + """Generate embeddings in batches for large datasets. + + Args: + texts: List of texts to embed + batch_size: Number of texts to process in each batch + **kwargs: Additional parameters for embedding generation + + Returns: + List of embedding vectors or None if failed + """ + try: + all_embeddings = [] + + for i in range(0, len(texts), batch_size): + batch = texts[i : i + batch_size] + batch_embeddings = self.generate_embeddings(batch, **kwargs) + + if batch_embeddings is None: + logger.error( + f"Failed to generate embeddings for batch {i // batch_size}" + ) + return None + + all_embeddings.extend(batch_embeddings) + + # Add delay between batches if specified + if self.config.retry_delay > 0 and i + batch_size < len(texts): + time.sleep(self.config.retry_delay) + + return all_embeddings + + except Exception as e: + logger.error(f"Batch embedding failed: {e}") + return None + + +def create_embedding_block(config: EmbeddingConfig) -> EmbeddingBlock: + """Create an EmbeddingBlock instance. + + Args: + config: Embedding configuration + + Returns: + Configured EmbeddingBlock instance + """ + return EmbeddingBlock(config) diff --git a/tests/config/test_embedding.py b/tests/config/test_embedding.py new file mode 100644 index 0000000..d51e434 --- /dev/null +++ b/tests/config/test_embedding.py @@ -0,0 +1,312 @@ +"""Tests for Embedding configuration.""" + +import unittest +from unittest.mock import patch + +from quantmind.config.embedding import EmbeddingConfig + + +class TestEmbeddingConfig(unittest.TestCase): + """Test cases for EmbeddingConfig.""" + + def test_default_config(self): + """Test default configuration values.""" + config = EmbeddingConfig() + + # Test default values + self.assertEqual(config.model, "text-embedding-ada-002") + self.assertIsNone(config.user) + self.assertIsNone(config.dimensions) + self.assertEqual(config.encoding_format, "float") + self.assertEqual(config.timeout, 600) + self.assertIsNone(config.api_base) + self.assertIsNone(config.api_version) + self.assertIsNone(config.api_key) + self.assertIsNone(config.api_type) + + def test_custom_config(self): + """Test custom configuration values.""" + config = EmbeddingConfig( + model="text-embedding-3-small", + user="test_user_123", + dimensions=512, + encoding_format="base64", + timeout=30, + api_key="test-key", + api_base="https://api.example.com", + api_version="2023-05-15", + api_type="azure", + ) + + self.assertEqual(config.model, "text-embedding-3-small") + self.assertEqual(config.user, "test_user_123") + self.assertEqual(config.dimensions, 512) + self.assertEqual(config.encoding_format, "base64") + self.assertEqual(config.timeout, 30) + self.assertEqual(config.api_key, "test-key") + self.assertEqual(config.api_base, "https://api.example.com") + self.assertEqual(config.api_version, "2023-05-15") + self.assertEqual(config.api_type, "azure") + + def test_validation_model(self): + """Test model validation.""" + # Valid model + config = EmbeddingConfig(model="text-embedding-ada-002") + self.assertEqual(config.model, "text-embedding-ada-002") + + # Empty model should raise error + with self.assertRaises(ValueError): + EmbeddingConfig(model="") + + # None model should raise error + with self.assertRaises(ValueError): + EmbeddingConfig(model=None) + + # Whitespace should be stripped + config = EmbeddingConfig(model=" text-embedding-ada-002 ") + self.assertEqual(config.model, "text-embedding-ada-002") + + def test_validation_api_key(self): + """Test API key validation.""" + # Valid API key + config = EmbeddingConfig(api_key="test-key") + self.assertEqual(config.api_key, "test-key") + + # None API key is valid + config = EmbeddingConfig(api_key=None) + self.assertIsNone(config.api_key) + + # Invalid API key type should raise error + with self.assertRaises(ValueError): + EmbeddingConfig(api_key=123) + + with self.assertRaises(ValueError): + EmbeddingConfig(api_key=[]) + + def test_get_provider_type(self): + """Test provider type detection.""" + # OpenAI models + config = EmbeddingConfig(model="text-embedding-ada-002") + self.assertEqual(config.get_provider_type(), "openai") + + config = EmbeddingConfig(model="text-embedding-3-small") + self.assertEqual(config.get_provider_type(), "openai") + + config = EmbeddingConfig(model="text-embedding-3-large") + self.assertEqual(config.get_provider_type(), "openai") + + # Azure models + config = EmbeddingConfig(model="azure/text-embedding-ada-002") + self.assertEqual(config.get_provider_type(), "azure") + + config = EmbeddingConfig(model="text-embedding-ada-002-azure") + self.assertEqual(config.get_provider_type(), "azure") + + # Gemini models + config = EmbeddingConfig(model="gemini/embed-multilingual-v3.0") + self.assertEqual(config.get_provider_type(), "gemini") + + # Unknown models + config = EmbeddingConfig(model="unknown-model") + self.assertEqual(config.get_provider_type(), "unknown") + + def test_get_litellm_params_minimal(self): + """Test get_litellm_params with minimal configuration.""" + config = EmbeddingConfig(model="text-embedding-ada-002") + params = config.get_litellm_params() + + self.assertEqual(params["model"], "text-embedding-ada-002") + self.assertIn("encoding_format", params) + self.assertEqual(len(params), 2) # Only model and encoding_format + + def test_get_litellm_params_full(self): + """Test get_litellm_params with full configuration.""" + config = EmbeddingConfig( + model="text-embedding-3-small", + user="test_user", + dimensions=512, + encoding_format="base64", + timeout=30, + api_key="test-key", + api_base="https://api.example.com", + api_version="2023-05-15", + api_type="azure", + ) + params = config.get_litellm_params() + + expected_params = { + "model": "text-embedding-3-small", + "user": "test_user", + "dimensions": 512, + "encoding_format": "base64", + "api_base": "https://api.example.com", + "api_version": "2023-05-15", + "api_key": "test-key", + "api_type": "azure", + } + + self.assertEqual(params, expected_params) + + def test_get_litellm_params_partial(self): + """Test get_litellm_params with partial configuration.""" + config = EmbeddingConfig( + model="text-embedding-ada-002", + user="test_user", + dimensions=1536, + api_key="test-key", + ) + params = config.get_litellm_params() + + expected_params = { + "model": "text-embedding-ada-002", + "user": "test_user", + "dimensions": 1536, + "encoding_format": "float", + "api_key": "test-key", + } + + self.assertEqual(params, expected_params) + + def test_create_variant(self): + """Test creating configuration variants.""" + base_config = EmbeddingConfig( + model="text-embedding-ada-002", + timeout=60, + api_key="base-key", + ) + + # Create variant with overrides + variant = base_config.create_variant( + timeout=30, + api_key="variant-key", + user="test_user", + ) + + # Original config should be unchanged + self.assertEqual(base_config.timeout, 60) + self.assertEqual(base_config.api_key, "base-key") + self.assertIsNone(base_config.user) + + # Variant should have new values + self.assertEqual(variant.timeout, 30) + self.assertEqual(variant.api_key, "variant-key") + self.assertEqual(variant.user, "test_user") + self.assertEqual(variant.model, "text-embedding-ada-002") # Unchanged + + def test_create_variant_empty(self): + """Test creating variant with no overrides.""" + base_config = EmbeddingConfig( + model="text-embedding-ada-002", + timeout=60, + ) + + variant = base_config.create_variant() + + # Should be identical to base config + self.assertEqual(variant.model, base_config.model) + self.assertEqual(variant.timeout, base_config.timeout) + self.assertEqual(variant.encoding_format, base_config.encoding_format) + + def test_encoding_format_validation(self): + """Test encoding format validation.""" + # Valid encoding formats + config = EmbeddingConfig(encoding_format="float") + self.assertEqual(config.encoding_format, "float") + + config = EmbeddingConfig(encoding_format="base64") + self.assertEqual(config.encoding_format, "base64") + + def test_dimensions_validation(self): + """Test dimensions validation.""" + # Valid dimensions + config = EmbeddingConfig(dimensions=512) + self.assertEqual(config.dimensions, 512) + + config = EmbeddingConfig(dimensions=1536) + self.assertEqual(config.dimensions, 1536) + + config = EmbeddingConfig(dimensions=3072) + self.assertEqual(config.dimensions, 3072) + + # None is valid + config = EmbeddingConfig(dimensions=None) + self.assertIsNone(config.dimensions) + + # Zero and negative dimensions should be allowed (validation handled by API) + config = EmbeddingConfig(dimensions=0) + self.assertEqual(config.dimensions, 0) + + config = EmbeddingConfig(dimensions=-1) + self.assertEqual(config.dimensions, -1) + + def test_timeout_validation(self): + """Test timeout validation.""" + # Valid timeouts + config = EmbeddingConfig(timeout=1) + self.assertEqual(config.timeout, 1) + + config = EmbeddingConfig(timeout=600) + self.assertEqual(config.timeout, 600) + + config = EmbeddingConfig(timeout=3600) + self.assertEqual(config.timeout, 3600) + + # Zero and negative timeouts should be allowed (validation handled by API) + config = EmbeddingConfig(timeout=0) + self.assertEqual(config.timeout, 0) + + config = EmbeddingConfig(timeout=-1) + self.assertEqual(config.timeout, -1) + + def test_equality(self): + """Test config equality.""" + config1 = EmbeddingConfig( + model="text-embedding-ada-002", + user="test_user", + dimensions=512, + ) + + config2 = EmbeddingConfig( + model="text-embedding-ada-002", + user="test_user", + dimensions=512, + ) + + config3 = EmbeddingConfig( + model="text-embedding-3-small", + user="test_user", + dimensions=512, + ) + + self.assertEqual(config1, config2) + self.assertNotEqual(config1, config3) + + def test_repr(self): + """Test config string representation.""" + config = EmbeddingConfig( + model="text-embedding-ada-002", + user="test_user", + dimensions=512, + ) + + repr_str = repr(config) + self.assertIn("text-embedding-ada-002", repr_str) + self.assertIn("test_user", repr_str) + self.assertIn("512", repr_str) + + def test_str(self): + """Test config string representation.""" + config = EmbeddingConfig( + model="text-embedding-ada-002", + user="test_user", + dimensions=512, + ) + + str_repr = str(config) + self.assertIn("text-embedding-ada-002", str_repr) + self.assertIn("test_user", str_repr) + self.assertIn("512", str_repr) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/llm/test_embedding_block.py b/tests/llm/test_embedding_block.py new file mode 100644 index 0000000..4b24cac --- /dev/null +++ b/tests/llm/test_embedding_block.py @@ -0,0 +1,438 @@ +"""Tests for EmbeddingBlock.""" + +import unittest +from unittest import mock +from unittest.mock import Mock, patch + +from quantmind.config import EmbeddingConfig +from quantmind.llm import EmbeddingBlock, create_embedding_block + + +class TestEmbeddingBlock(unittest.TestCase): + """Test cases for EmbeddingBlock.""" + + def setUp(self): + """Set up test fixtures.""" + self.config = EmbeddingConfig( + model="text-embedding-ada-002", + api_key="test-key", + timeout=30, + ) + + @patch("quantmind.llm.embedding.LITELLM_AVAILABLE", True) + @patch("quantmind.llm.embedding.litellm") + def test_init_success(self, mock_litellm): + """Test successful initialization.""" + block = EmbeddingBlock(self.config) + + self.assertEqual(block.config, self.config) + mock_litellm.set_verbose = False + self.assertEqual(mock_litellm.num_retries, 3) + self.assertEqual(mock_litellm.request_timeout, 30) + + @patch("quantmind.llm.embedding.LITELLM_AVAILABLE", False) + def test_init_litellm_unavailable(self): + """Test initialization when LiteLLM is not available.""" + with self.assertRaises(ImportError) as context: + EmbeddingBlock(self.config) + + self.assertIn( + "litellm is required for EmbeddingBlock", str(context.exception) + ) + + @patch("quantmind.llm.embedding.LITELLM_AVAILABLE", True) + @patch("quantmind.llm.embedding.litellm") + @patch("os.environ", {}) + def test_setup_litellm_openai(self, mock_litellm): + """Test LiteLLM setup for OpenAI.""" + config = EmbeddingConfig( + model="text-embedding-ada-002", api_key="test-key" + ) + + with patch("os.environ", {}) as mock_env: + block = EmbeddingBlock(config) + self.assertEqual(mock_env.get("OPENAI_API_KEY"), "test-key") + + @patch("quantmind.llm.embedding.LITELLM_AVAILABLE", True) + @patch("quantmind.llm.embedding.litellm") + @patch("os.environ", {}) + def test_setup_litellm_azure(self, mock_litellm): + """Test LiteLLM setup for Azure.""" + config = EmbeddingConfig( + model="azure/text-embedding-ada-002", api_key="azure-key" + ) + + with patch("os.environ", {}) as mock_env: + block = EmbeddingBlock(config) + self.assertEqual(mock_env.get("AZURE_API_KEY"), "azure-key") + + @patch("quantmind.llm.embedding.LITELLM_AVAILABLE", True) + @patch("quantmind.llm.embedding.litellm") + @patch("quantmind.llm.embedding.embedding") + def test_generate_embedding_success(self, mock_embedding, mock_litellm): + """Test successful single embedding generation.""" + # Mock response + mock_response = Mock() + mock_response.data = [Mock()] + mock_response.data[0].embedding = [0.1, 0.2, 0.3, 0.4, 0.5] + mock_embedding.return_value = mock_response + + block = EmbeddingBlock(self.config) + result = block.generate_embedding("Test text") + + self.assertEqual(result, [0.1, 0.2, 0.3, 0.4, 0.5]) + mock_embedding.assert_called_once() + + @patch("quantmind.llm.embedding.LITELLM_AVAILABLE", True) + @patch("quantmind.llm.embedding.litellm") + @patch("quantmind.llm.embedding.embedding") + def test_generate_embedding_failure(self, mock_embedding, mock_litellm): + """Test embedding generation failure.""" + mock_embedding.side_effect = Exception("API Error") + + block = EmbeddingBlock(self.config) + result = block.generate_embedding("Test text") + + self.assertIsNone(result) + + @patch("quantmind.llm.embedding.LITELLM_AVAILABLE", True) + @patch("quantmind.llm.embedding.litellm") + @patch("quantmind.llm.embedding.embedding") + def test_generate_embeddings_success(self, mock_embedding, mock_litellm): + """Test successful multiple embedding generation.""" + # Mock response + mock_response = Mock() + mock_response.data = [ + Mock(embedding=[0.1, 0.2, 0.3]), + Mock(embedding=[0.4, 0.5, 0.6]), + ] + mock_embedding.return_value = mock_response + + block = EmbeddingBlock(self.config) + result = block.generate_embeddings(["Text 1", "Text 2"]) + + self.assertEqual(result, [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]) + mock_embedding.assert_called_once() + + @patch("quantmind.llm.embedding.LITELLM_AVAILABLE", True) + @patch("quantmind.llm.embedding.litellm") + @patch("quantmind.llm.embedding.embedding") + def test_generate_embeddings_failure(self, mock_embedding, mock_litellm): + """Test multiple embedding generation failure.""" + mock_embedding.side_effect = Exception("API Error") + + block = EmbeddingBlock(self.config) + result = block.generate_embeddings(["Text 1", "Text 2"]) + + self.assertIsNone(result) + + @patch("quantmind.llm.embedding.LITELLM_AVAILABLE", True) + @patch("quantmind.llm.embedding.litellm") + @patch("quantmind.llm.embedding.embedding") + def test_call_with_retry_success(self, mock_embedding, mock_litellm): + """Test successful call with retry.""" + mock_response = Mock() + mock_embedding.return_value = mock_response + + block = EmbeddingBlock(self.config) + result = block._call_with_retry( + {"model": "text-embedding-ada-002", "input": "test"} + ) + + self.assertEqual(result, mock_response) + mock_embedding.assert_called_once() + + @patch("quantmind.llm.embedding.LITELLM_AVAILABLE", True) + @patch("quantmind.llm.embedding.litellm") + @patch("quantmind.llm.embedding.embedding") + @patch("time.sleep") + def test_call_with_retry_failure_then_success( + self, mock_sleep, mock_embedding, mock_litellm + ): + """Test retry logic with failure then success.""" + mock_response = Mock() + mock_embedding.side_effect = [ + Exception("First failure"), + mock_response, + ] + + block = EmbeddingBlock(self.config) + result = block._call_with_retry( + {"model": "text-embedding-ada-002", "input": "test"} + ) + + self.assertEqual(result, mock_response) + self.assertEqual(mock_embedding.call_count, 2) + mock_sleep.assert_called_once_with(1.0) + + @patch("quantmind.llm.embedding.LITELLM_AVAILABLE", True) + @patch("quantmind.llm.embedding.litellm") + @patch("quantmind.llm.embedding.embedding") + @patch("time.sleep") + def test_call_with_retry_all_failures( + self, mock_sleep, mock_embedding, mock_litellm + ): + """Test retry logic with all failures.""" + mock_embedding.side_effect = Exception("Always fails") + + block = EmbeddingBlock(self.config) + result = block._call_with_retry( + {"model": "text-embedding-ada-002", "input": "test"} + ) + + self.assertIsNone(result) + self.assertEqual(mock_embedding.call_count, 4) # 1 initial + 3 retries + self.assertEqual(mock_sleep.call_count, 3) + + @patch("quantmind.llm.embedding.LITELLM_AVAILABLE", True) + @patch("quantmind.llm.embedding.litellm") + @patch("quantmind.llm.embedding.embedding") + def test_test_connection_success(self, mock_embedding, mock_litellm): + """Test successful connection test.""" + mock_response = Mock() + mock_response.data = [Mock()] + mock_response.data[0].embedding = [0.1, 0.2, 0.3] + mock_embedding.return_value = mock_response + + block = EmbeddingBlock(self.config) + result = block.test_connection() + + self.assertTrue(result) + + @patch("quantmind.llm.embedding.LITELLM_AVAILABLE", True) + @patch("quantmind.llm.embedding.litellm") + @patch("quantmind.llm.embedding.embedding") + def test_test_connection_failure(self, mock_embedding, mock_litellm): + """Test connection test failure.""" + mock_embedding.side_effect = Exception("Connection failed") + + block = EmbeddingBlock(self.config) + result = block.test_connection() + + self.assertFalse(result) + + @patch("quantmind.llm.embedding.LITELLM_AVAILABLE", True) + @patch("quantmind.llm.embedding.litellm") + @patch("quantmind.llm.embedding.embedding") + def test_get_embedding_dimension_from_config( + self, mock_embedding, mock_litellm + ): + """Test getting embedding dimension from config.""" + config = EmbeddingConfig( + model="text-embedding-3-small", + dimensions=512, + ) + block = EmbeddingBlock(config) + dimension = block.get_embedding_dimension() + + self.assertEqual(dimension, 512) + + @patch("quantmind.llm.embedding.LITELLM_AVAILABLE", True) + @patch("quantmind.llm.embedding.litellm") + @patch("quantmind.llm.embedding.embedding") + def test_get_embedding_dimension_from_test_embedding( + self, mock_embedding, mock_litellm + ): + """Test getting embedding dimension by generating test embedding.""" + mock_response = Mock() + mock_response.data = [Mock()] + mock_response.data[0].embedding = [0.1, 0.2, 0.3, 0.4, 0.5] + mock_embedding.return_value = mock_response + + block = EmbeddingBlock(self.config) + dimension = block.get_embedding_dimension() + + self.assertEqual(dimension, 5) + + @patch("quantmind.llm.embedding.LITELLM_AVAILABLE", True) + @patch("quantmind.llm.embedding.litellm") + @patch("quantmind.llm.embedding.embedding") + def test_get_embedding_dimension_failure( + self, mock_embedding, mock_litellm + ): + """Test getting embedding dimension when test embedding fails.""" + mock_embedding.side_effect = Exception("API Error") + + block = EmbeddingBlock(self.config) + dimension = block.get_embedding_dimension() + + self.assertIsNone(dimension) + + @patch("quantmind.llm.embedding.LITELLM_AVAILABLE", True) + @patch("quantmind.llm.embedding.litellm") + def test_get_info(self, mock_litellm): + """Test getting block info.""" + block = EmbeddingBlock(self.config) + info = block.get_info() + + expected_keys = ["model", "provider", "timeout", "retry_attempts"] + for key in expected_keys: + self.assertIn(key, info) + + self.assertEqual(info["model"], "text-embedding-ada-002") + self.assertEqual(info["provider"], "openai") + self.assertEqual(info["timeout"], 30) + self.assertEqual(info["retry_attempts"], 3) + + @patch("quantmind.llm.embedding.LITELLM_AVAILABLE", True) + @patch("quantmind.llm.embedding.litellm") + def test_update_config(self, mock_litellm): + """Test configuration update.""" + block = EmbeddingBlock(self.config) + + # Check initial config + self.assertEqual(block.config.timeout, 30) + self.assertEqual(block.config.api_key, "test-key") + + # Update config + block.update_config(timeout=60, api_key="new-key") + + # Check updated config + self.assertEqual(block.config.timeout, 60) + self.assertEqual(block.config.api_key, "new-key") + # Other values should remain unchanged + self.assertEqual(block.config.model, "text-embedding-ada-002") + + @patch("quantmind.llm.embedding.LITELLM_AVAILABLE", True) + @patch("quantmind.llm.embedding.litellm") + def test_temporary_config(self, mock_litellm): + """Test temporary configuration context manager.""" + block = EmbeddingBlock(self.config) + + # Check initial config + self.assertEqual(block.config.timeout, 30) + + # Use temporary config + with block.temporary_config(timeout=60): + self.assertEqual(block.config.timeout, 60) + + # Check config is restored + self.assertEqual(block.config.timeout, 30) + + @patch("quantmind.llm.embedding.LITELLM_AVAILABLE", True) + @patch("quantmind.llm.embedding.litellm") + @patch("quantmind.llm.embedding.embedding") + def test_batch_embed_success(self, mock_embedding, mock_litellm): + """Test successful batch embedding.""" + mock_response = Mock() + mock_response.data = [ + Mock(embedding=[0.1, 0.2]), + Mock(embedding=[0.3, 0.4]), + ] + mock_embedding.return_value = mock_response + + block = EmbeddingBlock(self.config) + texts = ["Text 1", "Text 2", "Text 3", "Text 4"] + result = block.batch_embed(texts, batch_size=2) + + expected = [[0.1, 0.2], [0.3, 0.4], [0.1, 0.2], [0.3, 0.4]] + self.assertEqual(result, expected) + self.assertEqual(mock_embedding.call_count, 2) # 2 batches + + @patch("quantmind.llm.embedding.LITELLM_AVAILABLE", True) + @patch("quantmind.llm.embedding.litellm") + @patch("quantmind.llm.embedding.embedding") + def test_batch_embed_failure(self, mock_embedding, mock_litellm): + """Test batch embedding failure.""" + mock_embedding.side_effect = Exception("API Error") + + block = EmbeddingBlock(self.config) + texts = ["Text 1", "Text 2", "Text 3", "Text 4"] + result = block.batch_embed(texts, batch_size=2) + + self.assertIsNone(result) + + @patch("quantmind.llm.embedding.LITELLM_AVAILABLE", True) + @patch("quantmind.llm.embedding.litellm") + @patch("quantmind.llm.embedding.embedding") + @patch("time.sleep") + def test_batch_embed_with_delay( + self, mock_sleep, mock_embedding, mock_litellm + ): + """Test batch embedding with delay between batches.""" + config = EmbeddingConfig( + model="text-embedding-ada-002", + api_key="test-key", + timeout=30, + retry_delay=0.1, + ) + + mock_response = Mock() + mock_response.data = [ + Mock(embedding=[0.1, 0.2]), + Mock(embedding=[0.3, 0.4]), + ] + mock_embedding.return_value = mock_response + + block = EmbeddingBlock(config) + texts = ["Text 1", "Text 2", "Text 3", "Text 4"] + result = block.batch_embed(texts, batch_size=2) + + # Should have delay between batches + self.assertEqual(mock_sleep.call_count, 1) # Delay between 2 batches + + @patch("quantmind.llm.embedding.LITELLM_AVAILABLE", True) + @patch("quantmind.llm.embedding.litellm") + @patch("quantmind.llm.embedding.embedding") + def test_generate_embedding_with_kwargs(self, mock_embedding, mock_litellm): + """Test embedding generation with additional kwargs.""" + mock_response = Mock() + mock_response.data = [Mock()] + mock_response.data[0].embedding = [0.1, 0.2, 0.3] + mock_embedding.return_value = mock_response + + block = EmbeddingBlock(self.config) + result = block.generate_embedding( + "Test text", dimensions=512, user="test_user" + ) + + # Check that kwargs were passed to the embedding call + call_args = mock_embedding.call_args + self.assertIn("dimensions", call_args[1]) + self.assertIn("user", call_args[1]) + self.assertEqual(call_args[1]["dimensions"], 512) + self.assertEqual(call_args[1]["user"], "test_user") + + @patch("quantmind.llm.embedding.LITELLM_AVAILABLE", True) + @patch("quantmind.llm.embedding.litellm") + @patch("quantmind.llm.embedding.embedding") + def test_generate_embeddings_with_kwargs( + self, mock_embedding, mock_litellm + ): + """Test multiple embedding generation with additional kwargs.""" + mock_response = Mock() + mock_response.data = [ + Mock(embedding=[0.1, 0.2]), + Mock(embedding=[0.3, 0.4]), + ] + mock_embedding.return_value = mock_response + + block = EmbeddingBlock(self.config) + result = block.generate_embeddings( + ["Text 1", "Text 2"], dimensions=512, user="test_user" + ) + + # Check that kwargs were passed to the embedding call + call_args = mock_embedding.call_args + self.assertIn("dimensions", call_args[1]) + self.assertIn("user", call_args[1]) + self.assertEqual(call_args[1]["dimensions"], 512) + self.assertEqual(call_args[1]["user"], "test_user") + + +class TestCreateEmbeddingBlock(unittest.TestCase): + """Test cases for create_embedding_block function.""" + + @patch("quantmind.llm.embedding.LITELLM_AVAILABLE", True) + @patch("quantmind.llm.embedding.litellm") + def test_create_embedding_block(self, mock_litellm): + """Test EmbeddingBlock creation.""" + config = EmbeddingConfig(model="text-embedding-ada-002") + block = create_embedding_block(config) + + self.assertIsInstance(block, EmbeddingBlock) + self.assertEqual(block.config, config) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/llm/test_block.py b/tests/llm/test_llm_block.py similarity index 100% rename from tests/llm/test_block.py rename to tests/llm/test_llm_block.py diff --git a/wiki/embedding.md b/wiki/embedding.md new file mode 100644 index 0000000..54bd300 --- /dev/null +++ b/wiki/embedding.md @@ -0,0 +1,324 @@ +# 🌟 Embedding Systems in QuantMind 🌟 + +## šŸ“‹ Table of Contents + +
+ šŸ“œ Contents +
    +
  1. šŸ“Œ Overview
  2. +
  3. šŸ“Œ Theoretical Background
  4. +
  5. šŸ“Œ Architecture
  6. +
  7. šŸ“Œ Configuration
  8. +
  9. šŸ“Œ Usage Examples
  10. +
  11. šŸ“Œ Advanced Features
  12. +
  13. šŸ“Œ Best Practices
  14. +
+
+ +## šŸ“Œ Overview + +Embeddings are numerical representations of text that capture semantic meaning in high-dimensional vector spaces. In quantitative finance, embeddings enable: + +- **Document Analysis**: Converting financial reports into searchable vectors +- **Semantic Search**: Finding similar financial documents +- **Content Clustering**: Grouping related financial information +- **Feature Engineering**: Creating numerical features from textual data + +QuantMind provides a flexible embedding system through the `EmbeddingBlock` class, supporting multiple providers via a unified interface. + +## šŸ“Œ Theoretical Background + +### What are Embeddings? + +Embeddings map discrete objects (words, sentences, documents) to continuous vector spaces where: +- **Similar objects** are positioned close to each other +- **Mathematical operations** have semantic meaning +- **Dimensionality** typically ranges from 100 to 1536 dimensions + +### Supported Models + +| Model | Dimensions | Use Case | Provider | +|-------|------------|----------|----------| +| `text-embedding-ada-002` | 1536 | General purpose | OpenAI | +| `text-embedding-3-small` | 1536 | High performance | OpenAI | +| `text-embedding-3-large` | 3072 | Maximum quality | OpenAI | + +## šŸ“Œ Architecture + +### Core Components + +```python +from quantmind.config import EmbeddingConfig +from quantmind.llm import EmbeddingBlock, create_embedding_block +``` + +#### EmbeddingConfig +Manages all embedding parameters: + +```python +class EmbeddingConfig(BaseModel): + model: str = "text-embedding-ada-002" + user: Optional[str] = None + dimensions: Optional[int] = None + encoding_format: str = "float" + timeout: int = 600 + api_base: Optional[str] = None + api_version: Optional[str] = None + api_key: Optional[str] = None + api_type: Optional[str] = None +``` + +#### EmbeddingBlock +Main interface for generating embeddings: + +```python +class EmbeddingBlock: + def generate_embedding(self, text: str) -> Optional[List[float]] + def generate_embeddings(self, texts: List[str]) -> Optional[List[List[float]]] + def batch_embed(self, texts: List[str], batch_size: int = 100) -> List[List[float]] + def test_connection(self) -> bool + def get_info(self) -> Dict[str, Any] +``` + +## šŸ“Œ Configuration + +### Basic Setup + +```python +from quantmind.config import EmbeddingConfig + +# Simple configuration +config = EmbeddingConfig( + model="text-embedding-ada-002", + api_key="your-api-key" +) +``` + +### Advanced Configuration + +```python +# Custom dimensions (OpenAI text-embedding-3+) +config = EmbeddingConfig( + model="text-embedding-3-small", + dimensions=512, # Reduce from default 1536 + encoding_format="float", + timeout=30 +) + +# Azure OpenAI +config = EmbeddingConfig( + model="text-embedding-ada-002", + api_key="azure-key", + api_base="https://your-resource.openai.azure.com/", + api_version="2023-05-15", + api_type="azure" +) +``` + +## šŸ“Œ Usage Examples + +### Basic Embedding Generation + +```python +from quantmind.config import EmbeddingConfig +from quantmind.llm import create_embedding_block + +# Create configuration +config = EmbeddingConfig( + model="text-embedding-ada-002", + api_key=os.getenv("OPENAI_API_KEY") +) + +# Create embedding block +embedding_block = create_embedding_block(config) + +# Generate single embedding +text = "Apple Inc. reported strong quarterly earnings with revenue growth of 15%." +embedding = embedding_block.generate_embedding(text) + +if embedding: + print(f"Generated embedding with {len(embedding)} dimensions") + print(f"First 5 values: {embedding[:5]}") +``` + +### Batch Processing + +```python +# Generate multiple embeddings +texts = [ + "Apple Inc. reported strong quarterly earnings.", + "Microsoft Corp. announced new AI initiatives.", + "Tesla Inc. delivered record vehicle production.", + "Amazon.com Inc. expanded cloud services." +] + +embeddings = embedding_block.generate_embeddings(texts) + +if embeddings: + print(f"Generated {len(embeddings)} embeddings") + for i, emb in enumerate(embeddings): + print(f"Text {i+1}: {len(emb)} dimensions") +``` + +### Semantic Similarity + +```python +import numpy as np +from sklearn.metrics.pairwise import cosine_similarity + +def calculate_similarity(text1: str, text2: str) -> float: + """Calculate semantic similarity between two texts.""" + embeddings = embedding_block.generate_embeddings([text1, text2]) + + if embeddings and len(embeddings) == 2: + similarity = cosine_similarity( + [embeddings[0]], + [embeddings[1]] + )[0][0] + return similarity + + return 0.0 + +# Example usage +text1 = "Apple's iPhone sales exceeded expectations" +text2 = "iPhone revenue surpassed analyst predictions" +similarity = calculate_similarity(text1, text2) +print(f"Similarity: {similarity:.3f}") +``` + +## šŸ“Œ Advanced Features + +### Custom Dimensions + +```python +# Reduce embedding dimensions for efficiency +config = EmbeddingConfig( + model="text-embedding-3-small", + dimensions=512 # Reduce from 1536 to 512 +) + +embedding_block = create_embedding_block(config) +embedding = embedding_block.generate_embedding("Sample text") +print(f"Reduced dimensions: {len(embedding)}") # 512 +``` + +### Connection Testing + +```python +# Test API connection before processing +if embedding_block.test_connection(): + print("āœ… API connection successful") + # Proceed with embedding generation +else: + print("āŒ API connection failed") + # Handle error or retry +``` + +### Configuration Information + +```python +# Get detailed configuration information +info = embedding_block.get_info() +print(f"Model: {info['model']}") +print(f"Provider: {info['provider']}") +print(f"Dimensions: {info['dimension']}") +print(f"Format: {info['encoding_format']}") +``` + +## šŸ“Œ Best Practices + +### 1. Model Selection + +| Use Case | Recommended Model | Reasoning | +|----------|------------------|-----------| +| **General purpose** | `text-embedding-ada-002` | Good balance of quality and cost | +| **High performance** | `text-embedding-3-small` | Better quality, slightly higher cost | +| **Maximum quality** | `text-embedding-3-large` | Best quality, highest cost | +| **Multilingual** | `embed-multilingual-v3.0` | Support for multiple languages | + +### 2. Batch Processing + +```python +# Efficient batch processing +def efficient_batch_embedding(texts: List[str], batch_size: int = 100): + """Process texts in optimal batches.""" + all_embeddings = [] + + for i in range(0, len(texts), batch_size): + batch = texts[i:i + batch_size] + try: + batch_embeddings = embedding_block.generate_embeddings(batch) + if batch_embeddings: + all_embeddings.extend(batch_embeddings) + except Exception as e: + print(f"Error processing batch {i//batch_size + 1}: {e}") + + return all_embeddings +``` + +### 3. Error Handling + +```python +def robust_embedding_generation(text: str, max_retries: int = 3): + """Generate embedding with retry logic.""" + for attempt in range(max_retries): + try: + embedding = embedding_block.generate_embedding(text) + if embedding: + return embedding + except Exception as e: + print(f"Attempt {attempt + 1} failed: {e}") + if attempt < max_retries - 1: + time.sleep(2 ** attempt) # Exponential backoff + + return None +``` + +### 4. Caching + +```python +import hashlib +import pickle +import os + +class CachedEmbeddingBlock: + def __init__(self, embedding_block, cache_dir: str = "embedding_cache"): + self.embedding_block = embedding_block + self.cache_dir = cache_dir + os.makedirs(cache_dir, exist_ok=True) + + def get_embedding(self, text: str) -> Optional[List[float]]: + # Create cache key + text_hash = hashlib.md5(text.encode()).hexdigest() + cache_file = os.path.join(self.cache_dir, f"{text_hash}.pkl") + + # Check cache + if os.path.exists(cache_file): + with open(cache_file, 'rb') as f: + return pickle.load(f) + + # Generate embedding + embedding = self.embedding_block.generate_embedding(text) + + # Cache result + if embedding: + with open(cache_file, 'wb') as f: + pickle.dump(embedding, f) + + return embedding +``` + +## šŸ“š Related Documentation + +- [EmbeddingBlock API Reference](../docs/EMBEDDINGS.md) +- [Examples](../examples/llm/embedding_block_example.py) +- [Configuration Guide](../quantmind/config/embedding.py) + +## šŸ”— External Resources + +- [OpenAI Embeddings Guide](https://platform.openai.com/docs/guides/embeddings) +- [Vector Similarity Search](https://www.pinecone.io/learn/vector-similarity-search/) + +--- + +*Last updated: January 2025*