diff --git a/README.md b/README.md index 7944e69..f8cbeb7 100644 --- a/README.md +++ b/README.md @@ -69,12 +69,13 @@ asyncio.run(main()) - 💬 **Two-way communication** - Send text messages (like transcribed user speech) and receive generated responses - 📝 **Real-time transcriptions** - Receive incremental message stream events for user and persona text as it's generated - 📚 **Message history tracking** - Automatic conversation history with incremental updates -- 🎤 **Audio-passthrough** - Send TTS generated audio input and receive rendered synchronized audio/video avatar +- 🤖 **Audio-passthrough** - Send TTS generated audio input and receive rendered synchronized audio/video avatar - 🗣️ **Direct text-to-speech** - Send text directly to TTS for immediate speech output (bypasses LLM processing) -- 🎯 **Async iterator API** - Clean, Pythonic async/await patterns for continuous stream of audio/video frames +- 🎤 **Real-time user audio input** - Send raw audio samples (e.g. from microphone) to Anam for processing (turnkey solution: STT → LLM → TTS → Face) +- 📡 **Async iterator API** - Clean, Pythonic async/await patterns for continuous stream of audio/video frames - 🎯 **Event-driven API** - Simple decorator-based event handlers for discrete events - 📝 **Fully typed** - Complete type hints for IDE support -- 🔒 **Server-side ready** - Designed for server-side Python applications (e.g. for use in a web application) +- 🔒 **Server-side ready** - Designed for server-side Python applications (e.g. for backend pipelines) ## API Reference @@ -101,9 +102,6 @@ client = AnamClient( voice_id="emma", language_code="en", ), - options=ClientOptions( - disable_input_audio=True, # Don't capture microphone - ), ) ``` @@ -126,6 +124,17 @@ async with client.connect() as session: # Both streams run concurrently await asyncio.gather(process_video(), process_audio()) ``` +### User Audio Input + +User audio input is real time audio such as microphone audio. +User audio is 16 bit PCM samples, mono or stereo, with any sample rate. In order to process the audio correctly, the sample rate needs to be provided. +The audio is forwarded in real-time as a webRTC audio track. In order to reduce latency, any audio provided before the webRTC audio track is created will be dropped. + +### TTS audio (Audio Passthrough) + +TTS audio is generated by a TTS engine, and should be provided in chunks through the `send_audio_chunk` method. The audio can be a byte array or base64 encoded strings (the SDK will convert to base64). The audio is ingested to the backend at max bandwidth. Sample_rate and channels need to be provided through the `AgentAudioInputConfig` object. + +For best performance, we suggest using 24kHz mono audio. The provided audio is returned in-sync with the avatar without any resampling. Sample rates lower than 24kHz will result in poor Avatar performance. Sample rates higher than 24kHz might impact latency without any noticeable improvement in audio quality. ### Events @@ -313,7 +322,6 @@ from anam import ClientOptions options = ClientOptions( api_base_url="https://api.anam.ai", # API base URL api_version="v1", # API version - disable_input_audio=False, # Disable microphone input ice_servers=None, # Custom ICE servers ) ``` @@ -359,6 +367,7 @@ except AnamError as e: - `aiohttp` - HTTP client - `websockets` - WebSocket client - `numpy` - Array handling + - `pyav` - Video and audio handling Optional for display utilities: - `opencv-python` - Video display diff --git a/examples/avatar_audio_passthrough.py b/examples/avatar_audio_passthrough.py index 2b0d652..7655659 100644 --- a/examples/avatar_audio_passthrough.py +++ b/examples/avatar_audio_passthrough.py @@ -198,7 +198,7 @@ def main() -> None: client = AnamClient( api_key=api_key, persona_config=persona_config, - options=ClientOptions(disable_input_audio=True, api_base_url=api_base_url), + options=ClientOptions(api_base_url=api_base_url), ) # Create display and audio player diff --git a/examples/persona_interactive_video.py b/examples/persona_interactive_video.py index 89f991c..e9e78d5 100644 --- a/examples/persona_interactive_video.py +++ b/examples/persona_interactive_video.py @@ -294,7 +294,7 @@ def main() -> None: client = AnamClient( api_key=api_key, persona_config=persona_config, - options=ClientOptions(disable_input_audio=False, api_base_url=api_base_url), + options=ClientOptions(api_base_url=api_base_url), ) # Create display and audio player diff --git a/examples/save_recording.py b/examples/save_recording.py index d245c0e..1ae8d51 100644 --- a/examples/save_recording.py +++ b/examples/save_recording.py @@ -131,7 +131,7 @@ async def main() -> None: client = AnamClient( api_key=api_key, persona_id=persona_id, - options=ClientOptions(disable_input_audio=True, api_base_url=api_base_url), + options=ClientOptions(api_base_url=api_base_url), ) # Register connection event handler diff --git a/examples/text_to_video.py b/examples/text_to_video.py index 0d0d5a9..3527f4e 100644 --- a/examples/text_to_video.py +++ b/examples/text_to_video.py @@ -268,7 +268,7 @@ async def text_to_video( client = AnamClient( api_key=api_key, persona_config=persona_config, - options=ClientOptions(disable_input_audio=True, api_base_url=api_base_url), + options=ClientOptions(api_base_url=api_base_url), ) # Temp files for video and audio (keep extensions for format detection) diff --git a/pyproject.toml b/pyproject.toml index 91faea0..8f987ec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ dependencies = [ "websockets>=12.0", "numpy>=1.26.0", "python-dotenv>=1.2.1", + "av>=16.0.1", ] [project.optional-dependencies] @@ -84,6 +85,7 @@ disallow_incomplete_defs = true [tool.pytest.ini_options] testpaths = ["tests"] +pythonpath = ["src"] [tool.semantic_release] version_toml = ["pyproject.toml:project.version"] diff --git a/src/anam/__init__.py b/src/anam/__init__.py index 5b835f2..8a9d7bb 100644 --- a/src/anam/__init__.py +++ b/src/anam/__init__.py @@ -39,6 +39,7 @@ async def consume_audio(): For more information, see https://docs.anam.ai """ +from av.audio.frame import AudioFrame from av.video.frame import VideoFrame from ._agent_audio_input_stream import AgentAudioInputStream @@ -71,6 +72,7 @@ async def consume_audio(): "AgentAudioInputConfig", "AgentAudioInputStream", "AnamEvent", + "AudioFrame", "ClientOptions", "ConnectionClosedCode", "Message", diff --git a/src/anam/_streaming.py b/src/anam/_streaming.py index 8406bd8..7db9990 100644 --- a/src/anam/_streaming.py +++ b/src/anam/_streaming.py @@ -21,6 +21,7 @@ from ._agent_audio_input_stream import AgentAudioInputStream from ._signalling import SignalAction, SignallingClient +from ._user_audio_input_track import UserAudioInputTrack from .types import AgentAudioInputConfig, SessionInfo logger = logging.getLogger(__name__) @@ -40,7 +41,6 @@ def __init__( on_message: Callable[[dict[str, Any]], Awaitable[None]] | None = None, on_connection_established: Callable[[], Awaitable[None]] | None = None, on_connection_closed: Callable[[str, str | None], Awaitable[None]] | None = None, - disable_input_audio: bool = False, custom_ice_servers: list[dict[str, Any]] | None = None, ): """Initialize the streaming client. @@ -50,7 +50,6 @@ def __init__( on_message: Callback for data channel messages. on_connection_established: Callback when connected. on_connection_closed: Callback when disconnected. - disable_input_audio: If True, don't send microphone audio. custom_ice_servers: Custom ICE servers (optional). """ self._session_info = session_info @@ -62,7 +61,6 @@ def __init__( self._on_connection_closed = on_connection_closed # Configuration - self._disable_input_audio = disable_input_audio self._ice_servers = custom_ice_servers or session_info.ice_servers # State @@ -75,6 +73,8 @@ def __init__( self._audio_track: MediaStreamTrack | None = None self._is_connected = False self._agent_audio_input_stream: AgentAudioInputStream | None = None + self._user_audio_input_track: UserAudioInputTrack | None = None + self._audio_transceiver = None # Store transceiver for lazy track creation async def connect(self, timeout: float = 30.0) -> None: """Start the streaming connection. @@ -322,12 +322,10 @@ def on_track(track: MediaStreamTrack) -> None: # Video: receive only self._peer_connection.addTransceiver("video", direction="recvonly") - # Audio: send/receive or receive only - if self._disable_input_audio: - self._peer_connection.addTransceiver("audio", direction="recvonly") - else: - self._peer_connection.addTransceiver("audio", direction="sendrecv") - # Note: Audio input track would be added here if needed + # Audio: send/receive (track created lazily when first audio arrives via send_user_audio()) + self._audio_transceiver = self._peer_connection.addTransceiver( + "audio", direction="sendrecv" + ) logger.debug("Peer connection initialized") @@ -642,6 +640,17 @@ async def close(self) -> None: finally: self._signalling_client = None + # Close user audio input track before closing peer connection + # This clears the audio queue and prevents recv() from generating more frames + if self._user_audio_input_track: + try: + self._user_audio_input_track.close() + logger.debug("Closed user audio input track") + except Exception as e: + logger.warning("Error closing user audio input track: %s", e) + finally: + self._user_audio_input_track = None + # Close peer connection if self._peer_connection: try: @@ -654,6 +663,53 @@ async def close(self) -> None: self._is_connected = False logger.info("Streaming client closed") + def send_user_audio( + self, + audio_bytes: bytes, + sample_rate: int, + num_channels: int, + ) -> None: + """Send raw user audio samples to Anam for processing. + + This method accepts 16-bit PCM samples and adds them to the audio buffer for transmission via WebRTC. + The audio track is created lazily when first audio arrives. + Audio is only added to the buffer after the connection is established, to avoid accumulating stale audio. + + Args: + audio_bytes: Raw audio data (16-bit PCM). + sample_rate: Sample rate of the input audio (Hz). + num_channels: Number of channels in the input audio (1=mono, 2=stereo). + + Raises: + RuntimeError: If peer connection is not initialized. + """ + if not self._peer_connection: + raise RuntimeError("Peer connection not initialized. Call connect() first.") + if num_channels != 1 and num_channels != 2: + raise RuntimeError("Invalid number of channels. Must be 1 or 2.") + + # Create track lazily when first audio arrives + if self._user_audio_input_track is None: + logger.info( + f"Creating user audio input track: sample_rate={sample_rate}Hz, " + f"channels={num_channels}" + ) + + self._user_audio_input_track = UserAudioInputTrack(sample_rate, num_channels) + + # Add track to transceiver (lazy track creation) + if self._audio_transceiver and self._audio_transceiver.sender: + try: + self._audio_transceiver.sender.replaceTrack(self._user_audio_input_track) + logger.info("Added user audio track to transceiver") + except Exception as e: + raise RuntimeError(f"Failed to add user audio track: {e}") from e + else: + raise RuntimeError("Audio transceiver not properly initialized") + if self._peer_connection.connectionState == "connected": + # Avoid accumulating stale audio, only queue audio when connection is established. + self._user_audio_input_track.add_audio_samples(audio_bytes, sample_rate, num_channels) + def __del__(self) -> None: """Cleanup on destruction to prevent warnings.""" # Clear peer connection reference if close() wasn't called explicitly. diff --git a/src/anam/_user_audio_input_track.py b/src/anam/_user_audio_input_track.py new file mode 100644 index 0000000..d8e52f0 --- /dev/null +++ b/src/anam/_user_audio_input_track.py @@ -0,0 +1,171 @@ +"""User audio input track for sending raw audio samples to Anam via WebRTC. + +This module provides a mechanism for accepting raw audio samples and +converting them to WebRTC-compatible format for transmission. +User audio is real time audio such as microphone audio. +""" + +import asyncio +import fractions +import logging + +import numpy as np +from aiortc.mediastreams import AUDIO_PTIME, AudioStreamTrack, MediaStreamError +from av.audio.frame import AudioFrame + +logger = logging.getLogger(__name__) + + +class UserAudioInputTrack(AudioStreamTrack): + """AudioStreamTrack that accepts raw audio samples and converts to WebRTC format. + + This track accepts raw audio bytes (16-bit PCM) and converts them to AudioFrames + for WebRTC transmission. Audio is stored in a byte buffer and converted to + AudioFrame only when recv() is called. + + To stay close to the live point, the buffer is flushed on the first recv() call, + keeping only the most recent chunk. This handles the case where audio accumulates + between track connection and WebRTC starting to pull frames. + """ + + def __init__(self, sample_rate: int, num_channels: int): + """Initialize the user audio input track. + + Args: + sample_rate: Sample rate of the audio (Hz), e.g., 16000 or 48000. + num_channels: Number of channels (1=mono, 2=stereo). + """ + super().__init__() + self._sample_rate = sample_rate + self._num_channels = num_channels + + # Byte buffer for raw 16-bit PCM audio + self._audio_buffer = bytearray() + + # Calculate samples per chunk (20ms frame) + self._samples_per_chunk = int(sample_rate * AUDIO_PTIME) + # 16-bit = 2 bytes per sample, per channel + self._bytes_per_chunk = self._samples_per_chunk * 2 * num_channels + + # Timestamp for frame pts (in samples) + self._timestamp = 0 + + # Flag to indicate if track is closed + self._is_closed = False + + # Flag to flush buffer on first recv() - handles audio that accumulated + # between track connection and WebRTC starting to pull frames + self._first_recv = True + + # Lock for thread-safe buffer access + self._lock = asyncio.Lock() + + # Maximum buffer size for backpressure (~500ms of audio) + # Drop oldest audio if buffer exceeds this to prevent unbounded growth + self._max_buffer_bytes = self._bytes_per_chunk * 50 + + logger.info( + f"UserAudioInputTrack initialized: {sample_rate}Hz, {num_channels} channel(s), " + f"{self._bytes_per_chunk} bytes per chunk" + ) + + def close(self) -> None: + """Mark track as closed and clear audio buffer. + + After this is called, recv() will raise MediaStreamError to signal + WebRTC to stop calling it. + """ + self._is_closed = True + self._audio_buffer = bytearray() + logger.debug("UserAudioInputTrack closed") + + def add_audio_samples( + self, + audio_bytes: bytes, + sample_rate: int, + num_channels: int, + ) -> None: + """Add raw audio samples to the track buffer. + + Args: + audio_bytes: Raw audio data (16-bit PCM). + sample_rate: Sample rate of the input audio (Hz). + num_channels: Number of channels in the input audio. + """ + if self._is_closed: + return + + # Validate format matches initialization + if sample_rate != self._sample_rate: + logger.warning( + f"Sample rate mismatch: expected {self._sample_rate}Hz, got {sample_rate}Hz. " + "Discarding audio." + ) + return + + if num_channels != self._num_channels: + logger.warning( + f"Channel count mismatch: expected {self._num_channels}, got {num_channels}. " + "Discarding audio." + ) + return + + # Append to buffer + self._audio_buffer.extend(audio_bytes) + + # Backpressure: drop oldest audio if buffer is too large + if len(self._audio_buffer) > self._max_buffer_bytes: + excess = len(self._audio_buffer) - self._max_buffer_bytes + # Align to frame boundary + excess = (excess // self._bytes_per_chunk) * self._bytes_per_chunk + if excess > 0: + logger.warning(f"Dropping {excess} bytes of old audio due to buffer overflow") + self._audio_buffer = self._audio_buffer[excess:] + + async def recv(self) -> AudioFrame: + """Return the next audio frame for WebRTC transmission. + + Returns: + An AudioFrame containing chunk of audio data. + + Raises: + MediaStreamError: If the track has been closed. + """ + if self._is_closed: + raise MediaStreamError("Track has been closed") + + # Wait for enough data (chunk) + while len(self._audio_buffer) < self._bytes_per_chunk: + if self._is_closed: + raise MediaStreamError("Track has been closed") + await asyncio.sleep(0.001) # 1ms poll + + # Extract one chunk from buffer + async with self._lock: + if self._is_closed: + raise MediaStreamError("Track has been closed") + + chunk_bytes = bytes(self._audio_buffer[: self._bytes_per_chunk]) + self._audio_buffer = self._audio_buffer[self._bytes_per_chunk :] + + # Convert bytes to numpy array (16-bit PCM) + samples = np.frombuffer(chunk_bytes, dtype=np.int16) + + # Shape for AudioFrame + if self._num_channels == 1: + audio_data = samples[None, :] # Shape: (1, num_samples) + layout = "mono" + else: + # Reshape interleaved stereo to (2, num_samples) + audio_data = samples.reshape((-1, self._num_channels)).T + layout = "stereo" + + # Create AudioFrame + frame = AudioFrame.from_ndarray(audio_data, layout=layout) + frame.sample_rate = self._sample_rate + frame.pts = self._timestamp + frame.time_base = fractions.Fraction(1, self._sample_rate) + + self._timestamp += self._samples_per_chunk + + return frame diff --git a/src/anam/client.py b/src/anam/client.py index 432adf5..73075d3 100644 --- a/src/anam/client.py +++ b/src/anam/client.py @@ -245,7 +245,6 @@ async def connect_async(self) -> "Session": on_message=self._handle_data_message, on_connection_established=self._handle_connection_established, on_connection_closed=self._handle_connection_closed, - disable_input_audio=self._options.disable_input_audio, custom_ice_servers=self._options.ice_servers, ) diff --git a/src/anam/types.py b/src/anam/types.py index 0f4fddb..dcd4018 100644 --- a/src/anam/types.py +++ b/src/anam/types.py @@ -109,7 +109,6 @@ class ClientOptions: Args: api_base_url: Base URL for the Anam API. api_version: API version to use. - disable_input_audio: If True, don't capture/send microphone audio. ice_servers: Custom ICE servers for WebRTC (optional). client_label: Custom label for session tracking (optional). Defaults to 'python-sdk' if not specified. @@ -117,7 +116,6 @@ class ClientOptions: api_base_url: str = "https://api.anam.ai" api_version: str = "v1" - disable_input_audio: bool = False ice_servers: list[dict[str, Any]] | None = None client_label: str | None = None diff --git a/tests/test_client.py b/tests/test_client.py index 0743c82..ea97f89 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -55,7 +55,6 @@ def test_init_with_options(self) -> None: """Test initialization with ClientOptions.""" options = ClientOptions( api_base_url="https://custom.api.com", - disable_input_audio=True, ) client = AnamClient( api_key="test-key", @@ -63,7 +62,6 @@ def test_init_with_options(self) -> None: options=options, ) assert client._options.api_base_url == "https://custom.api.com" - assert client._options.disable_input_audio is True class TestAnamClientEvents: diff --git a/uv.lock b/uv.lock index de09101..9d44fe9 100644 --- a/uv.lock +++ b/uv.lock @@ -193,6 +193,7 @@ source = { editable = "." } dependencies = [ { name = "aiohttp" }, { name = "aiortc" }, + { name = "av" }, { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "numpy", version = "2.3.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "python-dotenv" }, @@ -215,6 +216,7 @@ display = [ requires-dist = [ { name = "aiohttp", specifier = ">=3.9.0" }, { name = "aiortc", specifier = ">=1.14.0" }, + { name = "av", specifier = ">=16.0.1" }, { name = "mypy", marker = "extra == 'dev'", specifier = ">=1.10.0" }, { name = "numpy", specifier = ">=1.26.0" }, { name = "opencv-python", marker = "extra == 'display'", specifier = ">=4.9.0" },