diff --git a/EXTENSION_POINTS.md b/EXTENSION_POINTS.md index 2522d7f..f215b4e 100644 --- a/EXTENSION_POINTS.md +++ b/EXTENSION_POINTS.md @@ -142,9 +142,10 @@ string is a minor bump. ### 2. `runtime_context` -**Version:** `1.0.0` +**Version:** `1.1.0` **Default:** `current_context_id` returns `None`; `with_context` yields -the callable's result without binding any state (single-scope mode). +the callable's result without binding any state (single-scope mode); +`bind_context` returns a no-op async context manager. ```python from typing import Protocol, Callable, TypeVar @@ -154,6 +155,10 @@ T = TypeVar("T") class RuntimeContext(Protocol): def current_context_id(self, request) -> str | None: ... def with_context(self, context_id: str, fn: Callable[[], T]) -> T: ... + # Optional (added in 1.1.0). If present, callers MAY use it via + # the public `bind_context(...)` shim to scope per-request binding + # across awaits. + # def bind_context(self, context_id: str) -> AsyncContextManager[None]: ... ``` `request` is the framework-native request object (FastAPI / Starlette @@ -183,6 +188,14 @@ evo_extension_points.replace("runtime_context", MyRuntimeContext()) from `str | None`, is a major bump. Adding sibling helpers is a minor bump. +**Why `bind_context` is async-only.** Per-request context binding has to +survive every `await` between the request handler and the next operation +that may observe the bound state. `with_context(fn)` is synchronous by +contract — if `fn` returns a coroutine, the consumer would reset its +binding before the caller awaits it. The dedicated async context manager +keeps the binding alive across awaits and guarantees deterministic reset +on exit (including on exception). + ### 3. `usage_reporter` **Version:** `1.0.0` @@ -319,4 +332,5 @@ document itself is unversioned. `reset(name)`. - `capability_gate` `1.0.0` — Initial contract. - `runtime_context` `1.0.0` — Initial contract. +- `runtime_context` `1.1.0` — Adds optional `bind_context(context_id) -> AsyncContextManager[None]` sibling for scoping per-request binding across awaits. - `usage_reporter` `1.0.0` — Initial contract. diff --git a/src/api/a2a_routes.py b/src/api/a2a_routes.py index 11f6804..2a74eef 100644 --- a/src/api/a2a_routes.py +++ b/src/api/a2a_routes.py @@ -1057,6 +1057,7 @@ async def handle_message_send( files=files if files else None, metadata=metadata, user_id=user_id, # Pass contact_id as user_id + request=request, ) final_response = result.get("final_response", "No response") diff --git a/src/api/chat_routes.py b/src/api/chat_routes.py index 4be13a6..491cb8c 100644 --- a/src/api/chat_routes.py +++ b/src/api/chat_routes.py @@ -646,6 +646,7 @@ async def chat( db, session_id=session_id, files=payload.files, + request=request, ) return success_response( diff --git a/src/evo_extension_points/runtime_context.py b/src/evo_extension_points/runtime_context.py index 7746acb..0850e2f 100644 --- a/src/evo_extension_points/runtime_context.py +++ b/src/evo_extension_points/runtime_context.py @@ -2,16 +2,19 @@ Community default: ``current_context_id`` returns ``None``; ``with_context`` yields the callable's result without binding any state -(single-scope mode). +(single-scope mode); ``bind_context`` returns a no-op async context +manager so consumers can wire per-request context binding without the +community having to know what binding means. """ from __future__ import annotations -from typing import Any, Callable, Protocol, TypeVar, runtime_checkable +from contextlib import AbstractAsyncContextManager, asynccontextmanager +from typing import Any, AsyncIterator, Callable, Protocol, TypeVar, runtime_checkable from . import registry -VERSION: str = "1.0.0" +VERSION: str = "1.1.0" T = TypeVar("T") @@ -35,6 +38,11 @@ def with_context(self, context_id: str, fn: Callable[[], T]) -> T: registry._register_protocol("runtime_context", RuntimeContext) +@asynccontextmanager +async def _null_bind(_context_id: str) -> AsyncIterator[None]: + yield + + def current_context_id(source: Any = None) -> str | None: impl = registry.impl_for("runtime_context") or _DEFAULT return impl.current_context_id(source) @@ -43,3 +51,18 @@ def current_context_id(source: Any = None) -> str | None: def with_context(context_id: str, fn: Callable[[], T]) -> T: impl = registry.impl_for("runtime_context") or _DEFAULT return impl.with_context(context_id, fn) + + +def bind_context(context_id: str) -> AbstractAsyncContextManager[None]: + """Return an async context manager that binds ``context_id`` for the + enclosed scope. + + Optional EP method (added in 1.1.0). The community default is a no-op + async CM; consumers may expose a ``bind_context(context_id)`` + method on their impl to participate. Callers must use ``async with``. + """ + impl = registry.impl_for("runtime_context") or _DEFAULT + bind = getattr(impl, "bind_context", None) + if bind is None: + return _null_bind(context_id) + return bind(context_id) diff --git a/src/services/adk/agent_runner.py b/src/services/adk/agent_runner.py index 3a1970d..67bf8d0 100644 --- a/src/services/adk/agent_runner.py +++ b/src/services/adk/agent_runner.py @@ -34,7 +34,7 @@ from src.services.adk.runners.streaming_runner import StreamingRunner from src.services.adk.runners.live_runner import LiveRunner from sqlalchemy.orm import Session -from typing import Optional, AsyncGenerator, Dict, Any +from typing import Any, AsyncGenerator, Dict, Optional async def run_agent( @@ -50,6 +50,7 @@ async def run_agent( files: Optional[list] = None, metadata: Optional[Dict[str, Any]] = None, user_id: Optional[str] = None, + request: Any = None, ) -> Dict[str, Any]: """Execute a non-streaming agent request.""" runner = StandardRunner(db) @@ -65,6 +66,7 @@ async def run_agent( files=files, metadata=metadata, user_id=user_id, + request=request, ) diff --git a/src/services/adk/runners/standard_runner.py b/src/services/adk/runners/standard_runner.py index 5ac760e..6537c2b 100644 --- a/src/services/adk/runners/standard_runner.py +++ b/src/services/adk/runners/standard_runner.py @@ -44,6 +44,7 @@ runtime_context, usage_reporter, ) +from contextlib import nullcontext import uuid logger = setup_logger(__name__) @@ -69,6 +70,7 @@ async def run_agent( files: Optional[list] = None, metadata: Optional[Dict[str, Any]] = None, user_id: Optional[str] = None, + request: Any = None, ) -> Dict[str, Any]: """Execute a non-streaming agent request.""" try: @@ -97,451 +99,456 @@ async def run_agent( # Extension point: runtime context resolution. Default returns # None; consumer overrides return an operational context id that # is logged here and (in a follow-up) propagated into metrics. - context_id = runtime_context.current_context_id(metadata) + context_id = runtime_context.current_context_id(request if request is not None else metadata) if context_id: logger.info( f"runtime_context resolved id={context_id!r}" f" for agent={agent_id}" ) - # Get and build agent - root_agent, state_params = await self.utils.get_and_build_agent(agent_id) - - # Setup session - adk_session_id = self.utils.create_session_id( - external_id, agent_id, session_id - ) + # Bind the resolved operational context for the rest of the + # request. Community default is a no-op async CM; a registered + # consumer may keep the context bound across awaits. + async with (runtime_context.bind_context(context_id) + if context_id else nullcontext()): + # Get and build agent + root_agent, state_params = await self.utils.get_and_build_agent(agent_id) + + # Setup session + adk_session_id = self.utils.create_session_id( + external_id, agent_id, session_id + ) - # Configure Runner - agent_runner = self.utils.create_runner( - root_agent, agent_id, session_service, artifacts_service, memory_service - ) + # Configure Runner + agent_runner = self.utils.create_runner( + root_agent, agent_id, session_service, artifacts_service, memory_service + ) - # Get or create session - # Use user_id (contact_id) if provided, otherwise fallback to external_id (conversation UUID) - effective_user_id = user_id if user_id else external_id - session = await self.utils.get_or_create_session( - session_service, agent_id, effective_user_id, adk_session_id - ) + # Get or create session + # Use user_id (contact_id) if provided, otherwise fallback to external_id (conversation UUID) + effective_user_id = user_id if user_id else external_id + session = await self.utils.get_or_create_session( + session_service, agent_id, effective_user_id, adk_session_id + ) - # Setup session state - await self.utils.setup_session_state( - session_service, session, message, state_params, metadata - ) + # Setup session state + await self.utils.setup_session_state( + session_service, session, message, state_params, metadata + ) - # Process files - file_parts, transcribed_texts = await self.utils.process_files( - files, artifacts_service, agent_id, external_id, adk_session_id - ) + # Process files + file_parts, transcribed_texts = await self.utils.process_files( + files, artifacts_service, agent_id, external_id, adk_session_id + ) - # Save user message to memory individually (FIFO) before processing - if memory_service and hasattr(memory_service, "add_event_to_memory"): - try: - # Get agent from database to extract config - from src.services.agent_service import get_agent - agent = await get_agent(self.db, agent_id) + # Save user message to memory individually (FIFO) before processing + if memory_service and hasattr(memory_service, "add_event_to_memory"): + try: + # Get agent from database to extract config + from src.services.agent_service import get_agent + agent = await get_agent(self.db, agent_id) - # Check if load_memory is enabled - load_memory_enabled = False - memory_base_config_id = None - short_term_max_messages = None - compression_interval = None - - if agent: - if agent.config: + # Check if load_memory is enabled + load_memory_enabled = False + memory_base_config_id = None + short_term_max_messages = None + compression_interval = None + + if agent: + if agent.config: + agent_config = agent.config if isinstance(agent.config, dict) else {} + if isinstance(agent_config, dict): + load_memory_enabled = agent_config.get("load_memory", False) + memory_base_config_id = agent_config.get("memory_base_config_id") + short_term_max_messages = agent_config.get("memory_short_term_max_messages") + compression_interval = agent_config.get("memory_medium_term_compression_interval") + + # Only save to memory if load_memory is enabled + if load_memory_enabled: + # Combine message with transcribed texts + user_content = message + if transcribed_texts: + user_content += "\n\n" + "\n\n".join(transcribed_texts) + + if user_content.strip(): + await memory_service.add_event_to_memory( + app_name=agent_id, + user_id=effective_user_id, + role="user", + content=user_content, + memory_base_config_id=memory_base_config_id, + short_term_max_messages=short_term_max_messages, + compression_interval=compression_interval, + ) + except Exception as e: + logger.debug(f"Could not save user message to memory: {e}") + + # Preload memory if enabled (before processing user message) + if memory_service: + try: + from src.services.agent_service import get_agent + agent = await get_agent(self.db, agent_id) + if agent and agent.config: agent_config = agent.config if isinstance(agent.config, dict) else {} - if isinstance(agent_config, dict): - load_memory_enabled = agent_config.get("load_memory", False) - memory_base_config_id = agent_config.get("memory_base_config_id") - short_term_max_messages = agent_config.get("memory_short_term_max_messages") - compression_interval = agent_config.get("memory_medium_term_compression_interval") - - # Only save to memory if load_memory is enabled - if load_memory_enabled: - # Combine message with transcribed texts - user_content = message - if transcribed_texts: - user_content += "\n\n" + "\n\n".join(transcribed_texts) - - if user_content.strip(): - await memory_service.add_event_to_memory( - app_name=agent_id, - user_id=effective_user_id, - role="user", - content=user_content, - memory_base_config_id=memory_base_config_id, - short_term_max_messages=short_term_max_messages, - compression_interval=compression_interval, - ) - except Exception as e: - logger.debug(f"Could not save user message to memory: {e}") - - # Preload memory if enabled (before processing user message) - if memory_service: - try: - from src.services.agent_service import get_agent - agent = await get_agent(self.db, agent_id) - if agent and agent.config: - agent_config = agent.config if isinstance(agent.config, dict) else {} - if isinstance(agent_config, dict) and agent_config.get("preload_memory") and agent_config.get("load_memory"): - logger.info(f"Preloading memory for agent {agent_id}, user {effective_user_id}") - # Call memory load endpoint directly via HTTP to get medium_term summaries - from src.services.memory_service import HttpMemoryService - from src.config.settings import settings - import httpx + if isinstance(agent_config, dict) and agent_config.get("preload_memory") and agent_config.get("load_memory"): + logger.info(f"Preloading memory for agent {agent_id}, user {effective_user_id}") + # Call memory load endpoint directly via HTTP to get medium_term summaries + from src.services.memory_service import HttpMemoryService + from src.config.settings import settings + import httpx - if isinstance(memory_service, HttpMemoryService): - memory_base_config_id = agent_config.get("memory_base_config_id") - - # Call /memory/load endpoint directly (preload mode - empty query returns medium_term summaries) + if isinstance(memory_service, HttpMemoryService): + memory_base_config_id = agent_config.get("memory_base_config_id") + + # Call /memory/load endpoint directly (preload mode - empty query returns medium_term summaries) + try: + base_url = settings.KNOWLEDGE_SERVICE_URL.rstrip("/") + url = f"{base_url}/memory/load" + + params = { + "app_name": agent_id, + "user_id": effective_user_id, + "query": "", # Empty query loads medium_term summaries + "max_results": 10, + } + + headers = { + "Content-Type": "application/json", + "Accept": "application/json", + } + + # Add service token for service-to-service authentication + if settings.KNOWLEDGE_SERVICE_API_TOKEN: + headers["X-Service-Token"] = settings.KNOWLEDGE_SERVICE_API_TOKEN + + if memory_base_config_id: + headers["x-memory-base-config-id"] = str(memory_base_config_id) + + async with httpx.AsyncClient(timeout=30.0) as client: + response = await client.get(url, params=params, headers=headers) + response.raise_for_status() + response_data = response.json() + + memory_results = response_data.get("memories", []) + total = response_data.get("total", 0) + + if memory_results: + logger.info(f"Preloaded {len(memory_results)} memory summaries for agent {agent_id}") + + # Add preloaded memories as system events to the session + # This makes them available to the LLM as context + from google.adk.events import Event + from google.genai.types import Content, Part + import time + + # Combine all memory summaries into a single context message + memory_context_parts = [] + memory_context_parts.append("Previous conversation context:\n\n") + + for idx, memory in enumerate(memory_results, 1): + memory_content = memory.get("content", "") + memory_metadata = memory.get("metadata", {}) + memory_timestamp = memory.get("timestamp") + + if memory_content: + memory_context_parts.append(f"--- Summary {idx} ---\n") + memory_context_parts.append(f"{memory_content}\n") + if memory_timestamp: + memory_context_parts.append(f"(Date: {memory_timestamp})\n") + memory_context_parts.append("\n") + + if len(memory_context_parts) > 1: # More than just the header + memory_context_text = "".join(memory_context_parts).strip() + + # Create a system event with the memory context + memory_event = Event( + invocation_id=f"preload_memory_{int(time.time())}", + author="system", + content=Content( + role="system", + parts=[Part(text=memory_context_text)] + ), + timestamp=time.time(), + ) + + # Add the event to the session + await session_service.append_event(session, memory_event) + logger.debug(f"Added {len(memory_results)} memory summaries to session context") + else: + logger.debug(f"No memory summaries found for preload (agent {agent_id}, user {effective_user_id})") + except Exception as e: + logger.warning(f"Could not preload memory: {e}") + + # Preload knowledge if enabled (before processing user message) + if isinstance(agent_config, dict) and agent_config.get("preload_knowledge") and agent_config.get("load_knowledge"): + logger.info(f"Preloading knowledge for agent {agent_id}") try: + from src.config.settings import settings + import httpx + + knowledge_tags = agent_config.get("knowledge_tags") + knowledge_base_config_id = agent_config.get("knowledge_base_config_id") + knowledge_max_results = agent_config.get("knowledge_max_results", 5) + + # Call /knowledge/search endpoint directly for preload base_url = settings.KNOWLEDGE_SERVICE_URL.rstrip("/") - url = f"{base_url}/memory/load" + url = f"{base_url}/knowledge/search" - params = { - "app_name": agent_id, - "user_id": effective_user_id, - "query": "", # Empty query loads medium_term summaries - "max_results": 10, + # Use a general query for preload context + payload = { + "query": "general context and information", + "tags": knowledge_tags or [], + "max_results": knowledge_max_results, } headers = { "Content-Type": "application/json", "Accept": "application/json", } - + # Add service token for service-to-service authentication if settings.KNOWLEDGE_SERVICE_API_TOKEN: headers["X-Service-Token"] = settings.KNOWLEDGE_SERVICE_API_TOKEN - - if memory_base_config_id: - headers["x-memory-base-config-id"] = str(memory_base_config_id) - + + if knowledge_base_config_id: + headers["x-knowledge-base-config-id"] = str(knowledge_base_config_id) + async with httpx.AsyncClient(timeout=30.0) as client: - response = await client.get(url, params=params, headers=headers) + response = await client.post(url, json=payload, headers=headers) response.raise_for_status() response_data = response.json() - - memory_results = response_data.get("memories", []) + + knowledge_results = response_data.get("results", []) total = response_data.get("total", 0) + + if knowledge_results: + logger.info(f"Preloaded {len(knowledge_results)} knowledge entries for agent {agent_id}") - if memory_results: - logger.info(f"Preloaded {len(memory_results)} memory summaries for agent {agent_id}") - - # Add preloaded memories as system events to the session - # This makes them available to the LLM as context + # Add preloaded knowledge as system events to the session from google.adk.events import Event from google.genai.types import Content, Part import time + + # Combine all knowledge entries into a single context message + knowledge_context_parts = [] + knowledge_context_parts.append("Preloaded knowledge base context:\n\n") + + for idx, result in enumerate(knowledge_results, 1): + knowledge = result.get("knowledge", {}) + knowledge_title = knowledge.get("title", "") + knowledge_content = knowledge.get("content", "") + knowledge_description = knowledge.get("description", "") - # Combine all memory summaries into a single context message - memory_context_parts = [] - memory_context_parts.append("Previous conversation context:\n\n") - - for idx, memory in enumerate(memory_results, 1): - memory_content = memory.get("content", "") - memory_metadata = memory.get("metadata", {}) - memory_timestamp = memory.get("timestamp") - - if memory_content: - memory_context_parts.append(f"--- Summary {idx} ---\n") - memory_context_parts.append(f"{memory_content}\n") - if memory_timestamp: - memory_context_parts.append(f"(Date: {memory_timestamp})\n") - memory_context_parts.append("\n") + if knowledge_content: + knowledge_context_parts.append(f"--- Knowledge Entry {idx} ---\n") + if knowledge_title: + knowledge_context_parts.append(f"Title: {knowledge_title}\n") + if knowledge_description: + knowledge_context_parts.append(f"Description: {knowledge_description}\n") + knowledge_context_parts.append(f"Content: {knowledge_content}\n") + knowledge_context_parts.append("\n") + + if len(knowledge_context_parts) > 1: # More than just the header + knowledge_context_text = "".join(knowledge_context_parts).strip() - if len(memory_context_parts) > 1: # More than just the header - memory_context_text = "".join(memory_context_parts).strip() - - # Create a system event with the memory context - memory_event = Event( - invocation_id=f"preload_memory_{int(time.time())}", + # Create a system event with the knowledge context + knowledge_event = Event( + invocation_id=f"preload_knowledge_{int(time.time())}", author="system", content=Content( role="system", - parts=[Part(text=memory_context_text)] + parts=[Part(text=knowledge_context_text)] ), timestamp=time.time(), ) - + # Add the event to the session - await session_service.append_event(session, memory_event) - logger.debug(f"Added {len(memory_results)} memory summaries to session context") + await session_service.append_event(session, knowledge_event) + logger.debug(f"Added {len(knowledge_results)} knowledge entries to session context") else: - logger.debug(f"No memory summaries found for preload (agent {agent_id}, user {effective_user_id})") + logger.debug(f"No knowledge entries found for preload (agent {agent_id})") except Exception as e: - logger.warning(f"Could not preload memory: {e}") - - # Preload knowledge if enabled (before processing user message) - if isinstance(agent_config, dict) and agent_config.get("preload_knowledge") and agent_config.get("load_knowledge"): - logger.info(f"Preloading knowledge for agent {agent_id}") - try: - from src.config.settings import settings - import httpx - - knowledge_tags = agent_config.get("knowledge_tags") - knowledge_base_config_id = agent_config.get("knowledge_base_config_id") - knowledge_max_results = agent_config.get("knowledge_max_results", 5) - - # Call /knowledge/search endpoint directly for preload - base_url = settings.KNOWLEDGE_SERVICE_URL.rstrip("/") - url = f"{base_url}/knowledge/search" - - # Use a general query for preload context - payload = { - "query": "general context and information", - "tags": knowledge_tags or [], - "max_results": knowledge_max_results, - } - - headers = { - "Content-Type": "application/json", - "Accept": "application/json", - } - - # Add service token for service-to-service authentication - if settings.KNOWLEDGE_SERVICE_API_TOKEN: - headers["X-Service-Token"] = settings.KNOWLEDGE_SERVICE_API_TOKEN - - if knowledge_base_config_id: - headers["x-knowledge-base-config-id"] = str(knowledge_base_config_id) - - async with httpx.AsyncClient(timeout=30.0) as client: - response = await client.post(url, json=payload, headers=headers) - response.raise_for_status() - response_data = response.json() - - knowledge_results = response_data.get("results", []) - total = response_data.get("total", 0) - - if knowledge_results: - logger.info(f"Preloaded {len(knowledge_results)} knowledge entries for agent {agent_id}") - - # Add preloaded knowledge as system events to the session - from google.adk.events import Event - from google.genai.types import Content, Part - import time - - # Combine all knowledge entries into a single context message - knowledge_context_parts = [] - knowledge_context_parts.append("Preloaded knowledge base context:\n\n") - - for idx, result in enumerate(knowledge_results, 1): - knowledge = result.get("knowledge", {}) - knowledge_title = knowledge.get("title", "") - knowledge_content = knowledge.get("content", "") - knowledge_description = knowledge.get("description", "") - - if knowledge_content: - knowledge_context_parts.append(f"--- Knowledge Entry {idx} ---\n") - if knowledge_title: - knowledge_context_parts.append(f"Title: {knowledge_title}\n") - if knowledge_description: - knowledge_context_parts.append(f"Description: {knowledge_description}\n") - knowledge_context_parts.append(f"Content: {knowledge_content}\n") - knowledge_context_parts.append("\n") - - if len(knowledge_context_parts) > 1: # More than just the header - knowledge_context_text = "".join(knowledge_context_parts).strip() - - # Create a system event with the knowledge context - knowledge_event = Event( - invocation_id=f"preload_knowledge_{int(time.time())}", - author="system", - content=Content( - role="system", - parts=[Part(text=knowledge_context_text)] - ), - timestamp=time.time(), - ) - - # Add the event to the session - await session_service.append_event(session, knowledge_event) - logger.debug(f"Added {len(knowledge_results)} knowledge entries to session context") - else: - logger.debug(f"No knowledge entries found for preload (agent {agent_id})") - except Exception as e: - logger.warning(f"Could not preload knowledge: {e}") - except Exception as e: - logger.debug(f"Could not check preload config: {e}") - - # Create content with transcribed audio if available - if transcribed_texts: - content = self.utils.create_content_with_transcribed_audio( - message, file_parts, transcribed_texts - ) - else: - content = self.utils.create_content(message, file_parts) + logger.warning(f"Could not preload knowledge: {e}") + except Exception as e: + logger.debug(f"Could not check preload config: {e}") + + # Create content with transcribed audio if available + if transcribed_texts: + content = self.utils.create_content_with_transcribed_audio( + message, file_parts, transcribed_texts + ) + else: + content = self.utils.create_content(message, file_parts) - # If content is None (empty message/transcription), skip processing - if content is None: - logger.info( - "No meaningful content to process, skipping agent execution" - ) - return { - "final_response": "No content to process", - "message_history": [], - } + # If content is None (empty message/transcription), skip processing + if content is None: + logger.info( + "No meaningful content to process, skipping agent execution" + ) + return { + "final_response": "No content to process", + "message_history": [], + } - # Run agent and collect response - final_response_text = "No response captured." - message_history = [] + # Run agent and collect response + final_response_text = "No response captured." + message_history = [] - try: - total_prompt_tokens = 0 - total_candidate_tokens = 0 - total_tokens = 0 + try: + total_prompt_tokens = 0 + total_candidate_tokens = 0 + total_tokens = 0 - events_async = agent_runner.run_async( - user_id=effective_user_id, - session_id=adk_session_id, - new_message=content, - ) + events_async = agent_runner.run_async( + user_id=effective_user_id, + session_id=adk_session_id, + new_message=content, + ) - async for event in events_async: - if event.usage_metadata: - total_prompt_tokens += ( - event.usage_metadata.prompt_token_count or 0 - ) - total_candidate_tokens += ( - event.usage_metadata.candidates_token_count or 0 - ) - total_tokens += event.usage_metadata.total_token_count or 0 - - if event.content and event.content.parts: - # Handle both Pydantic v2 (model_dump) and older versions - if hasattr(event, "model_dump"): - event_dict = event.model_dump() - elif hasattr(event, "dict"): - event_dict = event.dict() - else: - event_dict = event.__dict__ - event_dict = convert_sets(event_dict) - message_history.append(event_dict) + async for event in events_async: + if event.usage_metadata: + total_prompt_tokens += ( + event.usage_metadata.prompt_token_count or 0 + ) + total_candidate_tokens += ( + event.usage_metadata.candidates_token_count or 0 + ) + total_tokens += event.usage_metadata.total_token_count or 0 + + if event.content and event.content.parts: + # Handle both Pydantic v2 (model_dump) and older versions + if hasattr(event, "model_dump"): + event_dict = event.model_dump() + elif hasattr(event, "dict"): + event_dict = event.dict() + else: + event_dict = event.__dict__ + event_dict = convert_sets(event_dict) + message_history.append(event_dict) - # Save event to memory individually (FIFO) - if memory_service and hasattr(memory_service, "add_event_to_memory"): - try: - # Extract text from event - event_text = "" - if event.content.parts: - for part in event.content.parts: - if hasattr(part, "text") and part.text: - event_text += part.text + " " - event_text = event_text.strip() + # Save event to memory individually (FIFO) + if memory_service and hasattr(memory_service, "add_event_to_memory"): + try: + # Extract text from event + event_text = "" + if event.content.parts: + for part in event.content.parts: + if hasattr(part, "text") and part.text: + event_text += part.text + " " + event_text = event_text.strip() - if event_text: - # Get agent from database to check if load_memory is enabled - from src.services.agent_service import get_agent - agent = await get_agent(self.db, agent_id) + if event_text: + # Get agent from database to check if load_memory is enabled + from src.services.agent_service import get_agent + agent = await get_agent(self.db, agent_id) - # Check if load_memory is enabled - load_memory_enabled = False - memory_base_config_id = None - short_term_max_messages = None - compression_interval = None - - if agent: - if agent.config: - agent_config = agent.config if isinstance(agent.config, dict) else {} - if isinstance(agent_config, dict): - load_memory_enabled = agent_config.get("load_memory", False) - memory_base_config_id = agent_config.get("memory_base_config_id") - short_term_max_messages = agent_config.get("memory_short_term_max_messages") - compression_interval = agent_config.get("memory_medium_term_compression_interval") - - # Only save to memory if load_memory is enabled - if load_memory_enabled: - # Determine role (agent response) - role = "agent" - - await memory_service.add_event_to_memory( - app_name=agent_id, - user_id=effective_user_id, - role=role, - content=event_text, - memory_base_config_id=memory_base_config_id, - short_term_max_messages=short_term_max_messages, - compression_interval=compression_interval, - ) - except Exception as e: - logger.debug(f"Could not save event to memory: {e}") - - if ( - event.content - and event.content.parts - and event.content.parts[0].text - ): - final_response_text = event.content.parts[0].text - - if event.actions and event.actions.escalate: - final_response_text = f"Agent escalated: {event.error_message or 'No specific message.'}" - break - - logger.info( - f"Session tokens: {total_tokens} (prompt={total_prompt_tokens}," - f" candidates={total_candidate_tokens})" - ) - - try: - if not hasattr(root_agent, "model"): - model_str = "external" - else: - # Handle both Pydantic v2 (model_dump) and older versions - if hasattr(root_agent.model, "model_dump"): - model_dict = root_agent.model.model_dump() - elif hasattr(root_agent.model, "dict"): - model_dict = root_agent.model.dict() - else: - model_dict = root_agent.model.__dict__ - model_str = model_dict.get("model", str(root_agent.model)) - metrics_data = ExecutionMetricsCreate( - agent_id=uuid.UUID(agent_id), - session_id=adk_session_id, - user_id=effective_user_id, - llm_model=str(model_str), - prompt_tokens=total_prompt_tokens, - candidate_tokens=total_candidate_tokens, - cost=0.0, - total_tokens=total_tokens, + # Check if load_memory is enabled + load_memory_enabled = False + memory_base_config_id = None + short_term_max_messages = None + compression_interval = None + + if agent: + if agent.config: + agent_config = agent.config if isinstance(agent.config, dict) else {} + if isinstance(agent_config, dict): + load_memory_enabled = agent_config.get("load_memory", False) + memory_base_config_id = agent_config.get("memory_base_config_id") + short_term_max_messages = agent_config.get("memory_short_term_max_messages") + compression_interval = agent_config.get("memory_medium_term_compression_interval") + + # Only save to memory if load_memory is enabled + if load_memory_enabled: + # Determine role (agent response) + role = "agent" + + await memory_service.add_event_to_memory( + app_name=agent_id, + user_id=effective_user_id, + role=role, + content=event_text, + memory_base_config_id=memory_base_config_id, + short_term_max_messages=short_term_max_messages, + compression_interval=compression_interval, + ) + except Exception as e: + logger.debug(f"Could not save event to memory: {e}") + + if ( + event.content + and event.content.parts + and event.content.parts[0].text + ): + final_response_text = event.content.parts[0].text + + if event.actions and event.actions.escalate: + final_response_text = f"Agent escalated: {event.error_message or 'No specific message.'}" + break + + logger.info( + f"Session tokens: {total_tokens} (prompt={total_prompt_tokens}," + f" candidates={total_candidate_tokens})" ) - create_execution_metrics(self.db, metrics_data) - except Exception as e: - logger.error(f"Error creating execution metrics: {e}") - # Extension point: usage reporter. Always called after the - # local persistence above; default is a no-op. A misbehaving - # consumer cannot break the run — we swallow the exception - # and log with full context. - try: - usage_reporter.report_execution( - ExecutionMetrics( - execution_id=adk_session_id, + try: + if not hasattr(root_agent, "model"): + model_str = "external" + else: + # Handle both Pydantic v2 (model_dump) and older versions + if hasattr(root_agent.model, "model_dump"): + model_dict = root_agent.model.model_dump() + elif hasattr(root_agent.model, "dict"): + model_dict = root_agent.model.dict() + else: + model_dict = root_agent.model.__dict__ + model_str = model_dict.get("model", str(root_agent.model)) + metrics_data = ExecutionMetricsCreate( + agent_id=uuid.UUID(agent_id), + session_id=adk_session_id, + user_id=effective_user_id, + llm_model=str(model_str), prompt_tokens=total_prompt_tokens, candidate_tokens=total_candidate_tokens, - total_tokens=total_tokens, cost=0.0, + total_tokens=total_tokens, + ) + create_execution_metrics(self.db, metrics_data) + except Exception as e: + logger.error(f"Error creating execution metrics: {e}") + + # Extension point: usage reporter. Always called after the + # local persistence above; default is a no-op. A misbehaving + # consumer cannot break the run — we swallow the exception + # and log with full context. + try: + usage_reporter.report_execution( + ExecutionMetrics( + execution_id=adk_session_id, + prompt_tokens=total_prompt_tokens, + candidate_tokens=total_candidate_tokens, + total_tokens=total_tokens, + cost=0.0, + ) + ) + except Exception: + logger.exception( + "usage_reporter.report_execution failed for" + f" execution_id={adk_session_id!r}" + f" impl={ep_impl_for('usage_reporter')!r}" ) - ) - except Exception: - logger.exception( - "usage_reporter.report_execution failed for" - f" execution_id={adk_session_id!r}" - f" impl={ep_impl_for('usage_reporter')!r}" - ) - except Exception as e: - logger.error(f"Error processing request: {str(e)}") - raise InternalServerError(str(e)) from e + except Exception as e: + logger.error(f"Error processing request: {str(e)}") + raise InternalServerError(str(e)) from e - # Note: We no longer save the entire session to memory at the end - # Events are saved individually during execution (FIFO) + # Note: We no longer save the entire session to memory at the end + # Events are saved individually during execution (FIFO) - logger.info("Agent execution completed successfully") - return { - "final_response": final_response_text, - "message_history": message_history, - } + logger.info("Agent execution completed successfully") + return { + "final_response": final_response_text, + "message_history": message_history, + } except AgentNotFoundError as e: logger.error(f"Agent not found: {str(e)}") diff --git a/tests/integration/test_evo_extension_points_runner.py b/tests/integration/test_evo_extension_points_runner.py index 9b8fc19..6cb976b 100644 --- a/tests/integration/test_evo_extension_points_runner.py +++ b/tests/integration/test_evo_extension_points_runner.py @@ -125,3 +125,100 @@ def report_execution(self, metrics): evo_extension_points.replace("usage_reporter", BoomReporter()) result = _mirror_runner_hooks({}) assert result["reported"] is False + + +class TestBindContextWrap: + """Mirror the ``async with runtime_context.bind_context(...)`` wrap + added to ``StandardRunner.run_agent``. + + These tests do not execute the full runner; they reproduce the wrap + shape so a future change to it is caught alongside the EP contract. + """ + + @staticmethod + async def _mirror_bind_wrap(context_id): + from contextlib import nullcontext + + events = [] + async with ( + runtime_context.bind_context(context_id) + if context_id + else nullcontext() + ): + events.append("inside") + events.append("after") + return events + + @pytest.mark.asyncio + async def test_default_bind_is_noop_async_cm(self): + # No consumer registered; community default returns a no-op async CM. + events = await self._mirror_bind_wrap("ctx-1") + assert events == ["inside", "after"] + + @pytest.mark.asyncio + async def test_null_branch_when_context_id_is_none(self): + # nullcontext() must support async with on Python 3.10+. + events = await self._mirror_bind_wrap(None) + assert events == ["inside", "after"] + + @pytest.mark.asyncio + async def test_consumer_bind_context_is_invoked(self): + from contextlib import asynccontextmanager + + seen = {"enter": 0, "exit": 0, "id": None} + + class Consumer: + def current_context_id(self, source): + return None + + def with_context(self, context_id, fn): + return fn() + + @asynccontextmanager + async def bind_context(self, context_id): + seen["enter"] += 1 + seen["id"] = context_id + try: + yield + finally: + seen["exit"] += 1 + + evo_extension_points.replace("runtime_context", Consumer()) + events = await self._mirror_bind_wrap("ctx-from-consumer") + assert events == ["inside", "after"] + assert seen == {"enter": 1, "exit": 1, "id": "ctx-from-consumer"} + + @pytest.mark.asyncio + async def test_consumer_bind_context_resets_on_exception(self): + from contextlib import asynccontextmanager + + seen = {"enter": 0, "exit": 0} + + class Consumer: + def current_context_id(self, source): + return None + + def with_context(self, context_id, fn): + return fn() + + @asynccontextmanager + async def bind_context(self, context_id): + seen["enter"] += 1 + try: + yield + finally: + seen["exit"] += 1 + + evo_extension_points.replace("runtime_context", Consumer()) + + from contextlib import nullcontext + + with pytest.raises(RuntimeError): + async with ( + runtime_context.bind_context("ctx-x") + if "ctx-x" + else nullcontext() + ): + raise RuntimeError("boom in body") + + assert seen == {"enter": 1, "exit": 1}