Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/react_agent/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Dict, List, Literal, cast

from langchain_core.messages import AIMessage
from langgraph.graph import StateGraph
from langgraph.graph import END, START, StateGraph
from langgraph.prebuilt import ToolNode
from langgraph.runtime import Runtime

Expand Down Expand Up @@ -42,7 +42,7 @@ async def call_model(
)

# Get the model's response
response = cast( # type: ignore[redundant-cast]
response = cast( # type: ignore[redundant-cast]
AIMessage,
await model.ainvoke(
[{"role": "system", "content": system_message}, *state.messages]
Expand Down Expand Up @@ -74,7 +74,7 @@ async def call_model(

# Set the entrypoint as `call_model`
# This means that this node is the first one called
builder.add_edge("__start__", "call_model")
builder.add_edge(START, "call_model")


def route_model_output(state: State) -> Literal["__end__", "tools"]:
Expand All @@ -86,7 +86,7 @@ def route_model_output(state: State) -> Literal["__end__", "tools"]:
state (State): The current state of the conversation.

Returns:
str: The name of the next node to call ("__end__" or "tools").
str: The name of the next node to call (END or "tools").
"""
last_message = state.messages[-1]
if not isinstance(last_message, AIMessage):
Expand All @@ -95,7 +95,7 @@ def route_model_output(state: State) -> Literal["__end__", "tools"]:
)
# If there is no tool call, then we finish
if not last_message.tool_calls:
return "__end__"
return END
# Otherwise we execute the requested actions
return "tools"

Expand Down