diff --git a/examples/getting-started/01_simple_tts.py b/examples/getting-started/01_simple_tts.py index 2574ce7..0587a3f 100644 --- a/examples/getting-started/01_simple_tts.py +++ b/examples/getting-started/01_simple_tts.py @@ -18,7 +18,8 @@ - Audio file contains the spoken text """ -import os +from pathlib import Path + from fishaudio import FishAudio from fishaudio.utils import save @@ -43,7 +44,7 @@ def main(): save(audio, output_file) print(f"✓ Audio saved to {output_file}") - print(f" File size: {os.path.getsize(output_file) / 1024:.2f} KB") + print(f" File size: {Path(output_file).stat().st_size / 1024:.2f} KB") if __name__ == "__main__": diff --git a/examples/getting_started.ipynb b/examples/getting_started.ipynb index 3951e24..2ae6cfb 100644 --- a/examples/getting_started.ipynb +++ b/examples/getting_started.ipynb @@ -39,7 +39,18 @@ } }, "outputs": [], - "source": "from dotenv import load_dotenv\nfrom fishaudio import FishAudio\nfrom fishaudio.utils import play\n# from fishaudio.utils import save # Uncomment if saving audio to file\n\nload_dotenv()\n\nclient = FishAudio()" + "source": [ + "from dotenv import load_dotenv\n", + "\n", + "from fishaudio import FishAudio\n", + "from fishaudio.utils import play\n", + "\n", + "# from fishaudio.utils import save # Uncomment if saving audio to file\n", + "\n", + "load_dotenv()\n", + "\n", + "client = FishAudio()" + ] }, { "cell_type": "markdown", diff --git a/pyproject.toml b/pyproject.toml index 02f4432..f0a2602 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -83,5 +83,34 @@ pages = [ {title = "Exceptions", name="fishaudio/exceptions", contents = ["fishaudio.exceptions.*"] }, ] +[tool.ruff.lint] +extend-select = [ + "F", # Pyflakes rules + "W", # PyCodeStyle warnings + "E", # PyCodeStyle errors + "I", # Sort imports properly + "UP", # Warn if certain things can changed due to newer Python versions + "C4", # Catch incorrect use of comprehensions, dict, list, etc + "FA", # Enforce from __future__ import annotations + "ISC", # Good use of string concatenation + "ICN", # Use common import conventions + "RET", # Good return practices + "SIM", # Common simplification rules + "TID", # Some good import practices + "TC", # Enforce importing certain types in a TYPE_CHECKING block + "PTH", # Use pathlib instead of os.path + "TD", # Be diligent with TODO comments + "NPY", # Some numpy-specific things +] +ignore = [ + "E501", # Line too long (handled by ruff format) +] + +[tool.ruff.lint.flake8-type-checking] +runtime-evaluated-base-classes = ["pydantic.BaseModel"] + +[tool.ruff.lint.pyupgrade] +keep-runtime-typing = true + [tool.uv.sources] fish-audio-sdk = { workspace = true } diff --git a/scripts/copy_docs.py b/scripts/copy_docs.py index ea82907..3053f0a 100644 --- a/scripts/copy_docs.py +++ b/scripts/copy_docs.py @@ -10,6 +10,8 @@ python scripts/copy_docs.py sdk docs # In CI context """ +from __future__ import annotations + import argparse import shutil from pathlib import Path diff --git a/src/fish_audio_sdk/__init__.py b/src/fish_audio_sdk/__init__.py index d2ca886..d38614b 100644 --- a/src/fish_audio_sdk/__init__.py +++ b/src/fish_audio_sdk/__init__.py @@ -1,18 +1,18 @@ from .apis import Session from .exceptions import HttpCodeErr, WebSocketErr from .schemas import ( + APICreditEntity, ASRRequest, - TTSRequest, - ReferenceAudio, - Prosody, - PaginatedResponse, + CloseEvent, ModelEntity, - APICreditEntity, + PaginatedResponse, + Prosody, + ReferenceAudio, StartEvent, TextEvent, - CloseEvent, + TTSRequest, ) -from .websocket import WebSocketSession, AsyncWebSocketSession +from .websocket import AsyncWebSocketSession, WebSocketSession __all__ = [ "Session", diff --git a/src/fish_audio_sdk/apis.py b/src/fish_audio_sdk/apis.py index 9e1f4b5..2ae4e84 100644 --- a/src/fish_audio_sdk/apis.py +++ b/src/fish_audio_sdk/apis.py @@ -1,4 +1,6 @@ -from typing import Generator, Literal +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal import ormsgpack @@ -7,13 +9,16 @@ APICreditEntity, ASRRequest, ASRResponse, - ModelEntity, Backends, + ModelEntity, PackageEntity, PaginatedResponse, TTSRequest, ) +if TYPE_CHECKING: + from collections.abc import Generator + class Session(RemoteCall): @convert_stream diff --git a/src/fish_audio_sdk/io.py b/src/fish_audio_sdk/io.py index 5000ff3..304dadb 100644 --- a/src/fish_audio_sdk/io.py +++ b/src/fish_audio_sdk/io.py @@ -1,21 +1,20 @@ +from __future__ import annotations + import dataclasses import typing +from collections.abc import AsyncGenerator, Awaitable, Generator from http.client import responses as http_responses from typing import ( Any, - AsyncGenerator, - Awaitable, Callable, - Generator, Generic, TypeVar, ) -from typing_extensions import Concatenate, ParamSpec - import httpx import httpx._client import httpx._types +from typing_extensions import Concatenate, ParamSpec from .exceptions import HttpCodeErr @@ -194,8 +193,7 @@ def sync_wrapper(self: RemoteCall, *args: P.args, **kwargs: P.kwargs) -> R: return exc.value raise RuntimeError("Generator did not stop") - call = IOCallDescriptor(async_wrapper, sync_wrapper) - return call + return IOCallDescriptor(async_wrapper, sync_wrapper) GStream = G[Generator[bytes, bytes, None]] @@ -257,5 +255,4 @@ def sync_wrapper( raise RuntimeError("Generator did not stop") - call = StreamIOCallDescriptor(async_wrapper, sync_wrapper) - return call + return StreamIOCallDescriptor(async_wrapper, sync_wrapper) diff --git a/src/fish_audio_sdk/schemas.py b/src/fish_audio_sdk/schemas.py index ab15f68..81d6323 100644 --- a/src/fish_audio_sdk/schemas.py +++ b/src/fish_audio_sdk/schemas.py @@ -1,10 +1,11 @@ +from __future__ import annotations + import datetime import decimal from typing import Annotated, Generic, Literal, TypeVar from pydantic import BaseModel, Field - Backends = Literal["speech-1.5", "speech-1.6", "agent-x0", "s1", "s1-mini"] Item = TypeVar("Item") diff --git a/src/fish_audio_sdk/websocket.py b/src/fish_audio_sdk/websocket.py index 0920d1b..10f972c 100644 --- a/src/fish_audio_sdk/websocket.py +++ b/src/fish_audio_sdk/websocket.py @@ -1,14 +1,13 @@ import asyncio +from collections.abc import AsyncGenerator, AsyncIterable, Generator, Iterable from concurrent.futures import ThreadPoolExecutor -from typing import AsyncGenerator, AsyncIterable, Generator, Iterable import httpx import ormsgpack -from httpx_ws import WebSocketDisconnect, connect_ws, aconnect_ws +from httpx_ws import WebSocketDisconnect, aconnect_ws, connect_ws from .exceptions import WebSocketErr - -from .schemas import Backends, CloseEvent, StartEvent, TTSRequest, TextEvent +from .schemas import Backends, CloseEvent, StartEvent, TextEvent, TTSRequest class WebSocketSession: diff --git a/src/fishaudio/client.py b/src/fishaudio/client.py index 53be1ec..11f2aa2 100644 --- a/src/fishaudio/client.py +++ b/src/fishaudio/client.py @@ -6,8 +6,8 @@ from .core import AsyncClientWrapper, ClientWrapper from .resources import ( - ASRClient, AccountClient, + ASRClient, AsyncAccountClient, AsyncASRClient, AsyncTTSClient, diff --git a/src/fishaudio/core/client_wrapper.py b/src/fishaudio/core/client_wrapper.py index fcd8e54..4164ba2 100644 --- a/src/fishaudio/core/client_wrapper.py +++ b/src/fishaudio/core/client_wrapper.py @@ -2,12 +2,12 @@ import os from json import JSONDecodeError -from typing import Any, Dict, Optional +from typing import Any, Optional import httpx -from .._version import __version__ -from ..exceptions import ( +from fishaudio._version import __version__ +from fishaudio.exceptions import ( APIError, AuthenticationError, NotFoundError, @@ -15,6 +15,7 @@ RateLimitError, ServerError, ) + from .request_options import RequestOptions @@ -32,16 +33,15 @@ def _raise_for_status(response: httpx.Response) -> None: # Raise specific exception based on status code if status == 401: raise AuthenticationError(status, message, response.text) - elif status == 403: + if status == 403: raise PermissionError(status, message, response.text) - elif status == 404: + if status == 404: raise NotFoundError(status, message, response.text) - elif status == 429: + if status == 429: raise RateLimitError(status, message, response.text) - elif status >= 500: + if status >= 500: raise ServerError(status, message, response.text) - else: - raise APIError(status, message, response.text) + raise APIError(status, message, response.text) class BaseClientWrapper: @@ -61,8 +61,8 @@ def __init__( self.base_url = base_url def get_headers( - self, additional_headers: Optional[Dict[str, str]] = None - ) -> Dict[str, str]: + self, additional_headers: Optional[dict[str, str]] = None + ) -> dict[str, str]: """Build headers including authentication and user agent.""" headers = { "Authorization": f"Bearer {self.api_key}", @@ -73,7 +73,7 @@ def get_headers( return headers def _prepare_request_kwargs( - self, request_options: Optional[RequestOptions], kwargs: Dict[str, Any] + self, request_options: Optional[RequestOptions], kwargs: dict[str, Any] ) -> None: """Prepare request kwargs by merging headers, timeout, and query params.""" # Merge headers diff --git a/src/fishaudio/core/iterators.py b/src/fishaudio/core/iterators.py index fbd5df8..e82fcd7 100644 --- a/src/fishaudio/core/iterators.py +++ b/src/fishaudio/core/iterators.py @@ -1,6 +1,6 @@ """Audio stream wrappers with collection utilities.""" -from typing import AsyncIterator, Iterator +from collections.abc import AsyncIterator, Iterator class AudioStream: diff --git a/src/fishaudio/core/request_options.py b/src/fishaudio/core/request_options.py index 6212cce..a01e07a 100644 --- a/src/fishaudio/core/request_options.py +++ b/src/fishaudio/core/request_options.py @@ -1,6 +1,6 @@ """Request-level options for API calls.""" -from typing import Dict, Optional +from typing import Optional import httpx @@ -21,8 +21,8 @@ def __init__( *, timeout: Optional[float] = None, max_retries: Optional[int] = None, - additional_headers: Optional[Dict[str, str]] = None, - additional_query_params: Optional[Dict[str, str]] = None, + additional_headers: Optional[dict[str, str]] = None, + additional_query_params: Optional[dict[str, str]] = None, ): self.timeout = timeout self.max_retries = max_retries diff --git a/src/fishaudio/core/websocket_options.py b/src/fishaudio/core/websocket_options.py index 1403922..4e9132a 100644 --- a/src/fishaudio/core/websocket_options.py +++ b/src/fishaudio/core/websocket_options.py @@ -1,6 +1,6 @@ """WebSocket-level options for WebSocket connections.""" -from typing import Any, Dict, Optional +from typing import Any, Optional class WebSocketOptions: @@ -40,7 +40,7 @@ def __init__( self.max_message_size_bytes = max_message_size_bytes self.queue_size = queue_size - def to_httpx_ws_kwargs(self) -> Dict[str, Any]: + def to_httpx_ws_kwargs(self) -> dict[str, Any]: """Convert to kwargs dict for httpx_ws aconnect_ws/connect_ws.""" kwargs = {} if self.keepalive_ping_timeout_seconds is not None: diff --git a/src/fishaudio/resources/account.py b/src/fishaudio/resources/account.py index 7ef096d..59393f2 100644 --- a/src/fishaudio/resources/account.py +++ b/src/fishaudio/resources/account.py @@ -2,8 +2,8 @@ from typing import Optional -from ..core import OMIT, AsyncClientWrapper, ClientWrapper, RequestOptions -from ..types import Credits, Package +from fishaudio.core import OMIT, AsyncClientWrapper, ClientWrapper, RequestOptions +from fishaudio.types import Credits, Package class AccountClient: diff --git a/src/fishaudio/resources/asr.py b/src/fishaudio/resources/asr.py index 13b7f82..2202892 100644 --- a/src/fishaudio/resources/asr.py +++ b/src/fishaudio/resources/asr.py @@ -4,8 +4,8 @@ import ormsgpack -from ..core import OMIT, AsyncClientWrapper, ClientWrapper, RequestOptions -from ..types import ASRResponse +from fishaudio.core import OMIT, AsyncClientWrapper, ClientWrapper, RequestOptions +from fishaudio.types import ASRResponse class ASRClient: diff --git a/src/fishaudio/resources/realtime.py b/src/fishaudio/resources/realtime.py index 5036549..1049859 100644 --- a/src/fishaudio/resources/realtime.py +++ b/src/fishaudio/resources/realtime.py @@ -1,14 +1,15 @@ """Real-time WebSocket streaming helpers.""" -from typing import Any, AsyncIterator, Dict, Iterator, Optional +from collections.abc import AsyncIterator, Iterator +from typing import Any, Optional import ormsgpack from httpx_ws import WebSocketDisconnect -from ..exceptions import WebSocketError +from fishaudio.exceptions import WebSocketError -def _should_stop(data: Dict[str, Any]) -> bool: +def _should_stop(data: dict[str, Any]) -> bool: """ Check if WebSocket event signals stream should stop. @@ -21,7 +22,7 @@ def _should_stop(data: Dict[str, Any]) -> bool: return data.get("event") == "finish" and data.get("reason") == "stop" -def _process_audio_event(data: Dict[str, Any]) -> Optional[bytes]: +def _process_audio_event(data: dict[str, Any]) -> Optional[bytes]: """ Process a WebSocket audio event. @@ -36,7 +37,7 @@ def _process_audio_event(data: Dict[str, Any]) -> Optional[bytes]: """ if data.get("event") == "audio": return data.get("audio") - elif data.get("event") == "finish" and data.get("reason") == "error": + if data.get("event") == "finish" and data.get("reason") == "error": raise WebSocketError("WebSocket stream ended with error") return None # Ignore unknown events diff --git a/src/fishaudio/resources/tts.py b/src/fishaudio/resources/tts.py index 89be523..3ec2343 100644 --- a/src/fishaudio/resources/tts.py +++ b/src/fishaudio/resources/tts.py @@ -1,16 +1,21 @@ """TTS (Text-to-Speech) namespace client.""" import asyncio +from collections.abc import AsyncIterable, Iterable, Iterator from concurrent.futures import ThreadPoolExecutor -from typing import AsyncIterable, Iterable, Iterator, List, Optional, Union +from typing import Optional, Union import ormsgpack from httpx_ws import AsyncWebSocketSession, WebSocketSession, aconnect_ws, connect_ws -from .realtime import aiter_websocket_audio, iter_websocket_audio -from ..core import AsyncClientWrapper, ClientWrapper, RequestOptions, WebSocketOptions -from ..core.iterators import AsyncAudioStream, AudioStream -from ..types import ( +from fishaudio.core import ( + AsyncClientWrapper, + ClientWrapper, + RequestOptions, + WebSocketOptions, +) +from fishaudio.core.iterators import AsyncAudioStream, AudioStream +from fishaudio.types import ( AudioFormat, CloseEvent, FlushEvent, @@ -24,6 +29,8 @@ TTSRequest, ) +from .realtime import aiter_websocket_audio, iter_websocket_audio + def _config_to_tts_request(config: TTSConfig, text: str) -> TTSRequest: """Convert TTSConfig to TTSRequest with text.""" @@ -69,7 +76,7 @@ def stream( *, text: str, reference_id: Optional[str] = None, - references: Optional[List[ReferenceAudio]] = None, + references: Optional[list[ReferenceAudio]] = None, format: Optional[AudioFormat] = None, latency: Optional[LatencyMode] = None, speed: Optional[float] = None, @@ -151,7 +158,7 @@ def convert( *, text: str, reference_id: Optional[str] = None, - references: Optional[List[ReferenceAudio]] = None, + references: Optional[list[ReferenceAudio]] = None, format: Optional[AudioFormat] = None, latency: Optional[LatencyMode] = None, speed: Optional[float] = None, @@ -213,7 +220,7 @@ def stream_websocket( text_stream: Iterable[Union[str, TextEvent, FlushEvent]], *, reference_id: Optional[str] = None, - references: Optional[List[ReferenceAudio]] = None, + references: Optional[list[ReferenceAudio]] = None, format: Optional[AudioFormat] = None, latency: Optional[LatencyMode] = None, speed: Optional[float] = None, @@ -351,8 +358,7 @@ def sender(): sender_future = executor.submit(sender) # Process incoming audio messages - for audio_chunk in iter_websocket_audio(ws): - yield audio_chunk + yield from iter_websocket_audio(ws) sender_future.result() finally: @@ -370,7 +376,7 @@ async def stream( *, text: str, reference_id: Optional[str] = None, - references: Optional[List[ReferenceAudio]] = None, + references: Optional[list[ReferenceAudio]] = None, format: Optional[AudioFormat] = None, latency: Optional[LatencyMode] = None, speed: Optional[float] = None, @@ -453,7 +459,7 @@ async def convert( *, text: str, reference_id: Optional[str] = None, - references: Optional[List[ReferenceAudio]] = None, + references: Optional[list[ReferenceAudio]] = None, format: Optional[AudioFormat] = None, latency: Optional[LatencyMode] = None, speed: Optional[float] = None, @@ -516,7 +522,7 @@ async def stream_websocket( text_stream: AsyncIterable[Union[str, TextEvent, FlushEvent]], *, reference_id: Optional[str] = None, - references: Optional[List[ReferenceAudio]] = None, + references: Optional[list[ReferenceAudio]] = None, format: Optional[AudioFormat] = None, latency: Optional[LatencyMode] = None, speed: Optional[float] = None, diff --git a/src/fishaudio/resources/voices.py b/src/fishaudio/resources/voices.py index 629b8fa..205ba27 100644 --- a/src/fishaudio/resources/voices.py +++ b/src/fishaudio/resources/voices.py @@ -1,9 +1,10 @@ """Voice management namespace client.""" -from typing import List, Optional, Union +import builtins +from typing import Optional, Union -from ..core import OMIT, AsyncClientWrapper, ClientWrapper, RequestOptions -from ..types import PaginatedResponse, Visibility, Voice +from fishaudio.core import OMIT, AsyncClientWrapper, ClientWrapper, RequestOptions +from fishaudio.types import PaginatedResponse, Visibility, Voice def _filter_none(d: dict) -> dict: @@ -23,11 +24,11 @@ def list( page_size: int = 10, page_number: int = 1, title: Optional[str] = OMIT, - tags: Optional[Union[List[str], str]] = OMIT, + tags: Optional[Union[list[str], str]] = OMIT, self_only: bool = False, author_id: Optional[str] = OMIT, - language: Optional[Union[List[str], str]] = OMIT, - title_language: Optional[Union[List[str], str]] = OMIT, + language: Optional[Union[list[str], str]] = OMIT, + title_language: Optional[Union[list[str], str]] = OMIT, sort_by: str = "task_count", request_options: Optional[RequestOptions] = None, ) -> PaginatedResponse[Voice]: @@ -123,10 +124,10 @@ def create( self, *, title: str, - voices: List[bytes], + voices: builtins.list[bytes], description: Optional[str] = OMIT, - texts: Optional[List[str]] = OMIT, - tags: Optional[List[str]] = OMIT, + texts: Optional[builtins.list[str]] = OMIT, + tags: Optional[builtins.list[str]] = OMIT, cover_image: Optional[bytes] = OMIT, visibility: Visibility = "private", train_mode: str = "fast", @@ -203,7 +204,7 @@ def update( description: Optional[str] = OMIT, cover_image: Optional[bytes] = OMIT, visibility: Optional[Visibility] = OMIT, - tags: Optional[List[str]] = OMIT, + tags: Optional[builtins.list[str]] = OMIT, request_options: Optional[RequestOptions] = None, ) -> None: """ @@ -290,11 +291,11 @@ async def list( page_size: int = 10, page_number: int = 1, title: Optional[str] = OMIT, - tags: Optional[Union[List[str], str]] = OMIT, + tags: Optional[Union[list[str], str]] = OMIT, self_only: bool = False, author_id: Optional[str] = OMIT, - language: Optional[Union[List[str], str]] = OMIT, - title_language: Optional[Union[List[str], str]] = OMIT, + language: Optional[Union[list[str], str]] = OMIT, + title_language: Optional[Union[list[str], str]] = OMIT, sort_by: str = "task_count", request_options: Optional[RequestOptions] = None, ) -> PaginatedResponse[Voice]: @@ -340,10 +341,10 @@ async def create( self, *, title: str, - voices: List[bytes], + voices: builtins.list[bytes], description: Optional[str] = OMIT, - texts: Optional[List[str]] = OMIT, - tags: Optional[List[str]] = OMIT, + texts: Optional[builtins.list[str]] = OMIT, + tags: Optional[builtins.list[str]] = OMIT, cover_image: Optional[bytes] = OMIT, visibility: Visibility = "private", train_mode: str = "fast", @@ -386,7 +387,7 @@ async def update( description: Optional[str] = OMIT, cover_image: Optional[bytes] = OMIT, visibility: Optional[Visibility] = OMIT, - tags: Optional[List[str]] = OMIT, + tags: Optional[builtins.list[str]] = OMIT, request_options: Optional[RequestOptions] = None, ) -> None: """Update voice metadata (async). See sync version for details.""" diff --git a/src/fishaudio/types/asr.py b/src/fishaudio/types/asr.py index db73916..41bec1c 100644 --- a/src/fishaudio/types/asr.py +++ b/src/fishaudio/types/asr.py @@ -1,7 +1,5 @@ """ASR (Automatic Speech Recognition) related types.""" -from typing import List - from pydantic import BaseModel @@ -30,4 +28,4 @@ class ASRResponse(BaseModel): text: str duration: float # Duration in milliseconds - segments: List[ASRSegment] + segments: list[ASRSegment] diff --git a/src/fishaudio/types/shared.py b/src/fishaudio/types/shared.py index 1e756d9..427d5da 100644 --- a/src/fishaudio/types/shared.py +++ b/src/fishaudio/types/shared.py @@ -1,6 +1,6 @@ """Shared types used across the SDK.""" -from typing import Generic, List, Literal, TypeVar +from typing import Generic, Literal, TypeVar from pydantic import BaseModel @@ -17,7 +17,7 @@ class PaginatedResponse(BaseModel, Generic[T]): """ total: int - items: List[T] + items: list[T] # Model types diff --git a/src/fishaudio/types/tts.py b/src/fishaudio/types/tts.py index 00f72da..a6bfbc3 100644 --- a/src/fishaudio/types/tts.py +++ b/src/fishaudio/types/tts.py @@ -1,6 +1,6 @@ """TTS-related types.""" -from typing import Annotated, List, Literal, Optional +from typing import Annotated, Literal, Optional from pydantic import BaseModel, Field @@ -95,7 +95,7 @@ class TTSConfig(BaseModel): # Voice/style settings reference_id: Optional[str] = None - references: List[ReferenceAudio] = [] + references: list[ReferenceAudio] = [] prosody: Optional[Prosody] = None # Model parameters @@ -144,7 +144,7 @@ class TTSRequest(BaseModel): sample_rate: Optional[int] = None mp3_bitrate: Literal[64, 128, 192] = 128 opus_bitrate: Literal[-1000, 24, 32, 48, 64] = 32 - references: List[ReferenceAudio] = [] + references: list[ReferenceAudio] = [] reference_id: Optional[str] = None normalize: bool = True latency: LatencyMode = "balanced" diff --git a/src/fishaudio/types/voices.py b/src/fishaudio/types/voices.py index 2fd7347..9ae8c3a 100644 --- a/src/fishaudio/types/voices.py +++ b/src/fishaudio/types/voices.py @@ -1,7 +1,7 @@ """Voice and model management types.""" import datetime -from typing import List, Literal +from typing import Literal from pydantic import BaseModel, Field @@ -75,11 +75,11 @@ class Voice(BaseModel): cover_image: str train_mode: TrainMode state: ModelState - tags: List[str] - samples: List[Sample] + tags: list[str] + samples: list[Sample] created_at: datetime.datetime updated_at: datetime.datetime - languages: List[str] + languages: list[str] visibility: Visibility lock_visibility: bool like_count: int diff --git a/src/fishaudio/utils/play.py b/src/fishaudio/utils/play.py index 90a6390..5973b24 100644 --- a/src/fishaudio/utils/play.py +++ b/src/fishaudio/utils/play.py @@ -2,9 +2,10 @@ import io import subprocess -from typing import Iterable, Union +from collections.abc import Iterable +from typing import Union -from ..exceptions import DependencyError +from fishaudio.exceptions import DependencyError def _is_installed(command: str) -> bool: diff --git a/src/fishaudio/utils/save.py b/src/fishaudio/utils/save.py index 0801520..03b86ae 100644 --- a/src/fishaudio/utils/save.py +++ b/src/fishaudio/utils/save.py @@ -1,6 +1,8 @@ """Audio saving utility.""" -from typing import Iterable, Union +from collections.abc import Iterable +from pathlib import Path +from typing import Union def save(audio: Union[bytes, Iterable[bytes]], filename: str) -> None: @@ -31,5 +33,5 @@ def save(audio: Union[bytes, Iterable[bytes]], filename: str) -> None: audio = b"".join(audio) # Write to file - with open(filename, "wb") as f: + with Path(filename).open("wb") as f: f.write(audio) diff --git a/src/fishaudio/utils/stream.py b/src/fishaudio/utils/stream.py index 887c42f..27db478 100644 --- a/src/fishaudio/utils/stream.py +++ b/src/fishaudio/utils/stream.py @@ -1,9 +1,9 @@ """Audio streaming utility.""" import subprocess -from typing import Iterator +from collections.abc import Iterator -from ..exceptions import DependencyError +from fishaudio.exceptions import DependencyError def _is_installed(command: str) -> bool: diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index df76c03..d0c5836 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -80,10 +80,7 @@ def _save(audio: bytes | list[bytes], filename: str) -> Path: Returns: Path to the saved file """ - if isinstance(audio, bytes): - complete_audio = audio - else: - complete_audio = b"".join(audio) + complete_audio = audio if isinstance(audio, bytes) else b"".join(audio) output_file = OUTPUT_DIR / filename output_file.write_bytes(complete_audio) return output_file diff --git a/tests/integration/test_account_integration.py b/tests/integration/test_account_integration.py index de16e85..1bc0a00 100644 --- a/tests/integration/test_account_integration.py +++ b/tests/integration/test_account_integration.py @@ -1,8 +1,9 @@ """Integration tests for Account functionality.""" -import pytest from decimal import Decimal +import pytest + from fishaudio.types import Credits, Package diff --git a/tests/integration/test_tts_websocket_integration.py b/tests/integration/test_tts_websocket_integration.py index 5e80321..1dc1604 100644 --- a/tests/integration/test_tts_websocket_integration.py +++ b/tests/integration/test_tts_websocket_integration.py @@ -5,8 +5,9 @@ import pytest from fishaudio import WebSocketOptions -from fishaudio.types import Prosody, TTSConfig, TextEvent, FlushEvent +from fishaudio.types import FlushEvent, Prosody, TextEvent, TTSConfig from fishaudio.types.shared import Model + from .conftest import TEST_REFERENCE_ID @@ -112,8 +113,7 @@ def text_stream(): "And that all audio is received correctly. ", "Finally, we end the stream here.", ] - for sentence in sentences: - yield sentence + yield from sentences audio_chunks = list(client.tts.stream_websocket(text_stream())) @@ -185,8 +185,7 @@ def text_stream(): "Long-form content generation is now much more reliable. ", "The implementation passes through all necessary parameters to the underlying httpx_ws library. ", ] - for sentence in long_text: - yield sentence + yield from long_text # This should succeed with increased timeout audio_chunks = list( diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 3a857d9..58c6a58 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -1,8 +1,9 @@ """Shared pytest fixtures for unit tests.""" -import pytest -from unittest.mock import Mock, AsyncMock +from unittest.mock import AsyncMock, Mock + import httpx +import pytest @pytest.fixture diff --git a/tests/unit/test_account.py b/tests/unit/test_account.py index 97a0faf..2b69d42 100644 --- a/tests/unit/test_account.py +++ b/tests/unit/test_account.py @@ -1,10 +1,11 @@ """Tests for Account namespace client.""" -import pytest -from unittest.mock import Mock, AsyncMock from decimal import Decimal +from unittest.mock import AsyncMock, Mock + +import pytest -from fishaudio.core import ClientWrapper, AsyncClientWrapper, RequestOptions +from fishaudio.core import AsyncClientWrapper, ClientWrapper, RequestOptions from fishaudio.resources.account import AccountClient, AsyncAccountClient from fishaudio.types import Credits, Package diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index d288491..15c7508 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -1,14 +1,15 @@ """Tests for main client classes.""" -import pytest from unittest.mock import patch -from fishaudio import FishAudio, AsyncFishAudio +import pytest + +from fishaudio import AsyncFishAudio, FishAudio from fishaudio.resources import ( - TTSClient, AsyncTTSClient, - VoicesClient, AsyncVoicesClient, + TTSClient, + VoicesClient, ) diff --git a/tests/unit/test_core.py b/tests/unit/test_core.py index af56f9a..e979689 100644 --- a/tests/unit/test_core.py +++ b/tests/unit/test_core.py @@ -1,13 +1,14 @@ """Tests for core components.""" -import pytest from unittest.mock import patch + import httpx +import pytest from fishaudio.core import ( OMIT, - ClientWrapper, AsyncClientWrapper, + ClientWrapper, RequestOptions, WebSocketOptions, ) @@ -100,9 +101,11 @@ def test_init_with_api_key(self, mock_api_key, mock_base_url): assert wrapper.base_url == mock_base_url def test_init_without_api_key_raises(self): - with patch.dict("os.environ", {}, clear=True): - with pytest.raises(ValueError, match="API key must be provided"): - ClientWrapper() + with ( + patch.dict("os.environ", {}, clear=True), + pytest.raises(ValueError, match="API key must be provided"), + ): + ClientWrapper() def test_init_with_env_var(self, mock_api_key): with patch.dict("os.environ", {"FISH_API_KEY": mock_api_key}): @@ -133,9 +136,11 @@ def test_init_with_api_key(self, mock_api_key, mock_base_url): assert wrapper.base_url == mock_base_url def test_init_without_api_key_raises(self): - with patch.dict("os.environ", {}, clear=True): - with pytest.raises(ValueError, match="API key must be provided"): - AsyncClientWrapper() + with ( + patch.dict("os.environ", {}, clear=True), + pytest.raises(ValueError, match="API key must be provided"), + ): + AsyncClientWrapper() def test_get_headers(self, mock_api_key): wrapper = AsyncClientWrapper(api_key=mock_api_key) diff --git a/tests/unit/test_realtime.py b/tests/unit/test_realtime.py index f9e47b4..446ae3c 100644 --- a/tests/unit/test_realtime.py +++ b/tests/unit/test_realtime.py @@ -1,17 +1,18 @@ """Tests for realtime WebSocket streaming helpers.""" -import pytest from unittest.mock import Mock + import ormsgpack +import pytest from httpx_ws import WebSocketDisconnect +from fishaudio.exceptions import WebSocketError from fishaudio.resources.realtime import ( - _should_stop, _process_audio_event, - iter_websocket_audio, + _should_stop, aiter_websocket_audio, + iter_websocket_audio, ) -from fishaudio.exceptions import WebSocketError class TestShouldStop: @@ -260,8 +261,7 @@ async def mock_receive_bytes(): call_count[0] += 1 if call_count[0] == 1: return ormsgpack.packb({"event": "audio", "audio": b"chunk1"}) - else: - raise WebSocketDisconnect() + raise WebSocketDisconnect() mock_ws.receive_bytes = mock_receive_bytes diff --git a/tests/unit/test_tts.py b/tests/unit/test_tts.py index d556c2c..f081aee 100644 --- a/tests/unit/test_tts.py +++ b/tests/unit/test_tts.py @@ -1,12 +1,13 @@ """Tests for TTS namespace client.""" -import pytest -from unittest.mock import Mock, AsyncMock +from unittest.mock import AsyncMock, Mock + import ormsgpack +import pytest -from fishaudio.core import ClientWrapper, AsyncClientWrapper, RequestOptions -from fishaudio.resources.tts import TTSClient, AsyncTTSClient -from fishaudio.types import ReferenceAudio, Prosody, TTSConfig +from fishaudio.core import AsyncClientWrapper, ClientWrapper, RequestOptions +from fishaudio.resources.tts import AsyncTTSClient, TTSClient +from fishaudio.types import Prosody, ReferenceAudio, TTSConfig @pytest.fixture diff --git a/tests/unit/test_tts_realtime.py b/tests/unit/test_tts_realtime.py index 6ca2fa0..eaf9210 100644 --- a/tests/unit/test_tts_realtime.py +++ b/tests/unit/test_tts_realtime.py @@ -1,12 +1,13 @@ """Tests for TTS realtime streaming.""" -import pytest -from unittest.mock import Mock, AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, Mock, patch -from fishaudio.core import ClientWrapper, AsyncClientWrapper, WebSocketOptions -from fishaudio.resources.tts import TTSClient, AsyncTTSClient -from fishaudio.types import Prosody, TTSConfig, TextEvent, FlushEvent, ReferenceAudio import ormsgpack +import pytest + +from fishaudio.core import AsyncClientWrapper, ClientWrapper, WebSocketOptions +from fishaudio.resources.tts import AsyncTTSClient, TTSClient +from fishaudio.types import FlushEvent, Prosody, ReferenceAudio, TextEvent, TTSConfig @pytest.fixture diff --git a/tests/unit/test_types.py b/tests/unit/test_types.py index fcc509c..cfc8446 100644 --- a/tests/unit/test_types.py +++ b/tests/unit/test_types.py @@ -3,16 +3,16 @@ from decimal import Decimal from fishaudio.types import ( - Voice, - PaginatedResponse, ASRResponse, ASRSegment, Credits, Package, - ReferenceAudio, + PaginatedResponse, Prosody, + ReferenceAudio, TTSConfig, TTSRequest, + Voice, ) diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 8effd61..86f8b5b 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -1,36 +1,35 @@ """Tests for utility functions.""" -import pytest -from unittest.mock import Mock, patch, mock_open import subprocess +from unittest.mock import Mock, patch + +import pytest -from fishaudio.utils import play, save, stream from fishaudio.exceptions import DependencyError +from fishaudio.utils import play, save, stream class TestSave: """Test save() function.""" - def test_save_bytes(self): + def test_save_bytes(self, tmp_path): """Test saving bytes to file.""" audio = b"fake audio data" + output_file = tmp_path / "output.mp3" - with patch("builtins.open", mock_open()) as m: - save(audio, "output.mp3") + save(audio, str(output_file)) - m.assert_called_once_with("output.mp3", "wb") - m().write.assert_called_once_with(audio) + assert output_file.read_bytes() == audio - def test_save_iterator(self): + def test_save_iterator(self, tmp_path): """Test saving iterator to file.""" audio = iter([b"chunk1", b"chunk2", b"chunk3"]) + output_file = tmp_path / "output.mp3" - with patch("builtins.open", mock_open()) as m: - save(audio, "output.mp3") + save(audio, str(output_file)) - m.assert_called_once_with("output.mp3", "wb") - # Should consolidate chunks - m().write.assert_called_once_with(b"chunk1chunk2chunk3") + # Should consolidate chunks + assert output_file.read_bytes() == b"chunk1chunk2chunk3" class TestPlay: diff --git a/tests/unit/test_voices.py b/tests/unit/test_voices.py index 8a1fb5d..68b6b62 100644 --- a/tests/unit/test_voices.py +++ b/tests/unit/test_voices.py @@ -1,11 +1,12 @@ """Tests for voices namespace client.""" -import pytest from unittest.mock import Mock +import pytest + from fishaudio.core import ClientWrapper from fishaudio.resources.voices import VoicesClient -from fishaudio.types import Voice, PaginatedResponse +from fishaudio.types import PaginatedResponse, Voice @pytest.fixture