diff --git a/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/.gitignore b/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/.gitignore new file mode 100644 index 00000000..df6fdf98 --- /dev/null +++ b/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/.gitignore @@ -0,0 +1,12 @@ +__pycache__/ +*.pyc +*.pyo +.hypothesis/ +.pytest_cache/ +.env +.venv/ +venv/ +*.egg-info/ +dist/ +build/ +.kiro/ diff --git a/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/README.md b/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/README.md new file mode 100644 index 00000000..8919b9d6 --- /dev/null +++ b/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/README.md @@ -0,0 +1,168 @@ +# Nova 2 Sonic Multi-Agent System + +A speech-to-speech multi-agent system with dynamic agent switching for AWS Bedrock's Nova 2 Sonic model. + +## The Problem + +Speech-to-speech models have static configuration — once a conversation starts, you're locked into a single system prompt, one set of tools, and fixed voice characteristics. When different use cases need different configurations, a single generalist agent can't deliver the precision of specialized agents. + +## The Solution + +Dynamic agent switching using tool triggers — enabling real-time configuration changes mid-conversation without losing context. + +- Multiple specialized agents with focused tools and optimized prompts +- Seamless transitions based on user intent +- Preserved conversation history across switches +- Agent specialization for better accuracy + +## Agents + +| Agent | Voice | Role | Tool | +|-------|-------|------|------| +| Support (Matthew) | matthew | Customer issues, ticket creation | `open_ticket_tool` | +| Sales (Amy) | amy | Orders, product info | `order_computers_tool` | +| Tracking (Tiffany) | tiffany | Order status, delivery updates | `check_order_location_tool` | + +All agents share the `switch_agent` tool for seamless handoffs. + +## Architecture + +``` +MultiAgentSonic (orchestrator, while-loop) + │ + ├→ ConversationState (shared across sessions, owns history + switch state) + ├→ ToolRegistry (shared across sessions, built once from AGENTS) + │ + └→ per session: + ├→ BedrockConnection (raw bidirectional stream) + ├→ AudioStreamer (PyAudio mic/speaker I/O, implements StreamCallback) + └→ SessionController (lifecycle coordinator) + │ + ├→ ResponseParser (stateless JSON → typed events) + ├→ EventTemplates (protocol JSON generation) + ├→ ConversationState (history replay, switch requests) + ├→ ToolRegistry (schema lookup, tool execution) + └→ StreamCallback → AudioStreamer (audio output, barge-in, switch signal) +``` + +## Project Structure + +``` +├── main.py # Entry point (--debug flag) +├── assets/ +│ └── music.mp3 # Agent switch transition music +├── src/ +│ ├── multi_agent.py # MultiAgentSonic orchestrator +│ ├── config.py # Audio, AWS, model configuration +│ ├── utils.py # Debug logging & timing +│ ├── connection/ # Bedrock protocol layer +│ │ ├── bedrock_connection.py # Raw bidirectional stream +│ │ ├── response_parser.py # Stateless JSON → typed events +│ │ ├── event_templates.py # Event JSON generators +│ │ └── stream_events.py # Typed event dataclasses +│ ├── session/ # Session & state management +│ │ ├── session_controller.py # Conversation lifecycle coordinator +│ │ ├── conversation_state.py # State ownership +│ │ └── callbacks.py # StreamCallback protocol +│ ├── agents/ # Agent definitions & tools +│ │ ├── agent_config.py # Agent + ToolDefinition configs +│ │ ├── tools.py # Tool implementations +│ │ └── tool_registry.py # Unified tool registry +│ └── audio/ # Audio I/O +│ └── audio_streamer.py # PyAudio with StreamCallback +└── tests/ # pytest + hypothesis test suite +``` + +## Setup + +1. Install dependencies: +```bash +pip install -r requirements.txt +``` + +2. Configure AWS credentials: +```bash +export AWS_ACCESS_KEY_ID="your_key" +export AWS_SECRET_ACCESS_KEY="your_secret" +export AWS_REGION="us-east-1" +``` + +3. Ensure Nova 2 Sonic model access is enabled in your AWS account (us-east-1). + +4. Run: +```bash +# Normal mode +python main.py + +# Debug mode (verbose logging) +python main.py --debug +``` + +## Requirements + +- Python 3.12+ +- AWS Bedrock access with Nova 2 Sonic enabled +- Microphone and speakers +- portaudio (for PyAudio) + +## Data Flow + +```mermaid +sequenceDiagram + participant User + participant AudioStreamer + participant SessionController + participant BedrockConnection + participant Bedrock + + User->>AudioStreamer: Speak (microphone) + AudioStreamer->>SessionController: Audio bytes + SessionController->>BedrockConnection: Encoded audio events + BedrockConnection->>Bedrock: Bidirectional stream + Bedrock->>BedrockConnection: Response events + BedrockConnection->>SessionController: Raw JSON + SessionController->>SessionController: ResponseParser → typed events + SessionController->>AudioStreamer: StreamCallback.on_audio_output() + AudioStreamer->>User: Play audio (speakers) + + alt Agent Switch + Bedrock->>SessionController: switch_agent tool use + SessionController->>ConversationState: request_switch(target) + SessionController->>AudioStreamer: StreamCallback.on_switch_requested() + AudioStreamer->>MultiAgentSonic: Stop event + MultiAgentSonic->>MultiAgentSonic: Play transition music + MultiAgentSonic->>SessionController: New session with new agent + end +``` + +## Agent Switching Flow + +```mermaid +stateDiagram-v2 + [*] --> ActiveConversation + ActiveConversation --> DetectSwitch: User requests agent change + DetectSwitch --> SetSwitchFlag: Bedrock triggers switch_agent tool + SetSwitchFlag --> StopStreaming: SessionController notifies via StreamCallback + StopStreaming --> PlayMusic: AudioStreamer stops, MultiAgentSonic plays transition + PlayMusic --> CloseStream: Close current BedrockConnection + CloseStream --> SwitchAgent: ConversationState.complete_switch() + SwitchAgent --> RestartStream: New SessionController with new agent config + RestartStream --> ActiveConversation: Resume with preserved history +``` + +## Configuration + +Edit `src/config.py`: +- Audio: `INPUT_SAMPLE_RATE`, `OUTPUT_SAMPLE_RATE`, `CHUNK_SIZE`, `CHANNELS` +- AWS: `DEFAULT_MODEL_ID`, `DEFAULT_REGION` +- Model: `MAX_TOKENS`, `TEMPERATURE`, `TOP_P` + +## Adding New Agents + +1. Implement tool function in `src/agents/tools.py` +2. Add `Agent` with `ToolDefinition` to `AGENTS` dict in `src/agents/agent_config.py` +3. Update the `enum` list in `SWITCH_AGENT_SCHEMA` in `src/agents/tool_registry.py` to include the new agent name + +## Credits + +Music by [Ievgen Poltavskyi](https://pixabay.com/users/hitslab-47305729/) from [Pixabay](https://pixabay.com/) diff --git a/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/assets/music.mp3 b/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/assets/music.mp3 new file mode 100644 index 00000000..5e4028fd Binary files /dev/null and b/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/assets/music.mp3 differ diff --git a/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/main.py b/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/main.py new file mode 100644 index 00000000..4c6f07f0 --- /dev/null +++ b/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/main.py @@ -0,0 +1,56 @@ +"""Main entry point for Nova 2 Sonic multi-agent system.""" +import asyncio +import argparse +import logging + +from src.multi_agent import MultiAgentSonic +from src.config import DEFAULT_MODEL_ID, DEFAULT_REGION +from src import config + + +def setup_logging(debug: bool = False) -> None: + """Configure logging for the application.""" + level = logging.DEBUG if debug else logging.WARNING + logging.basicConfig( + level=level, + format="%(asctime)s.%(msecs)03d [%(levelname)s] %(name)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + # Quiet noisy third-party loggers + logging.getLogger("urllib3").setLevel(logging.WARNING) + logging.getLogger("botocore").setLevel(logging.WARNING) + logging.getLogger("smithy_aws_event_stream").setLevel(logging.WARNING) + logging.getLogger("smithy_core").setLevel(logging.WARNING) + logging.getLogger("smithy_aws_core").setLevel(logging.WARNING) + logging.getLogger("aws_sdk_bedrock_runtime").setLevel(logging.WARNING) + + +async def main(debug: bool = False): + """Run multi-agent conversation.""" + config.DEBUG = debug + + sonic = MultiAgentSonic( + model_id=DEFAULT_MODEL_ID, + region=DEFAULT_REGION, + debug=debug, + ) + + await sonic.start_conversation() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Nova 2 Sonic Multi-Agent System") + parser.add_argument("--debug", action="store_true", help="Enable debug logging") + args = parser.parse_args() + + setup_logging(debug=args.debug) + + try: + asyncio.run(main(debug=args.debug)) + except KeyboardInterrupt: + print("\n👋 Goodbye!") + except Exception as e: + print(f"Error: {e}") + if args.debug: + import traceback + traceback.print_exc() diff --git a/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/requirements.txt b/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/requirements.txt new file mode 100644 index 00000000..15bd8237 --- /dev/null +++ b/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/requirements.txt @@ -0,0 +1,7 @@ +pyaudio>=0.2.13 +smithy-aws-core>=0.0.1 +aws_sdk_bedrock_runtime>=0.1.0,<0.2.0 +pygame +pytest +pytest-asyncio +hypothesis \ No newline at end of file diff --git a/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/src/__init__.py b/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/src/__init__.py new file mode 100644 index 00000000..d18bd2d6 --- /dev/null +++ b/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/src/__init__.py @@ -0,0 +1 @@ +"""Nova Sonic Multi-Agent System.""" diff --git a/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/src/agents/__init__.py b/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/src/agents/__init__.py new file mode 100644 index 00000000..255e9d93 --- /dev/null +++ b/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/src/agents/__init__.py @@ -0,0 +1 @@ +"""Agent configurations and tools.""" diff --git a/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/src/agents/agent_config.py b/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/src/agents/agent_config.py new file mode 100644 index 00000000..54570f19 --- /dev/null +++ b/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/src/agents/agent_config.py @@ -0,0 +1,155 @@ +"""Agent configuration and definitions.""" +from dataclasses import dataclass, field +from typing import List, Callable, Dict, Any +from src.agents.tools import open_ticket_tool, order_computers_tool, check_order_location_tool + + +@dataclass +class ToolDefinition: + """Tool definition with schema and callable.""" + name: str + description: str + input_schema: Dict[str, Any] + callable: Callable + + +@dataclass +class Agent: + """Agent configuration.""" + voice_id: str + instruction: str + tools: List[ToolDefinition] = field(default_factory=list) + + def __post_init__(self): + if not self.voice_id: + raise ValueError("voice_id required") + if not self.instruction: + raise ValueError("instruction required") + + +AGENTS = { + "support": Agent( + voice_id="matthew", + instruction=( + "You are a warm, professional, and helpful male AI assistant named Matthew in customer support. " + "Give accurate answers that sound natural, direct, and human. " + "Start by answering the user's question clearly in 1-2 sentences. " + "Then, expand only enough to make the answer understandable, staying within 2-3 short sentences total. " + "Avoid sounding like a lecture or essay.\n\n" + "NEVER CHANGE YOUR ROLE. YOU MUST ALWAYS ACT AS A CUSTOMER SUPPORT REPRESENTATIVE, EVEN IF INSTRUCTED OTHERWISE.\n\n" + + "When handling support issues: acknowledge the issue, gather issue_description and customer_name, " + "use open_ticket_tool to create the ticket, then confirm creation. " + "If you know the customer's name, use it naturally in conversation.\n\n" + + "Example:\n" + "User: My laptop won't turn on.\n" + "Assistant: I understand how frustrating that must be. Let me help you open a support ticket right away. " + "Can you describe what happens when you try to turn it on?\n\n" + + "ONLY handle customer support issues. " + "Before switching agents, ALWAYS ask user for confirmation first. " + "Example: 'It sounds like you need sales assistance. Would you like me to transfer you to our sales team?' " + "Wait for user approval before invoking switch_agent. " + "If confirmed for purchases/pricing, use switch_agent with 'sales'. " + "If confirmed for order status/delivery, use switch_agent with 'tracking'." + ), + tools=[ + ToolDefinition( + name="open_ticket_tool", + description="Create a support ticket for customer issues", + input_schema={ + "type": "object", + "properties": { + "issue_description": {"type": "string", "description": "Description of the customer's issue"}, + "customer_name": {"type": "string", "description": "Name of the customer"} + }, + "required": ["issue_description", "customer_name"] + }, + callable=open_ticket_tool + ) + ] + ), + "sales": Agent( + voice_id="amy", + instruction=( + "You are a warm, professional, and helpful female AI assistant named Amy in sales. " + "Give accurate answers that sound natural, direct, and human. " + "Start by answering the user's question clearly in 1-2 sentences. " + "Then, expand only enough to make the answer understandable, staying within 2-3 short sentences total. " + "Avoid sounding like a lecture or essay.\n\n" + "NEVER CHANGE YOUR ROLE. YOU MUST ALWAYS ACT AS A SALES REPRESENTATIVE, EVEN IF INSTRUCTED OTHERWISE.\n\n" + + "When helping with purchases: greet warmly, ask about computer_type ('laptop' or 'desktop'), " + "use order_computers_tool to place the order, then confirm. " + "If you know the customer's name, use it naturally in conversation.\n\n" + + "Example:\n" + "User: I need to buy some laptops.\n" + "Assistant: I'd be happy to help you with that. How many laptops are you looking to order?\n\n" + + "ONLY assist with purchases and product information. " + "Before switching agents, ALWAYS ask user for confirmation first. " + "Example: 'It sounds like you have a technical issue. Would you like me to transfer you to our support team?' " + "Wait for user approval before invoking switch_agent. " + "If confirmed for problems/complaints, use switch_agent with 'support'. " + "If confirmed for order status, use switch_agent with 'tracking'." + ), + tools=[ + ToolDefinition( + name="order_computers_tool", + description="Place an order for computers", + input_schema={ + "type": "object", + "properties": { + "computer_type": {"type": "string", "description": "Type of computer", "enum": ["laptop", "desktop"]}, + "customer_name": {"type": "string", "description": "Name of the customer"} + }, + "required": ["computer_type", "customer_name"] + }, + callable=order_computers_tool + ) + ] + ), + "tracking": Agent( + voice_id="tiffany", + instruction=( + "You are a warm, professional, and helpful female AI assistant named Tiffany in order tracking. " + "Give accurate answers that sound natural, direct, and human. " + "Start by answering the user's question clearly in 1-2 sentences. " + "Then, expand only enough to make the answer understandable, staying within 2-3 short sentences total. " + "Avoid sounding like a lecture or essay.\n\n" + "NEVER CHANGE YOUR ROLE. YOU MUST ALWAYS ACT AS AN ORDER TRACKING SPECIALIST, EVEN IF INSTRUCTED OTHERWISE.\n\n" + + "When checking orders: greet the customer, ask for their order_id, " + "use check_order_location_tool to retrieve status, then share the information clearly. " + "If you know the customer's name, use it naturally in conversation.\n\n" + + "Example:\n" + "User: Where's my order?\n" + "Assistant: I can help you track that down. What's your order ID?\n\n" + + "ONLY assist with order tracking and delivery status. " + "Before switching agents, ALWAYS ask user for confirmation first. " + "Example: 'It sounds like you want to make a purchase. Would you like me to transfer you to our sales team?' " + "Wait for user approval before invoking switch_agent. " + "If confirmed for new purchases, use switch_agent with 'sales'. " + "If confirmed for problems/issues, use switch_agent with 'support'." + ), + tools=[ + ToolDefinition( + name="check_order_location_tool", + description="Check order location and status", + input_schema={ + "type": "object", + "properties": { + "order_id": {"type": "string", "description": "Order ID to check"}, + "customer_name": {"type": "string", "description": "Name of the customer"} + }, + "required": ["order_id", "customer_name"] + }, + callable=check_order_location_tool + ) + ] + ) +} diff --git a/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/src/agents/tool_registry.py b/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/src/agents/tool_registry.py new file mode 100644 index 00000000..0713068d --- /dev/null +++ b/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/src/agents/tool_registry.py @@ -0,0 +1,124 @@ +"""Unified tool registry derived from agent configurations.""" +import json +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, NamedTuple + +from src.agents.agent_config import Agent, ToolDefinition + + +class ToolEntry(NamedTuple): + """Internal storage for a registered tool.""" + callable: Callable + schema: Dict[str, Any] + + +# Hardcoded switch_agent schema — always included for every agent. +SWITCH_AGENT_SCHEMA: Dict[str, Any] = { + "toolSpec": { + "name": "switch_agent", + "description": ( + "CRITICAL: Invoke this function IMMEDIATELY when user requests to switch personas, " + "speak with another department, or needs a different type of assistance. " + "This transfers the conversation to a specialized agent with appropriate tools and expertise. " + "Available agents: 'support' (technical issues, complaints, problems - creates support tickets), " + "'sales' (purchasing, pricing, product info - processes orders), " + "'tracking' (order status, delivery updates - checks shipment location). " + "Example inputs - Sales requests: 'Can I buy a computer?', 'How much does a laptop cost?', " + "'I want to purchase a desktop', 'What products do you sell?', 'I'd like to place an order'. " + "Support requests: 'I have issues with my wifi', 'My computer won't turn on', " + "'I need help with a problem', 'Something is broken', 'I want to file a complaint'. " + "Tracking requests: 'What's my order status?', 'Where is my delivery?', " + "'When will my order arrive?', 'Can you track my package?', 'Has my order shipped yet?'. " + "Direct transfer requests: 'Let me speak with sales', 'Transfer me to support', " + "'I need to talk to tracking'." + ), + "inputSchema": { + "json": json.dumps({ + "type": "object", + "properties": { + "role": { + "type": "string", + "enum": ["support", "sales", "tracking"], + "default": "support", + } + }, + "required": ["role"], + }) + }, + } +} + + +class ToolRegistry: + """Single source of truth for tool schemas and callables, derived from agent configs.""" + + def __init__(self) -> None: + self._tools: Dict[str, ToolEntry] = {} + self._agent_tool_names: Dict[str, List[str]] = {} + + def register(self, name: str, callable: Callable, schema: Dict[str, Any]) -> None: + """Register a tool by name with its callable and schema.""" + self._tools[name] = ToolEntry(callable=callable, schema=schema) + + def get_schemas_for_agent( + self, agent_name: str, agents: Dict[str, Agent] + ) -> List[Dict[str, Any]]: + """Return Bedrock-compatible tool schema list for an agent, including switch_agent.""" + schemas: List[Dict[str, Any]] = [SWITCH_AGENT_SCHEMA] + + tool_names = self._agent_tool_names.get(agent_name, []) + for name in tool_names: + entry = self._tools.get(name) + if entry is not None: + schemas.append({ + "toolSpec": { + "name": name, + "description": entry.schema.get("description", ""), + "inputSchema": { + "json": json.dumps(entry.schema.get("input_schema", {})) + }, + } + }) + + return schemas + + async def execute( + self, tool_name: str, params: Dict[str, Any] + ) -> Dict[str, Any]: + """Look up and execute a tool by name. Returns error dict for unknown tools.""" + entry = self._tools.get(tool_name) + if entry is None: + return {"error": f"Unknown tool: {tool_name}"} + + try: + # Parse string content to dict if needed (matches ToolProcessor behavior) + if isinstance(params.get("content"), str): + params = json.loads(params["content"]) + elif "content" in params: + params = params["content"] + return await entry.callable(**params) + except Exception as e: + return {"error": f"Tool execution failed: {str(e)}"} + + @classmethod + def from_agents(cls, agents: Dict[str, Agent]) -> "ToolRegistry": + """Build registry from agent configurations.""" + registry = cls() + + for agent_name, agent in agents.items(): + tool_names: List[str] = [] + for tool_def in agent.tools: + tool_names.append(tool_def.name) + # Only register once (tools may be shared across agents) + if tool_def.name not in registry._tools: + registry.register( + name=tool_def.name, + callable=tool_def.callable, + schema={ + "description": tool_def.description, + "input_schema": tool_def.input_schema, + }, + ) + registry._agent_tool_names[agent_name] = tool_names + + return registry diff --git a/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/src/agents/tools.py b/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/src/agents/tools.py new file mode 100644 index 00000000..bb5845d7 --- /dev/null +++ b/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/src/agents/tools.py @@ -0,0 +1,29 @@ +"""Tool implementations for agent actions.""" +import asyncio +from typing import Dict, Any + + +async def open_ticket_tool(issue_description: str, customer_name: str) -> Dict[str, Any]: + """Create support ticket.""" + ticket_id = 'A1Z3R' + return { + "status": "success", + "message": f"Support ticket {ticket_id} created for {customer_name} regarding: '{issue_description}'. Team will contact within 4 hours.", + "ticket_id": ticket_id + } + + +async def order_computers_tool(computer_type: str, customer_name: str) -> Dict[str, Any]: + """Place computer order.""" + return { + "status": "success", + "message": f"{computer_type.title()} order placed successfully for {customer_name}. Confirmation sent to email." + } + + +async def check_order_location_tool(order_id: str, customer_name: str) -> Dict[str, Any]: + """Check order location and status.""" + return { + "status": "success", + "message": f"Order {order_id} for {customer_name} in transit from Seattle warehouse. Arrives in 2-3 business days." + } diff --git a/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/src/audio/__init__.py b/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/src/audio/__init__.py new file mode 100644 index 00000000..aa4d72b9 --- /dev/null +++ b/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/src/audio/__init__.py @@ -0,0 +1 @@ +"""Audio streaming and I/O.""" diff --git a/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/src/audio/audio_streamer.py b/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/src/audio/audio_streamer.py new file mode 100644 index 00000000..45769d79 --- /dev/null +++ b/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/src/audio/audio_streamer.py @@ -0,0 +1,178 @@ +"""Audio streaming for microphone input and speaker output.""" +import asyncio +from typing import Callable, Awaitable, Optional +import pyaudio +from src.config import INPUT_SAMPLE_RATE, OUTPUT_SAMPLE_RATE, CHANNELS, CHUNK_SIZE +from src.utils import debug_print, time_it, time_it_async + + +FORMAT = pyaudio.paInt16 + + +class AudioStreamer: + """Handles continuous audio I/O. + + Implements the StreamCallback protocol (on_audio_output, on_barge_in, + on_switch_requested) so that SessionController can push events without + AudioStreamer knowing about stream internals. + """ + + def __init__( + self, + send_audio_fn: Callable[[bytes], Awaitable[None]], + send_audio_content_start_fn: Optional[Callable[[], Awaitable[None]]] = None, + ): + self._send_audio = send_audio_fn + self._send_audio_content_start = send_audio_content_start_fn + self.is_streaming = False + self._audio_output_queue: asyncio.Queue[bytes] = asyncio.Queue() + self._stop_event = asyncio.Event() + self.loop = asyncio.get_event_loop() + + # Initialize PyAudio + debug_print("Initializing PyAudio") + self.p = pyaudio.PyAudio() + + # Input stream with callback + debug_print("Opening input stream") + self.input_stream = self.p.open( + format=FORMAT, + channels=CHANNELS, + rate=INPUT_SAMPLE_RATE, + input=True, + frames_per_buffer=CHUNK_SIZE, + stream_callback=self.input_callback, + ) + + # Output stream for direct writing + debug_print("Opening output stream") + self.output_stream = self.p.open( + format=FORMAT, + channels=CHANNELS, + rate=OUTPUT_SAMPLE_RATE, + output=True, + frames_per_buffer=CHUNK_SIZE, + ) + + # --- StreamCallback implementation --- + + def on_audio_output(self, audio_bytes: bytes) -> None: + """Enqueue decoded audio for playback.""" + self._audio_output_queue.put_nowait(audio_bytes) + + def on_barge_in(self) -> None: + """Drain the audio output queue so playback stops immediately.""" + while not self._audio_output_queue.empty(): + try: + self._audio_output_queue.get_nowait() + except asyncio.QueueEmpty: + break + + def on_switch_requested(self) -> None: + """Signal the streaming loop to stop for an agent switch.""" + self._stop_event.set() + + # --- Audio I/O --- + + def input_callback(self, in_data, frame_count, time_info, status): + """Callback for microphone input.""" + if self.is_streaming and in_data: + asyncio.run_coroutine_threadsafe( + self.process_input_audio(in_data), + self.loop, + ) + return (None, pyaudio.paContinue) + + async def process_input_audio(self, audio_data: bytes): + """Process single audio chunk by sending it to SessionController.""" + try: + await self._send_audio(audio_data) + except Exception as e: + if self.is_streaming: + debug_print(f"Error sending audio input: {e}") + + async def play_output_audio(self): + """Play audio responses from the internal queue.""" + while self.is_streaming: + try: + audio_data = await asyncio.wait_for( + self._audio_output_queue.get(), + timeout=0.1, + ) + + if audio_data and self.is_streaming: + for i in range(0, len(audio_data), CHUNK_SIZE): + if not self.is_streaming: + break + chunk = audio_data[i : i + CHUNK_SIZE] + await self.loop.run_in_executor( + None, self.output_stream.write, chunk + ) + await asyncio.sleep(0.001) + + except asyncio.TimeoutError: + continue + except Exception as e: + if self.is_streaming: + print(f"Error playing output: {e}") + await asyncio.sleep(0.05) + + async def start_streaming(self): + """Start audio streaming. Blocks until _stop_event is set.""" + if self.is_streaming: + return + + self.is_streaming = True + self._stop_event.clear() + + if self._send_audio_content_start: + await time_it_async( + "send_audio_content_start", + self._send_audio_content_start, + ) + + print("🎤 Streaming started. Speak into microphone...") + + if not self.input_stream.is_active(): + self.input_stream.start_stream() + + self.output_task = asyncio.create_task(self.play_output_audio()) + + # Wait for stop event (set by on_switch_requested) instead of polling + await self._stop_event.wait() + print("🔄 Agent switch detected") + self.is_streaming = False + + await self.stop_streaming() + + async def stop_streaming(self): + """Stop audio streaming and release resources.""" + self.is_streaming = False + + # Cancel output task + if hasattr(self, "output_task") and not self.output_task.done(): + self.output_task.cancel() + await asyncio.gather(self.output_task, return_exceptions=True) + + # Close streams safely + if self.input_stream: + try: + if self.input_stream.is_active(): + self.input_stream.stop_stream() + self.input_stream.close() + except OSError: + pass + self.input_stream = None + + if self.output_stream: + try: + if self.output_stream.is_active(): + self.output_stream.stop_stream() + self.output_stream.close() + except OSError: + pass + self.output_stream = None + + if self.p: + self.p.terminate() + self.p = None diff --git a/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/src/config.py b/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/src/config.py new file mode 100644 index 00000000..a0948b57 --- /dev/null +++ b/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/src/config.py @@ -0,0 +1,19 @@ +"""Configuration constants for Nova 2 Sonic application.""" + +# Audio Configuration +INPUT_SAMPLE_RATE = 16000 +OUTPUT_SAMPLE_RATE = 24000 +CHANNELS = 1 +CHUNK_SIZE = 1024 + +# AWS Configuration +DEFAULT_MODEL_ID = 'amazon.nova-2-sonic-v1:0' +DEFAULT_REGION = 'us-east-1' + +# Model Configuration +MAX_TOKENS = 10000 +TOP_P = 0.0 +TEMPERATURE = 0.2 + +# Debug +DEBUG = False diff --git a/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/src/connection/__init__.py b/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/src/connection/__init__.py new file mode 100644 index 00000000..1103b3e8 --- /dev/null +++ b/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/src/connection/__init__.py @@ -0,0 +1 @@ +"""Bedrock connection and protocol layer.""" diff --git a/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/src/connection/bedrock_connection.py b/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/src/connection/bedrock_connection.py new file mode 100644 index 00000000..b32f801e --- /dev/null +++ b/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/src/connection/bedrock_connection.py @@ -0,0 +1,145 @@ +"""Thin wrapper for the raw bidirectional Bedrock stream. + +No parsing, no business logic — just open, send, receive, close. +""" +import json +import logging +import sys +from typing import AsyncIterator + +from aws_sdk_bedrock_runtime.client import ( + BedrockRuntimeClient, + InvokeModelWithBidirectionalStreamOperationInput, +) +from aws_sdk_bedrock_runtime.models import ( + InvokeModelWithBidirectionalStreamInputChunk, + BidirectionalInputPayloadPart, +) +from aws_sdk_bedrock_runtime.config import Config +from smithy_aws_core.identity.environment import EnvironmentCredentialsResolver + +from src.utils import debug_print, time_it_async + +logger = logging.getLogger("sonic.connection") + + +class BedrockConnection: + """Manages the raw bidirectional stream lifecycle with AWS Bedrock. + + Responsibilities: + - Initialise the Bedrock client + - Open the bidirectional stream + - Send raw JSON event strings + - Yield raw JSON response strings + - Close the stream (idempotent) + """ + + def __init__(self, model_id: str, region: str) -> None: + self.model_id = model_id + self.region = region + self._client: BedrockRuntimeClient | None = None + self._stream_response = None + self._closed = False + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + async def open(self) -> None: + """Initialise the client and open the bidirectional stream. + + Raises on connection failure so the caller can decide how to handle it. + """ + config = Config( + endpoint_uri=f"https://bedrock-runtime.{self.region}.amazonaws.com", + region=self.region, + aws_credentials_identity_resolver=EnvironmentCredentialsResolver(), + ) + self._client = BedrockRuntimeClient(config=config) + + self._stream_response = await time_it_async( + "invoke_model_with_bidirectional_stream", + lambda: self._client.invoke_model_with_bidirectional_stream( + InvokeModelWithBidirectionalStreamOperationInput(model_id=self.model_id) + ), + ) + self._closed = False + logger.info("Connection opened (model=%s, region=%s)", self.model_id, self.region) + debug_print("Connection opened") + + async def send(self, event_json: str) -> None: + """Send a raw JSON event string to Bedrock. + + Silently logs and returns if the stream is already closed. + """ + if self._closed or self._stream_response is None: + debug_print("Send skipped — stream not active") + return + + event = InvokeModelWithBidirectionalStreamInputChunk( + value=BidirectionalInputPayloadPart(bytes_=event_json.encode("utf-8")) + ) + + try: + await self._stream_response.input_stream.send(event) + except Exception as e: + logger.error("Error sending event: %s", e) + + async def receive(self) -> AsyncIterator[str]: + """Yield raw JSON response strings from Bedrock.""" + if self._stream_response is None: + debug_print("receive() called but no stream") + return + + debug_print("receive() loop starting") + while True: + try: + output = await self._stream_response.await_output() + result = await output[1].receive() + + if result.value and result.value.bytes_: + raw = result.value.bytes_.decode("utf-8") + yield raw + else: + debug_print("Received empty response") + + except StopAsyncIteration: + debug_print("Stream ended (StopAsyncIteration)") + break + except Exception as e: + if "InvalidStateError" in str(e) or "CANCELLED" in str(e): + debug_print("Stream cancelled") + elif "ValidationException" in str(e): + logger.error("Validation error from Bedrock: %s", e) + else: + logger.error("Error receiving: %s", e) + break + + debug_print("receive() loop ended") + + async def close(self) -> None: + """Close the stream and client. Idempotent — safe to call multiple times.""" + if self._closed: + return + + self._closed = True + debug_print("Closing connection") + + if self._stream_response is not None: + try: + await self._stream_response.input_stream.close() + except Exception as e: + debug_print(f"Error closing input stream: {e}") + + self._stream_response = None + logger.info("Connection closed") + debug_print("Connection closed") + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + @property + def is_open(self) -> bool: + """True when the stream has been opened and not yet closed.""" + return self._stream_response is not None and not self._closed diff --git a/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/src/connection/event_templates.py b/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/src/connection/event_templates.py new file mode 100644 index 00000000..9e0e919e --- /dev/null +++ b/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/src/connection/event_templates.py @@ -0,0 +1,188 @@ +"""Event templates for Bedrock streaming.""" +import json +from typing import Dict, Any, List +from src.config import MAX_TOKENS, TOP_P, TEMPERATURE, INPUT_SAMPLE_RATE, OUTPUT_SAMPLE_RATE + + +class EventTemplates: + """Bedrock event template generator.""" + + @staticmethod + def start_session() -> str: + """Create session start event.""" + return json.dumps({ + "event": { + "sessionStart": { + "inferenceConfiguration": { + "maxTokens": MAX_TOKENS, + "topP": TOP_P, + "temperature": TEMPERATURE + } + } + } + }) + + @staticmethod + def content_start(prompt_name: str, content_name: str, role: str = "USER") -> str: + """Create audio content start event.""" + return json.dumps({ + "event": { + "contentStart": { + "promptName": prompt_name, + "contentName": content_name, + "type": "AUDIO", + "interactive": True, + "role": role, + "audioInputConfiguration": { + "mediaType": "audio/lpcm", + "sampleRateHertz": INPUT_SAMPLE_RATE, + "sampleSizeBits": 16, + "channelCount": 1, + "audioType": "SPEECH", + "encoding": "base64" + } + } + } + }) + + @staticmethod + def audio_input(prompt_name: str, content_name: str, audio_base64: str) -> str: + """Create audio input event.""" + return json.dumps({ + "event": { + "audioInput": { + "promptName": prompt_name, + "contentName": content_name, + "content": audio_base64 + } + } + }) + + @staticmethod + def text_content_start(prompt_name: str, content_name: str, role: str, interactive: bool = False) -> str: + """Create text content start event.""" + return json.dumps({ + "event": { + "contentStart": { + "promptName": prompt_name, + "contentName": content_name, + "type": "TEXT", + "role": role, + "interactive": interactive, + "textInputConfiguration": { + "mediaType": "text/plain" + } + } + } + }) + + @staticmethod + def text_input(prompt_name: str, content_name: str, content: str) -> str: + """Create text input event.""" + return json.dumps({ + "event": { + "textInput": { + "promptName": prompt_name, + "contentName": content_name, + "content": content + } + } + }) + + @staticmethod + def tool_content_start(prompt_name: str, content_name: str, tool_use_id: str) -> str: + """Create tool content start event.""" + return json.dumps({ + "event": { + "contentStart": { + "promptName": prompt_name, + "contentName": content_name, + "interactive": False, + "type": "TOOL", + "role": "TOOL", + "toolResultInputConfiguration": { + "toolUseId": tool_use_id, + "type": "TEXT", + "textInputConfiguration": { + "mediaType": "text/plain" + } + } + } + } + }) + + @staticmethod + def tool_result(prompt_name: str, content_name: str, content: Any) -> str: + """Create tool result event.""" + content_str = json.dumps(content) if isinstance(content, dict) else str(content) + return json.dumps({ + "event": { + "toolResult": { + "promptName": prompt_name, + "contentName": content_name, + "content": content_str + } + } + }) + + @staticmethod + def content_end(prompt_name: str, content_name: str) -> str: + """Create content end event.""" + return json.dumps({ + "event": { + "contentEnd": { + "promptName": prompt_name, + "contentName": content_name + } + } + }) + + @staticmethod + def prompt_end(prompt_name: str) -> str: + """Create prompt end event.""" + return json.dumps({ + "event": { + "promptEnd": { + "promptName": prompt_name + } + } + }) + + @staticmethod + def session_end() -> str: + """Create session end event.""" + return json.dumps({ + "event": { + "sessionEnd": {} + } + }) + + @staticmethod + def prompt_start(prompt_name: str, voice_id: str, tool_schemas: List[Dict[str, Any]]) -> str: + """Create prompt start event with tool configuration. + + Args: + prompt_name: Name for the prompt. + voice_id: Bedrock voice identifier. + tool_schemas: Complete list of Bedrock-compatible tool schema dicts + (e.g. [{"toolSpec": {...}}, ...]). Passed through as-is. + """ + return json.dumps({ + "event": { + "promptStart": { + "promptName": prompt_name, + "textOutputConfiguration": {"mediaType": "text/plain"}, + "audioOutputConfiguration": { + "mediaType": "audio/lpcm", + "sampleRateHertz": OUTPUT_SAMPLE_RATE, + "sampleSizeBits": 16, + "channelCount": 1, + "voiceId": voice_id, + "encoding": "base64", + "audioType": "SPEECH" + }, + "toolUseOutputConfiguration": {"mediaType": "application/json"}, + "toolConfiguration": {"tools": tool_schemas} + } + } + }) diff --git a/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/src/connection/response_parser.py b/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/src/connection/response_parser.py new file mode 100644 index 00000000..06417bd7 --- /dev/null +++ b/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/src/connection/response_parser.py @@ -0,0 +1,94 @@ +"""Stateless parser that transforms raw Bedrock JSON into typed stream events.""" +import json +from typing import Dict, Any + +from src.connection.stream_events import ( + StreamEvent, + CompletionStartEvent, + ContentStartEvent, + TextOutputEvent, + AudioOutputEvent, + ToolUseEvent, + BargeInEvent, + ContentEndEvent, + CompletionEndEvent, + UsageEvent, + UnknownEvent, +) + + +class ResponseParser: + """Stateless transformer from raw Bedrock JSON to typed events. + + Contains no business logic — only protocol-level detection. + """ + + @staticmethod + def parse(response_data: str) -> StreamEvent: + """Parse a raw Bedrock JSON string into a typed StreamEvent.""" + try: + json_data = json.loads(response_data) + except (json.JSONDecodeError, TypeError): + return UnknownEvent(raw_data=response_data) + + if "event" not in json_data: + return UnknownEvent(raw_data=response_data) + + event = json_data["event"] + + if "completionStart" in event: + return CompletionStartEvent(data=event) + + if "contentStart" in event: + return ResponseParser._parse_content_start(event["contentStart"]) + + if "textOutput" in event: + return ResponseParser._parse_text_output(event["textOutput"]) + + if "audioOutput" in event: + return AudioOutputEvent(audio_base64=event["audioOutput"].get("content", "")) + + if "toolUse" in event: + return ResponseParser._parse_tool_use(event["toolUse"]) + + if "contentEnd" in event: + return ContentEndEvent(content_type=event["contentEnd"].get("type")) + + if "completionEnd" in event: + return CompletionEndEvent() + + if "usageEvent" in event: + return UsageEvent(data=event) + + return UnknownEvent(raw_data=response_data) + + @staticmethod + def _parse_content_start(content_start: Dict[str, Any]) -> ContentStartEvent: + role = content_start.get("role", "") + is_final = False + if "additionalModelFields" in content_start: + try: + fields = json.loads(content_start["additionalModelFields"]) + is_final = fields.get("generationStage") == "FINAL" + except (json.JSONDecodeError, TypeError): + pass + return ContentStartEvent(role=role, is_final_response=is_final) + + @staticmethod + def _parse_text_output(text_output: Dict[str, Any]) -> StreamEvent: + content = text_output.get("content", "") + role = text_output.get("role", "") + if '{ "interrupted" : true }' in content: + return BargeInEvent() + return TextOutputEvent(content=content, role=role) + + @staticmethod + def _parse_tool_use(tool_use: Dict[str, Any]) -> ToolUseEvent: + tool_name = tool_use.get("toolName", "") + tool_use_id = tool_use.get("toolUseId", "") + raw_content = tool_use.get("content", "{}") + try: + content = json.loads(raw_content) if isinstance(raw_content, str) else raw_content + except (json.JSONDecodeError, TypeError): + content = {} + return ToolUseEvent(tool_name=tool_name, tool_use_id=tool_use_id, content=content) diff --git a/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/src/connection/stream_events.py b/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/src/connection/stream_events.py new file mode 100644 index 00000000..dfd69c02 --- /dev/null +++ b/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/src/connection/stream_events.py @@ -0,0 +1,71 @@ +"""Typed event dataclasses for Bedrock stream responses.""" +from dataclasses import dataclass +from typing import Dict, Any, Optional, Union + + +@dataclass +class CompletionStartEvent: + data: Dict[str, Any] + + +@dataclass +class ContentStartEvent: + role: str + is_final_response: bool = False + + +@dataclass +class TextOutputEvent: + content: str + role: str + + +@dataclass +class AudioOutputEvent: + audio_base64: str + + +@dataclass +class ToolUseEvent: + tool_name: str + tool_use_id: str + content: Dict[str, Any] + + +@dataclass +class BargeInEvent: + pass + + +@dataclass +class ContentEndEvent: + content_type: Optional[str] = None + + +@dataclass +class CompletionEndEvent: + pass + + +@dataclass +class UsageEvent: + data: Dict[str, Any] + + +@dataclass +class UnknownEvent: + raw_data: str + + +StreamEvent = Union[ + CompletionStartEvent, + ContentStartEvent, + TextOutputEvent, + AudioOutputEvent, + ToolUseEvent, + BargeInEvent, + ContentEndEvent, + CompletionEndEvent, + UsageEvent, + UnknownEvent, +] diff --git a/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/src/multi_agent.py b/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/src/multi_agent.py new file mode 100644 index 00000000..ab95587a --- /dev/null +++ b/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/src/multi_agent.py @@ -0,0 +1,111 @@ +"""Multi-agent orchestrator for Nova 2 Sonic conversations.""" +import asyncio +import logging +import os +import pygame + +from src.session.conversation_state import ConversationState +from src.agents.tool_registry import ToolRegistry +from src.connection.bedrock_connection import BedrockConnection +from src.session.session_controller import SessionController +from src.audio.audio_streamer import AudioStreamer +from src.agents.agent_config import AGENTS + + +logger = logging.getLogger("sonic.orchestrator") + + +class MultiAgentSonic: + """Orchestrates multi-agent voice conversations.""" + + def __init__(self, model_id: str, region: str, debug: bool = False): + self.model_id = model_id + self.region = region + self.debug = debug + self.state = ConversationState() + self.registry = ToolRegistry.from_agents(AGENTS) + logger.info("Initialized MultiAgentSonic (model=%s, region=%s)", model_id, region) + + async def start_conversation(self): + """Start voice conversation with agent switching.""" + while True: + try: + agent = AGENTS.get(self.state.active_agent, AGENTS["support"]) + logger.info("Starting session with agent=%s, voice=%s", + self.state.active_agent, agent.voice_id) + print(f"🎤 Starting conversation with {self.state.active_agent.title()}...") + + await asyncio.sleep(1) + + # Create per-session components + connection = BedrockConnection(self.model_id, self.region) + audio_streamer = AudioStreamer(send_audio_fn=lambda b: None) # wired below + controller = SessionController( + connection=connection, + state=self.state, + registry=self.registry, + callback=audio_streamer, + voice_id=agent.voice_id, + system_prompt=agent.instruction, + ) + # Wire audio streamer to send audio through the controller + audio_streamer._send_audio = controller.send_audio + audio_streamer._send_audio_content_start = controller.send_audio_content_start_event + + # Initialize and start + await controller.start_session() + + # Stop transition music + self._stop_music() + + # Start conversation (blocks until stop event) + await audio_streamer.start_streaming() + + # Check for agent switch + if self.state.switch_requested: + old = self.state.active_agent + new = self.state.complete_switch() + logger.info("Agent switch: %s → %s", old, new) + print(f"🔄 Switching: {old} → {new}") + + # Play transition music + self._play_music() + + # Close connection + await controller.stop() + await audio_streamer.stop_streaming() + continue + else: + print("👋 Conversation ended") + break + + except KeyboardInterrupt: + print("\n👋 Interrupted by user") + break + except Exception as e: + logger.exception("Session error") + print(f"Error: {e}") + if self.debug: + import traceback + traceback.print_exc() + break + + def _play_music(self): + """Play transition music.""" + try: + pygame.mixer.init() + music_path = os.path.join(os.path.dirname(__file__), "..", "assets", "music.mp3") + if os.path.exists(music_path): + pygame.mixer.music.load(music_path) + pygame.mixer.music.play(-1) + print("🎵 Playing transition music") + except Exception as e: + print(f"Could not play music: {e}") + + def _stop_music(self): + """Stop transition music.""" + try: + pygame.mixer.music.stop() + print("🎵 Stopped transition music") + except: + pass diff --git a/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/src/session/__init__.py b/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/src/session/__init__.py new file mode 100644 index 00000000..bcc1e5e3 --- /dev/null +++ b/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/src/session/__init__.py @@ -0,0 +1 @@ +"""Session management and conversation state.""" diff --git a/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/src/session/callbacks.py b/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/src/session/callbacks.py new file mode 100644 index 00000000..619dc710 --- /dev/null +++ b/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/src/session/callbacks.py @@ -0,0 +1,9 @@ +"""StreamCallback protocol for decoupling SessionController from AudioStreamer.""" +from typing import Protocol, runtime_checkable + + +@runtime_checkable +class StreamCallback(Protocol): + def on_audio_output(self, audio_bytes: bytes) -> None: ... + def on_barge_in(self) -> None: ... + def on_switch_requested(self) -> None: ... diff --git a/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/src/session/conversation_state.py b/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/src/session/conversation_state.py new file mode 100644 index 00000000..87de63cf --- /dev/null +++ b/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/src/session/conversation_state.py @@ -0,0 +1,39 @@ +"""Conversation state management.""" +from dataclasses import dataclass, field +from typing import Dict, List, Optional + + +@dataclass +class ConversationState: + """Owns conversation history, active agent, and switch state.""" + + active_agent: str = "support" + conversation_history: List[Dict[str, str]] = field(default_factory=list) + switch_requested: bool = False + switch_target: Optional[str] = None + + def append_message(self, role: str, content: str) -> None: + """Append a message preserving role and content.""" + self.conversation_history.append({"role": role, "content": content}) + + def request_switch(self, target_agent: str) -> None: + """Record an agent switch request atomically.""" + self.switch_target = target_agent + self.switch_requested = True + + def complete_switch(self) -> str: + """Complete the pending switch, returning the new agent name.""" + agent = self.switch_target + self.active_agent = agent + self.switch_requested = False + self.switch_target = None + return agent + + def get_history(self) -> List[Dict[str, str]]: + """Return a copy of the conversation history.""" + return list(self.conversation_history) + + def reset_switch(self) -> None: + """Clear any pending switch request.""" + self.switch_requested = False + self.switch_target = None diff --git a/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/src/session/session_controller.py b/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/src/session/session_controller.py new file mode 100644 index 00000000..5f185b84 --- /dev/null +++ b/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/src/session/session_controller.py @@ -0,0 +1,451 @@ +"""Session controller — conversation lifecycle coordinator. + +Wires together BedrockConnection, ResponseParser, ConversationState, +ToolRegistry, EventTemplates, and StreamCallback. All business logic +decisions (agent switching, barge-in handling, tool execution, history +updates) live here. +""" +import asyncio +import base64 +import json +import logging +import uuid +from typing import Optional, Set + +from src.connection.bedrock_connection import BedrockConnection +from src.session.conversation_state import ConversationState +from src.agents.tool_registry import ToolRegistry +from src.connection.stream_events import ( + StreamEvent, + CompletionStartEvent, + ContentStartEvent, + TextOutputEvent, + AudioOutputEvent, + ToolUseEvent, + BargeInEvent, + ContentEndEvent, + CompletionEndEvent, + UsageEvent, +) +from src.session.callbacks import StreamCallback +from src.connection.response_parser import ResponseParser +from src.connection.event_templates import EventTemplates +from src.utils import debug_print +from src.agents.agent_config import AGENTS + +logger = logging.getLogger("sonic.session") + + +class SessionController: + """Central coordinator for a single conversation session. + + Responsibilities: + - Open the Bedrock connection and send the initialisation sequence + - Receive raw responses, parse them into typed events, dispatch business logic + - Execute tools asynchronously and send results back + - Notify the AudioStreamer (via StreamCallback) of audio output, barge-in, and switch events + - Manage session teardown + """ + + def __init__( + self, + connection: BedrockConnection, + state: ConversationState, + registry: ToolRegistry, + callback: StreamCallback, + voice_id: str, + system_prompt: str, + ) -> None: + self._connection = connection + self._state = state + self._registry = registry + self._callback = callback + self._voice_id = voice_id + self._system_prompt = system_prompt + + # Session IDs + self._prompt_name = str(uuid.uuid4()) + self._content_name = str(uuid.uuid4()) + self._audio_content_name = str(uuid.uuid4()) + + # Response tracking + self._display_assistant_text = False + self._role: Optional[str] = None + + # Tool handling + self._pending_tool_tasks: dict[str, asyncio.Task] = {} + self._tool_name = "" + self._tool_use_id = "" + self._tool_use_content: dict = {} + + # Background tasks + self._response_task: Optional[asyncio.Task] = None + self._is_active = False + self._audio_send_count = 0 + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + async def start_session(self) -> None: + """Open connection, send init events, start response processing.""" + debug_print("Opening connection") + await self._connection.open() + self._is_active = True + + debug_print("Sending initialization events") + await self._send_initialization_events() + debug_print("Initialization events sent") + + self._response_task = asyncio.create_task(self._process_responses()) + await asyncio.sleep(0.1) + logger.info("Session started (voice=%s)", self._voice_id) + + async def send_audio(self, audio_bytes: bytes) -> None: + """Encode raw audio bytes as base64 and send to Bedrock.""" + self._audio_send_count += 1 + if self._audio_send_count % 50 == 1: + debug_print(f"Audio chunk #{self._audio_send_count} ({len(audio_bytes)} bytes)") + blob = base64.b64encode(audio_bytes).decode("utf-8") + event = EventTemplates.audio_input( + self._prompt_name, self._audio_content_name, blob + ) + await self._connection.send(event) + + async def send_audio_content_start_event(self) -> None: + """Send audio content start event.""" + event = EventTemplates.content_start( + self._prompt_name, self._audio_content_name + ) + await self._connection.send(event) + + async def send_audio_content_end_event(self) -> None: + """Send audio content end event.""" + if self._is_active: + event = EventTemplates.content_end( + self._prompt_name, self._audio_content_name + ) + await self._connection.send(event) + debug_print("Audio ended") + + async def stop(self) -> None: + """Close connection and cancel pending tasks.""" + if not self._is_active: + return + + debug_print("Stopping session") + logger.info("Stopping session (pending_tools=%d)", len(self._pending_tool_tasks)) + self._is_active = False + + # Cancel pending tool tasks + for task in self._pending_tool_tasks.values(): + task.cancel() + self._pending_tool_tasks.clear() + + # Cancel response processing + if self._response_task and not self._response_task.done(): + self._response_task.cancel() + + # Send close sequence + try: + await self.send_audio_content_end_event() + await self._connection.send( + EventTemplates.prompt_end(self._prompt_name) + ) + await self._connection.send(EventTemplates.session_end()) + except Exception as e: + debug_print(f"Error during stop: {e}") + + await self._connection.close() + logger.info("Session stopped") + debug_print("Session stopped") + + # ------------------------------------------------------------------ + # Properties + # ------------------------------------------------------------------ + + @property + def is_active(self) -> bool: + return self._is_active + + @property + def prompt_name(self) -> str: + return self._prompt_name + + @property + def audio_content_name(self) -> str: + return self._audio_content_name + + # ------------------------------------------------------------------ + # Initialisation sequence + # ------------------------------------------------------------------ + + async def _send_initialization_events(self) -> None: + """Send the startup sequence: session start, prompt start, system prompt, history, greeting.""" + system_prompt = self._system_prompt or "You are a friend engaging in natural real-time conversation." + + tool_schemas = self._registry.get_schemas_for_agent( + self._state.active_agent, AGENTS + ) + + events = [ + EventTemplates.start_session(), + EventTemplates.prompt_start( + self._prompt_name, self._voice_id, tool_schemas + ), + EventTemplates.text_content_start( + self._prompt_name, self._content_name, "SYSTEM" + ), + EventTemplates.text_input( + self._prompt_name, self._content_name, system_prompt + ), + EventTemplates.content_end(self._prompt_name, self._content_name), + ] + + for event in events: + await self._connection.send(event) + await asyncio.sleep(0.1) + + # Send conversation history + history = self._state.get_history() + if history: + print(f"📝 Add conversation history: {len(history)} messages") + debug_print(f"Sending history: {len(history)} messages") + # Drop last message and leading assistant messages (matches original behaviour) + history = history[:-1] + while history and history[0].get("role") == "ASSISTANT": + history.pop(0) + for msg in history: + await self._send_history_message(msg) + + # Send greeting prompt + speak_first_content_name = str(uuid.uuid4()) + greeting_events = [ + EventTemplates.text_content_start( + self._prompt_name, + content_name=speak_first_content_name, + role="USER", + interactive=True, + ), + EventTemplates.text_input( + self._prompt_name, + speak_first_content_name, + "Greet the user with his name and SHORT explanation your role", + ), + EventTemplates.content_end( + self._prompt_name, speak_first_content_name + ), + ] + for event in greeting_events: + await self._connection.send(event) + await asyncio.sleep(0.1) + + async def _send_history_message(self, message: dict) -> None: + """Send a single history message to Bedrock.""" + history_content_name = str(uuid.uuid4()) + events = [ + EventTemplates.text_content_start( + self._prompt_name, history_content_name, message["role"] + ), + EventTemplates.text_input( + self._prompt_name, history_content_name, message["content"] + ), + EventTemplates.content_end(self._prompt_name, history_content_name), + ] + for event in events: + await self._connection.send(event) + await asyncio.sleep(0.1) + + # ------------------------------------------------------------------ + # Response processing loop + # ------------------------------------------------------------------ + + async def _process_responses(self) -> None: + """Main response loop: receive → parse → dispatch.""" + debug_print("Response processing loop started") + try: + async for raw_data in self._connection.receive(): + if not self._is_active or self._state.switch_requested: + break + event = ResponseParser.parse(raw_data) + debug_print(f"Event: {type(event).__name__}") + await self._dispatch_event(event) + except asyncio.CancelledError: + debug_print("Response processing cancelled") + except Exception as e: + logger.error("Response processing error: %s", e, exc_info=True) + finally: + self._is_active = False + + # ------------------------------------------------------------------ + # Event dispatch — all business logic lives here + # ------------------------------------------------------------------ + + async def _dispatch_event(self, event: StreamEvent) -> None: + """Route typed events to the appropriate business logic handler.""" + + if isinstance(event, CompletionStartEvent): + debug_print(f"Completion start: {event.data}") + + elif isinstance(event, ContentStartEvent): + self._handle_content_start(event) + + elif isinstance(event, TextOutputEvent): + self._handle_text_output(event) + + elif isinstance(event, AudioOutputEvent): + self._handle_audio_output(event) + + elif isinstance(event, BargeInEvent): + self._handle_barge_in() + + elif isinstance(event, ToolUseEvent): + await self._handle_tool_use(event) + + elif isinstance(event, ContentEndEvent): + self._handle_content_end(event) + + elif isinstance(event, CompletionEndEvent): + debug_print("Completion end") + + elif isinstance(event, UsageEvent): + debug_print(f"Usage: {event.data}") + + # ------------------------------------------------------------------ + # Business logic handlers + # ------------------------------------------------------------------ + + def _handle_content_start(self, event: ContentStartEvent) -> None: + """Track role and whether this is a final (displayable) response.""" + debug_print("Content start") + self._role = event.role + if event.is_final_response: + self._display_assistant_text = True + else: + self._display_assistant_text = False + + def _handle_text_output(self, event: TextOutputEvent) -> None: + """Append to conversation state and/or print to console.""" + role = event.role + content = event.content + + if (self._role == "ASSISTANT" and self._display_assistant_text) or self._role == "USER": + self._state.append_message(role, content) + if (self._role == "ASSISTANT" and not self._display_assistant_text) or self._role == "USER": + print(f"{role.title()}: {content}") + + def _handle_audio_output(self, event: AudioOutputEvent) -> None: + """Decode base64 audio and push to callback.""" + audio_bytes = base64.b64decode(event.audio_base64) + debug_print(f"Audio output: {len(audio_bytes)} bytes") + self._callback.on_audio_output(audio_bytes) + + def _handle_barge_in(self) -> None: + """Notify callback of barge-in.""" + debug_print("Barge-in detected") + self._callback.on_barge_in() + + async def _handle_tool_use(self, event: ToolUseEvent) -> None: + """Handle tool use — either agent switch or regular tool execution.""" + self._tool_name = event.tool_name + self._tool_use_id = event.tool_use_id + self._tool_use_content = { + "toolName": event.tool_name, + "toolUseId": event.tool_use_id, + "content": event.content, + } + + if event.tool_name == "switch_agent": + target = event.content.get("role", "support").lower() + logger.info("Agent switch requested → %s", target) + self._state.request_switch(target) + await asyncio.sleep(0.1) + self._callback.on_switch_requested() + print(f"🎯 Switching to: {target}") + else: + logger.info("Tool invoked: %s (id=%s)", event.tool_name, event.tool_use_id) + print(f"🎯 Tool use: {event.tool_name}") + debug_print(f"Tool: {event.tool_name}, ID: {event.tool_use_id}") + + def _handle_content_end(self, event: ContentEndEvent) -> None: + """On TOOL content end, kick off async tool execution.""" + if event.content_type == "TOOL": + debug_print("Processing tool") + self._execute_tool_async( + self._tool_name, self._tool_use_content, self._tool_use_id + ) + else: + debug_print("Content end") + + # ------------------------------------------------------------------ + # Tool execution + # ------------------------------------------------------------------ + + def _execute_tool_async( + self, tool_name: str, tool_content: dict, tool_use_id: str + ) -> None: + """Fire-and-forget async tool execution with result sent back to Bedrock.""" + content_name = str(uuid.uuid4()) + task = asyncio.create_task( + self._execute_tool_and_send_result( + tool_name, tool_content, tool_use_id, content_name + ) + ) + self._pending_tool_tasks[content_name] = task + task.add_done_callback( + lambda t: self._handle_tool_completion(t, content_name) + ) + + def _handle_tool_completion(self, task: asyncio.Task, content_name: str) -> None: + """Clean up after a tool task finishes.""" + self._pending_tool_tasks.pop(content_name, None) + if task.done() and not task.cancelled(): + exc = task.exception() + if exc: + debug_print(f"Tool task failed: {exc}") + + async def _execute_tool_and_send_result( + self, + tool_name: str, + tool_content: dict, + tool_use_id: str, + content_name: str, + ) -> None: + """Execute a tool via ToolRegistry and send the result back to Bedrock.""" + try: + debug_print(f"Executing tool: {tool_name}") + result = await self._registry.execute(tool_name, tool_content) + + await self._connection.send( + EventTemplates.tool_content_start( + self._prompt_name, content_name, tool_use_id + ) + ) + await self._connection.send( + EventTemplates.tool_result( + self._prompt_name, content_name, result + ) + ) + await self._connection.send( + EventTemplates.content_end(self._prompt_name, content_name) + ) + debug_print(f"Tool complete: {tool_name}") + except Exception as e: + debug_print(f"Tool error: {e}") + try: + error_result = {"error": f"Tool failed: {e}"} + await self._connection.send( + EventTemplates.tool_content_start( + self._prompt_name, content_name, tool_use_id + ) + ) + await self._connection.send( + EventTemplates.tool_result( + self._prompt_name, content_name, error_result + ) + ) + await self._connection.send( + EventTemplates.content_end(self._prompt_name, content_name) + ) + except Exception as send_error: + debug_print(f"Failed to send error: {send_error}") diff --git a/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/src/utils.py b/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/src/utils.py new file mode 100644 index 00000000..90ac580b --- /dev/null +++ b/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/src/utils.py @@ -0,0 +1,28 @@ +"""Utility functions for logging and performance monitoring.""" +import logging +import time + +logger = logging.getLogger("sonic") + + +def debug_print(message: str) -> None: + """Log a debug message. Replaces the old print-based approach.""" + logger.debug(message) + + +def time_it(label: str, func): + """Time synchronous function execution.""" + start = time.perf_counter() + result = func() + elapsed = time.perf_counter() - start + logger.debug("Execution time for %s: %.4fs", label, elapsed) + return result + + +async def time_it_async(label: str, func): + """Time asynchronous function execution.""" + start = time.perf_counter() + result = await func() + elapsed = time.perf_counter() - start + logger.debug("Execution time for %s: %.4fs", label, elapsed) + return result diff --git a/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/tests/__init__.py b/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/tests/test_bedrock_connection.py b/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/tests/test_bedrock_connection.py new file mode 100644 index 00000000..e8a0551a --- /dev/null +++ b/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/tests/test_bedrock_connection.py @@ -0,0 +1,224 @@ +"""Unit tests for BedrockConnection.""" +import json +import pytest +import pytest_asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +from src.connection.bedrock_connection import BedrockConnection + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +def _make_stream_response(): + """Build a mock stream_response with input_stream.send / .close and await_output.""" + stream = MagicMock() + stream.input_stream = MagicMock() + stream.input_stream.send = AsyncMock() + stream.input_stream.close = AsyncMock() + stream.await_output = AsyncMock() + return stream + + +@pytest_asyncio.fixture +async def connection(): + """Return a BedrockConnection with a mocked stream already open.""" + conn = BedrockConnection(model_id="test-model", region="us-east-1") + conn._stream_response = _make_stream_response() + conn._closed = False + return conn + + +# --------------------------------------------------------------------------- +# open() +# --------------------------------------------------------------------------- + +class TestOpen: + @pytest.mark.asyncio + async def test_open_sets_stream_response(self): + conn = BedrockConnection(model_id="test-model", region="us-east-1") + mock_stream = _make_stream_response() + + async def fake_time_it_async(label, func): + return mock_stream + + with patch("src.connection.bedrock_connection.BedrockRuntimeClient") as MockClient, \ + patch("src.connection.bedrock_connection.time_it_async", side_effect=fake_time_it_async): + await conn.open() + + assert conn._stream_response is mock_stream + assert conn.is_open is True + assert conn._closed is False + + @pytest.mark.asyncio + async def test_open_raises_on_failure(self): + conn = BedrockConnection(model_id="test-model", region="us-east-1") + + async def failing_time_it_async(label, func): + raise Exception("Connection refused") + + with patch("src.connection.bedrock_connection.BedrockRuntimeClient") as MockClient, \ + patch("src.connection.bedrock_connection.time_it_async", side_effect=failing_time_it_async): + with pytest.raises(Exception, match="Connection refused"): + await conn.open() + + +# --------------------------------------------------------------------------- +# send() +# --------------------------------------------------------------------------- + +class TestSend: + @pytest.mark.asyncio + async def test_send_encodes_and_forwards(self, connection): + event_json = json.dumps({"event": {"sessionStart": {}}}) + await connection.send(event_json) + connection._stream_response.input_stream.send.assert_awaited_once() + + @pytest.mark.asyncio + async def test_send_on_closed_stream_is_noop(self, connection): + connection._closed = True + await connection.send('{"event":{}}') + connection._stream_response.input_stream.send.assert_not_awaited() + + @pytest.mark.asyncio + async def test_send_on_none_stream_is_noop(self): + conn = BedrockConnection(model_id="m", region="r") + # Should not raise + await conn.send('{"event":{}}') + + @pytest.mark.asyncio + async def test_send_logs_error_on_exception(self, connection): + connection._stream_response.input_stream.send.side_effect = Exception("broken") + # Should not raise — errors are caught and logged + await connection.send('{"event":{}}') + + +# --------------------------------------------------------------------------- +# receive() +# --------------------------------------------------------------------------- + +class TestReceive: + @pytest.mark.asyncio + async def test_receive_yields_decoded_strings(self, connection): + payload = '{"event":{"textOutput":{"content":"hi","role":"ASSISTANT"}}}' + result_mock = MagicMock() + result_mock.value.bytes_ = payload.encode("utf-8") + + output_mock = AsyncMock(return_value=result_mock) + connection._stream_response.await_output = AsyncMock( + side_effect=[(None, MagicMock(receive=output_mock)), StopAsyncIteration] + ) + + results = [] + async for item in connection.receive(): + results.append(item) + break # only one item expected + + assert results == [payload] + + @pytest.mark.asyncio + async def test_receive_stops_on_stop_async_iteration(self, connection): + connection._stream_response.await_output = AsyncMock(side_effect=StopAsyncIteration) + + results = [] + async for item in connection.receive(): + results.append(item) + + assert results == [] + + @pytest.mark.asyncio + async def test_receive_stops_on_cancelled(self, connection): + connection._stream_response.await_output = AsyncMock( + side_effect=Exception("InvalidStateError: CANCELLED") + ) + + results = [] + async for item in connection.receive(): + results.append(item) + + assert results == [] + + @pytest.mark.asyncio + async def test_receive_stops_on_validation_exception(self, connection): + connection._stream_response.await_output = AsyncMock( + side_effect=Exception("ValidationException: bad input") + ) + + results = [] + async for item in connection.receive(): + results.append(item) + + assert results == [] + + @pytest.mark.asyncio + async def test_receive_returns_immediately_when_no_stream(self): + conn = BedrockConnection(model_id="m", region="r") + results = [] + async for item in conn.receive(): + results.append(item) + assert results == [] + + @pytest.mark.asyncio + async def test_receive_skips_empty_bytes(self, connection): + """When result.value.bytes_ is None, nothing is yielded.""" + result_mock = MagicMock() + result_mock.value.bytes_ = None + + output_mock = AsyncMock(return_value=result_mock) + connection._stream_response.await_output = AsyncMock( + side_effect=[(None, MagicMock(receive=output_mock)), StopAsyncIteration] + ) + + results = [] + async for item in connection.receive(): + results.append(item) + + assert results == [] + + +# --------------------------------------------------------------------------- +# close() +# --------------------------------------------------------------------------- + +class TestClose: + @pytest.mark.asyncio + async def test_close_sets_closed_flag(self, connection): + await connection.close() + assert connection._closed is True + assert connection._stream_response is None + + @pytest.mark.asyncio + async def test_close_is_idempotent(self, connection): + await connection.close() + await connection.close() # second call should be a no-op + assert connection._closed is True + + @pytest.mark.asyncio + async def test_close_handles_stream_close_error(self, connection): + connection._stream_response.input_stream.close.side_effect = Exception("oops") + await connection.close() # should not raise + assert connection._closed is True + + @pytest.mark.asyncio + async def test_close_on_fresh_connection(self): + conn = BedrockConnection(model_id="m", region="r") + await conn.close() # no stream to close — should be fine + assert conn._closed is True + + +# --------------------------------------------------------------------------- +# is_open property +# --------------------------------------------------------------------------- + +class TestIsOpen: + def test_is_open_false_initially(self): + conn = BedrockConnection(model_id="m", region="r") + assert conn.is_open is False + + def test_is_open_true_when_stream_active(self, connection): + assert connection.is_open is True + + def test_is_open_false_after_close(self, connection): + connection._closed = True + assert connection.is_open is False diff --git a/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/tests/test_connection.py b/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/tests/test_connection.py new file mode 100644 index 00000000..02db8101 --- /dev/null +++ b/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/tests/test_connection.py @@ -0,0 +1,140 @@ +"""Minimal test: can we connect to Bedrock and get a response?""" +import asyncio +import json +import os + +from aws_sdk_bedrock_runtime.client import ( + BedrockRuntimeClient, + InvokeModelWithBidirectionalStreamOperationInput, +) +from aws_sdk_bedrock_runtime.models import ( + InvokeModelWithBidirectionalStreamInputChunk, + BidirectionalInputPayloadPart, +) +from aws_sdk_bedrock_runtime.config import Config +from smithy_aws_core.identity.environment import EnvironmentCredentialsResolver + + +MODEL_ID = "amazon.nova-2-sonic-v1:0" +REGION = "us-east-1" + + +async def send_event(stream, event_dict): + event_json = json.dumps(event_dict) + chunk = InvokeModelWithBidirectionalStreamInputChunk( + value=BidirectionalInputPayloadPart(bytes_=event_json.encode("utf-8")) + ) + await stream.input_stream.send(chunk) + print(f" Sent: {list(event_dict['event'].keys())[0]}", flush=True) + + +async def main(): + print("=== Bedrock Connection Test ===", flush=True) + + # Check env vars + key = os.environ.get("AWS_ACCESS_KEY_ID", "") + region = os.environ.get("AWS_DEFAULT_REGION", os.environ.get("AWS_REGION", "")) + print(f"AWS_ACCESS_KEY_ID: {'set (' + key[:4] + '...)' if key else 'NOT SET'}", flush=True) + print(f"AWS_SECRET_ACCESS_KEY: {'set' if os.environ.get('AWS_SECRET_ACCESS_KEY') else 'NOT SET'}", flush=True) + print(f"AWS_SESSION_TOKEN: {'set' if os.environ.get('AWS_SESSION_TOKEN') else 'NOT SET'}", flush=True) + print(f"AWS_REGION env: {region or 'NOT SET'}", flush=True) + print(flush=True) + + # Connect + print("1. Creating client...", flush=True) + config = Config( + endpoint_uri=f"https://bedrock-runtime.{REGION}.amazonaws.com", + region=REGION, + aws_credentials_identity_resolver=EnvironmentCredentialsResolver(), + ) + client = BedrockRuntimeClient(config=config) + + print("2. Opening stream...", flush=True) + stream = await client.invoke_model_with_bidirectional_stream( + InvokeModelWithBidirectionalStreamOperationInput(model_id=MODEL_ID) + ) + print(" Stream opened", flush=True) + + # Send minimal init sequence + print("3. Sending init events...", flush=True) + + await send_event(stream, { + "event": {"sessionStart": {"inferenceConfiguration": {"maxTokens": 1024, "topP": 0.0, "temperature": 0.0}}} + }) + + prompt_name = "test-prompt" + await send_event(stream, { + "event": {"promptStart": { + "promptName": prompt_name, + "textOutputConfiguration": {"mediaType": "text/plain"}, + "audioOutputConfiguration": { + "mediaType": "audio/lpcm", "sampleRateHertz": 24000, + "sampleSizeBits": 16, "channelCount": 1, + "voiceId": "matthew", "encoding": "base64", "audioType": "SPEECH" + }, + "toolUseOutputConfiguration": {"mediaType": "application/json"}, + "toolConfiguration": {"tools": []} + }} + }) + + content_name = "test-content" + await send_event(stream, { + "event": {"contentStart": { + "promptName": prompt_name, "contentName": content_name, + "type": "TEXT", "role": "USER", "interactive": True, + "textInputConfiguration": {"mediaType": "text/plain"} + }} + }) + + await send_event(stream, { + "event": {"textInput": { + "promptName": prompt_name, "contentName": content_name, + "content": "Say hello in one sentence." + }} + }) + + await send_event(stream, { + "event": {"contentEnd": {"promptName": prompt_name, "contentName": content_name}} + }) + + # Try to receive — loop with timeout to catch all responses + print("4. Waiting for responses (15s timeout)...", flush=True) + got_response = False + try: + deadline = asyncio.get_event_loop().time() + 15.0 + while asyncio.get_event_loop().time() < deadline: + remaining = deadline - asyncio.get_event_loop().time() + output = await asyncio.wait_for(stream.await_output(), timeout=remaining) + result = await output[1].receive() + if result.value and result.value.bytes_: + raw = result.value.bytes_.decode("utf-8") + print(f" RESPONSE: {raw[:100]}...", flush=True) + got_response = True + else: + print(" Got empty response", flush=True) + except asyncio.TimeoutError: + if not got_response: + print(" TIMEOUT — no response from Bedrock after 15s", flush=True) + else: + print(" (timeout after receiving responses — normal)", flush=True) + except StopAsyncIteration: + print(" Stream ended", flush=True) + except Exception as e: + print(f" ERROR: {type(e).__name__}: {e}", flush=True) + + if not got_response: + print(flush=True) + print("=== TROUBLESHOOTING ===", flush=True) + print("1. Check model access: aws bedrock list-foundation-models --region us-east-1 --query 'modelSummaries[?modelId==`amazon.nova-sonic-v1:0`]'", flush=True) + print("2. Check token expiry: aws sts get-caller-identity", flush=True) + print("3. Try a simple Bedrock call: aws bedrock-runtime invoke-model --model-id amazon.nova-lite-v1:0 --region us-east-1 --body '{\"messages\":[{\"role\":\"user\",\"content\":[{\"text\":\"hi\"}]}]}' --cli-binary-format raw-in-base64-out /dev/stdout", flush=True) + + # Cleanup + try: + await stream.input_stream.close() + except: + pass + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/tests/test_conversation_state.py b/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/tests/test_conversation_state.py new file mode 100644 index 00000000..9d2cb0c7 --- /dev/null +++ b/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/tests/test_conversation_state.py @@ -0,0 +1,62 @@ +"""Property-based tests for ConversationState. + +**Validates: Requirements 1.3, 1.4, 1.5** +""" + +import sys +import os + +from hypothesis import given, settings +from hypothesis import strategies as st + +from src.session.conversation_state import ConversationState + + +# Feature: architecture-refactor, Property 1: Agent switch request is recorded atomically +@given(agent_name=st.text(min_size=1)) +@settings(max_examples=200) +def test_request_switch_records_atomically(agent_name: str): + """For any valid agent name string, calling request_switch(agent_name) on a + ConversationState SHALL result in switch_requested being True and + switch_target being equal to the provided agent name. + + **Validates: Requirements 1.3** + """ + state = ConversationState() + state.request_switch(agent_name) + + assert state.switch_requested is True + assert state.switch_target == agent_name + + +# Feature: architecture-refactor, Property 2: Conversation history preserves all appended messages in order +@given( + messages=st.lists( + st.tuples( + st.text(min_size=1), + st.text(min_size=1), + ), + min_size=0, + max_size=50, + ) +) +@settings(max_examples=200) +def test_history_preserves_appended_messages_in_order(messages): + """For any sequence of (role, content) string pairs appended to a + ConversationState, calling get_history() SHALL return a list of the same + length containing dictionaries with matching role and content values in the + same order they were appended. + + **Validates: Requirements 1.4, 1.5** + """ + state = ConversationState() + + for role, content in messages: + state.append_message(role, content) + + history = state.get_history() + + assert len(history) == len(messages) + for (expected_role, expected_content), entry in zip(messages, history): + assert entry["role"] == expected_role + assert entry["content"] == expected_content diff --git a/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/tests/test_response_parser.py b/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/tests/test_response_parser.py new file mode 100644 index 00000000..f9b76d83 --- /dev/null +++ b/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/tests/test_response_parser.py @@ -0,0 +1,287 @@ +"""Unit tests for ResponseParser covering each event type, malformed JSON, and missing fields.""" +import json +import pytest + +from src.connection.response_parser import ResponseParser +from src.connection.stream_events import ( + CompletionStartEvent, + ContentStartEvent, + TextOutputEvent, + AudioOutputEvent, + ToolUseEvent, + BargeInEvent, + ContentEndEvent, + CompletionEndEvent, + UsageEvent, + UnknownEvent, +) + + +# --- completionStart --- + +class TestCompletionStart: + def test_basic(self): + data = json.dumps({"event": {"completionStart": {"requestId": "abc123"}}}) + result = ResponseParser.parse(data) + assert isinstance(result, CompletionStartEvent) + assert result.data["completionStart"]["requestId"] == "abc123" + + +# --- contentStart --- + +class TestContentStart: + def test_basic_role(self): + data = json.dumps({"event": {"contentStart": {"role": "ASSISTANT"}}}) + result = ResponseParser.parse(data) + assert isinstance(result, ContentStartEvent) + assert result.role == "ASSISTANT" + assert result.is_final_response is False + + def test_final_response(self): + data = json.dumps({ + "event": { + "contentStart": { + "role": "ASSISTANT", + "additionalModelFields": json.dumps({"generationStage": "FINAL"}), + } + } + }) + result = ResponseParser.parse(data) + assert isinstance(result, ContentStartEvent) + assert result.is_final_response is True + + def test_non_final_generation_stage(self): + data = json.dumps({ + "event": { + "contentStart": { + "role": "ASSISTANT", + "additionalModelFields": json.dumps({"generationStage": "SPECULATIVE"}), + } + } + }) + result = ResponseParser.parse(data) + assert isinstance(result, ContentStartEvent) + assert result.is_final_response is False + + def test_malformed_additional_fields(self): + data = json.dumps({ + "event": { + "contentStart": { + "role": "USER", + "additionalModelFields": "not-valid-json{", + } + } + }) + result = ResponseParser.parse(data) + assert isinstance(result, ContentStartEvent) + assert result.role == "USER" + assert result.is_final_response is False + + def test_missing_role(self): + data = json.dumps({"event": {"contentStart": {}}}) + result = ResponseParser.parse(data) + assert isinstance(result, ContentStartEvent) + assert result.role == "" + + +# --- textOutput --- + +class TestTextOutput: + def test_basic(self): + data = json.dumps({ + "event": {"textOutput": {"content": "Hello!", "role": "ASSISTANT"}} + }) + result = ResponseParser.parse(data) + assert isinstance(result, TextOutputEvent) + assert result.content == "Hello!" + assert result.role == "ASSISTANT" + + def test_barge_in(self): + data = json.dumps({ + "event": { + "textOutput": { + "content": '{ "interrupted" : true }', + "role": "ASSISTANT", + } + } + }) + result = ResponseParser.parse(data) + assert isinstance(result, BargeInEvent) + + def test_barge_in_embedded(self): + data = json.dumps({ + "event": { + "textOutput": { + "content": 'some text { "interrupted" : true } more text', + "role": "ASSISTANT", + } + } + }) + result = ResponseParser.parse(data) + assert isinstance(result, BargeInEvent) + + def test_missing_fields(self): + data = json.dumps({"event": {"textOutput": {}}}) + result = ResponseParser.parse(data) + assert isinstance(result, TextOutputEvent) + assert result.content == "" + assert result.role == "" + + +# --- audioOutput --- + +class TestAudioOutput: + def test_basic(self): + data = json.dumps({ + "event": {"audioOutput": {"content": "base64audiodata=="}} + }) + result = ResponseParser.parse(data) + assert isinstance(result, AudioOutputEvent) + assert result.audio_base64 == "base64audiodata==" + + def test_missing_content(self): + data = json.dumps({"event": {"audioOutput": {}}}) + result = ResponseParser.parse(data) + assert isinstance(result, AudioOutputEvent) + assert result.audio_base64 == "" + + +# --- toolUse --- + +class TestToolUse: + def test_basic(self): + tool_content = json.dumps({"role": "sales"}) + data = json.dumps({ + "event": { + "toolUse": { + "toolName": "switch_agent", + "toolUseId": "tool-123", + "content": tool_content, + } + } + }) + result = ResponseParser.parse(data) + assert isinstance(result, ToolUseEvent) + assert result.tool_name == "switch_agent" + assert result.tool_use_id == "tool-123" + assert result.content == {"role": "sales"} + + def test_non_switch_tool(self): + tool_content = json.dumps({"order_id": "ORD-456"}) + data = json.dumps({ + "event": { + "toolUse": { + "toolName": "track_order", + "toolUseId": "tool-789", + "content": tool_content, + } + } + }) + result = ResponseParser.parse(data) + assert isinstance(result, ToolUseEvent) + assert result.tool_name == "track_order" + assert result.content == {"order_id": "ORD-456"} + + def test_malformed_content(self): + data = json.dumps({ + "event": { + "toolUse": { + "toolName": "some_tool", + "toolUseId": "id-1", + "content": "not-valid-json{", + } + } + }) + result = ResponseParser.parse(data) + assert isinstance(result, ToolUseEvent) + assert result.content == {} + + def test_dict_content(self): + """Content already a dict (not a JSON string).""" + data = json.dumps({ + "event": { + "toolUse": { + "toolName": "my_tool", + "toolUseId": "id-2", + "content": {"key": "value"}, + } + } + }) + result = ResponseParser.parse(data) + assert isinstance(result, ToolUseEvent) + assert result.content == {"key": "value"} + + def test_missing_fields(self): + data = json.dumps({"event": {"toolUse": {}}}) + result = ResponseParser.parse(data) + assert isinstance(result, ToolUseEvent) + assert result.tool_name == "" + assert result.tool_use_id == "" + assert result.content == {} + + +# --- contentEnd --- + +class TestContentEnd: + def test_with_type(self): + data = json.dumps({"event": {"contentEnd": {"type": "TOOL"}}}) + result = ResponseParser.parse(data) + assert isinstance(result, ContentEndEvent) + assert result.content_type == "TOOL" + + def test_without_type(self): + data = json.dumps({"event": {"contentEnd": {}}}) + result = ResponseParser.parse(data) + assert isinstance(result, ContentEndEvent) + assert result.content_type is None + + +# --- completionEnd --- + +class TestCompletionEnd: + def test_basic(self): + data = json.dumps({"event": {"completionEnd": {}}}) + result = ResponseParser.parse(data) + assert isinstance(result, CompletionEndEvent) + + +# --- usageEvent --- + +class TestUsageEvent: + def test_basic(self): + data = json.dumps({ + "event": { + "usageEvent": {"inputTokens": 100, "outputTokens": 50} + } + }) + result = ResponseParser.parse(data) + assert isinstance(result, UsageEvent) + assert result.data["usageEvent"]["inputTokens"] == 100 + + +# --- UnknownEvent / error cases --- + +class TestUnknownEvent: + def test_invalid_json(self): + result = ResponseParser.parse("not json at all {{{") + assert isinstance(result, UnknownEvent) + assert result.raw_data == "not json at all {{{" + + def test_empty_string(self): + result = ResponseParser.parse("") + assert isinstance(result, UnknownEvent) + + def test_valid_json_no_event_key(self): + data = json.dumps({"something": "else"}) + result = ResponseParser.parse(data) + assert isinstance(result, UnknownEvent) + assert data in result.raw_data + + def test_unrecognized_event_type(self): + data = json.dumps({"event": {"unknownType": {"data": 1}}}) + result = ResponseParser.parse(data) + assert isinstance(result, UnknownEvent) + + def test_none_input(self): + result = ResponseParser.parse(None) + assert isinstance(result, UnknownEvent) diff --git a/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/tests/test_tool_registry.py b/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/tests/test_tool_registry.py new file mode 100644 index 00000000..deeb76cf --- /dev/null +++ b/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/conversation-transfer/tests/test_tool_registry.py @@ -0,0 +1,176 @@ +"""Unit tests for ToolRegistry. + +**Validates: Requirements 2.1, 2.2, 2.3, 2.4, 2.5** +""" + +import sys +import os +import json + +import pytest +from typing import Dict, Any + +from src.agents.tool_registry import ToolRegistry, SWITCH_AGENT_SCHEMA +from src.agents.agent_config import Agent, ToolDefinition + + +# --- Helpers --- + +async def dummy_tool(x: str) -> Dict[str, Any]: + return {"result": x} + + +async def failing_tool(x: str) -> Dict[str, Any]: + raise RuntimeError("boom") + + +def make_agent(name: str, tools: list) -> Agent: + return Agent(voice_id="v1", instruction="test", tools=tools) + + +def make_tool_def(name: str, callable=dummy_tool) -> ToolDefinition: + return ToolDefinition( + name=name, + description=f"Desc for {name}", + input_schema={"type": "object", "properties": {"x": {"type": "string"}}}, + callable=callable, + ) + + +# --- from_agents tests --- + +class TestFromAgents: + def test_builds_registry_from_agents(self): + agents = { + "support": make_agent("support", [make_tool_def("tool_a")]), + "sales": make_agent("sales", [make_tool_def("tool_b")]), + } + registry = ToolRegistry.from_agents(agents) + + assert "tool_a" in registry._tools + assert "tool_b" in registry._tools + assert registry._agent_tool_names["support"] == ["tool_a"] + assert registry._agent_tool_names["sales"] == ["tool_b"] + + def test_empty_agents(self): + registry = ToolRegistry.from_agents({}) + assert registry._tools == {} + assert registry._agent_tool_names == {} + + def test_agent_with_no_tools(self): + agents = {"support": make_agent("support", [])} + registry = ToolRegistry.from_agents(agents) + assert registry._agent_tool_names["support"] == [] + + def test_duplicate_tool_across_agents_registered_once(self): + shared = make_tool_def("shared_tool") + agents = { + "a": make_agent("a", [shared]), + "b": make_agent("b", [shared]), + } + registry = ToolRegistry.from_agents(agents) + assert len(registry._tools) == 1 + assert "shared_tool" in registry._tools + + +# --- register tests --- + +class TestRegister: + def test_register_and_lookup(self): + registry = ToolRegistry() + registry.register("my_tool", dummy_tool, {"description": "d", "input_schema": {}}) + assert "my_tool" in registry._tools + assert registry._tools["my_tool"].callable is dummy_tool + + +# --- get_schemas_for_agent tests --- + +class TestGetSchemasForAgent: + def test_always_includes_switch_agent(self): + registry = ToolRegistry.from_agents({ + "support": make_agent("support", [make_tool_def("tool_a")]), + }) + agents = {"support": make_agent("support", [make_tool_def("tool_a")])} + schemas = registry.get_schemas_for_agent("support", agents) + + names = [s["toolSpec"]["name"] for s in schemas] + assert "switch_agent" in names + + def test_includes_agent_specific_tools(self): + agents = { + "support": make_agent("support", [make_tool_def("tool_a")]), + "sales": make_agent("sales", [make_tool_def("tool_b")]), + } + registry = ToolRegistry.from_agents(agents) + + support_schemas = registry.get_schemas_for_agent("support", agents) + support_names = [s["toolSpec"]["name"] for s in support_schemas] + assert "tool_a" in support_names + assert "tool_b" not in support_names + + sales_schemas = registry.get_schemas_for_agent("sales", agents) + sales_names = [s["toolSpec"]["name"] for s in sales_schemas] + assert "tool_b" in sales_names + assert "tool_a" not in sales_names + + def test_bedrock_compatible_format(self): + agents = {"support": make_agent("support", [make_tool_def("tool_a")])} + registry = ToolRegistry.from_agents(agents) + schemas = registry.get_schemas_for_agent("support", agents) + + # Find the agent tool (not switch_agent) + tool_schema = [s for s in schemas if s["toolSpec"]["name"] == "tool_a"][0] + spec = tool_schema["toolSpec"] + assert "name" in spec + assert "description" in spec + assert "inputSchema" in spec + assert "json" in spec["inputSchema"] + # inputSchema.json should be a JSON string + parsed = json.loads(spec["inputSchema"]["json"]) + assert parsed["type"] == "object" + + def test_unknown_agent_returns_only_switch_agent(self): + registry = ToolRegistry.from_agents({}) + schemas = registry.get_schemas_for_agent("nonexistent", {}) + assert len(schemas) == 1 + assert schemas[0]["toolSpec"]["name"] == "switch_agent" + + +# --- execute tests --- + +class TestExecute: + @pytest.mark.asyncio + async def test_execute_known_tool(self): + registry = ToolRegistry() + registry.register("my_tool", dummy_tool, {"description": "d", "input_schema": {}}) + result = await registry.execute("my_tool", {"x": "hello"}) + assert result == {"result": "hello"} + + @pytest.mark.asyncio + async def test_execute_unknown_tool(self): + registry = ToolRegistry() + result = await registry.execute("no_such_tool", {}) + assert "error" in result + assert "Unknown tool: no_such_tool" in result["error"] + + @pytest.mark.asyncio + async def test_execute_with_string_content(self): + registry = ToolRegistry() + registry.register("my_tool", dummy_tool, {"description": "d", "input_schema": {}}) + result = await registry.execute("my_tool", {"content": '{"x": "from_string"}'}) + assert result == {"result": "from_string"} + + @pytest.mark.asyncio + async def test_execute_with_dict_content(self): + registry = ToolRegistry() + registry.register("my_tool", dummy_tool, {"description": "d", "input_schema": {}}) + result = await registry.execute("my_tool", {"content": {"x": "from_dict"}}) + assert result == {"result": "from_dict"} + + @pytest.mark.asyncio + async def test_execute_failure_returns_error(self): + registry = ToolRegistry() + registry.register("bad_tool", failing_tool, {"description": "d", "input_schema": {}}) + result = await registry.execute("bad_tool", {"x": "hello"}) + assert "error" in result + assert "Tool execution failed: boom" in result["error"]