diff --git a/.github/workflows/e2e-tests.yaml b/.github/workflows/e2e-tests.yaml index 3b942df1..3d7e9c79 100644 --- a/.github/workflows/e2e-tests.yaml +++ b/.github/workflows/e2e-tests.yaml @@ -629,17 +629,23 @@ jobs: deployment/llamastack -n rag-e2e-ui kubectl wait --for=condition=available --timeout=300s \ deployment/rag -n rag-e2e-ui + + echo "Waiting for pods to be ready..." + kubectl wait --for=condition=ready --timeout=600s \ + pod -l app.kubernetes.io/name=llamastack -n rag-e2e-ui + kubectl wait --for=condition=ready --timeout=300s \ + pod -l app.kubernetes.io/name=rag -n rag-e2e-ui + + echo "✅ All pods are ready" + kubectl get pods -n rag-e2e-ui - name: Expose services via NodePort run: | kubectl patch service rag -n rag-e2e-ui -p '{"spec":{"type":"NodePort","ports":[{"port":8501,"nodePort":30080}]}}' kubectl patch service llamastack -n rag-e2e-ui -p '{"spec":{"type":"NodePort","ports":[{"port":8321,"nodePort":30081}]}}' - - - name: Port forward services - run: | - kubectl port-forward -n rag-e2e-ui svc/rag 8501:8501 & - kubectl port-forward -n rag-e2e-ui svc/llamastack 8321:8321 & - sleep 10 + + # Verify services + kubectl get services -n rag-e2e-ui - name: Run UI E2E tests with Playwright env: @@ -649,10 +655,74 @@ jobs: MAAS_MODEL_ID: ${{ env.MAAS_MODEL_ID }} SKIP_MODEL_TESTS: "false" # Enable MaaS inference tests in UI run: | + echo "Starting port forwarding and running tests..." + + # Start port forwarding in background and keep them running + echo "Starting port forwarding for RAG UI..." + kubectl port-forward -n rag-e2e-ui svc/rag 8501:8501 > /tmp/rag-portforward.log 2>&1 & + RAG_PF_PID=$! + echo "RAG port-forward PID: $RAG_PF_PID" + + echo "Starting port forwarding for LlamaStack..." + kubectl port-forward -n rag-e2e-ui svc/llamastack 8321:8321 > /tmp/llamastack-portforward.log 2>&1 & + LLAMASTACK_PF_PID=$! + echo "LlamaStack port-forward PID: $LLAMASTACK_PF_PID" + + # Function to check if port forwarding is working + check_port_forwarding() { + (timeout 2 bash -c 'cat < /dev/null > /dev/tcp/localhost/8501' 2>/dev/null) && \ + (timeout 2 bash -c 'cat < /dev/null > /dev/tcp/localhost/8321' 2>/dev/null) + } + + # Function to verify processes are alive + check_processes() { + kill -0 $RAG_PF_PID 2>/dev/null && kill -0 $LLAMASTACK_PF_PID 2>/dev/null + } + + # Wait for port forwarding to establish + echo "Waiting for port forwarding to be ready..." + for i in {1..30}; do + if check_port_forwarding && check_processes; then + echo "✅ Port forwarding is working! (attempt $i)" + break + fi + if [ $i -eq 30 ]; then + echo "❌ Port forwarding failed to establish after 30 attempts" + echo "RAG port-forward log:" + cat /tmp/rag-portforward.log || true + echo "LlamaStack port-forward log:" + cat /tmp/llamastack-portforward.log || true + echo "Checking port-forward processes:" + ps aux | grep "kubectl port-forward" || true + echo "Checking if ports are listening:" + ss -tlnp | grep -E ':(8501|8321)' || netstat -tlnp 2>/dev/null | grep -E ':(8501|8321)' || true + exit 1 + fi + echo "Waiting for port forwarding... (attempt $i/30)" + sleep 2 + done + + # Set up cleanup trap to kill port forwarding on exit + trap "echo 'Cleaning up port forwarding...'; kill $RAG_PF_PID $LLAMASTACK_PF_PID 2>/dev/null || true" EXIT + echo "Running UI E2E tests with MaaS integration..." echo "MaaS Endpoint: ${MAAS_ENDPOINT}" echo "MaaS Model ID: ${MAAS_MODEL_ID}" - pytest tests/e2e_ui/ -v --tb=short --browser chromium + + # Run tests + pytest tests/e2e_ui/ -v --tb=short --browser chromium || TEST_EXIT_CODE=$? + + # Verify port forwarding was still working after tests + if ! check_port_forwarding; then + echo "⚠️ Warning: Port forwarding stopped working during tests" + echo "RAG port-forward log:" + cat /tmp/rag-portforward.log || true + echo "LlamaStack port-forward log:" + cat /tmp/llamastack-portforward.log || true + fi + + # Exit with test result + exit ${TEST_EXIT_CODE:-0} - name: Upload Playwright test results if: always() diff --git a/README.md b/README.md index cd5e732f..b6820384 100644 --- a/README.md +++ b/README.md @@ -267,7 +267,7 @@ Watch for all pods to reach Running or Completed status. Key pods to watch inclu oc get pods -l component=predictor ``` -Look for **3/3** under the Ready column. +Look for **2/2** (or **3/3** when RAW_DEPLOYMENT=false) under the Ready column. 8. **Verify Installation** diff --git a/deploy/helm/Makefile b/deploy/helm/Makefile index 836f6f04..e90f2f1f 100644 --- a/deploy/helm/Makefile +++ b/deploy/helm/Makefile @@ -85,10 +85,24 @@ endef # Helper function to validate values file define validate_values_file echo -e "$(BLUE)[INFO]$(NC) Validating configuration values..."; \ - HF_TOKEN=$$(grep -A 2 "^llm-service:" "$(VALUES_FILE)" | grep "hf_token:" | sed 's/.*hf_token: *//' | tr -d '"' | tr -d ' '); \ - LLAMA_STACK_TAVILY=$$(grep "TAVILY_SEARCH_API_KEY:" "$(VALUES_FILE)" 2>/dev/null | sed 's/.*TAVILY_SEARCH_API_KEY: *//' | tr -d '"' | tr -d ' '); \ UPDATED=0; \ - if [ -z "$$HF_TOKEN" ] || [ "$$HF_TOKEN" = "" ]; then \ + \ + if [ -n "$$HF_TOKEN" ]; then \ + echo -e "$(GREEN)[SUCCESS]$(NC) Using HF_TOKEN from environment variable."; \ + sed -i.bak "/^llm-service:/,/^[^ ]/ s|hf_token:.*|hf_token: \"$$HF_TOKEN\"|" "$(VALUES_FILE)"; \ + UPDATED=1; \ + fi; \ + \ + if [ -n "$$TAVILY_API_KEY" ]; then \ + echo -e "$(GREEN)[SUCCESS]$(NC) Using TAVILY_API_KEY from environment variable."; \ + sed -i.bak "s/TAVILY_SEARCH_API_KEY:.*/TAVILY_SEARCH_API_KEY: \"$$TAVILY_API_KEY\"/" "$(VALUES_FILE)"; \ + UPDATED=1; \ + fi; \ + \ + HF_TOKEN_FILE=$$(grep -A 2 "^llm-service:" "$(VALUES_FILE)" | grep "hf_token:" | sed 's/.*hf_token: *//' | tr -d '"' | tr -d ' '); \ + LLAMA_STACK_TAVILY=$$(grep "TAVILY_SEARCH_API_KEY:" "$(VALUES_FILE)" 2>/dev/null | sed 's/.*TAVILY_SEARCH_API_KEY: *//' | tr -d '"' | tr -d ' '); \ + \ + if [ -z "$$HF_TOKEN_FILE" ] || [ "$$HF_TOKEN_FILE" = "" ]; then \ echo -e "$(YELLOW)[WARNING]$(NC) Hugging Face token is not set. Model downloads may fail."; \ echo -e "$(BLUE)[INFO]$(NC) Get your token from: https://huggingface.co/settings/tokens"; \ echo -e ""; \ @@ -102,6 +116,7 @@ define validate_values_file fi; \ echo -e ""; \ fi; \ + \ if [ -z "$$LLAMA_STACK_TAVILY" ] || [ "$$LLAMA_STACK_TAVILY" = "Paste-your-key-here" ]; then \ echo -e "$(YELLOW)[WARNING]$(NC) TAVILY search API key is not set. Web search will be disabled."; \ echo -e "$(BLUE)[INFO]$(NC) Get your key from: https://tavily.com/"; \ @@ -116,6 +131,7 @@ define validate_values_file fi; \ echo -e ""; \ fi; \ + \ if [ $$UPDATED -eq 1 ]; then \ echo -e "$(GREEN)[SUCCESS]$(NC) Configuration updated. Proceeding with installation..."; \ echo -e ""; \ @@ -312,7 +328,7 @@ show-config: ## Show configuration file contents # Create namespace and deploy namespace: @echo -e "$(BLUE)[INFO]$(NC) Creating namespace $(NAMESPACE)..." - @oc create namespace $(NAMESPACE) &> /dev/null && oc label namespace $(NAMESPACE) modelmesh-enabled=false ||: + @oc new-project $(NAMESPACE) &> /dev/null && oc label namespace $(NAMESPACE) modelmesh-enabled=false &> /dev/null ||: @oc project $(NAMESPACE) &> /dev/null ||: @echo -e "$(GREEN)[SUCCESS]$(NC) Namespace $(NAMESPACE) is ready" @@ -479,11 +495,11 @@ install: ## Install the RAG deployment fi; \ if [ -n "$(LLM_URL)" ]; then \ echo -e "$(BLUE)[INFO]$(NC) Setting LLM URL: $(LLM_URL)"; \ - HELM_ARGS="$$HELM_ARGS --set global.models.$(LLM).url='$(LLM_URL)'"; \ + HELM_ARGS="$$HELM_ARGS --set global.models.$(LLM).url=$(LLM_URL)"; \ fi; \ if [ -n "$(LLM_API_TOKEN)" ]; then \ echo -e "$(BLUE)[INFO]$(NC) Setting LLM API token"; \ - HELM_ARGS="$$HELM_ARGS --set global.models.$(LLM).apiToken='$(LLM_API_TOKEN)'"; \ + HELM_ARGS="$$HELM_ARGS --set global.models.$(LLM).apiToken=$(LLM_API_TOKEN)"; \ fi; \ fi; \ if [ -n "$(SAFETY)" ]; then \ diff --git a/deploy/helm/rag/Chart.lock b/deploy/helm/rag/Chart.lock index 45c5ac1b..f3dfd713 100644 --- a/deploy/helm/rag/Chart.lock +++ b/deploy/helm/rag/Chart.lock @@ -1,21 +1,21 @@ dependencies: - name: pgvector repository: https://rh-ai-quickstart.github.io/ai-architecture-charts - version: 0.5.1 + version: 0.5.5 - name: llm-service repository: https://rh-ai-quickstart.github.io/ai-architecture-charts - version: 0.5.2 + version: 0.5.9 - name: configure-pipeline repository: https://rh-ai-quickstart.github.io/ai-architecture-charts - version: 0.5.4 + version: 0.5.6 - name: ingestion-pipeline repository: https://rh-ai-quickstart.github.io/ai-architecture-charts - version: 0.5.1 + version: 0.6.6 - name: llama-stack repository: https://rh-ai-quickstart.github.io/ai-architecture-charts - version: 0.5.2 + version: 0.6.11 - name: mcp-servers repository: https://rh-ai-quickstart.github.io/ai-architecture-charts - version: 0.5.7 -digest: sha256:d7abd4b5f5c4080a241c567f0bde351f927a5ac0d95fea4bbdf8f364f7a92866 -generated: "2025-12-05T10:53:08.788253807-05:00" + version: 0.5.15 +digest: sha256:1065a9cbf8dfb460fd9c9a6d3571fdfc33e3693503aae6535e07e75919a6c9f2 +generated: "2026-02-13T13:19:24.192726731-05:00" diff --git a/deploy/helm/rag/Chart.yaml b/deploy/helm/rag/Chart.yaml index 9a5583ed..8da8df28 100644 --- a/deploy/helm/rag/Chart.yaml +++ b/deploy/helm/rag/Chart.yaml @@ -2,31 +2,31 @@ apiVersion: v2 name: rag description: A Helm chart for Kubernetes type: application -version: 0.2.31 -appVersion: "0.2.31" +version: 0.2.32 +appVersion: "0.2.32" dependencies: - name: pgvector - version: 0.5.1 + version: 0.5.5 repository: https://rh-ai-quickstart.github.io/ai-architecture-charts condition: pgvector.enabled - name: llm-service - version: 0.5.2 + version: 0.5.9 repository: https://rh-ai-quickstart.github.io/ai-architecture-charts condition: llm-service.enabled - name: configure-pipeline - version: 0.5.4 + version: 0.5.6 repository: https://rh-ai-quickstart.github.io/ai-architecture-charts condition: configure-pipeline.enabled - name: ingestion-pipeline - version: 0.5.1 + version: 0.6.6 repository: https://rh-ai-quickstart.github.io/ai-architecture-charts condition: ingestion-pipeline.enabled - name: llama-stack - version: 0.5.2 + version: 0.6.11 repository: https://rh-ai-quickstart.github.io/ai-architecture-charts condition: llama-stack.enabled - name: mcp-servers - version: 0.5.7 + version: 0.5.15 repository: https://rh-ai-quickstart.github.io/ai-architecture-charts condition: mcp-servers.enabled diff --git a/deploy/helm/rag/templates/deployment.yaml b/deploy/helm/rag/templates/deployment.yaml index eddbcccf..dbc49cea 100644 --- a/deploy/helm/rag/templates/deployment.yaml +++ b/deploy/helm/rag/templates/deployment.yaml @@ -40,6 +40,18 @@ spec: - name: TAVILY_SEARCH_API_KEY value: {{ (index .Values "llama-stack").secrets.TAVILY_SEARCH_API_KEY | quote }} {{- end }} + {{- if .Values.pgvector }} + - name: PGVECTOR_HOST + value: {{ .Values.pgvector.secret.host | quote }} + - name: PGVECTOR_PORT + value: {{ .Values.pgvector.secret.port | quote }} + - name: PGVECTOR_USER + value: {{ .Values.pgvector.secret.user | quote }} + - name: PGVECTOR_PASSWORD + value: {{ .Values.pgvector.secret.password | quote }} + - name: PGVECTOR_DB + value: {{ .Values.pgvector.secret.dbname | quote }} + {{- end }} {{- if .Values.suggestedQuestions }} - name: RAG_QUESTION_SUGGESTIONS valueFrom: diff --git a/deploy/helm/rag/templates/route.yaml b/deploy/helm/rag/templates/route.yaml index 486adc6d..2c7b4b88 100644 --- a/deploy/helm/rag/templates/route.yaml +++ b/deploy/helm/rag/templates/route.yaml @@ -4,6 +4,9 @@ metadata: name: {{ include "rag.fullname" . }} labels: {{- include "rag.labels" . | nindent 4 }} + annotations: + # 10 minute timeout for large document uploads + haproxy.router.openshift.io/timeout: 600s spec: to: kind: Service diff --git a/deploy/helm/rag/values.yaml b/deploy/helm/rag/values.yaml index d52a2fd6..59a0e95d 100644 --- a/deploy/helm/rag/values.yaml +++ b/deploy/helm/rag/values.yaml @@ -3,7 +3,7 @@ replicaCount: 1 image: repository: quay.io/rh-ai-quickstart/llamastack-dist-ui pullPolicy: Always - tag: 0.2.31 + tag: 0.2.32 service: type: ClusterIP diff --git a/frontend/llama_stack_ui/distribution/ui/app.py b/frontend/llama_stack_ui/distribution/ui/app.py index 772cb91f..ec08af75 100644 --- a/frontend/llama_stack_ui/distribution/ui/app.py +++ b/frontend/llama_stack_ui/distribution/ui/app.py @@ -3,9 +3,17 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import logging + import streamlit as st + def main(): + # Configure logging to show DEBUG messages by default + logging.basicConfig( + level=logging.DEBUG, + format='[%(levelname)s] %(name)s: %(message)s' + ) # Define available pages: path and icon pages = { "Chat": ("page/playground/chat.py", "💬"), @@ -15,7 +23,7 @@ def main(): # Build navigation items dynamically nav_items = [ - st.Page(path, title=name, icon=icon, default=(name == "Chat")) + st.Page(path, title=name, icon=icon, default=name == "Chat") for name, (path, icon) in pages.items() ] # Render navigation @@ -25,4 +33,3 @@ def main(): if __name__ == "__main__": main() - diff --git a/frontend/llama_stack_ui/distribution/ui/modules/api.py b/frontend/llama_stack_ui/distribution/ui/modules/api.py index 96023239..6e884ed2 100644 --- a/frontend/llama_stack_ui/distribution/ui/modules/api.py +++ b/frontend/llama_stack_ui/distribution/ui/modules/api.py @@ -11,9 +11,13 @@ class LlamaStackApi: def __init__(self): + # Timeout of 600 seconds (10 minutes) for large document uploads + # Default is 60 seconds which is too short for large PDFs + timeout = float(os.environ.get("LLAMA_STACK_TIMEOUT", "600")) + self.client = LlamaStackClient( base_url=os.environ.get("LLAMA_STACK_ENDPOINT", "http://localhost:8321"), - + timeout=timeout, provider_data={ "fireworks_api_key": os.environ.get("FIREWORKS_API_KEY", ""), "together_api_key": os.environ.get("TOGETHER_API_KEY", ""), diff --git a/frontend/llama_stack_ui/distribution/ui/modules/utils.py b/frontend/llama_stack_ui/distribution/ui/modules/utils.py index b8bf725c..a706e786 100644 --- a/frontend/llama_stack_ui/distribution/ui/modules/utils.py +++ b/frontend/llama_stack_ui/distribution/ui/modules/utils.py @@ -7,6 +7,7 @@ import base64 import json import os +import re import pandas as pd import streamlit as st @@ -57,18 +58,23 @@ def data_url_from_file(file) -> str: return data_url +def clean_text(text): + """Collapse consecutive whitespace into a single space.""" + return re.sub(r'\s+', ' ', text).strip() + + def get_vector_db_name(vector_db): """ Get the display name for a vector database. - Falls back to identifier if vector_db_name attribute is not present. - + Falls back to id if name attribute is not present. + Args: vector_db: Vector database object from API - + Returns: str: The vector database name """ - return getattr(vector_db, 'vector_db_name', vector_db.identifier) + return getattr(vector_db, 'name', vector_db.id) def get_question_suggestions(): @@ -91,39 +97,39 @@ def get_question_suggestions(): def get_suggestions_for_databases(selected_dbs, all_vector_dbs): """ Get combined question suggestions for selected databases. - + Args: selected_dbs: List of selected vector DB names all_vector_dbs: List of all vector DB objects from API - + Returns: List of tuples (question, source_db_name) """ suggestions_map = get_question_suggestions() combined_suggestions = [] - + if not suggestions_map: return [] - - # Create a mapping from vector_db_name to identifier - db_name_to_identifier = { - get_vector_db_name(vdb): vdb.identifier + + # Create a mapping from vector_db_name to id + db_name_to_id = { + get_vector_db_name(vdb): vdb.id for vdb in all_vector_dbs } - + for db_name in selected_dbs: - # Get the identifier for this database name - db_identifier = db_name_to_identifier.get(db_name) - - # Try both the identifier and the db_name as keys in the suggestions map + # Get the id for this database name + db_id = db_name_to_id.get(db_name) + + # Try both the id and the db_name as keys in the suggestions map questions = None - if db_identifier and db_identifier in suggestions_map: - questions = suggestions_map[db_identifier] + if db_id and db_id in suggestions_map: + questions = suggestions_map[db_id] elif db_name in suggestions_map: questions = suggestions_map[db_name] - + if questions: for question in questions: combined_suggestions.append((question, db_name)) - + return combined_suggestions diff --git a/frontend/llama_stack_ui/distribution/ui/page/distribution/inspect.py b/frontend/llama_stack_ui/distribution/ui/page/distribution/inspect.py index a159b7a1..46835167 100644 --- a/frontend/llama_stack_ui/distribution/ui/page/distribution/inspect.py +++ b/frontend/llama_stack_ui/distribution/ui/page/distribution/inspect.py @@ -10,10 +10,7 @@ from streamlit_option_menu import option_menu import streamlit as st -from llama_stack_ui.distribution.ui.page.distribution.datasets import datasets -from llama_stack_ui.distribution.ui.page.distribution.eval_tasks import benchmarks from llama_stack_ui.distribution.ui.page.distribution.models import models -from llama_stack_ui.distribution.ui.page.distribution.scoring_functions import scoring_functions from llama_stack_ui.distribution.ui.page.distribution.shields import shields from llama_stack_ui.distribution.ui.page.distribution.providers import providers from llama_stack_ui.distribution.ui.page.distribution.vector_dbs import vector_dbs diff --git a/frontend/llama_stack_ui/distribution/ui/page/distribution/providers.py b/frontend/llama_stack_ui/distribution/ui/page/distribution/providers.py index 7549d0e4..749ad812 100644 --- a/frontend/llama_stack_ui/distribution/ui/page/distribution/providers.py +++ b/frontend/llama_stack_ui/distribution/ui/page/distribution/providers.py @@ -29,6 +29,3 @@ def providers(): for api_name, providers in api_to_providers.items(): st.markdown(f"###### {api_name}") st.dataframe([p.to_dict() for p in providers], width=500) - - - diff --git a/frontend/llama_stack_ui/distribution/ui/page/distribution/scoring_functions.py b/frontend/llama_stack_ui/distribution/ui/page/distribution/scoring_functions.py index 3fd51bd4..5f531f07 100644 --- a/frontend/llama_stack_ui/distribution/ui/page/distribution/scoring_functions.py +++ b/frontend/llama_stack_ui/distribution/ui/page/distribution/scoring_functions.py @@ -22,5 +22,8 @@ def scoring_functions(): scoring_functions_info = {s.identifier: s.to_dict() for s in sf_list} # Let user select and view a scoring function - selected_scoring_function = st.selectbox("Select a scoring function", list(scoring_functions_info.keys())) + selected_scoring_function = st.selectbox( + "Select a scoring function", + list(scoring_functions_info.keys()), + ) st.json(scoring_functions_info[selected_scoring_function], expanded=True) diff --git a/frontend/llama_stack_ui/distribution/ui/page/distribution/vector_dbs.py b/frontend/llama_stack_ui/distribution/ui/page/distribution/vector_dbs.py index 37c46616..0ccb9a9b 100644 --- a/frontend/llama_stack_ui/distribution/ui/page/distribution/vector_dbs.py +++ b/frontend/llama_stack_ui/distribution/ui/page/distribution/vector_dbs.py @@ -4,10 +4,10 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack_ui.distribution.ui.modules.utils import get_vector_db_name import streamlit as st from llama_stack_ui.distribution.ui.modules.api import llama_stack_api +from llama_stack_ui.distribution.ui.modules.utils import get_vector_db_name def vector_dbs(): @@ -16,7 +16,7 @@ def vector_dbs(): """ st.header("Vector Databases") # Fetch all vector databases - vdb_list = llama_stack_api.client.vector_dbs.list() + vdb_list = llama_stack_api.client.vector_stores.list() if not vdb_list: st.info("No vector databases found.") return diff --git a/frontend/llama_stack_ui/distribution/ui/page/evaluations/app_eval.py b/frontend/llama_stack_ui/distribution/ui/page/evaluations/app_eval.py index f595382e..e4dc6e73 100644 --- a/frontend/llama_stack_ui/distribution/ui/page/evaluations/app_eval.py +++ b/frontend/llama_stack_ui/distribution/ui/page/evaluations/app_eval.py @@ -85,8 +85,12 @@ def application_evaluation_page(): ) new_params[param_name] = value else: + label = ( + f"Enter value for **{param_name}** in " + f"{scoring_fn_id} in valid JSON format" + ) value = st.text_area( - f"Enter value for **{param_name}** in {scoring_fn_id} in valid JSON format", + label, value=json.dumps(param_value, indent=2), height=80, ) diff --git a/frontend/llama_stack_ui/distribution/ui/page/evaluations/evaluations.py b/frontend/llama_stack_ui/distribution/ui/page/evaluations/evaluations.py index be165adc..47ed9066 100644 --- a/frontend/llama_stack_ui/distribution/ui/page/evaluations/evaluations.py +++ b/frontend/llama_stack_ui/distribution/ui/page/evaluations/evaluations.py @@ -10,4 +10,4 @@ def evaluations_page(): with tabs[1]: native_evaluation_page() -evaluations_page() \ No newline at end of file +evaluations_page() diff --git a/frontend/llama_stack_ui/distribution/ui/page/evaluations/native_eval.py b/frontend/llama_stack_ui/distribution/ui/page/evaluations/native_eval.py index 5539ea95..b4aecdc9 100644 --- a/frontend/llama_stack_ui/distribution/ui/page/evaluations/native_eval.py +++ b/frontend/llama_stack_ui/distribution/ui/page/evaluations/native_eval.py @@ -5,7 +5,8 @@ from llama_stack_ui.distribution.ui.modules.api import llama_stack_api """ -Native Evaluation page: select a benchmark, configure eval candidate, and run full generation + scoring. +Native Evaluation page: select a benchmark, configure eval candidate, +and run full generation + scoring. """ def select_benchmark_1(): @@ -48,7 +49,8 @@ def define_eval_candidate_2(): st.subheader("2. Define Eval Candidate") st.info( - "Define generation configuration: choose 'model' for inference API or 'agent' for agent API." + "Define generation configuration: choose 'model' for inference API " + "or 'agent' for agent API." ) with st.expander("Define Eval Candidate", expanded=True): @@ -216,4 +218,4 @@ def native_evaluation_page(): run_evaluation_3() -native_evaluation_page() \ No newline at end of file +native_evaluation_page() diff --git a/frontend/llama_stack_ui/distribution/ui/page/playground/agent.py b/frontend/llama_stack_ui/distribution/ui/page/playground/agent.py new file mode 100644 index 00000000..56bc38de --- /dev/null +++ b/frontend/llama_stack_ui/distribution/ui/page/playground/agent.py @@ -0,0 +1,362 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +""" +Agent mode implementation for chat with automatic tool calling. +""" + +import logging + +import streamlit as st + +from llama_stack_ui.distribution.ui.modules.api import llama_stack_api +from llama_stack_ui.distribution.ui.modules.utils import clean_text, get_vector_db_name + + +logger = logging.getLogger(__name__) + + +def build_response_tools(toolgroup_selection, selected_vector_dbs, client): + """ + Convert toolgroup selections to LlamaStack Responses API compatible tool format. + + Args: + toolgroup_selection: List of selected toolgroup IDs + selected_vector_dbs: List of selected vector database names + client: LlamaStack client instance + + Returns: + List of tools in Responses API format (works for both Agent and Direct modes) + """ + agent_tools = [] + + for toolgroup_name in toolgroup_selection: + if toolgroup_name == "builtin::rag": + if len(selected_vector_dbs) > 0: + vector_dbs = client.vector_stores.list() or [] + vector_db_ids = [ + vector_db.id for vector_db in vector_dbs + if get_vector_db_name(vector_db) in selected_vector_dbs + ] + # Use file_search tool format + agent_tools.append({ + "type": "file_search", + "vector_store_ids": list(vector_db_ids), + }) + elif "web_search" in toolgroup_name or "search" in toolgroup_name.lower(): + # Convert search tools to web_search format + agent_tools.append({"type": "web_search"}) + elif toolgroup_name.startswith("mcp::"): + # For MCP tools, get server info + try: + toolgroups = client.toolgroups.list() + for toolgroup in toolgroups: + if str(toolgroup.identifier) == toolgroup_name: + agent_tools.append({ + "type": "mcp", + "server_label": toolgroup.args.get( + "name", str(toolgroup.identifier) + ), + "server_url": toolgroup.mcp_endpoint.uri, + }) + break + except Exception as e: + logger.logger.debug("Failed to get MCP server info for %s: %s", toolgroup_name, e) + else: + # For other toolgroups, get individual tools and convert to function format + try: + tools_in_group = client.tools.list(toolgroup_id=toolgroup_name) + for tool in tools_in_group: + # Convert to function tool dict + agent_tools.append({ + "type": "function", + "function": { + "name": tool.name, + "description": tool.description or "", + "parameters": tool.parameters or {} + } + }) + except Exception as e: + logger.logger.debug("Failed to get tools for %s: %s", toolgroup_name, e) + + return agent_tools + + +# ============================================================================ +# Agent Mode - Chunk Handlers +# ============================================================================ + +def handle_agent_file_search_chunk(state, selected_vector_dbs): + """Handle file_search tool chunk in agent mode.""" + if state.tool_used: + return + + # Show tool status message in persistent container + if selected_vector_dbs: + db_label = "vector store" if len(selected_vector_dbs) == 1 else "vector stores" + status_msg = ( + f"🛠 :grey[_Using file_search tool with {db_label}: " + f"{', '.join(selected_vector_dbs)}_]" + ) + else: + status_msg = "🛠 :grey[_Using file_search tool..._]" + + state.tool_status = status_msg + with state.containers.tool_status: + st.markdown(status_msg) + + +def handle_agent_web_search_chunk(state): + """Handle web_search tool chunk in agent mode.""" + if state.tool_used: + return + + status_msg = "🛠 :grey[_Using web_search tool..._]" + state.tool_status = status_msg + with state.containers.tool_status: + st.markdown(status_msg) + + +def handle_agent_output_item_done(chunk, state): + """Handle response.output_item.done - tool execution completion with results.""" + if not hasattr(chunk, 'item'): + return + + item = chunk.item + item_type = getattr(item, 'type', None) + + if item_type == "file_search_call": + # File search results + if hasattr(item, 'results') and item.results: + display_results = [] + for r in item.results: + text = getattr(r, 'text', '') + attrs = getattr(r, 'attributes', {}) + source = attrs.get('source') or getattr(r, 'filename', 'unknown') + display_results.append({"source": source, "text": clean_text(text)}) + state.tool_results.append({ + 'title': '📄 File Search Results', + 'type': 'json', + 'content': display_results + }) + with state.containers.tool_results: + with st.expander("📄 File Search Results", expanded=False): + st.json(display_results) + + elif item_type == "web_search_call": + # Web search - API doesn't expose raw results, just status + pass + + elif item_type == "function_call": + # Function call output + if hasattr(item, 'output') and item.output: + tool_name = getattr(item, 'name', 'function') + state.tool_results.append({ + 'title': f'🔧 Tool Output: {tool_name}', + 'type': 'code', + 'content': str(item.output) + }) + with state.containers.tool_results: + with st.expander(f"🔧 Tool Output: {tool_name}", expanded=False): + st.code(str(item.output)) + + elif item_type == "mcp_call": + # MCP call output + if hasattr(item, 'output') and item.output: + tool_name = getattr(item, 'name', 'mcp') + state.tool_results.append({ + 'title': f'🔧 MCP Tool Output: {tool_name}', + 'type': 'code', + 'content': str(item.output) + }) + with state.containers.tool_results: + with st.expander(f"🔧 MCP Tool Output: {tool_name}", expanded=False): + st.code(str(item.output)) + + elif item_type and item_type.endswith("_call"): + # Generic handler for any other tool call types + if hasattr(item, 'results') and item.results: + formatted_name = item_type.replace("_", " ").title() + state.tool_results.append({ + 'title': f'🔧 {formatted_name} Results', + 'type': 'json', + 'content': item.results + }) + with state.containers.tool_results: + with st.expander(f"🔧 {formatted_name} Results", expanded=False): + st.json(item.results) + elif hasattr(item, 'output') and item.output: + formatted_name = item_type.replace("_", " ").title() + state.tool_results.append({ + 'title': f'🔧 {formatted_name} Output', + 'type': 'json', + 'content': item.output + }) + with state.containers.tool_results: + with st.expander(f"🔧 {formatted_name} Output", expanded=False): + st.json(item.output) + + +def handle_chunk_error(chunk): + """Handle error chunk and return whether to stop streaming.""" + error_msg = "Unknown error" + error_code = None + + # Try to get error from chunk.error first + if hasattr(chunk, 'error') and chunk.error: + if hasattr(chunk.error, 'message'): + error_msg = chunk.error.message + if hasattr(chunk.error, 'code'): + error_code = chunk.error.code + # Fallback to chunk attributes + elif hasattr(chunk, 'error_message'): + error_msg = chunk.error_message + + error_display = f"❌ Error: {error_msg}" + if error_code: + error_display += f" (Code: {error_code})" + + st.error(error_display) + logger.debug("Response failed: %s", error_msg) + return True # Stop streaming + + +def handle_chunk_completed(chunk): + """Handle completed chunk.""" + logger.debug("Response completed successfully") + if hasattr(chunk, 'stop_reason'): + logger.debug("Stop reason: %s", chunk.stop_reason) + + +def handle_chunk_done(chunk, state): + """Handle done chunk and finalize response.""" + has_output = ( + hasattr(chunk, 'response') and + hasattr(chunk.response, 'output_text') and + chunk.response.output_text + ) + if has_output: + state.full_response = chunk.response.output_text + + +def process_chunk_by_type(chunk, state, selected_vector_dbs): + """Process a single chunk based on its type. Returns True to stop streaming.""" + chunk_type = chunk.type + + # Handle file_search tool + if chunk_type == "response.file_search_call.in_progress": + handle_agent_file_search_chunk(state, selected_vector_dbs) + + # Handle web_search tool + elif chunk_type == "response.web_search_call.in_progress": + handle_agent_web_search_chunk(state) + + elif chunk_type in ("response.web_search_call.searching", + "response.web_search_call.completed"): + pass # Just for event tracking + + # Handle tool results + elif chunk_type == "response.output_item.done": + handle_agent_output_item_done(chunk, state) + + # Handle reasoning + elif chunk_type == "response.reasoning_text.delta": + if hasattr(chunk, 'delta') and chunk.delta: + state.update_reasoning(chunk.delta) + + # Handle message content + elif chunk_type == "response.output_text.delta": + if hasattr(chunk, 'delta') and chunk.delta: + state.update_message(chunk.delta) + + # Handle errors + elif chunk_type == "response.failed": + return handle_chunk_error(chunk) + + # Handle completion + elif chunk_type == "response.completed": + handle_chunk_completed(chunk) + + # Handle done + elif chunk_type == "response.done": + handle_chunk_done(chunk, state) + + return False # Continue streaming + + +# ============================================================================ +# Agent Mode - Main Functions +# ============================================================================ + +def stream_agent_response(response, state, selected_vector_dbs): + """ + Stream and process chunks from Responses API. + Updates state containers as chunks arrive. + """ + chunk_count = 0 + + for chunk in response: + chunk_count += 1 + logger.debug("Chunk #%s: type=%s", chunk_count, getattr(chunk, 'type', 'NO_TYPE')) + logger.debug(" -> Full chunk: %s", chunk) + + if hasattr(chunk, 'type'): + should_stop = process_chunk_by_type(chunk, state, selected_vector_dbs) + if should_stop: + break + + +def save_agent_response_to_session(state): + """Save agent response to session state.""" + state.finalize_reasoning() + state.finalize_message() + + response_dict = { + "role": "assistant", + "content": state.full_response, + "stop_reason": "end_of_message" + } + + if state.reasoning_text: + response_dict["reasoning"] = state.reasoning_text + if state.tool_status: + response_dict["tool_status"] = state.tool_status + if state.tool_results: + response_dict["tool_results"] = state.tool_results + + st.session_state.messages.append(response_dict) + + +def agent_process_prompt(prompt, state, config): + """Agent-based mode: Use Responses API with automatic tool calling.""" + # Build tools list from selected toolgroups + tools = build_response_tools( + config.toolgroup_selection, config.selected_vector_dbs, llama_stack_api.client + ) if config.toolgroup_selection else None + + # Build request for Responses API + request_kwargs = { + "model": config.model, + "instructions": config.system_prompt, + "input": prompt, + "conversation": config.conversation_id, + "temperature": config.sampling.temperature, + "max_infer_iters": config.sampling.max_infer_iters, + "stream": True, + } + + # Add tools if available + if tools: + request_kwargs["tools"] = tools + + logger.debug("Request: %s", request_kwargs) + response = llama_stack_api.client.responses.create(**request_kwargs) + + # Stream response and update UI + stream_agent_response(response, state, config.selected_vector_dbs) + + # Save response to session + save_agent_response_to_session(state) diff --git a/frontend/llama_stack_ui/distribution/ui/page/playground/chat.py b/frontend/llama_stack_ui/distribution/ui/page/playground/chat.py index 03e91935..081d2335 100644 --- a/frontend/llama_stack_ui/distribution/ui/page/playground/chat.py +++ b/frontend/llama_stack_ui/distribution/ui/page/playground/chat.py @@ -4,671 +4,577 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import enum -import json -import uuid +""" +Chat page for LlamaStack UI with RAG support. +Provides both Direct mode (manual RAG) and Agent-based mode (automatic tool calling). +""" + +import logging +from dataclasses import dataclass import streamlit as st -from llama_stack_client.lib.agents.agent import Agent -from llama_stack_client.lib.agents.event_logger import EventLogger -from llama_stack_client.lib.agents.react.agent import ReActAgent -from llama_stack_client.lib.agents.react.tool_parser import ReActOutput -from llama_stack.apis.common.content_types import ToolCallDelta + from llama_stack_ui.distribution.ui.modules.api import llama_stack_api -from llama_stack_ui.distribution.ui.modules.utils import get_suggestions_for_databases, get_vector_db_name -from llama_stack_client.types import UserMessage -from llama_stack_client.types.shared_params import SamplingParams -from llama_stack_client.types.shared_params.response_format import JsonSchemaResponseFormat -from llama_stack_client.types.shared_params.sampling_params import StrategyTopPSamplingStrategy +from llama_stack_ui.distribution.ui.modules.utils import ( + get_suggestions_for_databases, + get_vector_db_name, +) +from llama_stack_ui.distribution.ui.page.playground.agent import ( + agent_process_prompt, +) +from llama_stack_ui.distribution.ui.page.playground.direct import ( + direct_process_prompt, +) + + +logger = logging.getLogger(__name__) + +def render_tool_results(tool_results): + """Render tool results from a message.""" + for tool_result in tool_results: + with st.expander(tool_result['title'], expanded=False): + if tool_result['type'] == 'json': + st.json(tool_result['content']) + else: + st.code(tool_result['content']) -class AgentType(enum.Enum): - REGULAR = "Regular" - REACT = "ReAct" +def render_message(msg): + """Render a single message in chat history.""" + with st.chat_message(msg['role']): + # Display tool status if present + if msg.get('tool_status'): + st.markdown(msg['tool_status']) -def get_strategy(temperature, top_p): - """Determines the sampling strategy for the LLM based on temperature.""" - return {'type': 'greedy'} if temperature == 0 else { - 'type': 'top_p', 'temperature': temperature, 'top_p': top_p - } + # Display tool results if present + if msg.get('tool_results'): + render_tool_results(msg['tool_results']) + # Display reasoning if present (right before the answer) + if msg.get('reasoning'): + with st.expander("🧠 Reasoning", expanded=False): + st.markdown(msg['reasoning']) -def render_history(tool_debug): - """Renders the chat history from the session state. - Also displays debug events for assistant messages if tool_debug is enabled. - """ + # Display the final answer + st.markdown(msg['content']) + + +def render_history(): + """Renders the chat history from the session state.""" # Initialize messages in the session state if not present if 'messages' not in st.session_state: st.session_state.messages = [{"role": "assistant", "content": "How can I help you?"}] - # Initialize debug_events in the session state if not present - if 'debug_events' not in st.session_state: - st.session_state.debug_events = [] - - for i, msg in enumerate(st.session_state.messages): - with st.chat_message(msg['role']): - st.markdown(msg['content']) - - # Display debug events expander for assistant messages (excluding the initial greeting) - if msg['role'] == 'assistant' and tool_debug and i > 0: - # Debug events are stored per assistant turn. - # The index for debug_events corresponds to the assistant message turn. - # messages: [A_initial, U_1, A_1, U_2, A_2, ...] - # debug_events: [events_for_A_1, events_for_A_2, ...] - # For A_1 (msg index 2), the debug_events index is (2//2)-1 = 0. - debug_event_list_index = (i // 2) - 1 - if 0 <= debug_event_list_index < len(st.session_state.debug_events): - current_turn_events_list = st.session_state.debug_events[debug_event_list_index] - - if current_turn_events_list: # Only show expander if there are events - with st.expander("Tool/Debug Events", expanded=False): - if isinstance(current_turn_events_list, list) and len(current_turn_events_list) > 0: - for event_idx, event_item in enumerate(current_turn_events_list): - with st.container(): - if isinstance(event_item, dict): - st.json(event_item, expanded=False) - elif isinstance(event_item, str): - st.text_area( - label=f"Debug Event {event_idx + 1}", - value=event_item, - height=100, - disabled=True, - key=f"debug_event_msg{i}_item{event_idx}" # Unique key for each text area - ) - else: - st.write(event_item) # Fallback for other data types - if event_idx < len(current_turn_events_list) - 1: - st.divider() - elif isinstance(current_turn_events_list, list) and not current_turn_events_list: - st.caption("No debug events recorded for this turn.") - else: # Should not happen with current logic - st.write("Debug data for this turn (unexpected format):") - st.write(current_turn_events_list) -def tool_chat_page(): - st.title("💬 Chat") + for msg in st.session_state.messages: + render_message(msg) + +def fetch_models_and_tools(): + """Fetch and categorize models and toolgroups from LlamaStack.""" client = llama_stack_api.client + + # Fetch models models = client.models.list() model_list = [model.identifier for model in models if model.api_model_type == "llm"] + # Fetch and categorize toolgroups tool_groups = client.toolgroups.list() + logger.debug("Raw tool groups from LlamaStack: %s", tool_groups) tool_groups_list = [tool_group.identifier for tool_group in tool_groups] + logger.debug("Tool group identifiers: %s", tool_groups_list) + mcp_tools_list = [tool for tool in tool_groups_list if tool.startswith("mcp::")] + logger.debug("MCP tools: %s", mcp_tools_list) + builtin_tools_list = [tool for tool in tool_groups_list if not tool.startswith("mcp::")] + logger.debug("Built-in tools: %s", builtin_tools_list) - selected_vector_dbs = [] + return model_list, builtin_tools_list, mcp_tools_list - def reset_agent(): - st.session_state.clear() - st.cache_resource.clear() - with st.sidebar: - st.title("Configuration") - st.subheader("Model") - model = st.selectbox(label="Model", options=model_list, on_change=reset_agent, label_visibility="collapsed") - - ## Added mode - processing_mode = st.radio( - "Processing mode", - ["Direct", "Agent-based"], - index=0, # Default to Direct - captions=[ - "Directly calls the model with optional RAG.", - "Uses an Agent with tools.", - ], - on_change=reset_agent, - help="Choose how requests are processed. 'Direct' bypasses agents, 'Agent-based' uses them.", - ) +def render_toolgroup_selection(builtin_tools_list, mcp_tools_list, selected_vector_dbs, + on_toolgroup_change, on_reset): + """Render toolgroup selection UI and return selected toolgroups.""" + st.subheader("Available ToolGroups") - - toolgroup_selection = [] - if processing_mode == "Direct": - vector_dbs = llama_stack_api.client.vector_dbs.list() or [] - if not vector_dbs: - st.info("No vector databases available for selection.") - vector_db_names = [get_vector_db_name(vector_db) for vector_db in vector_dbs] - selected_vector_dbs = st.multiselect( - label="Select Document Collections to use in RAG queries", - options=vector_db_names, - on_change=reset_agent, - ) - if processing_mode == "Agent-based": - st.subheader("Available ToolGroups") - - toolgroup_selection = st.pills( - label="Built-in tools", - options=builtin_tools_list, - selection_mode="multi", - on_change=reset_agent, - format_func=lambda tool: "".join(tool.split("::")[1:]), - help="List of built-in tools from your llama stack server.", - ) - - if "builtin::rag" in toolgroup_selection: - vector_dbs = llama_stack_api.client.vector_dbs.list() or [] - if not vector_dbs: - st.info("No vector databases available for selection.") - vector_db_names = [get_vector_db_name(vector_db) for vector_db in vector_dbs] - selected_vector_dbs = st.multiselect( - label="Select Document Collections to use in RAG queries", - options=vector_db_names, - on_change=reset_agent, - ) - - # Display mcp list only if there are mcp tools - if len(mcp_tools_list) > 0: - mcp_selection = st.pills( - label="MCP Servers", - options=mcp_tools_list, - selection_mode="multi", - on_change=reset_agent, - format_func=lambda tool: "".join(tool.split("::")[1:]), - help="List of MCP servers registered to your llama stack server.", - ) - - toolgroup_selection.extend(mcp_selection) - - grouped_tools = {} - total_tools = 0 - - for toolgroup_id in toolgroup_selection: - tools = client.tools.list(toolgroup_id=toolgroup_id) - grouped_tools[toolgroup_id] = [tool.identifier for tool in tools] - total_tools += len(tools) - - st.markdown(f"Active Tools: 🛠 {total_tools}") - - for group_id, tools in grouped_tools.items(): - with st.expander(f"🔧 Tools from `{group_id}`"): - for idx, tool in enumerate(tools, start=1): - st.markdown(f"{idx}. `{tool.split(':')[-1]}`") - - # st.subheader("Agent Configurations") - # st.subheader("Agent Type") - # agent_type = st.radio( - # label="Select Agent Type", - # options=["Regular", "ReAct"], - # on_change=reset_agent, - # ) - - # if agent_type == "ReAct": - # agent_type = AgentType.REACT - # else: - # agent_type = AgentType.REGULAR - agent_type = AgentType.REGULAR - - if processing_mode == "Agent-based": - input_shields = [] - output_shields = [] - - st.subheader("Security Shields") - shields_available = client.shields.list() - shield_options = [s.identifier for s in shields_available if hasattr(s, 'identifier')] - input_shields = st.multiselect("Input Shields", options=shield_options, on_change=reset_agent) - output_shields = st.multiselect("Output Shields", options=shield_options, on_change=reset_agent) - - st.subheader("Sampling Parameters") - temperature = st.slider("Temperature", 0.0, 2.0, 0.1, 0.05, on_change=reset_agent) - top_p = st.slider("Top P", 0.0, 1.0, 0.95, 0.05, on_change=reset_agent) - max_tokens = st.slider("Max Tokens", 1, 4096, 512, 64, on_change=reset_agent) - repetition_penalty = st.slider("Repetition Penalty", 1.0, 2.0, 1.0, 0.05, on_change=reset_agent) - - st.subheader("System Prompt") - default_prompt = "You are a helpful AI assistant." - if processing_mode == "Agent-based" and agent_type == AgentType.REACT: - default_prompt = "You are a helpful ReAct agent. Reason step-by-step to fulfill the user query using available tools." - system_prompt = st.text_area( - "System Prompt", value=default_prompt, on_change=reset_agent, height=100 + # Initialize toolgroup selector if not present + if "toolgroup_selector" not in st.session_state: + # Default: include RAG if a vector DB is selected + if selected_vector_dbs: + st.session_state["toolgroup_selector"] = ["builtin::rag"] + else: + st.session_state["toolgroup_selector"] = [] + + # Built-in tools selection (web_search, etc.) + toolgroup_selection = st.pills( + label="Built-in tools", + options=builtin_tools_list, + selection_mode="multi", + key="toolgroup_selector", + on_change=on_toolgroup_change, + format_func=lambda tool: "".join(tool.split("::")[1:]), + help="List of built-in tools from your llama stack server.", + ) + + # MCP tools selection (if available) + if mcp_tools_list: + mcp_selection = st.pills( + label="MCP Servers", + options=mcp_tools_list, + selection_mode="multi", + on_change=on_reset, + format_func=lambda tool: "".join(tool.split("::")[1:]), + help="List of MCP servers registered to your llama stack server.", ) + toolgroup_selection = list(toolgroup_selection) + list(mcp_selection) - st.subheader("Response Handling") - #stream_opt = st.toggle("Stream Response", value=True, on_change=reset_agent) - tool_debug = st.toggle("Show Tool/Debug Info", value=False) + # Display active tools summary + client = llama_stack_api.client + grouped_tools = {} + total_tools = 0 - if st.button("Clear Chat & Reset Config", use_container_width=True): - reset_agent() - st.rerun() - + for toolgroup_id in toolgroup_selection: + tools = client.tools.list(toolgroup_id=toolgroup_id) + logger.debug("Raw tools from toolgroup '%s': %s", toolgroup_id, tools) + grouped_tools[toolgroup_id] = [tool.name for tool in tools] + total_tools += len(tools) - updated_toolgroup_selection = [] - if processing_mode == "Agent-based": - for i, tool_name in enumerate(toolgroup_selection): - if tool_name == "builtin::rag": - if len(selected_vector_dbs) > 0: - vector_dbs = llama_stack_api.client.vector_dbs.list() or [] - vector_db_ids = [vector_db.identifier for vector_db in vector_dbs if get_vector_db_name(vector_db) in selected_vector_dbs] - tool_dict = dict( - name="builtin::rag/knowledge_search", - args={ - "vector_db_ids": list(vector_db_ids), - # Defaults - "query_config": { - "chunk_size_in_tokens": 512, - "chunk_overlap_in_tokens": 50, - }, - }, - ) - updated_toolgroup_selection.append(tool_dict) - else: - updated_toolgroup_selection.append(tool_name) - - @st.cache_resource - def create_agent(): - if "agent_type" in st.session_state and st.session_state.agent_type == AgentType.REACT: - return ReActAgent( - client=client, - model=model, - tools=updated_toolgroup_selection, - response_format=JsonSchemaResponseFormat( - type="json_schema", - json_schema=ReActOutput.model_json_schema() - ), - sampling_params=SamplingParams( - strategy=StrategyTopPSamplingStrategy(type="top_p", temperature=temperature, top_p=top_p), - max_tokens=max_tokens, - repetition_penalty=repetition_penalty, - ), - input_shields= input_shields, - output_shields= output_shields, - ) - else: - updated_system_prompt = system_prompt.strip() - updated_system_prompt = updated_system_prompt if updated_system_prompt.strip().endswith('.') else updated_system_prompt + '.' - return Agent( - client, - model=model, - instructions=f"{updated_system_prompt} When you use a tool always respond with a summary of the result.", - tools=updated_toolgroup_selection, - sampling_params=SamplingParams( - strategy=StrategyTopPSamplingStrategy(type="top_p", temperature=temperature, top_p=top_p), - max_tokens=max_tokens, - repetition_penalty=repetition_penalty, - ), - input_shields= input_shields, - output_shields= output_shields, - ) + logger.debug("Grouped tools summary: %s", grouped_tools) + + if total_tools > 0: + st.markdown(f"Active Tools: 🛠 {total_tools}") + for group_id, tools in grouped_tools.items(): + with st.expander(f"🔧 Tools from `{group_id}`"): + for idx, tool in enumerate(tools, start=1): + st.markdown(f"{idx}. `{tool.split(':')[-1]}`") + + return toolgroup_selection + + +def reset_agent(): + """Reset the agent by clearing session state and cache.""" + st.session_state.clear() + st.cache_resource.clear() + + +def create_vector_db_callbacks(processing_mode, vector_dbs): + """Create callbacks for vector DB and toolgroup synchronization.""" + def on_vector_db_change(): + """When vector DB changes, update toolgroup selection""" + if processing_mode != "Agent-based": + return + + selected_vdbs = st.session_state.get("chat_vector_db_selector", []) + current_toolgroups = st.session_state.get("toolgroup_selector", []) + + if not selected_vdbs: + # Remove RAG from toolgroups + if "builtin::rag" in current_toolgroups: + filtered = [t for t in current_toolgroups if t != "builtin::rag"] + st.session_state["toolgroup_selector"] = filtered + else: + # Add RAG to toolgroups if not present + if "builtin::rag" not in current_toolgroups: + updated = list(current_toolgroups) + ["builtin::rag"] + st.session_state["toolgroup_selector"] = updated + + def on_toolgroup_change(): + """When toolgroup changes, update vector DB selection""" + current_toolgroups = st.session_state.get("toolgroup_selector", []) + selected_vdbs = st.session_state.get("chat_vector_db_selector", []) + + if "builtin::rag" not in current_toolgroups: + # RAG deselected, clear vector DB selection + st.session_state["chat_vector_db_selector"] = [] + elif "builtin::rag" in current_toolgroups and not selected_vdbs: + # RAG selected but no vector DB, select first available + if vector_dbs: + first_vdb = get_vector_db_name(vector_dbs[0]) + st.session_state["chat_vector_db_selector"] = [first_vdb] + + return on_vector_db_change, on_toolgroup_change + + +def render_sidebar_configuration(model_list, builtin_tools_list, mcp_tools_list): + """Render sidebar configuration and return selected parameters.""" + st.title("Configuration") + st.subheader("Model") + model = st.selectbox( + label="Model", + options=model_list, + on_change=reset_agent, + label_visibility="collapsed", + ) + + # Processing Mode + processing_mode = st.radio( + "Processing mode", + ["Direct", "Agent-based"], + index=0, + captions=[ + "Passes vector store search results as context to LLM", + "Uses Responses API with tool calling", + ], + on_change=reset_agent, + help="Choose how requests are processed.", + ) + + # Vector Database Selection + vector_dbs = list(llama_stack_api.client.vector_stores.list() or []) + on_vector_db_change, on_toolgroup_change = create_vector_db_callbacks( + processing_mode, vector_dbs + ) + + selected_vector_dbs = render_vector_db_selector( + vector_dbs, processing_mode, on_vector_db_change + ) + + # Toolgroup Selection (Agent-based mode only) + toolgroup_selection = [] if processing_mode == "Agent-based": - st.session_state.agent_type = agent_type - agent = create_agent() + toolgroup_selection = render_toolgroup_selection( + builtin_tools_list, mcp_tools_list, selected_vector_dbs, + on_toolgroup_change, reset_agent + ) + + # Sampling Parameters + st.subheader("Sampling Parameters") + temperature = st.slider( + "Temperature", + 0.0, 2.0, 0.1, 0.05, + on_change=reset_agent, + help="Controls randomness. Higher values = more random.", + ) + max_infer_iters = st.slider( + "Max Inference Iterations", + 1, 50, 10, 1, + on_change=reset_agent, + help="Maximum number of inference iterations before stopping", + ) + + # System Prompt + st.subheader("System Prompt") + default_prompt = "You are a helpful AI assistant." + system_prompt = st.text_area( + "System Prompt", value=default_prompt, on_change=reset_agent, height=100 + ) + + if st.button("Clear Chat & Reset Config", use_container_width=True): + reset_agent() + st.rerun() - if "agent_session_id" not in st.session_state: - st.session_state["agent_session_id"] = agent.create_session(session_name=f"tool_demo_{uuid.uuid4()}") + return { + 'model': model, + 'processing_mode': processing_mode, + 'selected_vector_dbs': selected_vector_dbs, + 'toolgroup_selection': toolgroup_selection, + 'temperature': temperature, + 'max_infer_iters': max_infer_iters, + 'system_prompt': system_prompt, + } - session_id = st.session_state["agent_session_id"] + +def render_vector_db_selector(vector_dbs, processing_mode, on_vector_db_change): + """Render vector database selector and return selected databases.""" + selected_vector_dbs = [] + + # Initialize vector DB selector if not present + if "chat_vector_db_selector" not in st.session_state: + st.session_state["chat_vector_db_selector"] = [] + + if not vector_dbs: + return selected_vector_dbs + + vector_db_names = [get_vector_db_name(vector_db) for vector_db in vector_dbs] + selected_vector_dbs = st.multiselect( + label="Select Document Collections for RAG queries", + options=vector_db_names, + key="chat_vector_db_selector", + on_change=on_vector_db_change, + help=( + "Select one or more vector databases to use for retrieval, " + "or leave empty for normal chat" + ) + ) + + # Store the selected DBs for Direct mode + if selected_vector_dbs: + if processing_mode == "Direct": + st.session_state["direct_vector_dbs"] = [ + vdb for vdb in vector_dbs + if get_vector_db_name(vdb) in selected_vector_dbs + ] + else: + # Clear direct_vector_dbs if nothing is selected + if "direct_vector_dbs" in st.session_state: + del st.session_state["direct_vector_dbs"] + + return selected_vector_dbs + + +def initialize_session_state(): + """Initialize session state variables.""" + if "conversation_id" not in st.session_state: + conversation = llama_stack_api.client.conversations.create() + st.session_state["conversation_id"] = conversation.id + logger.debug("Created new conversation: %s", conversation.id) if "messages" not in st.session_state: - st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you?", "stop_reason": "end_of_turn"}] - - if "debug_events" not in st.session_state: # Per-turn debug logs - st.session_state["debug_events"] = [] - + st.session_state["messages"] = [ + { + "role": "assistant", + "content": "How can I help you?", + "stop_reason": "end_of_turn", + } + ] + if "show_more_questions" not in st.session_state: st.session_state["show_more_questions"] = False - + if "selected_question" not in st.session_state: st.session_state["selected_question"] = None - render_history(tool_debug) # Display the current chat history and any past debug events - # Display suggested questions if databases are selected - def display_suggested_questions(): - """Display suggested questions based on selected databases.""" - if not selected_vector_dbs: - return - - vector_dbs = llama_stack_api.client.vector_dbs.list() or [] - suggestions = get_suggestions_for_databases(selected_vector_dbs, vector_dbs) - - if not suggestions: - return - - st.markdown("### 💡 Suggested Questions") - - # Determine how many questions to show - num_to_show = len(suggestions) if st.session_state.show_more_questions else min(4, len(suggestions)) - - # Display questions in a grid-like format using columns - cols_per_row = 2 - for i in range(0, num_to_show, cols_per_row): - cols = st.columns(cols_per_row) - for j in range(cols_per_row): - idx = i + j - if idx < num_to_show: - question, db_name = suggestions[idx] - with cols[j]: - # Create a button for each question - button_key = f"question_btn_{idx}_{hash(question)}" - if st.button( - question, - key=button_key, - use_container_width=True, - help=f"From: {db_name}" - ): - st.session_state.selected_question = question - st.rerun() - - # Show "Show More" or "Show Less" button if there are more than 4 questions - if len(suggestions) > 4: - col1, col2, col3 = st.columns([1, 1, 1]) - with col2: - if st.session_state.show_more_questions: - if st.button("Show Less", use_container_width=True): - st.session_state.show_more_questions = False - st.rerun() - else: - if st.button(f"Show More ({len(suggestions) - 4} more)", use_container_width=True): - st.session_state.show_more_questions = True - st.rerun() - - st.markdown("---") - - display_suggested_questions() - - def response_generator(turn_response, debug_events_list): - if st.session_state.get("agent_type") == AgentType.REACT: - return _handle_react_response(turn_response) +# ============================================================================ +# Configuration Classes +# ============================================================================ + +@dataclass +class SamplingParams: + """Sampling parameters for model inference.""" + temperature: float + max_infer_iters: int + + +@dataclass +class ChatConfig: + """Configuration for chat processing.""" + model: str + processing_mode: str + system_prompt: str + conversation_id: str + toolgroup_selection: list + selected_vector_dbs: list + sampling: SamplingParams + + +# ============================================================================ +# UI Container Classes +# ============================================================================ + +class Containers: + """Simple container for UI elements. + + Note: Containers are created in visual display order (top to bottom). + """ + def __init__(self): + # Create containers in visual order: tools -> reasoning -> message + self.tool_status = st.container() + self.tool_results = st.container() + self.reasoning = st.empty() + self.message = st.empty() + +class ResponseState: + """ + State container for assistant response UI components and data. + Shared by both Agent-based and Direct modes. + """ + def __init__(self): + # UI containers (grouped) + self.containers = Containers() + + # Reasoning state + self.reasoning_text = "" + self.reasoning_placeholder = None + + # Tool state + self.tool_status = None + self.tool_results = [] + + # Response content + self.full_response = "" + + @property + def has_reasoning(self): + """Check if reasoning has been started.""" + return self.reasoning_placeholder is not None + + @property + def tool_used(self): + """Check if any tool has been used.""" + return self.tool_status is not None + + def update_reasoning(self, delta_text): + """Add reasoning text and update display.""" + self.reasoning_text += delta_text + + # Create reasoning expander on first delta + if not self.has_reasoning: + with self.containers.reasoning.container(): + reasoning_expander = st.expander("🧠 Reasoning", expanded=True) + self.reasoning_placeholder = reasoning_expander.empty() + + # Update reasoning text with cursor + if self.reasoning_placeholder: + self.reasoning_placeholder.markdown(self.reasoning_text + "▌") + + def finalize_reasoning(self): + """Remove cursor from reasoning display.""" + if self.reasoning_placeholder and self.reasoning_text: + self.reasoning_placeholder.markdown(self.reasoning_text) + + def update_message(self, delta_text): + """Add message text and update display.""" + self.full_response += delta_text + self.containers.message.markdown(self.full_response + "▌") + + def finalize_message(self): + """Remove cursor from message display.""" + self.containers.message.markdown(self.full_response) + + +# ============================================================================ +# Suggested Questions UI +# ============================================================================ + +def render_question_button(question, db_name, idx): + """Render a single question button.""" + button_key = f"question_btn_{idx}_{hash(question)}" + if st.button( + question, + key=button_key, + use_container_width=True, + help=f"From: {db_name}" + ): + st.session_state.selected_question = question + st.rerun() + + +def render_question_grid(suggestions, num_to_show): + """Render questions in a grid layout.""" + cols_per_row = 2 + for i in range(0, num_to_show, cols_per_row): + cols = st.columns(cols_per_row) + for j in range(cols_per_row): + idx = i + j + if idx < num_to_show: + question, db_name = suggestions[idx] + with cols[j]: + render_question_button(question, db_name, idx) + + +def render_show_more_button(suggestions): + """Render show more/less button if needed.""" + if len(suggestions) <= 4: + return + + _, col2, _ = st.columns([1, 1, 1]) + with col2: + if st.session_state.show_more_questions: + if st.button("Show Less", use_container_width=True): + st.session_state.show_more_questions = False + st.rerun() else: - return _handle_regular_response(turn_response, debug_events_list) - - def _handle_react_response(turn_response): - current_step_content = "" - final_answer = None - tool_results = [] - - for response in turn_response: - if not hasattr(response.event, "payload"): - yield ( - "\n\n🚨 :red[_Llama Stack server Error:_]\n" - "The response received is missing an expected `payload` attribute.\n" - "This could indicate a malformed response or an internal issue within the server.\n\n" - f"Error details: {response}" - ) - return - - payload = response.event.payload - - if payload.event_type == "step_progress" and hasattr(payload.delta, "text"): - current_step_content += payload.delta.text - continue - - if payload.event_type == "step_complete": - step_details = payload.step_details - - if step_details.step_type == "inference": - yield from _process_inference_step(current_step_content, tool_results, final_answer) - current_step_content = "" - elif step_details.step_type == "tool_execution": - tool_results = _process_tool_execution(step_details, tool_results) - current_step_content = "" - else: - current_step_content = "" - - if not final_answer and tool_results: - yield from _format_tool_results_summary(tool_results) - - def _process_inference_step(current_step_content, tool_results, final_answer): - try: - react_output_data = json.loads(current_step_content) - thought = react_output_data.get("thought") - action = react_output_data.get("action") - answer = react_output_data.get("answer") - - if answer and answer != "null" and answer is not None: - final_answer = answer - - if thought: - with st.expander("🤔 Thinking...", expanded=False): - st.markdown(f":grey[__{thought}__]") - - if action and isinstance(action, dict): - tool_name = action.get("tool_name") - tool_params = action.get("tool_params") - with st.expander(f'🛠 Action: Using tool "{tool_name}"', expanded=False): - st.json(tool_params) - - if answer and answer != "null" and answer is not None: - yield f"\n\n✅ **Final Answer:**\n{answer}" - - except json.JSONDecodeError: - yield f"\n\nFailed to parse ReAct step content:\n```json\n{current_step_content}\n```" - except Exception as e: - yield f"\n\nFailed to process ReAct step: {e}\n```json\n{current_step_content}\n```" - - return final_answer - - def _process_tool_execution(step_details, tool_results): - try: - if hasattr(step_details, "tool_responses") and step_details.tool_responses: - for tool_response in step_details.tool_responses: - tool_name = tool_response.tool_name - content = tool_response.content - tool_results.append((tool_name, content)) - with st.expander(f'⚙️ Observation (Result from "{tool_name}")', expanded=False): - try: - parsed_content = json.loads(content) - st.json(parsed_content) - except json.JSONDecodeError: - st.code(content, language=None) - else: - with st.expander("⚙️ Observation", expanded=False): - st.markdown(":grey[_Tool execution step completed, but no response data found._]") - except Exception as e: - with st.expander("⚙️ Error in Tool Execution", expanded=False): - st.markdown(f":red[_Error processing tool execution: {str(e)}_]") - - return tool_results - - def _format_tool_results_summary(tool_results): - yield "\n\n**Here's what I found:**\n" - for tool_name, content in tool_results: - try: - parsed_content = json.loads(content) - - if tool_name == "web_search" and "top_k" in parsed_content: - yield from _format_web_search_results(parsed_content) - elif "results" in parsed_content and isinstance(parsed_content["results"], list): - yield from _format_results_list(parsed_content["results"]) - elif isinstance(parsed_content, dict) and len(parsed_content) > 0: - yield from _format_dict_results(parsed_content) - elif isinstance(parsed_content, list) and len(parsed_content) > 0: - yield from _format_list_results(parsed_content) - except json.JSONDecodeError: - yield f"\n**{tool_name}** was used but returned complex data. Check the observation for details.\n" - except (TypeError, AttributeError, KeyError, IndexError) as e: - print(f"Error processing {tool_name} result: {type(e).__name__}: {e}") - - def _format_web_search_results(parsed_content): - for i, result in enumerate(parsed_content["top_k"], 1): - if i <= 3: - title = result.get("title", "Untitled") - url = result.get("url", "") - content_text = result.get("content", "").strip() - yield f"\n- **{title}**\n {content_text}\n [Source]({url})\n" - - def _format_results_list(results): - for i, result in enumerate(results, 1): - if i <= 3: - if isinstance(result, dict): - name = result.get("name", result.get("title", "Result " + str(i))) - description = result.get("description", result.get("content", result.get("summary", ""))) - yield f"\n- **{name}**\n {description}\n" - else: - yield f"\n- {result}\n" - - def _format_dict_results(parsed_content): - yield "\n```\n" - for key, value in list(parsed_content.items())[:5]: - if isinstance(value, str) and len(value) < 100: - yield f"{key}: {value}\n" - else: - yield f"{key}: [Complex data]\n" - yield "```\n" - - def _format_list_results(parsed_content): - yield "\n" - for _, item in enumerate(parsed_content[:3], 1): - if isinstance(item, str): - yield f"- {item}\n" - elif isinstance(item, dict) and "text" in item: - yield f"- {item['text']}\n" - elif isinstance(item, dict) and len(item) > 0: - first_value = next(iter(item.values())) - if isinstance(first_value, str) and len(first_value) < 100: - yield f"- {first_value}\n" - - def _handle_regular_response(turn_response, debug_events_list): - - # Use itertools.tee to duplicate the stream for UI and debug logging - # This is crucial because a generator can only be consumed once. - from itertools import tee - ui_stream, debug_log_stream = tee(turn_response, 2) - - for response in ui_stream: - if hasattr(response.event, "payload"): - if response.event.payload.event_type == "step_progress": - if hasattr(response.event.payload.delta, "text"): - yield response.event.payload.delta.text - if response.event.payload.event_type == "step_complete": - if response.event.payload.step_details.step_type == "tool_execution": - if response.event.payload.step_details.tool_calls: - tool_name = str(response.event.payload.step_details.tool_calls[0].tool_name) - yield f'\n\n🛠 :grey[_Using "{tool_name}" tool:_]\n\n' - else: - yield "No tool_calls present in step_details" - if response.event.payload.step_details.step_type == "shield_call": - if response.event.payload.step_details.violation: - yield response.event.payload.step_details.violation.user_message - else: - yield f"Error occurred in the Llama Stack Cluster: {response}" - debug_events_list.append({"type": "warning", "source": "_handle_regular_response", "details": "Unexpected event structure", "event": str(response)[:200]}) - - # Process the debug log stream separately - # EventLogger helps parse and structure these events - for log_entry in EventLogger().log(debug_log_stream): - if log_entry.role == "tool_execution": # Or other relevant roles - debug_events_list.append({"type": "tool_log", "content": log_entry.content}) - # Add other log types as needed for debugging - - def agent_process_prompt(prompt, debug_events_list): - print(f"In agent_process_prompt: {prompt}") - # Send the prompt to the agent - turn_response = agent.create_turn( - session_id=session_id, - messages=[UserMessage(role="user", content=prompt)], - stream=True, + button_text = f"Show More ({len(suggestions) - 4} more)" + if st.button(button_text, use_container_width=True): + st.session_state.show_more_questions = True + st.rerun() + + +def display_suggested_questions(selected_vector_dbs): + """Display suggested questions based on selected databases.""" + if not selected_vector_dbs: + return + + vector_dbs = list(llama_stack_api.client.vector_stores.list() or []) + suggestions = get_suggestions_for_databases(selected_vector_dbs, vector_dbs) + + if not suggestions: + return + + st.markdown("### 💡 Suggested Questions") + + # Determine how many questions to show + num_to_show = ( + len(suggestions) if st.session_state.show_more_questions + else min(4, len(suggestions)) + ) + + # Display questions and controls + render_question_grid(suggestions, num_to_show) + render_show_more_button(suggestions) + st.markdown("---") + + +# ============================================================================ +# Main Prompt Processing +# ============================================================================ + +def process_prompt(prompt, config): + """Process user prompt and generate response.""" + st.session_state.messages.append({"role": "user", "content": prompt}) + with st.chat_message("user"): + st.markdown(prompt) + + # Create assistant message context and setup shared containers once + with st.chat_message("assistant"): + state = ResponseState() + + # Call the appropriate mode-specific function + if config.processing_mode == "Direct": + direct_process_prompt(prompt, state, config) + elif config.processing_mode == "Agent-based": + agent_process_prompt(prompt, state, config) + + st.rerun() + + +def tool_chat_page(): + """Main chat page with RAG support in Direct and Agent-based modes.""" + st.title("💬 Chat") + + # Fetch models and tools + model_list, builtin_tools_list, mcp_tools_list = fetch_models_and_tools() + + # Render sidebar and get configuration + with st.sidebar: + sidebar_config = render_sidebar_configuration( + model_list, builtin_tools_list, mcp_tools_list ) - print(f"In agent_process_prompt: {turn_response}") - response_content = st.write_stream(response_generator(turn_response, debug_events_list)) - print(f"In agent_process_prompt: {response_content}") - st.session_state.messages.append({"role": "assistant", "content": response_content}) + # Initialize session state + initialize_session_state() + + # Create chat configuration object + chat_config = ChatConfig( + model=sidebar_config['model'], + processing_mode=sidebar_config['processing_mode'], + system_prompt=sidebar_config['system_prompt'], + conversation_id=st.session_state["conversation_id"], + toolgroup_selection=sidebar_config['toolgroup_selection'], + selected_vector_dbs=sidebar_config['selected_vector_dbs'], + sampling=SamplingParams( + temperature=sidebar_config['temperature'], + max_infer_iters=sidebar_config['max_infer_iters'] + ) + ) - def direct_process_prompt(prompt, debug_events_list): - # Query the vector DB - if selected_vector_dbs: - vector_dbs = llama_stack_api.client.vector_dbs.list() or [] - vector_db_ids = [vector_db.identifier for vector_db in vector_dbs if get_vector_db_name(vector_db) in selected_vector_dbs] - with st.spinner("Retrieving context (RAG)..."): - try: - rag_response = llama_stack_api.client.tool_runtime.rag_tool.query( - content=prompt, vector_db_ids=list(vector_db_ids) - ) - prompt_context = rag_response.content - debug_events_list.append({ - "type": "rag_query_direct_mode", "query": prompt, - "vector_dbs": selected_vector_dbs, - "context_length": len(prompt_context) if prompt_context else 0, - "context_preview": (str(prompt_context[:200]) + "..." if prompt_context else "None") - }) - except Exception as e: - st.warning(f"RAG Error (Direct Mode): {e}") - debug_events_list.append({"type": "error", "source": "rag_direct_mode", "content": str(e)}) - else: - prompt_context = None - - with st.chat_message("assistant"): - message_placeholder = st.empty() - full_response = "" - retrieval_response = "" - - # Construct the extended prompt - if prompt_context: - extended_prompt = f"Please answer the following query using the context below.\n\nCONTEXT:\n{prompt_context}\n\nQUERY:\n{prompt}" - else: - extended_prompt = f"Please answer the following query. \n\nQUERY:\n{prompt}" - - # Run inference directly - #st.session_state.messages.append({"role": "user", "content": extended_prompt}) - messages_for_direct_api = ( - [{'role': 'system', 'content': system_prompt}] + - [{'role': 'user', 'content': extended_prompt}] - ) - response = llama_stack_api.client.inference.chat_completion( - messages=messages_for_direct_api, - model_id=model, - sampling_params={ - "strategy": get_strategy(temperature, top_p), - "max_tokens": max_tokens, - "repetition_penalty": repetition_penalty, - }, - stream=True, - timeout=120, - ) - - # Display assistant response - for chunk in response: - if chunk.event: - response_delta = chunk.event.delta - if isinstance(response_delta, ToolCallDelta): - retrieval_response += response_delta.tool_call.replace("====", "").strip() - #retrieval_message_placeholder.info(retrieval_response) - else: - full_response += chunk.event.delta.text - message_placeholder.markdown(full_response + "▌") - message_placeholder.markdown(full_response) - - response_dict = {"role": "assistant", "content": full_response, "stop_reason": "end_of_message"} - st.session_state.messages.append(response_dict) - #st.session_state.displayed_messages.append(response_dict) - - def process_prompt(prompt): - print(f"In process_prompt: {prompt}") - st.session_state.messages.append({"role": "user", "content": prompt}) - with st.chat_message("user"): - st.markdown(prompt) - - # Prepare for assistant's response - # Each assistant turn gets its own list for debug events - st.session_state.debug_events.append([]) - current_turn_debug_events_list = st.session_state.debug_events[-1] # Get the list for this turn - - st.session_state.prompt = prompt - - print(f"In process_prompt: {st.session_state.prompt}") - print(f"In processing mode: {processing_mode}") - if processing_mode == "Agent-based": - agent_process_prompt(st.session_state.prompt, current_turn_debug_events_list) - else: # rag_mode == "Direct" - direct_process_prompt(st.session_state.prompt, current_turn_debug_events_list) - #st.session_state.prompt = None - st.rerun() + # Display chat history + render_history() + + # Display suggested questions + display_suggested_questions(chat_config.selected_vector_dbs) # Handle selected question from suggestions if st.session_state.selected_question: prompt = st.session_state.selected_question - st.session_state.selected_question = None # Clear the selected question + st.session_state.selected_question = None + process_prompt(prompt, chat_config) - process_prompt(prompt) - - # Handle manual chat input if prompt := st.chat_input(placeholder="Ask a question..."): - # Append the user message to history and display it - process_prompt(prompt) - - - + process_prompt(prompt, chat_config) tool_chat_page() diff --git a/frontend/llama_stack_ui/distribution/ui/page/playground/direct.py b/frontend/llama_stack_ui/distribution/ui/page/playground/direct.py new file mode 100644 index 00000000..1b4744b0 --- /dev/null +++ b/frontend/llama_stack_ui/distribution/ui/page/playground/direct.py @@ -0,0 +1,231 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +""" +Direct mode implementation for chat with manual RAG. +""" + +import logging +import traceback + +import streamlit as st + +from llama_stack_ui.distribution.ui.modules.api import llama_stack_api +from llama_stack_ui.distribution.ui.modules.utils import clean_text, get_vector_db_name + + +logger = logging.getLogger(__name__) + + +# ============================================================================ +# Direct Mode - Helper Functions +# ============================================================================ + +def extract_text_from_search_result(result): + """Extract and clean text content from a search result object.""" + text = None + + # Handle Data objects with content list + if hasattr(result, 'content') and isinstance(result.content, list): + for content_item in result.content: + if hasattr(content_item, 'text'): + text = content_item.text + break + + # Handle simple content attribute + elif hasattr(result, 'content') and isinstance(result.content, str): + text = result.content + + # Handle dict format + elif isinstance(result, dict) and 'content' in result: + if isinstance(result['content'], list) and result['content']: + text = result['content'][0].get('text', '') + else: + text = result['content'] + + return clean_text(text) if text else None + + +def search_vector_store_direct(prompt, vector_db_id, vector_db_name, state): + """Search vector store and extract context for Direct mode.""" + search_results = [] + context_parts = [] + display_results = [] + + # Show search status + with state.containers.tool_status: + st.markdown(f"🛠 :grey[_Searching vector store: {vector_db_name}_]") + + logger.debug("Searching vector store %s with query: %s", vector_db_id, prompt) + + # Call vector store search API + search_response = llama_stack_api.client.vector_stores.search( + vector_store_id=vector_db_id, + query=prompt, + ) + + logger.debug("Search response: %s", search_response) + + # Extract search results from response + if hasattr(search_response, 'data') and search_response.data: + search_results = search_response.data + elif hasattr(search_response, 'chunks') and search_response.chunks: + search_results = search_response.chunks + elif hasattr(search_response, 'results') and search_response.results: + search_results = search_response.results + + # Display and process search results + if search_results: + # Build context and display data from search results + for result in search_results: + text_content = extract_text_from_search_result(result) + if text_content: + attrs = getattr(result, 'attributes', {}) + source = attrs.get('source') or getattr(result, 'filename', 'unknown') + context_parts.append(f"[Source: {source}]: {text_content}") + display_results.append({"source": source, "text": text_content}) + + with state.containers.tool_results: + with st.expander(f"📄 Search Results from '{vector_db_name}'", expanded=False): + st.json(display_results) + + logger.debug("Built context with %s documents", len(context_parts)) + else: + # No results found + with state.containers.tool_results: + st.info(f"No results found in '{vector_db_name}'") + + return search_results, context_parts, display_results + + +def build_rag_messages(prompt, context_parts, system_prompt): + """Build messages for LLM - with or without RAG context.""" + if context_parts: + # RAG mode: Format user message with explicit CONTEXT and QUERY sections + context = "\n\n".join(context_parts) + extended_prompt = ( + f"Please answer the following query using the context below.\n\n" + f"CONTEXT:\n{context}\n\n" + f"QUERY:\n{prompt}" + ) + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": extended_prompt} + ] + logger.debug( + "Built RAG prompt with %s documents, total context length: %s", + len(context_parts), len(context) + ) + else: + # Normal chat mode: no context + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt} + ] + logger.debug("No context - using normal chat mode") + + return messages + + +def stream_completions_direct(completion_response, state): + """Stream chunks from Completions API and update state.""" + for chunk in completion_response: + logger.debug("Completion chunk: %s", chunk) + if hasattr(chunk, 'choices') and len(chunk.choices) > 0: + delta = chunk.choices[0].delta + + # Handle reasoning content (for models that support it like R1) + if hasattr(delta, 'reasoning_content') and delta.reasoning_content: + state.update_reasoning(delta.reasoning_content) + + # Handle regular content + if hasattr(delta, 'content') and delta.content: + state.update_message(delta.content) + + +def save_direct_response_to_session(state, all_search_results): + """Save direct response to session state.""" + state.finalize_reasoning() + state.finalize_message() + + response_dict = { + "role": "assistant", + "content": state.full_response, + "stop_reason": "end_of_message" + } + + # Save reasoning if present + if state.reasoning_text: + response_dict["reasoning"] = state.reasoning_text + + # Save search results for history display if we had any + if all_search_results: + db_names = [name for name, _ in all_search_results] + response_dict["tool_results"] = [ + { + 'title': f'📄 Search Results from \'{name}\'', + 'type': 'json', + 'content': display + } + for name, display in all_search_results + ] + response_dict["tool_status"] = ( + f"🛠 :grey[_Searched vector stores: {', '.join(db_names)}_]" + ) + + st.session_state.messages.append(response_dict) + + +# ============================================================================ +# Direct Mode - Main Function +# ============================================================================ + +def direct_process_prompt(prompt, state, config): + """Direct mode: Manual RAG with completions API.""" + context_parts = [] + all_search_results = [] + + vector_dbs = st.session_state.get("direct_vector_dbs", []) + if not vector_dbs: + logger.debug("No vector DB selected - normal chat mode") + + try: + # Step 1: Search each selected vector store + for vector_db in vector_dbs: + vector_db_id = vector_db.id + vector_db_name = get_vector_db_name(vector_db) + search_results, parts, display = search_vector_store_direct( + prompt, vector_db_id, vector_db_name, state + ) + if search_results: + all_search_results.append((vector_db_name, display)) + context_parts.extend(parts) + + # Step 2: Build messages (with or without RAG context) + messages = build_rag_messages(prompt, context_parts, config.system_prompt) + + # Step 3: Call completions API + logger.debug("Calling completions API with %s messages", len(messages)) + for i, msg in enumerate(messages): + logger.debug(" Message %s (%s): %s...", i, msg['role'], msg['content'][:200]) + + completion_response = llama_stack_api.client.chat.completions.create( + model=config.model, + messages=messages, + temperature=config.sampling.temperature, + stream=True, + ) + + # Step 4: Stream response and update UI + stream_completions_direct(completion_response, state) + + # Step 5: Save to session + save_direct_response_to_session(state, all_search_results) + + except Exception as e: + st.error(f"Error in Direct mode: {str(e)}") + logger.debug("Direct mode error: %s", e) + logger.debug("%s", traceback.format_exc()) diff --git a/frontend/llama_stack_ui/distribution/ui/page/upload/__init__.py b/frontend/llama_stack_ui/distribution/ui/page/upload/__init__.py index d4a3e15c..756f351d 100644 --- a/frontend/llama_stack_ui/distribution/ui/page/upload/__init__.py +++ b/frontend/llama_stack_ui/distribution/ui/page/upload/__init__.py @@ -3,4 +3,3 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. - diff --git a/frontend/llama_stack_ui/distribution/ui/page/upload/upload.py b/frontend/llama_stack_ui/distribution/ui/page/upload/upload.py index ffd01751..2f8e09aa 100644 --- a/frontend/llama_stack_ui/distribution/ui/page/upload/upload.py +++ b/frontend/llama_stack_ui/distribution/ui/page/upload/upload.py @@ -1,67 +1,429 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +"""Upload documents page for managing vector databases and document ingestion.""" + +import traceback + import streamlit as st -from llama_stack_client import RAGDocument + from llama_stack_ui.distribution.ui.modules.api import llama_stack_api -from llama_stack_ui.distribution.ui.modules.utils import data_url_from_file +from llama_stack_ui.distribution.ui.modules.utils import get_vector_db_name + + +def _init_upload_page_session_state(): + """Initialize all session state variables needed by the upload page.""" + defaults = { + "creation_status": None, + "creation_message": "", + "selected_vector_db": "", + "newly_created_vdb": None, + } + for key, value in defaults.items(): + if key not in st.session_state: + st.session_state[key] = value + + if "vector_db_selector" not in st.session_state: + st.session_state["vector_db_selector"] = st.session_state["selected_vector_db"] + + +def _show_status(status_key, message_key): + """Show and clear a status message from session state.""" + status = st.session_state[status_key] + if status == "success": + st.success(st.session_state[message_key]) + elif status == "error": + st.error(st.session_state[message_key]) + else: + return + st.session_state[status_key] = None + st.session_state[message_key] = "" + + +def _build_dropdown_options(vdb_list): + """Build dropdown options from the vector database list. + + Returns: + tuple: (dropdown_options, create_new_option) + """ + create_new_option = "➕ Create New" + + if vdb_list: + existing_vdbs = [get_vector_db_name(v) for v in vdb_list] + return existing_vdbs + [create_new_option], create_new_option + + return [create_new_option], create_new_option + + +def _sync_vector_db_selection(dropdown_options, vdb_list): + """Sync the vector database selection state with available options.""" + # Priority 1: Auto-select a newly created database + newly_created = st.session_state["newly_created_vdb"] + if newly_created and newly_created in dropdown_options: + st.session_state["selected_vector_db"] = newly_created + st.session_state["vector_db_selector"] = newly_created + st.session_state["newly_created_vdb"] = None + return + + # Priority 2: Keep the previously selected database if it still exists + selected = st.session_state["selected_vector_db"] + if selected and selected in dropdown_options: + st.session_state["vector_db_selector"] = selected + return + + # Priority 3: Smart default + if vdb_list: + first_db = dropdown_options[0] + st.session_state["selected_vector_db"] = first_db + st.session_state["vector_db_selector"] = first_db + else: + st.session_state["selected_vector_db"] = dropdown_options[0] + st.session_state["vector_db_selector"] = dropdown_options[0] + def upload_page(): + """Page to upload documents and manage vector databases for RAG.""" + st.title("📄 Upload Documents") + + _init_upload_page_session_state() + _show_status("creation_status", "creation_message") + + vdb_list = llama_stack_api.client.vector_stores.list() + dropdown_options, create_new_option = _build_dropdown_options(vdb_list) + _sync_vector_db_selection(dropdown_options, vdb_list) + + def on_vector_db_change(): + st.session_state["selected_vector_db"] = st.session_state["vector_db_selector"] + + selected_vector_db = st.selectbox( + "Select a vector database", + dropdown_options, + key="vector_db_selector", + on_change=on_vector_db_change, + help="Your selection will be remembered when you navigate to other pages" + ) + + if selected_vector_db != st.session_state["selected_vector_db"]: + st.session_state["selected_vector_db"] = selected_vector_db + + if selected_vector_db == create_new_option: + _show_create_vector_db_ui() + elif selected_vector_db: + selected_vdb_obj = None + for vdb in vdb_list: + if get_vector_db_name(vdb) == selected_vector_db: + selected_vdb_obj = vdb + break + + _show_existing_documents_table(selected_vector_db, selected_vdb_obj) + st.subheader(f"📁 Upload Documents to '{selected_vector_db}'") + _show_document_upload_ui(selected_vector_db, selected_vdb_obj) + + +def _show_create_vector_db_ui(): + """Display UI for creating a new vector database.""" + st.subheader("Create New Vector Database") + + if "new_vdb_name" not in st.session_state: + st.session_state["new_vdb_name"] = "" + + new_vdb_name = st.text_input( + "Add New Vector Database", + value=st.session_state["new_vdb_name"], + help="Enter a unique name for the new vector database", + key="new_vdb_name_input" + ) + + st.session_state["new_vdb_name"] = new_vdb_name + + if st.button("Add", type="primary", disabled=not new_vdb_name.strip()): + _create_vector_database(new_vdb_name.strip()) + + +def _create_vector_database(vdb_name): + """Create a new vector database using the LlamaStack API. + + Args: + vdb_name (str): Name for the new vector database """ - Page to upload documents and create a vector database for RAG. + try: + st.session_state["creation_status"] = None + st.session_state["creation_message"] = "" + + if not vdb_name or not vdb_name.strip(): + st.session_state["creation_status"] = "error" + st.session_state["creation_message"] = "Vector database name cannot be empty." + return + + existing_vdbs = llama_stack_api.client.vector_stores.list() + existing_names = [get_vector_db_name(vdb) for vdb in existing_vdbs] + if vdb_name in existing_names: + st.session_state["creation_status"] = "error" + st.session_state["creation_message"] = ( + f"Vector database '{vdb_name}' already exists. " + "Please choose a different name." + ) + return + + with st.spinner(f"Creating vector database '{vdb_name}'..."): + _vector_db = llama_stack_api.client.vector_stores.create( + name=vdb_name, + ) + + st.session_state["creation_status"] = "success" + st.session_state["creation_message"] = ( + f"Vector database '{vdb_name}' created successfully!" + ) + st.session_state["newly_created_vdb"] = vdb_name + st.session_state["new_vdb_name"] = "" + st.rerun() + + except Exception as e: # pylint: disable=broad-exception-caught + st.session_state["creation_status"] = "error" + st.session_state["creation_message"] = f"Error creating vector database: {str(e)}" + + +def _show_document_upload_ui(vector_db_name, vector_db_obj=None): + """Display UI for uploading documents to an existing vector database. + + Args: + vector_db_name (str): Name of the selected vector database + vector_db_obj: The actual vector database object with identifier """ - st.title("📄 Upload") - # File/Directory Upload Section - st.subheader("Create Vector DB") - # Let user select files to ingest + if "upload_status" not in st.session_state: + st.session_state["upload_status"] = None + if "upload_message" not in st.session_state: + st.session_state["upload_message"] = "" + + _show_status("upload_status", "upload_message") + + upload_key = f"processed_files_{vector_db_name}" + if upload_key not in st.session_state: + st.session_state[upload_key] = set() + uploaded_files = st.file_uploader( - "Upload file(s) or directory", + "Browse and select files to upload (files will upload automatically)", accept_multiple_files=True, - type=["txt", "pdf", "doc", "docx"], # supported file types + type=["txt", "pdf", "doc", "docx"], + key=f"uploader_{vector_db_name}", + help=( + "Select one or more documents - they will be uploaded " + "automatically to this vector database" + ), ) - # Process uploaded files + if uploaded_files: - # Show upload success and prompt for DB name - st.success(f"Successfully uploaded {len(uploaded_files)} files") - vector_db_name = st.text_input( - "Vector Database Name", - value="rag_vector_db", - help="Enter a unique identifier for this vector database", - ) - if st.button("Create Vector Database"): - # Convert uploaded files into RAGDocument instances - documents = [ - RAGDocument( - document_id=uploaded_file.name, - content=data_url_from_file(uploaded_file), - ) - for i, uploaded_file in enumerate(uploaded_files) - ] - - # Determine provider for vector IO - providers = llama_stack_api.client.providers.list() - vector_io_provider = None - for x in providers: - if x.api == "vector_io": - vector_io_provider = x.provider_id - break - - # Register new vector database - vector_db = llama_stack_api.client.vector_dbs.register( - vector_db_id=vector_db_name, - embedding_dimension=384, - embedding_model="all-MiniLM-L6-v2", - provider_id=vector_io_provider, - ) - vector_db_id = vector_db.identifier + file_set_id = frozenset([f.name + str(f.size) for f in uploaded_files]) + + if file_set_id not in st.session_state[upload_key]: + st.session_state[upload_key].add(file_set_id) - # Insert documents into the vector database - llama_stack_api.client.tool_runtime.rag_tool.insert( - vector_db_id=vector_db_id, - documents=documents, - chunk_size_in_tokens=512, + if vector_db_obj and hasattr(vector_db_obj, 'id'): + vector_db_id = vector_db_obj.id + else: + vector_db_id = vector_db_name + + _upload_documents_to_database( + vector_db_name, uploaded_files, vector_db_id ) - st.success("Vector database created successfully!") - # Reset form fields - uploaded_files.clear() - vector_db_name = "" -upload_page() \ No newline at end of file +def _upload_documents_to_database(vector_db_name, uploaded_files, vector_db_id=None): + """Upload documents to an existing vector database. + + Args: + vector_db_name (str): Name of the target vector database + uploaded_files: List of uploaded files from Streamlit file uploader + vector_db_id (str): The actual database identifier for API calls + """ + try: + st.session_state["upload_status"] = None + st.session_state["upload_message"] = "" + + if not uploaded_files: + st.session_state["upload_status"] = "error" + st.session_state["upload_message"] = "No files selected for upload." + return + + actual_db_id = vector_db_id or vector_db_name + uploaded_file_ids = [] + + with st.spinner(f"Uploading {len(uploaded_files)} file(s)..."): + for uploaded_file in uploaded_files: + file_response = llama_stack_api.client.files.create( + file=uploaded_file, + purpose="assistants" + ) + llama_stack_api.client.vector_stores.files.create( + vector_store_id=actual_db_id, + file_id=file_response.id, + ) + uploaded_file_ids.append(file_response.id) + + st.session_state["upload_status"] = "success" + st.session_state["upload_message"] = ( + f"Successfully uploaded {len(uploaded_files)} document(s) " + f"to '{vector_db_name}'!" + ) + st.rerun() + + except Exception as e: # pylint: disable=broad-exception-caught + st.session_state["upload_status"] = "error" + st.session_state["upload_message"] = f"Error uploading documents: {str(e)}" + st.rerun() + + +def _get_documents_from_vector_store(vector_store_id): + """Get files from a vector store using the Files API. + + Args: + vector_store_id (str): The vector store identifier + + Returns: + list: List of file objects, or None if query fails + """ + try: + files_response = llama_stack_api.client.vector_stores.files.list( + vector_store_id=vector_store_id + ) + + if hasattr(files_response, 'data'): + return files_response.data + return list(files_response) if files_response else None + + except Exception as e: # pylint: disable=broad-exception-caught + print(f"Error listing files from vector store: {e}") + return None + + +def _delete_file_from_vector_store(vector_store_id, file_id): + """Delete a file from a vector store using the Files API. + + Args: + vector_store_id (str): The vector store identifier + file_id (str): The file ID to delete + + Returns: + tuple: (success: bool, error_message: str) + """ + try: + llama_stack_api.client.vector_stores.files.delete( + file_id=file_id, + vector_store_id=vector_store_id + ) + return True, None + except Exception as e: # pylint: disable=broad-exception-caught + return False, str(e) + + +def _get_file_sources(files): + """Retrieve the source name for each file. + + Prefers attributes["source"], falls back to filename from Files API. + + Args: + files: List of vector store file objects + + Returns: + dict: Mapping of file_id to source name + """ + source_names = {} + for file_obj in files: + file_id = getattr(file_obj, 'id', None) + if not file_id: + continue + attrs = getattr(file_obj, 'attributes', None) or {} + source = attrs.get("source") + if not source: + try: + file_info = llama_stack_api.client.files.retrieve(file_id) + source = getattr(file_info, 'filename', None) + except Exception: # pylint: disable=broad-exception-caught + source = None + source_names[file_id] = source + return source_names + + +def _render_documents_table(files, source_names): + """Render the documents table with source and document ID columns. + + Args: + files: List of vector store file objects + source_names (dict): Mapping of file_id to source name + """ + # Add CSS for bordered table rows + st.markdown(""" + + """, unsafe_allow_html=True) + + # Display table header + col1, col2, col3 = st.columns([0.5, 3, 3]) + with col1: + st.markdown("**#**") + with col2: + st.markdown("**Source**") + with col3: + st.markdown("**Document ID**") + + # Display each file in a row + for idx, file_obj in enumerate(files, start=1): + col1, col2, col3 = st.columns([0.5, 3, 3]) + file_id = getattr(file_obj, 'id', 'unknown') + source = source_names.get(file_id) or "unknown" + + with col1: + st.write(idx) + with col2: + st.write(source) + with col3: + st.write(file_id) + + +def _show_existing_documents_table(vector_db_name, vector_db_obj=None): + """Display information about documents in the selected vector database. + + Args: + vector_db_name (str): Display name of the selected vector database + vector_db_obj: The actual vector database object with identifier + """ + try: + if vector_db_obj and hasattr(vector_db_obj, 'id'): + vector_db_id = vector_db_obj.id + else: + vector_db_id = vector_db_name + + if "delete_status" not in st.session_state: + st.session_state["delete_status"] = None + if "delete_message" not in st.session_state: + st.session_state["delete_message"] = "" + + _show_status("delete_status", "delete_message") + + with st.spinner("Checking for documents..."): + files = _get_documents_from_vector_store(vector_db_id) + + if files: + st.subheader(f"📄 Documents in '{vector_db_name}'") + source_names = _get_file_sources(files) + _render_documents_table(files, source_names) + + except Exception as e: # pylint: disable=broad-exception-caught + st.error(f"Error loading document information: {str(e)}") + with st.expander("Error Details"): + st.code(traceback.format_exc()) + + +upload_page() diff --git a/frontend/pyproject.toml b/frontend/pyproject.toml index be65e231..41c56e7d 100644 --- a/frontend/pyproject.toml +++ b/frontend/pyproject.toml @@ -10,11 +10,12 @@ requires-python = ">=3.12" dependencies = [ "streamlit", "pandas", - "llama-stack-client==0.2.23", + "llama-stack-client==0.3.5", "requests", "streamlit-option-menu", - "llama-stack==0.2.23", + "llama-stack==0.3.5", "fire", + "asyncpg", ] [tool.setuptools] diff --git a/pytest.ini b/pytest.ini index c37d9fcb..8636a617 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,6 +1,5 @@ [pytest] # Pytest configuration -asyncio_default_fixture_loop_scope = function # Test discovery patterns python_files = test_*.py diff --git a/tests/e2e_ui/test_chat_ui.py b/tests/e2e_ui/test_chat_ui.py index aee05add..aeec0187 100644 --- a/tests/e2e_ui/test_chat_ui.py +++ b/tests/e2e_ui/test_chat_ui.py @@ -29,11 +29,23 @@ def browser_context_args(browser_context_args): @pytest.fixture(autouse=True) def wait_for_app(page: Page): """Wait for the Streamlit app to be ready before each test""" - page.goto(RAG_UI_ENDPOINT) - # Wait for Streamlit to finish loading - page.wait_for_load_state("networkidle") - # Give Streamlit additional time to initialize - time.sleep(2) + # Retry navigation in case of transient connection issues + max_retries = 3 + for attempt in range(max_retries): + try: + page.goto(RAG_UI_ENDPOINT, timeout=60000, wait_until="domcontentloaded") + # Wait for Streamlit to finish loading + page.wait_for_load_state("networkidle", timeout=60000) + # Give Streamlit additional time to initialize + time.sleep(2) + # Verify page actually loaded + if page.url.startswith(RAG_UI_ENDPOINT): + return + except Exception as e: + if attempt == max_retries - 1: + raise + print(f"Navigation attempt {attempt + 1} failed: {e}, retrying...") + time.sleep(2) class TestChatUIBasics: @@ -108,18 +120,6 @@ def test_agent_mode_shows_toolgroups(self, page: Page): toolgroups = page.get_by_text("Available ToolGroups", exact=False) expect(toolgroups).to_be_visible(timeout=TEST_TIMEOUT) - - def test_agent_type_selector(self, page: Page): - """Test agent type selector (Regular vs ReAct)""" - agent_radio = page.get_by_text("Agent-based", exact=False).first - if agent_radio.is_visible(): - agent_radio.click() - time.sleep(1) - - # Look for agent type options with more specific selectors - # Check if either Regular or ReAct options exist - page_content = page.content() - assert "Regular" in page_content or "ReAct" in page_content class TestConfigurationOptions: diff --git a/tests/integration/test_upload_integration.py b/tests/integration/test_upload_integration.py index 549f28fd..378ac1a9 100644 --- a/tests/integration/test_upload_integration.py +++ b/tests/integration/test_upload_integration.py @@ -11,6 +11,22 @@ # Add the frontend directory to the path sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../frontend')) +# Mock all external dependencies before any imports from the upload module +# This is required because @patch decorators try to import the target module +mock_streamlit = MagicMock() +mock_streamlit.session_state = {} +sys.modules['streamlit'] = mock_streamlit +sys.modules['asyncpg'] = MagicMock() +sys.modules['pandas'] = MagicMock() + +# Mock llama_stack_client with a proper RAGDocument mock +mock_llama_stack_client = MagicMock() +def mock_rag_document(**kwargs): + """Create a dict-like RAGDocument mock""" + return kwargs +mock_llama_stack_client.RAGDocument = mock_rag_document +sys.modules['llama_stack_client'] = mock_llama_stack_client + # Configuration LLAMA_STACK_ENDPOINT = os.getenv("LLAMA_STACK_ENDPOINT", "http://localhost:8321") diff --git a/tests/unit/test_upload.py b/tests/unit/test_upload.py index 0e4f0011..de5fe7bb 100644 --- a/tests/unit/test_upload.py +++ b/tests/unit/test_upload.py @@ -2,217 +2,323 @@ Unit tests for the upload module Tests document upload and vector DB creation logic """ -import pytest -from unittest.mock import Mock, patch, MagicMock -import sys +import asyncio import os +import sys +from unittest.mock import AsyncMock, MagicMock, Mock, patch + +import pytest # Add the frontend directory to the path sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../frontend')) -# Mock streamlit before importing -sys.modules['streamlit'] = MagicMock() +# Mock all external dependencies before any imports from the upload module +# This is required because @patch decorators try to import the target module +mock_streamlit = MagicMock() +mock_streamlit.session_state = {} +sys.modules['streamlit'] = mock_streamlit +sys.modules['asyncpg'] = MagicMock() +sys.modules['pandas'] = MagicMock() + +# Mock llama_stack_client with a proper RAGDocument mock +mock_llama_stack_client = MagicMock() +def mock_rag_document(**kwargs): + """Create a dict-like RAGDocument mock""" + return kwargs +mock_llama_stack_client.RAGDocument = mock_rag_document +sys.modules['llama_stack_client'] = mock_llama_stack_client +# Now we can safely import modules that will be patched +# Pre-import the modules so @patch can find them +from llama_stack_ui.distribution.ui.modules import api, utils +from llama_stack_ui.distribution.ui.page.upload import upload as upload_module -class TestVectorDBConfiguration: - """Test vector database configuration and setup""" + +class MockAsyncContextManager: + """Mock async context manager for pool.acquire()""" + def __init__(self, conn): + self.conn = conn - def test_vector_db_default_name(self): - """Test default vector database name""" - default_name = "rag_vector_db" - assert default_name == "rag_vector_db" - assert len(default_name) > 0 + async def __aenter__(self): + return self.conn - def test_vector_db_embedding_dimension(self): - """Test that embedding dimension is set correctly for all-MiniLM-L6-v2""" - embedding_dimension = 384 - embedding_model = "all-MiniLM-L6-v2" - - assert embedding_dimension == 384 - assert embedding_model == "all-MiniLM-L6-v2" + async def __aexit__(self, exc_type, exc_val, exc_tb): + return None + + +def create_mock_pool_with_connection(mock_conn): + """ + Helper to create a mock connection pool that yields the given connection. - def test_chunk_size_configuration(self): - """Test that chunk size is set to 512 tokens""" - chunk_size = 512 - assert chunk_size == 512 + Args: + mock_conn: The mock connection to return from pool.acquire() + + Returns: + MagicMock: A mock pool with proper acquire() context manager + """ + mock_pool = MagicMock() + mock_pool.acquire.return_value = MockAsyncContextManager(mock_conn) + return mock_pool -class TestDocumentProcessing: - """Test document processing and RAGDocument creation""" +class TestGetDocumentsFromPgvector: + """Unit tests for _get_documents_from_pgvector function""" - def test_supported_file_types(self): - """Test that supported file types are correctly defined""" - supported_types = ["txt", "pdf", "doc", "docx"] - - assert "txt" in supported_types - assert "pdf" in supported_types - assert "doc" in supported_types - assert "docx" in supported_types + def test_get_documents_success(self): + """Test successful retrieval of documents from pgvector""" + # Setup mock connection + mock_conn = AsyncMock() + mock_rows = [ + {'document_id': 'document1.pdf'}, + {'document_id': 'document2.txt'}, + {'document_id': 'document3.docx'}, + ] + mock_conn.fetch = AsyncMock(return_value=mock_rows) + + # Create mock pool + mock_pool = create_mock_pool_with_connection(mock_conn) + + # Patch _get_pg_pool to return our mock pool + async def mock_get_pool(): + return mock_pool + + with patch.object(upload_module, '_get_pg_pool', mock_get_pool): + # Call the actual function + result = upload_module._get_documents_from_pgvector("my-test-db") + + # Verify the result + assert result == ['document1.pdf', 'document2.txt', 'document3.docx'] + + # Verify acquire was called (connection borrowed from pool) + mock_pool.acquire.assert_called_once() - def test_document_id_from_filename(self): - """Test that document ID is created from filename""" - filename = "test_document.pdf" - document_id = filename + def test_get_documents_empty_result(self): + """Test that empty results return None""" + # Setup mock connection with empty result + mock_conn = AsyncMock() + mock_conn.fetch = AsyncMock(return_value=[]) - assert document_id == "test_document.pdf" - assert document_id.endswith(".pdf") + mock_pool = create_mock_pool_with_connection(mock_conn) + + async def mock_get_pool(): + return mock_pool + + with patch.object(upload_module, '_get_pg_pool', mock_get_pool): + result = upload_module._get_documents_from_pgvector("empty-db") + + # Empty result should return None + assert result is None - @patch('llama_stack_ui.distribution.ui.modules.utils.data_url_from_file') - def test_rag_document_creation(self, mock_data_url): - """Test RAGDocument creation from uploaded file""" - from llama_stack_client import RAGDocument - - # Mock file and data URL - mock_data_url.return_value = "data:text/plain;base64,SGVsbG8gV29ybGQ=" - - mock_file = Mock() - mock_file.name = "test.txt" - - # Create RAGDocument as done in upload.py - document = RAGDocument( - document_id=mock_file.name, - content=mock_data_url(mock_file), - ) - - # RAGDocument returns a dict-like object - assert document['document_id'] == "test.txt" - assert document['content'].startswith("data:") - mock_data_url.assert_called_once() + def test_get_documents_connection_error(self): + """Test that connection errors return None""" + # Setup mock pool that raises an exception on acquire + async def mock_get_pool(): + raise Exception("Connection refused") + + with patch.object(upload_module, '_get_pg_pool', mock_get_pool): + result = upload_module._get_documents_from_pgvector("error-db") + + # Error should return None + assert result is None - def test_multiple_documents_processing(self): - """Test processing multiple uploaded files""" - # Simulate multiple uploaded files - mock_file1 = Mock() - mock_file1.name = "doc1.txt" - mock_file2 = Mock() - mock_file2.name = "doc2.pdf" - mock_file3 = Mock() - mock_file3.name = "doc3.docx" + def test_get_documents_filters_null_ids(self): + """Test that null document IDs are filtered out""" + mock_conn = AsyncMock() + mock_rows = [ + {'document_id': 'valid1.pdf'}, + {'document_id': None}, # Should be filtered + {'document_id': 'valid2.txt'}, + {'document_id': None}, # Should be filtered + ] + mock_conn.fetch = AsyncMock(return_value=mock_rows) - uploaded_files = [mock_file1, mock_file2, mock_file3] + mock_pool = create_mock_pool_with_connection(mock_conn) - # Simulate creating document list - document_ids = [f.name for f in uploaded_files] + async def mock_get_pool(): + return mock_pool - assert len(document_ids) == 3 - assert "doc1.txt" in document_ids - assert "doc2.pdf" in document_ids - assert "doc3.docx" in document_ids + with patch.object(upload_module, '_get_pg_pool', mock_get_pool): + result = upload_module._get_documents_from_pgvector("mixed-db") + + # Only valid IDs should be returned + assert result == ['valid1.pdf', 'valid2.txt'] + assert len(result) == 2 -class TestVectorDBOperations: - """Test vector database operations""" +class TestDeleteDocumentFromPgvector: + """Unit tests for _delete_document_from_pgvector function""" - @patch('llama_stack_ui.distribution.ui.modules.api.llama_stack_api') - def test_vector_db_registration_params(self, mock_api): - """Test that vector DB registration uses correct parameters""" - mock_client = Mock() - mock_api.client = mock_client + def test_delete_document_success(self): + """Test successful deletion of document from pgvector""" + mock_conn = AsyncMock() + mock_conn.execute = AsyncMock(return_value="DELETE 5") + + mock_pool = create_mock_pool_with_connection(mock_conn) - vector_db_id = "test_vector_db" - embedding_dimension = 384 - embedding_model = "all-MiniLM-L6-v2" - provider_id = "pgvector" - - # Simulate registration call - mock_client.vector_dbs.register( - vector_db_id=vector_db_id, - embedding_dimension=embedding_dimension, - embedding_model=embedding_model, - provider_id=provider_id, - ) - - # Verify the call was made with correct params - mock_client.vector_dbs.register.assert_called_once_with( - vector_db_id=vector_db_id, - embedding_dimension=embedding_dimension, - embedding_model=embedding_model, - provider_id=provider_id, - ) + async def mock_get_pool(): + return mock_pool + + with patch.object(upload_module, '_get_pg_pool', mock_get_pool): + success, count, error = upload_module._delete_document_from_pgvector( + "my-test-db", + "document.pdf" + ) + + assert success is True + assert count == 5 + assert error is None + mock_pool.acquire.assert_called_once() - @patch('llama_stack_ui.distribution.ui.modules.api.llama_stack_api') - def test_document_insertion_params(self, mock_api): - """Test that document insertion uses correct parameters""" - from llama_stack_client import RAGDocument + def test_delete_document_not_found(self): + """Test deletion when document doesn't exist""" + mock_conn = AsyncMock() + mock_conn.execute = AsyncMock(return_value="DELETE 0") - mock_client = Mock() - mock_api.client = mock_client + mock_pool = create_mock_pool_with_connection(mock_conn) - vector_db_id = "test_vector_db" - documents = [ - RAGDocument(document_id="doc1", content="content1"), - RAGDocument(document_id="doc2", content="content2"), + async def mock_get_pool(): + return mock_pool + + with patch.object(upload_module, '_get_pg_pool', mock_get_pool): + success, count, error = upload_module._delete_document_from_pgvector( + "my-test-db", + "nonexistent.pdf" + ) + + assert success is True + assert count == 0 + assert error is None + + def test_delete_document_connection_error(self): + """Test deletion with connection error""" + async def mock_get_pool(): + raise Exception("Connection refused") + + with patch.object(upload_module, '_get_pg_pool', mock_get_pool): + success, count, error = upload_module._delete_document_from_pgvector( + "my-test-db", + "document.pdf" + ) + + assert success is False + assert count == 0 + assert error is not None + assert "Connection refused" in str(error) + + +class TestCreateVectorDatabase: + """Unit tests for _create_vector_database function""" + + def test_create_vector_database_success(self): + """Test successful creation of vector database""" + # Mock the API client + mock_client = MagicMock() + mock_client.vector_dbs.list.return_value = [] + mock_client.providers.list.return_value = [ + MagicMock(api="vector_io", provider_id="pgvector") ] - chunk_size = 512 - - # Simulate insertion call - mock_client.tool_runtime.rag_tool.insert( - vector_db_id=vector_db_id, - documents=documents, - chunk_size_in_tokens=chunk_size, - ) - - # Verify the call was made - mock_client.tool_runtime.rag_tool.insert.assert_called_once() - call_args = mock_client.tool_runtime.rag_tool.insert.call_args - assert call_args[1]['vector_db_id'] == vector_db_id - assert call_args[1]['chunk_size_in_tokens'] == chunk_size - assert len(call_args[1]['documents']) == 2 + mock_client.vector_dbs.register.return_value = MagicMock() + + mock_api = MagicMock() + mock_api.client = mock_client + + # Mock session state + mock_st = MagicMock() + mock_st.session_state = {} + + with patch.object(upload_module, 'llama_stack_api', mock_api): + with patch.object(upload_module, 'st', mock_st): + upload_module._create_vector_database("new-test-db") + + # Verify registration was called with correct parameters + mock_client.vector_dbs.register.assert_called_once() + call_kwargs = mock_client.vector_dbs.register.call_args[1] + assert call_kwargs['vector_db_id'] == "new-test-db" + assert call_kwargs['embedding_model'] == "all-MiniLM-L6-v2" + assert call_kwargs['embedding_dimension'] == 384 + assert call_kwargs['provider_id'] == "pgvector" - @patch('llama_stack_ui.distribution.ui.modules.api.llama_stack_api') - def test_provider_detection(self, mock_api): - """Test vector IO provider detection""" - mock_client = Mock() + def test_create_vector_database_duplicate_name(self): + """Test that duplicate names are rejected""" + # Mock existing database with same name + existing_db = MagicMock() + existing_db.identifier = "existing-db" + + mock_client = MagicMock() + mock_client.vector_dbs.list.return_value = [existing_db] + + mock_api = MagicMock() mock_api.client = mock_client - # Mock provider list - mock_providers = [ - Mock(api="inference", provider_id="ollama"), - Mock(api="vector_io", provider_id="pgvector"), - Mock(api="memory", provider_id="redis"), + mock_st = MagicMock() + mock_st.session_state = {} + + with patch.object(upload_module, 'llama_stack_api', mock_api): + with patch.object(upload_module, 'st', mock_st): + upload_module._create_vector_database("existing-db") + + # Registration should NOT be called for duplicates + mock_client.vector_dbs.register.assert_not_called() + + # Error status should be set + assert mock_st.session_state.get("creation_status") == "error" + + def test_create_vector_database_no_provider(self): + """Test error when no vector_io provider exists""" + mock_client = MagicMock() + mock_client.vector_dbs.list.return_value = [] + mock_client.providers.list.return_value = [ + MagicMock(api="inference", provider_id="ollama") # No vector_io ] - mock_client.providers.list.return_value = mock_providers - # Simulate provider detection logic - providers = mock_client.providers.list() - vector_io_provider = None - for x in providers: - if x.api == "vector_io": - vector_io_provider = x.provider_id + mock_api = MagicMock() + mock_api.client = mock_client + + mock_st = MagicMock() + mock_st.session_state = {} - assert vector_io_provider == "pgvector" + with patch.object(upload_module, 'llama_stack_api', mock_api): + with patch.object(upload_module, 'st', mock_st): + upload_module._create_vector_database("new-db") + + # Registration should NOT be called without provider + mock_client.vector_dbs.register.assert_not_called() + + # Error status should be set + assert mock_st.session_state.get("creation_status") == "error" -class TestUploadValidation: - """Test upload validation and error handling""" - - def test_empty_upload_list(self): - """Test handling of empty upload list""" - uploaded_files = [] - assert len(uploaded_files) == 0 +class TestConnectionPool: + """Unit tests for the connection pool functionality""" - def test_upload_count_display(self): - """Test upload count display logic""" - uploaded_files = [Mock(), Mock(), Mock()] - count = len(uploaded_files) - message = f"Successfully uploaded {count} files" - - assert message == "Successfully uploaded 3 files" - assert str(count) in message - - def test_vector_db_name_validation(self): - """Test vector database name validation""" - # Valid names - valid_names = ["rag_vector_db", "test-db-123", "my_documents"] - for name in valid_names: - assert len(name) > 0 - assert name.replace('_', '').replace('-', '').isalnum() + def test_pool_is_reused(self): + """Test that the same pool is returned on subsequent calls""" + # Reset the global pool + upload_module._pg_pool = None + + mock_pool = AsyncMock() + + # Mock asyncpg.create_pool + async def mock_create_pool(**kwargs): + return mock_pool - # Invalid names should be caught - invalid_name = "" - assert len(invalid_name) == 0 + with patch.object(upload_module.asyncpg, 'create_pool', mock_create_pool): + # Get pool twice + async def get_pools(): + pool1 = await upload_module._get_pg_pool() + pool2 = await upload_module._get_pg_pool() + return pool1, pool2 + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + pool1, pool2 = loop.run_until_complete(get_pools()) + + # Should be the same pool instance + assert pool1 is pool2 + + # Clean up + upload_module._pg_pool = None if __name__ == "__main__": pytest.main([__file__, "-v"]) -