diff --git a/README.md b/README.md index 68acc6e..7944e69 100644 --- a/README.md +++ b/README.md @@ -67,6 +67,8 @@ asyncio.run(main()) - šŸŽ„ **Real-time Audio/Video streaming** - Receive synchronized audio/video frames from the avatar (as PyAV AudioFrame/VideoFrame objects) - šŸ’¬ **Two-way communication** - Send text messages (like transcribed user speech) and receive generated responses +- šŸ“ **Real-time transcriptions** - Receive incremental message stream events for user and persona text as it's generated +- šŸ“š **Message history tracking** - Automatic conversation history with incremental updates - šŸŽ¤ **Audio-passthrough** - Send TTS generated audio input and receive rendered synchronized audio/video avatar - šŸ—£ļø **Direct text-to-speech** - Send text directly to TTS for immediate speech output (bypasses LLM processing) - šŸŽÆ **Async iterator API** - Clean, Pythonic async/await patterns for continuous stream of audio/video frames @@ -130,22 +132,59 @@ async with client.connect() as session: Register callbacks for connection and message events using the `@client.on()` decorator: ```python -from anam import AnamEvent - -@client.on(AnamEvent.MESSAGE_RECEIVED) -async def on_message(message: Message): - """Called when a chat message is received.""" - print(f"{message.role}: {message.content}") +from anam import AnamEvent, Message,MessageRole, MessageStreamEvent @client.on(AnamEvent.CONNECTION_ESTABLISHED) async def on_connected(): """Called when the connection is established.""" - pass + print("āœ… Connected!") @client.on(AnamEvent.CONNECTION_CLOSED) async def on_closed(code: str, reason: str | None): """Called when the connection is closed.""" - pass + print(f"Connection closed: {code} - {reason or 'No reason'}") + +@client.on(AnamEvent.MESSAGE_STREAM_EVENT_RECEIVED) +async def on_message_stream_event(event: MessageStreamEvent): + """Called for each incremental chunk of transcribed text or persona response. + + This event fires for both user transcriptions and persona responses as they stream in. + This can be used for real-time captions or transcriptions. + """ + if event.role == MessageRole.USER: + # User transcription (from their speech) + if event.content_index == 0: + print(f"šŸ‘¤ User: ", end="", flush=True) + print(event.content, end="", flush=True) + if event.end_of_speech: + print() # New line when transcription completes + else: + # Persona response + if event.content_index == 0: + print(f"šŸ¤– Persona: ", end="", flush=True) + print(event.content, end="", flush=True) + if event.end_of_speech: + status = "āœ“" if not event.interrupted else "āœ— INTERRUPTED" + print(f" {status}") + +@client.on(AnamEvent.MESSAGE_RECEIVED) +async def on_message(message: Message): + """Called when a complete message is received (after end_of_speech). + + This is fired after MESSAGE_STREAM_EVENT_RECEIVED with end_of_speech=True. + Useful for backward compatibility or when you only need complete messages. + """ + print(f"{message.role}: {message.content}") + +@client.on(AnamEvent.MESSAGE_HISTORY_UPDATED) +async def on_message_history_updated(messages: list[Message]): + """Called when the message history is updated (after a message completes). + + The messages list contains the complete conversation history. + """ + print(f"šŸ“ Conversation history: {len(messages)} messages") + for msg in messages: + print(f" {msg.role}: {msg.content[:50]}...") ``` ### Session @@ -157,9 +196,29 @@ async with client.connect() as session: # Send a text message (simulates user speech) await session.send_message("Hello, how are you?") + # Send text directly to TTS (bypasses LLM) + await session.talk("This will be spoken immediately") + + # Stream text to TTS incrementally (for streaming scenarios) + await session.send_talk_stream( + content="Hello", + start_of_speech=True, + end_of_speech=False, + ) + await session.send_talk_stream( + content=" world!", + start_of_speech=False, + end_of_speech=True, + ) + # Interrupt the avatar if speaking await session.interrupt() + # Get message history + history = client.get_message_history() + for msg in history: + print(f"{msg.role}: {msg.content}") + # Wait until the session ends await session.wait_until_closed() ``` diff --git a/examples/persona_interactive_video.py b/examples/persona_interactive_video.py index 8f7d452..89f991c 100644 --- a/examples/persona_interactive_video.py +++ b/examples/persona_interactive_video.py @@ -27,7 +27,7 @@ from dotenv import load_dotenv from anam import AnamClient, AnamEvent, ClientOptions -from anam.types import AgentAudioInputConfig, PersonaConfig +from anam.types import AgentAudioInputConfig, MessageRole, PersonaConfig # Add parent directory to path to allow importing from examples sys.path.insert(0, str(Path(__file__).parent.parent)) @@ -38,7 +38,7 @@ # Configure logging - reduced verbosity logging.basicConfig( - level=logging.WARNING, # Reduced from INFO to WARNING + level=logging.WARNING, format="%(levelname)s: %(message)s", # Simplified format ) logger = logging.getLogger(__name__) @@ -47,19 +47,30 @@ logging.getLogger("anam").setLevel(logging.WARNING) logging.getLogger("websockets").setLevel(logging.WARNING) logging.getLogger("aiohttp").setLevel(logging.WARNING) +logging.getLogger("aiortc").setLevel(logging.WARNING) +logging.getLogger("aioice").setLevel(logging.WARNING) + +# Global state for captions toggle +show_captions = False +print_conversation_history = False async def interactive_loop(session, display: VideoDisplay) -> None: """Interactive command loop.""" + global show_captions + global print_conversation_history print("\n" + "=" * 60) print("Interactive Session Started!") print("=" * 60) print("Available commands:") - print(" f [filename] - Send audio file (defaults to input.wav)") - print(" m - Send text message (user input for the conversation.)") - print(" t|ts - Send talk command (bypasses LLM and sends text directly to TTS). t: REST API, ts: WebSocket)") - print(" i - Interrupt current audio") - print(" q - Quit and stop session") + print(" f [filename] - Send audio file (defaults to input.wav)") + print(" m - Send text message (user input for the conversation.)") + print(" t - Send talk command (bypasses LLM and sends text to TTS) usingREST API)") + print(" ts - Send talk stream (bypasses LLM and sends text to TTS) using WebSocket)") + print(" i - Interrupt current audio") + print(" c - Toggle live captions. Default: disabled") + print(" h - Toggle conversation history at session end. Default: disabled.") + print(" q - Quit and stop session") print("=" * 60 + "\n") while True: @@ -79,6 +90,10 @@ async def interactive_loop(session, display: VideoDisplay) -> None: display.stop() break + elif command == "c": + show_captions = not show_captions + print(f"Captions {'enabled' if show_captions else 'disabled'}") + elif command == "f": # Default to input.wav if no filename provided wav_file = parts[1] if len(parts) > 1 else "input.wav" @@ -92,6 +107,12 @@ async def interactive_loop(session, display: VideoDisplay) -> None: else: print(f"āŒ File not found: {wav_file}") + elif command == "h": + print_conversation_history = not print_conversation_history + print( + f"Conversation history {'enabled' if print_conversation_history else 'disabled'}" + ) + elif command == "m": # Get the rest of the input as the message text if len(parts) < 2: @@ -115,7 +136,10 @@ async def interactive_loop(session, display: VideoDisplay) -> None: await session.talk(message_text) elif command == "ts": await session.send_talk_stream( - message_text, start_of_speech=True, end_of_speech=True, correlation_id=None + message_text, + start_of_speech=True, + end_of_speech=True, + correlation_id=None, ) print(f"āœ… Sent talk command: {message_text}") except Exception as e: @@ -145,6 +169,7 @@ async def stream_session( audio_player: AudioPlayer, ) -> None: """Run the streaming session.""" + global show_captions # Register connection event handlers @client.on(AnamEvent.CONNECTION_ESTABLISHED) @@ -153,8 +178,44 @@ async def on_connected() -> None: @client.on(AnamEvent.CONNECTION_CLOSED) async def on_closed(code: str, reason: str | None) -> None: + global print_conversation_history + if print_conversation_history: + print("Conversation transcript:") + print("=" * 24) + print( + "\n".join( + [ + f"{m.role.value.capitalize()}: {m.content}" + for m in client.get_message_history() + ] + ) + ) print(f"Connection closed: {code} - {reason or 'User initiated'}") + # Register message stream event handlers + @client.on(AnamEvent.MESSAGE_STREAM_EVENT_RECEIVED) + async def on_message_stream_event(event) -> None: + """Handle incremental message stream events.""" + global show_captions + if show_captions: + role_emoji = "šŸ‘¤" if event.role == MessageRole.USER else "šŸ¤–" + role_name = "USER" if event.role == MessageRole.USER else "PERSONA" + + if event.content_index == 0: + # content_index 0 denotes the start of a new message + print(f"{role_emoji} {role_name}: ", end="", flush=True) + # Show incremental updates (you can customize this) + print(f"{event.content}", end="", flush=True) + if event.end_of_speech: + # end_of_speech is fired when the message is complete + status = "āœ“" if not event.interrupted else "āœ— INTERRUPTED" + print(f"{status}\n") + + @client.on(AnamEvent.MESSAGE_HISTORY_UPDATED) + async def on_message_history_updated(messages) -> None: + """Handle message history updates.""" + logger.debug(f"\nšŸ“ Message history updated: {len(messages)} messages total") + async def consume_video_frames(session) -> None: """Consume video frames from iterator.""" try: diff --git a/src/anam/__init__.py b/src/anam/__init__.py index ec68043..5b835f2 100644 --- a/src/anam/__init__.py +++ b/src/anam/__init__.py @@ -59,6 +59,7 @@ async def consume_audio(): ConnectionClosedCode, Message, MessageRole, + MessageStreamEvent, PersonaConfig, ) @@ -74,6 +75,7 @@ async def consume_audio(): "ConnectionClosedCode", "Message", "MessageRole", + "MessageStreamEvent", "PersonaConfig", "VideoFrame", # Errors diff --git a/src/anam/_streaming.py b/src/anam/_streaming.py index 0511c3d..8406bd8 100644 --- a/src/anam/_streaming.py +++ b/src/anam/_streaming.py @@ -341,6 +341,13 @@ async def _setup_data_channel(self) -> None: ordered=True, ) + # Initialize to False in case there's a stale value from a previous session + self._data_channel_open = False + + # Check if channel is already open + if self._data_channel.readyState == "open": + self._data_channel_open = True + @self._data_channel.on("open") def on_open() -> None: logger.info("Data channel opened") @@ -355,14 +362,11 @@ def on_close() -> None: async def on_message(message: str) -> None: try: data = json.loads(message) - logger.debug("Data channel message: %s", data.get("messageType", "unknown")) if self._on_message: await self._on_message(data) except json.JSONDecodeError as e: logger.error("Failed to parse data channel message: %s", e) - self._data_channel_open = False - async def video_frames(self) -> AsyncIterator[VideoFrame]: """Get video frames as an async iterator. diff --git a/src/anam/client.py b/src/anam/client.py index 06ad5bf..432adf5 100644 --- a/src/anam/client.py +++ b/src/anam/client.py @@ -20,6 +20,7 @@ ClientOptions, Message, MessageRole, + MessageStreamEvent, PersonaConfig, SessionInfo, ) @@ -131,6 +132,7 @@ def __init__( self._session_info: SessionInfo | None = None self._streaming_client: StreamingClient | None = None self._is_streaming = False + self._message_history: list[Message] = [] def on(self, event: AnamEvent) -> Callable[[T], T]: """Decorator to register an event handler. @@ -257,15 +259,83 @@ async def _handle_data_message(self, data: dict[str, Any]) -> None: """Handle data channel message.""" message_type = data.get("messageType", "") - if message_type == "speech_text": - # Convert to Message object + if message_type == "speechText": + # Convert to MessageStreamEvent for incremental updates msg_data = data.get("data", {}) - message = Message( - role=MessageRole(msg_data.get("role", "assistant")), - content=msg_data.get("content", ""), - timestamp=msg_data.get("timestamp", ""), + message_id = msg_data.get("message_id", "") + role_str = msg_data.get("role", "assistant") + content = msg_data.get("content", "") + content_index = msg_data.get("content_index", 0) + end_of_speech = msg_data.get("end_of_speech", False) + interrupted = msg_data.get("interrupted", False) + timestamp = msg_data.get("timestamp", "") + + # Create message ID similar to JS SDK: "{role}::{message_id}" + stream_event_id = f"{role_str}::{message_id}" + + # Determine role + if role_str.lower() == "user": + role = MessageRole.USER + elif role_str.lower() == "persona": + role = MessageRole.ASSISTANT + else: + role = MessageRole.ASSISTANT + + # Emit incremental stream event + stream_event = MessageStreamEvent( + id=stream_event_id, + content=content, + role=role, + content_index=content_index, + end_of_speech=end_of_speech, + interrupted=interrupted, ) - await self._emit(AnamEvent.MESSAGE_RECEIVED, message) + await self._emit(AnamEvent.MESSAGE_STREAM_EVENT_RECEIVED, stream_event) + + # Update message history + self._process_message_stream_event(stream_event, timestamp) + + # Emit final message when speech ends (for backward compatibility) + if end_of_speech: + # Find the complete message in history + complete_message = next( + (msg for msg in self._message_history if msg.id == stream_event_id), + None, + ) + if complete_message: + await self._emit(AnamEvent.MESSAGE_RECEIVED, complete_message) + await self._emit( + AnamEvent.MESSAGE_HISTORY_UPDATED, self._message_history.copy() + ) + + def _process_message_stream_event(self, event: MessageStreamEvent, timestamp: str) -> None: + """Process a message stream event and update message history.""" + # Find existing message with same ID (for both user and persona messages) + existing_index = next( + (i for i, msg in enumerate(self._message_history) if msg.id == event.id), + None, + ) + + if existing_index is not None: + # Update existing message by appending new content + existing = self._message_history[existing_index] + self._message_history[existing_index] = Message( + id=existing.id, + role=existing.role, + content=existing.content + event.content, + timestamp=existing.timestamp or timestamp, + interrupted=existing.interrupted or event.interrupted, + ) + else: + # Add new message (first chunk) + new_message = Message( + id=event.id, + role=event.role, + content=event.content, + timestamp=timestamp, + interrupted=event.interrupted, + ) + self._message_history.append(new_message) async def _handle_connection_established(self) -> None: """Handle connection established.""" @@ -316,6 +386,14 @@ def session_id(self) -> str | None: """Get the current session ID.""" return self._session_info.session_id if self._session_info else None + def get_message_history(self) -> list[Message]: + """Get the current message history. + + Returns: + A list of messages in the conversation history. + """ + return self._message_history.copy() + def set_persona_config(self, persona_config: PersonaConfig) -> None: """Set the persona configuration. diff --git a/src/anam/types.py b/src/anam/types.py index af99540..0f4fddb 100644 --- a/src/anam/types.py +++ b/src/anam/types.py @@ -15,6 +15,7 @@ class AnamEvent(str, Enum): # Message events MESSAGE_RECEIVED = "message_received" + MESSAGE_STREAM_EVENT_RECEIVED = "message_stream_event_received" MESSAGE_HISTORY_UPDATED = "message_history_updated" # Persona events @@ -126,14 +127,42 @@ class Message: """A message in the conversation. Attributes: + id: Unique identifier for the message. role: Who sent the message (user, assistant, system). content: The text content of the message. timestamp: When the message was sent (ISO format). + interrupted: Whether the message was interrupted (for persona messages). """ + id: str role: MessageRole content: str - timestamp: str + timestamp: str = "" + interrupted: bool = False + + +@dataclass +class MessageStreamEvent: + """A streaming message event for incremental updates. + + This represents a chunk of a message that may be part of a larger message. + Similar to the JavaScript SDK's MessageStreamEvent. + + Attributes: + id: Unique identifier for the message (same for all chunks of the same message). + content: The text content of this chunk. + role: Who sent the message (user or persona). + content_index: Index of this chunk in the message (0 = first chunk/start of speech). + end_of_speech: Whether this is the final chunk of the message. + interrupted: Whether the message was interrupted (for persona messages). + """ + + id: str + content: str + role: MessageRole + content_index: int + end_of_speech: bool + interrupted: bool = False @dataclass diff --git a/uv.lock b/uv.lock index 5ff1981..de09101 100644 --- a/uv.lock +++ b/uv.lock @@ -187,8 +187,8 @@ wheels = [ ] [[package]] -name = "anam-ai" -version = "1.0.0" +name = "anam" +version = "0.1.0" source = { editable = "." } dependencies = [ { name = "aiohttp" },