diff --git a/docs/content/docs/agents/agent-config.md b/docs/content/docs/agents/agent-config.md index 619d3c90..fe00a4b8 100644 --- a/docs/content/docs/agents/agent-config.md +++ b/docs/content/docs/agents/agent-config.md @@ -22,8 +22,7 @@ class AgentConfig(ArbitraryTypeBaseModel, Generic[InjectorT]): retrievers: RetrieversConfig injector: Type[InjectorT] registry: RegistryConfig - prompt_template: Optional[PromptTemplateConfig] = None - chat_prompt_template: Optional[ChatPromptTemplateConfig] = None + system_prompt: str tools: Optional[list[FunctionToolConfig | LLMToolConfig | AgentToolConfig]] = None agent_callback_handler: Optional[AgentCallbackHandlerConfig] = None ``` @@ -81,11 +80,10 @@ class AgentConfig(ArbitraryTypeBaseModel, Generic[InjectorT]): - `embeddings`: Embedding provider registry configuration - `model`: Database model registry configuration -#### 7. **prompt_template** / **chat_prompt_template** -- **Type**: `PromptTemplateConfig` or `ChatPromptTemplateConfig` -- **Required**: Exactly one must be provided -- **Description**: Prompt configuration for agent behavior and instructions -- **Validation**: The system ensures exactly one prompt type is configured +#### 7. **system_prompt** +- **Type**: `str` +- **Required**: Yes +- **Description**: The system prompt passed to the agent. May contain `{variable}` placeholders resolved via the agent's `_get_system_prompt_variables()` hook. #### 8. **tools** - **Type**: `Optional[list[FunctionToolConfig | LLMToolConfig | AgentToolConfig]]` diff --git a/docs/content/docs/agents/agent-types.md b/docs/content/docs/agents/agent-types.md index 9bf07c01..274669b6 100644 --- a/docs/content/docs/agents/agent-types.md +++ b/docs/content/docs/agents/agent-types.md @@ -84,15 +84,12 @@ class ExampleAgent(BaseAgent): ] ``` -Those arguments can be accessed by agent like this: +Those arguments can be accessed inside the agent. For example, `_get_system_prompt_variables()` is called before each invocation to resolve `{variable}` placeholders in the system prompt: ```python - def get_answer(self, input_text: str) -> str: - agent_output = self._agent_executor.invoke( - {"input": input_text, "product_type": self.PROMPT_INPUT.product_type} - ) - return agent_output["output"] + def _get_system_prompt_variables(self) -> dict: + return {"output_format": self.PROMPT_INPUT.output_format} ``` @@ -107,38 +104,57 @@ The `BaseToolCallingAgent` implements the standard tool calling pattern, leverag ```python class BaseToolCallingAgent(BaseAgent): def get_answer(self, input_text: str) -> str: - # Build and execute the agent - agent_executor = self._build_agent_executor() - response = agent_executor.invoke( - {"input": input_text}, - config=self._build_invoke_config() + history = self._injector.chat_history + + # Trim history to token budget and append current user message + agent = self._build_agent() + limited_history = self._build_limited_history(history) + input_messages = limited_history + [HumanMessage(content=input_text)] + + # Execute the agent with the full message list + result = agent.invoke({"messages": input_messages}, config=self._build_invoke_config()) + + # Slice off only the new messages produced this turn, then persist them + new_messages = result["messages"][len(limited_history):] + final_message = next( + m for m in reversed(new_messages) + if isinstance(m, AIMessage) and not m.tool_calls ) - return response["output"] + history.add_messages(new_messages) + return final_message.text def _build_tools(self) -> list[BaseTool]: - """Convert internal tools to LangChain tools""" + """Convert internal tools to LangChain BaseTool instances.""" return [tool.as_tool() for tool in self._tools] - def _build_memory(self) -> BaseMemory: - """Use limited memory for ReAct reasoning""" - return self._injector.chat_limited_memory - def _build_invoke_config(self) -> dict[str, Any]: - """Configure callback handlers""" + """Pass callback handler to the agent invocation if one is configured.""" if self._callback_handler: return {"callbacks": [self._callback_handler]} return {} - def _build_agent_executor(self) -> AgentExecutor: - """Create the LangChain agent executor""" - tools = self._build_tools() - agent = create_tool_calling_agent( - tools=tools, - llm=self._llm, - prompt=self._prompt, + def _build_limited_history(self, history: BaseChatMessageHistory) -> list[BaseMessage]: + """Trim conversation history to MAX_HISTORY_TOKENS using the LLM as the token counter.""" + return trim_messages( + history.messages, + strategy="last", + token_counter=self._llm, + max_tokens=MAX_HISTORY_TOKENS, + start_on=HumanMessage, + include_system=True, + allow_partial=False, ) - return AgentExecutor( - agent=agent, tools=tools, verbose=True, memory=self._build_memory(), return_intermediate_steps=True + + def _get_system_prompt(self) -> str: + """Resolve template variables in the system prompt string via _get_system_prompt_variables().""" + return self._system_prompt.format(**self._get_system_prompt_variables()) + + def _build_agent(self): + """Build a LangGraph agent with the configured LLM, tools, and system prompt.""" + return create_agent( + model=self._llm, + tools=self._build_tools(), + system_prompt=self._get_system_prompt(), ) ``` diff --git a/docs/content/docs/agents/builder.md b/docs/content/docs/agents/builder.md index d327ee76..127e48f9 100644 --- a/docs/content/docs/agents/builder.md +++ b/docs/content/docs/agents/builder.md @@ -56,10 +56,9 @@ def build(self) -> BaseAgent: # 3. Build tools and handlers tools = self._build_tools(default_llm=self._llm, injector=self._injector) agent_callback_handler = self._build_agent_callback_handler() - prompt_template = self._build_prompt_template() # 4. Create and configure agent - agent_instance = self._build_agent(tools, self._llm, prompt_template, agent_callback_handler) + agent_instance = self._build_agent(tools, self._llm, agent_callback_handler) self._inject_additional_arguments(agent_instance) return agent_instance @@ -77,7 +76,6 @@ def _build_agent( self, tools: list[BaseTool], llm: BaseLanguageModel, - prompt: PromptTemplate | ChatMessagePromptTemplate, callback_handler: BaseCallbackHandler, ) -> BaseAgent: """Build the final agent instance""" @@ -87,11 +85,6 @@ def _build_agent( def _build_injector(self) -> BaseInjector: """Build the dependency injection container""" pass - -@abstractmethod -def _build_prompt_template(self) -> BasePromptTemplate: - """Build the prompt template for the agent""" - pass ``` #### Registry and Repository Methods @@ -152,7 +145,7 @@ def _build_agent_tool(self, config: AgentToolConfig) -> BaseAgentTool: pass ``` -#### Memory and Callback Methods +#### Callback Methods ```python @abstractmethod @@ -164,16 +157,6 @@ def _build_agent_callback_handler(self) -> Optional[BaseCallbackHandler]: def _build_llm_callback_handlers(self) -> Optional[BaseCallbackHandler]: """Build LLM callback handlers""" pass - -@abstractmethod -def _build_chat_summary_memory(self) -> BaseMemory: - """Build summary-based chat memory""" - pass - -@abstractmethod -def _build_chat_limited_memory(self) -> BaseMemory: - """Build limited chat memory""" - pass ``` ### Runtime Configuration Injection diff --git a/docs/content/docs/agents/injector.md b/docs/content/docs/agents/injector.md index 7d893c35..57050e0f 100644 --- a/docs/content/docs/agents/injector.md +++ b/docs/content/docs/agents/injector.md @@ -38,12 +38,7 @@ class BaseInjector(ABC): @property @abstractmethod - def chat_summary_memory(self) -> BaseMemory: - pass - - @property - @abstractmethod - def chat_limited_memory(self) -> BaseMemory: + def chat_history(self) -> BaseChatMessageHistory: pass ``` @@ -58,14 +53,12 @@ class Injector(BaseInjector): document_retriever: BaseVectorStoreRetriever[DocumentChunk], product_retriever: BaseProductRetriever, repositories: RepositoriesInstances, - chat_summary_memory: SummaryChatMemory, - chat_limited_memory: LimitedChatMemory, + chat_history: BaseChatMessageHistory, ): super().__init__(repositories) self._document_retriever = document_retriever self._product_retriever = product_retriever - self._chat_summary_memory = chat_summary_memory - self._chat_limited_memory = chat_limited_memory + self._chat_history = chat_history @property def document_retriever(self) -> BaseVectorStoreRetriever[DocumentChunk]: @@ -76,12 +69,8 @@ class Injector(BaseInjector): return self._product_retriever @property - def chat_summary_memory(self) -> SummaryChatMemory: - return self._chat_summary_memory - - @property - def chat_limited_memory(self) -> LimitedChatMemory: - return self._chat_limited_memory + def chat_history(self) -> BaseChatMessageHistory: + return self._chat_history ``` ## Available Resources @@ -94,9 +83,9 @@ The document retriever provides access to document content through vector search The product retriever provides access to product information: -### 3. Chat Memory Systems +### 3. Chat History -The injector provides access to two types of memory systems Summary Chat Memory and Limited Chat Memory: +The injector provides access to the persistent conversation history via `chat_history: BaseChatMessageHistory`. See [Memory](./memory.md) for details on token limiting and persistence. ### 4. Repository Access @@ -150,15 +139,14 @@ class ExampleTool(BaseLLMTool): Agents receive the injector during construction and can access all resources: ```python -class ExampleAgent(BaseAgent): - def _build_agent_executor(self) -> AgentExecutor: - tools = self._build_tools() - agent = create_tool_calling_agent( - tools=tools, - llm=self._llm, - prompt=self._prompt, - ) - return AgentExecutor(agent=agent, tools=tools, memory=self.injector.chat_limited_memory) +class ExampleAgent(BaseToolCallingAgent): + def get_answer(self, input_text: str) -> str: + # Fetch relevant documents before invoking the agent + # and inject them as additional context into the user message + docs = self._injector.document_retriever.find_content_matching_query(input_text) + context = "\n".join(doc.content for doc in docs) + enriched_input = f"{input_text}\n\nContext:\n{context}" + return super().get_answer(enriched_input) ``` @@ -173,16 +161,14 @@ def _build_injector(self) -> BaseInjector: # Build individual components document_retriever = self._build_document_retriever() product_retriever = self._build_product_retriever() - chat_summary_memory = self._build_chat_summary_memory() - chat_limited_memory = self._build_chat_limited_memory() - + chat_history = self._build_chat_history() + # Create injector with all components return self._config.injector( product_retriever=product_retriever, document_retriever=document_retriever, repositories=self._repositories, - chat_summary_memory=chat_summary_memory, - chat_limited_memory=chat_limited_memory, + chat_history=chat_history, ) ``` @@ -206,12 +192,8 @@ class CustomInjector(BaseInjector): return self._build_custom_product_retriever() @property - def chat_summary_memory(self) -> BaseMemory: - return self._build_custom_summary_memory() - - @property - def chat_limited_memory(self) -> BaseMemory: - return self._build_custom_limited_memory() + def chat_history(self) -> BaseChatMessageHistory: + return self._build_custom_chat_history() @property def custom_service(self) -> CustomService: diff --git a/docs/content/docs/agents/memory.md b/docs/content/docs/agents/memory.md index ae13bf10..54daf8f5 100644 --- a/docs/content/docs/agents/memory.md +++ b/docs/content/docs/agents/memory.md @@ -2,201 +2,78 @@ sidebar_position: 3 --- -# Memory Types +# Memory -Enthusiast provides by default two memory management strategies to help agents maintain context and conversation history. Each memory type is designed for different use cases and performance requirements. +Enthusiast manages conversation memory through two complementary mechanisms: persistent storage to the database and token-limited context trimming before each agent invocation. ## Overview Memory in Enthusiast serves two main purposes: -1. **Conversation Persistence**: Storing and retrieving chat history from the database -2. **Context Management**: Providing relevant conversation context to agents while managing token limits +1. **Conversation Persistence**: Storing and retrieving the full chat history from the database +2. **Context Management**: Trimming history to a token budget before passing it to the LLM -## Available Memory Types +## PersistentChatHistory -### 1. Summary Chat Memory +`PersistentChatHistory` implements `BaseChatMessageHistory` and is responsible for reading and writing messages to the database. It is injected into the agent via the `Injector` through the `chat_history` property. -**Class**: `SummaryChatMemory` -**Base**: `ConversationSummaryBufferMemory` with intermediate steps persistence - -**Description**: This memory type automatically summarizes conversation history when it exceeds the token limit, ensuring agents always have relevant context without hitting token constraints. - -**Key Features**: -- Automatically summarizes long conversations -- Persists intermediate agent steps (tool calls, observations) -- Configurable token limit (default: 3000 tokens) -- Maintains conversation flow while optimizing memory usage - -**Best For**: -- Long-running conversations -- Agents that need to remember key points from extended discussions -- Scenarios where conversation context is important but token efficiency is required - -**Configuration**: ```python -SummaryChatMemory( - llm=language_model, - memory_key="chat_history", - return_messages=True, - max_token_limit=3000, - output_key="output", - chat_memory=persistent_history -) -``` +class PersistentChatHistory(BaseChatMessageHistory): + def __init__(self, conversation_repo: BaseConversationRepository, conversation_id: Any): + self._conversation = conversation_repo.get_by_id(conversation_id) -### 2. Limited Chat Memory + def add_message(self, message: BaseMessage) -> None: + ... -**Class**: `LimitedChatMemory` -**Base**: `ConversationTokenBufferMemory` with intermediate steps persistence + @property + def messages(self) -> list[BaseMessage]: + ... +``` -**Description**: This memory type maintains a fixed token limit for conversation history, automatically truncating older messages when the limit is exceeded. +Messages are persisted after each agent turn by calling `history.add_messages(new_messages)`, where `new_messages` contains the full turn: human input, intermediate tool call/result pairs, and the final AI response. -**Key Features**: -- Fixed token limit for conversation context -- Persists intermediate agent steps -- Configurable token limit (default: 3000 tokens) -- Predictable memory usage +## Token Limiting -**Best For**: -- Real-time applications with strict token budgets -- Scenarios where recent context is more important than historical context -- High-frequency chat applications +Before each invocation, conversation history is trimmed using `trim_messages` from `langchain_core.messages`. The default token limit is 3000. -**Configuration**: ```python -LimitedChatMemory( - llm=language_model, - memory_key="chat_history", - return_messages=True, - max_token_limit=3000, - output_key="output", - chat_memory=persistent_history +from langchain_core.messages import trim_messages, HumanMessage + +limited_history = trim_messages( + history.messages, + strategy="last", + token_counter=llm, + max_tokens=3000, + start_on=HumanMessage, + include_system=True, + allow_partial=False, ) ``` -## Configuration - -### Default Settings +This is handled automatically inside `BaseToolCallingAgent` — no manual configuration is required. -- **Token Limit**: 3000 tokens (configurable) -- **Memory Key**: "chat_history" -- **Output Key**: "output" -- **Message Return**: True (returns structured messages) +## Accessing Chat History -### Customization +The chat history is accessible inside agents and tools via the injector: -In need of customization, those classes may be changed inside builder's methods responsible for creating it. ```python - def _build_chat_summary_memory(self) -> SummaryChatMemory: - history = PersistentChatHistory(self._repositories.conversation, self._config.conversation_id) - return SummaryChatMemory( - llm=self._llm, - memory_key="chat_history", - return_messages=True, - max_token_limit=3000, - output_key="output", - chat_memory=history, - ) - - def _build_chat_limited_memory(self) -> LimitedChatMemory: - history = PersistentChatHistory(self._repositories.conversation, self._config.conversation_id) - return LimitedChatMemory( - llm=self._llm, - memory_key="chat_history", - return_messages=True, - max_token_limit=3000, - output_key="output", - chat_memory=history, - ) +class MyAgent(BaseToolCallingAgent): + def get_answer(self, input_text: str) -> str: + history = self._injector.chat_history + # history.messages → full conversation history + # history.add_messages([...]) → persist new messages ``` -## Additional memory -In order to add additional type of memory: -Create custom memory class and then, build custom Injector based on enthusiast-common interface - `BaseInjector`: -```python -class Injector(BaseInjector): - def __init__( - self, - document_retriever: BaseRetriever, - product_retriever: BaseRetriever, - repositories: RepositoriesInstances, - chat_summary_memory: SummaryChatMemory, - chat_limited_memory: LimitedChatMemory, - additional_memory: AdditionalMemoryClass, - ): - super().__init__(repositories) - self._document_retriever = document_retriever - self._product_retriever = product_retriever - self._chat_summary_memory = chat_summary_memory - self._chat_limited_memory = chat_limited_memory - self._additional_memory = additional_memory - - @property - def document_retriever(self) -> BaseRetriever: - return self._document_retriever +## Extending the Injector with Additional State - @property - def product_retriever(self) -> BaseRetriever: - return self._product_retriever +If you need to persist additional state beyond conversation messages, extend `BaseInjector` with a custom property backed by its own repository or service: - @property - def chat_summary_memory(self) -> SummaryChatMemory: - return self._chat_summary_memory - - @property - def chat_limited_memory(self) -> LimitedChatMemory: - return self._chat_limited_memory - - @property - def additional_memory(self) -> AdditionalMemory: - return self.additional_memory -``` -Add method to build memory class instance inside Builder: ```python - def _build_additional_memory(self) -> AdditionalMemory: - history = PersistentChatHistory(self._repositories.conversation, self._config.conversation_id) - return AdditionalMemory( - llm=self._llm, - memory_key="chat_history", - return_messages=True, - max_token_limit=3000, - output_key="output", - chat_memory=history, - ) -``` -Add it to injector: - -```python - def _build_injector(self) -> BaseInjector: - document_retriever = self._build_document_retriever() - product_retriever = self._build_product_retriever() - chat_summary_memory = self._build_chat_summary_memory() - chat_limited_memory = self._build_chat_limited_memory() - additional_memory = self._build_additional_memory() - return self._config.injector( - product_retriever=product_retriever, - document_retriever=document_retriever, - repositories=self._repositories, - chat_summary_memory=chat_summary_memory, - chat_limited_memory=chat_limited_memory, - additional_memory=additional_memory - ) -``` -## Usage Examples +class CustomInjector(BaseInjector): + def __init__(self, ..., custom_store: CustomStore): + super().__init__(...) + self._custom_store = custom_store -### Basic Memory Usage -All memory instances are accessible inside Agent class via `self.injector` -```python -from enthusiast_common.agents import BaseAgent -from langchain.agents import AgentExecutor, create_tool_calling_agent - -class MyAgent(BaseAgent): - def _build_agent_executor(self) -> AgentExecutor: - tools = self._build_tools() - agent = create_tool_calling_agent( - tools=tools, - llm=self._llm, - prompt=self._prompt, - ) - return AgentExecutor(agent=agent, tools=tools, memory=self.injector.chat_limited_memory) + @property + def custom_store(self) -> CustomStore: + return self._custom_store ``` diff --git a/docs/content/docs/agents/prompts.md b/docs/content/docs/agents/prompts.md index ba25fa90..60b56eba 100644 --- a/docs/content/docs/agents/prompts.md +++ b/docs/content/docs/agents/prompts.md @@ -16,97 +16,47 @@ Prompts in Enthusiast serve several key purposes: - **Output Formatting**: Define the expected structure and format of responses - **Context Management**: Handle conversation history and current context -## Supported Prompt Types - -Enthusiast supports two main prompt types, each designed for different use cases: - -### 1. PromptTemplate - -Single text template with variable placeholders - -### 2. ChatPromptTemplate - -Multi-message template with conversational interactions - ## Configuration -### Prompt Configuration Structure - -Prompts are configured through the `AgentConfig`: - -```python -class AgentConfig(ArbitraryTypeBaseModel, Generic[InjectorT]): - # ... other configuration fields ... - - prompt_template: Optional[PromptTemplateConfig] = None - chat_prompt_template: Optional[ChatPromptTemplateConfig] = None - -``` - -### PromptTemplate Configuration +The agent system prompt is a plain string passed via `AgentConfig.system_prompt`: ```python -class PromptTemplateConfig(ArbitraryTypeBaseModel): - input_variables: list[str] # List of variable names used in the template - template: str # The prompt template string +# prompt.py (conventional location) +MY_AGENT_SYSTEM_PROMPT = """ +You are a helpful assistant... +""" ``` -**Example Configuration**: ```python -prompt_template=PromptTemplateConfig( - input_variables=["tools", "tool_names", "input", "agent_scratchpad"], - template=EXAMPLE_AGENT_PROMPT_TEMPLATE -) +# config.py + +class ExampleConfig(BaseAgentConfigProvider): + def get_config(self, config_type: ConfigType = ConfigType.CONVERSATION) -> AgentConfigWithDefaults: + return AgentConfigWithDefaults( + system_prompt=MY_AGENT_SYSTEM_PROMPT, + agent_class=MyAgent, + tools=MyAgent.TOOLS, + ) ``` -### ChatPromptTemplate Configuration +## Template Variables -```python -class ChatPromptTemplateConfig(ArbitraryTypeBaseModel): - messages: Sequence[MessageLikeRepresentation] # List of message components -``` +If the system prompt contains `{variable}` placeholders, override `_get_system_prompt_variables()` in the agent class: -**Example Configuration**: ```python -chat_prompt_template=ChatPromptTemplateConfig( - messages=[ - ("system", "You are a sales support agent, and you know everything about a company and their products."), - ("placeholder", "{chat_history}"), - ("human", "{input}"), - ("placeholder", "{agent_scratchpad}"), - ] -) +class MyAgent(BaseToolCallingAgent): + def _get_system_prompt_variables(self) -> dict: + return {"output_format": self.PROMPT_INPUT.output_format} ``` -## Prompt Construction - -### Builder Integration - -The `AgentBuilder` automatically constructs the appropriate prompt type based on configuration: - -```python -def _build_prompt_template(self) -> BasePromptTemplate: - """Build the prompt template for the agent""" - if self._config.prompt_template: - # Use text-based prompt template - return PromptTemplate( - input_variables=self._config.prompt_template.input_variables, - template=self._config.prompt_template.template, - ) - else: - # Use chat-based prompt template - return ChatPromptTemplate.from_messages( - messages=self._config.chat_prompt_template.messages - ) -``` +The base implementation returns `{}`, so agents with static prompts require no override. ## Summary Prompts in Enthusiast provide a powerful and flexible foundation for agent behavior: -- **Multiple Types**: Support for both text-based and chat-based prompts -- **Context Management**: Rich context and conversation history support -- **Configuration-Driven**: Flexible configuration through the agent config system -- **Best Practices**: Established patterns for effective prompt design +- **Plain string**: System prompt is a `str` passed directly to the LLM via LangGraph +- **Context Management**: Conversation history and token limiting handled automatically by `BaseToolCallingAgent` +- **Configuration-Driven**: Prompt passed through the agent config system By understanding and effectively using the prompt system, developers can create agents that exhibit sophisticated reasoning, clear communication, and effective tool usage while maintaining flexibility and extensibility. diff --git a/docs/content/docs/customization/concept-product-search.md b/docs/content/docs/customization/concept-product-search.md index 755a8460..e670c7fd 100644 --- a/docs/content/docs/customization/concept-product-search.md +++ b/docs/content/docs/customization/concept-product-search.md @@ -146,13 +146,8 @@ class ProductSearchAgent(BaseToolCallingAgent): LLMToolConfig(tool_class=ProductVerificationTool), ] - def get_answer(self, input_text: str) -> str: - agent_executor = self._build_agent_executor() - agent_output = agent_executor.invoke( - {"input": input_text, "products_type": self.PROMPT_INPUT.products_type}, - config=self._build_invoke_config(), - ) - return agent_output["output"] + def _get_system_prompt_variables(self) -> dict: + return {"products_type": self.PROMPT_INPUT.products_type} ``` @@ -251,271 +246,11 @@ class DocumentRetriever(BaseVectorStoreRetriever): ``` -### Memory - -```python -import typing -from abc import ABC -from typing import Dict, Any - -from langchain.memory import ConversationBufferMemory, ConversationTokenBufferMemory, ConversationSummaryBufferMemory -from langchain_core.messages import AIMessage, FunctionMessage, HumanMessage - -from enthusiast_common.repositories import BaseConversationRepository -from langchain_core.chat_history import BaseChatMessageHistory -from langchain_core.messages import BaseMessage, messages_from_dict - - -class PersistentChatHistory(BaseChatMessageHistory): - """ - A chat history implementation that persists messages in the database. - Inject it to agent's memory, to enable conversation persistence. - """ - - def __init__(self, conversation_repo: BaseConversationRepository, conversation_id: Any): - self._conversation = conversation_repo.get_by_id(conversation_id) - - def add_message(self, message: BaseMessage) -> None: - self._conversation.messages.create(role=message.type, text=message.content) - - @property - def messages(self) -> list[BaseMessage]: - messages = self._conversation.messages.filter(answer_failed=False).order_by("created_at") - message_dicts = [{"type": message.role, "data": {"content": message.text}} for message in messages] - return messages_from_dict(message_dicts) - - def clear(self) -> None: - self._conversation.messages.all().delete() - - -class PersistIntermediateStepsMixin(ABC): - """ - This mixin can be added to a ConversationBufferMemory class in order to persist agent's function calls. - """ - - def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: - self_as_conversation_memory = typing.cast(ConversationBufferMemory, self) - - human_message = HumanMessage(inputs["input"]) - self_as_conversation_memory.chat_memory.add_message(human_message) - - if "intermediate_steps" in outputs: - for agent_action, result in outputs["intermediate_steps"]: - self_as_conversation_memory.chat_memory.add_message(agent_action.messages[0]) - - function_message = FunctionMessage(name=agent_action.tool, content=result) - self_as_conversation_memory.chat_memory.add_message(function_message) - - ai_message = AIMessage(outputs["output"]) - self_as_conversation_memory.chat_memory.add_message(ai_message) - - - -class LimitedChatMemory(PersistIntermediateStepsMixin, ConversationTokenBufferMemory): - """ - This memory persists intermediate steps, and limits the amount of tokens passed back to the agent to - what's defined as max_token_limit. - """ - - pass - - -class SummaryChatMemory(PersistIntermediateStepsMixin, ConversationSummaryBufferMemory): - """ - This memory persists intermediate steps, and summarizes the history passed back to the agent if the history - exceeds the token limit. - """ - - pass -``` - -### Builder -```python -from typing import Optional - -from enthusiast_common.agents import BaseAgent -from enthusiast_common.builder import BaseAgentBuilder, RepositoriesInstances -from enthusiast_common.config import AgentConfig, LLMConfig -from enthusiast_common.injectors import BaseInjector -from enthusiast_common.registry import BaseDBModelsRegistry, BaseEmbeddingProviderRegistry, BaseLanguageModelRegistry -from enthusiast_common.retrievers import BaseVectorStoreRetriever -from enthusiast_common.tools import BaseAgentTool, BaseFunctionTool, BaseLLMTool -from langchain_core.callbacks import BaseCallbackHandler -from langchain_core.language_models import BaseLanguageModel -from langchain_core.prompts import ChatMessagePromptTemplate, PromptTemplate -from langchain_core.tools import BaseTool - -from .memory import PersistentChatHistory, SummaryChatMemory, LimitedChatMemory - - -class AgentBuilder(BaseAgentBuilder[AgentConfig]): - def _build_agent( - self, - tools: list[BaseTool], - llm: BaseLanguageModel, - prompt: PromptTemplate | ChatMessagePromptTemplate, - callback_handler: BaseCallbackHandler, - ) -> BaseAgent: - return self._config.agent_class( - tools=tools, - llm=llm, - prompt=prompt, - conversation_id=self._config.conversation_id, - callback_handler=callback_handler, - injector=self._injector, - ) - - def _build_llm_registry(self) -> BaseLanguageModelRegistry: - llm_registry_class = self._config.registry.llm.registry_class - data_set_repo = self._repositories.data_set - if providers := self._config.registry.llm.providers: - llm_registry = llm_registry_class(providers=providers) - else: - llm_registry = llm_registry_class(data_set_repo=data_set_repo) - return llm_registry - - def _build_db_models_registry(self) -> BaseDBModelsRegistry: - db_models_registry_class = self._config.registry.model.registry_class - if models_config := self._config.registry.model.models_config: - db_model_registry = db_models_registry_class(models_config=models_config) - else: - db_model_registry = db_models_registry_class() - return db_model_registry - - def _build_and_set_repositories(self, models_registry: BaseDBModelsRegistry) -> None: - repositories = {} - for name in self._config.repositories.__class__.model_fields.keys(): - repo_class = getattr(self._config.repositories, name) - model_class = models_registry.get_model_class_by_name(name) - repositories[name] = repo_class(model_class) - self._repositories = RepositoriesInstances(**repositories) - - def _build_embeddings_registry(self) -> BaseEmbeddingProviderRegistry: - embeddings_registry_class = self._config.registry.embeddings.registry_class - data_set_repo = self._repositories.data_set - if providers := self._config.registry.llm.providers: - embeddings_registry = embeddings_registry_class(providers=providers) - else: - embeddings_registry = embeddings_registry_class(data_set_repo=data_set_repo) - return embeddings_registry - - def _build_llm(self, llm_config: LLMConfig) -> BaseLanguageModel: - data_set_repo = self._repositories.data_set - llm_registry = self._build_llm_registry() - llm = self._config.llm.llm_class( - llm_registry=llm_registry, - callbacks=llm_config.callbacks, - streaming=llm_config.streaming, - data_set_repo=data_set_repo, - ) - return llm.create(self._data_set_id) - - def _build_default_llm(self) -> BaseLanguageModel: - llm_registry_class = self._config.registry.llm.registry_class - data_set_repo = self._repositories.data_set - if providers := self._config.registry.llm.providers: - llm_registry = llm_registry_class(providers=providers) - else: - llm_registry = llm_registry_class(data_set_repo=data_set_repo) - - llm = self._config.llm.llm_class( - llm_registry=llm_registry, - data_set_repo=data_set_repo, - ) - return llm.create(self._data_set_id) - - def _build_tools(self, default_llm: BaseLanguageModel, injector: BaseInjector) -> list[BaseTool]: - function_tools = self._build_function_tools() if self._config.function_tools else [] - llm_tools = self._build_llm_tools(default_llm, injector) if self._config.llm_tools else [] - agent_tools = self._build_agent_tools() if self._config.agent_tools else [] - return [*function_tools, *llm_tools, *agent_tools] - - def _build_function_tools(self) -> list[BaseFunctionTool]: - return [tool() for tool in self._config.function_tools] - - def _build_llm_tools(self, default_llm: BaseLanguageModel, injector: BaseInjector) -> list[BaseLLMTool]: - tools = [] - for tool_config in self._config.llm_tools: - llm = default_llm - data_set_id = tool_config.data_set_id or self._data_set_id - if tool_config.llm: - llm = tool_config.llm - tools.append( - tool_config.tool_class( - data_set_id=data_set_id, - llm=llm, - injector=injector, - ) - ) - return tools - - def _build_agent_tools(self) -> list[BaseAgentTool]: - return [tool_config.tool_class(agent=tool_config.agent) for tool_config in self._config.agent_tools] - - def _build_injector(self) -> BaseInjector: - document_retriever = self._build_document_retriever() - product_retriever = self._build_product_retriever() - return self._config.injector( - product_retriever=product_retriever, - document_retriever=document_retriever, - repositories=self._repositories, - chat_summary_memory=self._chat_summary_memory, - chat_limited_memory=self._chat_limited_memory, - ) - - def _build_agent_callback_handler(self) -> Optional[BaseCallbackHandler]: - if self._config.agent_callback_handler: - return self._config.agent_callback_handler.handler_class(**self._config.agent_callback_handler.args) - return None - - def _build_product_retriever(self) -> BaseVectorStoreRetriever: - return self._config.retrievers.product.retriever_class.create( - config=self._config, - data_set_id=self._data_set_id, - repositories=self._repositories, - embeddings_registry=self._embeddings_registry, - llm=self._llm, - ) - - def _build_document_retriever(self) -> BaseVectorStoreRetriever: - return self._config.retrievers.document.retriever_class.create( - config=self._config, - data_set_id=self._data_set_id, - repositories=self._repositories, - embeddings_registry=self._embeddings_registry, - llm=self._llm, - ) - - def _build_chat_summary_memory(self) -> SummaryChatMemory: - history = PersistentChatHistory(self._repositories.conversation, self._config.conversation_id) - return SummaryChatMemory( - llm=self._llm, - memory_key="chat_history", - return_messages=True, - max_token_limit=3000, - output_key="output", - chat_memory=history, - ) - - def _build_chat_limited_memory(self) -> LimitedChatMemory: - history = PersistentChatHistory(self._repositories.conversation, self._config.conversation_id) - return LimitedChatMemory( - llm=self._llm, - memory_key="chat_history", - return_messages=True, - max_token_limit=3000, - output_key="output", - chat_memory=history, - ) - -``` - ### Configuration Create configuration inside `config.py` file: ```python from enthusiast_common.config import AgentConfigWithDefaults from enthusiast_common.config.base import RetrieverConfig, RetrieversConfig -from enthusiast_common.config.prompts import ChatPromptTemplateConfig, Message, MessageRole from .agent import ProductSearchAgent from .prompt import PRODUCT_FINDER_AGENT_PROMPT @@ -524,17 +259,7 @@ from .retrievers import ProductVectorStoreRetriever, DocumentRetriever def get_config() -> AgentConfigWithDefaults: return AgentConfigWithDefaults( - prompt_template=ChatPromptTemplateConfig( - messages=[ - Message( - role=MessageRole.SYSTEM, - content=PRODUCT_FINDER_AGENT_PROMPT, - ), - Message(role=MessageRole.PLACEHOLDER, content="{chat_history}"), - Message(role=MessageRole.USER, content="{input}"), - Message(role=MessageRole.PLACEHOLDER, content="{agent_scratchpad}"), - ] - ), + system_prompt=PRODUCT_FINDER_AGENT_PROMPT, agent_class=ProductSearchAgent, tools=ProductSearchAgent.TOOLS, retrievers=RetrieversConfig( diff --git a/docs/content/docs/customization/custom-agent.md b/docs/content/docs/customization/custom-agent.md index f73ff787..dc72fd77 100644 --- a/docs/content/docs/customization/custom-agent.md +++ b/docs/content/docs/customization/custom-agent.md @@ -89,7 +89,6 @@ __all__ = ["ContextSearchTool"] 5. Create configuration inside `config.py` file: ```python from enthusiast_common.config import AgentConfigWithDefaults -from enthusiast_common.config.prompts import ChatPromptTemplateConfig, Message, MessageRole from .agent import ExampleDocumentContextAgent from .prompt import DOCUMENT_CONTEXT_AGENT_SYSTEM_PROMPT @@ -97,17 +96,7 @@ from .prompt import DOCUMENT_CONTEXT_AGENT_SYSTEM_PROMPT def get_config() -> AgentConfigWithDefaults: return AgentConfigWithDefaults( - prompt_template=ChatPromptTemplateConfig( - messages=[ - Message( - role=MessageRole.SYSTEM, - content=DOCUMENT_CONTEXT_AGENT_SYSTEM_PROMPT, - ), - Message(role=MessageRole.PLACEHOLDER, content="{chat_history}"), - Message(role=MessageRole.USER, content="{input}"), - Message(role=MessageRole.PLACEHOLDER, content="{agent_scratchpad}"), - ] - ), + system_prompt=DOCUMENT_CONTEXT_AGENT_SYSTEM_PROMPT, agent_class=ExampleDocumentContextAgent, tools=ExampleDocumentContextAgent.TOOLS, )