diff --git a/frontend/src/components/conversation-view/chat-session.tsx b/frontend/src/components/conversation-view/chat-session.tsx index 880f7efa..04cfa015 100644 --- a/frontend/src/components/conversation-view/chat-session.tsx +++ b/frontend/src/components/conversation-view/chat-session.tsx @@ -13,23 +13,52 @@ import { AttachmentBubble } from "@/components/conversation-view/attachment-bubb import { MessageError } from "@/components/conversation-view/message-error.tsx"; import { MessageBubbleTyping } from "@/components/conversation-view/message-bubble-typing.tsx"; import { ChatWindow } from "@/components/conversation-view/chat-window.tsx"; +import { ProductWidget } from "@/components/conversation-view/product-widget.tsx"; -import type { AgentDetails } from "@/lib/types"; -import type { Conversation as ConversationSchema } from '@/lib/types.ts'; -import type { Message, MessageFile } from "@/components/conversation-view/message-composer.tsx"; +import type {AgentDetails, Conversation as ConversationSchema} from '@/lib/types.ts'; +import type {Message, MessageFile} from "@/components/conversation-view/message-composer.tsx"; +import { Message as ApiMessage } from "src/lib/types.ts" -export type FileMessageProps = { +export type BaseMessageProps = { id: number | null; + type: "file" | "human" | "ai" | "system" | "widget"; +} + +export type FileMessageProps = BaseMessageProps & { type: "file"; file_name: string; file_type: string; }; +export type WidgetMessageProps = BaseMessageProps & { + type: "widget"; + widget_data: WidgetData; + isStreaming: boolean; +}; -export type MessageProps = { +export type TextMessage = BaseMessageProps & { type: "human" | "ai" | "system"; text: string; - id: number | null; -} | FileMessageProps; + isStreaming?: boolean; +} +export type MessageProps = FileMessageProps | WidgetMessageProps | TextMessage; + +export interface ProductData { + id: number; + name: string; + sku: string; + slug: string; + description: string; + categories: string; + properties: string; + price?: number; +} + +export interface WidgetData { + widget_type: string; + message?: string; + data?: ProductData[]; +} + interface ChatSessionProps { pendingMessage?: Message; @@ -155,6 +184,67 @@ export function ChatSession({ pendingMessage }: ChatSessionProps) { action: () => { setAgentAction(data.data.output); }, + product_widget_start: () => { + setIsLoading(false); + setIsAgentLoading(false); + + let widgetData; + try { + widgetData = typeof data.data === 'string' ? JSON.parse(data.data) : data.data; + } catch { + widgetData = { message: String(data.data), number_of_products: 0 }; + } + + const newWidget: WidgetMessageProps = { + type: "widget", + id: null, + widget_data: { + widget_type: "product_widget", + message: widgetData.message || '', + data: [], + }, + isStreaming: true, + }; + + setMessages((prevMessages) => [...prevMessages, newWidget]); + scrollToLastMessage(); + }, + product_widget_end: () => { + setMessages((prevMessages) => { + const lastMessage = prevMessages[prevMessages.length - 1]; + if (lastMessage && lastMessage.type === "widget" && lastMessage.isStreaming) { + return [ + ...prevMessages.slice(0, -1), + { ...lastMessage, isStreaming: false } + ]; + } + return prevMessages; + }); + }, + product_widget_product: () => { + const productData = data.data.chunk as ProductData; + + setMessages((prevMessages) => { + const lastMessage = prevMessages[prevMessages.length - 1]; + + if (lastMessage && lastMessage.type === "widget" && lastMessage.widget_data) { + const updatedData = [...(lastMessage.widget_data.data || []), productData]; + return [ + ...prevMessages.slice(0, -1), + { + ...lastMessage, + widget_data: { + ...lastMessage.widget_data, + data: updatedData + } + } + ]; + } + return prevMessages; + }); + + scrollToLastMessage(); + }, error: () => { setIsLoading(false); setIsAgentLoading(false); @@ -219,6 +309,23 @@ export function ChatSession({ pendingMessage }: ChatSessionProps) { }); }; + const createMessageFromResponse = (response: ApiMessage): MessageProps => { + if (response.widget_data) { + return { + type: "widget", + id: response.id, + widget_data: response.widget_data as WidgetData, + isStreaming: false, + } as WidgetMessageProps; + } else { + return { + type: response.type as "ai" | "system", + text: response.text, + id: response.id, + } as TextMessage; + } + }; + const updateTaskStatus = async (conversationId: number, taskHandle: TaskHandle, streaming: boolean) => { if (streaming) { return; @@ -229,10 +336,8 @@ export function ChatSession({ pendingMessage }: ChatSessionProps) { if (response) { setIsAgentLoading(false); - setMessages((prev) => [ - ...prev, - { type: "ai", text: response.text, id: response.id } - ]); + const message = createMessageFromResponse(response); + setMessages((prev) => [...prev, message]); setIsLoading(false); scrollToLastMessage(); } else { @@ -317,7 +422,6 @@ export function ChatSession({ pendingMessage }: ChatSessionProps) {
{messages.map((message, index) => { const inMessageGroup = isUserMessage(message) && index > 0 && isUserMessage(messages[index - 1]); - if (message.type === "system") { return ( @@ -329,20 +433,42 @@ export function ChatSession({ pendingMessage }: ChatSessionProps) { ); } - - if (message.type === 'ai' && message.text === '') { - return (); + if (message.type === 'ai' && message.text === '' && index === messages.length) { + return (); } - - return ( - - ); + ); + } + if (message.type === "widget" && message.widget_data?.widget_type === "product_widget") { + return ( + + ); + } + if (message.type === "human") { + return ( + + ); + } })} {isAgentLoading && }
diff --git a/frontend/src/components/conversation-view/message-bubble.tsx b/frontend/src/components/conversation-view/message-bubble.tsx index 792bb9d8..c7c23264 100644 --- a/frontend/src/components/conversation-view/message-bubble.tsx +++ b/frontend/src/components/conversation-view/message-bubble.tsx @@ -19,7 +19,14 @@ export function MessageBubble({ text, variant, questionId, inMessageGroup }: Mes useEffect(() => { const processText = async () => { - const rawHtml = await marked(text); + const renderer = new marked.Renderer(); + renderer.link = ({ href, title, tokens }) => { + const text = tokens.map(token => token.raw).join(''); + const titleAttr = title ? ` title="${title}"` : ''; + return `${text}`; + }; + + const rawHtml = await marked(text, { renderer }); const cleanHtml = DOMPurify.sanitize(rawHtml); setSanitizedHtml(cleanHtml); }; diff --git a/frontend/src/components/conversation-view/product-card.tsx b/frontend/src/components/conversation-view/product-card.tsx new file mode 100644 index 00000000..09d88e3e --- /dev/null +++ b/frontend/src/components/conversation-view/product-card.tsx @@ -0,0 +1,103 @@ +import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card"; +import { Badge } from "@/components/ui/badge"; + +export interface ProductCardProps { + id: number; + name: string; + sku: string; + slug: string; + description: string; + categories: string; + properties: string; + price?: number; +} + +export function ProductCard({ + name, + sku, + description, + categories, + properties, + price +}: ProductCardProps) { + let parsedCategories: string[] = []; + if (categories) { + try { + const parsed = JSON.parse(categories); + parsedCategories = Array.isArray(parsed) ? parsed : [String(parsed)]; + } catch { + parsedCategories = categories.split(',').map(c => c.trim()).filter(Boolean); + } + } + + let parsedProperties: Record = {}; + if (properties) { + try { + const parsed = JSON.parse(properties); + if (typeof parsed === 'object' && parsed !== null) { + parsedProperties = parsed; + } + } catch { + properties.split(';').forEach(pair => { + const [key, value] = pair.split('->').map(s => s.trim()); + if (key && value) { + parsedProperties[key] = value; + } + }); + } + } + + return ( + + +
+ {name} + {price !== undefined && ( + + ${price.toFixed(2)} + + )} +
+ + SKU: {sku} + +
+ + {description && ( +

+ {description} +

+ )} + + {parsedCategories.length > 0 && ( +
+
+ {parsedCategories.slice(0, 3).map((category: string, index: number) => ( + + {category} + + ))} + {parsedCategories.length > 3 && ( + + +{parsedCategories.length - 3} more + + )} +
+
+ )} + + {Object.keys(parsedProperties).length > 0 && ( +
+ {Object.entries(parsedProperties).slice(0, 3).map(([key, value], index) => ( +
+ {key}: + {String(value)} +
+ ))} +
+ )} +
+
+ ); +} + diff --git a/frontend/src/components/conversation-view/product-widget.tsx b/frontend/src/components/conversation-view/product-widget.tsx new file mode 100644 index 00000000..a7a204e8 --- /dev/null +++ b/frontend/src/components/conversation-view/product-widget.tsx @@ -0,0 +1,116 @@ +import { useState, useEffect } from 'react'; +import { ProductCard, ProductCardProps } from './product-card'; +import { BaseBubble } from './base-bubble'; +import { marked } from 'marked'; +import DOMPurify from 'dompurify'; + +export interface ProductWidgetProps { + message?: string; + products: ProductCardProps[]; + isStreaming?: boolean; + expectedProductCount?: number; +} + +function ProductPlaceholder() { + return ( +
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ ); +} + +export function ProductWidget({ + message, + products, + isStreaming = false, + expectedProductCount = 0 +}: ProductWidgetProps) { + const [sanitizedHtml, setSanitizedHtml] = useState(""); + const [shouldAnimate, setShouldAnimate] = useState(isStreaming); + + useEffect(() => { + if (isStreaming) { + setShouldAnimate(true); + } + }, [isStreaming]); + + useEffect(() => { + if (message) { + const processText = async () => { + const renderer = new marked.Renderer(); + renderer.link = ({ href, title, tokens }) => { + const text = tokens.map(token => token.raw).join(''); + const titleAttr = title ? ` title="${title}"` : ''; + return `${text}`; + }; + + const rawHtml = await marked(message, { renderer }); + const cleanHtml = DOMPurify.sanitize(rawHtml); + setSanitizedHtml(cleanHtml); + }; + + processText(); + } + }, [message]); + + const totalSlots = expectedProductCount || products.length; + const productsReceived = products.length; + const productDelayMS = 500 + + return ( + + {message && ( +
+ )} + +
+ {products.map((product, index) => ( +
+ {shouldAnimate ? ( + <> +
+ +
+
+ +
+ + ) : ( + + )} +
+ ))} + + {isStreaming && Array.from({ length: Math.max(0, totalSlots - productsReceived) }).map((_, index) => ( + + ))} +
+ + ); +} + diff --git a/frontend/src/components/ui/badge.tsx b/frontend/src/components/ui/badge.tsx new file mode 100644 index 00000000..690e3de6 --- /dev/null +++ b/frontend/src/components/ui/badge.tsx @@ -0,0 +1,37 @@ +import * as React from "react" +import { cva, type VariantProps } from "class-variance-authority" + +import { cn } from "@/lib/utils" + +const badgeVariants = cva( + "inline-flex items-center rounded-full border px-2.5 py-0.5 text-xs font-semibold transition-colors focus:outline-none focus:ring-2 focus:ring-ring focus:ring-offset-2", + { + variants: { + variant: { + default: + "border-transparent bg-primary text-primary-foreground hover:bg-primary/80", + secondary: + "border-transparent bg-secondary text-secondary-foreground hover:bg-secondary/80", + destructive: + "border-transparent bg-destructive text-destructive-foreground hover:bg-destructive/80", + outline: "text-foreground", + }, + }, + defaultVariants: { + variant: "default", + }, + } +) + +export interface BadgeProps + extends React.HTMLAttributes, + VariantProps {} + +function Badge({ className, variant, ...props }: BadgeProps) { + return ( +
+ ) +} + +export { Badge, badgeVariants } + diff --git a/frontend/src/components/ui/card.tsx b/frontend/src/components/ui/card.tsx new file mode 100644 index 00000000..5e494017 --- /dev/null +++ b/frontend/src/components/ui/card.tsx @@ -0,0 +1,80 @@ +import * as React from "react" + +import { cn } from "@/lib/utils" + +const Card = React.forwardRef< + HTMLDivElement, + React.HTMLAttributes +>(({ className, ...props }, ref) => ( +
+)) +Card.displayName = "Card" + +const CardHeader = React.forwardRef< + HTMLDivElement, + React.HTMLAttributes +>(({ className, ...props }, ref) => ( +
+)) +CardHeader.displayName = "CardHeader" + +const CardTitle = React.forwardRef< + HTMLParagraphElement, + React.HTMLAttributes +>(({ className, ...props }, ref) => ( +

+)) +CardTitle.displayName = "CardTitle" + +const CardDescription = React.forwardRef< + HTMLParagraphElement, + React.HTMLAttributes +>(({ className, ...props }, ref) => ( +

+)) +CardDescription.displayName = "CardDescription" + +const CardContent = React.forwardRef< + HTMLDivElement, + React.HTMLAttributes +>(({ className, ...props }, ref) => ( +

+)) +CardContent.displayName = "CardContent" + +const CardFooter = React.forwardRef< + HTMLDivElement, + React.HTMLAttributes +>(({ className, ...props }, ref) => ( +
+)) +CardFooter.displayName = "CardFooter" + +export { Card, CardHeader, CardFooter, CardTitle, CardDescription, CardContent } + diff --git a/frontend/src/index.css b/frontend/src/index.css index 2f3351a7..97f662cb 100644 --- a/frontend/src/index.css +++ b/frontend/src/index.css @@ -80,3 +80,62 @@ @apply bg-background text-foreground; } } + +@layer components { + .primary a, + .secondary a { + @apply text-blue-600 underline underline-offset-2 hover:text-blue-800 transition-colors; + } + + .primary a:visited { + @apply text-purple-600 hover:text-purple-800; + } + + .secondary a:visited { + @apply text-purple-600 hover:text-purple-800; + } + + @keyframes fadeIn { + from { + opacity: 0; + transform: translateY(10px); + } + to { + opacity: 1; + transform: translateY(0); + } + } + + .animate-fadeIn { + opacity: 0; + animation: fadeIn 0.5s ease-out forwards; + } + + @keyframes fadeOut { + from { + opacity: 1; + } + to { + opacity: 0; + visibility: hidden; + } + } + + .animate-fadeOut { + animation: fadeOut 0.5s ease-out forwards; + } + + @keyframes shimmer { + 0% { + background-position: -200% 0; + } + 100% { + background-position: 200% 0; + } + } + + .animate-shimmer { + background-size: 200% 100%; + animation: shimmer 1.5s ease-in-out infinite; + } +} diff --git a/frontend/src/lib/types.ts b/frontend/src/lib/types.ts index 5a4b4e02..879890fe 100644 --- a/frontend/src/lib/types.ts +++ b/frontend/src/lib/types.ts @@ -81,6 +81,11 @@ export type Message = { id: number; type: string; text: string; + widget_data?: { + widget_type: string; + message?: string; + data?: any[]; // eslint-disable-line @typescript-eslint/no-explicit-any + }; } export type ServiceAccount = { diff --git a/plugins/enthusiast-agent-product-search/src/enthusiast_agent_product_search/agent.py b/plugins/enthusiast-agent-product-search/src/enthusiast_agent_product_search/agent.py index 11384625..03f85b57 100644 --- a/plugins/enthusiast-agent-product-search/src/enthusiast_agent_product_search/agent.py +++ b/plugins/enthusiast-agent-product-search/src/enthusiast_agent_product_search/agent.py @@ -3,10 +3,12 @@ from .tools import ProductExamplesTool from .tools import ProductSQLSearchTool +from .tools import PresentProductsTool class ProductSearchToolCallingAgent(BaseToolCallingAgent): TOOLS = [ LLMToolConfig(tool_class=ProductSQLSearchTool), LLMToolConfig(tool_class=ProductExamplesTool), + LLMToolConfig(tool_class=PresentProductsTool), ] diff --git a/plugins/enthusiast-agent-product-search/src/enthusiast_agent_product_search/re_act/agent.py b/plugins/enthusiast-agent-product-search/src/enthusiast_agent_product_search/re_act/agent.py index 36281252..33b46549 100644 --- a/plugins/enthusiast-agent-product-search/src/enthusiast_agent_product_search/re_act/agent.py +++ b/plugins/enthusiast-agent-product-search/src/enthusiast_agent_product_search/re_act/agent.py @@ -1,12 +1,11 @@ -from enthusiast_agent_re_act import BaseReActAgent, StructuredReActOutputParser +from enthusiast_agent_re_act import BaseReActAgent from enthusiast_common.config.base import LLMToolConfig -from langchain.agents import AgentExecutor, create_react_agent -from langchain_core.tools import render_text_description_and_args -from ..tools.product_search_tool import ProductSearchTool +from ..tools import PresentProductsTool, ProductSearchTool class ProductSearchReActAgent(BaseReActAgent): TOOLS = [ LLMToolConfig(tool_class=ProductSearchTool), + LLMToolConfig(tool_class=PresentProductsTool), ] diff --git a/plugins/enthusiast-agent-product-search/src/enthusiast_agent_product_search/tools/__init__.py b/plugins/enthusiast-agent-product-search/src/enthusiast_agent_product_search/tools/__init__.py index 74061679..057e65a5 100644 --- a/plugins/enthusiast-agent-product-search/src/enthusiast_agent_product_search/tools/__init__.py +++ b/plugins/enthusiast-agent-product-search/src/enthusiast_agent_product_search/tools/__init__.py @@ -1,5 +1,6 @@ from .product_examples_tool import ProductExamplesTool from .product_search_tool import ProductSearchTool from .product_sql_search_tool import ProductSQLSearchTool +from .present_product_tool import PresentProductsTool -__all__ = ["ProductExamplesTool", "ProductSearchTool", "ProductSQLSearchTool"] +__all__ = ["ProductExamplesTool", "ProductSearchTool", "ProductSQLSearchTool", "PresentProductsTool"] diff --git a/plugins/enthusiast-agent-product-search/src/enthusiast_agent_product_search/tools/present_product_tool.py b/plugins/enthusiast-agent-product-search/src/enthusiast_agent_product_search/tools/present_product_tool.py new file mode 100644 index 00000000..2266f118 --- /dev/null +++ b/plugins/enthusiast-agent-product-search/src/enthusiast_agent_product_search/tools/present_product_tool.py @@ -0,0 +1,41 @@ +import json + +from django.forms import model_to_dict +from enthusiast_common.tools import BaseWidgetResponseLLMTool +from enthusiast_common.tools.interfaces import BaseWidgetResponseSerializer +from pydantic import BaseModel, Field + + +class PresentProductsToolSerializer(BaseWidgetResponseSerializer): + def serialize(self): + return model_to_dict(self.data) + + +class PresentProductsToolInput(BaseModel): + products_ids: str = Field(description="Comma separated list of product ids") + message_to_user: str = Field(description="Message to user that will go with products.") + + +class PresentProductsTool(BaseWidgetResponseLLMTool): + NAME = "present_products_tool" + DESCRIPTION = "IMPORTANT: *Always use this tool to show products to the user!*" + ARGS_SCHEMA = PresentProductsToolInput + RETURN_DIRECT = True + SERIALIZER_CLASS = PresentProductsToolSerializer + + def run(self, products_ids: str, message_to_user: str): + ids = products_ids.split(",") + if self._streaming: + self._injector.callbacks_handler.on_product_widget_start( + json.dumps({"number_of_products": len(ids), "message": message_to_user}) + ) + all_products = [] + for id in ids: + product = self._injector.repositories.product.get_by_id(int(id)) + serialized_product = self.SERIALIZER_CLASS(product).serialize() + all_products.append(serialized_product) + if self._streaming: + self._injector.callbacks_handler.on_product_widget_product(serialized_product) + if self._streaming: + self._injector.callbacks_handler.on_product_widget_end() + return json.dumps({"widget_type": "product_widget", "message": message_to_user, "data": all_products}) diff --git a/plugins/enthusiast-common/enthusiast_common/builder/base.py b/plugins/enthusiast-common/enthusiast_common/builder/base.py index edbf96f6..0f90ac53 100644 --- a/plugins/enthusiast-common/enthusiast_common/builder/base.py +++ b/plugins/enthusiast-common/enthusiast_common/builder/base.py @@ -28,6 +28,7 @@ def __init__(self, config: ConfigT, conversation_id: Any, streaming: bool = Fals self._data_set_id = None self._injector = None self._prompt = None + self._agent_callback_handler = None self._config = config self.conversation_id = conversation_id self.streaming = streaming @@ -40,11 +41,12 @@ def build(self) -> BaseAgent: self._embeddings_registry = self._build_embeddings_registry() self._llm = self._build_llm(self._config.llm) self._default_llm = self._build_default_llm() + self._agent_callback_handler = self._build_agent_callback_handler() self._injector = self._build_injector() tools = self._build_tools(default_llm=self._default_llm, injector=self._injector) - agent_callback_handler = self._build_agent_callback_handler() + self._prompt = self._build_prompt_template() - agent_instance = self._build_agent(tools, self._llm, agent_callback_handler) + agent_instance = self._build_agent(tools, self._llm, self._agent_callback_handler) self._inject_additional_arguments(agent_instance) return agent_instance diff --git a/plugins/enthusiast-common/enthusiast_common/injectors/base.py b/plugins/enthusiast-common/enthusiast_common/injectors/base.py index 2ca80273..ace10cab 100644 --- a/plugins/enthusiast-common/enthusiast_common/injectors/base.py +++ b/plugins/enthusiast-common/enthusiast_common/injectors/base.py @@ -3,7 +3,8 @@ from enthusiast_common.connectors import ECommercePlatformConnector from enthusiast_common.retrievers import BaseProductRetriever, BaseVectorStoreRetriever -from enthusiast_common.structures import RepositoriesInstances, DocumentChunkDetails +from enthusiast_common.structures import DocumentChunkDetails, RepositoriesInstances +from langchain_core.callbacks import BaseCallbackHandler from langchain_core.memory import BaseMemory @@ -35,3 +36,8 @@ def chat_summary_memory(self) -> BaseMemory: @abstractmethod def chat_limited_memory(self) -> BaseMemory: pass + + @property + @abstractmethod + def callbacks_handler(self) -> BaseCallbackHandler: + pass diff --git a/plugins/enthusiast-common/enthusiast_common/tools/__init__.py b/plugins/enthusiast-common/enthusiast_common/tools/__init__.py index 3c9f80bf..91741ffc 100644 --- a/plugins/enthusiast-common/enthusiast_common/tools/__init__.py +++ b/plugins/enthusiast-common/enthusiast_common/tools/__init__.py @@ -1,3 +1,3 @@ -from .base import BaseAgentTool, BaseFileTool, BaseFunctionTool, BaseLLMTool +from .base import BaseAgentTool, BaseFileTool, BaseFunctionTool, BaseLLMTool, BaseWidgetResponseLLMTool -__all__ = ["BaseLLMTool", "BaseAgentTool", "BaseFunctionTool", "BaseFileTool"] +__all__ = ["BaseLLMTool", "BaseAgentTool", "BaseFunctionTool", "BaseFileTool", "BaseWidgetResponseLLMTool"] diff --git a/plugins/enthusiast-common/enthusiast_common/tools/base.py b/plugins/enthusiast-common/enthusiast_common/tools/base.py index 0ea3b4a2..6a862019 100644 --- a/plugins/enthusiast-common/enthusiast_common/tools/base.py +++ b/plugins/enthusiast-common/enthusiast_common/tools/base.py @@ -14,6 +14,7 @@ from ..injectors import BaseInjector from ..utils import RequiredFieldsModel, validate_required_vars +from .interfaces import BaseWidgetResponseSerializer class ToolMeta(ABCMeta): @@ -99,6 +100,14 @@ def __init__( self._llm_registry = llm_registry +class BaseWidgetResponseLLMTool(BaseLLMTool, ABC): + SERIALIZER_CLASS: Type[BaseWidgetResponseSerializer] + + def __init__(self, data_set_id: Any, llm: BaseLanguageModel, injector: BaseInjector, streaming: bool = False): + super().__init__(data_set_id, llm, injector) + self._streaming = streaming + + class BaseAgentTool(BaseTool, ABC): def __init__( self, diff --git a/plugins/enthusiast-common/enthusiast_common/tools/interfaces.py b/plugins/enthusiast-common/enthusiast_common/tools/interfaces.py new file mode 100644 index 00000000..58b435ec --- /dev/null +++ b/plugins/enthusiast-common/enthusiast_common/tools/interfaces.py @@ -0,0 +1,11 @@ +from abc import ABC, abstractmethod +from typing import Any + + +class BaseWidgetResponseSerializer(ABC): + def __init__(self, data: Any): + self.data = data + + @abstractmethod + def serialize(self): + pass diff --git a/server/agent/consumers.py b/server/agent/consumers.py index 9d12fa6d..899c622b 100644 --- a/server/agent/consumers.py +++ b/server/agent/consumers.py @@ -25,6 +25,9 @@ async def chat_message(self, event): "message_created": self.handle_message_created, "action": self.handle_action, "error": self.handle_error, + "product_widget_start": self.handle_product_widget_start, + "product_widget_end": self.handle_product_widget_end, + "product_widget_product": self.handle_product_widget_product, } handler = handlers.get(event_type) @@ -51,6 +54,15 @@ async def handle_action(self, event): async def handle_error(self, event): await self.send(json.dumps({"event": "error", "data": {"output": event.get("output")}})) + async def handle_product_widget_start(self, event): + await self.send(json.dumps({"event": "product_widget_start", "data": event.get("data")})) + + async def handle_product_widget_end(self, event): + await self.send(json.dumps({"event": "product_widget_end", "data": event.get("data")})) + + async def handle_product_widget_product(self, event): + await self.send(json.dumps({"event": "product_widget_product", "data": {"chunk": event.get("data")}})) + async def save_message(self, output): from .models import Conversation, Message diff --git a/server/agent/core/agents/default_config.py b/server/agent/core/agents/default_config.py index 61272044..6bba9476 100644 --- a/server/agent/core/agents/default_config.py +++ b/server/agent/core/agents/default_config.py @@ -20,6 +20,7 @@ from agent.core.callbacks import ( AgentActionWebsocketCallbackHandler, + AgentWidgetWebsocketCallbackHandler, ConversationWebSocketCallbackHandler, ReactAgentWebsocketCallbackHandler, ) @@ -59,7 +60,7 @@ def get_default_config(type: AgentType) -> DefaultAgentConfig: agent_callback_handler_config = ( AgentCallbackHandlerConfig(handler_class=AgentActionWebsocketCallbackHandler) if type == AgentType.RE_ACT - else None + else AgentCallbackHandlerConfig(handler_class=AgentWidgetWebsocketCallbackHandler) ) return DefaultAgentConfig( repositories=RepositoriesConfig( diff --git a/server/agent/core/builder.py b/server/agent/core/builder.py index 31297b6a..2844adac 100644 --- a/server/agent/core/builder.py +++ b/server/agent/core/builder.py @@ -16,7 +16,13 @@ from enthusiast_common.injectors import BaseInjector from enthusiast_common.registry import BaseDBModelsRegistry, BaseEmbeddingProviderRegistry, BaseLanguageModelRegistry from enthusiast_common.retrievers import BaseRetriever -from enthusiast_common.tools import BaseAgentTool, BaseFileTool, BaseFunctionTool, BaseLLMTool +from enthusiast_common.tools import ( + BaseAgentTool, + BaseFileTool, + BaseFunctionTool, + BaseLLMTool, + BaseWidgetResponseLLMTool, +) from langchain_core.callbacks import BaseCallbackHandler from langchain_core.language_models import BaseLanguageModel from langchain_core.prompts import ChatPromptTemplate, PromptTemplate @@ -109,6 +115,11 @@ def _build_tools(self, default_llm: BaseLanguageModel, injector: BaseInjector) - if isinstance(tool_config, FunctionToolConfig): tools.append(self._build_function_tool(config=tool_config)) elif isinstance(tool_config, LLMToolConfig): + if issubclass(tool_config.tool_class, BaseWidgetResponseLLMTool): + tools.append( + self._build_widget_llm_tool(config=tool_config, default_llm=default_llm, injector=injector) + ) + continue tools.append(self._build_llm_tool(config=tool_config, injector=injector, default_llm=default_llm)) elif isinstance(tool_config, AgentToolConfig): tools.append(self._build_agent_tool(config=tool_config)) @@ -127,11 +138,15 @@ def _build_llm_tool( llm = default_llm if config.llm: llm = config.llm - return config.tool_class( - data_set_id=self._data_set_id, - llm=llm, - injector=injector, - ) + return config.tool_class(data_set_id=self._data_set_id, llm=llm, injector=injector) + + def _build_widget_llm_tool( + self, config: LLMToolConfig, default_llm: BaseLanguageModel, injector: BaseInjector + ) -> BaseLLMTool: + llm = default_llm + if config.llm: + llm = config.llm + return config.tool_class(data_set_id=self._data_set_id, llm=llm, injector=injector, streaming=self.streaming) def _build_file_tool( self, config: FileToolConfig, default_llm: BaseLanguageModel, injector: BaseInjector @@ -165,6 +180,7 @@ def _build_injector(self) -> BaseInjector: repositories=self._repositories, chat_summary_memory=chat_summary_memory, chat_limited_memory=chat_limited_memory, + callbacks_handler=self._agent_callback_handler, ) def _build_agent_callback_handler(self) -> Optional[BaseCallbackHandler]: diff --git a/server/agent/core/callbacks.py b/server/agent/core/callbacks.py index b0282f04..b79b9ffa 100644 --- a/server/agent/core/callbacks.py +++ b/server/agent/core/callbacks.py @@ -91,7 +91,38 @@ def on_llm_new_token(self, token: str, **kwargs): ) -class AgentActionWebsocketCallbackHandler(BaseWebSocketHandler): +class AgentWidgetWebsocketCallbackHandler(BaseWebSocketHandler): + def on_product_widget_start(self, data, **kwargs): + self.send_message( + { + "type": "chat_message", + "event": "product_widget_start", + "run_id": self.run_id, + "data": data, + } + ) + + def on_product_widget_product(self, product, **kwargs): + self.send_message( + { + "type": "chat_message", + "event": "product_widget_product", + "run_id": self.run_id, + "data": product, + }, + ) + + def on_product_widget_end(self): + self.send_message( + { + "type": "chat_message", + "event": "product_widget_end", + "run_id": self.run_id, + } + ) + + +class AgentActionWebsocketCallbackHandler(AgentWidgetWebsocketCallbackHandler): def on_agent_action( self, action: AgentAction, diff --git a/server/agent/core/injector.py b/server/agent/core/injector.py index e72d68a4..600368ea 100644 --- a/server/agent/core/injector.py +++ b/server/agent/core/injector.py @@ -4,6 +4,7 @@ from enthusiast_common.injectors import BaseInjector from enthusiast_common.retrievers import BaseProductRetriever, BaseVectorStoreRetriever from enthusiast_common.structures import RepositoriesInstances +from langchain_core.callbacks import BaseCallbackHandler from agent.core.memory import SummaryChatMemory from agent.core.memory.limited_chat_memory import LimitedChatMemory @@ -19,6 +20,7 @@ def __init__( repositories: RepositoriesInstances, chat_summary_memory: SummaryChatMemory, chat_limited_memory: LimitedChatMemory, + callbacks_handler: BaseCallbackHandler, ): super().__init__(repositories) self._document_retriever = document_retriever @@ -26,6 +28,7 @@ def __init__( self._ecommerce_platform_connector = ecommerce_platform_connector self._chat_summary_memory = chat_summary_memory self._chat_limited_memory = chat_limited_memory + self._callbacks_handler = callbacks_handler @property def document_retriever(self) -> BaseVectorStoreRetriever[DocumentChunk]: @@ -46,3 +49,7 @@ def chat_summary_memory(self) -> SummaryChatMemory: @property def chat_limited_memory(self) -> LimitedChatMemory: return self._chat_limited_memory + + @property + def callbacks_handler(self) -> BaseCallbackHandler: + return self._callbacks_handler diff --git a/server/agent/core/memory/persist_intermediate_steps_mixin.py b/server/agent/core/memory/persist_intermediate_steps_mixin.py index 93e7d396..02396638 100644 --- a/server/agent/core/memory/persist_intermediate_steps_mixin.py +++ b/server/agent/core/memory/persist_intermediate_steps_mixin.py @@ -1,5 +1,6 @@ +import json from abc import ABC -from typing import Any, Dict, cast +from typing import Any, Dict, Literal, cast from langchain.agents.output_parsers.tools import ToolAgentAction from langchain.memory import ConversationBufferMemory @@ -8,6 +9,11 @@ from agent.models import Message +class WidgetMessage(BaseMessage): + type: Literal["widget"] = "widget" + widget_data: dict + + class PersistIntermediateStepsMixin(ABC): """ This mixin can be added to a ConversationBufferMemory class in order to persist agent's function calls. @@ -30,9 +36,16 @@ def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: type=Message.MessageType.INTERMEDIATE_STEP, content=agent_action.messages[0].content ) self_as_conversation_memory.chat_memory.add_message(intermediate_step_message) - - function_message = FunctionMessage(name=agent_action.tool, content=result) - self_as_conversation_memory.chat_memory.add_message(function_message) - - ai_message = AIMessage(outputs["output"]) - self_as_conversation_memory.chat_memory.add_message(ai_message) + if result: + function_message = FunctionMessage(name=agent_action.tool, content=result) + self_as_conversation_memory.chat_memory.add_message(function_message) + if outputs.get("output") is not None: + try: + json_object = json.loads(outputs["output"]) + ai_message = WidgetMessage( + content=f"Widget representation of data: {json_object['data']}", widget_data=json_object + ) + self_as_conversation_memory.chat_memory.add_message(ai_message) + except Exception: + ai_message = AIMessage(outputs["output"]) + self_as_conversation_memory.chat_memory.add_message(ai_message) diff --git a/server/agent/core/memory/persistent_chat_history.py b/server/agent/core/memory/persistent_chat_history.py index 6afe76e4..9b6af2af 100644 --- a/server/agent/core/memory/persistent_chat_history.py +++ b/server/agent/core/memory/persistent_chat_history.py @@ -16,7 +16,10 @@ def __init__(self, conversation_repo: BaseConversationRepository, conversation_i def add_message(self, message: BaseMessage) -> None: self._conversation.messages.create( - type=message.type, text=message.content, function_name=getattr(message, "name", None) + type=message.type, + text=message.content, + function_name=getattr(message, "name", None), + widget_data=getattr(message, "widget_data", None), ) @property diff --git a/server/agent/migrations/0024_message_widget_data_alter_message_type.py b/server/agent/migrations/0024_message_widget_data_alter_message_type.py new file mode 100644 index 00000000..7bca93f9 --- /dev/null +++ b/server/agent/migrations/0024_message_widget_data_alter_message_type.py @@ -0,0 +1,33 @@ +# Generated by Django 5.2.4 on 2025-10-20 11:06 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("agent", "0024_alter_agent_description"), + ] + + operations = [ + migrations.AddField( + model_name="message", + name="widget_data", + field=models.JSONField(blank=True, null=True), + ), + migrations.AlterField( + model_name="message", + name="type", + field=models.CharField( + choices=[ + ("function", "Function"), + ("human", "Human"), + ("ai", "Ai"), + ("system", "System"), + ("intermediate_step", "Intermediate Step"), + ("file", "File"), + ("widget", "Widget"), + ], + max_length=50, + ), + ), + ] diff --git a/server/agent/models/message.py b/server/agent/models/message.py index 0767c450..7e785eae 100644 --- a/server/agent/models/message.py +++ b/server/agent/models/message.py @@ -12,6 +12,7 @@ class MessageType(models.TextChoices): SYSTEM = "system" INTERMEDIATE_STEP = "intermediate_step" FILE = "file" + WIDGET = "widget" conversation = models.ForeignKey(Conversation, related_name="messages", on_delete=models.PROTECT) created_at = models.DateTimeField(auto_now_add=True) @@ -25,6 +26,7 @@ class MessageType(models.TextChoices): function_name = models.CharField(max_length=50, blank=True, null=True) file_name = models.CharField(max_length=256, blank=True, null=True) file_type = models.CharField(max_length=50, blank=True, null=True) + widget_data = models.JSONField(blank=True, null=True) @classmethod def internal_message_types(cls): @@ -33,12 +35,13 @@ def internal_message_types(cls): @property def langchain_type(self): langchain_type_mapping = { - self.MessageType.FUNCTION: "ai", self.MessageType.FILE: "human", - self.MessageType.INTERMEDIATE_STEP: "ai", self.MessageType.HUMAN: "human", - self.MessageType.AI: "ai", self.MessageType.SYSTEM: "system", + self.MessageType.INTERMEDIATE_STEP: "ai", + self.MessageType.AI: "ai", + self.MessageType.FUNCTION: "ai", + self.MessageType.WIDGET: "ai", } return langchain_type_mapping[self.type] diff --git a/server/agent/serializers/conversation.py b/server/agent/serializers/conversation.py index 66900cd0..c3265b53 100644 --- a/server/agent/serializers/conversation.py +++ b/server/agent/serializers/conversation.py @@ -88,10 +88,10 @@ class MessagesSerializer(serializers.ModelSerializer): class Meta: model = Message - fields = ["id", "text", "type", "file_type", "file_name"] + fields = ["id", "text", "type", "file_type", "file_name", "widget_data"] def get_text(self, obj: Message): - if obj.type == Message.MessageType.FILE: + if obj.type == Message.MessageType.FILE or obj.type == Message.MessageType.WIDGET: return "" else: return obj.text diff --git a/server/agent/views.py b/server/agent/views.py index 245a256b..f3f917ff 100644 --- a/server/agent/views.py +++ b/server/agent/views.py @@ -126,7 +126,7 @@ def post(self, request, conversation_id): conversation_id=conversation.id, created_at=datetime.now(), type=Message.MessageType.FILE, - file_name=file.file.name.split(".")[0], + file_name=file.file.name.split(".")[-1], file_type=file.file.name.split(".")[-1], text=f"Uploaded {file.file.name} with id: {file.pk}", )