From c6cec68507e6538a6a528a2a84bb856061e6705f Mon Sep 17 00:00:00 2001 From: Ganesh Murthy Date: Fri, 19 Dec 2025 10:13:00 -0500 Subject: [PATCH] APPENG-4252: Adds the following features Add new drop down to show all available vector databases and when a database is picked, show the documents already uploaded to the database. Ability to delete a document that was already added to a vector database Ability to upload document to the any chosen vector database Create new vector databases by clicking the Create New button --- deploy/helm/rag/templates/deployment.yaml | 12 + .../rag/templates/embedding-warmup-job.yaml | 91 +++ deploy/helm/rag/values.yaml | 7 + .../distribution/ui/page/upload/upload.py | 627 ++++++++++++++++-- frontend/pyproject.toml | 1 + tests/integration/test_upload_integration.py | 16 + tests/unit/test_upload.py | 450 ++++++++----- 7 files changed, 987 insertions(+), 217 deletions(-) create mode 100644 deploy/helm/rag/templates/embedding-warmup-job.yaml diff --git a/deploy/helm/rag/templates/deployment.yaml b/deploy/helm/rag/templates/deployment.yaml index eddbcccf..dbc49cea 100644 --- a/deploy/helm/rag/templates/deployment.yaml +++ b/deploy/helm/rag/templates/deployment.yaml @@ -40,6 +40,18 @@ spec: - name: TAVILY_SEARCH_API_KEY value: {{ (index .Values "llama-stack").secrets.TAVILY_SEARCH_API_KEY | quote }} {{- end }} + {{- if .Values.pgvector }} + - name: PGVECTOR_HOST + value: {{ .Values.pgvector.secret.host | quote }} + - name: PGVECTOR_PORT + value: {{ .Values.pgvector.secret.port | quote }} + - name: PGVECTOR_USER + value: {{ .Values.pgvector.secret.user | quote }} + - name: PGVECTOR_PASSWORD + value: {{ .Values.pgvector.secret.password | quote }} + - name: PGVECTOR_DB + value: {{ .Values.pgvector.secret.dbname | quote }} + {{- end }} {{- if .Values.suggestedQuestions }} - name: RAG_QUESTION_SUGGESTIONS valueFrom: diff --git a/deploy/helm/rag/templates/embedding-warmup-job.yaml b/deploy/helm/rag/templates/embedding-warmup-job.yaml new file mode 100644 index 00000000..0cf97622 --- /dev/null +++ b/deploy/helm/rag/templates/embedding-warmup-job.yaml @@ -0,0 +1,91 @@ +{{/* +Embedding Warmup Job +This job ensures the embedding model is fully loaded before ingestion pipelines run. +It prevents the race condition where pipelines try to embed documents before the embedding model is ready. +*/}} +{{- if .Values.global.embeddingWarmup.enabled | default true }} +apiVersion: batch/v1 +kind: Job +metadata: + name: {{ include "rag.fullname" . }}-embedding-warmup + labels: + {{- include "rag.labels" . | nindent 4 }} + app.kubernetes.io/component: embedding-warmup + annotations: + # Run as a post-install hook with low weight to run early + "helm.sh/hook": post-install,post-upgrade + "helm.sh/hook-weight": "-10" + "helm.sh/hook-delete-policy": hook-succeeded,before-hook-creation +spec: + ttlSecondsAfterFinished: 300 + backoffLimit: 10 + template: + metadata: + labels: + {{- include "rag.selectorLabels" . | nindent 8 }} + app.kubernetes.io/component: embedding-warmup + spec: + restartPolicy: OnFailure + containers: + - name: warmup + image: "image-registry.openshift-image-registry.svc:5000/openshift/tools:latest" + imagePullPolicy: IfNotPresent + env: + - name: LLAMASTACK_URL + value: "http://llamastack:8321" + - name: EMBEDDING_MODEL + value: {{ .Values.global.embeddingWarmup.model | default "all-MiniLM-L6-v2" | quote }} + - name: MAX_RETRIES + value: {{ .Values.global.embeddingWarmup.maxRetries | default "60" | quote }} + - name: RETRY_INTERVAL + value: {{ .Values.global.embeddingWarmup.retryInterval | default "5" | quote }} + command: + - /bin/bash + - -c + - | + set -e + echo "=== Embedding Model Warmup Job ===" + echo "LlamaStack URL: $LLAMASTACK_URL" + echo "Embedding Model: $EMBEDDING_MODEL" + echo "Max Retries: $MAX_RETRIES" + echo "Retry Interval: ${RETRY_INTERVAL}s" + echo "" + + # First wait for LlamaStack to be available + echo "Step 1: Waiting for LlamaStack to be available..." + retries=0 + until curl -sf "$LLAMASTACK_URL/v1/models" > /dev/null 2>&1; do + retries=$((retries + 1)) + if [ $retries -ge $MAX_RETRIES ]; then + echo "ERROR: LlamaStack not available after $MAX_RETRIES retries" + exit 1 + fi + echo " Waiting for LlamaStack... (attempt $retries/$MAX_RETRIES)" + sleep $RETRY_INTERVAL + done + echo " LlamaStack is available!" + echo "" + + # Now warm up the embedding model by making an actual embedding request + echo "Step 2: Warming up embedding model..." + retries=0 + until curl -sf -X POST "$LLAMASTACK_URL/v1/inference/embeddings" \ + -H "Content-Type: application/json" \ + -d "{\"model_id\": \"$EMBEDDING_MODEL\", \"contents\": [\"warmup test\"]}" \ + --max-time 30 \ + | grep -q "embeddings"; do + retries=$((retries + 1)) + if [ $retries -ge $MAX_RETRIES ]; then + echo "ERROR: Embedding model not ready after $MAX_RETRIES retries" + exit 1 + fi + echo " Waiting for embedding model to load... (attempt $retries/$MAX_RETRIES)" + sleep $RETRY_INTERVAL + done + echo " Embedding model is ready!" + echo "" + + echo "=== Warmup Complete ===" + echo "The embedding model is now loaded and ready for ingestion pipelines." +{{- end }} + diff --git a/deploy/helm/rag/values.yaml b/deploy/helm/rag/values.yaml index 744fe059..3642c7ca 100644 --- a/deploy/helm/rag/values.yaml +++ b/deploy/helm/rag/values.yaml @@ -120,6 +120,13 @@ volumeMounts: global: models: {} mcp-servers: {} + # Embedding warmup configuration + # Ensures the embedding model is loaded before ingestion pipelines run + embeddingWarmup: + enabled: true + model: "all-MiniLM-L6-v2" + maxRetries: 60 # Maximum number of retries (60 * 5s = 5 minutes max wait) + retryInterval: 5 # Seconds between retries # Hugging Face Token for model downloads llm-service: diff --git a/frontend/llama_stack_ui/distribution/ui/page/upload/upload.py b/frontend/llama_stack_ui/distribution/ui/page/upload/upload.py index ffd01751..8547cd8f 100644 --- a/frontend/llama_stack_ui/distribution/ui/page/upload/upload.py +++ b/frontend/llama_stack_ui/distribution/ui/page/upload/upload.py @@ -1,67 +1,604 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio +import asyncpg +import os import streamlit as st -from llama_stack_client import RAGDocument +import traceback + +from llama_stack_ui.distribution.ui.modules.utils import get_vector_db_name, data_url_from_file from llama_stack_ui.distribution.ui.modules.api import llama_stack_api -from llama_stack_ui.distribution.ui.modules.utils import data_url_from_file +from llama_stack_client import RAGDocument + + +# Module-level connection pool (initialized lazily) +_pg_pool = None + + +async def _get_pg_pool(): + """ + Get or create the PostgreSQL connection pool. + The pool is created lazily on first use and reused for subsequent calls. + + Returns: + asyncpg.Pool: The connection pool instance + """ + global _pg_pool + if _pg_pool is None: + pg_host = os.environ.get("PGVECTOR_HOST", "pgvector") + pg_port = os.environ.get("PGVECTOR_PORT", "5432") + pg_user = os.environ.get("PGVECTOR_USER", "postgres") + pg_password = os.environ.get("PGVECTOR_PASSWORD", "rag_password") + pg_database = os.environ.get("PGVECTOR_DB", "rag_blueprint") + + _pg_pool = await asyncpg.create_pool( + host=pg_host, + port=int(pg_port), + user=pg_user, + password=pg_password, + database=pg_database, + min_size=1, + max_size=5, + ) + return _pg_pool + def upload_page(): """ - Page to upload documents and create a vector database for RAG. + Page to upload documents and manage vector databases for RAG. + Supports creating new vector databases and uploading documents to existing ones. + """ + st.title("📄 Upload Documents") + + # Initialize session state for creation status messages + if "creation_status" not in st.session_state: + st.session_state["creation_status"] = None + if "creation_message" not in st.session_state: + st.session_state["creation_message"] = "" + + # Initialize session state for selected vector database + # This persists the selection when navigating away and back to this page + if "selected_vector_db" not in st.session_state: + st.session_state["selected_vector_db"] = "" + + # Initialize the widget key to match our tracked selection + # This ensures the selectbox displays the correct value on page load + if "vector_db_selector" not in st.session_state: + st.session_state["vector_db_selector"] = st.session_state["selected_vector_db"] + + # Initialize newly created VDB tracker + if "newly_created_vdb" not in st.session_state: + st.session_state["newly_created_vdb"] = None + + # Show status messages at the top level (before dropdown) + if st.session_state["creation_status"] == "success": + st.success(st.session_state["creation_message"]) + # Clear the message after showing it + st.session_state["creation_status"] = None + st.session_state["creation_message"] = "" + elif st.session_state["creation_status"] == "error": + st.error(st.session_state["creation_message"]) + # Clear the message after showing it + st.session_state["creation_status"] = None + st.session_state["creation_message"] = "" + + # Fetch all vector databases + vdb_list = llama_stack_api.client.vector_dbs.list() + + # Build dropdown options based on whether databases exist + dropdown_options = [] + vdb_info = {} + + # Define the "Create New" option with emoji for visibility + CREATE_NEW_OPTION = "➕ Create New" + + if vdb_list: + # When databases exist: list actual DBs first, then "Create New" LAST + existing_vdbs = {get_vector_db_name(v): v.to_dict() for v in vdb_list} + dropdown_options.extend(list(existing_vdbs.keys())) + dropdown_options.append(CREATE_NEW_OPTION) # Add "Create New" as LAST item + vdb_info = existing_vdbs + else: + # When NO databases exist: only show "Create New" + dropdown_options = [CREATE_NEW_OPTION] + + # Sync session state for widget - ensure it shows the right value + # Priority 1: If a database was just created, auto-select it (highest priority) + if st.session_state["newly_created_vdb"]: + newly_created_name = st.session_state["newly_created_vdb"] + if newly_created_name in dropdown_options: + # Update both session variables to sync state + st.session_state["selected_vector_db"] = newly_created_name + st.session_state["vector_db_selector"] = newly_created_name + st.session_state["newly_created_vdb"] = None + # Priority 2: Use the previously selected database from session if it still exists + elif st.session_state["selected_vector_db"] and st.session_state["selected_vector_db"] in dropdown_options: + # Sync widget state with our tracked state + st.session_state["vector_db_selector"] = st.session_state["selected_vector_db"] + # Priority 3: If no saved selection or saved selection doesn't exist, use smart default + else: + if vdb_list: + # When databases exist: default to FIRST actual database (not "Create New") + first_db = dropdown_options[0] # First item is first actual database + st.session_state["selected_vector_db"] = first_db + st.session_state["vector_db_selector"] = first_db + else: + # When NO databases exist: default to "Create New" + st.session_state["selected_vector_db"] = CREATE_NEW_OPTION + st.session_state["vector_db_selector"] = CREATE_NEW_OPTION + + # Vector database selection dropdown with persistent selection + # Using key parameter to bind directly to session state - NO index parameter to avoid conflicts + def on_vector_db_change(): + """Callback to update session state when selection changes""" + st.session_state["selected_vector_db"] = st.session_state["vector_db_selector"] + + selected_vector_db = st.selectbox( + "Select a vector database", + dropdown_options, + key="vector_db_selector", # Key binds to session state (session state controls the value) + on_change=on_vector_db_change, # Callback updates our tracking variable + help="Your selection will be remembered when you navigate to other pages" + ) + + # Ensure session state is updated (in case callback didn't fire) + if selected_vector_db != st.session_state["selected_vector_db"]: + st.session_state["selected_vector_db"] = selected_vector_db + + # Get the actual vector database object for API calls (do this before using it) + selected_vdb_obj = None + if selected_vector_db and selected_vector_db != CREATE_NEW_OPTION: + for vdb in vdb_list: + if get_vector_db_name(vdb) == selected_vector_db: + selected_vdb_obj = vdb + break + + if selected_vector_db == CREATE_NEW_OPTION: + # Show vector database creation UI + _show_create_vector_db_ui() + elif selected_vector_db and selected_vector_db != CREATE_NEW_OPTION: + # Show existing documents in the database (heading will show only if documents exist) + _show_existing_documents_table(selected_vector_db, selected_vdb_obj) + + # Add Browse functionality for uploading documents to this database + st.subheader(f"📁 Upload Documents to '{selected_vector_db}'") + _show_document_upload_ui(selected_vector_db, selected_vdb_obj) + # If empty string is selected, show nothing (clean default state) + + +def _show_create_vector_db_ui(): + """ + Display UI for creating a new vector database. + """ + st.subheader("Create New Vector Database") + + # Initialize session state for creation form + if "new_vdb_name" not in st.session_state: + st.session_state["new_vdb_name"] = "" + + # Vector database name input + new_vdb_name = st.text_input( + "Add New Vector Database", + value=st.session_state["new_vdb_name"], + help="Enter a unique name for the new vector database", + key="new_vdb_name_input" + ) + + # Update session state + st.session_state["new_vdb_name"] = new_vdb_name + + # Add button + if st.button("Add", type="primary", disabled=not new_vdb_name.strip()): + _create_vector_database(new_vdb_name.strip()) + + +def _create_vector_database(vdb_name): + """ + Create a new vector database using the LlamaStack API. + + Args: + vdb_name (str): Name for the new vector database + """ + try: + # Reset status + st.session_state["creation_status"] = None + st.session_state["creation_message"] = "" + + # Validate input + if not vdb_name or not vdb_name.strip(): + st.session_state["creation_status"] = "error" + st.session_state["creation_message"] = "Vector database name cannot be empty." + return + + # Check for duplicate names + existing_vdbs = llama_stack_api.client.vector_dbs.list() + existing_names = [get_vector_db_name(vdb) for vdb in existing_vdbs] + if vdb_name in existing_names: + st.session_state["creation_status"] = "error" + st.session_state["creation_message"] = f"Vector database '{vdb_name}' already exists. Please choose a different name." + return + + # Get vector IO provider + providers = llama_stack_api.client.providers.list() + vector_io_provider = None + for provider in providers: + if provider.api == "vector_io": + vector_io_provider = provider.provider_id + break + + if not vector_io_provider: + st.session_state["creation_status"] = "error" + st.session_state["creation_message"] = "No vector IO provider found. Cannot create vector database." + return + + # Create the vector database + with st.spinner(f"Creating vector database '{vdb_name}'..."): + vector_db = llama_stack_api.client.vector_dbs.register( + vector_db_id=vdb_name, + embedding_dimension=384, + embedding_model="all-MiniLM-L6-v2", + provider_id=vector_io_provider, + ) + + # Success + st.session_state["creation_status"] = "success" + st.session_state["creation_message"] = f"Vector database '{vdb_name}' created successfully!" + + # Mark this database to be auto-selected after refresh + st.session_state["newly_created_vdb"] = vdb_name + + # Clear the input field + st.session_state["new_vdb_name"] = "" + + # Trigger page refresh to update the dropdown - this will show the message at the top + st.rerun() + + except Exception as e: + st.session_state["creation_status"] = "error" + st.session_state["creation_message"] = f"Error creating vector database: {str(e)}" + + +def _show_document_upload_ui(vector_db_name, vector_db_obj=None): + """ + Display UI for uploading documents to an existing vector database. + + Args: + vector_db_name (str): Name of the selected vector database """ - st.title("📄 Upload") - # File/Directory Upload Section - st.subheader("Create Vector DB") - # Let user select files to ingest + # Initialize session state for upload status + if "upload_status" not in st.session_state: + st.session_state["upload_status"] = None + if "upload_message" not in st.session_state: + st.session_state["upload_message"] = "" + + # Show upload status messages + if st.session_state["upload_status"] == "success": + st.success(st.session_state["upload_message"]) + # Clear after showing + st.session_state["upload_status"] = None + st.session_state["upload_message"] = "" + elif st.session_state["upload_status"] == "error": + st.error(st.session_state["upload_message"]) + # Clear after showing + st.session_state["upload_status"] = None + st.session_state["upload_message"] = "" + + # Initialize session state to track processed files + upload_key = f"processed_files_{vector_db_name}" + if upload_key not in st.session_state: + st.session_state[upload_key] = set() + + # File uploader uploaded_files = st.file_uploader( - "Upload file(s) or directory", + "Browse and select files to upload (files will upload automatically)", accept_multiple_files=True, - type=["txt", "pdf", "doc", "docx"], # supported file types + type=["txt", "pdf", "doc", "docx"], + key=f"uploader_{vector_db_name}", # Unique key per database + help="Select one or more documents - they will be uploaded automatically to this vector database" ) - # Process uploaded files + + # Auto-upload when files are selected if uploaded_files: - # Show upload success and prompt for DB name - st.success(f"Successfully uploaded {len(uploaded_files)} files") - vector_db_name = st.text_input( - "Vector Database Name", - value="rag_vector_db", - help="Enter a unique identifier for this vector database", - ) - if st.button("Create Vector Database"): - # Convert uploaded files into RAGDocument instances + # Create a unique identifier for this set of files + file_set_id = frozenset([f.name + str(f.size) for f in uploaded_files]) + + # Only process if this is a new set of files + if file_set_id not in st.session_state[upload_key]: + # Mark as processed IMMEDIATELY before upload to prevent re-triggering + st.session_state[upload_key].add(file_set_id) + + # Get the correct database ID for upload + vector_db_id = vector_db_obj.identifier if vector_db_obj and hasattr(vector_db_obj, 'identifier') else vector_db_name + + # Upload automatically + _upload_documents_to_database(vector_db_name, uploaded_files, vector_db_id) + + +def _upload_documents_to_database(vector_db_name, uploaded_files, vector_db_id=None): + """ + Upload documents to an existing vector database. + + Args: + vector_db_name (str): Name of the target vector database + uploaded_files: List of uploaded files from Streamlit file uploader + """ + try: + # Reset status + st.session_state["upload_status"] = None + st.session_state["upload_message"] = "" + + if not uploaded_files: + st.session_state["upload_status"] = "error" + st.session_state["upload_message"] = "No files selected for upload." + return + + # Convert uploaded files into RAGDocument instances + with st.spinner(f"Processing {len(uploaded_files)} file(s)..."): documents = [ RAGDocument( document_id=uploaded_file.name, content=data_url_from_file(uploaded_file), + metadata={"source": uploaded_file.name, "type": "uploaded_file"} # LlamaStack maps 'source' to chunk_metadata.source ) - for i, uploaded_file in enumerate(uploaded_files) + for uploaded_file in uploaded_files ] - - # Determine provider for vector IO - providers = llama_stack_api.client.providers.list() - vector_io_provider = None - for x in providers: - if x.api == "vector_io": - vector_io_provider = x.provider_id - break - - # Register new vector database - vector_db = llama_stack_api.client.vector_dbs.register( - vector_db_id=vector_db_name, - embedding_dimension=384, - embedding_model="all-MiniLM-L6-v2", - provider_id=vector_io_provider, - ) - vector_db_id = vector_db.identifier - - # Insert documents into the vector database + + # Insert documents into the existing vector database + actual_db_id = vector_db_id or vector_db_name + with st.spinner(f"Uploading documents to '{vector_db_name}'..."): llama_stack_api.client.tool_runtime.rag_tool.insert( - vector_db_id=vector_db_id, + vector_db_id=actual_db_id, # Use the correct database ID documents=documents, chunk_size_in_tokens=512, ) - st.success("Vector database created successfully!") - # Reset form fields - uploaded_files.clear() - vector_db_name = "" + + # Success + st.session_state["upload_status"] = "success" + st.session_state["upload_message"] = f"Successfully uploaded {len(uploaded_files)} document(s) to '{vector_db_name}'!" + + # Trigger refresh to show the success message + st.rerun() + + except Exception as e: + st.session_state["upload_status"] = "error" + st.session_state["upload_message"] = f"Error uploading documents: {str(e)}" + st.rerun() + + +def _get_documents_from_pgvector(vector_db_id): + """ + Query pgvector directly to get document IDs stored in the database. + Uses a connection pool for efficient connection reuse. + + Args: + vector_db_id (str): The vector database identifier + + Returns: + list: List of unique document IDs, or None if query fails + """ + try: + async def fetch_documents(): + try: + # Get connection from pool + pool = await _get_pg_pool() + + async with pool.acquire() as conn: + # Query for unique document IDs from the document JSONB column + # The vector_db_id is used as the table name with underscores replacing hyphens + table_name = f"vs_{vector_db_id.replace('-', '_')}" + + # Query metadata.source where LlamaStack stores the filename + # Try multiple paths since different upload methods use different structures: + # - Ingestion pipeline: metadata.source + # - Manual upload: chunk_metadata.source + # Fall back to auto-generated document_id if source is null + query = f""" + SELECT DISTINCT + COALESCE( + NULLIF(document->'metadata'->>'source', 'null'), + NULLIF(document->'chunk_metadata'->>'source', 'null'), + document->'metadata'->>'document_id' + ) as document_id + FROM {table_name} + WHERE document->'metadata'->>'document_id' IS NOT NULL + OR document->'metadata'->>'source' IS NOT NULL + ORDER BY document_id + """ + + queries = [query] + + doc_ids = [] + for query in queries: + try: + rows = await conn.fetch(query) + if rows: + doc_ids = [row['document_id'] for row in rows if row['document_id']] + if doc_ids: + break + except Exception as e: + continue # Try next query pattern + + return doc_ids if doc_ids else None + # Connection automatically returned to pool + + except Exception as e: + return None + + # Run the async function + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + return loop.run_until_complete(fetch_documents()) + + except Exception as e: + return None + + +def _delete_document_from_pgvector(vector_db_id, filename): + """ + Delete a document and all its chunks/embeddings from pgvector. + Uses a connection pool for efficient connection reuse. + + Args: + vector_db_id (str): The vector database identifier + filename (str): The filename/source to delete + + Returns: + tuple: (success: bool, deleted_count: int, error_message: str) + """ + try: + async def delete_document(): + try: + # Get connection from pool + pool = await _get_pg_pool() + + async with pool.acquire() as conn: + # The vector_db_id is used as the table name with underscores replacing hyphens + table_name = f"vs_{vector_db_id.replace('-', '_')}" + + # Delete all chunks where the source matches the filename + # Handle both document structures: + # - Ingestion pipeline: metadata.source + # - Manual upload: chunk_metadata.source + query = f""" + DELETE FROM {table_name} + WHERE document->'metadata'->>'source' = $1 + OR document->'chunk_metadata'->>'source' = $1 + """ + + result = await conn.execute(query, filename) + + # Parse the result to get the number of deleted rows + # Result format is like "DELETE 5" where 5 is the number of rows + deleted_count = int(result.split()[-1]) if result else 0 + + return True, deleted_count, None + # Connection automatically returned to pool + + except Exception as e: + return False, 0, str(e) + + # Run the async function + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + return loop.run_until_complete(delete_document()) + + except Exception as e: + return False, 0, str(e) + + +def _show_existing_documents_table(vector_db_name, vector_db_obj=None): + """ + Display information about documents in the selected vector database. + + Args: + vector_db_name (str): Display name of the selected vector database + vector_db_obj: The actual vector database object with identifier + """ + try: + # Get the correct vector database ID + if vector_db_obj and hasattr(vector_db_obj, 'identifier'): + vector_db_id = vector_db_obj.identifier + else: + vector_db_id = vector_db_name # Fallback to display name + + # Initialize session state for deletion status + if "delete_status" not in st.session_state: + st.session_state["delete_status"] = None + if "delete_message" not in st.session_state: + st.session_state["delete_message"] = "" + + # Show deletion status messages (before checking documents, so last delete shows) + if st.session_state["delete_status"] == "success": + st.success(st.session_state["delete_message"]) + st.session_state["delete_status"] = None + st.session_state["delete_message"] = "" + elif st.session_state["delete_status"] == "error": + st.error(st.session_state["delete_message"]) + st.session_state["delete_status"] = None + st.session_state["delete_message"] = "" + + with st.spinner("Checking for documents..."): + # First, try to get document list from pgvector directly + document_ids = _get_documents_from_pgvector(vector_db_id) + + if document_ids: + # Success! We have the actual document filenames + # Show heading for documents section + st.subheader(f"📄 Documents in '{vector_db_name}'") + + # Add CSS for bordered table rows + st.markdown(""" + + """, unsafe_allow_html=True) + + # Display table header + col1, col2, col3 = st.columns([0.5, 5, 0.5]) + with col1: + st.markdown("**#**") + with col2: + st.markdown("**Filename**") + with col3: + st.markdown("**Del**") + + # Display each document in a row with delete button + for idx, doc_id in enumerate(document_ids, start=1): + col1, col2, col3 = st.columns([0.5, 5, 0.5]) + + with col1: + st.write(idx) + + with col2: + st.write(doc_id) + + with col3: + delete_key = f"delete_{vector_db_name}_{doc_id}_{idx}" + + if st.button("✕", key=delete_key, help=f"Delete {doc_id}"): + # Delete immediately without confirmation + success, deleted_count, error = _delete_document_from_pgvector( + vector_db_id, + doc_id + ) + + if success: + st.session_state["delete_status"] = "success" + st.session_state["delete_message"] = f"✅ Successfully deleted '{doc_id}' ({deleted_count} chunk(s) removed)" + else: + st.session_state["delete_status"] = "error" + st.session_state["delete_message"] = f"❌ Failed to delete '{doc_id}': {error}" + + st.rerun() + + # else: Database appears empty or pgvector query not available + # For newly created databases, this is expected - just show nothing + # The upload section below will allow users to add documents + + except Exception as e: + st.error(f"Error loading document information: {str(e)}") + with st.expander("Error Details"): + st.code(traceback.format_exc()) -upload_page() \ No newline at end of file +upload_page() diff --git a/frontend/pyproject.toml b/frontend/pyproject.toml index be65e231..31e4233e 100644 --- a/frontend/pyproject.toml +++ b/frontend/pyproject.toml @@ -15,6 +15,7 @@ dependencies = [ "streamlit-option-menu", "llama-stack==0.2.23", "fire", + "asyncpg", ] [tool.setuptools] diff --git a/tests/integration/test_upload_integration.py b/tests/integration/test_upload_integration.py index 549f28fd..378ac1a9 100644 --- a/tests/integration/test_upload_integration.py +++ b/tests/integration/test_upload_integration.py @@ -11,6 +11,22 @@ # Add the frontend directory to the path sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../frontend')) +# Mock all external dependencies before any imports from the upload module +# This is required because @patch decorators try to import the target module +mock_streamlit = MagicMock() +mock_streamlit.session_state = {} +sys.modules['streamlit'] = mock_streamlit +sys.modules['asyncpg'] = MagicMock() +sys.modules['pandas'] = MagicMock() + +# Mock llama_stack_client with a proper RAGDocument mock +mock_llama_stack_client = MagicMock() +def mock_rag_document(**kwargs): + """Create a dict-like RAGDocument mock""" + return kwargs +mock_llama_stack_client.RAGDocument = mock_rag_document +sys.modules['llama_stack_client'] = mock_llama_stack_client + # Configuration LLAMA_STACK_ENDPOINT = os.getenv("LLAMA_STACK_ENDPOINT", "http://localhost:8321") diff --git a/tests/unit/test_upload.py b/tests/unit/test_upload.py index 0e4f0011..de5fe7bb 100644 --- a/tests/unit/test_upload.py +++ b/tests/unit/test_upload.py @@ -2,217 +2,323 @@ Unit tests for the upload module Tests document upload and vector DB creation logic """ -import pytest -from unittest.mock import Mock, patch, MagicMock -import sys +import asyncio import os +import sys +from unittest.mock import AsyncMock, MagicMock, Mock, patch + +import pytest # Add the frontend directory to the path sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../frontend')) -# Mock streamlit before importing -sys.modules['streamlit'] = MagicMock() +# Mock all external dependencies before any imports from the upload module +# This is required because @patch decorators try to import the target module +mock_streamlit = MagicMock() +mock_streamlit.session_state = {} +sys.modules['streamlit'] = mock_streamlit +sys.modules['asyncpg'] = MagicMock() +sys.modules['pandas'] = MagicMock() + +# Mock llama_stack_client with a proper RAGDocument mock +mock_llama_stack_client = MagicMock() +def mock_rag_document(**kwargs): + """Create a dict-like RAGDocument mock""" + return kwargs +mock_llama_stack_client.RAGDocument = mock_rag_document +sys.modules['llama_stack_client'] = mock_llama_stack_client +# Now we can safely import modules that will be patched +# Pre-import the modules so @patch can find them +from llama_stack_ui.distribution.ui.modules import api, utils +from llama_stack_ui.distribution.ui.page.upload import upload as upload_module -class TestVectorDBConfiguration: - """Test vector database configuration and setup""" + +class MockAsyncContextManager: + """Mock async context manager for pool.acquire()""" + def __init__(self, conn): + self.conn = conn - def test_vector_db_default_name(self): - """Test default vector database name""" - default_name = "rag_vector_db" - assert default_name == "rag_vector_db" - assert len(default_name) > 0 + async def __aenter__(self): + return self.conn - def test_vector_db_embedding_dimension(self): - """Test that embedding dimension is set correctly for all-MiniLM-L6-v2""" - embedding_dimension = 384 - embedding_model = "all-MiniLM-L6-v2" - - assert embedding_dimension == 384 - assert embedding_model == "all-MiniLM-L6-v2" + async def __aexit__(self, exc_type, exc_val, exc_tb): + return None + + +def create_mock_pool_with_connection(mock_conn): + """ + Helper to create a mock connection pool that yields the given connection. - def test_chunk_size_configuration(self): - """Test that chunk size is set to 512 tokens""" - chunk_size = 512 - assert chunk_size == 512 + Args: + mock_conn: The mock connection to return from pool.acquire() + + Returns: + MagicMock: A mock pool with proper acquire() context manager + """ + mock_pool = MagicMock() + mock_pool.acquire.return_value = MockAsyncContextManager(mock_conn) + return mock_pool -class TestDocumentProcessing: - """Test document processing and RAGDocument creation""" +class TestGetDocumentsFromPgvector: + """Unit tests for _get_documents_from_pgvector function""" - def test_supported_file_types(self): - """Test that supported file types are correctly defined""" - supported_types = ["txt", "pdf", "doc", "docx"] - - assert "txt" in supported_types - assert "pdf" in supported_types - assert "doc" in supported_types - assert "docx" in supported_types + def test_get_documents_success(self): + """Test successful retrieval of documents from pgvector""" + # Setup mock connection + mock_conn = AsyncMock() + mock_rows = [ + {'document_id': 'document1.pdf'}, + {'document_id': 'document2.txt'}, + {'document_id': 'document3.docx'}, + ] + mock_conn.fetch = AsyncMock(return_value=mock_rows) + + # Create mock pool + mock_pool = create_mock_pool_with_connection(mock_conn) + + # Patch _get_pg_pool to return our mock pool + async def mock_get_pool(): + return mock_pool + + with patch.object(upload_module, '_get_pg_pool', mock_get_pool): + # Call the actual function + result = upload_module._get_documents_from_pgvector("my-test-db") + + # Verify the result + assert result == ['document1.pdf', 'document2.txt', 'document3.docx'] + + # Verify acquire was called (connection borrowed from pool) + mock_pool.acquire.assert_called_once() - def test_document_id_from_filename(self): - """Test that document ID is created from filename""" - filename = "test_document.pdf" - document_id = filename + def test_get_documents_empty_result(self): + """Test that empty results return None""" + # Setup mock connection with empty result + mock_conn = AsyncMock() + mock_conn.fetch = AsyncMock(return_value=[]) - assert document_id == "test_document.pdf" - assert document_id.endswith(".pdf") + mock_pool = create_mock_pool_with_connection(mock_conn) + + async def mock_get_pool(): + return mock_pool + + with patch.object(upload_module, '_get_pg_pool', mock_get_pool): + result = upload_module._get_documents_from_pgvector("empty-db") + + # Empty result should return None + assert result is None - @patch('llama_stack_ui.distribution.ui.modules.utils.data_url_from_file') - def test_rag_document_creation(self, mock_data_url): - """Test RAGDocument creation from uploaded file""" - from llama_stack_client import RAGDocument - - # Mock file and data URL - mock_data_url.return_value = "data:text/plain;base64,SGVsbG8gV29ybGQ=" - - mock_file = Mock() - mock_file.name = "test.txt" - - # Create RAGDocument as done in upload.py - document = RAGDocument( - document_id=mock_file.name, - content=mock_data_url(mock_file), - ) - - # RAGDocument returns a dict-like object - assert document['document_id'] == "test.txt" - assert document['content'].startswith("data:") - mock_data_url.assert_called_once() + def test_get_documents_connection_error(self): + """Test that connection errors return None""" + # Setup mock pool that raises an exception on acquire + async def mock_get_pool(): + raise Exception("Connection refused") + + with patch.object(upload_module, '_get_pg_pool', mock_get_pool): + result = upload_module._get_documents_from_pgvector("error-db") + + # Error should return None + assert result is None - def test_multiple_documents_processing(self): - """Test processing multiple uploaded files""" - # Simulate multiple uploaded files - mock_file1 = Mock() - mock_file1.name = "doc1.txt" - mock_file2 = Mock() - mock_file2.name = "doc2.pdf" - mock_file3 = Mock() - mock_file3.name = "doc3.docx" + def test_get_documents_filters_null_ids(self): + """Test that null document IDs are filtered out""" + mock_conn = AsyncMock() + mock_rows = [ + {'document_id': 'valid1.pdf'}, + {'document_id': None}, # Should be filtered + {'document_id': 'valid2.txt'}, + {'document_id': None}, # Should be filtered + ] + mock_conn.fetch = AsyncMock(return_value=mock_rows) - uploaded_files = [mock_file1, mock_file2, mock_file3] + mock_pool = create_mock_pool_with_connection(mock_conn) - # Simulate creating document list - document_ids = [f.name for f in uploaded_files] + async def mock_get_pool(): + return mock_pool - assert len(document_ids) == 3 - assert "doc1.txt" in document_ids - assert "doc2.pdf" in document_ids - assert "doc3.docx" in document_ids + with patch.object(upload_module, '_get_pg_pool', mock_get_pool): + result = upload_module._get_documents_from_pgvector("mixed-db") + + # Only valid IDs should be returned + assert result == ['valid1.pdf', 'valid2.txt'] + assert len(result) == 2 -class TestVectorDBOperations: - """Test vector database operations""" +class TestDeleteDocumentFromPgvector: + """Unit tests for _delete_document_from_pgvector function""" - @patch('llama_stack_ui.distribution.ui.modules.api.llama_stack_api') - def test_vector_db_registration_params(self, mock_api): - """Test that vector DB registration uses correct parameters""" - mock_client = Mock() - mock_api.client = mock_client + def test_delete_document_success(self): + """Test successful deletion of document from pgvector""" + mock_conn = AsyncMock() + mock_conn.execute = AsyncMock(return_value="DELETE 5") + + mock_pool = create_mock_pool_with_connection(mock_conn) - vector_db_id = "test_vector_db" - embedding_dimension = 384 - embedding_model = "all-MiniLM-L6-v2" - provider_id = "pgvector" - - # Simulate registration call - mock_client.vector_dbs.register( - vector_db_id=vector_db_id, - embedding_dimension=embedding_dimension, - embedding_model=embedding_model, - provider_id=provider_id, - ) - - # Verify the call was made with correct params - mock_client.vector_dbs.register.assert_called_once_with( - vector_db_id=vector_db_id, - embedding_dimension=embedding_dimension, - embedding_model=embedding_model, - provider_id=provider_id, - ) + async def mock_get_pool(): + return mock_pool + + with patch.object(upload_module, '_get_pg_pool', mock_get_pool): + success, count, error = upload_module._delete_document_from_pgvector( + "my-test-db", + "document.pdf" + ) + + assert success is True + assert count == 5 + assert error is None + mock_pool.acquire.assert_called_once() - @patch('llama_stack_ui.distribution.ui.modules.api.llama_stack_api') - def test_document_insertion_params(self, mock_api): - """Test that document insertion uses correct parameters""" - from llama_stack_client import RAGDocument + def test_delete_document_not_found(self): + """Test deletion when document doesn't exist""" + mock_conn = AsyncMock() + mock_conn.execute = AsyncMock(return_value="DELETE 0") - mock_client = Mock() - mock_api.client = mock_client + mock_pool = create_mock_pool_with_connection(mock_conn) - vector_db_id = "test_vector_db" - documents = [ - RAGDocument(document_id="doc1", content="content1"), - RAGDocument(document_id="doc2", content="content2"), + async def mock_get_pool(): + return mock_pool + + with patch.object(upload_module, '_get_pg_pool', mock_get_pool): + success, count, error = upload_module._delete_document_from_pgvector( + "my-test-db", + "nonexistent.pdf" + ) + + assert success is True + assert count == 0 + assert error is None + + def test_delete_document_connection_error(self): + """Test deletion with connection error""" + async def mock_get_pool(): + raise Exception("Connection refused") + + with patch.object(upload_module, '_get_pg_pool', mock_get_pool): + success, count, error = upload_module._delete_document_from_pgvector( + "my-test-db", + "document.pdf" + ) + + assert success is False + assert count == 0 + assert error is not None + assert "Connection refused" in str(error) + + +class TestCreateVectorDatabase: + """Unit tests for _create_vector_database function""" + + def test_create_vector_database_success(self): + """Test successful creation of vector database""" + # Mock the API client + mock_client = MagicMock() + mock_client.vector_dbs.list.return_value = [] + mock_client.providers.list.return_value = [ + MagicMock(api="vector_io", provider_id="pgvector") ] - chunk_size = 512 - - # Simulate insertion call - mock_client.tool_runtime.rag_tool.insert( - vector_db_id=vector_db_id, - documents=documents, - chunk_size_in_tokens=chunk_size, - ) - - # Verify the call was made - mock_client.tool_runtime.rag_tool.insert.assert_called_once() - call_args = mock_client.tool_runtime.rag_tool.insert.call_args - assert call_args[1]['vector_db_id'] == vector_db_id - assert call_args[1]['chunk_size_in_tokens'] == chunk_size - assert len(call_args[1]['documents']) == 2 + mock_client.vector_dbs.register.return_value = MagicMock() + + mock_api = MagicMock() + mock_api.client = mock_client + + # Mock session state + mock_st = MagicMock() + mock_st.session_state = {} + + with patch.object(upload_module, 'llama_stack_api', mock_api): + with patch.object(upload_module, 'st', mock_st): + upload_module._create_vector_database("new-test-db") + + # Verify registration was called with correct parameters + mock_client.vector_dbs.register.assert_called_once() + call_kwargs = mock_client.vector_dbs.register.call_args[1] + assert call_kwargs['vector_db_id'] == "new-test-db" + assert call_kwargs['embedding_model'] == "all-MiniLM-L6-v2" + assert call_kwargs['embedding_dimension'] == 384 + assert call_kwargs['provider_id'] == "pgvector" - @patch('llama_stack_ui.distribution.ui.modules.api.llama_stack_api') - def test_provider_detection(self, mock_api): - """Test vector IO provider detection""" - mock_client = Mock() + def test_create_vector_database_duplicate_name(self): + """Test that duplicate names are rejected""" + # Mock existing database with same name + existing_db = MagicMock() + existing_db.identifier = "existing-db" + + mock_client = MagicMock() + mock_client.vector_dbs.list.return_value = [existing_db] + + mock_api = MagicMock() mock_api.client = mock_client - # Mock provider list - mock_providers = [ - Mock(api="inference", provider_id="ollama"), - Mock(api="vector_io", provider_id="pgvector"), - Mock(api="memory", provider_id="redis"), + mock_st = MagicMock() + mock_st.session_state = {} + + with patch.object(upload_module, 'llama_stack_api', mock_api): + with patch.object(upload_module, 'st', mock_st): + upload_module._create_vector_database("existing-db") + + # Registration should NOT be called for duplicates + mock_client.vector_dbs.register.assert_not_called() + + # Error status should be set + assert mock_st.session_state.get("creation_status") == "error" + + def test_create_vector_database_no_provider(self): + """Test error when no vector_io provider exists""" + mock_client = MagicMock() + mock_client.vector_dbs.list.return_value = [] + mock_client.providers.list.return_value = [ + MagicMock(api="inference", provider_id="ollama") # No vector_io ] - mock_client.providers.list.return_value = mock_providers - # Simulate provider detection logic - providers = mock_client.providers.list() - vector_io_provider = None - for x in providers: - if x.api == "vector_io": - vector_io_provider = x.provider_id + mock_api = MagicMock() + mock_api.client = mock_client + + mock_st = MagicMock() + mock_st.session_state = {} - assert vector_io_provider == "pgvector" + with patch.object(upload_module, 'llama_stack_api', mock_api): + with patch.object(upload_module, 'st', mock_st): + upload_module._create_vector_database("new-db") + + # Registration should NOT be called without provider + mock_client.vector_dbs.register.assert_not_called() + + # Error status should be set + assert mock_st.session_state.get("creation_status") == "error" -class TestUploadValidation: - """Test upload validation and error handling""" - - def test_empty_upload_list(self): - """Test handling of empty upload list""" - uploaded_files = [] - assert len(uploaded_files) == 0 +class TestConnectionPool: + """Unit tests for the connection pool functionality""" - def test_upload_count_display(self): - """Test upload count display logic""" - uploaded_files = [Mock(), Mock(), Mock()] - count = len(uploaded_files) - message = f"Successfully uploaded {count} files" - - assert message == "Successfully uploaded 3 files" - assert str(count) in message - - def test_vector_db_name_validation(self): - """Test vector database name validation""" - # Valid names - valid_names = ["rag_vector_db", "test-db-123", "my_documents"] - for name in valid_names: - assert len(name) > 0 - assert name.replace('_', '').replace('-', '').isalnum() + def test_pool_is_reused(self): + """Test that the same pool is returned on subsequent calls""" + # Reset the global pool + upload_module._pg_pool = None + + mock_pool = AsyncMock() + + # Mock asyncpg.create_pool + async def mock_create_pool(**kwargs): + return mock_pool - # Invalid names should be caught - invalid_name = "" - assert len(invalid_name) == 0 + with patch.object(upload_module.asyncpg, 'create_pool', mock_create_pool): + # Get pool twice + async def get_pools(): + pool1 = await upload_module._get_pg_pool() + pool2 = await upload_module._get_pg_pool() + return pool1, pool2 + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + pool1, pool2 = loop.run_until_complete(get_pools()) + + # Should be the same pool instance + assert pool1 is pool2 + + # Clean up + upload_module._pg_pool = None if __name__ == "__main__": pytest.main([__file__, "-v"]) -