Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions deploy/helm/rag/Chart.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ apiVersion: v2
name: rag
description: A Helm chart for Kubernetes
type: application
version: 0.2.38
appVersion: "0.2.38"
version: 0.2.39
appVersion: "0.2.39"

dependencies:
- name: pgvector
Expand All @@ -15,7 +15,7 @@ dependencies:
repository: https://rh-ai-quickstart.github.io/ai-architecture-charts
condition: llm-service.enabled
- name: configure-pipeline
version: 0.5.6
version: 0.5.7
repository: https://rh-ai-quickstart.github.io/ai-architecture-charts
condition: configure-pipeline.enabled
- name: ingestion-pipeline
Expand Down
2 changes: 1 addition & 1 deletion deploy/helm/rag/values.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ replicaCount: 1
image:
repository: quay.io/rh-ai-quickstart/llamastack-dist-ui
pullPolicy: Always
tag: 0.2.38
tag: 0.2.39

service:
type: ClusterIP
Expand Down
159 changes: 139 additions & 20 deletions frontend/llama_stack_ui/distribution/ui/modules/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@
# the root directory of this source tree.

import base64
import io
import json
import logging
import os
import re

import pandas as pd
import streamlit as st

logger = logging.getLogger(__name__)

"""
Utility functions for file processing and data conversion in the UI.
Expand Down Expand Up @@ -63,6 +66,37 @@ def clean_text(text):
return re.sub(r'\s+', ' ', text).strip()


def strip_file_citations(text):
"""
Remove file citation markers injected by the Responses API file_search tool.
Strips bare file ID references and bracket-style annotation markers.

Args:
text: Raw response text potentially containing citation markers

Returns:
str: Text with citation markers removed
"""
text = re.sub(r'file<[^>]+>', '', text)
text = re.sub(r'<\|file-[^|]*\|>', '', text)
text = re.sub(r'【[^】]*†[^】]*】', '', text)
text = re.sub(r' +', ' ', text)
return text


def strip_file_citations_streaming(text):
"""
Strip citations for streaming display. Removes complete citation markers
and also trims trailing partial patterns that haven't fully arrived yet,
preventing citation fragments from briefly flashing in the UI.
"""
text = strip_file_citations(text)
text = re.sub(r'<\|(?:f(?:i(?:l(?:e(?:-[^|]*)?)?)?)?)?\s*$', '', text)
text = re.sub(r'\bfile<[^>]*$', '', text)
text = re.sub(r'【[^】]*$', '', text)
return text


def get_vector_db_name(vector_db):
"""
Get the display name for a vector database.
Expand Down Expand Up @@ -94,6 +128,101 @@ def get_question_suggestions():
return {}


def fetch_available_shields(client):
"""
Fetch available safety shields from the LlamaStack server.

Args:
client: LlamaStack client instance

Returns:
List of shield identifier strings
"""
try:
shields_list = client.shields.list()
if shields_list:
return [s.identifier for s in shields_list]
except Exception as e:
logger.debug("Failed to fetch shields: %s", e)
return []


def run_input_shields(client, shield_ids, user_message):
"""
Run input safety shields on the user's message before processing.

Args:
client: LlamaStack client instance
shield_ids: List of shield identifiers to run
user_message: The user's input text

Returns:
Tuple of (is_blocked: bool, violation_message: str or None, shield_id: str or None)
"""
if not shield_ids:
return False, None, None

for shield_id in shield_ids:
try:
logger.debug("Running input shield: %s", shield_id)
shield_response = client.safety.run_shield(
shield_id=shield_id,
messages=[{"role": "user", "content": user_message}],
params={},
)
logger.debug("Input shield %s response: %s", shield_id, shield_response)
if hasattr(shield_response, "violation") and shield_response.violation:
violation_msg = getattr(
shield_response.violation, "user_message", "Content blocked by safety guardrail"
)
logger.warning("Input blocked by shield %s: %s", shield_id, violation_msg)
return True, violation_msg, shield_id
logger.debug("Input shield %s passed (no violation)", shield_id)
except Exception as e:
logger.warning("Error running input shield %s: %s", shield_id, e)
return False, None, None


def run_output_shields(client, shield_ids, user_message, assistant_response):
"""
Run output safety shields on the assistant's response after generation.

Args:
client: LlamaStack client instance
shield_ids: List of shield identifiers to run
user_message: The original user prompt
assistant_response: The generated assistant response text

Returns:
Tuple of (is_blocked: bool, violation_message: str or None, shield_id: str or None)
"""
if not shield_ids:
return False, None, None

for shield_id in shield_ids:
try:
logger.debug("Running output shield: %s", shield_id)
shield_response = client.safety.run_shield(
shield_id=shield_id,
messages=[
{"role": "user", "content": user_message},
{"role": "assistant", "content": assistant_response},
],
params={},
)
logger.debug("Output shield %s response: %s", shield_id, shield_response)
if hasattr(shield_response, "violation") and shield_response.violation:
violation_msg = getattr(
shield_response.violation, "user_message", "Response blocked by safety guardrail"
)
logger.warning("Output blocked by shield %s: %s", shield_id, violation_msg)
return True, violation_msg, shield_id
logger.debug("Output shield %s passed (no violation)", shield_id)
except Exception as e:
logger.warning("Error running output shield %s: %s", shield_id, e)
return False, None, None


def get_suggestions_for_databases(selected_dbs, all_vector_dbs):
"""
Get combined question suggestions for selected databases.
Expand All @@ -111,32 +240,22 @@ def get_suggestions_for_databases(selected_dbs, all_vector_dbs):
if not suggestions_map:
return []

# Build a mapping from displayed DB name to the full DB object so we can
# resolve all possible identifiers used by different backend versions.
db_name_to_obj = {
get_vector_db_name(vdb): vdb
# Create a mapping from vector_db_name to id
db_name_to_id = {
get_vector_db_name(vdb): vdb.id
for vdb in all_vector_dbs
}

for db_name in selected_dbs:
# Try several keys because the selected UI name may differ from the
# suggestion map key (e.g. vector_store_name/identifier/id/display name).
vdb = db_name_to_obj.get(db_name)
candidate_keys = []
if vdb:
candidate_keys.extend([
getattr(vdb, "vector_store_name", None),
getattr(vdb, "identifier", None),
getattr(vdb, "id", None),
getattr(vdb, "name", None),
])
candidate_keys.append(db_name)
# Get the id for this database name
db_id = db_name_to_id.get(db_name)

# Try both the id and the db_name as keys in the suggestions map
questions = None
for key in candidate_keys:
if key and key in suggestions_map:
questions = suggestions_map[key]
break
if db_id and db_id in suggestions_map:
questions = suggestions_map[db_id]
elif db_name in suggestions_map:
questions = suggestions_map[db_name]

if questions:
for question in questions:
Expand Down
Loading
Loading