Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions agentflow_cli/src/app/core/middleware/request_limits.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ async def dispatch(self, request: Request, call_next):
"error": {
"code": "REQUEST_TOO_LARGE",
"message": (
f"Request body too large. "
f"Maximum size is {self.max_size_mb:.1f}MB"
f"Request body too large. Maximum size is {self.max_size_mb:.1f}MB"
),
"max_size_bytes": self.max_size,
"max_size_mb": self.max_size_mb,
Expand Down
37 changes: 36 additions & 1 deletion agentflow_cli/src/app/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@


async def load_graph(path: str) -> CompiledGraph | None:
if ":" not in path:
raise ValueError(f"Invalid graph path format '{path}'. Expected 'module:attribute'.")

module_name_importable, function_name = path.split(":")

try:
Expand All @@ -44,6 +47,12 @@ async def load_graph(path: str) -> CompiledGraph | None:
else:
raise TypeError("Loaded object is not a CompiledGraph.")

except ModuleNotFoundError as e:
logger.error(f"Module not found when loading graph from {path}: {e}")
raise ModuleNotFoundError(f"Module not found for graph path '{path}': {e}")
except AttributeError as e:
logger.error(f"Attribute not found when loading graph from {path}: {e}")
raise AttributeError(f"Attribute not found for graph path '{path}': {e}")
except Exception as e:
logger.error(f"Error loading graph from {path}: {e}")
raise Exception(f"Failed to load graph from {path}: {e}")
Expand All @@ -55,6 +64,9 @@ def load_checkpointer(path: str | None) -> BaseCheckpointer | None:
if not path:
return None

if ":" not in path:
raise ValueError(f"Invalid checkpointer path format '{path}'. Expected 'module:attribute'.")

module_name_importable, function_name = path.split(":")

try:
Expand All @@ -69,6 +81,12 @@ def load_checkpointer(path: str | None) -> BaseCheckpointer | None:
logger.info(f"Successfully loaded BaseCheckpointer '{function_name}' from {path}.")
else:
raise TypeError("Loaded object is not a BaseCheckpointer.")
except ModuleNotFoundError as e:
logger.error(f"Module not found when loading BaseCheckpointer from {path}: {e}")
raise ModuleNotFoundError(f"Module not found for checkpointer path '{path}': {e}")
except AttributeError as e:
logger.error(f"Attribute not found when loading BaseCheckpointer from {path}: {e}")
raise AttributeError(f"Attribute not found for checkpointer path '{path}': {e}")
except Exception as e:
logger.error(f"Error loading BaseCheckpointer from {path}: {e}")
raise Exception(f"Failed to load BaseCheckpointer from {path}: {e}")
Expand All @@ -80,6 +98,9 @@ def load_store(path: str | None) -> BaseStore | None:
if not path:
return None

if ":" not in path:
raise ValueError(f"Invalid store path format '{path}'. Expected 'module:attribute'.")

module_name_importable, function_name = path.split(":")

try:
Expand All @@ -94,6 +115,12 @@ def load_store(path: str | None) -> BaseStore | None:
logger.info(f"Successfully loaded graph '{function_name}' from {path}.")
else:
raise TypeError("Loaded object is not a BaseStore.")
except ModuleNotFoundError as e:
logger.error(f"Module not found when loading BaseStore from {path}: {e}")
raise ModuleNotFoundError(f"Module not found for store path '{path}': {e}")
except AttributeError as e:
logger.error(f"Attribute not found when loading BaseStore from {path}: {e}")
raise AttributeError(f"Attribute not found for store path '{path}': {e}")
except Exception as e:
logger.error(f"Error loading BaseStore from {path}: {e}")
raise Exception(f"Failed to load BaseStore from {path}: {e}")
Expand Down Expand Up @@ -134,13 +161,15 @@ def load_auth(path: str | None) -> BaseAuth | None:
if not path:
return None

if ":" not in path:
raise ValueError(f"Invalid auth path format '{path}'. Expected 'module:attribute'.")

module_name_importable, function_name = path.split(":")

try:
module = importlib.import_module(module_name_importable)
entry_point_obj = getattr(module, function_name)

# If it's a class, instantiate it; if it's an instance, use as is
if inspect.isclass(entry_point_obj) and issubclass(entry_point_obj, BaseAuth):
auth = entry_point_obj()
elif isinstance(entry_point_obj, BaseAuth):
Expand All @@ -149,6 +178,12 @@ def load_auth(path: str | None) -> BaseAuth | None:
raise TypeError("Loaded object is not a subclass or instance of BaseAuth.")

logger.info(f"Successfully loaded BaseAuth '{function_name}' from {path}.")
except ModuleNotFoundError as e:
logger.error(f"Module not found when loading BaseAuth from {path}: {e}")
raise ModuleNotFoundError(f"Module not found for auth path '{path}': {e}")
except AttributeError as e:
logger.error(f"Attribute not found when loading BaseAuth from {path}: {e}")
raise AttributeError(f"Attribute not found for auth path '{path}': {e}")
except Exception as e:
logger.error(f"Error loading BaseAuth from {path}: {e}")
raise Exception(f"Failed to load BaseAuth from {path}: {e}")
Expand Down
48 changes: 45 additions & 3 deletions agentflow_cli/src/app/routers/checkpointer/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any

from agentflow.state import Message
from fastapi import APIRouter, Depends, Request, status
from fastapi import APIRouter, Depends, HTTPException, Request, status
from injectq.integrations import InjectAPI

from agentflow_cli.src.app.core.auth.permissions import RequirePermission
Expand All @@ -26,6 +26,26 @@
router = APIRouter(tags=["checkpointer"])


def validate_thread_id(thread_id: int | str) -> None:
if isinstance(thread_id, str):
if not thread_id.strip():
raise HTTPException(
status_code=422,
detail="thread_id cannot be empty or whitespace"
)
elif isinstance(thread_id, int):
if thread_id < 1:
raise HTTPException(
status_code=422,
detail="thread_id must be a non-negative integer"
)
else:
raise HTTPException(
status_code=422,
detail="thread_id must be a string or integer"
)


@router.get(
"/v1/threads/{thread_id}/state",
status_code=status.HTTP_200_OK,
Expand All @@ -48,6 +68,8 @@ async def get_state(
Returns:
State response with state data or error
"""
validate_thread_id(thread_id)

config = {"thread_id": thread_id}

result = await service.get_state(
Expand Down Expand Up @@ -86,11 +108,11 @@ async def put_state(
Returns:
Success response or error
"""
validate_thread_id(thread_id)
config = {"thread_id": thread_id}
if payload.config:
config.update(payload.config)

# State is provided as dict; service will handle merging/reconstruction
res = await service.put_state(
config,
user,
Expand Down Expand Up @@ -127,6 +149,7 @@ async def clear_state(
Returns:
Success response or error
"""
validate_thread_id(thread_id)
config = {"thread_id": thread_id}

res = await service.clear_state(
Expand Down Expand Up @@ -168,7 +191,10 @@ async def put_messages(
Returns:
Success response or error
"""
# Convert message dicts to Message objects if needed
validate_thread_id(thread_id)
if not payload.messages:
raise HTTPException(status_code=422, detail="messages must not be empty")

config = {"thread_id": thread_id}
if payload.config:
config.update(payload.config)
Expand Down Expand Up @@ -213,6 +239,10 @@ async def get_message(
Returns:
Message response with message data or error
"""
validate_thread_id(thread_id)
if not message_id or (isinstance(message_id, str) and not str(message_id).strip()):
raise HTTPException(status_code=422, detail="message_id is required and cannot be empty")

config = {"thread_id": thread_id}

result = await service.get_message(
Expand Down Expand Up @@ -255,6 +285,12 @@ async def list_messages(
Returns:
Messages list response with messages data or error
"""
validate_thread_id(thread_id)
if offset is not None and offset < 0:
raise HTTPException(status_code=422, detail="offset must be >= 0")
if limit is not None and limit <= 0:
raise HTTPException(status_code=422, detail="limit must be > 0")

config = {"thread_id": thread_id}

result = await service.get_messages(
Expand Down Expand Up @@ -297,6 +333,10 @@ async def delete_message(
Returns:
Success response or error
"""
validate_thread_id(thread_id)
if not message_id or (isinstance(message_id, str) and not str(message_id).strip()):
raise HTTPException(status_code=422, detail="message_id is required and cannot be empty")

config = {"thread_id": thread_id}
if payload.config:
config.update(payload.config)
Expand Down Expand Up @@ -340,6 +380,7 @@ async def get_thread(
Returns:
Thread response with thread data or error
"""
validate_thread_id(thread_id)
result = await service.get_thread(
{"thread_id": thread_id},
user,
Expand Down Expand Up @@ -415,6 +456,7 @@ async def delete_thread(
Returns:
Success response or error
"""
validate_thread_id(thread_id)
config = {"thread_id": thread_id}
if payload.config:
config.update(payload.config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from agentflow.checkpointer import BaseCheckpointer
from agentflow.state import AgentState, Message
from fastapi import HTTPException
from injectq import inject, singleton

from agentflow_cli.src.app.core import logger
Expand All @@ -26,7 +27,7 @@ def __init__(self, checkpointer: BaseCheckpointer):

def _config(self, config: dict[str, Any] | None, user: dict) -> dict[str, Any]:
if not self.checkpointer:
raise ValueError("Checkpointer is not configured")
raise HTTPException(status_code=503, detail="Checkpointer service is not available")

cfg: dict[str, Any] = dict(config or {})
cfg["user"] = user
Expand Down
32 changes: 29 additions & 3 deletions agentflow_cli/src/app/routers/graph/schemas/graph_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from agentflow.state import Message
from agentflow.utils import ResponseGranularity
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator


class GraphInputSchema(BaseModel):
Expand All @@ -23,8 +23,18 @@ class GraphInputSchema(BaseModel):
)
recursion_limit: int = Field(
default=25,
ge=1,
le=100,
description="Maximum recursion limit for graph execution",
)

@field_validator("messages")
@classmethod
def messages_must_not_be_empty(cls, v: list[Message]) -> list[Message]:
if not v:
raise ValueError("messages must contain at least one message")
return v

response_granularity: ResponseGranularity = Field(
default=ResponseGranularity.LOW,
description="Granularity of the response (full, partial, low)",
Expand Down Expand Up @@ -111,7 +121,15 @@ class GraphSchema(BaseModel):
class GraphStopSchema(BaseModel):
"""Schema for stopping graph execution."""

thread_id: str = Field(..., description="Thread ID to stop execution for")
thread_id: str = Field(..., min_length=1, description="Thread ID to stop execution for")

@field_validator("thread_id")
@classmethod
def thread_id_not_blank(cls, v: str) -> str:
if not v.strip():
raise ValueError("thread_id cannot be empty or whitespace")
return v.strip()

config: dict[str, Any] | None = Field(
default=None, description="Optional configuration for the stop operation"
)
Expand All @@ -137,7 +155,15 @@ class GraphSetupSchema(BaseModel):
class FixGraphRequestSchema(BaseModel):
"""Schema for fixing graph state by removing messages with empty tool call content."""

thread_id: str = Field(..., description="Thread ID to fix the graph state for")
thread_id: str = Field(..., min_length=1, description="Thread ID to fix the graph state for")

@field_validator("thread_id")
@classmethod
def thread_id_not_blank(cls, v: str) -> str:
if not v.strip():
raise ValueError("thread_id cannot be empty or whitespace")
return v.strip()

config: dict[str, Any] | None = Field(
default=None, description="Optional configuration for the fix operation"
)
Expand Down
24 changes: 20 additions & 4 deletions agentflow_cli/src/app/routers/graph/services/graph_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@ async def stop_graph(
logger.info(f"Graph stop completed for thread {thread_id}: {result}")
return result

except ValueError as e:
logger.warning(f"Graph stop input validation failed for thread {thread_id}: {e}")
raise HTTPException(status_code=422, detail=str(e))
except Exception as e:
logger.error(f"Graph stop failed for thread {thread_id}: {e}")
raise HTTPException(
Expand All @@ -166,8 +169,8 @@ async def _prepare_input(
):
is_new_thread = False
config = graph_input.config or {}
if "thread_id" in config:
thread_id = config["thread_id"]
if config.get("thread_id") and str(config["thread_id"]).strip():
thread_id = str(config["thread_id"]).strip()
else:
thread_id = await InjectQ.get_instance().atry_get("generated_id") or str(uuid4())
is_new_thread = True
Expand Down Expand Up @@ -258,6 +261,9 @@ async def invoke_graph(
meta=meta,
)

except ValueError as e:
logger.warning(f"Graph input validation failed: {e}")
raise HTTPException(status_code=422, detail=str(e))
except Exception as e:
logger.error(f"Graph execution failed: {e}")
raise HTTPException(status_code=500, detail=f"Graph execution failed: {e!s}")
Expand Down Expand Up @@ -334,26 +340,33 @@ async def stream_graph(
+ "\n"
)

except ValueError as e:
logger.warning(f"Graph stream input validation failed: {e}")
raise HTTPException(status_code=422, detail=str(e))
except Exception as e:
logger.error(f"Graph streaming failed: {e}")
raise HTTPException(status_code=500, detail=f"Graph streaming failed: {e!s}")

async def graph_details(self) -> GraphSchema:
try:
logger.info("Getting graph details")
# Fetch and return graph details
res = self._graph.generate_graph()
return GraphSchema(**res)
except ValueError as e:
logger.warning(f"Graph details validation failed: {e}")
raise HTTPException(status_code=422, detail=str(e))
except Exception as e:
logger.error(f"Failed to get graph details: {e}")
raise HTTPException(status_code=500, detail=f"Failed to get graph details: {e!s}")

async def get_state_schema(self) -> dict:
try:
logger.info("Getting state schema")
# Fetch and return state schema
res: BaseModel = self._graph._state
return res.model_json_schema()
except ValueError as e:
logger.warning(f"State schema validation failed: {e}")
raise HTTPException(status_code=422, detail=str(e))
except Exception as e:
logger.error(f"Failed to get state schema: {e}")
raise HTTPException(status_code=500, detail=f"Failed to get state schema: {e!s}")
Expand Down Expand Up @@ -451,6 +464,9 @@ async def fix_graph(
"removed_count": removed_count,
"state": state.model_dump_json(serialize_as_any=True),
}
except ValueError as e:
logger.warning(f"Fix graph input validation failed: {e}")
raise HTTPException(status_code=422, detail=str(e))
except Exception as e:
logger.error(f"Fix graph operation failed: {e}")
raise HTTPException(status_code=500, detail=f"Fix graph operation failed: {e!s}")
Expand Down
Loading