diff --git a/TOKENUSAGETRACKER_IMPLEMENTATION.md b/TOKENUSAGETRACKER_IMPLEMENTATION.md new file mode 100644 index 0000000..1cf320f --- /dev/null +++ b/TOKENUSAGETRACKER_IMPLEMENTATION.md @@ -0,0 +1,200 @@ +# TokenUsageTracker Implementation for SDK Mode + +## Overview + +This implementation addresses issues [#37](https://github.ibm.com/research-rpa/cuga-internal-evaluation/issues/37) and [cuga-agent#71](https://github.com/cuga-project/cuga-agent/issues/71) by adding TokenUsageTracker-like functionality to SDK mode (CugaAgent) evaluations. + +## Problem Statement + +Previously, benchmarks invoked CUGA in two different ways with inconsistent trajectory outputs: + +- **SDK mode** (BPO, M3, Oak): Used `CugaAgent` with minimal trajectory data +- **AgentRunner mode** (AppWorld): Used full agent loop with rich trajectory data including all LLM prompts + +This inconsistency made cross-benchmark trajectory analysis and visualization unreliable. + +## Solution + +Created a LangChain callback handler (`SDKTokenUsageTrackerCallback`) that mimics the behavior of `TokenUsageTracker` from cuga-agent's `agent_loop.py`. This callback: + +1. Captures system and user prompts when LLM calls start (`on_llm_start`) +2. Captures assistant responses when LLM calls complete (`on_llm_end`) +3. Forwards all captured data to `ActivityTracker` via `collect_prompt()` + +## Files Changed + +### New Files + +1. **`benchmarks/helpers/token_usage_tracker_callback.py`** + - New module containing `SDKTokenUsageTrackerCallback` class + - Factory function `create_token_usage_tracker_callback(tracker)` + - Implements LangChain `AsyncCallbackHandler` interface + +### Modified Files + +1. **`benchmarks/helpers/sdk_eval_helpers.py`** + - Updated `setup_agent_with_tools()` to automatically add TokenUsageTracker callback + - New parameter: `enable_token_usage_tracker` (default: True) + - Automatically enabled for all benchmarks using this helper + +2. **`benchmarks/m3/eval_m3.py`** + - Added TokenUsageTracker callback to CugaAgent creation (line ~1392) + - Uses global ActivityTracker singleton + +3. **`benchmarks/m3/eval_m3_task_1_enterprise_style.py`** + - Added TokenUsageTracker callback to CugaAgent creation (line ~320) + - Uses existing tracker instance + +## Affected Benchmarks + +All four benchmarks now have TokenUsageTracker support: + +### ✅ Automatically Enabled (via `setup_agent_with_tools`) + +1. **BPO** (`benchmarks/bpo/eval_bench_sdk.py`) + - Uses `setup_agent_with_tools()` → automatically gets callback + +2. **M3 Multi-turn** (`benchmarks/m3/eval_m3_multiturn.py`) + - Uses `setup_agent_with_tools()` → automatically gets callback + +3. **Oak Health Insurance** (`benchmarks/oak_health_insurance/eval_bench_sdk.py`) + - Uses `setup_agent_with_tools()` → automatically gets callback + +4. **AppWorld SDK** (`benchmarks/appworld/eval_appworld_sdk.py`) + - Uses `setup_agent_with_tools()` → automatically gets callback + +### ✅ Manually Added + +5. **M3 Single-turn** (`benchmarks/m3/eval_m3.py`) + - Manually added callback during CugaAgent creation + +6. **M3 Task 1 Enterprise** (`benchmarks/m3/eval_m3_task_1_enterprise_style.py`) + - Manually added callback during CugaAgent creation + +## Usage + +### For New Benchmarks + +If using `setup_agent_with_tools()`, TokenUsageTracker is automatically enabled: + +```python +from benchmarks.helpers import setup_agent_with_tools + +# Automatically includes TokenUsageTracker callback +agent, langfuse_handler = await setup_agent_with_tools() +``` + +To disable (not recommended): + +```python +agent, langfuse_handler = await setup_agent_with_tools( + enable_token_usage_tracker=False +) +``` + +### For Direct CugaAgent Creation + +If creating CugaAgent directly, add the callback manually: + +```python +from cuga.backend.activity_tracker.tracker import ActivityTracker +from cuga.sdk import CugaAgent +from benchmarks.helpers.token_usage_tracker_callback import create_token_usage_tracker_callback + +tracker = ActivityTracker() # Singleton +callbacks = [langfuse_handler] if langfuse_handler else [] + +# Add TokenUsageTracker callback +try: + token_tracker_callback = create_token_usage_tracker_callback(tracker) + callbacks.append(token_tracker_callback) + logger.info("✅ TokenUsageTracker callback enabled") +except Exception as e: + logger.warning(f"Failed to enable TokenUsageTracker callback: {e}") + +agent = CugaAgent( + tool_provider=tool_provider, + callbacks=callbacks, +) +``` + +## Expected Impact + +### Before + +SDK mode trajectories had: +- Empty `prompts` fields on most steps (except manually added UserPrompt) +- Minimal step granularity (5 steps: Raw_Assistant_Response, Assistant_nl, FinalAnswerAgent, UserPrompt, EvaluationResult) +- Token tracking via Langfuse only + +### After + +SDK mode trajectories now have: +- **Full LLM conversation history** in `prompts` fields +- System prompts, user prompts, and assistant responses captured for every LLM call +- Same prompt richness as AgentRunner mode +- Compatible with cuga-viz and other trajectory analysis tools + +### Trajectory File Structure + +Each step in the trajectory now includes: + +```json +{ + "name": "UserPrompt", + "data": "...", + "prompts": [ + { + "role": "system", + "value": "You are a helpful assistant..." + }, + { + "role": "user", + "value": "What is the weather today?" + }, + { + "role": "assistant", + "value": "I'll check the weather for you..." + } + ] +} +``` + +## Testing + +To verify the implementation: + +1. Run any SDK-mode evaluation: + ```bash + ./benchmarks/bpo/eval.sh --task 1 + ``` + +2. Check the trajectory file in `benchmarks/bpo/logging/trajectory_data/` + +3. Verify that steps now include `prompts` arrays with system/user/assistant messages + +4. Compare with previous trajectories to see the enrichment + +## Limitations + +This is a **workaround** until the upstream fix in cuga-agent#71 is implemented. The callback approach: + +- ✅ Captures all LLM prompts and responses +- ✅ Works with existing SDK code without breaking changes +- ✅ Compatible with all benchmarks +- ⚠️ Does not provide the same step-by-step granularity as AgentRunner (5 steps vs 50+ steps) +- ⚠️ Requires manual addition for direct CugaAgent instantiations + +## Future Work + +Once [cuga-agent#71](https://github.com/cuga-project/cuga-agent/issues/71) is implemented: + +1. Remove this workaround callback +2. Update to use native TokenUsageTracker from cuga-agent +3. Achieve full parity between SDK and AgentRunner modes + +## Related Issues + +- [cuga-internal-evaluation#37](https://github.ibm.com/research-rpa/cuga-internal-evaluation/issues/37) - Standardize CUGA invocation mode across benchmarks +- [cuga-agent#71](https://github.com/cuga-project/cuga-agent/issues/71) - Instrument CugaAgent SDK with TokenUsageTracker +- [cuga-internal-evaluation#31](https://github.ibm.com/research-rpa/cuga-internal-evaluation/issues/31) - Changes and fixes to new evaluation framework (closed) diff --git a/benchmarks/helpers/sdk_eval_helpers.py b/benchmarks/helpers/sdk_eval_helpers.py index c96e3e6..0837f97 100644 --- a/benchmarks/helpers/sdk_eval_helpers.py +++ b/benchmarks/helpers/sdk_eval_helpers.py @@ -275,12 +275,14 @@ def _langfuse_trace_root_log_message(agent: Any) -> str: async def setup_agent_with_tools( special_instructions: Optional[str] = None, extra_callbacks: Optional[List[Any]] = None, + enable_token_usage_tracker: bool = True, ) -> tuple[CugaAgent, Optional[Any]]: """Set up CugaAgent with tools and Langfuse tracing. Args: special_instructions: Optional special instructions to pass to the agent extra_callbacks: Optional additional LangChain callbacks (e.g. TokenUsageCallback) + enable_token_usage_tracker: Whether to enable TokenUsageTracker-like callback for rich trajectories (default: True) Returns: Tuple of (agent, langfuse_handler) @@ -300,6 +302,21 @@ async def setup_agent_with_tools( logger.info("ℹ️ Langfuse not available (optional)") callbacks = list(extra_callbacks) if extra_callbacks else [] + + # Add TokenUsageTracker-like callback for rich trajectory capture + if enable_token_usage_tracker: + try: + from cuga.backend.activity_tracker.tracker import ActivityTracker + + from benchmarks.helpers.token_usage_tracker_callback import create_token_usage_tracker_callback + + tracker = ActivityTracker() # Singleton - returns existing instance + token_tracker_callback = create_token_usage_tracker_callback(tracker) + callbacks.append(token_tracker_callback) + logger.info("✅ TokenUsageTracker callback enabled for rich trajectory capture") + except Exception as e: + logger.warning(f"Failed to enable TokenUsageTracker callback: {e}") + agent_kwargs = {"tool_provider": tool_provider} if callbacks: agent_kwargs["callbacks"] = callbacks diff --git a/benchmarks/helpers/token_usage_tracker_callback.py b/benchmarks/helpers/token_usage_tracker_callback.py new file mode 100644 index 0000000..2cc7841 --- /dev/null +++ b/benchmarks/helpers/token_usage_tracker_callback.py @@ -0,0 +1,177 @@ +"""TokenUsageTracker-like callback for SDK mode to capture LLM prompts and responses. + +This module provides a LangChain callback handler that mimics the behavior of +TokenUsageTracker from cuga-agent, enabling SDK mode (CugaAgent) to produce +trajectory files with the same prompt richness as AgentRunner mode. + +Related issues: +- cuga-internal-evaluation#37: Standardize CUGA invocation mode across benchmarks +- cuga-agent#71: Instrument CugaAgent SDK with TokenUsageTracker +""" + +from typing import Any, Dict, List, Optional + +from langchain_core.callbacks import AsyncCallbackHandler +from langchain_core.messages import BaseMessage +from langchain_core.outputs import LLMResult +from loguru import logger + + +class SDKTokenUsageTrackerCallback(AsyncCallbackHandler): + """LangChain callback handler that captures prompts and responses for ActivityTracker. + + This callback mimics the behavior of TokenUsageTracker from cuga-agent's agent_loop.py, + enabling SDK mode to produce rich trajectory files with full LLM conversation history. + + The callback captures: + - on_llm_start: System and user prompts + - on_llm_end: Assistant responses + - on_llm_error: Error information + + All captured data is forwarded to ActivityTracker via collect_prompt() and collect_step(). + """ + + def __init__(self, tracker: Any): + """Initialize the callback with an ActivityTracker instance. + + Args: + tracker: ActivityTracker instance to collect prompts and steps + """ + super().__init__() + self.tracker = tracker + self._call_count = 0 + logger.debug("SDKTokenUsageTrackerCallback initialized") + + async def on_llm_start( + self, + serialized: Dict[str, Any], + prompts: List[str], + *, + run_id: Any, + parent_run_id: Optional[Any] = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> None: + """Capture prompts when LLM call starts. + + This method extracts system and user messages from the prompts and + forwards them to ActivityTracker.collect_prompt(). + """ + self._call_count += 1 + logger.debug(f"[SDKTokenUsageTracker] LLM call #{self._call_count} started") + + # Extract messages from kwargs if available (LangChain passes messages here) + messages = kwargs.get("invocation_params", {}).get("messages", []) + if not messages and "messages" in kwargs: + messages = kwargs["messages"] + + # If we have structured messages, extract them + if messages: + for msg in messages: + if isinstance(msg, BaseMessage): + role = msg.type if hasattr(msg, 'type') else 'unknown' + content = msg.content if hasattr(msg, 'content') else str(msg) + + # Map LangChain message types to standard roles + if role == 'system': + self.tracker.collect_prompt(role="system", value=content) + logger.debug(f"[SDKTokenUsageTracker] Captured system prompt ({len(content)} chars)") + elif role == 'human': + self.tracker.collect_prompt(role="user", value=content) + logger.debug(f"[SDKTokenUsageTracker] Captured user prompt ({len(content)} chars)") + elif role == 'ai': + # Sometimes previous AI messages are included in context + self.tracker.collect_prompt(role="assistant", value=content) + logger.debug( + f"[SDKTokenUsageTracker] Captured assistant context ({len(content)} chars)" + ) + elif isinstance(msg, dict): + # Handle dict-style messages + role = msg.get('role', msg.get('type', 'unknown')) + content = msg.get('content', str(msg)) + + if role in ['system', 'user', 'assistant']: + self.tracker.collect_prompt(role=role, value=content) + logger.debug(f"[SDKTokenUsageTracker] Captured {role} prompt ({len(content)} chars)") + + # Fallback: if no structured messages, try to parse prompts list + elif prompts: + for prompt in prompts: + # Simple heuristic: if prompt is very long, it likely contains system instructions + if len(prompt) > 500: + self.tracker.collect_prompt(role="system", value=prompt) + logger.debug(f"[SDKTokenUsageTracker] Captured prompt as system ({len(prompt)} chars)") + else: + self.tracker.collect_prompt(role="user", value=prompt) + logger.debug(f"[SDKTokenUsageTracker] Captured prompt as user ({len(prompt)} chars)") + + async def on_llm_end( + self, + response: LLMResult, + *, + run_id: Any, + parent_run_id: Optional[Any] = None, + **kwargs: Any, + ) -> None: + """Capture assistant response when LLM call completes. + + This method extracts the generated text and forwards it to + ActivityTracker.collect_prompt() as an assistant message. + """ + logger.debug(f"[SDKTokenUsageTracker] LLM call #{self._call_count} completed") + + # Extract the generated text from the response + if response.generations: + for generation_list in response.generations: + for generation in generation_list: + text = generation.text if hasattr(generation, 'text') else str(generation) + if text: + self.tracker.collect_prompt(role="assistant", value=text) + logger.debug( + f"[SDKTokenUsageTracker] Captured assistant response ({len(text)} chars)" + ) + + # Also capture token usage if available + if response.llm_output: + token_usage = response.llm_output.get('token_usage', {}) + if token_usage: + logger.debug(f"[SDKTokenUsageTracker] Token usage: {token_usage}") + + async def on_llm_error( + self, + error: BaseException, + *, + run_id: Any, + parent_run_id: Optional[Any] = None, + **kwargs: Any, + ) -> None: + """Capture error information when LLM call fails.""" + logger.warning(f"[SDKTokenUsageTracker] LLM call #{self._call_count} failed: {error}") + + # Record error as a step + from cuga.backend.activity_tracker.tracker import Step + + error_msg = f"LLM Error: {str(error)}" + self.tracker.collect_step(Step(name="LLM_Error", data=error_msg)) + + +def create_token_usage_tracker_callback(tracker: Any) -> SDKTokenUsageTrackerCallback: + """Factory function to create a TokenUsageTracker-like callback. + + Args: + tracker: ActivityTracker instance + + Returns: + SDKTokenUsageTrackerCallback instance + + Example: + >>> from cuga.backend.activity_tracker.tracker import ActivityTracker + >>> tracker = ActivityTracker() + >>> callback = create_token_usage_tracker_callback(tracker) + >>> agent = CugaAgent(tool_provider=provider, callbacks=[langfuse_handler, callback]) + """ + return SDKTokenUsageTrackerCallback(tracker) + + +# Made with Bob diff --git a/benchmarks/m3/eval_m3.py b/benchmarks/m3/eval_m3.py index e35baf8..6748ac8 100644 --- a/benchmarks/m3/eval_m3.py +++ b/benchmarks/m3/eval_m3.py @@ -1393,6 +1393,21 @@ def _dom_name(dc): langfuse_handler = setup_langfuse() callbacks = [langfuse_handler] if langfuse_handler else [] + # Add TokenUsageTracker callback for rich trajectory capture + try: + from cuga.backend.activity_tracker.tracker import ActivityTracker + + from benchmarks.helpers.token_usage_tracker_callback import ( + create_token_usage_tracker_callback, + ) + + tracker = ActivityTracker() # Singleton - returns existing instance + token_tracker_callback = create_token_usage_tracker_callback(tracker) + callbacks.append(token_tracker_callback) + logger.info("✅ TokenUsageTracker callback enabled for rich trajectory capture") + except Exception as e: + logger.warning(f"Failed to enable TokenUsageTracker callback: {e}") + evaluator.agent = CugaAgent( tool_provider=filtered_provider, # Only sees this domain's tools callbacks=callbacks, diff --git a/benchmarks/m3/eval_m3_task_1_enterprise_style.py b/benchmarks/m3/eval_m3_task_1_enterprise_style.py index 46eed34..d5dd5b0 100644 --- a/benchmarks/m3/eval_m3_task_1_enterprise_style.py +++ b/benchmarks/m3/eval_m3_task_1_enterprise_style.py @@ -322,13 +322,24 @@ async def run_benchmark_for_domain_single_connection( # Create agent with tool_provider (not raw tools) # Note: Code executor timeout is hardcoded at 30s in CUGA # For slow queries, consider increasing container resources or using --max-samples - if langfuse_handler: - agent = CugaAgent( - tool_provider=tool_provider, - callbacks=[langfuse_handler], + callbacks = [langfuse_handler] if langfuse_handler else [] + + # Add TokenUsageTracker callback for rich trajectory capture + try: + from benchmarks.helpers.token_usage_tracker_callback import ( + create_token_usage_tracker_callback, ) - else: - agent = CugaAgent(tool_provider=tool_provider) + + token_tracker_callback = create_token_usage_tracker_callback(tracker) + callbacks.append(token_tracker_callback) + logger.info("✅ TokenUsageTracker callback enabled for rich trajectory capture") + except Exception as e: + logger.warning(f"Failed to enable TokenUsageTracker callback: {e}") + + agent = CugaAgent( + tool_provider=tool_provider, + callbacks=callbacks if callbacks else None, + ) logger.info(f"✅ Agent created with {len(tools)} tools via DirectLangChainToolsProvider") logger.warning("⚠️ Code executor timeout is 30s (hardcoded in CUGA). Slow queries may timeout.") diff --git a/pyproject.toml b/pyproject.toml index c3d1b43..1cc4766 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,8 @@ dependencies = [ # but langchain-docling 2.0.0 (a transitive of cuga) still imports it. # Pin to 2.91.x until langchain-docling catches up with the slim layout. "docling>=2.26,<2.92", + "langchain-core>=1.3.3", + "python-multipart>=0.0.27", ] [tool.uv] diff --git a/uv.lock b/uv.lock index f34ae3d..3d14bee 100644 --- a/uv.lock +++ b/uv.lock @@ -790,11 +790,13 @@ dependencies = [ { name = "cugaviz" }, { name = "docling" }, { name = "griffelib" }, + { name = "langchain-core" }, { name = "langfuse" }, { name = "loguru" }, { name = "pyarrow" }, { name = "pydantic" }, { name = "python-dotenv" }, + { name = "python-multipart" }, { name = "rapidfuzz" }, { name = "sentence-transformers" }, ] @@ -822,11 +824,13 @@ requires-dist = [ { name = "cugaviz" }, { name = "docling", specifier = ">=2.26,<2.92" }, { name = "griffelib" }, + { name = "langchain-core", specifier = ">=1.3.3" }, { name = "langfuse", specifier = ">=3" }, { name = "loguru" }, { name = "pyarrow", specifier = ">=23.0.0" }, { name = "pydantic" }, { name = "python-dotenv", specifier = ">=1.0.0" }, + { name = "python-multipart", specifier = ">=0.0.27" }, { name = "rapidfuzz" }, { name = "sentence-transformers" }, ]