Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions frontend/src/components/settings/EmbeddingConfigForm.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ import { Label } from '@/components/ui/label';
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card';
import { Badge } from '@/components/ui/badge';
import { Alert, AlertDescription } from '@/components/ui/alert';
import { Switch } from '@/components/ui/switch';
import { Separator } from '@/components/ui/separator';
import {
Select,
SelectContent,
Expand Down Expand Up @@ -153,6 +155,15 @@ export function EmbeddingConfigForm() {
}));
}, [currentProvider]);

// Handle active toggle
const handleToggle = useCallback((checked: boolean) => {
setFormState(prev => ({
...prev,
isEnabled: checked,
hasChanges: true,
}));
}, []);

// Handler for testing connection
const handleTestConnection = useCallback(async () => {
setFormState(prev => ({
Expand Down Expand Up @@ -433,6 +444,23 @@ export function EmbeddingConfigForm() {
</div>
)}

<Separator />

{/* Set as Active Toggle */}
<div className="flex items-center justify-between">
<div className="space-y-0.5">
<Label htmlFor="is-active">Set as active</Label>
<p className="text-sm text-muted-foreground">
Use this embedding provider for document processing
</p>
</div>
<Switch
id="is-active"
checked={formState.isEnabled}
onCheckedChange={handleToggle}
/>
</div>

{/* Action Buttons */}
<div className="flex gap-2 pt-2">
<Button
Expand Down
42 changes: 9 additions & 33 deletions ragitect/api/v1/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,11 @@
from ragitect.services.database.repositories.document_repo import DocumentRepository
from ragitect.services.database.repositories.vector_repo import VectorRepository
from ragitect.services.database.repositories.workspace_repo import WorkspaceRepository
from ragitect.services.embedding import create_embeddings_model, embed_text
from ragitect.services.llm_config_service import get_active_embedding_config
from ragitect.services.embedding import (
create_embeddings_model,
embed_text,
get_embedding_model_from_config,
)
from ragitect.services.llm_factory import create_llm_with_provider


Expand Down Expand Up @@ -166,21 +169,8 @@ async def retrieve_context_with_graph(
Returns:
List of chunks with content and metadata
"""
# Get embedding configuration and create model
embedding_config = await get_active_embedding_config(session)

if embedding_config:
config = EmbeddingConfig(
provider=embedding_config.provider_name,
model=embedding_config.model_name or "nomic-embed-text",
base_url=embedding_config.config_data.get("base_url"),
api_key=embedding_config.config_data.get("api_key"),
dimension=embedding_config.config_data.get("dimension", 768),
)
else:
config = EmbeddingConfig()

embedding_model = create_embeddings_model(config)
# Get embedding model from active DB config (or defaults)
embedding_model = await get_embedding_model_from_config(session)

# Create async embedding function for graph
async def embed_fn(text: str) -> list[float]:
Expand Down Expand Up @@ -373,22 +363,8 @@ async def chat_stream(
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))

# Get active embedding configuration and model
embed_config_dto = await get_active_embedding_config(session)
if embed_config_dto is None:
# Use default config if no active config found
embed_config = EmbeddingConfig()
else:
# Convert DTO to EmbeddingConfig
embed_config = EmbeddingConfig(
provider=embed_config_dto.provider_name,
model=embed_config_dto.model_name or "all-MiniLM-L6-v2",
api_key=embed_config_dto.api_key,
base_url=embed_config_dto.base_url,
dimension=embed_config_dto.dimension,
)

embed_model = create_embeddings_model(embed_config)
# Get embedding model from active DB config (or defaults)
embed_model = await get_embedding_model_from_config(session)

async def embed_fn(text: str) -> list[float]:
"""Embedding function for runtime dependency injection."""
Expand Down
4 changes: 2 additions & 2 deletions ragitect/services/document_processing_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from ragitect.services.config import load_document_config
from ragitect.services.database.repositories.document_repo import DocumentRepository
from ragitect.services.document_processor import process_file_bytes, split_document
from ragitect.services.embedding import create_embeddings_model, embed_documents
from ragitect.services.embedding import embed_documents, get_embedding_model_from_config

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -135,7 +135,7 @@ async def process_document(self, document_id: UUID) -> None:
# Generate embeddings if there are chunks
if chunks:
try:
embedding_model = create_embeddings_model()
embedding_model = await get_embedding_model_from_config(self.session)
embeddings = await embed_documents(embedding_model, chunks)
logger.info(
f"Generated {len(embeddings)} embeddings for document {document_id}"
Expand Down
49 changes: 48 additions & 1 deletion ragitect/services/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,17 @@
"""

import logging
from typing import Any
from typing import TYPE_CHECKING, Any

import httpx
from langchain_core.embeddings import Embeddings
from langchain_ollama import OllamaEmbeddings

from ragitect.services.config import EmbeddingConfig

if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -307,3 +310,47 @@ def get_embedding_dimension(config: EmbeddingConfig | None = None) -> int:
if config is None:
return 768 # Default for nomic-embed-text
return config.dimension


async def get_embedding_model_from_config(session: "AsyncSession") -> Embeddings:
"""Get embedding model using active DB config, fallback to env vars, then defaults.

Priority order:
1. Active embedding config from database (UI-configured)
2. Environment variables (EMBEDDING_PROVIDER, EMBEDDING_API_KEY, etc.)
3. Hardcoded defaults (Ollama)

This is the DRY helper that encapsulates embedding config lookup, avoiding
duplicate patterns across chat.py and document_processing_service.py.

Args:
session: Database session for config lookup

Returns:
Embeddings model configured from DB, env vars, or defaults
"""
# Import inside function to avoid circular imports
from ragitect.services.llm_config_service import get_active_embedding_config
from ragitect.services.config import load_embedding_config

embedding_config = await get_active_embedding_config(session)
if embedding_config:
logger.info(
f"Using embedding config from DB: provider={embedding_config.provider_name}, "
f"model={embedding_config.model_name}"
)
config = EmbeddingConfig(
provider=embedding_config.provider_name,
model=embedding_config.model_name or "qwen3-embedding:0.6b",
base_url=embedding_config.base_url,
api_key=embedding_config.api_key,
dimension=embedding_config.dimension,
)
else:
# Fall back to environment variables
config = load_embedding_config()
logger.info(
f"No active embedding config in DB, using env vars: "
f"provider={config.provider}, model={config.model}"
)
return create_embeddings_model(config)
18 changes: 2 additions & 16 deletions tests/api/v1/test_chat_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,23 +32,10 @@ def setup_langgraph_streaming_mocks(mocker):
Returns:
dict with mock objects for assertions
"""
# Mock embedding config - required for both retrieval and streaming
mock_embed_config = mocker.MagicMock()
mock_embed_config.provider_name = "ollama"
mock_embed_config.model_name = "all-MiniLM-L6-v2"
mock_embed_config.api_key = None
mock_embed_config.base_url = None
mock_embed_config.dimension = 768

mocker.patch(
"ragitect.api.v1.chat.get_active_embedding_config",
return_value=mock_embed_config,
)

# Mock embedding model and embed function
# Mock embedding model - now using the DRY helper
mock_embed_model = mocker.MagicMock()
mocker.patch(
"ragitect.api.v1.chat.create_embeddings_model",
"ragitect.api.v1.chat.get_embedding_model_from_config",
return_value=mock_embed_model,
)

Expand Down Expand Up @@ -96,7 +83,6 @@ async def mock_embed_fn(model, text: str):
)

return {
"embed_config": mock_embed_config,
"embed_model": mock_embed_model,
"vector_repo": mock_vector_repo,
"llm": mock_llm,
Expand Down
9 changes: 6 additions & 3 deletions tests/services/test_document_processing_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ async def test_process_document_success(
mock_model = MagicMock()
mock_model.aembed_documents = AsyncMock(return_value=[[0.1] * 768])
with patch(
"ragitect.services.document_processing_service.create_embeddings_model",
return_value=mock_model,
"ragitect.services.document_processing_service.get_embedding_model_from_config",
new=AsyncMock(return_value=mock_model),
):
with patch(
"ragitect.services.document_processing_service.embed_documents",
Expand Down Expand Up @@ -247,8 +247,11 @@ async def test_process_document_integration_with_repo_methods(
"ragitect.services.document_processing_service.split_document",
return_value=["Chunk"],
):
mock_model = MagicMock()
mock_model.aembed_documents = AsyncMock(return_value=[[0.1] * 768])
with patch(
"ragitect.services.document_processing_service.create_embeddings_model"
"ragitect.services.document_processing_service.get_embedding_model_from_config",
new=AsyncMock(return_value=mock_model),
):
with patch(
"ragitect.services.document_processing_service.embed_documents",
Expand Down
Loading
Loading