diff --git a/src/react_agent/graph.py b/src/react_agent/graph.py index 9d2ca76..a919b75 100644 --- a/src/react_agent/graph.py +++ b/src/react_agent/graph.py @@ -7,7 +7,7 @@ from typing import Dict, List, Literal, cast from langchain_core.messages import AIMessage -from langgraph.graph import StateGraph +from langgraph.graph import END, START, StateGraph from langgraph.prebuilt import ToolNode from langgraph.runtime import Runtime @@ -42,7 +42,7 @@ async def call_model( ) # Get the model's response - response = cast( # type: ignore[redundant-cast] + response = cast( # type: ignore[redundant-cast] AIMessage, await model.ainvoke( [{"role": "system", "content": system_message}, *state.messages] @@ -74,7 +74,7 @@ async def call_model( # Set the entrypoint as `call_model` # This means that this node is the first one called -builder.add_edge("__start__", "call_model") +builder.add_edge(START, "call_model") def route_model_output(state: State) -> Literal["__end__", "tools"]: @@ -86,7 +86,7 @@ def route_model_output(state: State) -> Literal["__end__", "tools"]: state (State): The current state of the conversation. Returns: - str: The name of the next node to call ("__end__" or "tools"). + str: The name of the next node to call (END or "tools"). """ last_message = state.messages[-1] if not isinstance(last_message, AIMessage): @@ -95,7 +95,7 @@ def route_model_output(state: State) -> Literal["__end__", "tools"]: ) # If there is no tool call, then we finish if not last_message.tool_calls: - return "__end__" + return END # Otherwise we execute the requested actions return "tools"