diff --git a/frontend/src/components/conversation-view/chat-session.tsx b/frontend/src/components/conversation-view/chat-session.tsx
index 880f7efa..04cfa015 100644
--- a/frontend/src/components/conversation-view/chat-session.tsx
+++ b/frontend/src/components/conversation-view/chat-session.tsx
@@ -13,23 +13,52 @@ import { AttachmentBubble } from "@/components/conversation-view/attachment-bubb
import { MessageError } from "@/components/conversation-view/message-error.tsx";
import { MessageBubbleTyping } from "@/components/conversation-view/message-bubble-typing.tsx";
import { ChatWindow } from "@/components/conversation-view/chat-window.tsx";
+import { ProductWidget } from "@/components/conversation-view/product-widget.tsx";
-import type { AgentDetails } from "@/lib/types";
-import type { Conversation as ConversationSchema } from '@/lib/types.ts';
-import type { Message, MessageFile } from "@/components/conversation-view/message-composer.tsx";
+import type {AgentDetails, Conversation as ConversationSchema} from '@/lib/types.ts';
+import type {Message, MessageFile} from "@/components/conversation-view/message-composer.tsx";
+import { Message as ApiMessage } from "src/lib/types.ts"
-export type FileMessageProps = {
+export type BaseMessageProps = {
id: number | null;
+ type: "file" | "human" | "ai" | "system" | "widget";
+}
+
+export type FileMessageProps = BaseMessageProps & {
type: "file";
file_name: string;
file_type: string;
};
+export type WidgetMessageProps = BaseMessageProps & {
+ type: "widget";
+ widget_data: WidgetData;
+ isStreaming: boolean;
+};
-export type MessageProps = {
+export type TextMessage = BaseMessageProps & {
type: "human" | "ai" | "system";
text: string;
- id: number | null;
-} | FileMessageProps;
+ isStreaming?: boolean;
+}
+export type MessageProps = FileMessageProps | WidgetMessageProps | TextMessage;
+
+export interface ProductData {
+ id: number;
+ name: string;
+ sku: string;
+ slug: string;
+ description: string;
+ categories: string;
+ properties: string;
+ price?: number;
+}
+
+export interface WidgetData {
+ widget_type: string;
+ message?: string;
+ data?: ProductData[];
+}
+
interface ChatSessionProps {
pendingMessage?: Message;
@@ -155,6 +184,67 @@ export function ChatSession({ pendingMessage }: ChatSessionProps) {
action: () => {
setAgentAction(data.data.output);
},
+ product_widget_start: () => {
+ setIsLoading(false);
+ setIsAgentLoading(false);
+
+ let widgetData;
+ try {
+ widgetData = typeof data.data === 'string' ? JSON.parse(data.data) : data.data;
+ } catch {
+ widgetData = { message: String(data.data), number_of_products: 0 };
+ }
+
+ const newWidget: WidgetMessageProps = {
+ type: "widget",
+ id: null,
+ widget_data: {
+ widget_type: "product_widget",
+ message: widgetData.message || '',
+ data: [],
+ },
+ isStreaming: true,
+ };
+
+ setMessages((prevMessages) => [...prevMessages, newWidget]);
+ scrollToLastMessage();
+ },
+ product_widget_end: () => {
+ setMessages((prevMessages) => {
+ const lastMessage = prevMessages[prevMessages.length - 1];
+ if (lastMessage && lastMessage.type === "widget" && lastMessage.isStreaming) {
+ return [
+ ...prevMessages.slice(0, -1),
+ { ...lastMessage, isStreaming: false }
+ ];
+ }
+ return prevMessages;
+ });
+ },
+ product_widget_product: () => {
+ const productData = data.data.chunk as ProductData;
+
+ setMessages((prevMessages) => {
+ const lastMessage = prevMessages[prevMessages.length - 1];
+
+ if (lastMessage && lastMessage.type === "widget" && lastMessage.widget_data) {
+ const updatedData = [...(lastMessage.widget_data.data || []), productData];
+ return [
+ ...prevMessages.slice(0, -1),
+ {
+ ...lastMessage,
+ widget_data: {
+ ...lastMessage.widget_data,
+ data: updatedData
+ }
+ }
+ ];
+ }
+ return prevMessages;
+ });
+
+ scrollToLastMessage();
+ },
error: () => {
setIsLoading(false);
setIsAgentLoading(false);
@@ -219,6 +309,23 @@ export function ChatSession({ pendingMessage }: ChatSessionProps) {
});
};
+ const createMessageFromResponse = (response: ApiMessage): MessageProps => {
+ if (response.widget_data) {
+ return {
+ type: "widget",
+ id: response.id,
+ widget_data: response.widget_data as WidgetData,
+ isStreaming: false,
+ } as WidgetMessageProps;
+ } else {
+ return {
+ type: response.type as "ai" | "system",
+ text: response.text,
+ id: response.id,
+ } as TextMessage;
+ }
+ };
+
const updateTaskStatus = async (conversationId: number, taskHandle: TaskHandle, streaming: boolean) => {
if (streaming) {
return;
@@ -229,10 +336,8 @@ export function ChatSession({ pendingMessage }: ChatSessionProps) {
if (response) {
setIsAgentLoading(false);
- setMessages((prev) => [
- ...prev,
- { type: "ai", text: response.text, id: response.id }
- ]);
+ const message = createMessageFromResponse(response);
+ setMessages((prev) => [...prev, message]);
setIsLoading(false);
scrollToLastMessage();
} else {
@@ -317,7 +422,6 @@ export function ChatSession({ pendingMessage }: ChatSessionProps) {
{messages.map((message, index) => {
const inMessageGroup = isUserMessage(message) && index > 0 && isUserMessage(messages[index - 1]);
-
if (message.type === "system") {
return (
@@ -329,20 +433,42 @@ export function ChatSession({ pendingMessage }: ChatSessionProps) {
);
}
-
- if (message.type === 'ai' && message.text === '') {
- return (
);
+ if (message.type === 'ai' && message.text === '' && index === messages.length) {
+ return (
);
}
-
- return (
-
- );
+ );
+ }
+ if (message.type === "widget" && message.widget_data?.widget_type === "product_widget") {
+ return (
+
+ );
+ }
+ if (message.type === "human") {
+ return (
+
+ );
+ }
})}
{isAgentLoading &&
}
diff --git a/frontend/src/components/conversation-view/message-bubble.tsx b/frontend/src/components/conversation-view/message-bubble.tsx
index 792bb9d8..c7c23264 100644
--- a/frontend/src/components/conversation-view/message-bubble.tsx
+++ b/frontend/src/components/conversation-view/message-bubble.tsx
@@ -19,7 +19,14 @@ export function MessageBubble({ text, variant, questionId, inMessageGroup }: Mes
useEffect(() => {
const processText = async () => {
- const rawHtml = await marked(text);
+ const renderer = new marked.Renderer();
+ renderer.link = ({ href, title, tokens }) => {
+ const text = tokens.map(token => token.raw).join('');
+ const titleAttr = title ? ` title="${title}"` : '';
+ return `
${text}`;
+ };
+
+ const rawHtml = await marked(text, { renderer });
const cleanHtml = DOMPurify.sanitize(rawHtml);
setSanitizedHtml(cleanHtml);
};
diff --git a/frontend/src/components/conversation-view/product-card.tsx b/frontend/src/components/conversation-view/product-card.tsx
new file mode 100644
index 00000000..09d88e3e
--- /dev/null
+++ b/frontend/src/components/conversation-view/product-card.tsx
@@ -0,0 +1,103 @@
+import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card";
+import { Badge } from "@/components/ui/badge";
+
+export interface ProductCardProps {
+ id: number;
+ name: string;
+ sku: string;
+ slug: string;
+ description: string;
+ categories: string;
+ properties: string;
+ price?: number;
+}
+
+export function ProductCard({
+ name,
+ sku,
+ description,
+ categories,
+ properties,
+ price
+}: ProductCardProps) {
+ let parsedCategories: string[] = [];
+ if (categories) {
+ try {
+ const parsed = JSON.parse(categories);
+ parsedCategories = Array.isArray(parsed) ? parsed : [String(parsed)];
+ } catch {
+ parsedCategories = categories.split(',').map(c => c.trim()).filter(Boolean);
+ }
+ }
+
+ let parsedProperties: Record
= {};
+ if (properties) {
+ try {
+ const parsed = JSON.parse(properties);
+ if (typeof parsed === 'object' && parsed !== null) {
+ parsedProperties = parsed;
+ }
+ } catch {
+ properties.split(';').forEach(pair => {
+ const [key, value] = pair.split('->').map(s => s.trim());
+ if (key && value) {
+ parsedProperties[key] = value;
+ }
+ });
+ }
+ }
+
+ return (
+
+
+
+ {name}
+ {price !== undefined && (
+
+ ${price.toFixed(2)}
+
+ )}
+
+
+ SKU: {sku}
+
+
+
+ {description && (
+
+ {description}
+
+ )}
+
+ {parsedCategories.length > 0 && (
+
+
+ {parsedCategories.slice(0, 3).map((category: string, index: number) => (
+
+ {category}
+
+ ))}
+ {parsedCategories.length > 3 && (
+
+ +{parsedCategories.length - 3} more
+
+ )}
+
+
+ )}
+
+ {Object.keys(parsedProperties).length > 0 && (
+
+ {Object.entries(parsedProperties).slice(0, 3).map(([key, value], index) => (
+
+ {key}:
+ {String(value)}
+
+ ))}
+
+ )}
+
+
+ );
+}
+
diff --git a/frontend/src/components/conversation-view/product-widget.tsx b/frontend/src/components/conversation-view/product-widget.tsx
new file mode 100644
index 00000000..a7a204e8
--- /dev/null
+++ b/frontend/src/components/conversation-view/product-widget.tsx
@@ -0,0 +1,116 @@
+import { useState, useEffect } from 'react';
+import { ProductCard, ProductCardProps } from './product-card';
+import { BaseBubble } from './base-bubble';
+import { marked } from 'marked';
+import DOMPurify from 'dompurify';
+
+export interface ProductWidgetProps {
+ message?: string;
+ products: ProductCardProps[];
+ isStreaming?: boolean;
+ expectedProductCount?: number;
+}
+
+function ProductPlaceholder() {
+ return (
+
+ );
+}
+
+export function ProductWidget({
+ message,
+ products,
+ isStreaming = false,
+ expectedProductCount = 0
+}: ProductWidgetProps) {
+ const [sanitizedHtml, setSanitizedHtml] = useState("");
+ const [shouldAnimate, setShouldAnimate] = useState(isStreaming);
+
+ useEffect(() => {
+ if (isStreaming) {
+ setShouldAnimate(true);
+ }
+ }, [isStreaming]);
+
+ useEffect(() => {
+ if (message) {
+ const processText = async () => {
+ const renderer = new marked.Renderer();
+ renderer.link = ({ href, title, tokens }) => {
+ const text = tokens.map(token => token.raw).join('');
+ const titleAttr = title ? ` title="${title}"` : '';
+ return `${text}`;
+ };
+
+ const rawHtml = await marked(message, { renderer });
+ const cleanHtml = DOMPurify.sanitize(rawHtml);
+ setSanitizedHtml(cleanHtml);
+ };
+
+ processText();
+ }
+ }, [message]);
+
+ const totalSlots = expectedProductCount || products.length;
+ const productsReceived = products.length;
+ const productDelayMS = 500
+
+ return (
+
+ {message && (
+
+ )}
+
+
+ {products.map((product, index) => (
+
+ {shouldAnimate ? (
+ <>
+
+
+ >
+ ) : (
+
+ )}
+
+ ))}
+
+ {isStreaming && Array.from({ length: Math.max(0, totalSlots - productsReceived) }).map((_, index) => (
+
+ ))}
+
+
+ );
+}
+
diff --git a/frontend/src/components/ui/badge.tsx b/frontend/src/components/ui/badge.tsx
new file mode 100644
index 00000000..690e3de6
--- /dev/null
+++ b/frontend/src/components/ui/badge.tsx
@@ -0,0 +1,37 @@
+import * as React from "react"
+import { cva, type VariantProps } from "class-variance-authority"
+
+import { cn } from "@/lib/utils"
+
+const badgeVariants = cva(
+ "inline-flex items-center rounded-full border px-2.5 py-0.5 text-xs font-semibold transition-colors focus:outline-none focus:ring-2 focus:ring-ring focus:ring-offset-2",
+ {
+ variants: {
+ variant: {
+ default:
+ "border-transparent bg-primary text-primary-foreground hover:bg-primary/80",
+ secondary:
+ "border-transparent bg-secondary text-secondary-foreground hover:bg-secondary/80",
+ destructive:
+ "border-transparent bg-destructive text-destructive-foreground hover:bg-destructive/80",
+ outline: "text-foreground",
+ },
+ },
+ defaultVariants: {
+ variant: "default",
+ },
+ }
+)
+
+export interface BadgeProps
+ extends React.HTMLAttributes,
+ VariantProps {}
+
+function Badge({ className, variant, ...props }: BadgeProps) {
+ return (
+
+ )
+}
+
+export { Badge, badgeVariants }
+
diff --git a/frontend/src/components/ui/card.tsx b/frontend/src/components/ui/card.tsx
new file mode 100644
index 00000000..5e494017
--- /dev/null
+++ b/frontend/src/components/ui/card.tsx
@@ -0,0 +1,80 @@
+import * as React from "react"
+
+import { cn } from "@/lib/utils"
+
+const Card = React.forwardRef<
+ HTMLDivElement,
+ React.HTMLAttributes
+>(({ className, ...props }, ref) => (
+
+))
+Card.displayName = "Card"
+
+const CardHeader = React.forwardRef<
+ HTMLDivElement,
+ React.HTMLAttributes
+>(({ className, ...props }, ref) => (
+
+))
+CardHeader.displayName = "CardHeader"
+
+const CardTitle = React.forwardRef<
+ HTMLParagraphElement,
+ React.HTMLAttributes
+>(({ className, ...props }, ref) => (
+
+))
+CardTitle.displayName = "CardTitle"
+
+const CardDescription = React.forwardRef<
+ HTMLParagraphElement,
+ React.HTMLAttributes
+>(({ className, ...props }, ref) => (
+
+))
+CardDescription.displayName = "CardDescription"
+
+const CardContent = React.forwardRef<
+ HTMLDivElement,
+ React.HTMLAttributes
+>(({ className, ...props }, ref) => (
+
+))
+CardContent.displayName = "CardContent"
+
+const CardFooter = React.forwardRef<
+ HTMLDivElement,
+ React.HTMLAttributes
+>(({ className, ...props }, ref) => (
+
+))
+CardFooter.displayName = "CardFooter"
+
+export { Card, CardHeader, CardFooter, CardTitle, CardDescription, CardContent }
+
diff --git a/frontend/src/index.css b/frontend/src/index.css
index 2f3351a7..97f662cb 100644
--- a/frontend/src/index.css
+++ b/frontend/src/index.css
@@ -80,3 +80,62 @@
@apply bg-background text-foreground;
}
}
+
+@layer components {
+ .primary a,
+ .secondary a {
+ @apply text-blue-600 underline underline-offset-2 hover:text-blue-800 transition-colors;
+ }
+
+ .primary a:visited {
+ @apply text-purple-600 hover:text-purple-800;
+ }
+
+ .secondary a:visited {
+ @apply text-purple-600 hover:text-purple-800;
+ }
+
+ @keyframes fadeIn {
+ from {
+ opacity: 0;
+ transform: translateY(10px);
+ }
+ to {
+ opacity: 1;
+ transform: translateY(0);
+ }
+ }
+
+ .animate-fadeIn {
+ opacity: 0;
+ animation: fadeIn 0.5s ease-out forwards;
+ }
+
+ @keyframes fadeOut {
+ from {
+ opacity: 1;
+ }
+ to {
+ opacity: 0;
+ visibility: hidden;
+ }
+ }
+
+ .animate-fadeOut {
+ animation: fadeOut 0.5s ease-out forwards;
+ }
+
+ @keyframes shimmer {
+ 0% {
+ background-position: -200% 0;
+ }
+ 100% {
+ background-position: 200% 0;
+ }
+ }
+
+ .animate-shimmer {
+ background-size: 200% 100%;
+ animation: shimmer 1.5s ease-in-out infinite;
+ }
+}
diff --git a/frontend/src/lib/types.ts b/frontend/src/lib/types.ts
index 5a4b4e02..879890fe 100644
--- a/frontend/src/lib/types.ts
+++ b/frontend/src/lib/types.ts
@@ -81,6 +81,11 @@ export type Message = {
id: number;
type: string;
text: string;
+ widget_data?: {
+ widget_type: string;
+ message?: string;
+ data?: any[]; // eslint-disable-line @typescript-eslint/no-explicit-any
+ };
}
export type ServiceAccount = {
diff --git a/plugins/enthusiast-agent-product-search/src/enthusiast_agent_product_search/agent.py b/plugins/enthusiast-agent-product-search/src/enthusiast_agent_product_search/agent.py
index 11384625..03f85b57 100644
--- a/plugins/enthusiast-agent-product-search/src/enthusiast_agent_product_search/agent.py
+++ b/plugins/enthusiast-agent-product-search/src/enthusiast_agent_product_search/agent.py
@@ -3,10 +3,12 @@
from .tools import ProductExamplesTool
from .tools import ProductSQLSearchTool
+from .tools import PresentProductsTool
class ProductSearchToolCallingAgent(BaseToolCallingAgent):
TOOLS = [
LLMToolConfig(tool_class=ProductSQLSearchTool),
LLMToolConfig(tool_class=ProductExamplesTool),
+ LLMToolConfig(tool_class=PresentProductsTool),
]
diff --git a/plugins/enthusiast-agent-product-search/src/enthusiast_agent_product_search/re_act/agent.py b/plugins/enthusiast-agent-product-search/src/enthusiast_agent_product_search/re_act/agent.py
index 36281252..33b46549 100644
--- a/plugins/enthusiast-agent-product-search/src/enthusiast_agent_product_search/re_act/agent.py
+++ b/plugins/enthusiast-agent-product-search/src/enthusiast_agent_product_search/re_act/agent.py
@@ -1,12 +1,11 @@
-from enthusiast_agent_re_act import BaseReActAgent, StructuredReActOutputParser
+from enthusiast_agent_re_act import BaseReActAgent
from enthusiast_common.config.base import LLMToolConfig
-from langchain.agents import AgentExecutor, create_react_agent
-from langchain_core.tools import render_text_description_and_args
-from ..tools.product_search_tool import ProductSearchTool
+from ..tools import PresentProductsTool, ProductSearchTool
class ProductSearchReActAgent(BaseReActAgent):
TOOLS = [
LLMToolConfig(tool_class=ProductSearchTool),
+ LLMToolConfig(tool_class=PresentProductsTool),
]
diff --git a/plugins/enthusiast-agent-product-search/src/enthusiast_agent_product_search/tools/__init__.py b/plugins/enthusiast-agent-product-search/src/enthusiast_agent_product_search/tools/__init__.py
index 74061679..057e65a5 100644
--- a/plugins/enthusiast-agent-product-search/src/enthusiast_agent_product_search/tools/__init__.py
+++ b/plugins/enthusiast-agent-product-search/src/enthusiast_agent_product_search/tools/__init__.py
@@ -1,5 +1,6 @@
from .product_examples_tool import ProductExamplesTool
from .product_search_tool import ProductSearchTool
from .product_sql_search_tool import ProductSQLSearchTool
+from .present_product_tool import PresentProductsTool
-__all__ = ["ProductExamplesTool", "ProductSearchTool", "ProductSQLSearchTool"]
+__all__ = ["ProductExamplesTool", "ProductSearchTool", "ProductSQLSearchTool", "PresentProductsTool"]
diff --git a/plugins/enthusiast-agent-product-search/src/enthusiast_agent_product_search/tools/present_product_tool.py b/plugins/enthusiast-agent-product-search/src/enthusiast_agent_product_search/tools/present_product_tool.py
new file mode 100644
index 00000000..2266f118
--- /dev/null
+++ b/plugins/enthusiast-agent-product-search/src/enthusiast_agent_product_search/tools/present_product_tool.py
@@ -0,0 +1,41 @@
+import json
+
+from django.forms import model_to_dict
+from enthusiast_common.tools import BaseWidgetResponseLLMTool
+from enthusiast_common.tools.interfaces import BaseWidgetResponseSerializer
+from pydantic import BaseModel, Field
+
+
+class PresentProductsToolSerializer(BaseWidgetResponseSerializer):
+ def serialize(self):
+ return model_to_dict(self.data)
+
+
+class PresentProductsToolInput(BaseModel):
+ products_ids: str = Field(description="Comma separated list of product ids")
+ message_to_user: str = Field(description="Message to user that will go with products.")
+
+
+class PresentProductsTool(BaseWidgetResponseLLMTool):
+ NAME = "present_products_tool"
+ DESCRIPTION = "IMPORTANT: *Always use this tool to show products to the user!*"
+ ARGS_SCHEMA = PresentProductsToolInput
+ RETURN_DIRECT = True
+ SERIALIZER_CLASS = PresentProductsToolSerializer
+
+ def run(self, products_ids: str, message_to_user: str):
+ ids = products_ids.split(",")
+ if self._streaming:
+ self._injector.callbacks_handler.on_product_widget_start(
+ json.dumps({"number_of_products": len(ids), "message": message_to_user})
+ )
+ all_products = []
+ for id in ids:
+ product = self._injector.repositories.product.get_by_id(int(id))
+ serialized_product = self.SERIALIZER_CLASS(product).serialize()
+ all_products.append(serialized_product)
+ if self._streaming:
+ self._injector.callbacks_handler.on_product_widget_product(serialized_product)
+ if self._streaming:
+ self._injector.callbacks_handler.on_product_widget_end()
+ return json.dumps({"widget_type": "product_widget", "message": message_to_user, "data": all_products})
diff --git a/plugins/enthusiast-common/enthusiast_common/builder/base.py b/plugins/enthusiast-common/enthusiast_common/builder/base.py
index edbf96f6..0f90ac53 100644
--- a/plugins/enthusiast-common/enthusiast_common/builder/base.py
+++ b/plugins/enthusiast-common/enthusiast_common/builder/base.py
@@ -28,6 +28,7 @@ def __init__(self, config: ConfigT, conversation_id: Any, streaming: bool = Fals
self._data_set_id = None
self._injector = None
self._prompt = None
+ self._agent_callback_handler = None
self._config = config
self.conversation_id = conversation_id
self.streaming = streaming
@@ -40,11 +41,12 @@ def build(self) -> BaseAgent:
self._embeddings_registry = self._build_embeddings_registry()
self._llm = self._build_llm(self._config.llm)
self._default_llm = self._build_default_llm()
+ self._agent_callback_handler = self._build_agent_callback_handler()
self._injector = self._build_injector()
tools = self._build_tools(default_llm=self._default_llm, injector=self._injector)
- agent_callback_handler = self._build_agent_callback_handler()
+
self._prompt = self._build_prompt_template()
- agent_instance = self._build_agent(tools, self._llm, agent_callback_handler)
+ agent_instance = self._build_agent(tools, self._llm, self._agent_callback_handler)
self._inject_additional_arguments(agent_instance)
return agent_instance
diff --git a/plugins/enthusiast-common/enthusiast_common/injectors/base.py b/plugins/enthusiast-common/enthusiast_common/injectors/base.py
index 2ca80273..ace10cab 100644
--- a/plugins/enthusiast-common/enthusiast_common/injectors/base.py
+++ b/plugins/enthusiast-common/enthusiast_common/injectors/base.py
@@ -3,7 +3,8 @@
from enthusiast_common.connectors import ECommercePlatformConnector
from enthusiast_common.retrievers import BaseProductRetriever, BaseVectorStoreRetriever
-from enthusiast_common.structures import RepositoriesInstances, DocumentChunkDetails
+from enthusiast_common.structures import DocumentChunkDetails, RepositoriesInstances
+from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.memory import BaseMemory
@@ -35,3 +36,8 @@ def chat_summary_memory(self) -> BaseMemory:
@abstractmethod
def chat_limited_memory(self) -> BaseMemory:
pass
+
+ @property
+ @abstractmethod
+ def callbacks_handler(self) -> BaseCallbackHandler:
+ pass
diff --git a/plugins/enthusiast-common/enthusiast_common/tools/__init__.py b/plugins/enthusiast-common/enthusiast_common/tools/__init__.py
index 3c9f80bf..91741ffc 100644
--- a/plugins/enthusiast-common/enthusiast_common/tools/__init__.py
+++ b/plugins/enthusiast-common/enthusiast_common/tools/__init__.py
@@ -1,3 +1,3 @@
-from .base import BaseAgentTool, BaseFileTool, BaseFunctionTool, BaseLLMTool
+from .base import BaseAgentTool, BaseFileTool, BaseFunctionTool, BaseLLMTool, BaseWidgetResponseLLMTool
-__all__ = ["BaseLLMTool", "BaseAgentTool", "BaseFunctionTool", "BaseFileTool"]
+__all__ = ["BaseLLMTool", "BaseAgentTool", "BaseFunctionTool", "BaseFileTool", "BaseWidgetResponseLLMTool"]
diff --git a/plugins/enthusiast-common/enthusiast_common/tools/base.py b/plugins/enthusiast-common/enthusiast_common/tools/base.py
index 0ea3b4a2..6a862019 100644
--- a/plugins/enthusiast-common/enthusiast_common/tools/base.py
+++ b/plugins/enthusiast-common/enthusiast_common/tools/base.py
@@ -14,6 +14,7 @@
from ..injectors import BaseInjector
from ..utils import RequiredFieldsModel, validate_required_vars
+from .interfaces import BaseWidgetResponseSerializer
class ToolMeta(ABCMeta):
@@ -99,6 +100,14 @@ def __init__(
self._llm_registry = llm_registry
+class BaseWidgetResponseLLMTool(BaseLLMTool, ABC):
+ SERIALIZER_CLASS: Type[BaseWidgetResponseSerializer]
+
+ def __init__(self, data_set_id: Any, llm: BaseLanguageModel, injector: BaseInjector, streaming: bool = False):
+ super().__init__(data_set_id, llm, injector)
+ self._streaming = streaming
+
+
class BaseAgentTool(BaseTool, ABC):
def __init__(
self,
diff --git a/plugins/enthusiast-common/enthusiast_common/tools/interfaces.py b/plugins/enthusiast-common/enthusiast_common/tools/interfaces.py
new file mode 100644
index 00000000..58b435ec
--- /dev/null
+++ b/plugins/enthusiast-common/enthusiast_common/tools/interfaces.py
@@ -0,0 +1,11 @@
+from abc import ABC, abstractmethod
+from typing import Any
+
+
+class BaseWidgetResponseSerializer(ABC):
+ def __init__(self, data: Any):
+ self.data = data
+
+ @abstractmethod
+ def serialize(self):
+ pass
diff --git a/server/agent/consumers.py b/server/agent/consumers.py
index 9d12fa6d..899c622b 100644
--- a/server/agent/consumers.py
+++ b/server/agent/consumers.py
@@ -25,6 +25,9 @@ async def chat_message(self, event):
"message_created": self.handle_message_created,
"action": self.handle_action,
"error": self.handle_error,
+ "product_widget_start": self.handle_product_widget_start,
+ "product_widget_end": self.handle_product_widget_end,
+ "product_widget_product": self.handle_product_widget_product,
}
handler = handlers.get(event_type)
@@ -51,6 +54,15 @@ async def handle_action(self, event):
async def handle_error(self, event):
await self.send(json.dumps({"event": "error", "data": {"output": event.get("output")}}))
+ async def handle_product_widget_start(self, event):
+ await self.send(json.dumps({"event": "product_widget_start", "data": event.get("data")}))
+
+ async def handle_product_widget_end(self, event):
+ await self.send(json.dumps({"event": "product_widget_end", "data": event.get("data")}))
+
+ async def handle_product_widget_product(self, event):
+ await self.send(json.dumps({"event": "product_widget_product", "data": {"chunk": event.get("data")}}))
+
async def save_message(self, output):
from .models import Conversation, Message
diff --git a/server/agent/core/agents/default_config.py b/server/agent/core/agents/default_config.py
index 61272044..6bba9476 100644
--- a/server/agent/core/agents/default_config.py
+++ b/server/agent/core/agents/default_config.py
@@ -20,6 +20,7 @@
from agent.core.callbacks import (
AgentActionWebsocketCallbackHandler,
+ AgentWidgetWebsocketCallbackHandler,
ConversationWebSocketCallbackHandler,
ReactAgentWebsocketCallbackHandler,
)
@@ -59,7 +60,7 @@ def get_default_config(type: AgentType) -> DefaultAgentConfig:
agent_callback_handler_config = (
AgentCallbackHandlerConfig(handler_class=AgentActionWebsocketCallbackHandler)
if type == AgentType.RE_ACT
- else None
+ else AgentCallbackHandlerConfig(handler_class=AgentWidgetWebsocketCallbackHandler)
)
return DefaultAgentConfig(
repositories=RepositoriesConfig(
diff --git a/server/agent/core/builder.py b/server/agent/core/builder.py
index 31297b6a..2844adac 100644
--- a/server/agent/core/builder.py
+++ b/server/agent/core/builder.py
@@ -16,7 +16,13 @@
from enthusiast_common.injectors import BaseInjector
from enthusiast_common.registry import BaseDBModelsRegistry, BaseEmbeddingProviderRegistry, BaseLanguageModelRegistry
from enthusiast_common.retrievers import BaseRetriever
-from enthusiast_common.tools import BaseAgentTool, BaseFileTool, BaseFunctionTool, BaseLLMTool
+from enthusiast_common.tools import (
+ BaseAgentTool,
+ BaseFileTool,
+ BaseFunctionTool,
+ BaseLLMTool,
+ BaseWidgetResponseLLMTool,
+)
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
@@ -109,6 +115,11 @@ def _build_tools(self, default_llm: BaseLanguageModel, injector: BaseInjector) -
if isinstance(tool_config, FunctionToolConfig):
tools.append(self._build_function_tool(config=tool_config))
elif isinstance(tool_config, LLMToolConfig):
+ if issubclass(tool_config.tool_class, BaseWidgetResponseLLMTool):
+ tools.append(
+ self._build_widget_llm_tool(config=tool_config, default_llm=default_llm, injector=injector)
+ )
+ continue
tools.append(self._build_llm_tool(config=tool_config, injector=injector, default_llm=default_llm))
elif isinstance(tool_config, AgentToolConfig):
tools.append(self._build_agent_tool(config=tool_config))
@@ -127,11 +138,15 @@ def _build_llm_tool(
llm = default_llm
if config.llm:
llm = config.llm
- return config.tool_class(
- data_set_id=self._data_set_id,
- llm=llm,
- injector=injector,
- )
+ return config.tool_class(data_set_id=self._data_set_id, llm=llm, injector=injector)
+
+ def _build_widget_llm_tool(
+ self, config: LLMToolConfig, default_llm: BaseLanguageModel, injector: BaseInjector
+ ) -> BaseLLMTool:
+ llm = default_llm
+ if config.llm:
+ llm = config.llm
+ return config.tool_class(data_set_id=self._data_set_id, llm=llm, injector=injector, streaming=self.streaming)
def _build_file_tool(
self, config: FileToolConfig, default_llm: BaseLanguageModel, injector: BaseInjector
@@ -165,6 +180,7 @@ def _build_injector(self) -> BaseInjector:
repositories=self._repositories,
chat_summary_memory=chat_summary_memory,
chat_limited_memory=chat_limited_memory,
+ callbacks_handler=self._agent_callback_handler,
)
def _build_agent_callback_handler(self) -> Optional[BaseCallbackHandler]:
diff --git a/server/agent/core/callbacks.py b/server/agent/core/callbacks.py
index b0282f04..b79b9ffa 100644
--- a/server/agent/core/callbacks.py
+++ b/server/agent/core/callbacks.py
@@ -91,7 +91,38 @@ def on_llm_new_token(self, token: str, **kwargs):
)
-class AgentActionWebsocketCallbackHandler(BaseWebSocketHandler):
+class AgentWidgetWebsocketCallbackHandler(BaseWebSocketHandler):
+ def on_product_widget_start(self, data, **kwargs):
+ self.send_message(
+ {
+ "type": "chat_message",
+ "event": "product_widget_start",
+ "run_id": self.run_id,
+ "data": data,
+ }
+ )
+
+ def on_product_widget_product(self, product, **kwargs):
+ self.send_message(
+ {
+ "type": "chat_message",
+ "event": "product_widget_product",
+ "run_id": self.run_id,
+ "data": product,
+ },
+ )
+
+ def on_product_widget_end(self):
+ self.send_message(
+ {
+ "type": "chat_message",
+ "event": "product_widget_end",
+ "run_id": self.run_id,
+ }
+ )
+
+
+class AgentActionWebsocketCallbackHandler(AgentWidgetWebsocketCallbackHandler):
def on_agent_action(
self,
action: AgentAction,
diff --git a/server/agent/core/injector.py b/server/agent/core/injector.py
index e72d68a4..600368ea 100644
--- a/server/agent/core/injector.py
+++ b/server/agent/core/injector.py
@@ -4,6 +4,7 @@
from enthusiast_common.injectors import BaseInjector
from enthusiast_common.retrievers import BaseProductRetriever, BaseVectorStoreRetriever
from enthusiast_common.structures import RepositoriesInstances
+from langchain_core.callbacks import BaseCallbackHandler
from agent.core.memory import SummaryChatMemory
from agent.core.memory.limited_chat_memory import LimitedChatMemory
@@ -19,6 +20,7 @@ def __init__(
repositories: RepositoriesInstances,
chat_summary_memory: SummaryChatMemory,
chat_limited_memory: LimitedChatMemory,
+ callbacks_handler: BaseCallbackHandler,
):
super().__init__(repositories)
self._document_retriever = document_retriever
@@ -26,6 +28,7 @@ def __init__(
self._ecommerce_platform_connector = ecommerce_platform_connector
self._chat_summary_memory = chat_summary_memory
self._chat_limited_memory = chat_limited_memory
+ self._callbacks_handler = callbacks_handler
@property
def document_retriever(self) -> BaseVectorStoreRetriever[DocumentChunk]:
@@ -46,3 +49,7 @@ def chat_summary_memory(self) -> SummaryChatMemory:
@property
def chat_limited_memory(self) -> LimitedChatMemory:
return self._chat_limited_memory
+
+ @property
+ def callbacks_handler(self) -> BaseCallbackHandler:
+ return self._callbacks_handler
diff --git a/server/agent/core/memory/persist_intermediate_steps_mixin.py b/server/agent/core/memory/persist_intermediate_steps_mixin.py
index 93e7d396..02396638 100644
--- a/server/agent/core/memory/persist_intermediate_steps_mixin.py
+++ b/server/agent/core/memory/persist_intermediate_steps_mixin.py
@@ -1,5 +1,6 @@
+import json
from abc import ABC
-from typing import Any, Dict, cast
+from typing import Any, Dict, Literal, cast
from langchain.agents.output_parsers.tools import ToolAgentAction
from langchain.memory import ConversationBufferMemory
@@ -8,6 +9,11 @@
from agent.models import Message
+class WidgetMessage(BaseMessage):
+ type: Literal["widget"] = "widget"
+ widget_data: dict
+
+
class PersistIntermediateStepsMixin(ABC):
"""
This mixin can be added to a ConversationBufferMemory class in order to persist agent's function calls.
@@ -30,9 +36,16 @@ def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
type=Message.MessageType.INTERMEDIATE_STEP, content=agent_action.messages[0].content
)
self_as_conversation_memory.chat_memory.add_message(intermediate_step_message)
-
- function_message = FunctionMessage(name=agent_action.tool, content=result)
- self_as_conversation_memory.chat_memory.add_message(function_message)
-
- ai_message = AIMessage(outputs["output"])
- self_as_conversation_memory.chat_memory.add_message(ai_message)
+ if result:
+ function_message = FunctionMessage(name=agent_action.tool, content=result)
+ self_as_conversation_memory.chat_memory.add_message(function_message)
+ if outputs.get("output") is not None:
+ try:
+ json_object = json.loads(outputs["output"])
+ ai_message = WidgetMessage(
+ content=f"Widget representation of data: {json_object['data']}", widget_data=json_object
+ )
+ self_as_conversation_memory.chat_memory.add_message(ai_message)
+ except Exception:
+ ai_message = AIMessage(outputs["output"])
+ self_as_conversation_memory.chat_memory.add_message(ai_message)
diff --git a/server/agent/core/memory/persistent_chat_history.py b/server/agent/core/memory/persistent_chat_history.py
index 6afe76e4..9b6af2af 100644
--- a/server/agent/core/memory/persistent_chat_history.py
+++ b/server/agent/core/memory/persistent_chat_history.py
@@ -16,7 +16,10 @@ def __init__(self, conversation_repo: BaseConversationRepository, conversation_i
def add_message(self, message: BaseMessage) -> None:
self._conversation.messages.create(
- type=message.type, text=message.content, function_name=getattr(message, "name", None)
+ type=message.type,
+ text=message.content,
+ function_name=getattr(message, "name", None),
+ widget_data=getattr(message, "widget_data", None),
)
@property
diff --git a/server/agent/migrations/0024_message_widget_data_alter_message_type.py b/server/agent/migrations/0024_message_widget_data_alter_message_type.py
new file mode 100644
index 00000000..7bca93f9
--- /dev/null
+++ b/server/agent/migrations/0024_message_widget_data_alter_message_type.py
@@ -0,0 +1,33 @@
+# Generated by Django 5.2.4 on 2025-10-20 11:06
+
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+ dependencies = [
+ ("agent", "0024_alter_agent_description"),
+ ]
+
+ operations = [
+ migrations.AddField(
+ model_name="message",
+ name="widget_data",
+ field=models.JSONField(blank=True, null=True),
+ ),
+ migrations.AlterField(
+ model_name="message",
+ name="type",
+ field=models.CharField(
+ choices=[
+ ("function", "Function"),
+ ("human", "Human"),
+ ("ai", "Ai"),
+ ("system", "System"),
+ ("intermediate_step", "Intermediate Step"),
+ ("file", "File"),
+ ("widget", "Widget"),
+ ],
+ max_length=50,
+ ),
+ ),
+ ]
diff --git a/server/agent/models/message.py b/server/agent/models/message.py
index 0767c450..7e785eae 100644
--- a/server/agent/models/message.py
+++ b/server/agent/models/message.py
@@ -12,6 +12,7 @@ class MessageType(models.TextChoices):
SYSTEM = "system"
INTERMEDIATE_STEP = "intermediate_step"
FILE = "file"
+ WIDGET = "widget"
conversation = models.ForeignKey(Conversation, related_name="messages", on_delete=models.PROTECT)
created_at = models.DateTimeField(auto_now_add=True)
@@ -25,6 +26,7 @@ class MessageType(models.TextChoices):
function_name = models.CharField(max_length=50, blank=True, null=True)
file_name = models.CharField(max_length=256, blank=True, null=True)
file_type = models.CharField(max_length=50, blank=True, null=True)
+ widget_data = models.JSONField(blank=True, null=True)
@classmethod
def internal_message_types(cls):
@@ -33,12 +35,13 @@ def internal_message_types(cls):
@property
def langchain_type(self):
langchain_type_mapping = {
- self.MessageType.FUNCTION: "ai",
self.MessageType.FILE: "human",
- self.MessageType.INTERMEDIATE_STEP: "ai",
self.MessageType.HUMAN: "human",
- self.MessageType.AI: "ai",
self.MessageType.SYSTEM: "system",
+ self.MessageType.INTERMEDIATE_STEP: "ai",
+ self.MessageType.AI: "ai",
+ self.MessageType.FUNCTION: "ai",
+ self.MessageType.WIDGET: "ai",
}
return langchain_type_mapping[self.type]
diff --git a/server/agent/serializers/conversation.py b/server/agent/serializers/conversation.py
index 66900cd0..c3265b53 100644
--- a/server/agent/serializers/conversation.py
+++ b/server/agent/serializers/conversation.py
@@ -88,10 +88,10 @@ class MessagesSerializer(serializers.ModelSerializer):
class Meta:
model = Message
- fields = ["id", "text", "type", "file_type", "file_name"]
+ fields = ["id", "text", "type", "file_type", "file_name", "widget_data"]
def get_text(self, obj: Message):
- if obj.type == Message.MessageType.FILE:
+ if obj.type == Message.MessageType.FILE or obj.type == Message.MessageType.WIDGET:
return ""
else:
return obj.text
diff --git a/server/agent/views.py b/server/agent/views.py
index 245a256b..f3f917ff 100644
--- a/server/agent/views.py
+++ b/server/agent/views.py
@@ -126,7 +126,7 @@ def post(self, request, conversation_id):
conversation_id=conversation.id,
created_at=datetime.now(),
type=Message.MessageType.FILE,
- file_name=file.file.name.split(".")[0],
+ file_name=file.file.name.split(".")[-1],
file_type=file.file.name.split(".")[-1],
text=f"Uploaded {file.file.name} with id: {file.pk}",
)