From 37e2953bb57d87dcddc5ca64c3548c8405e8b3c6 Mon Sep 17 00:00:00 2001 From: Bryan Anderson Date: Sat, 8 Feb 2025 04:54:52 +0000 Subject: [PATCH] Add PlayDialogArabic voice engine --- README.md | 7 ++++--- pyht/async_client.py | 4 ++-- pyht/client.py | 10 +++++----- pyht/inference_coordinates.py | 2 +- pyht/utils.py | 18 +++++++++++++++--- 5 files changed, 27 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 9521108..35e90ea 100644 --- a/README.md +++ b/README.md @@ -81,11 +81,12 @@ The `tts` method takes the following arguments: - `voice_engine`: The voice engine to use for the TTS request; a string (default `Play3.0-mini-http`). - `PlayDialog`: Our large, expressive English model, which also supports multi-turn two-speaker dialogues. - `PlayDialogMultilingual`: Our large, expressive multilingual model, which also supports multi-turn two-speaker dialogues. + - `PlayDialogArabic`: Our large, expressive Arabic model, which also supports multi-turn two-speaker dialogues. - `Play3.0-mini`: Our small, fast multilingual model. - `PlayHT2.0-turbo`: Our legacy English-only model - `protocol`: The protocol to use to communicate with the Play API (`http` by default except for `PlayHT2.0-turbo` which is `grpc` by default). - - `http`: Streaming and non-streaming audio over HTTP (supports `Play3.0-mini`, `PlayDialog`, and `PlayDialogMultilingual`). - - `ws`: Streaming audio over WebSockets (supports `Play3.0-mini`, `PlayDialog`, and `PlayDialogMultilingual`). + - `http`: Streaming and non-streaming audio over HTTP (supports `Play3.0-mini` and `PlayDialog*`). + - `ws`: Streaming audio over WebSockets (supports `Play3.0-mini` and `PlayDialog*`). - `grpc`: Streaming audio over gRPC (supports `PlayHT2.0-turbo` for all, and `Play3.0-mini` ONLY for Play On-Prem customers). - `streaming`: Whether or not to stream the audio in chunks (default True); non-streaming is only enabled for HTTP endpoints. @@ -157,7 +158,7 @@ The `TTSOptions` class is used to specify the options for the TTS request. It ha - `UKRAINIAN` - `URDU` - `XHOSA` -- The following options are additional inference-time hyperparameters which only apply to the `PlayDialog` and `PlayDialogMultilingual` models; if unset, the model will use default values chosen by Play. +- The following options are additional inference-time hyperparameters which only apply to the `PlayDialog*` models; if unset, the model will use default values chosen by Play. - `voice_2` (multi-turn dialogue only): The second voice to use for a multi-turn TTS request; a string. - A URL pointing to a Play voice manifest file. - `turn_prefix` (multi-turn dialogue only): The prefix for the first speaker's turns in a multi-turn TTS request; a string. diff --git a/pyht/async_client.py b/pyht/async_client.py index e376625..419f0f5 100644 --- a/pyht/async_client.py +++ b/pyht/async_client.py @@ -369,7 +369,7 @@ async def _tts_http( streaming: bool = True, context: Optional[AsyncContext] = None ) -> AsyncIterable[bytes]: - supported_voice_engines = ["Play3.0-mini", "PlayDialog", "PlayDialogMultilingual"] + supported_voice_engines = ["Play3.0-mini", "PlayDialog", "PlayDialogMultilingual", "PlayDialogArabic"] if voice_engine not in supported_voice_engines: raise ValueError(f"Only {supported_voice_engines} are supported in the HTTP API; got {voice_engine}") @@ -436,7 +436,7 @@ async def _tts_ws( metrics: Metrics, context: Optional[AsyncContext] = None ) -> AsyncIterable[bytes]: - supported_voice_engines = ["Play3.0-mini", "PlayDialog", "PlayDialogMultilingual"] + supported_voice_engines = ["Play3.0-mini", "PlayDialog", "PlayDialogMultilingual", "PlayDialogArabic"] if voice_engine not in supported_voice_engines: raise ValueError(f"Only {supported_voice_engines} are supported in the WebSocket API; got {voice_engine}") diff --git a/pyht/client.py b/pyht/client.py index dc947e7..b5eb08e 100644 --- a/pyht/client.py +++ b/pyht/client.py @@ -64,7 +64,7 @@ class HTTPFormat(Enum): FORMAT_PCM = "pcm" -# PlayDialog and PlayDialogMultilingual only +# PlayDialog* only class CandidateRankingMethod(Enum): # non-streaming only DescriptionASRWithMeanProbRank = "description_asr_with_mean_prob" @@ -199,7 +199,7 @@ class TTSOptions: # only applies to Play3.0 and PlayDialogMultilingual language: Optional[Language] = None - # only apply to PlayDialog and PlayDialogMultilingual + # only apply to PlayDialog* # leave the _2 params None if generating single-speaker audio voice_2: Optional[str] = None turn_prefix: Optional[str] = None @@ -293,7 +293,7 @@ def http_prepare_dict(text: List[str], options: TTSOptions, voice_engine: str) - "language": options.language.value if options.language is not None else None, "version": version, - # PlayDialog and PlayDialogMultilingual + # PlayDialog* # leave the _2 params None if generating single-speaker audio "voice_2": options.voice_2, "turn_prefix": options.turn_prefix, @@ -640,7 +640,7 @@ def _tts_http( metrics: Metrics, streaming: bool = True ) -> Iterable[bytes]: - supported_voice_engines = ["Play3.0-mini", "PlayDialog", "PlayDialogMultilingual"] + supported_voice_engines = ["Play3.0-mini", "PlayDialog", "PlayDialogMultilingual", "PlayDialogArabic"] if voice_engine not in supported_voice_engines: raise ValueError(f"Only {supported_voice_engines} are supported in the HTTP API; got {voice_engine}") @@ -705,7 +705,7 @@ def _tts_ws( voice_engine: Optional[str], metrics: Metrics ) -> Iterable[bytes]: - supported_voice_engines = ["Play3.0-mini", "PlayDialog", "PlayDialogMultilingual"] + supported_voice_engines = ["Play3.0-mini", "PlayDialog", "PlayDialogMultilingual", "PlayDialogArabic"] if voice_engine not in supported_voice_engines: raise ValueError(f"Only {supported_voice_engines} are supported in the WebSocket API; got {voice_engine}") diff --git a/pyht/inference_coordinates.py b/pyht/inference_coordinates.py index 3037661..14faed7 100644 --- a/pyht/inference_coordinates.py +++ b/pyht/inference_coordinates.py @@ -9,7 +9,7 @@ import aiohttp -REQUIRED_MODELS = ["Play3.0-mini", "PlayDialog", "PlayDialogMultilingual"] +REQUIRED_MODELS = ["Play3.0-mini", "PlayDialog", "PlayDialogMultilingual", "PlayDialogArabic"] REQUIRED_URLS = ["http_streaming_url", "websocket_url"] diff --git a/pyht/utils.py b/pyht/utils.py index fc9778c..1c8009a 100644 --- a/pyht/utils.py +++ b/pyht/utils.py @@ -68,8 +68,9 @@ def get_voice_engine_and_protocol(voice_engine: Optional[str], protocol: Optiona voice_engine, protocol = _convert_deprecated_voice_engine(voice_engine, protocol) elif voice_engine in ["PlayDialog", "PlayDialog-http", "PlayDialog-ws", "PlayDialogMultilingual", - "PlayDialogMultilingual-http", "PlayDialogMultilingual-ws"]: - if voice_engine in ["PlayDialog", "PlayDialogMultilingual"]: + "PlayDialogMultilingual-http", "PlayDialogMultilingual-ws", "PlayDialogArabic", + "PlayDialogArabic-http", "PlayDialogArabic-ws"]: + if voice_engine in ["PlayDialog", "PlayDialogMultilingual", "PlayDialogArabic"]: if not protocol: logging.warning("No protocol specified; using http") protocol = "http" @@ -81,7 +82,7 @@ def get_voice_engine_and_protocol(voice_engine: Optional[str], protocol: Optiona else: raise ValueError(f"Invalid voice engine: {voice_engine} (must be Play3.0-mini, PlayDialog, \ - PlayDialogMultilingual, or PlayHT2.0-turbo).") + PlayDialogMultilingual, PlayDialogArabic, or PlayHT2.0-turbo).") return voice_engine, protocol @@ -146,6 +147,17 @@ def main(): assert get_voice_engine_and_protocol("PlayDialogMultilingual-ws", None) == ("PlayDialogMultilingual", "ws") assert get_voice_engine_and_protocol("PlayDialogMultilingual-ws", "") == ("PlayDialogMultilingual", "ws") + assert get_voice_engine_and_protocol("PlayDialogArabic", "http") == ("PlayDialogArabic", "http") + assert get_voice_engine_and_protocol("PlayDialogArabic", "ws") == ("PlayDialogArabic", "ws") + assert get_voice_engine_and_protocol("PlayDialogArabic", None) == ("PlayDialogArabic", "http") + assert get_voice_engine_and_protocol("PlayDialogArabic", "") == ("PlayDialogArabic", "http") + assert get_voice_engine_and_protocol("PlayDialogArabic-http", "http") == ("PlayDialogArabic", "http") + assert get_voice_engine_and_protocol("PlayDialogArabic-http", None) == ("PlayDialogArabic", "http") + assert get_voice_engine_and_protocol("PlayDialogArabic-http", "") == ("PlayDialogArabic", "http") + assert get_voice_engine_and_protocol("PlayDialogArabic-ws", "ws") == ("PlayDialogArabic", "ws") + assert get_voice_engine_and_protocol("PlayDialogArabic-ws", None) == ("PlayDialogArabic", "ws") + assert get_voice_engine_and_protocol("PlayDialogArabic-ws", "") == ("PlayDialogArabic", "ws") + assert get_voice_engine_and_protocol(None, "grpc") == ("PlayHT2.0-turbo", "grpc") assert get_voice_engine_and_protocol("", "grpc") == ("PlayHT2.0-turbo", "grpc") assert get_voice_engine_and_protocol("PlayHT2.0-turbo", "grpc") == ("PlayHT2.0-turbo", "grpc")