diff --git a/deploy/helm/rag/Chart.yaml b/deploy/helm/rag/Chart.yaml index 0ab3b23..341a1b8 100644 --- a/deploy/helm/rag/Chart.yaml +++ b/deploy/helm/rag/Chart.yaml @@ -2,8 +2,8 @@ apiVersion: v2 name: rag description: A Helm chart for Kubernetes type: application -version: 0.2.38 -appVersion: "0.2.38" +version: 0.2.39 +appVersion: "0.2.39" dependencies: - name: pgvector @@ -15,7 +15,7 @@ dependencies: repository: https://rh-ai-quickstart.github.io/ai-architecture-charts condition: llm-service.enabled - name: configure-pipeline - version: 0.5.6 + version: 0.5.7 repository: https://rh-ai-quickstart.github.io/ai-architecture-charts condition: configure-pipeline.enabled - name: ingestion-pipeline diff --git a/deploy/helm/rag/values.yaml b/deploy/helm/rag/values.yaml index 1690843..576a098 100644 --- a/deploy/helm/rag/values.yaml +++ b/deploy/helm/rag/values.yaml @@ -3,7 +3,7 @@ replicaCount: 1 image: repository: quay.io/rh-ai-quickstart/llamastack-dist-ui pullPolicy: Always - tag: 0.2.38 + tag: 0.2.39 service: type: ClusterIP diff --git a/frontend/llama_stack_ui/distribution/ui/modules/utils.py b/frontend/llama_stack_ui/distribution/ui/modules/utils.py index 68cbc2d..6b2252c 100644 --- a/frontend/llama_stack_ui/distribution/ui/modules/utils.py +++ b/frontend/llama_stack_ui/distribution/ui/modules/utils.py @@ -5,13 +5,16 @@ # the root directory of this source tree. import base64 +import io import json +import logging import os import re import pandas as pd import streamlit as st +logger = logging.getLogger(__name__) """ Utility functions for file processing and data conversion in the UI. @@ -63,6 +66,37 @@ def clean_text(text): return re.sub(r'\s+', ' ', text).strip() +def strip_file_citations(text): + """ + Remove file citation markers injected by the Responses API file_search tool. + Strips bare file ID references and bracket-style annotation markers. + + Args: + text: Raw response text potentially containing citation markers + + Returns: + str: Text with citation markers removed + """ + text = re.sub(r'file<[^>]+>', '', text) + text = re.sub(r'<\|file-[^|]*\|>', '', text) + text = re.sub(r'【[^】]*†[^】]*】', '', text) + text = re.sub(r' +', ' ', text) + return text + + +def strip_file_citations_streaming(text): + """ + Strip citations for streaming display. Removes complete citation markers + and also trims trailing partial patterns that haven't fully arrived yet, + preventing citation fragments from briefly flashing in the UI. + """ + text = strip_file_citations(text) + text = re.sub(r'<\|(?:f(?:i(?:l(?:e(?:-[^|]*)?)?)?)?)?\s*$', '', text) + text = re.sub(r'\bfile<[^>]*$', '', text) + text = re.sub(r'【[^】]*$', '', text) + return text + + def get_vector_db_name(vector_db): """ Get the display name for a vector database. @@ -94,6 +128,101 @@ def get_question_suggestions(): return {} +def fetch_available_shields(client): + """ + Fetch available safety shields from the LlamaStack server. + + Args: + client: LlamaStack client instance + + Returns: + List of shield identifier strings + """ + try: + shields_list = client.shields.list() + if shields_list: + return [s.identifier for s in shields_list] + except Exception as e: + logger.debug("Failed to fetch shields: %s", e) + return [] + + +def run_input_shields(client, shield_ids, user_message): + """ + Run input safety shields on the user's message before processing. + + Args: + client: LlamaStack client instance + shield_ids: List of shield identifiers to run + user_message: The user's input text + + Returns: + Tuple of (is_blocked: bool, violation_message: str or None, shield_id: str or None) + """ + if not shield_ids: + return False, None, None + + for shield_id in shield_ids: + try: + logger.debug("Running input shield: %s", shield_id) + shield_response = client.safety.run_shield( + shield_id=shield_id, + messages=[{"role": "user", "content": user_message}], + params={}, + ) + logger.debug("Input shield %s response: %s", shield_id, shield_response) + if hasattr(shield_response, "violation") and shield_response.violation: + violation_msg = getattr( + shield_response.violation, "user_message", "Content blocked by safety guardrail" + ) + logger.warning("Input blocked by shield %s: %s", shield_id, violation_msg) + return True, violation_msg, shield_id + logger.debug("Input shield %s passed (no violation)", shield_id) + except Exception as e: + logger.warning("Error running input shield %s: %s", shield_id, e) + return False, None, None + + +def run_output_shields(client, shield_ids, user_message, assistant_response): + """ + Run output safety shields on the assistant's response after generation. + + Args: + client: LlamaStack client instance + shield_ids: List of shield identifiers to run + user_message: The original user prompt + assistant_response: The generated assistant response text + + Returns: + Tuple of (is_blocked: bool, violation_message: str or None, shield_id: str or None) + """ + if not shield_ids: + return False, None, None + + for shield_id in shield_ids: + try: + logger.debug("Running output shield: %s", shield_id) + shield_response = client.safety.run_shield( + shield_id=shield_id, + messages=[ + {"role": "user", "content": user_message}, + {"role": "assistant", "content": assistant_response}, + ], + params={}, + ) + logger.debug("Output shield %s response: %s", shield_id, shield_response) + if hasattr(shield_response, "violation") and shield_response.violation: + violation_msg = getattr( + shield_response.violation, "user_message", "Response blocked by safety guardrail" + ) + logger.warning("Output blocked by shield %s: %s", shield_id, violation_msg) + return True, violation_msg, shield_id + logger.debug("Output shield %s passed (no violation)", shield_id) + except Exception as e: + logger.warning("Error running output shield %s: %s", shield_id, e) + return False, None, None + + def get_suggestions_for_databases(selected_dbs, all_vector_dbs): """ Get combined question suggestions for selected databases. @@ -111,32 +240,22 @@ def get_suggestions_for_databases(selected_dbs, all_vector_dbs): if not suggestions_map: return [] - # Build a mapping from displayed DB name to the full DB object so we can - # resolve all possible identifiers used by different backend versions. - db_name_to_obj = { - get_vector_db_name(vdb): vdb + # Create a mapping from vector_db_name to id + db_name_to_id = { + get_vector_db_name(vdb): vdb.id for vdb in all_vector_dbs } for db_name in selected_dbs: - # Try several keys because the selected UI name may differ from the - # suggestion map key (e.g. vector_store_name/identifier/id/display name). - vdb = db_name_to_obj.get(db_name) - candidate_keys = [] - if vdb: - candidate_keys.extend([ - getattr(vdb, "vector_store_name", None), - getattr(vdb, "identifier", None), - getattr(vdb, "id", None), - getattr(vdb, "name", None), - ]) - candidate_keys.append(db_name) + # Get the id for this database name + db_id = db_name_to_id.get(db_name) + # Try both the id and the db_name as keys in the suggestions map questions = None - for key in candidate_keys: - if key and key in suggestions_map: - questions = suggestions_map[key] - break + if db_id and db_id in suggestions_map: + questions = suggestions_map[db_id] + elif db_name in suggestions_map: + questions = suggestions_map[db_name] if questions: for question in questions: diff --git a/frontend/llama_stack_ui/distribution/ui/page/playground/agent.py b/frontend/llama_stack_ui/distribution/ui/page/playground/agent.py index 31a6241..3937b02 100644 --- a/frontend/llama_stack_ui/distribution/ui/page/playground/agent.py +++ b/frontend/llama_stack_ui/distribution/ui/page/playground/agent.py @@ -13,7 +13,7 @@ import streamlit as st from llama_stack_ui.distribution.ui.modules.api import llama_stack_api -from llama_stack_ui.distribution.ui.modules.utils import clean_text, get_vector_db_name +from llama_stack_ui.distribution.ui.modules.utils import clean_text, get_vector_db_name, strip_file_citations, strip_file_citations_streaming, run_input_shields, run_output_shields logger = logging.getLogger(__name__) @@ -131,22 +131,8 @@ def handle_agent_output_item_done(chunk, state): item_type = getattr(item, 'type', None) if item_type == "file_search_call": - # File search results - if hasattr(item, 'results') and item.results: - display_results = [] - for r in item.results: - text = getattr(r, 'text', '') - attrs = getattr(r, 'attributes', {}) - source = attrs.get('source') or getattr(r, 'filename', 'unknown') - display_results.append({"source": source, "text": clean_text(text)}) - state.tool_results.append({ - 'title': '📄 File Search Results', - 'type': 'json', - 'content': display_results - }) - with state.containers.tool_results: - with st.expander("📄 File Search Results", expanded=False): - st.json(display_results) + # Results are fetched explicitly per-DB after streaming completes + pass elif item_type == "web_search_call": # Web search - API doesn't expose raw results, just status @@ -235,13 +221,90 @@ def handle_chunk_completed(chunk): def handle_chunk_done(chunk, state): """Handle done chunk and finalize response.""" - has_output = ( - hasattr(chunk, 'response') and - hasattr(chunk.response, 'output_text') and - chunk.response.output_text + if not hasattr(chunk, 'response'): + return + + response = chunk.response + + if hasattr(response, 'output_text') and response.output_text: + state.full_response = strip_file_citations(response.output_text) + + +def search_vector_stores_fallback(prompt, selected_vector_dbs, state): + """ + Explicitly search vector stores when the Responses API stream didn't + include file_search results (common on subsequent conversation turns). + """ + client = llama_stack_api.client + vector_dbs = list(client.vector_stores.list() or []) + + selected_vdb_objects = [ + vdb for vdb in vector_dbs + if get_vector_db_name(vdb) in selected_vector_dbs + ] + if not selected_vdb_objects: + return + + db_label = "vector store" if len(selected_vdb_objects) == 1 else "vector stores" + status_msg = ( + f"🛠 :grey[_Using file_search tool with {db_label}: " + f"{', '.join(selected_vector_dbs)}_]" ) - if has_output: - state.full_response = chunk.response.output_text + state.tool_status = status_msg + with state.containers.tool_status: + st.markdown(status_msg) + + for vdb in selected_vdb_objects: + vdb_id = vdb.id + vdb_name = get_vector_db_name(vdb) + + try: + search_response = client.vector_stores.search( + vector_store_id=vdb_id, + query=prompt, + ) + except Exception as e: + logger.debug("Fallback search failed for %s: %s", vdb_id, e) + continue + + search_results = None + if hasattr(search_response, 'data') and search_response.data: + search_results = search_response.data + elif hasattr(search_response, 'chunks') and search_response.chunks: + search_results = search_response.chunks + elif hasattr(search_response, 'results') and search_response.results: + search_results = search_response.results + + if not search_results: + continue + + display_results = [] + for result in search_results: + text = None + if hasattr(result, 'content') and isinstance(result.content, list): + for content_item in result.content: + if hasattr(content_item, 'text'): + text = content_item.text + break + elif hasattr(result, 'content') and isinstance(result.content, str): + text = result.content + elif hasattr(result, 'text'): + text = result.text + + if text: + attrs = getattr(result, 'attributes', {}) + source = attrs.get('source') or getattr(result, 'filename', 'unknown') + display_results.append({"source": source, "text": clean_text(text)}) + + if display_results: + state.tool_results.append({ + 'title': f"📄 File Search Results from '{vdb_name}'", + 'type': 'json', + 'content': display_results + }) + with state.containers.tool_results: + with st.expander(f"📄 File Search Results from '{vdb_name}'", expanded=False): + st.json(display_results) def process_chunk_by_type(chunk, state, selected_vector_dbs): @@ -272,7 +335,7 @@ def process_chunk_by_type(chunk, state, selected_vector_dbs): # Handle message content elif chunk_type == "response.output_text.delta": if hasattr(chunk, 'delta') and chunk.delta: - state.update_message(chunk.delta) + state.update_message(chunk.delta, display_fn=strip_file_citations_streaming) # Handle errors elif chunk_type == "response.failed": @@ -305,18 +368,6 @@ def stream_agent_response(response, state, selected_vector_dbs): logger.debug("Chunk #%s: type=%s", chunk_count, getattr(chunk, 'type', 'NO_TYPE')) logger.debug(" -> Full chunk: %s", chunk) - # Some server failures arrive as an error payload with type=None. - if hasattr(chunk, 'error') and chunk.error: - if isinstance(chunk.error, dict): - error_msg = chunk.error.get("message", "Unknown error") - else: - error_msg = str(chunk.error) - st.error( - f"❌ Error: {error_msg}" - ) - logger.debug("Response stream error: %s", error_msg) - break - if hasattr(chunk, 'type'): should_stop = process_chunk_by_type(chunk, state, selected_vector_dbs) if should_stop: @@ -325,6 +376,17 @@ def stream_agent_response(response, state, selected_vector_dbs): def save_agent_response_to_session(state): """Save agent response to session state.""" + if state.guardrail_blocked: + response_dict = { + "role": "assistant", + "content": f"🛡️ {state.guardrail_blocked}", + "guardrail_blocked": state.guardrail_blocked, + "stop_reason": "end_of_message", + } + st.session_state.messages.append(response_dict) + return + + state.full_response = strip_file_citations(state.full_response) state.finalize_reasoning() state.finalize_message() @@ -344,8 +406,34 @@ def save_agent_response_to_session(state): st.session_state.messages.append(response_dict) +def _get_live_shields(config): + """Read guardrail selections directly from widget state to avoid stale config.""" + input_shields = st.session_state.get("guardrail_input_selector", config.guardrails.input_shields) + output_shields = st.session_state.get("guardrail_output_selector", config.guardrails.output_shields) + return input_shields or [], output_shields or [] + + def agent_process_prompt(prompt, state, config): """Agent-based mode: Use Responses API with automatic tool calling.""" + input_shields, output_shields = _get_live_shields(config) + + # Run input guardrails before calling the API + if input_shields: + guardrail_status = state.containers.tool_status.empty() + guardrail_status.markdown("🛡️ :grey[_Running input guardrail check..._]") + is_blocked, violation_msg, blocked_shield = run_input_shields( + llama_stack_api.client, input_shields, prompt + ) + if is_blocked: + guardrail_status.empty() + blocked_msg = f"**Input Guardrail Triggered** (`{blocked_shield}`): {violation_msg}" + st.warning(blocked_msg, icon="🛡️") + state.guardrail_blocked = blocked_msg + state.full_response = "" + save_agent_response_to_session(state) + return + guardrail_status.empty() + # Build tools list from selected toolgroups tools = build_response_tools( config.toolgroup_selection, @@ -371,6 +459,7 @@ def agent_process_prompt(prompt, state, config): request_kwargs["tools"] = tools logger.debug("Request: %s", request_kwargs) + state.show_thinking() try: response = llama_stack_api.client.responses.create(**request_kwargs) except Exception as e: # pylint: disable=broad-exception-caught @@ -381,5 +470,29 @@ def agent_process_prompt(prompt, state, config): # Stream response and update UI stream_agent_response(response, state, config.selected_vector_dbs) + # Run output guardrails after response is fully streamed but before search results + if output_shields and state.full_response: + guardrail_status = state.containers.tool_status.empty() + guardrail_status.markdown("🛡️ :grey[_Running output guardrail check..._]") + is_blocked, violation_msg, blocked_shield = run_output_shields( + llama_stack_api.client, output_shields, prompt, state.full_response + ) + guardrail_status.empty() + if is_blocked: + blocked_msg = f"**Output Guardrail Triggered** (`{blocked_shield}`): {violation_msg}" + state.containers.clear_tools() + state.containers.message.empty() + st.warning(blocked_msg, icon="🛡️") + state.guardrail_blocked = blocked_msg + state.full_response = "" + state.tool_results = [] + state.tool_status = None + save_agent_response_to_session(state) + return + + # Fetch file search results only if response was not blocked + if config.selected_vector_dbs: + search_vector_stores_fallback(prompt, config.selected_vector_dbs, state) + # Save response to session save_agent_response_to_session(state) diff --git a/frontend/llama_stack_ui/distribution/ui/page/playground/chat.py b/frontend/llama_stack_ui/distribution/ui/page/playground/chat.py index 8049522..f0971af 100644 --- a/frontend/llama_stack_ui/distribution/ui/page/playground/chat.py +++ b/frontend/llama_stack_ui/distribution/ui/page/playground/chat.py @@ -18,6 +18,7 @@ from llama_stack_ui.distribution.ui.modules.utils import ( get_suggestions_for_databases, get_vector_db_name, + fetch_available_shields, ) from llama_stack_ui.distribution.ui.page.playground.agent import ( agent_process_prompt, @@ -42,6 +43,10 @@ def render_tool_results(tool_results): def render_message(msg): """Render a single message in chat history.""" with st.chat_message(msg['role']): + if msg.get('guardrail_blocked'): + st.warning(msg['guardrail_blocked'], icon="🛡️") + return + # Display tool status if present if msg.get('tool_status'): st.markdown(msg['tool_status']) @@ -56,7 +61,8 @@ def render_message(msg): st.markdown(msg['reasoning']) # Display the final answer - st.markdown(msg['content']) + if msg.get('content'): + st.markdown(msg['content']) def render_history(): @@ -70,13 +76,20 @@ def render_history(): def fetch_models_and_tools(): - """Fetch and categorize models and toolgroups from LlamaStack.""" + """Fetch and categorize models, toolgroups, and shields from LlamaStack.""" client = llama_stack_api.client - # Fetch models + # Fetch available shields first so we can exclude them from the model list + shields_list = fetch_available_shields(client) + logger.debug("Available shields: %s", shields_list) + shields_set = set(shields_list) + + # Fetch models, excluding guardrail/shield models models = client.models.list() - model_list = [model.id for model in models if model.custom_metadata.get("model_type") == "llm"] - + model_list = [ + model.identifier for model in models + if model.api_model_type == "llm" and model.identifier not in shields_set + ] # Fetch and categorize toolgroups tool_groups = client.toolgroups.list() @@ -90,7 +103,7 @@ def fetch_models_and_tools(): builtin_tools_list = [tool for tool in tool_groups_list if not tool.startswith("mcp::")] logger.debug("Built-in tools: %s", builtin_tools_list) - return model_list, builtin_tools_list, mcp_tools_list + return model_list, builtin_tools_list, mcp_tools_list, shields_list def render_toolgroup_selection(builtin_tools_list, mcp_tools_list, selected_vector_dbs, @@ -159,6 +172,14 @@ def reset_agent(): st.cache_resource.clear() +def reset_conversation(): + """Reset conversation messages without clearing widget configuration.""" + keys_to_clear = ["messages", "conversation_id", "show_more_questions", "selected_question", "direct_vector_dbs"] + for key in keys_to_clear: + st.session_state.pop(key, None) + st.cache_resource.clear() + + def create_vector_db_callbacks(processing_mode, vector_dbs): """Create callbacks for vector DB and toolgroup synchronization.""" def on_vector_db_change(): @@ -197,7 +218,35 @@ def on_toolgroup_change(): return on_vector_db_change, on_toolgroup_change -def render_sidebar_configuration(model_list, builtin_tools_list, mcp_tools_list): +def render_guardrails_selection(shields_list): + """Render guardrail selection UI in the sidebar.""" + st.subheader("Guardrails") + + if not shields_list: + st.caption("No guardrails available on this server.") + return [], [] + + input_shields = st.multiselect( + "Input Guardrails", + options=shields_list, + key="guardrail_input_selector", + on_change=reset_conversation, + help="Safety guardrails to check user input before processing.", + ) + + output_shields = st.multiselect( + "Output Guardrails", + options=shields_list, + key="guardrail_output_selector", + on_change=reset_conversation, + help="Safety guardrails to check assistant response after generation.", + ) + + return input_shields, output_shields + + +def render_sidebar_configuration(model_list, builtin_tools_list, mcp_tools_list, + shields_list): """Render sidebar configuration and return selected parameters.""" st.title("Configuration") st.subheader("Model") @@ -239,6 +288,9 @@ def render_sidebar_configuration(model_list, builtin_tools_list, mcp_tools_list) on_toolgroup_change, reset_agent ) + # Guardrails Selection + input_shields, output_shields = render_guardrails_selection(shields_list) + # Sampling Parameters st.subheader("Sampling Parameters") temperature = st.slider( @@ -287,6 +339,8 @@ def render_sidebar_configuration(model_list, builtin_tools_list, mcp_tools_list) 'max_infer_iters': max_infer_iters, 'max_tokens': max_tokens, 'system_prompt': system_prompt, + 'input_shields': input_shields, + 'output_shields': output_shields, } @@ -364,6 +418,13 @@ class SamplingParams: max_tokens: int +@dataclass +class GuardrailConfig: + """Configuration for safety guardrails.""" + input_shields: list + output_shields: list + + @dataclass class ChatConfig: """Configuration for chat processing.""" @@ -374,6 +435,7 @@ class ChatConfig: toolgroup_selection: list selected_vector_dbs: list sampling: SamplingParams + guardrails: GuardrailConfig # ============================================================================ @@ -386,12 +448,20 @@ class Containers: Note: Containers are created in visual display order (top to bottom). """ def __init__(self): - # Create containers in visual order: tools -> reasoning -> message - self.tool_status = st.container() - self.tool_results = st.container() + # Create containers in visual order: tools -> thinking -> reasoning -> message + self._tool_status_slot = st.empty() + self.tool_status = self._tool_status_slot.container() + self._tool_results_slot = st.empty() + self.tool_results = self._tool_results_slot.container() + self.thinking = st.container() self.reasoning = st.empty() self.message = st.empty() + def clear_tools(self): + """Clear all rendered tool status and results.""" + self._tool_status_slot.empty() + self._tool_results_slot.empty() + class ResponseState: """ State container for assistant response UI components and data. @@ -401,6 +471,9 @@ def __init__(self): # UI containers (grouped) self.containers = Containers() + # Thinking indicator state + self._thinking_active = False + # Reasoning state self.reasoning_text = "" self.reasoning_placeholder = None @@ -412,6 +485,9 @@ def __init__(self): # Response content self.full_response = "" + # Guardrail block message (set when a shield blocks the request/response) + self.guardrail_blocked = None + @property def has_reasoning(self): """Check if reasoning has been started.""" @@ -422,8 +498,21 @@ def tool_used(self): """Check if any tool has been used.""" return self.tool_status is not None + def show_thinking(self): + """Show a 'Thinking...' progress indicator.""" + self._thinking_active = True + self._thinking_placeholder = self.containers.thinking.empty() + self._thinking_placeholder.status("Thinking...", state="running") + + def dismiss_thinking(self): + """Dismiss the thinking indicator if active.""" + if self._thinking_active: + self._thinking_active = False + self._thinking_placeholder.empty() + def update_reasoning(self, delta_text): """Add reasoning text and update display.""" + self.dismiss_thinking() self.reasoning_text += delta_text # Create reasoning expander on first delta @@ -441,10 +530,12 @@ def finalize_reasoning(self): if self.reasoning_placeholder and self.reasoning_text: self.reasoning_placeholder.markdown(self.reasoning_text) - def update_message(self, delta_text): + def update_message(self, delta_text, display_fn=None): """Add message text and update display.""" + self.dismiss_thinking() self.full_response += delta_text - self.containers.message.markdown(self.full_response + "▌") + display_text = display_fn(self.full_response) if display_fn else self.full_response + self.containers.message.markdown(display_text + "▌") def finalize_message(self): """Remove cursor from message display.""" @@ -551,13 +642,13 @@ def tool_chat_page(): """Main chat page with RAG support in Direct and Agent-based modes.""" st.title("💬 Chat") - # Fetch models and tools - model_list, builtin_tools_list, mcp_tools_list = fetch_models_and_tools() + # Fetch models, tools, and shields + model_list, builtin_tools_list, mcp_tools_list, shields_list = fetch_models_and_tools() # Render sidebar and get configuration with st.sidebar: sidebar_config = render_sidebar_configuration( - model_list, builtin_tools_list, mcp_tools_list + model_list, builtin_tools_list, mcp_tools_list, shields_list ) # Initialize session state @@ -576,7 +667,11 @@ def tool_chat_page(): top_k=sidebar_config['top_k'], max_infer_iters=sidebar_config['max_infer_iters'], max_tokens=sidebar_config['max_tokens'], - ) + ), + guardrails=GuardrailConfig( + input_shields=sidebar_config['input_shields'], + output_shields=sidebar_config['output_shields'], + ), ) # Display chat history diff --git a/frontend/llama_stack_ui/distribution/ui/page/playground/direct.py b/frontend/llama_stack_ui/distribution/ui/page/playground/direct.py index daff6a2..80dfd09 100644 --- a/frontend/llama_stack_ui/distribution/ui/page/playground/direct.py +++ b/frontend/llama_stack_ui/distribution/ui/page/playground/direct.py @@ -14,7 +14,7 @@ import streamlit as st from llama_stack_ui.distribution.ui.modules.api import llama_stack_api -from llama_stack_ui.distribution.ui.modules.utils import clean_text, get_vector_db_name +from llama_stack_ui.distribution.ui.modules.utils import clean_text, get_vector_db_name, run_input_shields, run_output_shields logger = logging.getLogger(__name__) @@ -51,7 +51,6 @@ def extract_text_from_search_result(result): def search_vector_store_direct(prompt, vector_db_id, vector_db_name, top_k, state): """Search vector store and extract context for Direct mode.""" - search_results = [] context_parts = [] display_results = [] @@ -71,6 +70,8 @@ def search_vector_store_direct(prompt, vector_db_id, vector_db_name, top_k, stat logger.debug("Search response: %s", search_response) # Extract search results from response + search_results = [] + if hasattr(search_response, 'data') and search_response.data: search_results = search_response.data elif hasattr(search_response, 'chunks') and search_response.chunks: @@ -89,8 +90,13 @@ def search_vector_store_direct(prompt, vector_db_id, vector_db_name, top_k, stat context_parts.append(f"[Source: {source}]: {text_content}") display_results.append({"source": source, "text": text_content}) + state.tool_results.append({ + 'title': f"📄 File Search Results from '{vector_db_name}'", + 'type': 'json', + 'content': display_results + }) with state.containers.tool_results: - with st.expander(f"📄 Search Results from '{vector_db_name}'", expanded=False): + with st.expander(f"📄 File Search Results from '{vector_db_name}'", expanded=False): st.json(display_results) logger.debug("Built context with %s documents", len(context_parts)) @@ -135,7 +141,7 @@ def stream_completions_direct(completion_response, state): """Stream chunks from Completions API and update state.""" for chunk in completion_response: logger.debug("Completion chunk: %s", chunk) - if hasattr(chunk, 'choices') and len(chunk.choices) > 0: + if hasattr(chunk, 'choices') and chunk.choices: delta = chunk.choices[0].delta # Handle reasoning content (for models that support it like R1) @@ -149,6 +155,16 @@ def stream_completions_direct(completion_response, state): def save_direct_response_to_session(state, all_search_results): """Save direct response to session state.""" + if state.guardrail_blocked: + response_dict = { + "role": "assistant", + "content": f"🛡️ {state.guardrail_blocked}", + "guardrail_blocked": state.guardrail_blocked, + "stop_reason": "end_of_message", + } + st.session_state.messages.append(response_dict) + return + state.finalize_reasoning() state.finalize_message() @@ -167,7 +183,7 @@ def save_direct_response_to_session(state, all_search_results): db_names = [name for name, _ in all_search_results] response_dict["tool_results"] = [ { - 'title': f'📄 Search Results from \'{name}\'', + 'title': f'📄 File Search Results from \'{name}\'', 'type': 'json', 'content': display } @@ -184,8 +200,16 @@ def save_direct_response_to_session(state, all_search_results): # Direct Mode - Main Function # ============================================================================ +def _get_live_shields(config): + """Read guardrail selections directly from widget state to avoid stale config.""" + input_shields = st.session_state.get("guardrail_input_selector", config.guardrails.input_shields) + output_shields = st.session_state.get("guardrail_output_selector", config.guardrails.output_shields) + return input_shields or [], output_shields or [] + + def direct_process_prompt(prompt, state, config): """Direct mode: Manual RAG with completions API.""" + input_shields, output_shields = _get_live_shields(config) context_parts = [] all_search_results = [] @@ -194,7 +218,24 @@ def direct_process_prompt(prompt, state, config): logger.debug("No vector DB selected - normal chat mode") try: - # Step 1: Search each selected vector store + # Step 0: Run input guardrails + if input_shields: + guardrail_status = state.containers.tool_status.empty() + guardrail_status.markdown("🛡️ :grey[_Running input guardrail check..._]") + is_blocked, violation_msg, blocked_shield = run_input_shields( + llama_stack_api.client, input_shields, prompt + ) + if is_blocked: + guardrail_status.empty() + blocked_msg = f"**Input Guardrail Triggered** (`{blocked_shield}`): {violation_msg}" + st.warning(blocked_msg, icon="🛡️") + state.guardrail_blocked = blocked_msg + state.full_response = "" + save_direct_response_to_session(state, []) + return + guardrail_status.empty() + + # Step 1: Search each selected vector store (renders results immediately) for vector_db in vector_dbs: vector_db_id = vector_db.id vector_db_name = get_vector_db_name(vector_db) @@ -205,6 +246,12 @@ def direct_process_prompt(prompt, state, config): all_search_results.append((vector_db_name, display)) context_parts.extend(parts) + # Update tool status to final state + if all_search_results: + db_names = [name for name, _ in all_search_results] + status_msg = f"🛠 :grey[_Searched vector stores: {', '.join(db_names)}_]" + state.tool_status = status_msg + # Step 2: Build messages (with or without RAG context) messages = build_rag_messages(prompt, context_parts, config.system_prompt) @@ -213,6 +260,7 @@ def direct_process_prompt(prompt, state, config): for i, msg in enumerate(messages): logger.debug(" Message %s (%s): %s...", i, msg['role'], msg['content'][:200]) + state.show_thinking() completion_response = llama_stack_api.client.chat.completions.create( model=config.model, messages=messages, @@ -224,7 +272,25 @@ def direct_process_prompt(prompt, state, config): # Step 4: Stream response and update UI stream_completions_direct(completion_response, state) - # Step 5: Save to session + # Step 5: Run output guardrails + if output_shields and state.full_response: + guardrail_status = state.containers.tool_status.empty() + guardrail_status.markdown("🛡️ :grey[_Running output guardrail check..._]") + is_blocked, violation_msg, blocked_shield = run_output_shields( + llama_stack_api.client, output_shields, prompt, state.full_response + ) + guardrail_status.empty() + if is_blocked: + blocked_msg = f"**Output Guardrail Triggered** (`{blocked_shield}`): {violation_msg}" + state.containers.clear_tools() + state.containers.message.empty() + st.warning(blocked_msg, icon="🛡️") + state.guardrail_blocked = blocked_msg + state.full_response = "" + save_direct_response_to_session(state, []) + return + + # Step 6: Save to session save_direct_response_to_session(state, all_search_results) except Exception as e: