From 1768edaef32e65be50f1bc57f3074bfe45e24b76 Mon Sep 17 00:00:00 2001 From: Sid Kattoju Date: Mon, 15 Dec 2025 10:26:31 -0500 Subject: [PATCH 01/18] fix: Improve port forwarding reliability for UI E2E tests - Wait for pods to be ready (not just deployments) - Add verification that port forwarding is actually working - Check port connectivity before running tests - Add better error messages and logging - Verify port-forward processes stay alive This should fix timeout issues where tests couldn't connect to localhost:8501 --- .github/workflows/e2e-tests.yaml | 61 ++++++++++++++++++++++++++++++-- 1 file changed, 58 insertions(+), 3 deletions(-) diff --git a/.github/workflows/e2e-tests.yaml b/.github/workflows/e2e-tests.yaml index 3b942df..a379857 100644 --- a/.github/workflows/e2e-tests.yaml +++ b/.github/workflows/e2e-tests.yaml @@ -629,17 +629,72 @@ jobs: deployment/llamastack -n rag-e2e-ui kubectl wait --for=condition=available --timeout=300s \ deployment/rag -n rag-e2e-ui + + echo "Waiting for pods to be ready..." + kubectl wait --for=condition=ready --timeout=600s \ + pod -l app.kubernetes.io/name=llamastack -n rag-e2e-ui + kubectl wait --for=condition=ready --timeout=300s \ + pod -l app.kubernetes.io/name=rag -n rag-e2e-ui + + echo "✅ All pods are ready" + kubectl get pods -n rag-e2e-ui - name: Expose services via NodePort run: | kubectl patch service rag -n rag-e2e-ui -p '{"spec":{"type":"NodePort","ports":[{"port":8501,"nodePort":30080}]}}' kubectl patch service llamastack -n rag-e2e-ui -p '{"spec":{"type":"NodePort","ports":[{"port":8321,"nodePort":30081}]}}' + + # Verify services + kubectl get services -n rag-e2e-ui - name: Port forward services run: | - kubectl port-forward -n rag-e2e-ui svc/rag 8501:8501 & - kubectl port-forward -n rag-e2e-ui svc/llamastack 8321:8321 & - sleep 10 + echo "Starting port forwarding..." + # Start port forwarding in background and capture PIDs + kubectl port-forward -n rag-e2e-ui svc/rag 8501:8501 > /tmp/rag-portforward.log 2>&1 & + RAG_PF_PID=$! + echo "RAG port-forward PID: $RAG_PF_PID" + + kubectl port-forward -n rag-e2e-ui svc/llamastack 8321:8321 > /tmp/llamastack-portforward.log 2>&1 & + LLAMASTACK_PF_PID=$! + echo "LlamaStack port-forward PID: $LLAMASTACK_PF_PID" + + # Wait for port forwarding to establish and verify connectivity + echo "Waiting for port forwarding to be ready..." + for i in {1..30}; do + # Check if ports are listening + if (timeout 2 bash -c 'cat < /dev/null > /dev/tcp/localhost/8501' 2>/dev/null) && \ + (timeout 2 bash -c 'cat < /dev/null > /dev/tcp/localhost/8321' 2>/dev/null); then + echo "✅ Port forwarding is working! (attempt $i)" + break + fi + if [ $i -eq 30 ]; then + echo "❌ Port forwarding failed to establish after 30 attempts" + echo "RAG port-forward log:" + cat /tmp/rag-portforward.log || true + echo "LlamaStack port-forward log:" + cat /tmp/llamastack-portforward.log || true + echo "Checking port-forward processes:" + ps aux | grep "kubectl port-forward" || true + echo "Checking if ports are listening:" + ss -tlnp | grep -E ':(8501|8321)' || netstat -tlnp 2>/dev/null | grep -E ':(8501|8321)' || true + exit 1 + fi + echo "Waiting for port forwarding... (attempt $i/30)" + sleep 2 + done + + # Verify processes are still running + if ! kill -0 $RAG_PF_PID 2>/dev/null || ! kill -0 $LLAMASTACK_PF_PID 2>/dev/null; then + echo "❌ Port forwarding processes died" + echo "RAG port-forward log:" + cat /tmp/rag-portforward.log || true + echo "LlamaStack port-forward log:" + cat /tmp/llamastack-portforward.log || true + exit 1 + fi + + echo "✅ Port forwarding verified and ready" - name: Run UI E2E tests with Playwright env: From b8d9b71e48d98589a54b8f58d1dba13e76531399 Mon Sep 17 00:00:00 2001 From: Sid Kattoju Date: Mon, 15 Dec 2025 10:35:03 -0500 Subject: [PATCH 02/18] fix: Wait for agent type selector UI elements instead of checking raw HTML The test was checking page.content() which returns raw HTML before React renders. Now it properly waits for the UI elements to be visible using Playwright's expect().to_be_visible() which is more reliable and matches the pattern used in other tests. --- tests/e2e_ui/test_chat_ui.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/tests/e2e_ui/test_chat_ui.py b/tests/e2e_ui/test_chat_ui.py index aee05ad..36f9473 100644 --- a/tests/e2e_ui/test_chat_ui.py +++ b/tests/e2e_ui/test_chat_ui.py @@ -116,10 +116,17 @@ def test_agent_type_selector(self, page: Page): agent_radio.click() time.sleep(1) - # Look for agent type options with more specific selectors - # Check if either Regular or ReAct options exist - page_content = page.content() - assert "Regular" in page_content or "ReAct" in page_content + # Wait for agent type options to be visible + # Check if either Regular or ReAct options exist in the rendered UI + regular_option = page.get_by_text("Regular", exact=False) + react_option = page.get_by_text("ReAct", exact=False) + + # At least one of them should be visible + try: + expect(regular_option).to_be_visible(timeout=TEST_TIMEOUT) + except AssertionError: + # If Regular is not visible, ReAct should be + expect(react_option).to_be_visible(timeout=TEST_TIMEOUT) class TestConfigurationOptions: From a59cfc38417a72f26b41d476d0dcd823734c308f Mon Sep 17 00:00:00 2001 From: Sid Kattoju Date: Mon, 15 Dec 2025 10:56:08 -0500 Subject: [PATCH 03/18] fix: Keep port forwarding alive during all tests to prevent timeouts - Combine port forwarding and test steps so processes persist - Add retry logic in test fixture for navigation failures - Add cleanup trap to ensure port forwarding is killed on exit - Increase timeouts for navigation and network idle waits - Better error handling and verification The issue was that port forwarding processes were dying between steps. By running port forwarding in the same step as tests, the processes stay alive throughout all test execution. --- .github/workflows/e2e-tests.yaml | 65 ++++++++++++++++++++------------ tests/e2e_ui/test_chat_ui.py | 22 ++++++++--- 2 files changed, 57 insertions(+), 30 deletions(-) diff --git a/.github/workflows/e2e-tests.yaml b/.github/workflows/e2e-tests.yaml index a379857..3d7e9c7 100644 --- a/.github/workflows/e2e-tests.yaml +++ b/.github/workflows/e2e-tests.yaml @@ -647,24 +647,42 @@ jobs: # Verify services kubectl get services -n rag-e2e-ui - - name: Port forward services + - name: Run UI E2E tests with Playwright + env: + RAG_UI_ENDPOINT: http://localhost:8501 + LLAMA_STACK_ENDPOINT: http://localhost:8321 + MAAS_ENDPOINT: ${{ env.MAAS_ENDPOINT }} + MAAS_MODEL_ID: ${{ env.MAAS_MODEL_ID }} + SKIP_MODEL_TESTS: "false" # Enable MaaS inference tests in UI run: | - echo "Starting port forwarding..." - # Start port forwarding in background and capture PIDs + echo "Starting port forwarding and running tests..." + + # Start port forwarding in background and keep them running + echo "Starting port forwarding for RAG UI..." kubectl port-forward -n rag-e2e-ui svc/rag 8501:8501 > /tmp/rag-portforward.log 2>&1 & RAG_PF_PID=$! echo "RAG port-forward PID: $RAG_PF_PID" + echo "Starting port forwarding for LlamaStack..." kubectl port-forward -n rag-e2e-ui svc/llamastack 8321:8321 > /tmp/llamastack-portforward.log 2>&1 & LLAMASTACK_PF_PID=$! echo "LlamaStack port-forward PID: $LLAMASTACK_PF_PID" - # Wait for port forwarding to establish and verify connectivity + # Function to check if port forwarding is working + check_port_forwarding() { + (timeout 2 bash -c 'cat < /dev/null > /dev/tcp/localhost/8501' 2>/dev/null) && \ + (timeout 2 bash -c 'cat < /dev/null > /dev/tcp/localhost/8321' 2>/dev/null) + } + + # Function to verify processes are alive + check_processes() { + kill -0 $RAG_PF_PID 2>/dev/null && kill -0 $LLAMASTACK_PF_PID 2>/dev/null + } + + # Wait for port forwarding to establish echo "Waiting for port forwarding to be ready..." for i in {1..30}; do - # Check if ports are listening - if (timeout 2 bash -c 'cat < /dev/null > /dev/tcp/localhost/8501' 2>/dev/null) && \ - (timeout 2 bash -c 'cat < /dev/null > /dev/tcp/localhost/8321' 2>/dev/null); then + if check_port_forwarding && check_processes; then echo "✅ Port forwarding is working! (attempt $i)" break fi @@ -684,30 +702,27 @@ jobs: sleep 2 done - # Verify processes are still running - if ! kill -0 $RAG_PF_PID 2>/dev/null || ! kill -0 $LLAMASTACK_PF_PID 2>/dev/null; then - echo "❌ Port forwarding processes died" + # Set up cleanup trap to kill port forwarding on exit + trap "echo 'Cleaning up port forwarding...'; kill $RAG_PF_PID $LLAMASTACK_PF_PID 2>/dev/null || true" EXIT + + echo "Running UI E2E tests with MaaS integration..." + echo "MaaS Endpoint: ${MAAS_ENDPOINT}" + echo "MaaS Model ID: ${MAAS_MODEL_ID}" + + # Run tests + pytest tests/e2e_ui/ -v --tb=short --browser chromium || TEST_EXIT_CODE=$? + + # Verify port forwarding was still working after tests + if ! check_port_forwarding; then + echo "⚠️ Warning: Port forwarding stopped working during tests" echo "RAG port-forward log:" cat /tmp/rag-portforward.log || true echo "LlamaStack port-forward log:" cat /tmp/llamastack-portforward.log || true - exit 1 fi - echo "✅ Port forwarding verified and ready" - - - name: Run UI E2E tests with Playwright - env: - RAG_UI_ENDPOINT: http://localhost:8501 - LLAMA_STACK_ENDPOINT: http://localhost:8321 - MAAS_ENDPOINT: ${{ env.MAAS_ENDPOINT }} - MAAS_MODEL_ID: ${{ env.MAAS_MODEL_ID }} - SKIP_MODEL_TESTS: "false" # Enable MaaS inference tests in UI - run: | - echo "Running UI E2E tests with MaaS integration..." - echo "MaaS Endpoint: ${MAAS_ENDPOINT}" - echo "MaaS Model ID: ${MAAS_MODEL_ID}" - pytest tests/e2e_ui/ -v --tb=short --browser chromium + # Exit with test result + exit ${TEST_EXIT_CODE:-0} - name: Upload Playwright test results if: always() diff --git a/tests/e2e_ui/test_chat_ui.py b/tests/e2e_ui/test_chat_ui.py index 36f9473..5e7843b 100644 --- a/tests/e2e_ui/test_chat_ui.py +++ b/tests/e2e_ui/test_chat_ui.py @@ -29,11 +29,23 @@ def browser_context_args(browser_context_args): @pytest.fixture(autouse=True) def wait_for_app(page: Page): """Wait for the Streamlit app to be ready before each test""" - page.goto(RAG_UI_ENDPOINT) - # Wait for Streamlit to finish loading - page.wait_for_load_state("networkidle") - # Give Streamlit additional time to initialize - time.sleep(2) + # Retry navigation in case of transient connection issues + max_retries = 3 + for attempt in range(max_retries): + try: + page.goto(RAG_UI_ENDPOINT, timeout=60000, wait_until="domcontentloaded") + # Wait for Streamlit to finish loading + page.wait_for_load_state("networkidle", timeout=60000) + # Give Streamlit additional time to initialize + time.sleep(2) + # Verify page actually loaded + if page.url.startswith(RAG_UI_ENDPOINT): + return + except Exception as e: + if attempt == max_retries - 1: + raise + print(f"Navigation attempt {attempt + 1} failed: {e}, retrying...") + time.sleep(2) class TestChatUIBasics: From 0baeb3bf170fc424d6618cb73b41806574d6d936 Mon Sep 17 00:00:00 2001 From: Sid Kattoju Date: Mon, 15 Dec 2025 10:56:39 -0500 Subject: [PATCH 04/18] fix: Enhance test reliability by ensuring UI elements are rendered before interaction - Implement waiting for UI elements instead of relying on raw HTML - Utilize Playwright's expect().to_be_visible() for better synchronization - Align with patterns used in other tests for consistency This change addresses issues with tests failing due to premature interactions with the UI. --- tests/e2e_ui/test_e2e_minimal.py | 110 +++++++++++++++++++++++++++++++ 1 file changed, 110 insertions(+) create mode 100644 tests/e2e_ui/test_e2e_minimal.py diff --git a/tests/e2e_ui/test_e2e_minimal.py b/tests/e2e_ui/test_e2e_minimal.py new file mode 100644 index 0000000..bbdc62f --- /dev/null +++ b/tests/e2e_ui/test_e2e_minimal.py @@ -0,0 +1,110 @@ +""" +Minimal E2E tests that actually call the backend +Only includes tests that make real API calls to verify end-to-end functionality +""" +import pytest +import os +import time +from playwright.sync_api import Page, expect + + +# Configuration +RAG_UI_ENDPOINT = os.getenv("RAG_UI_ENDPOINT", "http://localhost:8501") +TEST_TIMEOUT = 30000 # 30 seconds + + +@pytest.fixture(scope="session") +def browser_context_args(browser_context_args): + """Configure browser context""" + return { + **browser_context_args, + "viewport": { + "width": 1920, + "height": 1080, + }, + } + + +@pytest.fixture(autouse=True) +def wait_for_app(page: Page): + """Wait for the Streamlit app to be ready before each test""" + page.goto(RAG_UI_ENDPOINT) + # Wait for Streamlit to finish loading + page.wait_for_load_state("networkidle") + # Give Streamlit additional time to initialize + time.sleep(2) + + +class TestMaaSIntegration: + """E2E tests for MaaS (Model-as-a-Service) integration through the UI + + These tests verify that MaaS works end-to-end through the browser UI. + They send actual messages and verify MaaS responses - these are the only + tests that actually call the backend. + """ + + @pytest.mark.skipif( + os.getenv("SKIP_MODEL_TESTS", "false").lower() == "true", + reason="Model inference tests disabled via SKIP_MODEL_TESTS" + ) + def test_maas_chat_completion_direct_mode(self, page: Page): + """Test that MaaS responds to chat messages in direct mode - E2E test with backend call""" + # Verify the chat input is visible + chat_input = page.get_by_placeholder("Ask a question...", exact=False) + expect(chat_input).to_be_visible(timeout=TEST_TIMEOUT) + + # Send a simple test message + test_message = "Say 'Hello from RAG e2e test!' in one short sentence." + chat_input.fill(test_message) + chat_input.press("Enter") + + # Wait for Streamlit to process the input and rerun + page.wait_for_load_state("networkidle") + time.sleep(3) # Give Streamlit time to send request and start receiving response + + # Wait for the user message to appear in chat + user_msg = page.get_by_text(test_message, exact=False) + expect(user_msg).to_be_visible(timeout=TEST_TIMEOUT) + + # Wait for assistant response (MaaS should respond) + # Streamlit chat messages have structure: stChatMessage with role + max_wait = 90 # seconds - MaaS can be slow + wait_time = 0 + while wait_time < max_wait: + time.sleep(3) + wait_time += 3 + + # Check for new assistant message content + assistant_containers = page.locator('[data-testid="stChatMessage"]').all() + + for container in assistant_containers: + if container.is_visible(): + text_content = container.inner_text().strip() + # Check if it's a new assistant message (not greeting, not user message) + if (text_content and + text_content != "How can I help you?" and + test_message not in text_content and + len(text_content) > 15): # Real response should be substantial + # Found a real MaaS response! + print(f"✅ MaaS responded: {text_content[:150]}...") + assert len(text_content) > 10, "MaaS response too short" + return # Success! + + # Also check for any new text that appeared after user message + all_visible_text = page.locator('body').inner_text() + if test_message in all_visible_text: + # Check if there's additional text that looks like a response + lines = all_visible_text.split('\n') + for line in lines: + line = line.strip() + if (line and + test_message not in line and + "How can I help you?" not in line and + len(line) > 20 and # Substantial response + any(word in line.lower() for word in ['hello', 'test', 'rag', 'e2e', 'from'])): # Should mention something from our test + print(f"✅ MaaS responded (found in text): {line[:150]}...") + return # Success! + + # If we get here, no response was received + pytest.fail(f"MaaS did not respond within {max_wait} seconds") + From 72bfd8a8c680637bd1407d13ac17b8540bcdb9ab Mon Sep 17 00:00:00 2001 From: Sid Kattoju Date: Mon, 15 Dec 2025 11:06:53 -0500 Subject: [PATCH 05/18] fix: Improve agent type selector test to wait for UI update - Wait for networkidle after clicking agent mode - Add additional sleep to allow Streamlit to re-render - This should fix the test that was checking raw HTML instead of rendered UI --- tests/e2e_ui/test_chat_ui.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/e2e_ui/test_chat_ui.py b/tests/e2e_ui/test_chat_ui.py index 5e7843b..bbb44e4 100644 --- a/tests/e2e_ui/test_chat_ui.py +++ b/tests/e2e_ui/test_chat_ui.py @@ -126,14 +126,18 @@ def test_agent_type_selector(self, page: Page): agent_radio = page.get_by_text("Agent-based", exact=False).first if agent_radio.is_visible(): agent_radio.click() - time.sleep(1) + # Wait for Streamlit to re-render after mode change + page.wait_for_load_state("networkidle") + time.sleep(2) # Give Streamlit time to update UI # Wait for agent type options to be visible # Check if either Regular or ReAct options exist in the rendered UI + # These might appear as radio buttons or in a selectbox regular_option = page.get_by_text("Regular", exact=False) react_option = page.get_by_text("ReAct", exact=False) - # At least one of them should be visible + # At least one of them should be visible after selecting agent mode + # Try Regular first, then ReAct try: expect(regular_option).to_be_visible(timeout=TEST_TIMEOUT) except AssertionError: From be8bba522fad575215a25c0a1fc2200e89651bb5 Mon Sep 17 00:00:00 2001 From: Sid Kattoju Date: Mon, 15 Dec 2025 11:20:22 -0500 Subject: [PATCH 06/18] fix: Remove redundant test_agent_type_selector test The agent type selector (Regular vs ReAct) UI feature is commented out in the codebase, so this test was testing a non-existent feature. The test was redundant with test_agent_mode_shows_toolgroups which already verifies agent mode selection and toolgroups visibility. Removed the test to avoid confusion and reduce maintenance burden. --- tests/e2e_ui/test_chat_ui.py | 23 ----------------------- 1 file changed, 23 deletions(-) diff --git a/tests/e2e_ui/test_chat_ui.py b/tests/e2e_ui/test_chat_ui.py index bbb44e4..aeec018 100644 --- a/tests/e2e_ui/test_chat_ui.py +++ b/tests/e2e_ui/test_chat_ui.py @@ -120,29 +120,6 @@ def test_agent_mode_shows_toolgroups(self, page: Page): toolgroups = page.get_by_text("Available ToolGroups", exact=False) expect(toolgroups).to_be_visible(timeout=TEST_TIMEOUT) - - def test_agent_type_selector(self, page: Page): - """Test agent type selector (Regular vs ReAct)""" - agent_radio = page.get_by_text("Agent-based", exact=False).first - if agent_radio.is_visible(): - agent_radio.click() - # Wait for Streamlit to re-render after mode change - page.wait_for_load_state("networkidle") - time.sleep(2) # Give Streamlit time to update UI - - # Wait for agent type options to be visible - # Check if either Regular or ReAct options exist in the rendered UI - # These might appear as radio buttons or in a selectbox - regular_option = page.get_by_text("Regular", exact=False) - react_option = page.get_by_text("ReAct", exact=False) - - # At least one of them should be visible after selecting agent mode - # Try Regular first, then ReAct - try: - expect(regular_option).to_be_visible(timeout=TEST_TIMEOUT) - except AssertionError: - # If Regular is not visible, ReAct should be - expect(react_option).to_be_visible(timeout=TEST_TIMEOUT) class TestConfigurationOptions: From bb8523236e371a47872c8af6fd1f131034c0eab9 Mon Sep 17 00:00:00 2001 From: Sid Kattoju Date: Mon, 15 Dec 2025 12:01:01 -0500 Subject: [PATCH 07/18] fix: Remove invalid pytest config option to eliminate warning Removed 'asyncio_default_fixture_loop_scope' which is not a valid pytest configuration option and was causing a PytestConfigWarning. This option is not needed as no tests in the codebase use async/await. The warning was harmless but removing it keeps the test output clean. --- pytest.ini | 1 - 1 file changed, 1 deletion(-) diff --git a/pytest.ini b/pytest.ini index c37d9fc..8636a61 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,6 +1,5 @@ [pytest] # Pytest configuration -asyncio_default_fixture_loop_scope = function # Test discovery patterns python_files = test_*.py From 0ff01e95d8e051874e90c34b1fd8c0829d4a848c Mon Sep 17 00:00:00 2001 From: Sid Kattoju Date: Mon, 15 Dec 2025 13:36:49 -0500 Subject: [PATCH 08/18] chore: Remove unused duplicate test file Removed tests/e2e_ui/test_e2e_minimal.py which contains duplicate tests that already exist in test_chat_ui.py. The TestMaaSIntegration class and test_maas_chat_completion_direct_mode test are already present in test_chat_ui.py, making this file redundant. --- tests/e2e_ui/test_e2e_minimal.py | 110 ------------------------------- 1 file changed, 110 deletions(-) delete mode 100644 tests/e2e_ui/test_e2e_minimal.py diff --git a/tests/e2e_ui/test_e2e_minimal.py b/tests/e2e_ui/test_e2e_minimal.py deleted file mode 100644 index bbdc62f..0000000 --- a/tests/e2e_ui/test_e2e_minimal.py +++ /dev/null @@ -1,110 +0,0 @@ -""" -Minimal E2E tests that actually call the backend -Only includes tests that make real API calls to verify end-to-end functionality -""" -import pytest -import os -import time -from playwright.sync_api import Page, expect - - -# Configuration -RAG_UI_ENDPOINT = os.getenv("RAG_UI_ENDPOINT", "http://localhost:8501") -TEST_TIMEOUT = 30000 # 30 seconds - - -@pytest.fixture(scope="session") -def browser_context_args(browser_context_args): - """Configure browser context""" - return { - **browser_context_args, - "viewport": { - "width": 1920, - "height": 1080, - }, - } - - -@pytest.fixture(autouse=True) -def wait_for_app(page: Page): - """Wait for the Streamlit app to be ready before each test""" - page.goto(RAG_UI_ENDPOINT) - # Wait for Streamlit to finish loading - page.wait_for_load_state("networkidle") - # Give Streamlit additional time to initialize - time.sleep(2) - - -class TestMaaSIntegration: - """E2E tests for MaaS (Model-as-a-Service) integration through the UI - - These tests verify that MaaS works end-to-end through the browser UI. - They send actual messages and verify MaaS responses - these are the only - tests that actually call the backend. - """ - - @pytest.mark.skipif( - os.getenv("SKIP_MODEL_TESTS", "false").lower() == "true", - reason="Model inference tests disabled via SKIP_MODEL_TESTS" - ) - def test_maas_chat_completion_direct_mode(self, page: Page): - """Test that MaaS responds to chat messages in direct mode - E2E test with backend call""" - # Verify the chat input is visible - chat_input = page.get_by_placeholder("Ask a question...", exact=False) - expect(chat_input).to_be_visible(timeout=TEST_TIMEOUT) - - # Send a simple test message - test_message = "Say 'Hello from RAG e2e test!' in one short sentence." - chat_input.fill(test_message) - chat_input.press("Enter") - - # Wait for Streamlit to process the input and rerun - page.wait_for_load_state("networkidle") - time.sleep(3) # Give Streamlit time to send request and start receiving response - - # Wait for the user message to appear in chat - user_msg = page.get_by_text(test_message, exact=False) - expect(user_msg).to_be_visible(timeout=TEST_TIMEOUT) - - # Wait for assistant response (MaaS should respond) - # Streamlit chat messages have structure: stChatMessage with role - max_wait = 90 # seconds - MaaS can be slow - wait_time = 0 - while wait_time < max_wait: - time.sleep(3) - wait_time += 3 - - # Check for new assistant message content - assistant_containers = page.locator('[data-testid="stChatMessage"]').all() - - for container in assistant_containers: - if container.is_visible(): - text_content = container.inner_text().strip() - # Check if it's a new assistant message (not greeting, not user message) - if (text_content and - text_content != "How can I help you?" and - test_message not in text_content and - len(text_content) > 15): # Real response should be substantial - # Found a real MaaS response! - print(f"✅ MaaS responded: {text_content[:150]}...") - assert len(text_content) > 10, "MaaS response too short" - return # Success! - - # Also check for any new text that appeared after user message - all_visible_text = page.locator('body').inner_text() - if test_message in all_visible_text: - # Check if there's additional text that looks like a response - lines = all_visible_text.split('\n') - for line in lines: - line = line.strip() - if (line and - test_message not in line and - "How can I help you?" not in line and - len(line) > 20 and # Substantial response - any(word in line.lower() for word in ['hello', 'test', 'rag', 'e2e', 'from'])): # Should mention something from our test - print(f"✅ MaaS responded (found in text): {line[:150]}...") - return # Success! - - # If we get here, no response was received - pytest.fail(f"MaaS did not respond within {max_wait} seconds") - From 871c2d69637430f7f015c41f71c0dedd5220e21f Mon Sep 17 00:00:00 2001 From: Ganesh Murthy Date: Fri, 19 Dec 2025 10:13:00 -0500 Subject: [PATCH 09/18] APPENG-4252: Adds the following features Add new drop down to show all available vector databases and when a database is picked, show the documents already uploaded to the database. Ability to delete a document that was already added to a vector database Ability to upload document to the any chosen vector database Create new vector databases by clicking the Create New button --- deploy/helm/rag/templates/deployment.yaml | 12 + .../distribution/ui/page/upload/upload.py | 627 ++++++++++++++++-- frontend/pyproject.toml | 1 + tests/integration/test_upload_integration.py | 16 + tests/unit/test_upload.py | 450 ++++++++----- 5 files changed, 889 insertions(+), 217 deletions(-) diff --git a/deploy/helm/rag/templates/deployment.yaml b/deploy/helm/rag/templates/deployment.yaml index eddbccc..dbc49ce 100644 --- a/deploy/helm/rag/templates/deployment.yaml +++ b/deploy/helm/rag/templates/deployment.yaml @@ -40,6 +40,18 @@ spec: - name: TAVILY_SEARCH_API_KEY value: {{ (index .Values "llama-stack").secrets.TAVILY_SEARCH_API_KEY | quote }} {{- end }} + {{- if .Values.pgvector }} + - name: PGVECTOR_HOST + value: {{ .Values.pgvector.secret.host | quote }} + - name: PGVECTOR_PORT + value: {{ .Values.pgvector.secret.port | quote }} + - name: PGVECTOR_USER + value: {{ .Values.pgvector.secret.user | quote }} + - name: PGVECTOR_PASSWORD + value: {{ .Values.pgvector.secret.password | quote }} + - name: PGVECTOR_DB + value: {{ .Values.pgvector.secret.dbname | quote }} + {{- end }} {{- if .Values.suggestedQuestions }} - name: RAG_QUESTION_SUGGESTIONS valueFrom: diff --git a/frontend/llama_stack_ui/distribution/ui/page/upload/upload.py b/frontend/llama_stack_ui/distribution/ui/page/upload/upload.py index ffd0175..8547cd8 100644 --- a/frontend/llama_stack_ui/distribution/ui/page/upload/upload.py +++ b/frontend/llama_stack_ui/distribution/ui/page/upload/upload.py @@ -1,67 +1,604 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio +import asyncpg +import os import streamlit as st -from llama_stack_client import RAGDocument +import traceback + +from llama_stack_ui.distribution.ui.modules.utils import get_vector_db_name, data_url_from_file from llama_stack_ui.distribution.ui.modules.api import llama_stack_api -from llama_stack_ui.distribution.ui.modules.utils import data_url_from_file +from llama_stack_client import RAGDocument + + +# Module-level connection pool (initialized lazily) +_pg_pool = None + + +async def _get_pg_pool(): + """ + Get or create the PostgreSQL connection pool. + The pool is created lazily on first use and reused for subsequent calls. + + Returns: + asyncpg.Pool: The connection pool instance + """ + global _pg_pool + if _pg_pool is None: + pg_host = os.environ.get("PGVECTOR_HOST", "pgvector") + pg_port = os.environ.get("PGVECTOR_PORT", "5432") + pg_user = os.environ.get("PGVECTOR_USER", "postgres") + pg_password = os.environ.get("PGVECTOR_PASSWORD", "rag_password") + pg_database = os.environ.get("PGVECTOR_DB", "rag_blueprint") + + _pg_pool = await asyncpg.create_pool( + host=pg_host, + port=int(pg_port), + user=pg_user, + password=pg_password, + database=pg_database, + min_size=1, + max_size=5, + ) + return _pg_pool + def upload_page(): """ - Page to upload documents and create a vector database for RAG. + Page to upload documents and manage vector databases for RAG. + Supports creating new vector databases and uploading documents to existing ones. + """ + st.title("📄 Upload Documents") + + # Initialize session state for creation status messages + if "creation_status" not in st.session_state: + st.session_state["creation_status"] = None + if "creation_message" not in st.session_state: + st.session_state["creation_message"] = "" + + # Initialize session state for selected vector database + # This persists the selection when navigating away and back to this page + if "selected_vector_db" not in st.session_state: + st.session_state["selected_vector_db"] = "" + + # Initialize the widget key to match our tracked selection + # This ensures the selectbox displays the correct value on page load + if "vector_db_selector" not in st.session_state: + st.session_state["vector_db_selector"] = st.session_state["selected_vector_db"] + + # Initialize newly created VDB tracker + if "newly_created_vdb" not in st.session_state: + st.session_state["newly_created_vdb"] = None + + # Show status messages at the top level (before dropdown) + if st.session_state["creation_status"] == "success": + st.success(st.session_state["creation_message"]) + # Clear the message after showing it + st.session_state["creation_status"] = None + st.session_state["creation_message"] = "" + elif st.session_state["creation_status"] == "error": + st.error(st.session_state["creation_message"]) + # Clear the message after showing it + st.session_state["creation_status"] = None + st.session_state["creation_message"] = "" + + # Fetch all vector databases + vdb_list = llama_stack_api.client.vector_dbs.list() + + # Build dropdown options based on whether databases exist + dropdown_options = [] + vdb_info = {} + + # Define the "Create New" option with emoji for visibility + CREATE_NEW_OPTION = "➕ Create New" + + if vdb_list: + # When databases exist: list actual DBs first, then "Create New" LAST + existing_vdbs = {get_vector_db_name(v): v.to_dict() for v in vdb_list} + dropdown_options.extend(list(existing_vdbs.keys())) + dropdown_options.append(CREATE_NEW_OPTION) # Add "Create New" as LAST item + vdb_info = existing_vdbs + else: + # When NO databases exist: only show "Create New" + dropdown_options = [CREATE_NEW_OPTION] + + # Sync session state for widget - ensure it shows the right value + # Priority 1: If a database was just created, auto-select it (highest priority) + if st.session_state["newly_created_vdb"]: + newly_created_name = st.session_state["newly_created_vdb"] + if newly_created_name in dropdown_options: + # Update both session variables to sync state + st.session_state["selected_vector_db"] = newly_created_name + st.session_state["vector_db_selector"] = newly_created_name + st.session_state["newly_created_vdb"] = None + # Priority 2: Use the previously selected database from session if it still exists + elif st.session_state["selected_vector_db"] and st.session_state["selected_vector_db"] in dropdown_options: + # Sync widget state with our tracked state + st.session_state["vector_db_selector"] = st.session_state["selected_vector_db"] + # Priority 3: If no saved selection or saved selection doesn't exist, use smart default + else: + if vdb_list: + # When databases exist: default to FIRST actual database (not "Create New") + first_db = dropdown_options[0] # First item is first actual database + st.session_state["selected_vector_db"] = first_db + st.session_state["vector_db_selector"] = first_db + else: + # When NO databases exist: default to "Create New" + st.session_state["selected_vector_db"] = CREATE_NEW_OPTION + st.session_state["vector_db_selector"] = CREATE_NEW_OPTION + + # Vector database selection dropdown with persistent selection + # Using key parameter to bind directly to session state - NO index parameter to avoid conflicts + def on_vector_db_change(): + """Callback to update session state when selection changes""" + st.session_state["selected_vector_db"] = st.session_state["vector_db_selector"] + + selected_vector_db = st.selectbox( + "Select a vector database", + dropdown_options, + key="vector_db_selector", # Key binds to session state (session state controls the value) + on_change=on_vector_db_change, # Callback updates our tracking variable + help="Your selection will be remembered when you navigate to other pages" + ) + + # Ensure session state is updated (in case callback didn't fire) + if selected_vector_db != st.session_state["selected_vector_db"]: + st.session_state["selected_vector_db"] = selected_vector_db + + # Get the actual vector database object for API calls (do this before using it) + selected_vdb_obj = None + if selected_vector_db and selected_vector_db != CREATE_NEW_OPTION: + for vdb in vdb_list: + if get_vector_db_name(vdb) == selected_vector_db: + selected_vdb_obj = vdb + break + + if selected_vector_db == CREATE_NEW_OPTION: + # Show vector database creation UI + _show_create_vector_db_ui() + elif selected_vector_db and selected_vector_db != CREATE_NEW_OPTION: + # Show existing documents in the database (heading will show only if documents exist) + _show_existing_documents_table(selected_vector_db, selected_vdb_obj) + + # Add Browse functionality for uploading documents to this database + st.subheader(f"📁 Upload Documents to '{selected_vector_db}'") + _show_document_upload_ui(selected_vector_db, selected_vdb_obj) + # If empty string is selected, show nothing (clean default state) + + +def _show_create_vector_db_ui(): + """ + Display UI for creating a new vector database. + """ + st.subheader("Create New Vector Database") + + # Initialize session state for creation form + if "new_vdb_name" not in st.session_state: + st.session_state["new_vdb_name"] = "" + + # Vector database name input + new_vdb_name = st.text_input( + "Add New Vector Database", + value=st.session_state["new_vdb_name"], + help="Enter a unique name for the new vector database", + key="new_vdb_name_input" + ) + + # Update session state + st.session_state["new_vdb_name"] = new_vdb_name + + # Add button + if st.button("Add", type="primary", disabled=not new_vdb_name.strip()): + _create_vector_database(new_vdb_name.strip()) + + +def _create_vector_database(vdb_name): + """ + Create a new vector database using the LlamaStack API. + + Args: + vdb_name (str): Name for the new vector database + """ + try: + # Reset status + st.session_state["creation_status"] = None + st.session_state["creation_message"] = "" + + # Validate input + if not vdb_name or not vdb_name.strip(): + st.session_state["creation_status"] = "error" + st.session_state["creation_message"] = "Vector database name cannot be empty." + return + + # Check for duplicate names + existing_vdbs = llama_stack_api.client.vector_dbs.list() + existing_names = [get_vector_db_name(vdb) for vdb in existing_vdbs] + if vdb_name in existing_names: + st.session_state["creation_status"] = "error" + st.session_state["creation_message"] = f"Vector database '{vdb_name}' already exists. Please choose a different name." + return + + # Get vector IO provider + providers = llama_stack_api.client.providers.list() + vector_io_provider = None + for provider in providers: + if provider.api == "vector_io": + vector_io_provider = provider.provider_id + break + + if not vector_io_provider: + st.session_state["creation_status"] = "error" + st.session_state["creation_message"] = "No vector IO provider found. Cannot create vector database." + return + + # Create the vector database + with st.spinner(f"Creating vector database '{vdb_name}'..."): + vector_db = llama_stack_api.client.vector_dbs.register( + vector_db_id=vdb_name, + embedding_dimension=384, + embedding_model="all-MiniLM-L6-v2", + provider_id=vector_io_provider, + ) + + # Success + st.session_state["creation_status"] = "success" + st.session_state["creation_message"] = f"Vector database '{vdb_name}' created successfully!" + + # Mark this database to be auto-selected after refresh + st.session_state["newly_created_vdb"] = vdb_name + + # Clear the input field + st.session_state["new_vdb_name"] = "" + + # Trigger page refresh to update the dropdown - this will show the message at the top + st.rerun() + + except Exception as e: + st.session_state["creation_status"] = "error" + st.session_state["creation_message"] = f"Error creating vector database: {str(e)}" + + +def _show_document_upload_ui(vector_db_name, vector_db_obj=None): + """ + Display UI for uploading documents to an existing vector database. + + Args: + vector_db_name (str): Name of the selected vector database """ - st.title("📄 Upload") - # File/Directory Upload Section - st.subheader("Create Vector DB") - # Let user select files to ingest + # Initialize session state for upload status + if "upload_status" not in st.session_state: + st.session_state["upload_status"] = None + if "upload_message" not in st.session_state: + st.session_state["upload_message"] = "" + + # Show upload status messages + if st.session_state["upload_status"] == "success": + st.success(st.session_state["upload_message"]) + # Clear after showing + st.session_state["upload_status"] = None + st.session_state["upload_message"] = "" + elif st.session_state["upload_status"] == "error": + st.error(st.session_state["upload_message"]) + # Clear after showing + st.session_state["upload_status"] = None + st.session_state["upload_message"] = "" + + # Initialize session state to track processed files + upload_key = f"processed_files_{vector_db_name}" + if upload_key not in st.session_state: + st.session_state[upload_key] = set() + + # File uploader uploaded_files = st.file_uploader( - "Upload file(s) or directory", + "Browse and select files to upload (files will upload automatically)", accept_multiple_files=True, - type=["txt", "pdf", "doc", "docx"], # supported file types + type=["txt", "pdf", "doc", "docx"], + key=f"uploader_{vector_db_name}", # Unique key per database + help="Select one or more documents - they will be uploaded automatically to this vector database" ) - # Process uploaded files + + # Auto-upload when files are selected if uploaded_files: - # Show upload success and prompt for DB name - st.success(f"Successfully uploaded {len(uploaded_files)} files") - vector_db_name = st.text_input( - "Vector Database Name", - value="rag_vector_db", - help="Enter a unique identifier for this vector database", - ) - if st.button("Create Vector Database"): - # Convert uploaded files into RAGDocument instances + # Create a unique identifier for this set of files + file_set_id = frozenset([f.name + str(f.size) for f in uploaded_files]) + + # Only process if this is a new set of files + if file_set_id not in st.session_state[upload_key]: + # Mark as processed IMMEDIATELY before upload to prevent re-triggering + st.session_state[upload_key].add(file_set_id) + + # Get the correct database ID for upload + vector_db_id = vector_db_obj.identifier if vector_db_obj and hasattr(vector_db_obj, 'identifier') else vector_db_name + + # Upload automatically + _upload_documents_to_database(vector_db_name, uploaded_files, vector_db_id) + + +def _upload_documents_to_database(vector_db_name, uploaded_files, vector_db_id=None): + """ + Upload documents to an existing vector database. + + Args: + vector_db_name (str): Name of the target vector database + uploaded_files: List of uploaded files from Streamlit file uploader + """ + try: + # Reset status + st.session_state["upload_status"] = None + st.session_state["upload_message"] = "" + + if not uploaded_files: + st.session_state["upload_status"] = "error" + st.session_state["upload_message"] = "No files selected for upload." + return + + # Convert uploaded files into RAGDocument instances + with st.spinner(f"Processing {len(uploaded_files)} file(s)..."): documents = [ RAGDocument( document_id=uploaded_file.name, content=data_url_from_file(uploaded_file), + metadata={"source": uploaded_file.name, "type": "uploaded_file"} # LlamaStack maps 'source' to chunk_metadata.source ) - for i, uploaded_file in enumerate(uploaded_files) + for uploaded_file in uploaded_files ] - - # Determine provider for vector IO - providers = llama_stack_api.client.providers.list() - vector_io_provider = None - for x in providers: - if x.api == "vector_io": - vector_io_provider = x.provider_id - break - - # Register new vector database - vector_db = llama_stack_api.client.vector_dbs.register( - vector_db_id=vector_db_name, - embedding_dimension=384, - embedding_model="all-MiniLM-L6-v2", - provider_id=vector_io_provider, - ) - vector_db_id = vector_db.identifier - - # Insert documents into the vector database + + # Insert documents into the existing vector database + actual_db_id = vector_db_id or vector_db_name + with st.spinner(f"Uploading documents to '{vector_db_name}'..."): llama_stack_api.client.tool_runtime.rag_tool.insert( - vector_db_id=vector_db_id, + vector_db_id=actual_db_id, # Use the correct database ID documents=documents, chunk_size_in_tokens=512, ) - st.success("Vector database created successfully!") - # Reset form fields - uploaded_files.clear() - vector_db_name = "" + + # Success + st.session_state["upload_status"] = "success" + st.session_state["upload_message"] = f"Successfully uploaded {len(uploaded_files)} document(s) to '{vector_db_name}'!" + + # Trigger refresh to show the success message + st.rerun() + + except Exception as e: + st.session_state["upload_status"] = "error" + st.session_state["upload_message"] = f"Error uploading documents: {str(e)}" + st.rerun() + + +def _get_documents_from_pgvector(vector_db_id): + """ + Query pgvector directly to get document IDs stored in the database. + Uses a connection pool for efficient connection reuse. + + Args: + vector_db_id (str): The vector database identifier + + Returns: + list: List of unique document IDs, or None if query fails + """ + try: + async def fetch_documents(): + try: + # Get connection from pool + pool = await _get_pg_pool() + + async with pool.acquire() as conn: + # Query for unique document IDs from the document JSONB column + # The vector_db_id is used as the table name with underscores replacing hyphens + table_name = f"vs_{vector_db_id.replace('-', '_')}" + + # Query metadata.source where LlamaStack stores the filename + # Try multiple paths since different upload methods use different structures: + # - Ingestion pipeline: metadata.source + # - Manual upload: chunk_metadata.source + # Fall back to auto-generated document_id if source is null + query = f""" + SELECT DISTINCT + COALESCE( + NULLIF(document->'metadata'->>'source', 'null'), + NULLIF(document->'chunk_metadata'->>'source', 'null'), + document->'metadata'->>'document_id' + ) as document_id + FROM {table_name} + WHERE document->'metadata'->>'document_id' IS NOT NULL + OR document->'metadata'->>'source' IS NOT NULL + ORDER BY document_id + """ + + queries = [query] + + doc_ids = [] + for query in queries: + try: + rows = await conn.fetch(query) + if rows: + doc_ids = [row['document_id'] for row in rows if row['document_id']] + if doc_ids: + break + except Exception as e: + continue # Try next query pattern + + return doc_ids if doc_ids else None + # Connection automatically returned to pool + + except Exception as e: + return None + + # Run the async function + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + return loop.run_until_complete(fetch_documents()) + + except Exception as e: + return None + + +def _delete_document_from_pgvector(vector_db_id, filename): + """ + Delete a document and all its chunks/embeddings from pgvector. + Uses a connection pool for efficient connection reuse. + + Args: + vector_db_id (str): The vector database identifier + filename (str): The filename/source to delete + + Returns: + tuple: (success: bool, deleted_count: int, error_message: str) + """ + try: + async def delete_document(): + try: + # Get connection from pool + pool = await _get_pg_pool() + + async with pool.acquire() as conn: + # The vector_db_id is used as the table name with underscores replacing hyphens + table_name = f"vs_{vector_db_id.replace('-', '_')}" + + # Delete all chunks where the source matches the filename + # Handle both document structures: + # - Ingestion pipeline: metadata.source + # - Manual upload: chunk_metadata.source + query = f""" + DELETE FROM {table_name} + WHERE document->'metadata'->>'source' = $1 + OR document->'chunk_metadata'->>'source' = $1 + """ + + result = await conn.execute(query, filename) + + # Parse the result to get the number of deleted rows + # Result format is like "DELETE 5" where 5 is the number of rows + deleted_count = int(result.split()[-1]) if result else 0 + + return True, deleted_count, None + # Connection automatically returned to pool + + except Exception as e: + return False, 0, str(e) + + # Run the async function + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + return loop.run_until_complete(delete_document()) + + except Exception as e: + return False, 0, str(e) + + +def _show_existing_documents_table(vector_db_name, vector_db_obj=None): + """ + Display information about documents in the selected vector database. + + Args: + vector_db_name (str): Display name of the selected vector database + vector_db_obj: The actual vector database object with identifier + """ + try: + # Get the correct vector database ID + if vector_db_obj and hasattr(vector_db_obj, 'identifier'): + vector_db_id = vector_db_obj.identifier + else: + vector_db_id = vector_db_name # Fallback to display name + + # Initialize session state for deletion status + if "delete_status" not in st.session_state: + st.session_state["delete_status"] = None + if "delete_message" not in st.session_state: + st.session_state["delete_message"] = "" + + # Show deletion status messages (before checking documents, so last delete shows) + if st.session_state["delete_status"] == "success": + st.success(st.session_state["delete_message"]) + st.session_state["delete_status"] = None + st.session_state["delete_message"] = "" + elif st.session_state["delete_status"] == "error": + st.error(st.session_state["delete_message"]) + st.session_state["delete_status"] = None + st.session_state["delete_message"] = "" + + with st.spinner("Checking for documents..."): + # First, try to get document list from pgvector directly + document_ids = _get_documents_from_pgvector(vector_db_id) + + if document_ids: + # Success! We have the actual document filenames + # Show heading for documents section + st.subheader(f"📄 Documents in '{vector_db_name}'") + + # Add CSS for bordered table rows + st.markdown(""" + + """, unsafe_allow_html=True) + + # Display table header + col1, col2, col3 = st.columns([0.5, 5, 0.5]) + with col1: + st.markdown("**#**") + with col2: + st.markdown("**Filename**") + with col3: + st.markdown("**Del**") + + # Display each document in a row with delete button + for idx, doc_id in enumerate(document_ids, start=1): + col1, col2, col3 = st.columns([0.5, 5, 0.5]) + + with col1: + st.write(idx) + + with col2: + st.write(doc_id) + + with col3: + delete_key = f"delete_{vector_db_name}_{doc_id}_{idx}" + + if st.button("✕", key=delete_key, help=f"Delete {doc_id}"): + # Delete immediately without confirmation + success, deleted_count, error = _delete_document_from_pgvector( + vector_db_id, + doc_id + ) + + if success: + st.session_state["delete_status"] = "success" + st.session_state["delete_message"] = f"✅ Successfully deleted '{doc_id}' ({deleted_count} chunk(s) removed)" + else: + st.session_state["delete_status"] = "error" + st.session_state["delete_message"] = f"❌ Failed to delete '{doc_id}': {error}" + + st.rerun() + + # else: Database appears empty or pgvector query not available + # For newly created databases, this is expected - just show nothing + # The upload section below will allow users to add documents + + except Exception as e: + st.error(f"Error loading document information: {str(e)}") + with st.expander("Error Details"): + st.code(traceback.format_exc()) -upload_page() \ No newline at end of file +upload_page() diff --git a/frontend/pyproject.toml b/frontend/pyproject.toml index be65e23..31e4233 100644 --- a/frontend/pyproject.toml +++ b/frontend/pyproject.toml @@ -15,6 +15,7 @@ dependencies = [ "streamlit-option-menu", "llama-stack==0.2.23", "fire", + "asyncpg", ] [tool.setuptools] diff --git a/tests/integration/test_upload_integration.py b/tests/integration/test_upload_integration.py index 549f28f..378ac1a 100644 --- a/tests/integration/test_upload_integration.py +++ b/tests/integration/test_upload_integration.py @@ -11,6 +11,22 @@ # Add the frontend directory to the path sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../frontend')) +# Mock all external dependencies before any imports from the upload module +# This is required because @patch decorators try to import the target module +mock_streamlit = MagicMock() +mock_streamlit.session_state = {} +sys.modules['streamlit'] = mock_streamlit +sys.modules['asyncpg'] = MagicMock() +sys.modules['pandas'] = MagicMock() + +# Mock llama_stack_client with a proper RAGDocument mock +mock_llama_stack_client = MagicMock() +def mock_rag_document(**kwargs): + """Create a dict-like RAGDocument mock""" + return kwargs +mock_llama_stack_client.RAGDocument = mock_rag_document +sys.modules['llama_stack_client'] = mock_llama_stack_client + # Configuration LLAMA_STACK_ENDPOINT = os.getenv("LLAMA_STACK_ENDPOINT", "http://localhost:8321") diff --git a/tests/unit/test_upload.py b/tests/unit/test_upload.py index 0e4f001..de5fe7b 100644 --- a/tests/unit/test_upload.py +++ b/tests/unit/test_upload.py @@ -2,217 +2,323 @@ Unit tests for the upload module Tests document upload and vector DB creation logic """ -import pytest -from unittest.mock import Mock, patch, MagicMock -import sys +import asyncio import os +import sys +from unittest.mock import AsyncMock, MagicMock, Mock, patch + +import pytest # Add the frontend directory to the path sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../frontend')) -# Mock streamlit before importing -sys.modules['streamlit'] = MagicMock() +# Mock all external dependencies before any imports from the upload module +# This is required because @patch decorators try to import the target module +mock_streamlit = MagicMock() +mock_streamlit.session_state = {} +sys.modules['streamlit'] = mock_streamlit +sys.modules['asyncpg'] = MagicMock() +sys.modules['pandas'] = MagicMock() + +# Mock llama_stack_client with a proper RAGDocument mock +mock_llama_stack_client = MagicMock() +def mock_rag_document(**kwargs): + """Create a dict-like RAGDocument mock""" + return kwargs +mock_llama_stack_client.RAGDocument = mock_rag_document +sys.modules['llama_stack_client'] = mock_llama_stack_client +# Now we can safely import modules that will be patched +# Pre-import the modules so @patch can find them +from llama_stack_ui.distribution.ui.modules import api, utils +from llama_stack_ui.distribution.ui.page.upload import upload as upload_module -class TestVectorDBConfiguration: - """Test vector database configuration and setup""" + +class MockAsyncContextManager: + """Mock async context manager for pool.acquire()""" + def __init__(self, conn): + self.conn = conn - def test_vector_db_default_name(self): - """Test default vector database name""" - default_name = "rag_vector_db" - assert default_name == "rag_vector_db" - assert len(default_name) > 0 + async def __aenter__(self): + return self.conn - def test_vector_db_embedding_dimension(self): - """Test that embedding dimension is set correctly for all-MiniLM-L6-v2""" - embedding_dimension = 384 - embedding_model = "all-MiniLM-L6-v2" - - assert embedding_dimension == 384 - assert embedding_model == "all-MiniLM-L6-v2" + async def __aexit__(self, exc_type, exc_val, exc_tb): + return None + + +def create_mock_pool_with_connection(mock_conn): + """ + Helper to create a mock connection pool that yields the given connection. - def test_chunk_size_configuration(self): - """Test that chunk size is set to 512 tokens""" - chunk_size = 512 - assert chunk_size == 512 + Args: + mock_conn: The mock connection to return from pool.acquire() + + Returns: + MagicMock: A mock pool with proper acquire() context manager + """ + mock_pool = MagicMock() + mock_pool.acquire.return_value = MockAsyncContextManager(mock_conn) + return mock_pool -class TestDocumentProcessing: - """Test document processing and RAGDocument creation""" +class TestGetDocumentsFromPgvector: + """Unit tests for _get_documents_from_pgvector function""" - def test_supported_file_types(self): - """Test that supported file types are correctly defined""" - supported_types = ["txt", "pdf", "doc", "docx"] - - assert "txt" in supported_types - assert "pdf" in supported_types - assert "doc" in supported_types - assert "docx" in supported_types + def test_get_documents_success(self): + """Test successful retrieval of documents from pgvector""" + # Setup mock connection + mock_conn = AsyncMock() + mock_rows = [ + {'document_id': 'document1.pdf'}, + {'document_id': 'document2.txt'}, + {'document_id': 'document3.docx'}, + ] + mock_conn.fetch = AsyncMock(return_value=mock_rows) + + # Create mock pool + mock_pool = create_mock_pool_with_connection(mock_conn) + + # Patch _get_pg_pool to return our mock pool + async def mock_get_pool(): + return mock_pool + + with patch.object(upload_module, '_get_pg_pool', mock_get_pool): + # Call the actual function + result = upload_module._get_documents_from_pgvector("my-test-db") + + # Verify the result + assert result == ['document1.pdf', 'document2.txt', 'document3.docx'] + + # Verify acquire was called (connection borrowed from pool) + mock_pool.acquire.assert_called_once() - def test_document_id_from_filename(self): - """Test that document ID is created from filename""" - filename = "test_document.pdf" - document_id = filename + def test_get_documents_empty_result(self): + """Test that empty results return None""" + # Setup mock connection with empty result + mock_conn = AsyncMock() + mock_conn.fetch = AsyncMock(return_value=[]) - assert document_id == "test_document.pdf" - assert document_id.endswith(".pdf") + mock_pool = create_mock_pool_with_connection(mock_conn) + + async def mock_get_pool(): + return mock_pool + + with patch.object(upload_module, '_get_pg_pool', mock_get_pool): + result = upload_module._get_documents_from_pgvector("empty-db") + + # Empty result should return None + assert result is None - @patch('llama_stack_ui.distribution.ui.modules.utils.data_url_from_file') - def test_rag_document_creation(self, mock_data_url): - """Test RAGDocument creation from uploaded file""" - from llama_stack_client import RAGDocument - - # Mock file and data URL - mock_data_url.return_value = "data:text/plain;base64,SGVsbG8gV29ybGQ=" - - mock_file = Mock() - mock_file.name = "test.txt" - - # Create RAGDocument as done in upload.py - document = RAGDocument( - document_id=mock_file.name, - content=mock_data_url(mock_file), - ) - - # RAGDocument returns a dict-like object - assert document['document_id'] == "test.txt" - assert document['content'].startswith("data:") - mock_data_url.assert_called_once() + def test_get_documents_connection_error(self): + """Test that connection errors return None""" + # Setup mock pool that raises an exception on acquire + async def mock_get_pool(): + raise Exception("Connection refused") + + with patch.object(upload_module, '_get_pg_pool', mock_get_pool): + result = upload_module._get_documents_from_pgvector("error-db") + + # Error should return None + assert result is None - def test_multiple_documents_processing(self): - """Test processing multiple uploaded files""" - # Simulate multiple uploaded files - mock_file1 = Mock() - mock_file1.name = "doc1.txt" - mock_file2 = Mock() - mock_file2.name = "doc2.pdf" - mock_file3 = Mock() - mock_file3.name = "doc3.docx" + def test_get_documents_filters_null_ids(self): + """Test that null document IDs are filtered out""" + mock_conn = AsyncMock() + mock_rows = [ + {'document_id': 'valid1.pdf'}, + {'document_id': None}, # Should be filtered + {'document_id': 'valid2.txt'}, + {'document_id': None}, # Should be filtered + ] + mock_conn.fetch = AsyncMock(return_value=mock_rows) - uploaded_files = [mock_file1, mock_file2, mock_file3] + mock_pool = create_mock_pool_with_connection(mock_conn) - # Simulate creating document list - document_ids = [f.name for f in uploaded_files] + async def mock_get_pool(): + return mock_pool - assert len(document_ids) == 3 - assert "doc1.txt" in document_ids - assert "doc2.pdf" in document_ids - assert "doc3.docx" in document_ids + with patch.object(upload_module, '_get_pg_pool', mock_get_pool): + result = upload_module._get_documents_from_pgvector("mixed-db") + + # Only valid IDs should be returned + assert result == ['valid1.pdf', 'valid2.txt'] + assert len(result) == 2 -class TestVectorDBOperations: - """Test vector database operations""" +class TestDeleteDocumentFromPgvector: + """Unit tests for _delete_document_from_pgvector function""" - @patch('llama_stack_ui.distribution.ui.modules.api.llama_stack_api') - def test_vector_db_registration_params(self, mock_api): - """Test that vector DB registration uses correct parameters""" - mock_client = Mock() - mock_api.client = mock_client + def test_delete_document_success(self): + """Test successful deletion of document from pgvector""" + mock_conn = AsyncMock() + mock_conn.execute = AsyncMock(return_value="DELETE 5") + + mock_pool = create_mock_pool_with_connection(mock_conn) - vector_db_id = "test_vector_db" - embedding_dimension = 384 - embedding_model = "all-MiniLM-L6-v2" - provider_id = "pgvector" - - # Simulate registration call - mock_client.vector_dbs.register( - vector_db_id=vector_db_id, - embedding_dimension=embedding_dimension, - embedding_model=embedding_model, - provider_id=provider_id, - ) - - # Verify the call was made with correct params - mock_client.vector_dbs.register.assert_called_once_with( - vector_db_id=vector_db_id, - embedding_dimension=embedding_dimension, - embedding_model=embedding_model, - provider_id=provider_id, - ) + async def mock_get_pool(): + return mock_pool + + with patch.object(upload_module, '_get_pg_pool', mock_get_pool): + success, count, error = upload_module._delete_document_from_pgvector( + "my-test-db", + "document.pdf" + ) + + assert success is True + assert count == 5 + assert error is None + mock_pool.acquire.assert_called_once() - @patch('llama_stack_ui.distribution.ui.modules.api.llama_stack_api') - def test_document_insertion_params(self, mock_api): - """Test that document insertion uses correct parameters""" - from llama_stack_client import RAGDocument + def test_delete_document_not_found(self): + """Test deletion when document doesn't exist""" + mock_conn = AsyncMock() + mock_conn.execute = AsyncMock(return_value="DELETE 0") - mock_client = Mock() - mock_api.client = mock_client + mock_pool = create_mock_pool_with_connection(mock_conn) - vector_db_id = "test_vector_db" - documents = [ - RAGDocument(document_id="doc1", content="content1"), - RAGDocument(document_id="doc2", content="content2"), + async def mock_get_pool(): + return mock_pool + + with patch.object(upload_module, '_get_pg_pool', mock_get_pool): + success, count, error = upload_module._delete_document_from_pgvector( + "my-test-db", + "nonexistent.pdf" + ) + + assert success is True + assert count == 0 + assert error is None + + def test_delete_document_connection_error(self): + """Test deletion with connection error""" + async def mock_get_pool(): + raise Exception("Connection refused") + + with patch.object(upload_module, '_get_pg_pool', mock_get_pool): + success, count, error = upload_module._delete_document_from_pgvector( + "my-test-db", + "document.pdf" + ) + + assert success is False + assert count == 0 + assert error is not None + assert "Connection refused" in str(error) + + +class TestCreateVectorDatabase: + """Unit tests for _create_vector_database function""" + + def test_create_vector_database_success(self): + """Test successful creation of vector database""" + # Mock the API client + mock_client = MagicMock() + mock_client.vector_dbs.list.return_value = [] + mock_client.providers.list.return_value = [ + MagicMock(api="vector_io", provider_id="pgvector") ] - chunk_size = 512 - - # Simulate insertion call - mock_client.tool_runtime.rag_tool.insert( - vector_db_id=vector_db_id, - documents=documents, - chunk_size_in_tokens=chunk_size, - ) - - # Verify the call was made - mock_client.tool_runtime.rag_tool.insert.assert_called_once() - call_args = mock_client.tool_runtime.rag_tool.insert.call_args - assert call_args[1]['vector_db_id'] == vector_db_id - assert call_args[1]['chunk_size_in_tokens'] == chunk_size - assert len(call_args[1]['documents']) == 2 + mock_client.vector_dbs.register.return_value = MagicMock() + + mock_api = MagicMock() + mock_api.client = mock_client + + # Mock session state + mock_st = MagicMock() + mock_st.session_state = {} + + with patch.object(upload_module, 'llama_stack_api', mock_api): + with patch.object(upload_module, 'st', mock_st): + upload_module._create_vector_database("new-test-db") + + # Verify registration was called with correct parameters + mock_client.vector_dbs.register.assert_called_once() + call_kwargs = mock_client.vector_dbs.register.call_args[1] + assert call_kwargs['vector_db_id'] == "new-test-db" + assert call_kwargs['embedding_model'] == "all-MiniLM-L6-v2" + assert call_kwargs['embedding_dimension'] == 384 + assert call_kwargs['provider_id'] == "pgvector" - @patch('llama_stack_ui.distribution.ui.modules.api.llama_stack_api') - def test_provider_detection(self, mock_api): - """Test vector IO provider detection""" - mock_client = Mock() + def test_create_vector_database_duplicate_name(self): + """Test that duplicate names are rejected""" + # Mock existing database with same name + existing_db = MagicMock() + existing_db.identifier = "existing-db" + + mock_client = MagicMock() + mock_client.vector_dbs.list.return_value = [existing_db] + + mock_api = MagicMock() mock_api.client = mock_client - # Mock provider list - mock_providers = [ - Mock(api="inference", provider_id="ollama"), - Mock(api="vector_io", provider_id="pgvector"), - Mock(api="memory", provider_id="redis"), + mock_st = MagicMock() + mock_st.session_state = {} + + with patch.object(upload_module, 'llama_stack_api', mock_api): + with patch.object(upload_module, 'st', mock_st): + upload_module._create_vector_database("existing-db") + + # Registration should NOT be called for duplicates + mock_client.vector_dbs.register.assert_not_called() + + # Error status should be set + assert mock_st.session_state.get("creation_status") == "error" + + def test_create_vector_database_no_provider(self): + """Test error when no vector_io provider exists""" + mock_client = MagicMock() + mock_client.vector_dbs.list.return_value = [] + mock_client.providers.list.return_value = [ + MagicMock(api="inference", provider_id="ollama") # No vector_io ] - mock_client.providers.list.return_value = mock_providers - # Simulate provider detection logic - providers = mock_client.providers.list() - vector_io_provider = None - for x in providers: - if x.api == "vector_io": - vector_io_provider = x.provider_id + mock_api = MagicMock() + mock_api.client = mock_client + + mock_st = MagicMock() + mock_st.session_state = {} - assert vector_io_provider == "pgvector" + with patch.object(upload_module, 'llama_stack_api', mock_api): + with patch.object(upload_module, 'st', mock_st): + upload_module._create_vector_database("new-db") + + # Registration should NOT be called without provider + mock_client.vector_dbs.register.assert_not_called() + + # Error status should be set + assert mock_st.session_state.get("creation_status") == "error" -class TestUploadValidation: - """Test upload validation and error handling""" - - def test_empty_upload_list(self): - """Test handling of empty upload list""" - uploaded_files = [] - assert len(uploaded_files) == 0 +class TestConnectionPool: + """Unit tests for the connection pool functionality""" - def test_upload_count_display(self): - """Test upload count display logic""" - uploaded_files = [Mock(), Mock(), Mock()] - count = len(uploaded_files) - message = f"Successfully uploaded {count} files" - - assert message == "Successfully uploaded 3 files" - assert str(count) in message - - def test_vector_db_name_validation(self): - """Test vector database name validation""" - # Valid names - valid_names = ["rag_vector_db", "test-db-123", "my_documents"] - for name in valid_names: - assert len(name) > 0 - assert name.replace('_', '').replace('-', '').isalnum() + def test_pool_is_reused(self): + """Test that the same pool is returned on subsequent calls""" + # Reset the global pool + upload_module._pg_pool = None + + mock_pool = AsyncMock() + + # Mock asyncpg.create_pool + async def mock_create_pool(**kwargs): + return mock_pool - # Invalid names should be caught - invalid_name = "" - assert len(invalid_name) == 0 + with patch.object(upload_module.asyncpg, 'create_pool', mock_create_pool): + # Get pool twice + async def get_pools(): + pool1 = await upload_module._get_pg_pool() + pool2 = await upload_module._get_pg_pool() + return pool1, pool2 + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + pool1, pool2 = loop.run_until_complete(get_pools()) + + # Should be the same pool instance + assert pool1 is pool2 + + # Clean up + upload_module._pg_pool = None if __name__ == "__main__": pytest.main([__file__, "-v"]) - From a1fc3e576958af2c12fad0edf0ea8a47ae8da47f Mon Sep 17 00:00:00 2001 From: Yuval Turgeman Date: Wed, 4 Feb 2026 14:49:19 -0500 Subject: [PATCH 10/18] Use oc project instead of namespace Signed-off-by: Yuval Turgeman --- deploy/helm/Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deploy/helm/Makefile b/deploy/helm/Makefile index 836f6f0..7ed3dc6 100644 --- a/deploy/helm/Makefile +++ b/deploy/helm/Makefile @@ -312,7 +312,7 @@ show-config: ## Show configuration file contents # Create namespace and deploy namespace: @echo -e "$(BLUE)[INFO]$(NC) Creating namespace $(NAMESPACE)..." - @oc create namespace $(NAMESPACE) &> /dev/null && oc label namespace $(NAMESPACE) modelmesh-enabled=false ||: + @oc new-project $(NAMESPACE) &> /dev/null ||: @oc project $(NAMESPACE) &> /dev/null ||: @echo -e "$(GREEN)[SUCCESS]$(NC) Namespace $(NAMESPACE) is ready" From 4729cc714fbfcc8448eaa2303fac003a97cf99a8 Mon Sep 17 00:00:00 2001 From: Yuval Turgeman Date: Wed, 4 Feb 2026 15:37:03 -0500 Subject: [PATCH 11/18] Try to label namespace for rhaoi Signed-off-by: Yuval Turgeman --- deploy/helm/Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deploy/helm/Makefile b/deploy/helm/Makefile index 7ed3dc6..a0d6b43 100644 --- a/deploy/helm/Makefile +++ b/deploy/helm/Makefile @@ -312,7 +312,7 @@ show-config: ## Show configuration file contents # Create namespace and deploy namespace: @echo -e "$(BLUE)[INFO]$(NC) Creating namespace $(NAMESPACE)..." - @oc new-project $(NAMESPACE) &> /dev/null ||: + @oc new-project $(NAMESPACE) &> /dev/null && oc label namespace $(NAMESPACE) modelmesh-enabled=false &> /dev/null ||: @oc project $(NAMESPACE) &> /dev/null ||: @echo -e "$(GREEN)[SUCCESS]$(NC) Namespace $(NAMESPACE) is ready" From e2b18ec99d3a8047d272099e912013cac3a7d3b8 Mon Sep 17 00:00:00 2001 From: Ganesh Murthy Date: Thu, 5 Feb 2026 13:01:24 -0500 Subject: [PATCH 12/18] Increase haproxy timeout and llamastack timeout to allow for upload of large PDFs --- deploy/helm/rag/templates/route.yaml | 3 +++ frontend/llama_stack_ui/distribution/ui/modules/api.py | 6 +++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/deploy/helm/rag/templates/route.yaml b/deploy/helm/rag/templates/route.yaml index 486adc6..2c7b4b8 100644 --- a/deploy/helm/rag/templates/route.yaml +++ b/deploy/helm/rag/templates/route.yaml @@ -4,6 +4,9 @@ metadata: name: {{ include "rag.fullname" . }} labels: {{- include "rag.labels" . | nindent 4 }} + annotations: + # 10 minute timeout for large document uploads + haproxy.router.openshift.io/timeout: 600s spec: to: kind: Service diff --git a/frontend/llama_stack_ui/distribution/ui/modules/api.py b/frontend/llama_stack_ui/distribution/ui/modules/api.py index 9602323..2aa58ce 100644 --- a/frontend/llama_stack_ui/distribution/ui/modules/api.py +++ b/frontend/llama_stack_ui/distribution/ui/modules/api.py @@ -11,9 +11,13 @@ class LlamaStackApi: def __init__(self): + # Timeout of 600 seconds (10 minutes) for large document uploads + # Default is 60 seconds which is too short for large PDFs + timeout = float(os.environ.get("LLAMA_STACK_TIMEOUT", "600")) + self.client = LlamaStackClient( base_url=os.environ.get("LLAMA_STACK_ENDPOINT", "http://localhost:8321"), - + timeout=timeout, provider_data={ "fireworks_api_key": os.environ.get("FIREWORKS_API_KEY", ""), "together_api_key": os.environ.get("TOGETHER_API_KEY", ""), From f8582cf5ac4e9a90cc0d5a89e2a8b38ab2ef36ca Mon Sep 17 00:00:00 2001 From: Yuval Turgeman Date: Mon, 9 Feb 2026 11:17:29 -0500 Subject: [PATCH 13/18] Update model readiness in README Signed-off-by: Yuval Turgeman --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index cd5e732..b682038 100644 --- a/README.md +++ b/README.md @@ -267,7 +267,7 @@ Watch for all pods to reach Running or Completed status. Key pods to watch inclu oc get pods -l component=predictor ``` -Look for **3/3** under the Ready column. +Look for **2/2** (or **3/3** when RAW_DEPLOYMENT=false) under the Ready column. 8. **Verify Installation** From 2eb03f9769b1e6c3ba244b2298d9be4d124095cc Mon Sep 17 00:00:00 2001 From: Yuval Turgeman Date: Wed, 11 Feb 2026 10:38:53 -0500 Subject: [PATCH 14/18] Refactor chat to use Responses API with agent and direct modes - Split chat.py into agent.py (tool calling), direct.py (manual RAG), chat.py (UI) - Implement agent mode using Responses API with file_search/web_search tools - Implement direct mode using Completions API with manual RAG - Add proper Python logging and fix linting issues - Fix reasoning persistence and display order in UI Co-Authored-By: Claude Sonnet 4.5 Signed-off-by: Yuval Turgeman --- deploy/helm/Makefile | 26 +- deploy/helm/rag/Chart.lock | 16 +- deploy/helm/rag/Chart.yaml | 12 +- .../llama_stack_ui/distribution/ui/app.py | 11 +- .../distribution/ui/modules/api.py | 2 +- .../distribution/ui/modules/utils.py | 40 +- .../ui/page/distribution/inspect.py | 3 - .../ui/page/distribution/providers.py | 3 - .../ui/page/distribution/scoring_functions.py | 5 +- .../ui/page/distribution/vector_dbs.py | 4 +- .../ui/page/evaluations/app_eval.py | 6 +- .../ui/page/evaluations/evaluations.py | 2 +- .../ui/page/evaluations/native_eval.py | 8 +- .../distribution/ui/page/playground/agent.py | 356 ++++++ .../distribution/ui/page/playground/chat.py | 1132 ++++++++--------- .../distribution/ui/page/playground/direct.py | 224 ++++ .../distribution/ui/page/upload/__init__.py | 1 - .../distribution/ui/page/upload/upload.py | 461 +++---- frontend/pyproject.toml | 4 +- 19 files changed, 1355 insertions(+), 961 deletions(-) create mode 100644 frontend/llama_stack_ui/distribution/ui/page/playground/agent.py create mode 100644 frontend/llama_stack_ui/distribution/ui/page/playground/direct.py diff --git a/deploy/helm/Makefile b/deploy/helm/Makefile index a0d6b43..e90f2f1 100644 --- a/deploy/helm/Makefile +++ b/deploy/helm/Makefile @@ -85,10 +85,24 @@ endef # Helper function to validate values file define validate_values_file echo -e "$(BLUE)[INFO]$(NC) Validating configuration values..."; \ - HF_TOKEN=$$(grep -A 2 "^llm-service:" "$(VALUES_FILE)" | grep "hf_token:" | sed 's/.*hf_token: *//' | tr -d '"' | tr -d ' '); \ - LLAMA_STACK_TAVILY=$$(grep "TAVILY_SEARCH_API_KEY:" "$(VALUES_FILE)" 2>/dev/null | sed 's/.*TAVILY_SEARCH_API_KEY: *//' | tr -d '"' | tr -d ' '); \ UPDATED=0; \ - if [ -z "$$HF_TOKEN" ] || [ "$$HF_TOKEN" = "" ]; then \ + \ + if [ -n "$$HF_TOKEN" ]; then \ + echo -e "$(GREEN)[SUCCESS]$(NC) Using HF_TOKEN from environment variable."; \ + sed -i.bak "/^llm-service:/,/^[^ ]/ s|hf_token:.*|hf_token: \"$$HF_TOKEN\"|" "$(VALUES_FILE)"; \ + UPDATED=1; \ + fi; \ + \ + if [ -n "$$TAVILY_API_KEY" ]; then \ + echo -e "$(GREEN)[SUCCESS]$(NC) Using TAVILY_API_KEY from environment variable."; \ + sed -i.bak "s/TAVILY_SEARCH_API_KEY:.*/TAVILY_SEARCH_API_KEY: \"$$TAVILY_API_KEY\"/" "$(VALUES_FILE)"; \ + UPDATED=1; \ + fi; \ + \ + HF_TOKEN_FILE=$$(grep -A 2 "^llm-service:" "$(VALUES_FILE)" | grep "hf_token:" | sed 's/.*hf_token: *//' | tr -d '"' | tr -d ' '); \ + LLAMA_STACK_TAVILY=$$(grep "TAVILY_SEARCH_API_KEY:" "$(VALUES_FILE)" 2>/dev/null | sed 's/.*TAVILY_SEARCH_API_KEY: *//' | tr -d '"' | tr -d ' '); \ + \ + if [ -z "$$HF_TOKEN_FILE" ] || [ "$$HF_TOKEN_FILE" = "" ]; then \ echo -e "$(YELLOW)[WARNING]$(NC) Hugging Face token is not set. Model downloads may fail."; \ echo -e "$(BLUE)[INFO]$(NC) Get your token from: https://huggingface.co/settings/tokens"; \ echo -e ""; \ @@ -102,6 +116,7 @@ define validate_values_file fi; \ echo -e ""; \ fi; \ + \ if [ -z "$$LLAMA_STACK_TAVILY" ] || [ "$$LLAMA_STACK_TAVILY" = "Paste-your-key-here" ]; then \ echo -e "$(YELLOW)[WARNING]$(NC) TAVILY search API key is not set. Web search will be disabled."; \ echo -e "$(BLUE)[INFO]$(NC) Get your key from: https://tavily.com/"; \ @@ -116,6 +131,7 @@ define validate_values_file fi; \ echo -e ""; \ fi; \ + \ if [ $$UPDATED -eq 1 ]; then \ echo -e "$(GREEN)[SUCCESS]$(NC) Configuration updated. Proceeding with installation..."; \ echo -e ""; \ @@ -479,11 +495,11 @@ install: ## Install the RAG deployment fi; \ if [ -n "$(LLM_URL)" ]; then \ echo -e "$(BLUE)[INFO]$(NC) Setting LLM URL: $(LLM_URL)"; \ - HELM_ARGS="$$HELM_ARGS --set global.models.$(LLM).url='$(LLM_URL)'"; \ + HELM_ARGS="$$HELM_ARGS --set global.models.$(LLM).url=$(LLM_URL)"; \ fi; \ if [ -n "$(LLM_API_TOKEN)" ]; then \ echo -e "$(BLUE)[INFO]$(NC) Setting LLM API token"; \ - HELM_ARGS="$$HELM_ARGS --set global.models.$(LLM).apiToken='$(LLM_API_TOKEN)'"; \ + HELM_ARGS="$$HELM_ARGS --set global.models.$(LLM).apiToken=$(LLM_API_TOKEN)"; \ fi; \ fi; \ if [ -n "$(SAFETY)" ]; then \ diff --git a/deploy/helm/rag/Chart.lock b/deploy/helm/rag/Chart.lock index 45c5ac1..36eeea4 100644 --- a/deploy/helm/rag/Chart.lock +++ b/deploy/helm/rag/Chart.lock @@ -1,21 +1,21 @@ dependencies: - name: pgvector repository: https://rh-ai-quickstart.github.io/ai-architecture-charts - version: 0.5.1 + version: 0.5.5 - name: llm-service repository: https://rh-ai-quickstart.github.io/ai-architecture-charts - version: 0.5.2 + version: 0.5.9 - name: configure-pipeline repository: https://rh-ai-quickstart.github.io/ai-architecture-charts - version: 0.5.4 + version: 0.5.6 - name: ingestion-pipeline repository: https://rh-ai-quickstart.github.io/ai-architecture-charts - version: 0.5.1 + version: 0.6.5 - name: llama-stack repository: https://rh-ai-quickstart.github.io/ai-architecture-charts - version: 0.5.2 + version: 0.6.10 - name: mcp-servers repository: https://rh-ai-quickstart.github.io/ai-architecture-charts - version: 0.5.7 -digest: sha256:d7abd4b5f5c4080a241c567f0bde351f927a5ac0d95fea4bbdf8f364f7a92866 -generated: "2025-12-05T10:53:08.788253807-05:00" + version: 0.5.15 +digest: sha256:a55fed2a9164c13653c7e341350f0a2eeba302cd627232800016be49fafc09db +generated: "2026-02-06T15:10:51.764446626-05:00" diff --git a/deploy/helm/rag/Chart.yaml b/deploy/helm/rag/Chart.yaml index 92e6fa9..624621d 100644 --- a/deploy/helm/rag/Chart.yaml +++ b/deploy/helm/rag/Chart.yaml @@ -7,26 +7,26 @@ appVersion: "0.2.30" dependencies: - name: pgvector - version: 0.5.1 + version: 0.5.5 repository: https://rh-ai-quickstart.github.io/ai-architecture-charts condition: pgvector.enabled - name: llm-service - version: 0.5.2 + version: 0.5.9 repository: https://rh-ai-quickstart.github.io/ai-architecture-charts condition: llm-service.enabled - name: configure-pipeline - version: 0.5.4 + version: 0.5.6 repository: https://rh-ai-quickstart.github.io/ai-architecture-charts condition: configure-pipeline.enabled - name: ingestion-pipeline - version: 0.5.1 + version: 0.6.5 repository: https://rh-ai-quickstart.github.io/ai-architecture-charts condition: ingestion-pipeline.enabled - name: llama-stack - version: 0.5.2 + version: 0.6.10 repository: https://rh-ai-quickstart.github.io/ai-architecture-charts condition: llama-stack.enabled - name: mcp-servers - version: 0.5.7 + version: 0.5.15 repository: https://rh-ai-quickstart.github.io/ai-architecture-charts condition: mcp-servers.enabled diff --git a/frontend/llama_stack_ui/distribution/ui/app.py b/frontend/llama_stack_ui/distribution/ui/app.py index 772cb91..ec08af7 100644 --- a/frontend/llama_stack_ui/distribution/ui/app.py +++ b/frontend/llama_stack_ui/distribution/ui/app.py @@ -3,9 +3,17 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import logging + import streamlit as st + def main(): + # Configure logging to show DEBUG messages by default + logging.basicConfig( + level=logging.DEBUG, + format='[%(levelname)s] %(name)s: %(message)s' + ) # Define available pages: path and icon pages = { "Chat": ("page/playground/chat.py", "💬"), @@ -15,7 +23,7 @@ def main(): # Build navigation items dynamically nav_items = [ - st.Page(path, title=name, icon=icon, default=(name == "Chat")) + st.Page(path, title=name, icon=icon, default=name == "Chat") for name, (path, icon) in pages.items() ] # Render navigation @@ -25,4 +33,3 @@ def main(): if __name__ == "__main__": main() - diff --git a/frontend/llama_stack_ui/distribution/ui/modules/api.py b/frontend/llama_stack_ui/distribution/ui/modules/api.py index 2aa58ce..6e884ed 100644 --- a/frontend/llama_stack_ui/distribution/ui/modules/api.py +++ b/frontend/llama_stack_ui/distribution/ui/modules/api.py @@ -14,7 +14,7 @@ def __init__(self): # Timeout of 600 seconds (10 minutes) for large document uploads # Default is 60 seconds which is too short for large PDFs timeout = float(os.environ.get("LLAMA_STACK_TIMEOUT", "600")) - + self.client = LlamaStackClient( base_url=os.environ.get("LLAMA_STACK_ENDPOINT", "http://localhost:8321"), timeout=timeout, diff --git a/frontend/llama_stack_ui/distribution/ui/modules/utils.py b/frontend/llama_stack_ui/distribution/ui/modules/utils.py index b8bf725..3d07310 100644 --- a/frontend/llama_stack_ui/distribution/ui/modules/utils.py +++ b/frontend/llama_stack_ui/distribution/ui/modules/utils.py @@ -60,15 +60,15 @@ def data_url_from_file(file) -> str: def get_vector_db_name(vector_db): """ Get the display name for a vector database. - Falls back to identifier if vector_db_name attribute is not present. - + Falls back to id if name attribute is not present. + Args: vector_db: Vector database object from API - + Returns: str: The vector database name """ - return getattr(vector_db, 'vector_db_name', vector_db.identifier) + return getattr(vector_db, 'name', vector_db.id) def get_question_suggestions(): @@ -91,39 +91,39 @@ def get_question_suggestions(): def get_suggestions_for_databases(selected_dbs, all_vector_dbs): """ Get combined question suggestions for selected databases. - + Args: selected_dbs: List of selected vector DB names all_vector_dbs: List of all vector DB objects from API - + Returns: List of tuples (question, source_db_name) """ suggestions_map = get_question_suggestions() combined_suggestions = [] - + if not suggestions_map: return [] - - # Create a mapping from vector_db_name to identifier - db_name_to_identifier = { - get_vector_db_name(vdb): vdb.identifier + + # Create a mapping from vector_db_name to id + db_name_to_id = { + get_vector_db_name(vdb): vdb.id for vdb in all_vector_dbs } - + for db_name in selected_dbs: - # Get the identifier for this database name - db_identifier = db_name_to_identifier.get(db_name) - - # Try both the identifier and the db_name as keys in the suggestions map + # Get the id for this database name + db_id = db_name_to_id.get(db_name) + + # Try both the id and the db_name as keys in the suggestions map questions = None - if db_identifier and db_identifier in suggestions_map: - questions = suggestions_map[db_identifier] + if db_id and db_id in suggestions_map: + questions = suggestions_map[db_id] elif db_name in suggestions_map: questions = suggestions_map[db_name] - + if questions: for question in questions: combined_suggestions.append((question, db_name)) - + return combined_suggestions diff --git a/frontend/llama_stack_ui/distribution/ui/page/distribution/inspect.py b/frontend/llama_stack_ui/distribution/ui/page/distribution/inspect.py index a159b7a..4683516 100644 --- a/frontend/llama_stack_ui/distribution/ui/page/distribution/inspect.py +++ b/frontend/llama_stack_ui/distribution/ui/page/distribution/inspect.py @@ -10,10 +10,7 @@ from streamlit_option_menu import option_menu import streamlit as st -from llama_stack_ui.distribution.ui.page.distribution.datasets import datasets -from llama_stack_ui.distribution.ui.page.distribution.eval_tasks import benchmarks from llama_stack_ui.distribution.ui.page.distribution.models import models -from llama_stack_ui.distribution.ui.page.distribution.scoring_functions import scoring_functions from llama_stack_ui.distribution.ui.page.distribution.shields import shields from llama_stack_ui.distribution.ui.page.distribution.providers import providers from llama_stack_ui.distribution.ui.page.distribution.vector_dbs import vector_dbs diff --git a/frontend/llama_stack_ui/distribution/ui/page/distribution/providers.py b/frontend/llama_stack_ui/distribution/ui/page/distribution/providers.py index 7549d0e..749ad81 100644 --- a/frontend/llama_stack_ui/distribution/ui/page/distribution/providers.py +++ b/frontend/llama_stack_ui/distribution/ui/page/distribution/providers.py @@ -29,6 +29,3 @@ def providers(): for api_name, providers in api_to_providers.items(): st.markdown(f"###### {api_name}") st.dataframe([p.to_dict() for p in providers], width=500) - - - diff --git a/frontend/llama_stack_ui/distribution/ui/page/distribution/scoring_functions.py b/frontend/llama_stack_ui/distribution/ui/page/distribution/scoring_functions.py index 3fd51bd..5f531f0 100644 --- a/frontend/llama_stack_ui/distribution/ui/page/distribution/scoring_functions.py +++ b/frontend/llama_stack_ui/distribution/ui/page/distribution/scoring_functions.py @@ -22,5 +22,8 @@ def scoring_functions(): scoring_functions_info = {s.identifier: s.to_dict() for s in sf_list} # Let user select and view a scoring function - selected_scoring_function = st.selectbox("Select a scoring function", list(scoring_functions_info.keys())) + selected_scoring_function = st.selectbox( + "Select a scoring function", + list(scoring_functions_info.keys()), + ) st.json(scoring_functions_info[selected_scoring_function], expanded=True) diff --git a/frontend/llama_stack_ui/distribution/ui/page/distribution/vector_dbs.py b/frontend/llama_stack_ui/distribution/ui/page/distribution/vector_dbs.py index 37c4661..0ccb9a9 100644 --- a/frontend/llama_stack_ui/distribution/ui/page/distribution/vector_dbs.py +++ b/frontend/llama_stack_ui/distribution/ui/page/distribution/vector_dbs.py @@ -4,10 +4,10 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack_ui.distribution.ui.modules.utils import get_vector_db_name import streamlit as st from llama_stack_ui.distribution.ui.modules.api import llama_stack_api +from llama_stack_ui.distribution.ui.modules.utils import get_vector_db_name def vector_dbs(): @@ -16,7 +16,7 @@ def vector_dbs(): """ st.header("Vector Databases") # Fetch all vector databases - vdb_list = llama_stack_api.client.vector_dbs.list() + vdb_list = llama_stack_api.client.vector_stores.list() if not vdb_list: st.info("No vector databases found.") return diff --git a/frontend/llama_stack_ui/distribution/ui/page/evaluations/app_eval.py b/frontend/llama_stack_ui/distribution/ui/page/evaluations/app_eval.py index f595382..e4dc6e7 100644 --- a/frontend/llama_stack_ui/distribution/ui/page/evaluations/app_eval.py +++ b/frontend/llama_stack_ui/distribution/ui/page/evaluations/app_eval.py @@ -85,8 +85,12 @@ def application_evaluation_page(): ) new_params[param_name] = value else: + label = ( + f"Enter value for **{param_name}** in " + f"{scoring_fn_id} in valid JSON format" + ) value = st.text_area( - f"Enter value for **{param_name}** in {scoring_fn_id} in valid JSON format", + label, value=json.dumps(param_value, indent=2), height=80, ) diff --git a/frontend/llama_stack_ui/distribution/ui/page/evaluations/evaluations.py b/frontend/llama_stack_ui/distribution/ui/page/evaluations/evaluations.py index be165adc..47ed906 100644 --- a/frontend/llama_stack_ui/distribution/ui/page/evaluations/evaluations.py +++ b/frontend/llama_stack_ui/distribution/ui/page/evaluations/evaluations.py @@ -10,4 +10,4 @@ def evaluations_page(): with tabs[1]: native_evaluation_page() -evaluations_page() \ No newline at end of file +evaluations_page() diff --git a/frontend/llama_stack_ui/distribution/ui/page/evaluations/native_eval.py b/frontend/llama_stack_ui/distribution/ui/page/evaluations/native_eval.py index 5539ea9..b4aecdc 100644 --- a/frontend/llama_stack_ui/distribution/ui/page/evaluations/native_eval.py +++ b/frontend/llama_stack_ui/distribution/ui/page/evaluations/native_eval.py @@ -5,7 +5,8 @@ from llama_stack_ui.distribution.ui.modules.api import llama_stack_api """ -Native Evaluation page: select a benchmark, configure eval candidate, and run full generation + scoring. +Native Evaluation page: select a benchmark, configure eval candidate, +and run full generation + scoring. """ def select_benchmark_1(): @@ -48,7 +49,8 @@ def define_eval_candidate_2(): st.subheader("2. Define Eval Candidate") st.info( - "Define generation configuration: choose 'model' for inference API or 'agent' for agent API." + "Define generation configuration: choose 'model' for inference API " + "or 'agent' for agent API." ) with st.expander("Define Eval Candidate", expanded=True): @@ -216,4 +218,4 @@ def native_evaluation_page(): run_evaluation_3() -native_evaluation_page() \ No newline at end of file +native_evaluation_page() diff --git a/frontend/llama_stack_ui/distribution/ui/page/playground/agent.py b/frontend/llama_stack_ui/distribution/ui/page/playground/agent.py new file mode 100644 index 0000000..649da36 --- /dev/null +++ b/frontend/llama_stack_ui/distribution/ui/page/playground/agent.py @@ -0,0 +1,356 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +""" +Agent mode implementation for chat with automatic tool calling. +""" + +import logging + +import streamlit as st + +from llama_stack_ui.distribution.ui.modules.api import llama_stack_api +from llama_stack_ui.distribution.ui.modules.utils import get_vector_db_name + + +logger = logging.getLogger(__name__) + + +def build_response_tools(toolgroup_selection, selected_vector_dbs, client): + """ + Convert toolgroup selections to LlamaStack Responses API compatible tool format. + + Args: + toolgroup_selection: List of selected toolgroup IDs + selected_vector_dbs: List of selected vector database names + client: LlamaStack client instance + + Returns: + List of tools in Responses API format (works for both Agent and Direct modes) + """ + agent_tools = [] + + for toolgroup_name in toolgroup_selection: + if toolgroup_name == "builtin::rag": + if len(selected_vector_dbs) > 0: + vector_dbs = client.vector_stores.list() or [] + vector_db_ids = [ + vector_db.id for vector_db in vector_dbs + if get_vector_db_name(vector_db) in selected_vector_dbs + ] + # Use file_search tool format + agent_tools.append({ + "type": "file_search", + "vector_store_ids": list(vector_db_ids), + }) + elif "web_search" in toolgroup_name or "search" in toolgroup_name.lower(): + # Convert search tools to web_search format + agent_tools.append({"type": "web_search"}) + elif toolgroup_name.startswith("mcp::"): + # For MCP tools, get server info + try: + toolgroups = client.toolgroups.list() + for toolgroup in toolgroups: + if str(toolgroup.identifier) == toolgroup_name: + agent_tools.append({ + "type": "mcp", + "server_label": toolgroup.args.get( + "name", str(toolgroup.identifier) + ), + "server_url": toolgroup.mcp_endpoint.uri, + }) + break + except Exception as e: + logger.logger.debug("Failed to get MCP server info for %s: %s", toolgroup_name, e) + else: + # For other toolgroups, get individual tools and convert to function format + try: + tools_in_group = client.tools.list(toolgroup_id=toolgroup_name) + for tool in tools_in_group: + # Convert to function tool dict + agent_tools.append({ + "type": "function", + "function": { + "name": tool.name, + "description": tool.description or "", + "parameters": tool.parameters or {} + } + }) + except Exception as e: + logger.logger.debug("Failed to get tools for %s: %s", toolgroup_name, e) + + return agent_tools + + +# ============================================================================ +# Agent Mode - Chunk Handlers +# ============================================================================ + +def handle_agent_file_search_chunk(state, selected_vector_dbs): + """Handle file_search tool chunk in agent mode.""" + if state.tool_used: + return + + # Show tool status message in persistent container + if selected_vector_dbs: + db_label = "vector store" if len(selected_vector_dbs) == 1 else "vector stores" + status_msg = ( + f"🛠 :grey[_Using file_search tool with {db_label}: " + f"{', '.join(selected_vector_dbs)}_]" + ) + else: + status_msg = "🛠 :grey[_Using file_search tool..._]" + + state.tool_status = status_msg + with state.containers.tool_status: + st.markdown(status_msg) + + +def handle_agent_web_search_chunk(state): + """Handle web_search tool chunk in agent mode.""" + if state.tool_used: + return + + status_msg = "🛠 :grey[_Using web_search tool..._]" + state.tool_status = status_msg + with state.containers.tool_status: + st.markdown(status_msg) + + +def handle_agent_output_item_done(chunk, state): + """Handle response.output_item.done - tool execution completion with results.""" + if not hasattr(chunk, 'item'): + return + + item = chunk.item + item_type = getattr(item, 'type', None) + + if item_type == "file_search_call": + # File search results + if hasattr(item, 'results') and item.results: + state.tool_results.append({ + 'title': '📄 File Search Results', + 'type': 'json', + 'content': item.results + }) + with state.containers.tool_results: + with st.expander("📄 File Search Results", expanded=False): + st.json(item.results) + + elif item_type == "web_search_call": + # Web search - API doesn't expose raw results, just status + pass + + elif item_type == "function_call": + # Function call output + if hasattr(item, 'output') and item.output: + tool_name = getattr(item, 'name', 'function') + state.tool_results.append({ + 'title': f'🔧 Tool Output: {tool_name}', + 'type': 'code', + 'content': str(item.output) + }) + with state.containers.tool_results: + with st.expander(f"🔧 Tool Output: {tool_name}", expanded=False): + st.code(str(item.output)) + + elif item_type == "mcp_call": + # MCP call output + if hasattr(item, 'output') and item.output: + tool_name = getattr(item, 'name', 'mcp') + state.tool_results.append({ + 'title': f'🔧 MCP Tool Output: {tool_name}', + 'type': 'code', + 'content': str(item.output) + }) + with state.containers.tool_results: + with st.expander(f"🔧 MCP Tool Output: {tool_name}", expanded=False): + st.code(str(item.output)) + + elif item_type and item_type.endswith("_call"): + # Generic handler for any other tool call types + if hasattr(item, 'results') and item.results: + formatted_name = item_type.replace("_", " ").title() + state.tool_results.append({ + 'title': f'🔧 {formatted_name} Results', + 'type': 'json', + 'content': item.results + }) + with state.containers.tool_results: + with st.expander(f"🔧 {formatted_name} Results", expanded=False): + st.json(item.results) + elif hasattr(item, 'output') and item.output: + formatted_name = item_type.replace("_", " ").title() + state.tool_results.append({ + 'title': f'🔧 {formatted_name} Output', + 'type': 'json', + 'content': item.output + }) + with state.containers.tool_results: + with st.expander(f"🔧 {formatted_name} Output", expanded=False): + st.json(item.output) + + +def handle_chunk_error(chunk): + """Handle error chunk and return whether to stop streaming.""" + error_msg = "Unknown error" + error_code = None + + # Try to get error from chunk.error first + if hasattr(chunk, 'error') and chunk.error: + if hasattr(chunk.error, 'message'): + error_msg = chunk.error.message + if hasattr(chunk.error, 'code'): + error_code = chunk.error.code + # Fallback to chunk attributes + elif hasattr(chunk, 'error_message'): + error_msg = chunk.error_message + + error_display = f"❌ Error: {error_msg}" + if error_code: + error_display += f" (Code: {error_code})" + + st.error(error_display) + logger.debug("Response failed: %s", error_msg) + return True # Stop streaming + + +def handle_chunk_completed(chunk): + """Handle completed chunk.""" + logger.debug("Response completed successfully") + if hasattr(chunk, 'stop_reason'): + logger.debug("Stop reason: %s", chunk.stop_reason) + + +def handle_chunk_done(chunk, state): + """Handle done chunk and finalize response.""" + has_output = ( + hasattr(chunk, 'response') and + hasattr(chunk.response, 'output_text') and + chunk.response.output_text + ) + if has_output: + state.full_response = chunk.response.output_text + + +def process_chunk_by_type(chunk, state, selected_vector_dbs): + """Process a single chunk based on its type. Returns True to stop streaming.""" + chunk_type = chunk.type + + # Handle file_search tool + if chunk_type == "response.file_search_call.in_progress": + handle_agent_file_search_chunk(state, selected_vector_dbs) + + # Handle web_search tool + elif chunk_type == "response.web_search_call.in_progress": + handle_agent_web_search_chunk(state) + + elif chunk_type in ("response.web_search_call.searching", + "response.web_search_call.completed"): + pass # Just for event tracking + + # Handle tool results + elif chunk_type == "response.output_item.done": + handle_agent_output_item_done(chunk, state) + + # Handle reasoning + elif chunk_type == "response.reasoning_text.delta": + if hasattr(chunk, 'delta') and chunk.delta: + state.update_reasoning(chunk.delta) + + # Handle message content + elif chunk_type == "response.output_text.delta": + if hasattr(chunk, 'delta') and chunk.delta: + state.update_message(chunk.delta) + + # Handle errors + elif chunk_type == "response.failed": + return handle_chunk_error(chunk) + + # Handle completion + elif chunk_type == "response.completed": + handle_chunk_completed(chunk) + + # Handle done + elif chunk_type == "response.done": + handle_chunk_done(chunk, state) + + return False # Continue streaming + + +# ============================================================================ +# Agent Mode - Main Functions +# ============================================================================ + +def stream_agent_response(response, state, selected_vector_dbs): + """ + Stream and process chunks from Responses API. + Updates state containers as chunks arrive. + """ + chunk_count = 0 + + for chunk in response: + chunk_count += 1 + logger.debug("Chunk #%s: type=%s", chunk_count, getattr(chunk, 'type', 'NO_TYPE')) + logger.debug(" -> Full chunk: %s", chunk) + + if hasattr(chunk, 'type'): + should_stop = process_chunk_by_type(chunk, state, selected_vector_dbs) + if should_stop: + break + + +def save_agent_response_to_session(state): + """Save agent response to session state.""" + state.finalize_reasoning() + state.finalize_message() + + response_dict = { + "role": "assistant", + "content": state.full_response, + "stop_reason": "end_of_message" + } + + if state.reasoning_text: + response_dict["reasoning"] = state.reasoning_text + if state.tool_status: + response_dict["tool_status"] = state.tool_status + if state.tool_results: + response_dict["tool_results"] = state.tool_results + + st.session_state.messages.append(response_dict) + + +def agent_process_prompt(prompt, state, config): + """Agent-based mode: Use Responses API with automatic tool calling.""" + # Build tools list from selected toolgroups + tools = build_response_tools( + config.toolgroup_selection, config.selected_vector_dbs, llama_stack_api.client + ) if config.toolgroup_selection else None + + # Build request for Responses API + request_kwargs = { + "model": config.model, + "instructions": config.system_prompt, + "input": prompt, + "conversation": config.conversation_id, + "temperature": config.sampling.temperature, + "max_infer_iters": config.sampling.max_infer_iters, + "stream": True, + } + + # Add tools if available + if tools: + request_kwargs["tools"] = tools + + logger.debug("Request: %s", request_kwargs) + response = llama_stack_api.client.responses.create(**request_kwargs) + + # Stream response and update UI + stream_agent_response(response, state, config.selected_vector_dbs) + + # Save response to session + save_agent_response_to_session(state) diff --git a/frontend/llama_stack_ui/distribution/ui/page/playground/chat.py b/frontend/llama_stack_ui/distribution/ui/page/playground/chat.py index 03e9193..081d233 100644 --- a/frontend/llama_stack_ui/distribution/ui/page/playground/chat.py +++ b/frontend/llama_stack_ui/distribution/ui/page/playground/chat.py @@ -4,671 +4,577 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import enum -import json -import uuid +""" +Chat page for LlamaStack UI with RAG support. +Provides both Direct mode (manual RAG) and Agent-based mode (automatic tool calling). +""" + +import logging +from dataclasses import dataclass import streamlit as st -from llama_stack_client.lib.agents.agent import Agent -from llama_stack_client.lib.agents.event_logger import EventLogger -from llama_stack_client.lib.agents.react.agent import ReActAgent -from llama_stack_client.lib.agents.react.tool_parser import ReActOutput -from llama_stack.apis.common.content_types import ToolCallDelta + from llama_stack_ui.distribution.ui.modules.api import llama_stack_api -from llama_stack_ui.distribution.ui.modules.utils import get_suggestions_for_databases, get_vector_db_name -from llama_stack_client.types import UserMessage -from llama_stack_client.types.shared_params import SamplingParams -from llama_stack_client.types.shared_params.response_format import JsonSchemaResponseFormat -from llama_stack_client.types.shared_params.sampling_params import StrategyTopPSamplingStrategy +from llama_stack_ui.distribution.ui.modules.utils import ( + get_suggestions_for_databases, + get_vector_db_name, +) +from llama_stack_ui.distribution.ui.page.playground.agent import ( + agent_process_prompt, +) +from llama_stack_ui.distribution.ui.page.playground.direct import ( + direct_process_prompt, +) + + +logger = logging.getLogger(__name__) + +def render_tool_results(tool_results): + """Render tool results from a message.""" + for tool_result in tool_results: + with st.expander(tool_result['title'], expanded=False): + if tool_result['type'] == 'json': + st.json(tool_result['content']) + else: + st.code(tool_result['content']) -class AgentType(enum.Enum): - REGULAR = "Regular" - REACT = "ReAct" +def render_message(msg): + """Render a single message in chat history.""" + with st.chat_message(msg['role']): + # Display tool status if present + if msg.get('tool_status'): + st.markdown(msg['tool_status']) -def get_strategy(temperature, top_p): - """Determines the sampling strategy for the LLM based on temperature.""" - return {'type': 'greedy'} if temperature == 0 else { - 'type': 'top_p', 'temperature': temperature, 'top_p': top_p - } + # Display tool results if present + if msg.get('tool_results'): + render_tool_results(msg['tool_results']) + # Display reasoning if present (right before the answer) + if msg.get('reasoning'): + with st.expander("🧠 Reasoning", expanded=False): + st.markdown(msg['reasoning']) -def render_history(tool_debug): - """Renders the chat history from the session state. - Also displays debug events for assistant messages if tool_debug is enabled. - """ + # Display the final answer + st.markdown(msg['content']) + + +def render_history(): + """Renders the chat history from the session state.""" # Initialize messages in the session state if not present if 'messages' not in st.session_state: st.session_state.messages = [{"role": "assistant", "content": "How can I help you?"}] - # Initialize debug_events in the session state if not present - if 'debug_events' not in st.session_state: - st.session_state.debug_events = [] - - for i, msg in enumerate(st.session_state.messages): - with st.chat_message(msg['role']): - st.markdown(msg['content']) - - # Display debug events expander for assistant messages (excluding the initial greeting) - if msg['role'] == 'assistant' and tool_debug and i > 0: - # Debug events are stored per assistant turn. - # The index for debug_events corresponds to the assistant message turn. - # messages: [A_initial, U_1, A_1, U_2, A_2, ...] - # debug_events: [events_for_A_1, events_for_A_2, ...] - # For A_1 (msg index 2), the debug_events index is (2//2)-1 = 0. - debug_event_list_index = (i // 2) - 1 - if 0 <= debug_event_list_index < len(st.session_state.debug_events): - current_turn_events_list = st.session_state.debug_events[debug_event_list_index] - - if current_turn_events_list: # Only show expander if there are events - with st.expander("Tool/Debug Events", expanded=False): - if isinstance(current_turn_events_list, list) and len(current_turn_events_list) > 0: - for event_idx, event_item in enumerate(current_turn_events_list): - with st.container(): - if isinstance(event_item, dict): - st.json(event_item, expanded=False) - elif isinstance(event_item, str): - st.text_area( - label=f"Debug Event {event_idx + 1}", - value=event_item, - height=100, - disabled=True, - key=f"debug_event_msg{i}_item{event_idx}" # Unique key for each text area - ) - else: - st.write(event_item) # Fallback for other data types - if event_idx < len(current_turn_events_list) - 1: - st.divider() - elif isinstance(current_turn_events_list, list) and not current_turn_events_list: - st.caption("No debug events recorded for this turn.") - else: # Should not happen with current logic - st.write("Debug data for this turn (unexpected format):") - st.write(current_turn_events_list) -def tool_chat_page(): - st.title("💬 Chat") + for msg in st.session_state.messages: + render_message(msg) + +def fetch_models_and_tools(): + """Fetch and categorize models and toolgroups from LlamaStack.""" client = llama_stack_api.client + + # Fetch models models = client.models.list() model_list = [model.identifier for model in models if model.api_model_type == "llm"] + # Fetch and categorize toolgroups tool_groups = client.toolgroups.list() + logger.debug("Raw tool groups from LlamaStack: %s", tool_groups) tool_groups_list = [tool_group.identifier for tool_group in tool_groups] + logger.debug("Tool group identifiers: %s", tool_groups_list) + mcp_tools_list = [tool for tool in tool_groups_list if tool.startswith("mcp::")] + logger.debug("MCP tools: %s", mcp_tools_list) + builtin_tools_list = [tool for tool in tool_groups_list if not tool.startswith("mcp::")] + logger.debug("Built-in tools: %s", builtin_tools_list) - selected_vector_dbs = [] + return model_list, builtin_tools_list, mcp_tools_list - def reset_agent(): - st.session_state.clear() - st.cache_resource.clear() - with st.sidebar: - st.title("Configuration") - st.subheader("Model") - model = st.selectbox(label="Model", options=model_list, on_change=reset_agent, label_visibility="collapsed") - - ## Added mode - processing_mode = st.radio( - "Processing mode", - ["Direct", "Agent-based"], - index=0, # Default to Direct - captions=[ - "Directly calls the model with optional RAG.", - "Uses an Agent with tools.", - ], - on_change=reset_agent, - help="Choose how requests are processed. 'Direct' bypasses agents, 'Agent-based' uses them.", - ) +def render_toolgroup_selection(builtin_tools_list, mcp_tools_list, selected_vector_dbs, + on_toolgroup_change, on_reset): + """Render toolgroup selection UI and return selected toolgroups.""" + st.subheader("Available ToolGroups") - - toolgroup_selection = [] - if processing_mode == "Direct": - vector_dbs = llama_stack_api.client.vector_dbs.list() or [] - if not vector_dbs: - st.info("No vector databases available for selection.") - vector_db_names = [get_vector_db_name(vector_db) for vector_db in vector_dbs] - selected_vector_dbs = st.multiselect( - label="Select Document Collections to use in RAG queries", - options=vector_db_names, - on_change=reset_agent, - ) - if processing_mode == "Agent-based": - st.subheader("Available ToolGroups") - - toolgroup_selection = st.pills( - label="Built-in tools", - options=builtin_tools_list, - selection_mode="multi", - on_change=reset_agent, - format_func=lambda tool: "".join(tool.split("::")[1:]), - help="List of built-in tools from your llama stack server.", - ) - - if "builtin::rag" in toolgroup_selection: - vector_dbs = llama_stack_api.client.vector_dbs.list() or [] - if not vector_dbs: - st.info("No vector databases available for selection.") - vector_db_names = [get_vector_db_name(vector_db) for vector_db in vector_dbs] - selected_vector_dbs = st.multiselect( - label="Select Document Collections to use in RAG queries", - options=vector_db_names, - on_change=reset_agent, - ) - - # Display mcp list only if there are mcp tools - if len(mcp_tools_list) > 0: - mcp_selection = st.pills( - label="MCP Servers", - options=mcp_tools_list, - selection_mode="multi", - on_change=reset_agent, - format_func=lambda tool: "".join(tool.split("::")[1:]), - help="List of MCP servers registered to your llama stack server.", - ) - - toolgroup_selection.extend(mcp_selection) - - grouped_tools = {} - total_tools = 0 - - for toolgroup_id in toolgroup_selection: - tools = client.tools.list(toolgroup_id=toolgroup_id) - grouped_tools[toolgroup_id] = [tool.identifier for tool in tools] - total_tools += len(tools) - - st.markdown(f"Active Tools: 🛠 {total_tools}") - - for group_id, tools in grouped_tools.items(): - with st.expander(f"🔧 Tools from `{group_id}`"): - for idx, tool in enumerate(tools, start=1): - st.markdown(f"{idx}. `{tool.split(':')[-1]}`") - - # st.subheader("Agent Configurations") - # st.subheader("Agent Type") - # agent_type = st.radio( - # label="Select Agent Type", - # options=["Regular", "ReAct"], - # on_change=reset_agent, - # ) - - # if agent_type == "ReAct": - # agent_type = AgentType.REACT - # else: - # agent_type = AgentType.REGULAR - agent_type = AgentType.REGULAR - - if processing_mode == "Agent-based": - input_shields = [] - output_shields = [] - - st.subheader("Security Shields") - shields_available = client.shields.list() - shield_options = [s.identifier for s in shields_available if hasattr(s, 'identifier')] - input_shields = st.multiselect("Input Shields", options=shield_options, on_change=reset_agent) - output_shields = st.multiselect("Output Shields", options=shield_options, on_change=reset_agent) - - st.subheader("Sampling Parameters") - temperature = st.slider("Temperature", 0.0, 2.0, 0.1, 0.05, on_change=reset_agent) - top_p = st.slider("Top P", 0.0, 1.0, 0.95, 0.05, on_change=reset_agent) - max_tokens = st.slider("Max Tokens", 1, 4096, 512, 64, on_change=reset_agent) - repetition_penalty = st.slider("Repetition Penalty", 1.0, 2.0, 1.0, 0.05, on_change=reset_agent) - - st.subheader("System Prompt") - default_prompt = "You are a helpful AI assistant." - if processing_mode == "Agent-based" and agent_type == AgentType.REACT: - default_prompt = "You are a helpful ReAct agent. Reason step-by-step to fulfill the user query using available tools." - system_prompt = st.text_area( - "System Prompt", value=default_prompt, on_change=reset_agent, height=100 + # Initialize toolgroup selector if not present + if "toolgroup_selector" not in st.session_state: + # Default: include RAG if a vector DB is selected + if selected_vector_dbs: + st.session_state["toolgroup_selector"] = ["builtin::rag"] + else: + st.session_state["toolgroup_selector"] = [] + + # Built-in tools selection (web_search, etc.) + toolgroup_selection = st.pills( + label="Built-in tools", + options=builtin_tools_list, + selection_mode="multi", + key="toolgroup_selector", + on_change=on_toolgroup_change, + format_func=lambda tool: "".join(tool.split("::")[1:]), + help="List of built-in tools from your llama stack server.", + ) + + # MCP tools selection (if available) + if mcp_tools_list: + mcp_selection = st.pills( + label="MCP Servers", + options=mcp_tools_list, + selection_mode="multi", + on_change=on_reset, + format_func=lambda tool: "".join(tool.split("::")[1:]), + help="List of MCP servers registered to your llama stack server.", ) + toolgroup_selection = list(toolgroup_selection) + list(mcp_selection) - st.subheader("Response Handling") - #stream_opt = st.toggle("Stream Response", value=True, on_change=reset_agent) - tool_debug = st.toggle("Show Tool/Debug Info", value=False) + # Display active tools summary + client = llama_stack_api.client + grouped_tools = {} + total_tools = 0 - if st.button("Clear Chat & Reset Config", use_container_width=True): - reset_agent() - st.rerun() - + for toolgroup_id in toolgroup_selection: + tools = client.tools.list(toolgroup_id=toolgroup_id) + logger.debug("Raw tools from toolgroup '%s': %s", toolgroup_id, tools) + grouped_tools[toolgroup_id] = [tool.name for tool in tools] + total_tools += len(tools) - updated_toolgroup_selection = [] - if processing_mode == "Agent-based": - for i, tool_name in enumerate(toolgroup_selection): - if tool_name == "builtin::rag": - if len(selected_vector_dbs) > 0: - vector_dbs = llama_stack_api.client.vector_dbs.list() or [] - vector_db_ids = [vector_db.identifier for vector_db in vector_dbs if get_vector_db_name(vector_db) in selected_vector_dbs] - tool_dict = dict( - name="builtin::rag/knowledge_search", - args={ - "vector_db_ids": list(vector_db_ids), - # Defaults - "query_config": { - "chunk_size_in_tokens": 512, - "chunk_overlap_in_tokens": 50, - }, - }, - ) - updated_toolgroup_selection.append(tool_dict) - else: - updated_toolgroup_selection.append(tool_name) - - @st.cache_resource - def create_agent(): - if "agent_type" in st.session_state and st.session_state.agent_type == AgentType.REACT: - return ReActAgent( - client=client, - model=model, - tools=updated_toolgroup_selection, - response_format=JsonSchemaResponseFormat( - type="json_schema", - json_schema=ReActOutput.model_json_schema() - ), - sampling_params=SamplingParams( - strategy=StrategyTopPSamplingStrategy(type="top_p", temperature=temperature, top_p=top_p), - max_tokens=max_tokens, - repetition_penalty=repetition_penalty, - ), - input_shields= input_shields, - output_shields= output_shields, - ) - else: - updated_system_prompt = system_prompt.strip() - updated_system_prompt = updated_system_prompt if updated_system_prompt.strip().endswith('.') else updated_system_prompt + '.' - return Agent( - client, - model=model, - instructions=f"{updated_system_prompt} When you use a tool always respond with a summary of the result.", - tools=updated_toolgroup_selection, - sampling_params=SamplingParams( - strategy=StrategyTopPSamplingStrategy(type="top_p", temperature=temperature, top_p=top_p), - max_tokens=max_tokens, - repetition_penalty=repetition_penalty, - ), - input_shields= input_shields, - output_shields= output_shields, - ) + logger.debug("Grouped tools summary: %s", grouped_tools) + + if total_tools > 0: + st.markdown(f"Active Tools: 🛠 {total_tools}") + for group_id, tools in grouped_tools.items(): + with st.expander(f"🔧 Tools from `{group_id}`"): + for idx, tool in enumerate(tools, start=1): + st.markdown(f"{idx}. `{tool.split(':')[-1]}`") + + return toolgroup_selection + + +def reset_agent(): + """Reset the agent by clearing session state and cache.""" + st.session_state.clear() + st.cache_resource.clear() + + +def create_vector_db_callbacks(processing_mode, vector_dbs): + """Create callbacks for vector DB and toolgroup synchronization.""" + def on_vector_db_change(): + """When vector DB changes, update toolgroup selection""" + if processing_mode != "Agent-based": + return + + selected_vdbs = st.session_state.get("chat_vector_db_selector", []) + current_toolgroups = st.session_state.get("toolgroup_selector", []) + + if not selected_vdbs: + # Remove RAG from toolgroups + if "builtin::rag" in current_toolgroups: + filtered = [t for t in current_toolgroups if t != "builtin::rag"] + st.session_state["toolgroup_selector"] = filtered + else: + # Add RAG to toolgroups if not present + if "builtin::rag" not in current_toolgroups: + updated = list(current_toolgroups) + ["builtin::rag"] + st.session_state["toolgroup_selector"] = updated + + def on_toolgroup_change(): + """When toolgroup changes, update vector DB selection""" + current_toolgroups = st.session_state.get("toolgroup_selector", []) + selected_vdbs = st.session_state.get("chat_vector_db_selector", []) + + if "builtin::rag" not in current_toolgroups: + # RAG deselected, clear vector DB selection + st.session_state["chat_vector_db_selector"] = [] + elif "builtin::rag" in current_toolgroups and not selected_vdbs: + # RAG selected but no vector DB, select first available + if vector_dbs: + first_vdb = get_vector_db_name(vector_dbs[0]) + st.session_state["chat_vector_db_selector"] = [first_vdb] + + return on_vector_db_change, on_toolgroup_change + + +def render_sidebar_configuration(model_list, builtin_tools_list, mcp_tools_list): + """Render sidebar configuration and return selected parameters.""" + st.title("Configuration") + st.subheader("Model") + model = st.selectbox( + label="Model", + options=model_list, + on_change=reset_agent, + label_visibility="collapsed", + ) + + # Processing Mode + processing_mode = st.radio( + "Processing mode", + ["Direct", "Agent-based"], + index=0, + captions=[ + "Passes vector store search results as context to LLM", + "Uses Responses API with tool calling", + ], + on_change=reset_agent, + help="Choose how requests are processed.", + ) + + # Vector Database Selection + vector_dbs = list(llama_stack_api.client.vector_stores.list() or []) + on_vector_db_change, on_toolgroup_change = create_vector_db_callbacks( + processing_mode, vector_dbs + ) + + selected_vector_dbs = render_vector_db_selector( + vector_dbs, processing_mode, on_vector_db_change + ) + + # Toolgroup Selection (Agent-based mode only) + toolgroup_selection = [] if processing_mode == "Agent-based": - st.session_state.agent_type = agent_type - agent = create_agent() + toolgroup_selection = render_toolgroup_selection( + builtin_tools_list, mcp_tools_list, selected_vector_dbs, + on_toolgroup_change, reset_agent + ) + + # Sampling Parameters + st.subheader("Sampling Parameters") + temperature = st.slider( + "Temperature", + 0.0, 2.0, 0.1, 0.05, + on_change=reset_agent, + help="Controls randomness. Higher values = more random.", + ) + max_infer_iters = st.slider( + "Max Inference Iterations", + 1, 50, 10, 1, + on_change=reset_agent, + help="Maximum number of inference iterations before stopping", + ) + + # System Prompt + st.subheader("System Prompt") + default_prompt = "You are a helpful AI assistant." + system_prompt = st.text_area( + "System Prompt", value=default_prompt, on_change=reset_agent, height=100 + ) + + if st.button("Clear Chat & Reset Config", use_container_width=True): + reset_agent() + st.rerun() - if "agent_session_id" not in st.session_state: - st.session_state["agent_session_id"] = agent.create_session(session_name=f"tool_demo_{uuid.uuid4()}") + return { + 'model': model, + 'processing_mode': processing_mode, + 'selected_vector_dbs': selected_vector_dbs, + 'toolgroup_selection': toolgroup_selection, + 'temperature': temperature, + 'max_infer_iters': max_infer_iters, + 'system_prompt': system_prompt, + } - session_id = st.session_state["agent_session_id"] + +def render_vector_db_selector(vector_dbs, processing_mode, on_vector_db_change): + """Render vector database selector and return selected databases.""" + selected_vector_dbs = [] + + # Initialize vector DB selector if not present + if "chat_vector_db_selector" not in st.session_state: + st.session_state["chat_vector_db_selector"] = [] + + if not vector_dbs: + return selected_vector_dbs + + vector_db_names = [get_vector_db_name(vector_db) for vector_db in vector_dbs] + selected_vector_dbs = st.multiselect( + label="Select Document Collections for RAG queries", + options=vector_db_names, + key="chat_vector_db_selector", + on_change=on_vector_db_change, + help=( + "Select one or more vector databases to use for retrieval, " + "or leave empty for normal chat" + ) + ) + + # Store the selected DBs for Direct mode + if selected_vector_dbs: + if processing_mode == "Direct": + st.session_state["direct_vector_dbs"] = [ + vdb for vdb in vector_dbs + if get_vector_db_name(vdb) in selected_vector_dbs + ] + else: + # Clear direct_vector_dbs if nothing is selected + if "direct_vector_dbs" in st.session_state: + del st.session_state["direct_vector_dbs"] + + return selected_vector_dbs + + +def initialize_session_state(): + """Initialize session state variables.""" + if "conversation_id" not in st.session_state: + conversation = llama_stack_api.client.conversations.create() + st.session_state["conversation_id"] = conversation.id + logger.debug("Created new conversation: %s", conversation.id) if "messages" not in st.session_state: - st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you?", "stop_reason": "end_of_turn"}] - - if "debug_events" not in st.session_state: # Per-turn debug logs - st.session_state["debug_events"] = [] - + st.session_state["messages"] = [ + { + "role": "assistant", + "content": "How can I help you?", + "stop_reason": "end_of_turn", + } + ] + if "show_more_questions" not in st.session_state: st.session_state["show_more_questions"] = False - + if "selected_question" not in st.session_state: st.session_state["selected_question"] = None - render_history(tool_debug) # Display the current chat history and any past debug events - # Display suggested questions if databases are selected - def display_suggested_questions(): - """Display suggested questions based on selected databases.""" - if not selected_vector_dbs: - return - - vector_dbs = llama_stack_api.client.vector_dbs.list() or [] - suggestions = get_suggestions_for_databases(selected_vector_dbs, vector_dbs) - - if not suggestions: - return - - st.markdown("### 💡 Suggested Questions") - - # Determine how many questions to show - num_to_show = len(suggestions) if st.session_state.show_more_questions else min(4, len(suggestions)) - - # Display questions in a grid-like format using columns - cols_per_row = 2 - for i in range(0, num_to_show, cols_per_row): - cols = st.columns(cols_per_row) - for j in range(cols_per_row): - idx = i + j - if idx < num_to_show: - question, db_name = suggestions[idx] - with cols[j]: - # Create a button for each question - button_key = f"question_btn_{idx}_{hash(question)}" - if st.button( - question, - key=button_key, - use_container_width=True, - help=f"From: {db_name}" - ): - st.session_state.selected_question = question - st.rerun() - - # Show "Show More" or "Show Less" button if there are more than 4 questions - if len(suggestions) > 4: - col1, col2, col3 = st.columns([1, 1, 1]) - with col2: - if st.session_state.show_more_questions: - if st.button("Show Less", use_container_width=True): - st.session_state.show_more_questions = False - st.rerun() - else: - if st.button(f"Show More ({len(suggestions) - 4} more)", use_container_width=True): - st.session_state.show_more_questions = True - st.rerun() - - st.markdown("---") - - display_suggested_questions() - - def response_generator(turn_response, debug_events_list): - if st.session_state.get("agent_type") == AgentType.REACT: - return _handle_react_response(turn_response) +# ============================================================================ +# Configuration Classes +# ============================================================================ + +@dataclass +class SamplingParams: + """Sampling parameters for model inference.""" + temperature: float + max_infer_iters: int + + +@dataclass +class ChatConfig: + """Configuration for chat processing.""" + model: str + processing_mode: str + system_prompt: str + conversation_id: str + toolgroup_selection: list + selected_vector_dbs: list + sampling: SamplingParams + + +# ============================================================================ +# UI Container Classes +# ============================================================================ + +class Containers: + """Simple container for UI elements. + + Note: Containers are created in visual display order (top to bottom). + """ + def __init__(self): + # Create containers in visual order: tools -> reasoning -> message + self.tool_status = st.container() + self.tool_results = st.container() + self.reasoning = st.empty() + self.message = st.empty() + +class ResponseState: + """ + State container for assistant response UI components and data. + Shared by both Agent-based and Direct modes. + """ + def __init__(self): + # UI containers (grouped) + self.containers = Containers() + + # Reasoning state + self.reasoning_text = "" + self.reasoning_placeholder = None + + # Tool state + self.tool_status = None + self.tool_results = [] + + # Response content + self.full_response = "" + + @property + def has_reasoning(self): + """Check if reasoning has been started.""" + return self.reasoning_placeholder is not None + + @property + def tool_used(self): + """Check if any tool has been used.""" + return self.tool_status is not None + + def update_reasoning(self, delta_text): + """Add reasoning text and update display.""" + self.reasoning_text += delta_text + + # Create reasoning expander on first delta + if not self.has_reasoning: + with self.containers.reasoning.container(): + reasoning_expander = st.expander("🧠 Reasoning", expanded=True) + self.reasoning_placeholder = reasoning_expander.empty() + + # Update reasoning text with cursor + if self.reasoning_placeholder: + self.reasoning_placeholder.markdown(self.reasoning_text + "▌") + + def finalize_reasoning(self): + """Remove cursor from reasoning display.""" + if self.reasoning_placeholder and self.reasoning_text: + self.reasoning_placeholder.markdown(self.reasoning_text) + + def update_message(self, delta_text): + """Add message text and update display.""" + self.full_response += delta_text + self.containers.message.markdown(self.full_response + "▌") + + def finalize_message(self): + """Remove cursor from message display.""" + self.containers.message.markdown(self.full_response) + + +# ============================================================================ +# Suggested Questions UI +# ============================================================================ + +def render_question_button(question, db_name, idx): + """Render a single question button.""" + button_key = f"question_btn_{idx}_{hash(question)}" + if st.button( + question, + key=button_key, + use_container_width=True, + help=f"From: {db_name}" + ): + st.session_state.selected_question = question + st.rerun() + + +def render_question_grid(suggestions, num_to_show): + """Render questions in a grid layout.""" + cols_per_row = 2 + for i in range(0, num_to_show, cols_per_row): + cols = st.columns(cols_per_row) + for j in range(cols_per_row): + idx = i + j + if idx < num_to_show: + question, db_name = suggestions[idx] + with cols[j]: + render_question_button(question, db_name, idx) + + +def render_show_more_button(suggestions): + """Render show more/less button if needed.""" + if len(suggestions) <= 4: + return + + _, col2, _ = st.columns([1, 1, 1]) + with col2: + if st.session_state.show_more_questions: + if st.button("Show Less", use_container_width=True): + st.session_state.show_more_questions = False + st.rerun() else: - return _handle_regular_response(turn_response, debug_events_list) - - def _handle_react_response(turn_response): - current_step_content = "" - final_answer = None - tool_results = [] - - for response in turn_response: - if not hasattr(response.event, "payload"): - yield ( - "\n\n🚨 :red[_Llama Stack server Error:_]\n" - "The response received is missing an expected `payload` attribute.\n" - "This could indicate a malformed response or an internal issue within the server.\n\n" - f"Error details: {response}" - ) - return - - payload = response.event.payload - - if payload.event_type == "step_progress" and hasattr(payload.delta, "text"): - current_step_content += payload.delta.text - continue - - if payload.event_type == "step_complete": - step_details = payload.step_details - - if step_details.step_type == "inference": - yield from _process_inference_step(current_step_content, tool_results, final_answer) - current_step_content = "" - elif step_details.step_type == "tool_execution": - tool_results = _process_tool_execution(step_details, tool_results) - current_step_content = "" - else: - current_step_content = "" - - if not final_answer and tool_results: - yield from _format_tool_results_summary(tool_results) - - def _process_inference_step(current_step_content, tool_results, final_answer): - try: - react_output_data = json.loads(current_step_content) - thought = react_output_data.get("thought") - action = react_output_data.get("action") - answer = react_output_data.get("answer") - - if answer and answer != "null" and answer is not None: - final_answer = answer - - if thought: - with st.expander("🤔 Thinking...", expanded=False): - st.markdown(f":grey[__{thought}__]") - - if action and isinstance(action, dict): - tool_name = action.get("tool_name") - tool_params = action.get("tool_params") - with st.expander(f'🛠 Action: Using tool "{tool_name}"', expanded=False): - st.json(tool_params) - - if answer and answer != "null" and answer is not None: - yield f"\n\n✅ **Final Answer:**\n{answer}" - - except json.JSONDecodeError: - yield f"\n\nFailed to parse ReAct step content:\n```json\n{current_step_content}\n```" - except Exception as e: - yield f"\n\nFailed to process ReAct step: {e}\n```json\n{current_step_content}\n```" - - return final_answer - - def _process_tool_execution(step_details, tool_results): - try: - if hasattr(step_details, "tool_responses") and step_details.tool_responses: - for tool_response in step_details.tool_responses: - tool_name = tool_response.tool_name - content = tool_response.content - tool_results.append((tool_name, content)) - with st.expander(f'⚙️ Observation (Result from "{tool_name}")', expanded=False): - try: - parsed_content = json.loads(content) - st.json(parsed_content) - except json.JSONDecodeError: - st.code(content, language=None) - else: - with st.expander("⚙️ Observation", expanded=False): - st.markdown(":grey[_Tool execution step completed, but no response data found._]") - except Exception as e: - with st.expander("⚙️ Error in Tool Execution", expanded=False): - st.markdown(f":red[_Error processing tool execution: {str(e)}_]") - - return tool_results - - def _format_tool_results_summary(tool_results): - yield "\n\n**Here's what I found:**\n" - for tool_name, content in tool_results: - try: - parsed_content = json.loads(content) - - if tool_name == "web_search" and "top_k" in parsed_content: - yield from _format_web_search_results(parsed_content) - elif "results" in parsed_content and isinstance(parsed_content["results"], list): - yield from _format_results_list(parsed_content["results"]) - elif isinstance(parsed_content, dict) and len(parsed_content) > 0: - yield from _format_dict_results(parsed_content) - elif isinstance(parsed_content, list) and len(parsed_content) > 0: - yield from _format_list_results(parsed_content) - except json.JSONDecodeError: - yield f"\n**{tool_name}** was used but returned complex data. Check the observation for details.\n" - except (TypeError, AttributeError, KeyError, IndexError) as e: - print(f"Error processing {tool_name} result: {type(e).__name__}: {e}") - - def _format_web_search_results(parsed_content): - for i, result in enumerate(parsed_content["top_k"], 1): - if i <= 3: - title = result.get("title", "Untitled") - url = result.get("url", "") - content_text = result.get("content", "").strip() - yield f"\n- **{title}**\n {content_text}\n [Source]({url})\n" - - def _format_results_list(results): - for i, result in enumerate(results, 1): - if i <= 3: - if isinstance(result, dict): - name = result.get("name", result.get("title", "Result " + str(i))) - description = result.get("description", result.get("content", result.get("summary", ""))) - yield f"\n- **{name}**\n {description}\n" - else: - yield f"\n- {result}\n" - - def _format_dict_results(parsed_content): - yield "\n```\n" - for key, value in list(parsed_content.items())[:5]: - if isinstance(value, str) and len(value) < 100: - yield f"{key}: {value}\n" - else: - yield f"{key}: [Complex data]\n" - yield "```\n" - - def _format_list_results(parsed_content): - yield "\n" - for _, item in enumerate(parsed_content[:3], 1): - if isinstance(item, str): - yield f"- {item}\n" - elif isinstance(item, dict) and "text" in item: - yield f"- {item['text']}\n" - elif isinstance(item, dict) and len(item) > 0: - first_value = next(iter(item.values())) - if isinstance(first_value, str) and len(first_value) < 100: - yield f"- {first_value}\n" - - def _handle_regular_response(turn_response, debug_events_list): - - # Use itertools.tee to duplicate the stream for UI and debug logging - # This is crucial because a generator can only be consumed once. - from itertools import tee - ui_stream, debug_log_stream = tee(turn_response, 2) - - for response in ui_stream: - if hasattr(response.event, "payload"): - if response.event.payload.event_type == "step_progress": - if hasattr(response.event.payload.delta, "text"): - yield response.event.payload.delta.text - if response.event.payload.event_type == "step_complete": - if response.event.payload.step_details.step_type == "tool_execution": - if response.event.payload.step_details.tool_calls: - tool_name = str(response.event.payload.step_details.tool_calls[0].tool_name) - yield f'\n\n🛠 :grey[_Using "{tool_name}" tool:_]\n\n' - else: - yield "No tool_calls present in step_details" - if response.event.payload.step_details.step_type == "shield_call": - if response.event.payload.step_details.violation: - yield response.event.payload.step_details.violation.user_message - else: - yield f"Error occurred in the Llama Stack Cluster: {response}" - debug_events_list.append({"type": "warning", "source": "_handle_regular_response", "details": "Unexpected event structure", "event": str(response)[:200]}) - - # Process the debug log stream separately - # EventLogger helps parse and structure these events - for log_entry in EventLogger().log(debug_log_stream): - if log_entry.role == "tool_execution": # Or other relevant roles - debug_events_list.append({"type": "tool_log", "content": log_entry.content}) - # Add other log types as needed for debugging - - def agent_process_prompt(prompt, debug_events_list): - print(f"In agent_process_prompt: {prompt}") - # Send the prompt to the agent - turn_response = agent.create_turn( - session_id=session_id, - messages=[UserMessage(role="user", content=prompt)], - stream=True, + button_text = f"Show More ({len(suggestions) - 4} more)" + if st.button(button_text, use_container_width=True): + st.session_state.show_more_questions = True + st.rerun() + + +def display_suggested_questions(selected_vector_dbs): + """Display suggested questions based on selected databases.""" + if not selected_vector_dbs: + return + + vector_dbs = list(llama_stack_api.client.vector_stores.list() or []) + suggestions = get_suggestions_for_databases(selected_vector_dbs, vector_dbs) + + if not suggestions: + return + + st.markdown("### 💡 Suggested Questions") + + # Determine how many questions to show + num_to_show = ( + len(suggestions) if st.session_state.show_more_questions + else min(4, len(suggestions)) + ) + + # Display questions and controls + render_question_grid(suggestions, num_to_show) + render_show_more_button(suggestions) + st.markdown("---") + + +# ============================================================================ +# Main Prompt Processing +# ============================================================================ + +def process_prompt(prompt, config): + """Process user prompt and generate response.""" + st.session_state.messages.append({"role": "user", "content": prompt}) + with st.chat_message("user"): + st.markdown(prompt) + + # Create assistant message context and setup shared containers once + with st.chat_message("assistant"): + state = ResponseState() + + # Call the appropriate mode-specific function + if config.processing_mode == "Direct": + direct_process_prompt(prompt, state, config) + elif config.processing_mode == "Agent-based": + agent_process_prompt(prompt, state, config) + + st.rerun() + + +def tool_chat_page(): + """Main chat page with RAG support in Direct and Agent-based modes.""" + st.title("💬 Chat") + + # Fetch models and tools + model_list, builtin_tools_list, mcp_tools_list = fetch_models_and_tools() + + # Render sidebar and get configuration + with st.sidebar: + sidebar_config = render_sidebar_configuration( + model_list, builtin_tools_list, mcp_tools_list ) - print(f"In agent_process_prompt: {turn_response}") - response_content = st.write_stream(response_generator(turn_response, debug_events_list)) - print(f"In agent_process_prompt: {response_content}") - st.session_state.messages.append({"role": "assistant", "content": response_content}) + # Initialize session state + initialize_session_state() + + # Create chat configuration object + chat_config = ChatConfig( + model=sidebar_config['model'], + processing_mode=sidebar_config['processing_mode'], + system_prompt=sidebar_config['system_prompt'], + conversation_id=st.session_state["conversation_id"], + toolgroup_selection=sidebar_config['toolgroup_selection'], + selected_vector_dbs=sidebar_config['selected_vector_dbs'], + sampling=SamplingParams( + temperature=sidebar_config['temperature'], + max_infer_iters=sidebar_config['max_infer_iters'] + ) + ) - def direct_process_prompt(prompt, debug_events_list): - # Query the vector DB - if selected_vector_dbs: - vector_dbs = llama_stack_api.client.vector_dbs.list() or [] - vector_db_ids = [vector_db.identifier for vector_db in vector_dbs if get_vector_db_name(vector_db) in selected_vector_dbs] - with st.spinner("Retrieving context (RAG)..."): - try: - rag_response = llama_stack_api.client.tool_runtime.rag_tool.query( - content=prompt, vector_db_ids=list(vector_db_ids) - ) - prompt_context = rag_response.content - debug_events_list.append({ - "type": "rag_query_direct_mode", "query": prompt, - "vector_dbs": selected_vector_dbs, - "context_length": len(prompt_context) if prompt_context else 0, - "context_preview": (str(prompt_context[:200]) + "..." if prompt_context else "None") - }) - except Exception as e: - st.warning(f"RAG Error (Direct Mode): {e}") - debug_events_list.append({"type": "error", "source": "rag_direct_mode", "content": str(e)}) - else: - prompt_context = None - - with st.chat_message("assistant"): - message_placeholder = st.empty() - full_response = "" - retrieval_response = "" - - # Construct the extended prompt - if prompt_context: - extended_prompt = f"Please answer the following query using the context below.\n\nCONTEXT:\n{prompt_context}\n\nQUERY:\n{prompt}" - else: - extended_prompt = f"Please answer the following query. \n\nQUERY:\n{prompt}" - - # Run inference directly - #st.session_state.messages.append({"role": "user", "content": extended_prompt}) - messages_for_direct_api = ( - [{'role': 'system', 'content': system_prompt}] + - [{'role': 'user', 'content': extended_prompt}] - ) - response = llama_stack_api.client.inference.chat_completion( - messages=messages_for_direct_api, - model_id=model, - sampling_params={ - "strategy": get_strategy(temperature, top_p), - "max_tokens": max_tokens, - "repetition_penalty": repetition_penalty, - }, - stream=True, - timeout=120, - ) - - # Display assistant response - for chunk in response: - if chunk.event: - response_delta = chunk.event.delta - if isinstance(response_delta, ToolCallDelta): - retrieval_response += response_delta.tool_call.replace("====", "").strip() - #retrieval_message_placeholder.info(retrieval_response) - else: - full_response += chunk.event.delta.text - message_placeholder.markdown(full_response + "▌") - message_placeholder.markdown(full_response) - - response_dict = {"role": "assistant", "content": full_response, "stop_reason": "end_of_message"} - st.session_state.messages.append(response_dict) - #st.session_state.displayed_messages.append(response_dict) - - def process_prompt(prompt): - print(f"In process_prompt: {prompt}") - st.session_state.messages.append({"role": "user", "content": prompt}) - with st.chat_message("user"): - st.markdown(prompt) - - # Prepare for assistant's response - # Each assistant turn gets its own list for debug events - st.session_state.debug_events.append([]) - current_turn_debug_events_list = st.session_state.debug_events[-1] # Get the list for this turn - - st.session_state.prompt = prompt - - print(f"In process_prompt: {st.session_state.prompt}") - print(f"In processing mode: {processing_mode}") - if processing_mode == "Agent-based": - agent_process_prompt(st.session_state.prompt, current_turn_debug_events_list) - else: # rag_mode == "Direct" - direct_process_prompt(st.session_state.prompt, current_turn_debug_events_list) - #st.session_state.prompt = None - st.rerun() + # Display chat history + render_history() + + # Display suggested questions + display_suggested_questions(chat_config.selected_vector_dbs) # Handle selected question from suggestions if st.session_state.selected_question: prompt = st.session_state.selected_question - st.session_state.selected_question = None # Clear the selected question + st.session_state.selected_question = None + process_prompt(prompt, chat_config) - process_prompt(prompt) - - # Handle manual chat input if prompt := st.chat_input(placeholder="Ask a question..."): - # Append the user message to history and display it - process_prompt(prompt) - - - + process_prompt(prompt, chat_config) tool_chat_page() diff --git a/frontend/llama_stack_ui/distribution/ui/page/playground/direct.py b/frontend/llama_stack_ui/distribution/ui/page/playground/direct.py new file mode 100644 index 0000000..ee0656c --- /dev/null +++ b/frontend/llama_stack_ui/distribution/ui/page/playground/direct.py @@ -0,0 +1,224 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +""" +Direct mode implementation for chat with manual RAG. +""" + +import logging +import traceback + +import streamlit as st + +from llama_stack_ui.distribution.ui.modules.api import llama_stack_api +from llama_stack_ui.distribution.ui.modules.utils import get_vector_db_name + + +logger = logging.getLogger(__name__) + + +# ============================================================================ +# Direct Mode - Helper Functions +# ============================================================================ + +def extract_text_from_search_result(result): + """Extract text content from a search result object.""" + # Handle Data objects with content list + if hasattr(result, 'content') and isinstance(result.content, list): + for content_item in result.content: + if hasattr(content_item, 'text'): + return content_item.text + return None + + # Handle simple content attribute + if hasattr(result, 'content') and isinstance(result.content, str): + return result.content + + # Handle dict format + if isinstance(result, dict) and 'content' in result: + if isinstance(result['content'], list) and result['content']: + return result['content'][0].get('text', '') + return result['content'] + + return None + + +def search_vector_store_direct(prompt, vector_db_id, vector_db_name, state): + """Search vector store and extract context for Direct mode.""" + search_results = [] + context_parts = [] + + # Show search status + with state.containers.tool_status: + st.markdown(f"🛠 :grey[_Searching vector store: {vector_db_name}_]") + + logger.debug("Searching vector store %s with query: %s", vector_db_id, prompt) + + # Call vector store search API + search_response = llama_stack_api.client.vector_stores.search( + vector_store_id=vector_db_id, + query=prompt, + ) + + logger.debug("Search response: %s", search_response) + + # Extract search results from response + if hasattr(search_response, 'data') and search_response.data: + search_results = search_response.data + elif hasattr(search_response, 'chunks') and search_response.chunks: + search_results = search_response.chunks + elif hasattr(search_response, 'results') and search_response.results: + search_results = search_response.results + + # Display and process search results + if search_results: + with state.containers.tool_results: + with st.expander(f"📄 Search Results from '{vector_db_name}'", expanded=False): + st.json(search_results) + + # Build context from search results + for idx, result in enumerate(search_results[:5], 1): # Top 5 results + text_content = extract_text_from_search_result(result) + if text_content: + context_parts.append(f"[Document {idx}]: {text_content}") + + logger.debug("Built context with %s documents", len(context_parts)) + else: + # No results found + with state.containers.tool_results: + st.info(f"No results found in '{vector_db_name}'") + + return search_results, context_parts + + +def build_rag_messages(prompt, context_parts, system_prompt): + """Build messages for LLM - with or without RAG context.""" + if context_parts: + # RAG mode: Format user message with explicit CONTEXT and QUERY sections + context = "\n\n".join(context_parts) + extended_prompt = ( + f"Please answer the following query using the context below.\n\n" + f"CONTEXT:\n{context}\n\n" + f"QUERY:\n{prompt}" + ) + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": extended_prompt} + ] + logger.debug( + "Built RAG prompt with %s documents, total context length: %s", + len(context_parts), len(context) + ) + else: + # Normal chat mode: no context + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt} + ] + logger.debug("No context - using normal chat mode") + + return messages + + +def stream_completions_direct(completion_response, state): + """Stream chunks from Completions API and update state.""" + for chunk in completion_response: + logger.debug("Completion chunk: %s", chunk) + if hasattr(chunk, 'choices') and len(chunk.choices) > 0: + delta = chunk.choices[0].delta + + # Handle reasoning content (for models that support it like R1) + if hasattr(delta, 'reasoning_content') and delta.reasoning_content: + state.update_reasoning(delta.reasoning_content) + + # Handle regular content + if hasattr(delta, 'content') and delta.content: + state.update_message(delta.content) + + +def save_direct_response_to_session(state, all_search_results): + """Save direct response to session state.""" + state.finalize_reasoning() + state.finalize_message() + + response_dict = { + "role": "assistant", + "content": state.full_response, + "stop_reason": "end_of_message" + } + + # Save reasoning if present + if state.reasoning_text: + response_dict["reasoning"] = state.reasoning_text + + # Save search results for history display if we had any + if all_search_results: + response_dict["tool_results"] = [ + { + 'title': f'📄 Search Results from \'{name}\'', + 'type': 'json', + 'content': results + } + for name, results in all_search_results + ] + db_names = [name for name, _ in all_search_results] + response_dict["tool_status"] = ( + f"🛠 :grey[_Searched vector stores: {', '.join(db_names)}_]" + ) + + st.session_state.messages.append(response_dict) + + +# ============================================================================ +# Direct Mode - Main Function +# ============================================================================ + +def direct_process_prompt(prompt, state, config): + """Direct mode: Manual RAG with completions API.""" + context_parts = [] + all_search_results = [] + + vector_dbs = st.session_state.get("direct_vector_dbs", []) + if not vector_dbs: + logger.debug("No vector DB selected - normal chat mode") + + try: + # Step 1: Search each selected vector store + for vector_db in vector_dbs: + vector_db_id = vector_db.id + vector_db_name = get_vector_db_name(vector_db) + search_results, parts = search_vector_store_direct( + prompt, vector_db_id, vector_db_name, state + ) + if search_results: + all_search_results.append((vector_db_name, search_results)) + context_parts.extend(parts) + + # Step 2: Build messages (with or without RAG context) + messages = build_rag_messages(prompt, context_parts, config.system_prompt) + + # Step 3: Call completions API + logger.debug("Calling completions API with %s messages", len(messages)) + for i, msg in enumerate(messages): + logger.debug(" Message %s (%s): %s...", i, msg['role'], msg['content'][:200]) + + completion_response = llama_stack_api.client.chat.completions.create( + model=config.model, + messages=messages, + temperature=config.sampling.temperature, + stream=True, + ) + + # Step 4: Stream response and update UI + stream_completions_direct(completion_response, state) + + # Step 5: Save to session + save_direct_response_to_session(state, all_search_results) + + except Exception as e: + st.error(f"Error in Direct mode: {str(e)}") + logger.debug("Direct mode error: %s", e) + logger.debug("%s", traceback.format_exc()) diff --git a/frontend/llama_stack_ui/distribution/ui/page/upload/__init__.py b/frontend/llama_stack_ui/distribution/ui/page/upload/__init__.py index d4a3e15..756f351 100644 --- a/frontend/llama_stack_ui/distribution/ui/page/upload/__init__.py +++ b/frontend/llama_stack_ui/distribution/ui/page/upload/__init__.py @@ -3,4 +3,3 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. - diff --git a/frontend/llama_stack_ui/distribution/ui/page/upload/upload.py b/frontend/llama_stack_ui/distribution/ui/page/upload/upload.py index 8547cd8..47107a3 100644 --- a/frontend/llama_stack_ui/distribution/ui/page/upload/upload.py +++ b/frontend/llama_stack_ui/distribution/ui/page/upload/upload.py @@ -4,47 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import asyncio -import asyncpg -import os -import streamlit as st import traceback -from llama_stack_ui.distribution.ui.modules.utils import get_vector_db_name, data_url_from_file -from llama_stack_ui.distribution.ui.modules.api import llama_stack_api -from llama_stack_client import RAGDocument - - -# Module-level connection pool (initialized lazily) -_pg_pool = None - +import streamlit as st -async def _get_pg_pool(): - """ - Get or create the PostgreSQL connection pool. - The pool is created lazily on first use and reused for subsequent calls. - - Returns: - asyncpg.Pool: The connection pool instance - """ - global _pg_pool - if _pg_pool is None: - pg_host = os.environ.get("PGVECTOR_HOST", "pgvector") - pg_port = os.environ.get("PGVECTOR_PORT", "5432") - pg_user = os.environ.get("PGVECTOR_USER", "postgres") - pg_password = os.environ.get("PGVECTOR_PASSWORD", "rag_password") - pg_database = os.environ.get("PGVECTOR_DB", "rag_blueprint") - - _pg_pool = await asyncpg.create_pool( - host=pg_host, - port=int(pg_port), - user=pg_user, - password=pg_password, - database=pg_database, - min_size=1, - max_size=5, - ) - return _pg_pool +from llama_stack_ui.distribution.ui.modules.api import llama_stack_api +from llama_stack_ui.distribution.ui.modules.utils import get_vector_db_name def upload_page(): @@ -53,27 +18,27 @@ def upload_page(): Supports creating new vector databases and uploading documents to existing ones. """ st.title("📄 Upload Documents") - + # Initialize session state for creation status messages if "creation_status" not in st.session_state: st.session_state["creation_status"] = None if "creation_message" not in st.session_state: st.session_state["creation_message"] = "" - + # Initialize session state for selected vector database # This persists the selection when navigating away and back to this page if "selected_vector_db" not in st.session_state: st.session_state["selected_vector_db"] = "" - + # Initialize the widget key to match our tracked selection # This ensures the selectbox displays the correct value on page load if "vector_db_selector" not in st.session_state: st.session_state["vector_db_selector"] = st.session_state["selected_vector_db"] - + # Initialize newly created VDB tracker if "newly_created_vdb" not in st.session_state: st.session_state["newly_created_vdb"] = None - + # Show status messages at the top level (before dropdown) if st.session_state["creation_status"] == "success": st.success(st.session_state["creation_message"]) @@ -82,30 +47,28 @@ def upload_page(): st.session_state["creation_message"] = "" elif st.session_state["creation_status"] == "error": st.error(st.session_state["creation_message"]) - # Clear the message after showing it + # Clear the message after showing it st.session_state["creation_status"] = None st.session_state["creation_message"] = "" - + # Fetch all vector databases - vdb_list = llama_stack_api.client.vector_dbs.list() - + vdb_list = llama_stack_api.client.vector_stores.list() + # Build dropdown options based on whether databases exist dropdown_options = [] - vdb_info = {} - + # Define the "Create New" option with emoji for visibility - CREATE_NEW_OPTION = "➕ Create New" - + create_new_option = "➕ Create New" + if vdb_list: # When databases exist: list actual DBs first, then "Create New" LAST existing_vdbs = {get_vector_db_name(v): v.to_dict() for v in vdb_list} dropdown_options.extend(list(existing_vdbs.keys())) - dropdown_options.append(CREATE_NEW_OPTION) # Add "Create New" as LAST item - vdb_info = existing_vdbs + dropdown_options.append(create_new_option) # Add "Create New" as LAST item else: # When NO databases exist: only show "Create New" - dropdown_options = [CREATE_NEW_OPTION] - + dropdown_options = [create_new_option] + # Sync session state for widget - ensure it shows the right value # Priority 1: If a database was just created, auto-select it (highest priority) if st.session_state["newly_created_vdb"]: @@ -116,7 +79,10 @@ def upload_page(): st.session_state["vector_db_selector"] = newly_created_name st.session_state["newly_created_vdb"] = None # Priority 2: Use the previously selected database from session if it still exists - elif st.session_state["selected_vector_db"] and st.session_state["selected_vector_db"] in dropdown_options: + elif ( + st.session_state["selected_vector_db"] + and st.session_state["selected_vector_db"] in dropdown_options + ): # Sync widget state with our tracked state st.session_state["vector_db_selector"] = st.session_state["selected_vector_db"] # Priority 3: If no saved selection or saved selection doesn't exist, use smart default @@ -128,42 +94,42 @@ def upload_page(): st.session_state["vector_db_selector"] = first_db else: # When NO databases exist: default to "Create New" - st.session_state["selected_vector_db"] = CREATE_NEW_OPTION - st.session_state["vector_db_selector"] = CREATE_NEW_OPTION - + st.session_state["selected_vector_db"] = create_new_option + st.session_state["vector_db_selector"] = create_new_option + # Vector database selection dropdown with persistent selection # Using key parameter to bind directly to session state - NO index parameter to avoid conflicts def on_vector_db_change(): """Callback to update session state when selection changes""" st.session_state["selected_vector_db"] = st.session_state["vector_db_selector"] - + selected_vector_db = st.selectbox( - "Select a vector database", + "Select a vector database", dropdown_options, key="vector_db_selector", # Key binds to session state (session state controls the value) on_change=on_vector_db_change, # Callback updates our tracking variable help="Your selection will be remembered when you navigate to other pages" ) - + # Ensure session state is updated (in case callback didn't fire) if selected_vector_db != st.session_state["selected_vector_db"]: st.session_state["selected_vector_db"] = selected_vector_db - + # Get the actual vector database object for API calls (do this before using it) selected_vdb_obj = None - if selected_vector_db and selected_vector_db != CREATE_NEW_OPTION: + if selected_vector_db and selected_vector_db != create_new_option: for vdb in vdb_list: if get_vector_db_name(vdb) == selected_vector_db: selected_vdb_obj = vdb break - - if selected_vector_db == CREATE_NEW_OPTION: + + if selected_vector_db == create_new_option: # Show vector database creation UI _show_create_vector_db_ui() - elif selected_vector_db and selected_vector_db != CREATE_NEW_OPTION: + elif selected_vector_db and selected_vector_db != create_new_option: # Show existing documents in the database (heading will show only if documents exist) _show_existing_documents_table(selected_vector_db, selected_vdb_obj) - + # Add Browse functionality for uploading documents to this database st.subheader(f"📁 Upload Documents to '{selected_vector_db}'") _show_document_upload_ui(selected_vector_db, selected_vdb_obj) @@ -175,11 +141,11 @@ def _show_create_vector_db_ui(): Display UI for creating a new vector database. """ st.subheader("Create New Vector Database") - + # Initialize session state for creation form if "new_vdb_name" not in st.session_state: st.session_state["new_vdb_name"] = "" - + # Vector database name input new_vdb_name = st.text_input( "Add New Vector Database", @@ -187,10 +153,10 @@ def _show_create_vector_db_ui(): help="Enter a unique name for the new vector database", key="new_vdb_name_input" ) - + # Update session state st.session_state["new_vdb_name"] = new_vdb_name - + # Add button if st.button("Add", type="primary", disabled=not new_vdb_name.strip()): _create_vector_database(new_vdb_name.strip()) @@ -199,7 +165,7 @@ def _show_create_vector_db_ui(): def _create_vector_database(vdb_name): """ Create a new vector database using the LlamaStack API. - + Args: vdb_name (str): Name for the new vector database """ @@ -207,56 +173,45 @@ def _create_vector_database(vdb_name): # Reset status st.session_state["creation_status"] = None st.session_state["creation_message"] = "" - + # Validate input if not vdb_name or not vdb_name.strip(): st.session_state["creation_status"] = "error" st.session_state["creation_message"] = "Vector database name cannot be empty." return - + # Check for duplicate names - existing_vdbs = llama_stack_api.client.vector_dbs.list() + existing_vdbs = llama_stack_api.client.vector_stores.list() existing_names = [get_vector_db_name(vdb) for vdb in existing_vdbs] if vdb_name in existing_names: st.session_state["creation_status"] = "error" - st.session_state["creation_message"] = f"Vector database '{vdb_name}' already exists. Please choose a different name." - return - - # Get vector IO provider - providers = llama_stack_api.client.providers.list() - vector_io_provider = None - for provider in providers: - if provider.api == "vector_io": - vector_io_provider = provider.provider_id - break - - if not vector_io_provider: - st.session_state["creation_status"] = "error" - st.session_state["creation_message"] = "No vector IO provider found. Cannot create vector database." + st.session_state["creation_message"] = ( + f"Vector database '{vdb_name}' already exists. " + "Please choose a different name." + ) return - - # Create the vector database + + # Create the vector database using the new simplified API + # Note: embedding settings (dimension, model, provider) must be + # configured at the server level with st.spinner(f"Creating vector database '{vdb_name}'..."): - vector_db = llama_stack_api.client.vector_dbs.register( - vector_db_id=vdb_name, - embedding_dimension=384, - embedding_model="all-MiniLM-L6-v2", - provider_id=vector_io_provider, + _vector_db = llama_stack_api.client.vector_stores.create( + name=vdb_name, ) - + # Success st.session_state["creation_status"] = "success" st.session_state["creation_message"] = f"Vector database '{vdb_name}' created successfully!" - + # Mark this database to be auto-selected after refresh st.session_state["newly_created_vdb"] = vdb_name - + # Clear the input field st.session_state["new_vdb_name"] = "" - + # Trigger page refresh to update the dropdown - this will show the message at the top st.rerun() - + except Exception as e: st.session_state["creation_status"] = "error" st.session_state["creation_message"] = f"Error creating vector database: {str(e)}" @@ -265,7 +220,7 @@ def _create_vector_database(vdb_name): def _show_document_upload_ui(vector_db_name, vector_db_obj=None): """ Display UI for uploading documents to an existing vector database. - + Args: vector_db_name (str): Name of the selected vector database """ @@ -274,7 +229,7 @@ def _show_document_upload_ui(vector_db_name, vector_db_obj=None): st.session_state["upload_status"] = None if "upload_message" not in st.session_state: st.session_state["upload_message"] = "" - + # Show upload status messages if st.session_state["upload_status"] == "success": st.success(st.session_state["upload_message"]) @@ -286,42 +241,50 @@ def _show_document_upload_ui(vector_db_name, vector_db_obj=None): # Clear after showing st.session_state["upload_status"] = None st.session_state["upload_message"] = "" - + # Initialize session state to track processed files upload_key = f"processed_files_{vector_db_name}" if upload_key not in st.session_state: st.session_state[upload_key] = set() - + # File uploader uploaded_files = st.file_uploader( "Browse and select files to upload (files will upload automatically)", accept_multiple_files=True, type=["txt", "pdf", "doc", "docx"], key=f"uploader_{vector_db_name}", # Unique key per database - help="Select one or more documents - they will be uploaded automatically to this vector database" + help=( + "Select one or more documents - they will be uploaded " + "automatically to this vector database" + ), ) - + # Auto-upload when files are selected if uploaded_files: # Create a unique identifier for this set of files file_set_id = frozenset([f.name + str(f.size) for f in uploaded_files]) - + # Only process if this is a new set of files if file_set_id not in st.session_state[upload_key]: # Mark as processed IMMEDIATELY before upload to prevent re-triggering st.session_state[upload_key].add(file_set_id) - + # Get the correct database ID for upload - vector_db_id = vector_db_obj.identifier if vector_db_obj and hasattr(vector_db_obj, 'identifier') else vector_db_name - + if vector_db_obj and hasattr(vector_db_obj, 'id'): + vector_db_id = vector_db_obj.id + else: + vector_db_id = vector_db_name + # Upload automatically - _upload_documents_to_database(vector_db_name, uploaded_files, vector_db_id) + _upload_documents_to_database( + vector_db_name, uploaded_files, vector_db_id + ) def _upload_documents_to_database(vector_db_name, uploaded_files, vector_db_id=None): """ Upload documents to an existing vector database. - + Args: vector_db_name (str): Name of the target vector database uploaded_files: List of uploaded files from Streamlit file uploader @@ -330,195 +293,117 @@ def _upload_documents_to_database(vector_db_name, uploaded_files, vector_db_id=N # Reset status st.session_state["upload_status"] = None st.session_state["upload_message"] = "" - + if not uploaded_files: st.session_state["upload_status"] = "error" st.session_state["upload_message"] = "No files selected for upload." return - - # Convert uploaded files into RAGDocument instances - with st.spinner(f"Processing {len(uploaded_files)} file(s)..."): - documents = [ - RAGDocument( - document_id=uploaded_file.name, - content=data_url_from_file(uploaded_file), - metadata={"source": uploaded_file.name, "type": "uploaded_file"} # LlamaStack maps 'source' to chunk_metadata.source - ) - for uploaded_file in uploaded_files - ] - - # Insert documents into the existing vector database + + # Upload files using the new Files API + Vector Stores API actual_db_id = vector_db_id or vector_db_name - with st.spinner(f"Uploading documents to '{vector_db_name}'..."): - llama_stack_api.client.tool_runtime.rag_tool.insert( - vector_db_id=actual_db_id, # Use the correct database ID - documents=documents, - chunk_size_in_tokens=512, - ) - + uploaded_file_ids = [] + + with st.spinner(f"Uploading {len(uploaded_files)} file(s)..."): + for uploaded_file in uploaded_files: + # Step 1: Upload file to Files API + file_response = llama_stack_api.client.files.create( + file=uploaded_file, + purpose="assistants" + ) + + # Step 2: Attach file to vector store + llama_stack_api.client.vector_stores.files.create( + vector_store_id=actual_db_id, + file_id=file_response.id, + ) + + uploaded_file_ids.append(file_response.id) + # Success st.session_state["upload_status"] = "success" - st.session_state["upload_message"] = f"Successfully uploaded {len(uploaded_files)} document(s) to '{vector_db_name}'!" - + st.session_state["upload_message"] = ( + f"Successfully uploaded {len(uploaded_files)} document(s) " + f"to '{vector_db_name}'!" + ) + # Trigger refresh to show the success message st.rerun() - + except Exception as e: st.session_state["upload_status"] = "error" st.session_state["upload_message"] = f"Error uploading documents: {str(e)}" st.rerun() -def _get_documents_from_pgvector(vector_db_id): +def _get_documents_from_vector_store(vector_store_id): """ - Query pgvector directly to get document IDs stored in the database. - Uses a connection pool for efficient connection reuse. - + Get files from a vector store using the Files API. + Args: - vector_db_id (str): The vector database identifier - + vector_store_id (str): The vector store identifier + Returns: - list: List of unique document IDs, or None if query fails + list: List of file objects, or None if query fails """ try: - async def fetch_documents(): - try: - # Get connection from pool - pool = await _get_pg_pool() - - async with pool.acquire() as conn: - # Query for unique document IDs from the document JSONB column - # The vector_db_id is used as the table name with underscores replacing hyphens - table_name = f"vs_{vector_db_id.replace('-', '_')}" - - # Query metadata.source where LlamaStack stores the filename - # Try multiple paths since different upload methods use different structures: - # - Ingestion pipeline: metadata.source - # - Manual upload: chunk_metadata.source - # Fall back to auto-generated document_id if source is null - query = f""" - SELECT DISTINCT - COALESCE( - NULLIF(document->'metadata'->>'source', 'null'), - NULLIF(document->'chunk_metadata'->>'source', 'null'), - document->'metadata'->>'document_id' - ) as document_id - FROM {table_name} - WHERE document->'metadata'->>'document_id' IS NOT NULL - OR document->'metadata'->>'source' IS NOT NULL - ORDER BY document_id - """ - - queries = [query] - - doc_ids = [] - for query in queries: - try: - rows = await conn.fetch(query) - if rows: - doc_ids = [row['document_id'] for row in rows if row['document_id']] - if doc_ids: - break - except Exception as e: - continue # Try next query pattern - - return doc_ids if doc_ids else None - # Connection automatically returned to pool - - except Exception as e: - return None - - # Run the async function - try: - loop = asyncio.get_event_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - return loop.run_until_complete(fetch_documents()) - + # List files in the vector store + files_response = llama_stack_api.client.vector_stores.files.list( + vector_store_id=vector_store_id + ) + + # Extract file information + if hasattr(files_response, 'data'): + return files_response.data + else: + return list(files_response) if files_response else None + except Exception as e: + print(f"Error listing files from vector store: {e}") return None -def _delete_document_from_pgvector(vector_db_id, filename): +def _delete_file_from_vector_store(vector_store_id, file_id): """ - Delete a document and all its chunks/embeddings from pgvector. - Uses a connection pool for efficient connection reuse. - + Delete a file from a vector store using the Files API. + Args: - vector_db_id (str): The vector database identifier - filename (str): The filename/source to delete - + vector_store_id (str): The vector store identifier + file_id (str): The file ID to delete + Returns: - tuple: (success: bool, deleted_count: int, error_message: str) + tuple: (success: bool, error_message: str) """ try: - async def delete_document(): - try: - # Get connection from pool - pool = await _get_pg_pool() - - async with pool.acquire() as conn: - # The vector_db_id is used as the table name with underscores replacing hyphens - table_name = f"vs_{vector_db_id.replace('-', '_')}" - - # Delete all chunks where the source matches the filename - # Handle both document structures: - # - Ingestion pipeline: metadata.source - # - Manual upload: chunk_metadata.source - query = f""" - DELETE FROM {table_name} - WHERE document->'metadata'->>'source' = $1 - OR document->'chunk_metadata'->>'source' = $1 - """ - - result = await conn.execute(query, filename) - - # Parse the result to get the number of deleted rows - # Result format is like "DELETE 5" where 5 is the number of rows - deleted_count = int(result.split()[-1]) if result else 0 - - return True, deleted_count, None - # Connection automatically returned to pool - - except Exception as e: - return False, 0, str(e) - - # Run the async function - try: - loop = asyncio.get_event_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - return loop.run_until_complete(delete_document()) - + llama_stack_api.client.vector_stores.files.delete( + file_id=file_id, + vector_store_id=vector_store_id + ) + return True, None except Exception as e: - return False, 0, str(e) + return False, str(e) def _show_existing_documents_table(vector_db_name, vector_db_obj=None): """ Display information about documents in the selected vector database. - + Args: vector_db_name (str): Display name of the selected vector database vector_db_obj: The actual vector database object with identifier """ try: # Get the correct vector database ID - if vector_db_obj and hasattr(vector_db_obj, 'identifier'): - vector_db_id = vector_db_obj.identifier + if vector_db_obj and hasattr(vector_db_obj, 'id'): + vector_db_id = vector_db_obj.id else: vector_db_id = vector_db_name # Fallback to display name - + # Initialize session state for deletion status if "delete_status" not in st.session_state: st.session_state["delete_status"] = None if "delete_message" not in st.session_state: st.session_state["delete_message"] = "" - + # Show deletion status messages (before checking documents, so last delete shows) if st.session_state["delete_status"] == "success": st.success(st.session_state["delete_message"]) @@ -528,16 +413,17 @@ def _show_existing_documents_table(vector_db_name, vector_db_obj=None): st.error(st.session_state["delete_message"]) st.session_state["delete_status"] = None st.session_state["delete_message"] = "" - + with st.spinner("Checking for documents..."): - # First, try to get document list from pgvector directly - document_ids = _get_documents_from_pgvector(vector_db_id) - - if document_ids: - # Success! We have the actual document filenames + # Get files from vector store using the Files API + files = _get_documents_from_vector_store(vector_db_id) + + if files: + # Success! We have the files # Show heading for documents section st.subheader(f"📄 Documents in '{vector_db_name}'") - + st.info("ℹ️ File deletion is currently not available.") + # Add CSS for bordered table rows st.markdown(""" """, unsafe_allow_html=True) - + # Display table header col1, col2, col3 = st.columns([0.5, 5, 0.5]) with col1: @@ -561,40 +447,37 @@ def _show_existing_documents_table(vector_db_name, vector_db_obj=None): st.markdown("**Filename**") with col3: st.markdown("**Del**") - - # Display each document in a row with delete button - for idx, doc_id in enumerate(document_ids, start=1): + + # Display each file in a row with delete button + for idx, file_obj in enumerate(files, start=1): col1, col2, col3 = st.columns([0.5, 5, 0.5]) - + + # Extract file information + file_id = getattr(file_obj, 'id', 'unknown') + # Try to get filename from metadata or use file_id + filename = file_id + with col1: st.write(idx) - + with col2: - st.write(doc_id) - + st.write(filename) + with col3: - delete_key = f"delete_{vector_db_name}_{doc_id}_{idx}" - - if st.button("✕", key=delete_key, help=f"Delete {doc_id}"): - # Delete immediately without confirmation - success, deleted_count, error = _delete_document_from_pgvector( - vector_db_id, - doc_id - ) - - if success: - st.session_state["delete_status"] = "success" - st.session_state["delete_message"] = f"✅ Successfully deleted '{doc_id}' ({deleted_count} chunk(s) removed)" - else: - st.session_state["delete_status"] = "error" - st.session_state["delete_message"] = f"❌ Failed to delete '{doc_id}': {error}" - - st.rerun() - + delete_key = f"delete_{vector_db_name}_{file_id}_{idx}" + + # Disable delete button + st.button( + "✕", + key=delete_key, + disabled=True, + help="Delete is currently not available", + ) + # else: Database appears empty or pgvector query not available # For newly created databases, this is expected - just show nothing # The upload section below will allow users to add documents - + except Exception as e: st.error(f"Error loading document information: {str(e)}") with st.expander("Error Details"): diff --git a/frontend/pyproject.toml b/frontend/pyproject.toml index 31e4233..41c56e7 100644 --- a/frontend/pyproject.toml +++ b/frontend/pyproject.toml @@ -10,10 +10,10 @@ requires-python = ">=3.12" dependencies = [ "streamlit", "pandas", - "llama-stack-client==0.2.23", + "llama-stack-client==0.3.5", "requests", "streamlit-option-menu", - "llama-stack==0.2.23", + "llama-stack==0.3.5", "fire", "asyncpg", ] From f71de459ee573563f0ba81a4c3a3d527284ccc8c Mon Sep 17 00:00:00 2001 From: Yuval Turgeman Date: Wed, 11 Feb 2026 16:41:11 -0500 Subject: [PATCH 15/18] Use source for filenames Signed-off-by: Yuval Turgeman Assisted-by: Claude --- .../distribution/ui/page/upload/upload.py | 37 ++++++++++++++----- 1 file changed, 27 insertions(+), 10 deletions(-) diff --git a/frontend/llama_stack_ui/distribution/ui/page/upload/upload.py b/frontend/llama_stack_ui/distribution/ui/page/upload/upload.py index 47107a3..1502ed8 100644 --- a/frontend/llama_stack_ui/distribution/ui/page/upload/upload.py +++ b/frontend/llama_stack_ui/distribution/ui/page/upload/upload.py @@ -422,7 +422,6 @@ def _show_existing_documents_table(vector_db_name, vector_db_obj=None): # Success! We have the files # Show heading for documents section st.subheader(f"📄 Documents in '{vector_db_name}'") - st.info("ℹ️ File deletion is currently not available.") # Add CSS for bordered table rows st.markdown(""" @@ -439,34 +438,52 @@ def _show_existing_documents_table(vector_db_name, vector_db_obj=None): """, unsafe_allow_html=True) + # Retrieve source for all files: prefer attributes["source"], fallback to filename + source_names = {} + for file_obj in files: + file_id = getattr(file_obj, 'id', None) + if file_id: + # Check attributes["source"] from the vector store file + attrs = getattr(file_obj, 'attributes', None) or {} + source = attrs.get("source") + if not source: + try: + file_info = llama_stack_api.client.files.retrieve(file_id) + source = getattr(file_info, 'filename', None) + except Exception: + source = None + source_names[file_id] = source + # Display table header - col1, col2, col3 = st.columns([0.5, 5, 0.5]) + col1, col2, col3, col4 = st.columns([0.5, 3, 3, 0.5]) with col1: st.markdown("**#**") with col2: - st.markdown("**Filename**") + st.markdown("**Source**") with col3: + st.markdown("**Document ID**") + with col4: st.markdown("**Del**") - # Display each file in a row with delete button + # Display each file in a row for idx, file_obj in enumerate(files, start=1): - col1, col2, col3 = st.columns([0.5, 5, 0.5]) + col1, col2, col3, col4 = st.columns([0.5, 3, 3, 0.5]) # Extract file information file_id = getattr(file_obj, 'id', 'unknown') - # Try to get filename from metadata or use file_id - filename = file_id + source = source_names.get(file_id) or "unknown" with col1: st.write(idx) with col2: - st.write(filename) + st.write(source) with col3: - delete_key = f"delete_{vector_db_name}_{file_id}_{idx}" + st.write(file_id) - # Disable delete button + with col4: + delete_key = f"delete_{vector_db_name}_{file_id}_{idx}" st.button( "✕", key=delete_key, From 7c16f80188c69c6495f2137c1c7cd6bbeb1903a3 Mon Sep 17 00:00:00 2001 From: Yuval Turgeman Date: Wed, 11 Feb 2026 17:23:55 -0500 Subject: [PATCH 16/18] disable delete file and refactor upload Signed-off-by: Yuval Turgeman --- .../distribution/ui/page/upload/upload.py | 421 +++++++----------- 1 file changed, 173 insertions(+), 248 deletions(-) diff --git a/frontend/llama_stack_ui/distribution/ui/page/upload/upload.py b/frontend/llama_stack_ui/distribution/ui/page/upload/upload.py index 1502ed8..2f8e09a 100644 --- a/frontend/llama_stack_ui/distribution/ui/page/upload/upload.py +++ b/frontend/llama_stack_ui/distribution/ui/page/upload/upload.py @@ -4,6 +4,8 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +"""Upload documents page for managing vector databases and document ingestion.""" + import traceback import streamlit as st @@ -12,141 +14,122 @@ from llama_stack_ui.distribution.ui.modules.utils import get_vector_db_name -def upload_page(): - """ - Page to upload documents and manage vector databases for RAG. - Supports creating new vector databases and uploading documents to existing ones. - """ - st.title("📄 Upload Documents") - - # Initialize session state for creation status messages - if "creation_status" not in st.session_state: - st.session_state["creation_status"] = None - if "creation_message" not in st.session_state: - st.session_state["creation_message"] = "" - - # Initialize session state for selected vector database - # This persists the selection when navigating away and back to this page - if "selected_vector_db" not in st.session_state: - st.session_state["selected_vector_db"] = "" +def _init_upload_page_session_state(): + """Initialize all session state variables needed by the upload page.""" + defaults = { + "creation_status": None, + "creation_message": "", + "selected_vector_db": "", + "newly_created_vdb": None, + } + for key, value in defaults.items(): + if key not in st.session_state: + st.session_state[key] = value - # Initialize the widget key to match our tracked selection - # This ensures the selectbox displays the correct value on page load if "vector_db_selector" not in st.session_state: st.session_state["vector_db_selector"] = st.session_state["selected_vector_db"] - # Initialize newly created VDB tracker - if "newly_created_vdb" not in st.session_state: - st.session_state["newly_created_vdb"] = None - # Show status messages at the top level (before dropdown) - if st.session_state["creation_status"] == "success": - st.success(st.session_state["creation_message"]) - # Clear the message after showing it - st.session_state["creation_status"] = None - st.session_state["creation_message"] = "" - elif st.session_state["creation_status"] == "error": - st.error(st.session_state["creation_message"]) - # Clear the message after showing it - st.session_state["creation_status"] = None - st.session_state["creation_message"] = "" +def _show_status(status_key, message_key): + """Show and clear a status message from session state.""" + status = st.session_state[status_key] + if status == "success": + st.success(st.session_state[message_key]) + elif status == "error": + st.error(st.session_state[message_key]) + else: + return + st.session_state[status_key] = None + st.session_state[message_key] = "" - # Fetch all vector databases - vdb_list = llama_stack_api.client.vector_stores.list() - # Build dropdown options based on whether databases exist - dropdown_options = [] +def _build_dropdown_options(vdb_list): + """Build dropdown options from the vector database list. - # Define the "Create New" option with emoji for visibility + Returns: + tuple: (dropdown_options, create_new_option) + """ create_new_option = "➕ Create New" if vdb_list: - # When databases exist: list actual DBs first, then "Create New" LAST - existing_vdbs = {get_vector_db_name(v): v.to_dict() for v in vdb_list} - dropdown_options.extend(list(existing_vdbs.keys())) - dropdown_options.append(create_new_option) # Add "Create New" as LAST item - else: - # When NO databases exist: only show "Create New" - dropdown_options = [create_new_option] - - # Sync session state for widget - ensure it shows the right value - # Priority 1: If a database was just created, auto-select it (highest priority) - if st.session_state["newly_created_vdb"]: - newly_created_name = st.session_state["newly_created_vdb"] - if newly_created_name in dropdown_options: - # Update both session variables to sync state - st.session_state["selected_vector_db"] = newly_created_name - st.session_state["vector_db_selector"] = newly_created_name - st.session_state["newly_created_vdb"] = None - # Priority 2: Use the previously selected database from session if it still exists - elif ( - st.session_state["selected_vector_db"] - and st.session_state["selected_vector_db"] in dropdown_options - ): - # Sync widget state with our tracked state - st.session_state["vector_db_selector"] = st.session_state["selected_vector_db"] - # Priority 3: If no saved selection or saved selection doesn't exist, use smart default + existing_vdbs = [get_vector_db_name(v) for v in vdb_list] + return existing_vdbs + [create_new_option], create_new_option + + return [create_new_option], create_new_option + + +def _sync_vector_db_selection(dropdown_options, vdb_list): + """Sync the vector database selection state with available options.""" + # Priority 1: Auto-select a newly created database + newly_created = st.session_state["newly_created_vdb"] + if newly_created and newly_created in dropdown_options: + st.session_state["selected_vector_db"] = newly_created + st.session_state["vector_db_selector"] = newly_created + st.session_state["newly_created_vdb"] = None + return + + # Priority 2: Keep the previously selected database if it still exists + selected = st.session_state["selected_vector_db"] + if selected and selected in dropdown_options: + st.session_state["vector_db_selector"] = selected + return + + # Priority 3: Smart default + if vdb_list: + first_db = dropdown_options[0] + st.session_state["selected_vector_db"] = first_db + st.session_state["vector_db_selector"] = first_db else: - if vdb_list: - # When databases exist: default to FIRST actual database (not "Create New") - first_db = dropdown_options[0] # First item is first actual database - st.session_state["selected_vector_db"] = first_db - st.session_state["vector_db_selector"] = first_db - else: - # When NO databases exist: default to "Create New" - st.session_state["selected_vector_db"] = create_new_option - st.session_state["vector_db_selector"] = create_new_option + st.session_state["selected_vector_db"] = dropdown_options[0] + st.session_state["vector_db_selector"] = dropdown_options[0] + + +def upload_page(): + """Page to upload documents and manage vector databases for RAG.""" + st.title("📄 Upload Documents") + + _init_upload_page_session_state() + _show_status("creation_status", "creation_message") + + vdb_list = llama_stack_api.client.vector_stores.list() + dropdown_options, create_new_option = _build_dropdown_options(vdb_list) + _sync_vector_db_selection(dropdown_options, vdb_list) - # Vector database selection dropdown with persistent selection - # Using key parameter to bind directly to session state - NO index parameter to avoid conflicts def on_vector_db_change(): - """Callback to update session state when selection changes""" st.session_state["selected_vector_db"] = st.session_state["vector_db_selector"] selected_vector_db = st.selectbox( "Select a vector database", dropdown_options, - key="vector_db_selector", # Key binds to session state (session state controls the value) - on_change=on_vector_db_change, # Callback updates our tracking variable + key="vector_db_selector", + on_change=on_vector_db_change, help="Your selection will be remembered when you navigate to other pages" ) - # Ensure session state is updated (in case callback didn't fire) if selected_vector_db != st.session_state["selected_vector_db"]: st.session_state["selected_vector_db"] = selected_vector_db - # Get the actual vector database object for API calls (do this before using it) - selected_vdb_obj = None - if selected_vector_db and selected_vector_db != create_new_option: + if selected_vector_db == create_new_option: + _show_create_vector_db_ui() + elif selected_vector_db: + selected_vdb_obj = None for vdb in vdb_list: if get_vector_db_name(vdb) == selected_vector_db: selected_vdb_obj = vdb break - if selected_vector_db == create_new_option: - # Show vector database creation UI - _show_create_vector_db_ui() - elif selected_vector_db and selected_vector_db != create_new_option: - # Show existing documents in the database (heading will show only if documents exist) _show_existing_documents_table(selected_vector_db, selected_vdb_obj) - - # Add Browse functionality for uploading documents to this database st.subheader(f"📁 Upload Documents to '{selected_vector_db}'") _show_document_upload_ui(selected_vector_db, selected_vdb_obj) - # If empty string is selected, show nothing (clean default state) def _show_create_vector_db_ui(): - """ - Display UI for creating a new vector database. - """ + """Display UI for creating a new vector database.""" st.subheader("Create New Vector Database") - # Initialize session state for creation form if "new_vdb_name" not in st.session_state: st.session_state["new_vdb_name"] = "" - # Vector database name input new_vdb_name = st.text_input( "Add New Vector Database", value=st.session_state["new_vdb_name"], @@ -154,33 +137,27 @@ def _show_create_vector_db_ui(): key="new_vdb_name_input" ) - # Update session state st.session_state["new_vdb_name"] = new_vdb_name - # Add button if st.button("Add", type="primary", disabled=not new_vdb_name.strip()): _create_vector_database(new_vdb_name.strip()) def _create_vector_database(vdb_name): - """ - Create a new vector database using the LlamaStack API. + """Create a new vector database using the LlamaStack API. Args: vdb_name (str): Name for the new vector database """ try: - # Reset status st.session_state["creation_status"] = None st.session_state["creation_message"] = "" - # Validate input if not vdb_name or not vdb_name.strip(): st.session_state["creation_status"] = "error" st.session_state["creation_message"] = "Vector database name cannot be empty." return - # Check for duplicate names existing_vdbs = llama_stack_api.client.vector_stores.list() existing_names = [get_vector_db_name(vdb) for vdb in existing_vdbs] if vdb_name in existing_names: @@ -191,106 +168,78 @@ def _create_vector_database(vdb_name): ) return - # Create the vector database using the new simplified API - # Note: embedding settings (dimension, model, provider) must be - # configured at the server level with st.spinner(f"Creating vector database '{vdb_name}'..."): _vector_db = llama_stack_api.client.vector_stores.create( name=vdb_name, ) - # Success st.session_state["creation_status"] = "success" - st.session_state["creation_message"] = f"Vector database '{vdb_name}' created successfully!" - - # Mark this database to be auto-selected after refresh + st.session_state["creation_message"] = ( + f"Vector database '{vdb_name}' created successfully!" + ) st.session_state["newly_created_vdb"] = vdb_name - - # Clear the input field st.session_state["new_vdb_name"] = "" - - # Trigger page refresh to update the dropdown - this will show the message at the top st.rerun() - except Exception as e: + except Exception as e: # pylint: disable=broad-exception-caught st.session_state["creation_status"] = "error" st.session_state["creation_message"] = f"Error creating vector database: {str(e)}" def _show_document_upload_ui(vector_db_name, vector_db_obj=None): - """ - Display UI for uploading documents to an existing vector database. + """Display UI for uploading documents to an existing vector database. Args: vector_db_name (str): Name of the selected vector database + vector_db_obj: The actual vector database object with identifier """ - # Initialize session state for upload status if "upload_status" not in st.session_state: st.session_state["upload_status"] = None if "upload_message" not in st.session_state: st.session_state["upload_message"] = "" - # Show upload status messages - if st.session_state["upload_status"] == "success": - st.success(st.session_state["upload_message"]) - # Clear after showing - st.session_state["upload_status"] = None - st.session_state["upload_message"] = "" - elif st.session_state["upload_status"] == "error": - st.error(st.session_state["upload_message"]) - # Clear after showing - st.session_state["upload_status"] = None - st.session_state["upload_message"] = "" + _show_status("upload_status", "upload_message") - # Initialize session state to track processed files upload_key = f"processed_files_{vector_db_name}" if upload_key not in st.session_state: st.session_state[upload_key] = set() - # File uploader uploaded_files = st.file_uploader( "Browse and select files to upload (files will upload automatically)", accept_multiple_files=True, type=["txt", "pdf", "doc", "docx"], - key=f"uploader_{vector_db_name}", # Unique key per database + key=f"uploader_{vector_db_name}", help=( "Select one or more documents - they will be uploaded " "automatically to this vector database" ), ) - # Auto-upload when files are selected if uploaded_files: - # Create a unique identifier for this set of files file_set_id = frozenset([f.name + str(f.size) for f in uploaded_files]) - # Only process if this is a new set of files if file_set_id not in st.session_state[upload_key]: - # Mark as processed IMMEDIATELY before upload to prevent re-triggering st.session_state[upload_key].add(file_set_id) - # Get the correct database ID for upload if vector_db_obj and hasattr(vector_db_obj, 'id'): vector_db_id = vector_db_obj.id else: vector_db_id = vector_db_name - # Upload automatically _upload_documents_to_database( vector_db_name, uploaded_files, vector_db_id ) def _upload_documents_to_database(vector_db_name, uploaded_files, vector_db_id=None): - """ - Upload documents to an existing vector database. + """Upload documents to an existing vector database. Args: vector_db_name (str): Name of the target vector database uploaded_files: List of uploaded files from Streamlit file uploader + vector_db_id (str): The actual database identifier for API calls """ try: - # Reset status st.session_state["upload_status"] = None st.session_state["upload_message"] = "" @@ -299,45 +248,36 @@ def _upload_documents_to_database(vector_db_name, uploaded_files, vector_db_id=N st.session_state["upload_message"] = "No files selected for upload." return - # Upload files using the new Files API + Vector Stores API actual_db_id = vector_db_id or vector_db_name uploaded_file_ids = [] with st.spinner(f"Uploading {len(uploaded_files)} file(s)..."): for uploaded_file in uploaded_files: - # Step 1: Upload file to Files API file_response = llama_stack_api.client.files.create( file=uploaded_file, purpose="assistants" ) - - # Step 2: Attach file to vector store llama_stack_api.client.vector_stores.files.create( vector_store_id=actual_db_id, file_id=file_response.id, ) - uploaded_file_ids.append(file_response.id) - # Success st.session_state["upload_status"] = "success" st.session_state["upload_message"] = ( f"Successfully uploaded {len(uploaded_files)} document(s) " f"to '{vector_db_name}'!" ) - - # Trigger refresh to show the success message st.rerun() - except Exception as e: + except Exception as e: # pylint: disable=broad-exception-caught st.session_state["upload_status"] = "error" st.session_state["upload_message"] = f"Error uploading documents: {str(e)}" st.rerun() def _get_documents_from_vector_store(vector_store_id): - """ - Get files from a vector store using the Files API. + """Get files from a vector store using the Files API. Args: vector_store_id (str): The vector store identifier @@ -346,25 +286,21 @@ def _get_documents_from_vector_store(vector_store_id): list: List of file objects, or None if query fails """ try: - # List files in the vector store files_response = llama_stack_api.client.vector_stores.files.list( vector_store_id=vector_store_id ) - # Extract file information if hasattr(files_response, 'data'): return files_response.data - else: - return list(files_response) if files_response else None + return list(files_response) if files_response else None - except Exception as e: + except Exception as e: # pylint: disable=broad-exception-caught print(f"Error listing files from vector store: {e}") return None def _delete_file_from_vector_store(vector_store_id, file_id): - """ - Delete a file from a vector store using the Files API. + """Delete a file from a vector store using the Files API. Args: vector_store_id (str): The vector store identifier @@ -379,123 +315,112 @@ def _delete_file_from_vector_store(vector_store_id, file_id): vector_store_id=vector_store_id ) return True, None - except Exception as e: + except Exception as e: # pylint: disable=broad-exception-caught return False, str(e) -def _show_existing_documents_table(vector_db_name, vector_db_obj=None): +def _get_file_sources(files): + """Retrieve the source name for each file. + + Prefers attributes["source"], falls back to filename from Files API. + + Args: + files: List of vector store file objects + + Returns: + dict: Mapping of file_id to source name """ - Display information about documents in the selected vector database. + source_names = {} + for file_obj in files: + file_id = getattr(file_obj, 'id', None) + if not file_id: + continue + attrs = getattr(file_obj, 'attributes', None) or {} + source = attrs.get("source") + if not source: + try: + file_info = llama_stack_api.client.files.retrieve(file_id) + source = getattr(file_info, 'filename', None) + except Exception: # pylint: disable=broad-exception-caught + source = None + source_names[file_id] = source + return source_names + + +def _render_documents_table(files, source_names): + """Render the documents table with source and document ID columns. + + Args: + files: List of vector store file objects + source_names (dict): Mapping of file_id to source name + """ + # Add CSS for bordered table rows + st.markdown(""" + + """, unsafe_allow_html=True) + + # Display table header + col1, col2, col3 = st.columns([0.5, 3, 3]) + with col1: + st.markdown("**#**") + with col2: + st.markdown("**Source**") + with col3: + st.markdown("**Document ID**") + + # Display each file in a row + for idx, file_obj in enumerate(files, start=1): + col1, col2, col3 = st.columns([0.5, 3, 3]) + file_id = getattr(file_obj, 'id', 'unknown') + source = source_names.get(file_id) or "unknown" + + with col1: + st.write(idx) + with col2: + st.write(source) + with col3: + st.write(file_id) + + +def _show_existing_documents_table(vector_db_name, vector_db_obj=None): + """Display information about documents in the selected vector database. Args: vector_db_name (str): Display name of the selected vector database vector_db_obj: The actual vector database object with identifier """ try: - # Get the correct vector database ID if vector_db_obj and hasattr(vector_db_obj, 'id'): vector_db_id = vector_db_obj.id else: - vector_db_id = vector_db_name # Fallback to display name + vector_db_id = vector_db_name - # Initialize session state for deletion status if "delete_status" not in st.session_state: st.session_state["delete_status"] = None if "delete_message" not in st.session_state: st.session_state["delete_message"] = "" - # Show deletion status messages (before checking documents, so last delete shows) - if st.session_state["delete_status"] == "success": - st.success(st.session_state["delete_message"]) - st.session_state["delete_status"] = None - st.session_state["delete_message"] = "" - elif st.session_state["delete_status"] == "error": - st.error(st.session_state["delete_message"]) - st.session_state["delete_status"] = None - st.session_state["delete_message"] = "" + _show_status("delete_status", "delete_message") with st.spinner("Checking for documents..."): - # Get files from vector store using the Files API files = _get_documents_from_vector_store(vector_db_id) if files: - # Success! We have the files - # Show heading for documents section st.subheader(f"📄 Documents in '{vector_db_name}'") + source_names = _get_file_sources(files) + _render_documents_table(files, source_names) - # Add CSS for bordered table rows - st.markdown(""" - - """, unsafe_allow_html=True) - - # Retrieve source for all files: prefer attributes["source"], fallback to filename - source_names = {} - for file_obj in files: - file_id = getattr(file_obj, 'id', None) - if file_id: - # Check attributes["source"] from the vector store file - attrs = getattr(file_obj, 'attributes', None) or {} - source = attrs.get("source") - if not source: - try: - file_info = llama_stack_api.client.files.retrieve(file_id) - source = getattr(file_info, 'filename', None) - except Exception: - source = None - source_names[file_id] = source - - # Display table header - col1, col2, col3, col4 = st.columns([0.5, 3, 3, 0.5]) - with col1: - st.markdown("**#**") - with col2: - st.markdown("**Source**") - with col3: - st.markdown("**Document ID**") - with col4: - st.markdown("**Del**") - - # Display each file in a row - for idx, file_obj in enumerate(files, start=1): - col1, col2, col3, col4 = st.columns([0.5, 3, 3, 0.5]) - - # Extract file information - file_id = getattr(file_obj, 'id', 'unknown') - source = source_names.get(file_id) or "unknown" - - with col1: - st.write(idx) - - with col2: - st.write(source) - - with col3: - st.write(file_id) - - with col4: - delete_key = f"delete_{vector_db_name}_{file_id}_{idx}" - st.button( - "✕", - key=delete_key, - disabled=True, - help="Delete is currently not available", - ) - - # else: Database appears empty or pgvector query not available - # For newly created databases, this is expected - just show nothing - # The upload section below will allow users to add documents - - except Exception as e: + except Exception as e: # pylint: disable=broad-exception-caught st.error(f"Error loading document information: {str(e)}") with st.expander("Error Details"): st.code(traceback.format_exc()) From 5d983a354a351676ff18ea5b966df5d14f724b56 Mon Sep 17 00:00:00 2001 From: Yuval Turgeman Date: Fri, 13 Feb 2026 14:02:51 -0500 Subject: [PATCH 17/18] Use new ingestion pipeline Format search results nicely Signed-off-by: Yuval Turgeman Assisted-by: Claude --- deploy/helm/rag/Chart.lock | 8 +-- deploy/helm/rag/Chart.yaml | 4 +- .../distribution/ui/modules/utils.py | 6 +++ .../distribution/ui/page/playground/agent.py | 12 +++-- .../distribution/ui/page/playground/direct.py | 53 +++++++++++-------- 5 files changed, 51 insertions(+), 32 deletions(-) diff --git a/deploy/helm/rag/Chart.lock b/deploy/helm/rag/Chart.lock index 36eeea4..f3dfd71 100644 --- a/deploy/helm/rag/Chart.lock +++ b/deploy/helm/rag/Chart.lock @@ -10,12 +10,12 @@ dependencies: version: 0.5.6 - name: ingestion-pipeline repository: https://rh-ai-quickstart.github.io/ai-architecture-charts - version: 0.6.5 + version: 0.6.6 - name: llama-stack repository: https://rh-ai-quickstart.github.io/ai-architecture-charts - version: 0.6.10 + version: 0.6.11 - name: mcp-servers repository: https://rh-ai-quickstart.github.io/ai-architecture-charts version: 0.5.15 -digest: sha256:a55fed2a9164c13653c7e341350f0a2eeba302cd627232800016be49fafc09db -generated: "2026-02-06T15:10:51.764446626-05:00" +digest: sha256:1065a9cbf8dfb460fd9c9a6d3571fdfc33e3693503aae6535e07e75919a6c9f2 +generated: "2026-02-13T13:19:24.192726731-05:00" diff --git a/deploy/helm/rag/Chart.yaml b/deploy/helm/rag/Chart.yaml index 624621d..0a10099 100644 --- a/deploy/helm/rag/Chart.yaml +++ b/deploy/helm/rag/Chart.yaml @@ -19,11 +19,11 @@ dependencies: repository: https://rh-ai-quickstart.github.io/ai-architecture-charts condition: configure-pipeline.enabled - name: ingestion-pipeline - version: 0.6.5 + version: 0.6.6 repository: https://rh-ai-quickstart.github.io/ai-architecture-charts condition: ingestion-pipeline.enabled - name: llama-stack - version: 0.6.10 + version: 0.6.11 repository: https://rh-ai-quickstart.github.io/ai-architecture-charts condition: llama-stack.enabled - name: mcp-servers diff --git a/frontend/llama_stack_ui/distribution/ui/modules/utils.py b/frontend/llama_stack_ui/distribution/ui/modules/utils.py index 3d07310..a706e78 100644 --- a/frontend/llama_stack_ui/distribution/ui/modules/utils.py +++ b/frontend/llama_stack_ui/distribution/ui/modules/utils.py @@ -7,6 +7,7 @@ import base64 import json import os +import re import pandas as pd import streamlit as st @@ -57,6 +58,11 @@ def data_url_from_file(file) -> str: return data_url +def clean_text(text): + """Collapse consecutive whitespace into a single space.""" + return re.sub(r'\s+', ' ', text).strip() + + def get_vector_db_name(vector_db): """ Get the display name for a vector database. diff --git a/frontend/llama_stack_ui/distribution/ui/page/playground/agent.py b/frontend/llama_stack_ui/distribution/ui/page/playground/agent.py index 649da36..56bc38d 100644 --- a/frontend/llama_stack_ui/distribution/ui/page/playground/agent.py +++ b/frontend/llama_stack_ui/distribution/ui/page/playground/agent.py @@ -13,7 +13,7 @@ import streamlit as st from llama_stack_ui.distribution.ui.modules.api import llama_stack_api -from llama_stack_ui.distribution.ui.modules.utils import get_vector_db_name +from llama_stack_ui.distribution.ui.modules.utils import clean_text, get_vector_db_name logger = logging.getLogger(__name__) @@ -131,14 +131,20 @@ def handle_agent_output_item_done(chunk, state): if item_type == "file_search_call": # File search results if hasattr(item, 'results') and item.results: + display_results = [] + for r in item.results: + text = getattr(r, 'text', '') + attrs = getattr(r, 'attributes', {}) + source = attrs.get('source') or getattr(r, 'filename', 'unknown') + display_results.append({"source": source, "text": clean_text(text)}) state.tool_results.append({ 'title': '📄 File Search Results', 'type': 'json', - 'content': item.results + 'content': display_results }) with state.containers.tool_results: with st.expander("📄 File Search Results", expanded=False): - st.json(item.results) + st.json(display_results) elif item_type == "web_search_call": # Web search - API doesn't expose raw results, just status diff --git a/frontend/llama_stack_ui/distribution/ui/page/playground/direct.py b/frontend/llama_stack_ui/distribution/ui/page/playground/direct.py index ee0656c..1b4744b 100644 --- a/frontend/llama_stack_ui/distribution/ui/page/playground/direct.py +++ b/frontend/llama_stack_ui/distribution/ui/page/playground/direct.py @@ -14,7 +14,7 @@ import streamlit as st from llama_stack_ui.distribution.ui.modules.api import llama_stack_api -from llama_stack_ui.distribution.ui.modules.utils import get_vector_db_name +from llama_stack_ui.distribution.ui.modules.utils import clean_text, get_vector_db_name logger = logging.getLogger(__name__) @@ -25,31 +25,35 @@ # ============================================================================ def extract_text_from_search_result(result): - """Extract text content from a search result object.""" + """Extract and clean text content from a search result object.""" + text = None + # Handle Data objects with content list if hasattr(result, 'content') and isinstance(result.content, list): for content_item in result.content: if hasattr(content_item, 'text'): - return content_item.text - return None + text = content_item.text + break # Handle simple content attribute - if hasattr(result, 'content') and isinstance(result.content, str): - return result.content + elif hasattr(result, 'content') and isinstance(result.content, str): + text = result.content # Handle dict format - if isinstance(result, dict) and 'content' in result: + elif isinstance(result, dict) and 'content' in result: if isinstance(result['content'], list) and result['content']: - return result['content'][0].get('text', '') - return result['content'] + text = result['content'][0].get('text', '') + else: + text = result['content'] - return None + return clean_text(text) if text else None def search_vector_store_direct(prompt, vector_db_id, vector_db_name, state): """Search vector store and extract context for Direct mode.""" search_results = [] context_parts = [] + display_results = [] # Show search status with state.containers.tool_status: @@ -75,15 +79,18 @@ def search_vector_store_direct(prompt, vector_db_id, vector_db_name, state): # Display and process search results if search_results: - with state.containers.tool_results: - with st.expander(f"📄 Search Results from '{vector_db_name}'", expanded=False): - st.json(search_results) - - # Build context from search results - for idx, result in enumerate(search_results[:5], 1): # Top 5 results + # Build context and display data from search results + for result in search_results: text_content = extract_text_from_search_result(result) if text_content: - context_parts.append(f"[Document {idx}]: {text_content}") + attrs = getattr(result, 'attributes', {}) + source = attrs.get('source') or getattr(result, 'filename', 'unknown') + context_parts.append(f"[Source: {source}]: {text_content}") + display_results.append({"source": source, "text": text_content}) + + with state.containers.tool_results: + with st.expander(f"📄 Search Results from '{vector_db_name}'", expanded=False): + st.json(display_results) logger.debug("Built context with %s documents", len(context_parts)) else: @@ -91,7 +98,7 @@ def search_vector_store_direct(prompt, vector_db_id, vector_db_name, state): with state.containers.tool_results: st.info(f"No results found in '{vector_db_name}'") - return search_results, context_parts + return search_results, context_parts, display_results def build_rag_messages(prompt, context_parts, system_prompt): @@ -156,15 +163,15 @@ def save_direct_response_to_session(state, all_search_results): # Save search results for history display if we had any if all_search_results: + db_names = [name for name, _ in all_search_results] response_dict["tool_results"] = [ { 'title': f'📄 Search Results from \'{name}\'', 'type': 'json', - 'content': results + 'content': display } - for name, results in all_search_results + for name, display in all_search_results ] - db_names = [name for name, _ in all_search_results] response_dict["tool_status"] = ( f"🛠 :grey[_Searched vector stores: {', '.join(db_names)}_]" ) @@ -190,11 +197,11 @@ def direct_process_prompt(prompt, state, config): for vector_db in vector_dbs: vector_db_id = vector_db.id vector_db_name = get_vector_db_name(vector_db) - search_results, parts = search_vector_store_direct( + search_results, parts, display = search_vector_store_direct( prompt, vector_db_id, vector_db_name, state ) if search_results: - all_search_results.append((vector_db_name, search_results)) + all_search_results.append((vector_db_name, display)) context_parts.extend(parts) # Step 2: Build messages (with or without RAG context) From 73d5490f2d18076ab752de53ad3c02379ffe0a3d Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Mon, 16 Feb 2026 17:12:53 +0000 Subject: [PATCH 18/18] chore: bump version to 0.2.32 --- deploy/helm/rag/Chart.yaml | 4 ++-- deploy/helm/rag/values.yaml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/deploy/helm/rag/Chart.yaml b/deploy/helm/rag/Chart.yaml index 0a10099..8da8df2 100644 --- a/deploy/helm/rag/Chart.yaml +++ b/deploy/helm/rag/Chart.yaml @@ -2,8 +2,8 @@ apiVersion: v2 name: rag description: A Helm chart for Kubernetes type: application -version: 0.2.30 -appVersion: "0.2.30" +version: 0.2.32 +appVersion: "0.2.32" dependencies: - name: pgvector diff --git a/deploy/helm/rag/values.yaml b/deploy/helm/rag/values.yaml index 1a36c59..59a0e95 100644 --- a/deploy/helm/rag/values.yaml +++ b/deploy/helm/rag/values.yaml @@ -3,7 +3,7 @@ replicaCount: 1 image: repository: quay.io/rh-ai-quickstart/llamastack-dist-ui pullPolicy: Always - tag: 0.2.30 + tag: 0.2.32 service: type: ClusterIP