Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,146 changes: 849 additions & 297 deletions poetry.lock

Large diffs are not rendered by default.

13 changes: 13 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,16 @@ dependencies = [
"pytest-env (>=1.1.5,<2.0.0)",
"langfuse>=2.0.0",
"openlit>=1.35.0",
"numpy (>=1.26,<2.0)",
"scipy (>=1.11,<1.15)",
"sympy (>=1.12,<2.0)",
"numpy-financial (>=1.0.0,<2.0.0)",
"py-vollib (>=1.0.1,<2.0.0)",
"pyportfolioopt (>=1.6.0,<2.0.0)",
"empyrical (>=0.5.5,<0.6.0)",
"arch (>=6.0,<7.0)",
"statsmodels (>=0.14,<0.15)",
"python-dotenv (>=1.2.2,<2.0.0)",
]

[project.urls]
Expand All @@ -64,6 +74,7 @@ mypy = "^1.15.0"
ruff = ">=0.11.4,<0.12.0"
nbqa = { version = "^1.7.0", extras = ["toolchain"] }
pip-audit = "^2.7.1"
pytest-asyncio = "^1.3.0"

[tool.poetry.group.docs]
optional = true
Expand Down Expand Up @@ -179,6 +190,8 @@ filterwarnings = [
]
# Exclude legacy tests (imports are broken after code was moved)
norecursedirs = ["legacy"]
asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "function"

[tool.coverage]
[tool.coverage.run]
Expand Down
47 changes: 46 additions & 1 deletion src/task_solver/messages.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Message types for task solving debate system."""

from dataclasses import dataclass
from typing import Dict, List
from typing import Dict, List, Union


@dataclass
Expand Down Expand Up @@ -37,6 +37,7 @@ class AgentSolution:
round_number: int
capability_name: str
area_name: str
solution_type: str = "standard" # Discriminator for Union serialization

def to_dict(self) -> Dict[str, str]:
"""Convert to dictionary."""
Expand All @@ -52,6 +53,46 @@ def to_dict(self) -> Dict[str, str]:
}


@dataclass
class ToolAssistedAgentSolution:
"""Solution proposed by a tool-assisted agent with code execution details.

Note: code and code_output use empty string as default instead of None
to avoid Union type issues with autogen_core serialization.
"""

agent_id: str
task_id: str
thought: str
final_answer: str
numerical_answer: str
round_number: int
capability_name: str
area_name: str
solution_type: str = "tool_assisted" # Discriminator for Union serialization
code: str = ""
code_output: str = ""

def to_dict(self) -> Dict[str, str]:
"""Convert to dictionary."""
result = {
"agent_id": self.agent_id,
"task_id": self.task_id,
"thought": self.thought,
"final_answer": self.final_answer,
"numerical_answer": self.numerical_answer,
"round_number": str(self.round_number),
"capability_name": self.capability_name,
"area_name": self.area_name,
}
# Include code fields if present (not empty)
if self.code:
result["code"] = self.code
if self.code_output:
result["code_output"] = self.code_output
return result


@dataclass
class AgentRevisionRequest:
"""Request for agent to revise solution based on other agents' solutions."""
Expand All @@ -73,6 +114,10 @@ class ConsensusCheck:
round_number: int


# Type alias for solutions that can appear in the debate
SolutionUnion = Union[AgentSolution, ToolAssistedAgentSolution]


@dataclass
class FinalSolution:
"""Final solution for a task."""
Expand Down
93 changes: 79 additions & 14 deletions src/task_solver/moderator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@
AgentRevisionRequest,
AgentSolution,
FinalSolution,
SolutionUnion,
Task,
TaskSolutionRequest,
ToolAssistedAgentSolution,
)
from src.utils.agentic_prompts import (
TASK_MODERATOR_CONSENSUS_PROMPT,
Expand Down Expand Up @@ -80,7 +82,8 @@ def __init__(
self._langfuse_client = langfuse_client

# Track solutions by task_id and round
self._solutions_buffer: Dict[int, List[AgentSolution]]
# Now properly typed with discriminator fields in solution classes
self._solutions_buffer: dict[int, list[SolutionUnion]] = {}
self._current_round = 0
self._final_solutions: FinalSolution
self._tasks: Task # Store original tasks for consensus checking
Expand Down Expand Up @@ -111,9 +114,13 @@ def _extract_consensus_components(
raise

def _check_simple_consensus(
self, solutions: List[AgentSolution]
self, solutions: list[SolutionUnion]
) -> tuple[bool, str, str]:
"""Check consensus; if all agents have the same final answer."""
"""Check consensus; if all agents have the same final answer.

Works with both AgentSolution and ToolAssistedAgentSolution since
both have final_answer and numerical_answer fields.
"""
if not solutions or len(solutions) < self._num_solvers:
return False, "", "null"

Expand Down Expand Up @@ -234,6 +241,74 @@ async def handle_agent_solution(
log.error(traceback.format_exc())
span.update(metadata={"error": error_msg})

@message_handler
async def handle_tool_assisted_agent_solution(
self, message: ToolAssistedAgentSolution, ctx: MessageContext
) -> None:
"""Handle solution from a tool-assisted agent.

Tool-assisted solutions include code and code_output fields that need
to be preserved in the final solution files.
"""
with self._langfuse_client.start_as_current_span(
name=f"moderator_handle_tool_assisted_solution_{message.task_id}_round_{message.round_number}"
) as span:
try:
task_id = message.task_id
round_num = message.round_number

msg = f"Moderator received tool-assisted solution from agent {message.agent_id} for task {task_id}, {message.capability_name}, {message.area_name} round {round_num}"
log.info(msg)
log.debug(
"Moderator: Tool-assisted solution has code: %s, has code_output: %s",
bool(message.code),
bool(message.code_output),
)
if message.code:
log.debug(
"Moderator: Code length in received message: %d characters",
len(message.code),
)
span.update(
metadata={
"solution_received": msg,
"task_id": task_id,
"agent_id": message.agent_id,
"round": round_num,
"has_code": bool(message.code),
"code_executed": bool(message.code_output),
}
)

if round_num != self._current_round:
msg = f"Moderator received solution from agent {message.agent_id} for task {task_id}, {message.capability_name}, {message.area_name} round {round_num} but current round is {self._current_round}"
log.error(msg)
span.update(metadata={"error": msg})
raise Exception(msg)

# Initialize round buffer if needed
if self._current_round not in self._solutions_buffer:
self._solutions_buffer[self._current_round] = []

# Add solution to buffer - store the ToolAssistedAgentSolution directly
self._solutions_buffer[self._current_round].append(message)

msg = f"{len(self._solutions_buffer[self._current_round])}/{self._num_solvers} solutions collected for round {self._current_round}"
log.info(msg)
span.update(metadata={"solutions_collected": msg})

if (
len(self._solutions_buffer[self._current_round])
== self._num_solvers
):
await self._check_consensus_and_proceed(task_id, ctx)

except Exception as e:
error_msg = f"Error handling tool-assisted solution from agent {message.agent_id}: {str(e)}"
log.error(error_msg)
log.error(traceback.format_exc())
span.update(metadata={"error": error_msg})

async def _check_consensus_and_proceed(
self, task_id: str, ctx: MessageContext
) -> None:
Expand Down Expand Up @@ -420,17 +495,7 @@ async def _save_final_solution(self, final_solution: FinalSolution) -> None:
"reasoning": final_solution.reasoning,
"consensus_reached": final_solution.consensus_reached,
"total_rounds": final_solution.total_rounds,
"all_solutions": [
{
"agent_id": sol["agent_id"],
"task_id": sol["task_id"],
"thought": sol["thought"],
"final_answer": sol["final_answer"],
"numerical_answer": sol["numerical_answer"],
"round_number": sol["round_number"],
}
for sol in final_solution.all_solutions
],
"all_solutions": final_solution.all_solutions, # Include all fields from to_dict()
}

with open(output_file, "w") as f:
Expand Down
Loading