diff --git a/src/fishaudio/resources/account.py b/src/fishaudio/resources/account.py index 02beaf3..7ef096d 100644 --- a/src/fishaudio/resources/account.py +++ b/src/fishaudio/resources/account.py @@ -2,7 +2,7 @@ from typing import Optional -from ..core import AsyncClientWrapper, ClientWrapper, RequestOptions +from ..core import OMIT, AsyncClientWrapper, ClientWrapper, RequestOptions from ..types import Credits, Package @@ -15,12 +15,14 @@ def __init__(self, client_wrapper: ClientWrapper): def get_credits( self, *, + check_free_credit: Optional[bool] = OMIT, request_options: Optional[RequestOptions] = None, ) -> Credits: """ Get API credit balance. Args: + check_free_credit: Whether to check free credit availability request_options: Request-level overrides Returns: @@ -31,11 +33,21 @@ def get_credits( client = FishAudio(api_key="...") credits = client.account.get_credits() print(f"Available credits: {float(credits.credit)}") + + # Check free credit availability + credits = client.account.get_credits(check_free_credit=True) + if credits.has_free_credit: + print("Free credits available!") ``` """ + params = {} + if check_free_credit is not OMIT: + params["check_free_credit"] = check_free_credit + response = self._client.request( "GET", "/wallet/self/api-credit", + params=params, request_options=request_options, ) return Credits.model_validate(response.json()) @@ -78,12 +90,14 @@ def __init__(self, client_wrapper: AsyncClientWrapper): async def get_credits( self, *, + check_free_credit: Optional[bool] = OMIT, request_options: Optional[RequestOptions] = None, ) -> Credits: """ Get API credit balance (async). Args: + check_free_credit: Whether to check free credit availability request_options: Request-level overrides Returns: @@ -94,11 +108,21 @@ async def get_credits( client = AsyncFishAudio(api_key="...") credits = await client.account.get_credits() print(f"Available credits: {float(credits.credit)}") + + # Check free credit availability + credits = await client.account.get_credits(check_free_credit=True) + if credits.has_free_credit: + print("Free credits available!") ``` """ + params = {} + if check_free_credit is not OMIT: + params["check_free_credit"] = check_free_credit + response = await self._client.request( "GET", "/wallet/self/api-credit", + params=params, request_options=request_options, ) return Credits.model_validate(response.json()) diff --git a/src/fishaudio/resources/tts.py b/src/fishaudio/resources/tts.py index ea73141..fef1cd4 100644 --- a/src/fishaudio/resources/tts.py +++ b/src/fishaudio/resources/tts.py @@ -2,7 +2,7 @@ import asyncio from concurrent.futures import ThreadPoolExecutor -from typing import AsyncIterable, Iterable, Iterator, Optional, Union +from typing import AsyncIterable, Iterable, Iterator, List, Optional, Union import ormsgpack from httpx_ws import AsyncWebSocketSession, WebSocketSession, aconnect_ws, connect_ws @@ -10,9 +10,13 @@ from .realtime import aiter_websocket_audio, iter_websocket_audio from ..core import AsyncClientWrapper, ClientWrapper, RequestOptions from ..types import ( + AudioFormat, CloseEvent, FlushEvent, + LatencyMode, Model, + Prosody, + ReferenceAudio, StartEvent, TextEvent, TTSConfig, @@ -58,6 +62,11 @@ def convert( self, *, text: str, + reference_id: Optional[str] = None, + references: Optional[List[ReferenceAudio]] = None, + format: Optional[AudioFormat] = None, + latency: Optional[LatencyMode] = None, + speed: Optional[float] = None, config: TTSConfig = TTSConfig(), model: Model = "s1", request_options: Optional[RequestOptions] = None, @@ -67,6 +76,11 @@ def convert( Args: text: Text to synthesize + reference_id: Voice reference ID (overrides config.reference_id if provided) + references: Reference audio samples (overrides config.references if provided) + format: Audio format - "mp3", "wav", "pcm", or "opus" (overrides config.format if provided) + latency: Latency mode - "normal" or "balanced" (overrides config.latency if provided) + speed: Speech speed multiplier, e.g. 1.5 for 1.5x speed (overrides config.prosody.speed if provided) config: TTS configuration (audio settings, voice, model parameters) model: TTS model to use request_options: Request-level overrides @@ -76,16 +90,40 @@ def convert( Example: ```python - from fishaudio import FishAudio, TTSConfig + from fishaudio import FishAudio, TTSConfig, ReferenceAudio client = FishAudio(api_key="...") # Simple usage with defaults audio = client.tts.convert(text="Hello world") - # Custom configuration - config = TTSConfig(format="wav", mp3_bitrate=192) - audio = client.tts.convert(text="Hello world", config=config) + # With format parameter + audio = client.tts.convert(text="Hello world", format="wav") + + # With speed parameter + audio = client.tts.convert(text="Hello world", speed=1.5) + + # With reference_id parameter + audio = client.tts.convert(text="Hello world", reference_id="your_model_id") + + # With references parameter + audio = client.tts.convert( + text="Hello world", + references=[ReferenceAudio(audio=audio_bytes, text="sample")] + ) + + # Combine multiple parameters + audio = client.tts.convert( + text="Hello world", + format="wav", + speed=1.2, + latency="normal" + ) + + # Parameters override config values + config = TTSConfig(format="mp3", prosody=Prosody(speed=1.0)) + audio = client.tts.convert(text="Hello world", format="wav", config=config) + # Result: format="wav" (parameter wins) with open("output.mp3", "wb") as f: for chunk in audio: @@ -94,6 +132,23 @@ def convert( """ # Build request payload from config request = _config_to_tts_request(config, text) + + # Apply direct parameters (always override config when provided) + if reference_id is not None: + request.reference_id = reference_id + + if references is not None: + request.references = references + + if format is not None: + request.format = format + + if latency is not None: + request.latency = latency + + if speed is not None: + request.prosody = Prosody.from_speed_override(speed, base=config.prosody) + payload = request.model_dump(exclude_none=True) # Make request with streaming @@ -114,6 +169,11 @@ def stream_websocket( self, text_stream: Iterable[Union[str, TextEvent, FlushEvent]], *, + reference_id: Optional[str] = None, + references: Optional[List[ReferenceAudio]] = None, + format: Optional[AudioFormat] = None, + latency: Optional[LatencyMode] = None, + speed: Optional[float] = None, config: TTSConfig = TTSConfig(), model: Model = "s1", max_workers: int = 10, @@ -125,6 +185,11 @@ def stream_websocket( Args: text_stream: Iterator of text chunks to stream + reference_id: Voice reference ID (overrides config.reference_id if provided) + references: Reference audio samples (overrides config.references if provided) + format: Audio format - "mp3", "wav", "pcm", or "opus" (overrides config.format if provided) + latency: Latency mode - "normal" or "balanced" (overrides config.latency if provided) + speed: Speech speed multiplier, e.g. 1.5 for 1.5x speed (overrides config.prosody.speed if provided) config: TTS configuration (audio settings, voice, model parameters) model: TTS model to use max_workers: ThreadPoolExecutor workers for concurrent sender @@ -134,7 +199,7 @@ def stream_websocket( Example: ```python - from fishaudio import FishAudio, TTSConfig + from fishaudio import FishAudio, TTSConfig, ReferenceAudio client = FishAudio(api_key="...") @@ -148,16 +213,60 @@ def text_generator(): for audio_chunk in client.tts.stream_websocket(text_generator()): f.write(audio_chunk) - # Custom configuration - config = TTSConfig(format="wav", latency="normal") + # With format and speed parameters with open("output.wav", "wb") as f: - for audio_chunk in client.tts.stream_websocket(text_generator(), config=config): + for audio_chunk in client.tts.stream_websocket( + text_generator(), + format="wav", + speed=1.3 + ): + f.write(audio_chunk) + + # With reference_id parameter + with open("output.mp3", "wb") as f: + for audio_chunk in client.tts.stream_websocket(text_generator(), reference_id="your_model_id"): + f.write(audio_chunk) + + # With references parameter + with open("output.mp3", "wb") as f: + for audio_chunk in client.tts.stream_websocket( + text_generator(), + references=[ReferenceAudio(audio=audio_bytes, text="sample")] + ): + f.write(audio_chunk) + + # Parameters override config values + config = TTSConfig(format="mp3", latency="balanced") + with open("output.wav", "wb") as f: + for audio_chunk in client.tts.stream_websocket( + text_generator(), + format="wav", # Parameter wins + config=config + ): f.write(audio_chunk) ``` """ # Build TTSRequest from config tts_request = _config_to_tts_request(config, text="") + # Apply direct parameters (always override config when provided) + if reference_id is not None: + tts_request.reference_id = reference_id + + if references is not None: + tts_request.references = references + + if format is not None: + tts_request.format = format + + if latency is not None: + tts_request.latency = latency + + if speed is not None: + tts_request.prosody = Prosody.from_speed_override( + speed, base=config.prosody + ) + executor = ThreadPoolExecutor(max_workers=max_workers) try: @@ -202,6 +311,11 @@ async def convert( self, *, text: str, + reference_id: Optional[str] = None, + references: Optional[List[ReferenceAudio]] = None, + format: Optional[AudioFormat] = None, + latency: Optional[LatencyMode] = None, + speed: Optional[float] = None, config: TTSConfig = TTSConfig(), model: Model = "s1", request_options: Optional[RequestOptions] = None, @@ -211,6 +325,11 @@ async def convert( Args: text: Text to synthesize + reference_id: Voice reference ID (overrides config.reference_id if provided) + references: Reference audio samples (overrides config.references if provided) + format: Audio format - "mp3", "wav", "pcm", or "opus" (overrides config.format if provided) + latency: Latency mode - "normal" or "balanced" (overrides config.latency if provided) + speed: Speech speed multiplier, e.g. 1.5 for 1.5x speed (overrides config.prosody.speed if provided) config: TTS configuration (audio settings, voice, model parameters) model: TTS model to use request_options: Request-level overrides @@ -220,16 +339,40 @@ async def convert( Example: ```python - from fishaudio import AsyncFishAudio, TTSConfig + from fishaudio import AsyncFishAudio, TTSConfig, ReferenceAudio client = AsyncFishAudio(api_key="...") # Simple usage with defaults audio = await client.tts.convert(text="Hello world") - # Custom configuration - config = TTSConfig(format="wav", mp3_bitrate=192) - audio = await client.tts.convert(text="Hello world", config=config) + # With format parameter + audio = await client.tts.convert(text="Hello world", format="wav") + + # With speed parameter + audio = await client.tts.convert(text="Hello world", speed=1.5) + + # With reference_id parameter + audio = await client.tts.convert(text="Hello world", reference_id="your_model_id") + + # With references parameter + audio = await client.tts.convert( + text="Hello world", + references=[ReferenceAudio(audio=audio_bytes, text="sample")] + ) + + # Combine multiple parameters + audio = await client.tts.convert( + text="Hello world", + format="wav", + speed=1.2, + latency="normal" + ) + + # Parameters override config values + config = TTSConfig(format="mp3", prosody=Prosody(speed=1.0)) + audio = await client.tts.convert(text="Hello world", format="wav", config=config) + # Result: format="wav" (parameter wins) async with aiofiles.open("output.mp3", "wb") as f: async for chunk in audio: @@ -238,6 +381,23 @@ async def convert( """ # Build request payload from config request = _config_to_tts_request(config, text) + + # Apply direct parameters (always override config when provided) + if reference_id is not None: + request.reference_id = reference_id + + if references is not None: + request.references = references + + if format is not None: + request.format = format + + if latency is not None: + request.latency = latency + + if speed is not None: + request.prosody = Prosody.from_speed_override(speed, base=config.prosody) + payload = request.model_dump(exclude_none=True) # Make request with streaming @@ -258,6 +418,11 @@ async def stream_websocket( self, text_stream: AsyncIterable[Union[str, TextEvent, FlushEvent]], *, + reference_id: Optional[str] = None, + references: Optional[List[ReferenceAudio]] = None, + format: Optional[AudioFormat] = None, + latency: Optional[LatencyMode] = None, + speed: Optional[float] = None, config: TTSConfig = TTSConfig(), model: Model = "s1", ): @@ -268,6 +433,11 @@ async def stream_websocket( Args: text_stream: Async iterator of text chunks to stream + reference_id: Voice reference ID (overrides config.reference_id if provided) + references: Reference audio samples (overrides config.references if provided) + format: Audio format - "mp3", "wav", "pcm", or "opus" (overrides config.format if provided) + latency: Latency mode - "normal" or "balanced" (overrides config.latency if provided) + speed: Speech speed multiplier, e.g. 1.5 for 1.5x speed (overrides config.prosody.speed if provided) config: TTS configuration (audio settings, voice, model parameters) model: TTS model to use @@ -276,7 +446,7 @@ async def stream_websocket( Example: ```python - from fishaudio import AsyncFishAudio, TTSConfig + from fishaudio import AsyncFishAudio, TTSConfig, ReferenceAudio client = AsyncFishAudio(api_key="...") @@ -290,16 +460,60 @@ async def text_generator(): async for audio_chunk in client.tts.stream_websocket(text_generator()): await f.write(audio_chunk) - # Custom configuration - config = TTSConfig(format="wav", latency="normal") + # With format and speed parameters async with aiofiles.open("output.wav", "wb") as f: - async for audio_chunk in client.tts.stream_websocket(text_generator(), config=config): + async for audio_chunk in client.tts.stream_websocket( + text_generator(), + format="wav", + speed=1.3 + ): + await f.write(audio_chunk) + + # With reference_id parameter + async with aiofiles.open("output.mp3", "wb") as f: + async for audio_chunk in client.tts.stream_websocket(text_generator(), reference_id="your_model_id"): + await f.write(audio_chunk) + + # With references parameter + async with aiofiles.open("output.mp3", "wb") as f: + async for audio_chunk in client.tts.stream_websocket( + text_generator(), + references=[ReferenceAudio(audio=audio_bytes, text="sample")] + ): + await f.write(audio_chunk) + + # Parameters override config values + config = TTSConfig(format="mp3", latency="balanced") + async with aiofiles.open("output.wav", "wb") as f: + async for audio_chunk in client.tts.stream_websocket( + text_generator(), + format="wav", # Parameter wins + config=config + ): await f.write(audio_chunk) ``` """ # Build TTSRequest from config tts_request = _config_to_tts_request(config, text="") + # Apply direct parameters (always override config when provided) + if reference_id is not None: + tts_request.reference_id = reference_id + + if references is not None: + tts_request.references = references + + if format is not None: + tts_request.format = format + + if latency is not None: + tts_request.latency = latency + + if speed is not None: + tts_request.prosody = Prosody.from_speed_override( + speed, base=config.prosody + ) + ws: AsyncWebSocketSession async with aconnect_ws( "/v1/tts/live", diff --git a/src/fishaudio/types/shared.py b/src/fishaudio/types/shared.py index 28ef594..df7ab4a 100644 --- a/src/fishaudio/types/shared.py +++ b/src/fishaudio/types/shared.py @@ -19,7 +19,7 @@ class PaginatedResponse(BaseModel, Generic[T]): Model = Literal["speech-1.5", "speech-1.6", "s1"] # Audio format types -AudioFormat = Literal["wav", "pcm", "mp3"] +AudioFormat = Literal["wav", "pcm", "mp3", "opus"] # Visibility types Visibility = Literal["public", "unlist", "private"] diff --git a/src/fishaudio/types/tts.py b/src/fishaudio/types/tts.py index eb36398..4dd7671 100644 --- a/src/fishaudio/types/tts.py +++ b/src/fishaudio/types/tts.py @@ -4,7 +4,7 @@ from pydantic import BaseModel, Field -from .shared import LatencyMode +from .shared import AudioFormat, LatencyMode class ReferenceAudio(BaseModel): @@ -17,8 +17,26 @@ class ReferenceAudio(BaseModel): class Prosody(BaseModel): """Speech prosody settings (speed and volume).""" - speed: float = 1.0 - volume: float = 0.0 + speed: Annotated[float, Field(ge=0.5, le=2.0)] = 1.0 + volume: Annotated[float, Field(ge=-20.0, le=20.0)] = 0.0 + + @classmethod + def from_speed_override( + cls, speed: float, base: Optional["Prosody"] = None + ) -> "Prosody": + """ + Create Prosody with speed override, preserving volume from base. + + Args: + speed: Speed value to use + base: Base prosody to preserve volume from (if any) + + Returns: + New Prosody instance with overridden speed + """ + if base: + return cls(speed=speed, volume=base.volume) + return cls(speed=speed) class TTSConfig(BaseModel): @@ -30,7 +48,7 @@ class TTSConfig(BaseModel): """ # Audio output settings - format: Literal["wav", "pcm", "mp3"] = "mp3" + format: AudioFormat = "mp3" sample_rate: Optional[int] = None mp3_bitrate: Literal[64, 128, 192] = 128 opus_bitrate: Literal[-1000, 24, 32, 48, 64] = 32 @@ -46,8 +64,8 @@ class TTSConfig(BaseModel): prosody: Optional[Prosody] = None # Model parameters - top_p: float = 0.7 - temperature: float = 0.7 + top_p: Annotated[float, Field(ge=0.0, le=1.0)] = 0.7 + temperature: Annotated[float, Field(ge=0.0, le=1.0)] = 0.7 class TTSRequest(BaseModel): @@ -60,7 +78,7 @@ class TTSRequest(BaseModel): text: str chunk_length: Annotated[int, Field(ge=100, le=300, strict=True)] = 200 - format: Literal["wav", "pcm", "mp3"] = "mp3" + format: AudioFormat = "mp3" sample_rate: Optional[int] = None mp3_bitrate: Literal[64, 128, 192] = 128 opus_bitrate: Literal[-1000, 24, 32, 48, 64] = 32 @@ -69,8 +87,8 @@ class TTSRequest(BaseModel): normalize: bool = True latency: LatencyMode = "balanced" prosody: Optional[Prosody] = None - top_p: float = 0.7 - temperature: float = 0.7 + top_p: Annotated[float, Field(ge=0.0, le=1.0)] = 0.7 + temperature: Annotated[float, Field(ge=0.0, le=1.0)] = 0.7 # WebSocket event types for streaming TTS diff --git a/tests/unit/test_tts.py b/tests/unit/test_tts.py index d41d8ca..6ddff60 100644 --- a/tests/unit/test_tts.py +++ b/tests/unit/test_tts.py @@ -81,6 +81,39 @@ def test_convert_with_reference_id(self, tts_client, mock_client_wrapper): payload = ormsgpack.unpackb(call_args[1]["content"]) assert payload["reference_id"] == "voice_123" + def test_convert_with_reference_id_parameter(self, tts_client, mock_client_wrapper): + """Test TTS with reference_id as direct parameter.""" + mock_response = Mock() + mock_response.iter_bytes.return_value = iter([b"audio"]) + mock_client_wrapper.request.return_value = mock_response + + list(tts_client.convert(text="Hello", reference_id="voice_456")) + + # Verify reference_id in payload + call_args = mock_client_wrapper.request.call_args + payload = ormsgpack.unpackb(call_args[1]["content"]) + assert payload["reference_id"] == "voice_456" + + def test_convert_parameter_reference_id_overrides_config( + self, tts_client, mock_client_wrapper + ): + """Test that parameter reference_id overrides config.reference_id.""" + mock_response = Mock() + mock_response.iter_bytes.return_value = iter([b"audio"]) + mock_client_wrapper.request.return_value = mock_response + + config = TTSConfig(reference_id="voice_from_config") + list( + tts_client.convert( + text="Hello", reference_id="voice_from_param", config=config + ) + ) + + # Verify parameter reference_id takes precedence + call_args = mock_client_wrapper.request.call_args + payload = ormsgpack.unpackb(call_args[1]["content"]) + assert payload["reference_id"] == "voice_from_param" + def test_convert_with_references(self, tts_client, mock_client_wrapper): """Test TTS with reference audio samples.""" mock_response = Mock() @@ -102,6 +135,46 @@ def test_convert_with_references(self, tts_client, mock_client_wrapper): assert payload["references"][0]["text"] == "Sample 1" assert payload["references"][1]["text"] == "Sample 2" + def test_convert_with_references_parameter(self, tts_client, mock_client_wrapper): + """Test TTS with references as direct parameter.""" + mock_response = Mock() + mock_response.iter_bytes.return_value = iter([b"audio"]) + mock_client_wrapper.request.return_value = mock_response + + references = [ + ReferenceAudio(audio=b"ref_audio_1", text="Sample 1"), + ReferenceAudio(audio=b"ref_audio_2", text="Sample 2"), + ] + + list(tts_client.convert(text="Hello", references=references)) + + # Verify references in payload + call_args = mock_client_wrapper.request.call_args + payload = ormsgpack.unpackb(call_args[1]["content"]) + assert len(payload["references"]) == 2 + assert payload["references"][0]["text"] == "Sample 1" + assert payload["references"][1]["text"] == "Sample 2" + + def test_convert_parameter_references_overrides_config( + self, tts_client, mock_client_wrapper + ): + """Test that parameter references overrides config.references.""" + mock_response = Mock() + mock_response.iter_bytes.return_value = iter([b"audio"]) + mock_client_wrapper.request.return_value = mock_response + + config_refs = [ReferenceAudio(audio=b"config_audio", text="Config")] + param_refs = [ReferenceAudio(audio=b"param_audio", text="Param")] + + config = TTSConfig(references=config_refs) + list(tts_client.convert(text="Hello", references=param_refs, config=config)) + + # Verify parameter references take precedence + call_args = mock_client_wrapper.request.call_args + payload = ormsgpack.unpackb(call_args[1]["content"]) + assert len(payload["references"]) == 1 + assert payload["references"][0]["text"] == "Param" + def test_convert_with_different_backend(self, tts_client, mock_client_wrapper): """Test TTS with different backend/model.""" mock_response = Mock() @@ -228,6 +301,110 @@ def test_convert_empty_response(self, tts_client, mock_client_wrapper): assert audio_chunks == [] + def test_convert_with_format_parameter(self, tts_client, mock_client_wrapper): + """Test TTS with format as direct parameter.""" + mock_response = Mock() + mock_response.iter_bytes.return_value = iter([b"audio"]) + mock_client_wrapper.request.return_value = mock_response + + list(tts_client.convert(text="Hello", format="wav")) + + # Verify format in payload + call_args = mock_client_wrapper.request.call_args + payload = ormsgpack.unpackb(call_args[1]["content"]) + assert payload["format"] == "wav" + + def test_convert_with_opus_format(self, tts_client, mock_client_wrapper): + """Test TTS with opus format.""" + mock_response = Mock() + mock_response.iter_bytes.return_value = iter([b"audio"]) + mock_client_wrapper.request.return_value = mock_response + + list(tts_client.convert(text="Hello", format="opus")) + + # Verify opus format in payload + call_args = mock_client_wrapper.request.call_args + payload = ormsgpack.unpackb(call_args[1]["content"]) + assert payload["format"] == "opus" + + def test_convert_with_latency_parameter(self, tts_client, mock_client_wrapper): + """Test TTS with latency as direct parameter.""" + mock_response = Mock() + mock_response.iter_bytes.return_value = iter([b"audio"]) + mock_client_wrapper.request.return_value = mock_response + + list(tts_client.convert(text="Hello", latency="normal")) + + # Verify latency in payload + call_args = mock_client_wrapper.request.call_args + payload = ormsgpack.unpackb(call_args[1]["content"]) + assert payload["latency"] == "normal" + + def test_convert_with_speed_parameter(self, tts_client, mock_client_wrapper): + """Test TTS with speed as direct parameter.""" + mock_response = Mock() + mock_response.iter_bytes.return_value = iter([b"audio"]) + mock_client_wrapper.request.return_value = mock_response + + list(tts_client.convert(text="Hello", speed=1.5)) + + # Verify speed creates prosody in payload + call_args = mock_client_wrapper.request.call_args + payload = ormsgpack.unpackb(call_args[1]["content"]) + assert payload["prosody"]["speed"] == 1.5 + + def test_convert_parameter_format_overrides_config( + self, tts_client, mock_client_wrapper + ): + """Test that parameter format overrides config.format.""" + mock_response = Mock() + mock_response.iter_bytes.return_value = iter([b"audio"]) + mock_client_wrapper.request.return_value = mock_response + + config = TTSConfig(format="wav") + list(tts_client.convert(text="Hello", format="pcm", config=config)) + + # Verify parameter format takes precedence + call_args = mock_client_wrapper.request.call_args + payload = ormsgpack.unpackb(call_args[1]["content"]) + assert payload["format"] == "pcm" + + def test_convert_parameter_speed_overrides_config_prosody( + self, tts_client, mock_client_wrapper + ): + """Test that parameter speed overrides config.prosody speed but preserves volume.""" + mock_response = Mock() + mock_response.iter_bytes.return_value = iter([b"audio"]) + mock_client_wrapper.request.return_value = mock_response + + config = TTSConfig(prosody=Prosody(speed=2.0, volume=0.5)) + list(tts_client.convert(text="Hello", speed=1.5, config=config)) + + # Verify parameter speed takes precedence but volume is preserved + call_args = mock_client_wrapper.request.call_args + payload = ormsgpack.unpackb(call_args[1]["content"]) + assert payload["prosody"]["speed"] == 1.5 + assert payload["prosody"]["volume"] == 0.5 # Preserved from config! + + def test_convert_combined_convenience_parameters( + self, tts_client, mock_client_wrapper + ): + """Test TTS with multiple convenience parameters combined.""" + mock_response = Mock() + mock_response.iter_bytes.return_value = iter([b"audio"]) + mock_client_wrapper.request.return_value = mock_response + + list( + tts_client.convert(text="Hello", format="wav", speed=1.3, latency="normal") + ) + + # Verify all parameters in payload + call_args = mock_client_wrapper.request.call_args + payload = ormsgpack.unpackb(call_args[1]["content"]) + assert payload["format"] == "wav" + assert payload["latency"] == "normal" + assert payload["prosody"]["speed"] == 1.3 + class TestAsyncTTSClient: """Test asynchronous AsyncTTSClient.""" @@ -282,6 +459,115 @@ async def async_iter_bytes(): payload = ormsgpack.unpackb(call_args[1]["content"]) assert payload["reference_id"] == "voice_123" + @pytest.mark.asyncio + async def test_convert_with_reference_id_parameter( + self, async_tts_client, async_mock_client_wrapper + ): + """Test async TTS with reference_id as direct parameter.""" + mock_response = Mock() + + async def async_iter_bytes(): + yield b"audio" + + mock_response.aiter_bytes = async_iter_bytes + async_mock_client_wrapper.request = AsyncMock(return_value=mock_response) + + audio_chunks = [] + async for chunk in async_tts_client.convert( + text="Hello", reference_id="voice_456" + ): + audio_chunks.append(chunk) + + # Verify reference_id in payload + call_args = async_mock_client_wrapper.request.call_args + payload = ormsgpack.unpackb(call_args[1]["content"]) + assert payload["reference_id"] == "voice_456" + + @pytest.mark.asyncio + async def test_convert_parameter_reference_id_overrides_config( + self, async_tts_client, async_mock_client_wrapper + ): + """Test that parameter reference_id overrides config.reference_id (async).""" + mock_response = Mock() + + async def async_iter_bytes(): + yield b"audio" + + mock_response.aiter_bytes = async_iter_bytes + async_mock_client_wrapper.request = AsyncMock(return_value=mock_response) + + config = TTSConfig(reference_id="voice_from_config") + audio_chunks = [] + async for chunk in async_tts_client.convert( + text="Hello", reference_id="voice_from_param", config=config + ): + audio_chunks.append(chunk) + + # Verify parameter reference_id takes precedence + call_args = async_mock_client_wrapper.request.call_args + payload = ormsgpack.unpackb(call_args[1]["content"]) + assert payload["reference_id"] == "voice_from_param" + + @pytest.mark.asyncio + async def test_convert_with_references_parameter( + self, async_tts_client, async_mock_client_wrapper + ): + """Test async TTS with references as direct parameter.""" + mock_response = Mock() + + async def async_iter_bytes(): + yield b"audio" + + mock_response.aiter_bytes = async_iter_bytes + async_mock_client_wrapper.request = AsyncMock(return_value=mock_response) + + references = [ + ReferenceAudio(audio=b"ref_audio_1", text="Sample 1"), + ReferenceAudio(audio=b"ref_audio_2", text="Sample 2"), + ] + + audio_chunks = [] + async for chunk in async_tts_client.convert( + text="Hello", references=references + ): + audio_chunks.append(chunk) + + # Verify references in payload + call_args = async_mock_client_wrapper.request.call_args + payload = ormsgpack.unpackb(call_args[1]["content"]) + assert len(payload["references"]) == 2 + assert payload["references"][0]["text"] == "Sample 1" + assert payload["references"][1]["text"] == "Sample 2" + + @pytest.mark.asyncio + async def test_convert_parameter_references_overrides_config( + self, async_tts_client, async_mock_client_wrapper + ): + """Test that parameter references overrides config.references (async).""" + mock_response = Mock() + + async def async_iter_bytes(): + yield b"audio" + + mock_response.aiter_bytes = async_iter_bytes + async_mock_client_wrapper.request = AsyncMock(return_value=mock_response) + + config_refs = [ReferenceAudio(audio=b"config_audio", text="Config")] + param_refs = [ReferenceAudio(audio=b"param_audio", text="Param")] + + config = TTSConfig(references=config_refs) + audio_chunks = [] + async for chunk in async_tts_client.convert( + text="Hello", references=param_refs, config=config + ): + audio_chunks.append(chunk) + + # Verify parameter references take precedence + call_args = async_mock_client_wrapper.request.call_args + payload = ormsgpack.unpackb(call_args[1]["content"]) + assert len(payload["references"]) == 1 + assert payload["references"][0]["text"] == "Param" + @pytest.mark.asyncio async def test_convert_with_prosody( self, async_tts_client, async_mock_client_wrapper @@ -352,3 +638,146 @@ async def async_iter_bytes(): audio_chunks.append(chunk) assert audio_chunks == [] + + @pytest.mark.asyncio + async def test_convert_with_format_parameter( + self, async_tts_client, async_mock_client_wrapper + ): + """Test async TTS with format as direct parameter.""" + mock_response = Mock() + + async def async_iter_bytes(): + yield b"audio" + + mock_response.aiter_bytes = async_iter_bytes + async_mock_client_wrapper.request = AsyncMock(return_value=mock_response) + + audio_chunks = [] + async for chunk in async_tts_client.convert(text="Hello", format="wav"): + audio_chunks.append(chunk) + + # Verify format in payload + call_args = async_mock_client_wrapper.request.call_args + payload = ormsgpack.unpackb(call_args[1]["content"]) + assert payload["format"] == "wav" + + @pytest.mark.asyncio + async def test_convert_with_latency_parameter( + self, async_tts_client, async_mock_client_wrapper + ): + """Test async TTS with latency as direct parameter.""" + mock_response = Mock() + + async def async_iter_bytes(): + yield b"audio" + + mock_response.aiter_bytes = async_iter_bytes + async_mock_client_wrapper.request = AsyncMock(return_value=mock_response) + + audio_chunks = [] + async for chunk in async_tts_client.convert(text="Hello", latency="normal"): + audio_chunks.append(chunk) + + # Verify latency in payload + call_args = async_mock_client_wrapper.request.call_args + payload = ormsgpack.unpackb(call_args[1]["content"]) + assert payload["latency"] == "normal" + + @pytest.mark.asyncio + async def test_convert_with_speed_parameter( + self, async_tts_client, async_mock_client_wrapper + ): + """Test async TTS with speed as direct parameter.""" + mock_response = Mock() + + async def async_iter_bytes(): + yield b"audio" + + mock_response.aiter_bytes = async_iter_bytes + async_mock_client_wrapper.request = AsyncMock(return_value=mock_response) + + audio_chunks = [] + async for chunk in async_tts_client.convert(text="Hello", speed=1.5): + audio_chunks.append(chunk) + + # Verify speed creates prosody in payload + call_args = async_mock_client_wrapper.request.call_args + payload = ormsgpack.unpackb(call_args[1]["content"]) + assert payload["prosody"]["speed"] == 1.5 + + @pytest.mark.asyncio + async def test_convert_parameter_format_overrides_config( + self, async_tts_client, async_mock_client_wrapper + ): + """Test that parameter format overrides config.format (async).""" + mock_response = Mock() + + async def async_iter_bytes(): + yield b"audio" + + mock_response.aiter_bytes = async_iter_bytes + async_mock_client_wrapper.request = AsyncMock(return_value=mock_response) + + config = TTSConfig(format="wav") + audio_chunks = [] + async for chunk in async_tts_client.convert( + text="Hello", format="pcm", config=config + ): + audio_chunks.append(chunk) + + # Verify parameter format takes precedence + call_args = async_mock_client_wrapper.request.call_args + payload = ormsgpack.unpackb(call_args[1]["content"]) + assert payload["format"] == "pcm" + + @pytest.mark.asyncio + async def test_convert_parameter_speed_overrides_config_prosody( + self, async_tts_client, async_mock_client_wrapper + ): + """Test that parameter speed overrides config.prosody speed but preserves volume (async).""" + mock_response = Mock() + + async def async_iter_bytes(): + yield b"audio" + + mock_response.aiter_bytes = async_iter_bytes + async_mock_client_wrapper.request = AsyncMock(return_value=mock_response) + + config = TTSConfig(prosody=Prosody(speed=2.0, volume=0.5)) + audio_chunks = [] + async for chunk in async_tts_client.convert( + text="Hello", speed=1.5, config=config + ): + audio_chunks.append(chunk) + + # Verify parameter speed takes precedence but volume is preserved + call_args = async_mock_client_wrapper.request.call_args + payload = ormsgpack.unpackb(call_args[1]["content"]) + assert payload["prosody"]["speed"] == 1.5 + assert payload["prosody"]["volume"] == 0.5 # Preserved from config! + + @pytest.mark.asyncio + async def test_convert_combined_convenience_parameters( + self, async_tts_client, async_mock_client_wrapper + ): + """Test async TTS with multiple convenience parameters combined.""" + mock_response = Mock() + + async def async_iter_bytes(): + yield b"audio" + + mock_response.aiter_bytes = async_iter_bytes + async_mock_client_wrapper.request = AsyncMock(return_value=mock_response) + + audio_chunks = [] + async for chunk in async_tts_client.convert( + text="Hello", format="wav", speed=1.3, latency="normal" + ): + audio_chunks.append(chunk) + + # Verify all parameters in payload + call_args = async_mock_client_wrapper.request.call_args + payload = ormsgpack.unpackb(call_args[1]["content"]) + assert payload["format"] == "wav" + assert payload["latency"] == "normal" + assert payload["prosody"]["speed"] == 1.3 diff --git a/tests/unit/test_tts_realtime.py b/tests/unit/test_tts_realtime.py index 72141ec..27874bb 100644 --- a/tests/unit/test_tts_realtime.py +++ b/tests/unit/test_tts_realtime.py @@ -5,7 +5,8 @@ from fishaudio.core import ClientWrapper, AsyncClientWrapper from fishaudio.resources.tts import TTSClient, AsyncTTSClient -from fishaudio.types import Prosody, TTSConfig, TextEvent, FlushEvent +from fishaudio.types import Prosody, TTSConfig, TextEvent, FlushEvent, ReferenceAudio +import ormsgpack @pytest.fixture @@ -181,6 +182,169 @@ def test_stream_websocket_max_workers( # Verify ThreadPoolExecutor was created with max_workers=5 mock_executor.assert_called_once_with(max_workers=5) + @patch("fishaudio.resources.tts.connect_ws") + @patch("fishaudio.resources.tts.ThreadPoolExecutor") + def test_stream_websocket_with_reference_id_parameter( + self, mock_executor, mock_connect_ws, tts_client, mock_client_wrapper + ): + """Test WebSocket streaming with reference_id as direct parameter.""" + # Setup mocks + mock_ws = MagicMock() + mock_ws.__enter__ = Mock(return_value=mock_ws) + mock_ws.__exit__ = Mock(return_value=None) + mock_ws.send_bytes = Mock() + mock_connect_ws.return_value = mock_ws + + # Make executor.submit actually run the function + def submit_side_effect(fn): + fn() # Execute the sender function + mock_future = Mock() + mock_future.result.return_value = None + return mock_future + + mock_executor_instance = Mock() + mock_executor_instance.submit.side_effect = submit_side_effect + mock_executor.return_value = mock_executor_instance + + with patch("fishaudio.resources.tts.iter_websocket_audio") as mock_receiver: + mock_receiver.return_value = iter([b"audio"]) + + text_stream = iter(["Test"]) + list(tts_client.stream_websocket(text_stream, reference_id="voice_456")) + + # Verify WebSocket was called with StartEvent containing reference_id + assert mock_ws.send_bytes.called + # Get the first call (StartEvent) + first_call = mock_ws.send_bytes.call_args_list[0] + start_event_payload = ormsgpack.unpackb(first_call[0][0]) + assert start_event_payload["request"]["reference_id"] == "voice_456" + + @patch("fishaudio.resources.tts.connect_ws") + @patch("fishaudio.resources.tts.ThreadPoolExecutor") + def test_stream_websocket_parameter_reference_id_overrides_config( + self, mock_executor, mock_connect_ws, tts_client, mock_client_wrapper + ): + """Test that parameter reference_id overrides config.reference_id.""" + # Setup mocks + mock_ws = MagicMock() + mock_ws.__enter__ = Mock(return_value=mock_ws) + mock_ws.__exit__ = Mock(return_value=None) + mock_ws.send_bytes = Mock() + mock_connect_ws.return_value = mock_ws + + # Make executor.submit actually run the function + def submit_side_effect(fn): + fn() # Execute the sender function + mock_future = Mock() + mock_future.result.return_value = None + return mock_future + + mock_executor_instance = Mock() + mock_executor_instance.submit.side_effect = submit_side_effect + mock_executor.return_value = mock_executor_instance + + with patch("fishaudio.resources.tts.iter_websocket_audio") as mock_receiver: + mock_receiver.return_value = iter([b"audio"]) + + config = TTSConfig(reference_id="voice_from_config") + text_stream = iter(["Test"]) + list( + tts_client.stream_websocket( + text_stream, reference_id="voice_from_param", config=config + ) + ) + + # Verify parameter reference_id takes precedence + first_call = mock_ws.send_bytes.call_args_list[0] + start_event_payload = ormsgpack.unpackb(first_call[0][0]) + assert start_event_payload["request"]["reference_id"] == "voice_from_param" + + @patch("fishaudio.resources.tts.connect_ws") + @patch("fishaudio.resources.tts.ThreadPoolExecutor") + def test_stream_websocket_with_references_parameter( + self, mock_executor, mock_connect_ws, tts_client, mock_client_wrapper + ): + """Test WebSocket streaming with references as direct parameter.""" + # Setup mocks + mock_ws = MagicMock() + mock_ws.__enter__ = Mock(return_value=mock_ws) + mock_ws.__exit__ = Mock(return_value=None) + mock_ws.send_bytes = Mock() + mock_connect_ws.return_value = mock_ws + + # Make executor.submit actually run the function + def submit_side_effect(fn): + fn() # Execute the sender function + mock_future = Mock() + mock_future.result.return_value = None + return mock_future + + mock_executor_instance = Mock() + mock_executor_instance.submit.side_effect = submit_side_effect + mock_executor.return_value = mock_executor_instance + + with patch("fishaudio.resources.tts.iter_websocket_audio") as mock_receiver: + mock_receiver.return_value = iter([b"audio"]) + + references = [ + ReferenceAudio(audio=b"ref_audio_1", text="Sample 1"), + ReferenceAudio(audio=b"ref_audio_2", text="Sample 2"), + ] + + text_stream = iter(["Test"]) + list(tts_client.stream_websocket(text_stream, references=references)) + + # Verify references in StartEvent + first_call = mock_ws.send_bytes.call_args_list[0] + start_event_payload = ormsgpack.unpackb(first_call[0][0]) + assert len(start_event_payload["request"]["references"]) == 2 + assert start_event_payload["request"]["references"][0]["text"] == "Sample 1" + assert start_event_payload["request"]["references"][1]["text"] == "Sample 2" + + @patch("fishaudio.resources.tts.connect_ws") + @patch("fishaudio.resources.tts.ThreadPoolExecutor") + def test_stream_websocket_parameter_references_overrides_config( + self, mock_executor, mock_connect_ws, tts_client, mock_client_wrapper + ): + """Test that parameter references overrides config.references.""" + # Setup mocks + mock_ws = MagicMock() + mock_ws.__enter__ = Mock(return_value=mock_ws) + mock_ws.__exit__ = Mock(return_value=None) + mock_ws.send_bytes = Mock() + mock_connect_ws.return_value = mock_ws + + # Make executor.submit actually run the function + def submit_side_effect(fn): + fn() # Execute the sender function + mock_future = Mock() + mock_future.result.return_value = None + return mock_future + + mock_executor_instance = Mock() + mock_executor_instance.submit.side_effect = submit_side_effect + mock_executor.return_value = mock_executor_instance + + with patch("fishaudio.resources.tts.iter_websocket_audio") as mock_receiver: + mock_receiver.return_value = iter([b"audio"]) + + config_refs = [ReferenceAudio(audio=b"config_audio", text="Config")] + param_refs = [ReferenceAudio(audio=b"param_audio", text="Param")] + + config = TTSConfig(references=config_refs) + text_stream = iter(["Test"]) + list( + tts_client.stream_websocket( + text_stream, references=param_refs, config=config + ) + ) + + # Verify parameter references take precedence + first_call = mock_ws.send_bytes.call_args_list[0] + start_event_payload = ormsgpack.unpackb(first_call[0][0]) + assert len(start_event_payload["request"]["references"]) == 1 + assert start_event_payload["request"]["references"][0]["text"] == "Param" + class TestAsyncTTSRealtimeClient: """Test asynchronous AsyncTTSClient realtime streaming.""" @@ -331,3 +495,157 @@ async def text_stream(): # Should have no audio assert audio_chunks == [] + + @pytest.mark.asyncio + @patch("fishaudio.resources.tts.aconnect_ws") + async def test_stream_websocket_with_reference_id_parameter( + self, mock_aconnect_ws, async_tts_client, async_mock_client_wrapper + ): + """Test async WebSocket streaming with reference_id as direct parameter.""" + # Setup mocks + mock_ws = MagicMock() + mock_ws.__aenter__ = AsyncMock(return_value=mock_ws) + mock_ws.__aexit__ = AsyncMock(return_value=None) + mock_ws.send_bytes = AsyncMock() + mock_aconnect_ws.return_value = mock_ws + + async def mock_audio_receiver(ws): + yield b"audio" + + with patch( + "fishaudio.resources.tts.aiter_websocket_audio", + return_value=mock_audio_receiver(mock_ws), + ): + + async def text_stream(): + yield "Test" + + audio_chunks = [] + async for chunk in async_tts_client.stream_websocket( + text_stream(), reference_id="voice_456" + ): + audio_chunks.append(chunk) + + # Verify WebSocket was called with StartEvent containing reference_id + assert mock_ws.send_bytes.called + # Get the first call (StartEvent) + first_call = mock_ws.send_bytes.call_args_list[0] + start_event_payload = ormsgpack.unpackb(first_call[0][0]) + assert start_event_payload["request"]["reference_id"] == "voice_456" + + @pytest.mark.asyncio + @patch("fishaudio.resources.tts.aconnect_ws") + async def test_stream_websocket_parameter_reference_id_overrides_config( + self, mock_aconnect_ws, async_tts_client, async_mock_client_wrapper + ): + """Test that parameter reference_id overrides config.reference_id (async).""" + # Setup mocks + mock_ws = MagicMock() + mock_ws.__aenter__ = AsyncMock(return_value=mock_ws) + mock_ws.__aexit__ = AsyncMock(return_value=None) + mock_ws.send_bytes = AsyncMock() + mock_aconnect_ws.return_value = mock_ws + + async def mock_audio_receiver(ws): + yield b"audio" + + with patch( + "fishaudio.resources.tts.aiter_websocket_audio", + return_value=mock_audio_receiver(mock_ws), + ): + config = TTSConfig(reference_id="voice_from_config") + + async def text_stream(): + yield "Test" + + audio_chunks = [] + async for chunk in async_tts_client.stream_websocket( + text_stream(), reference_id="voice_from_param", config=config + ): + audio_chunks.append(chunk) + + # Verify parameter reference_id takes precedence + first_call = mock_ws.send_bytes.call_args_list[0] + start_event_payload = ormsgpack.unpackb(first_call[0][0]) + assert start_event_payload["request"]["reference_id"] == "voice_from_param" + + @pytest.mark.asyncio + @patch("fishaudio.resources.tts.aconnect_ws") + async def test_stream_websocket_with_references_parameter( + self, mock_aconnect_ws, async_tts_client, async_mock_client_wrapper + ): + """Test async WebSocket streaming with references as direct parameter.""" + # Setup mocks + mock_ws = MagicMock() + mock_ws.__aenter__ = AsyncMock(return_value=mock_ws) + mock_ws.__aexit__ = AsyncMock(return_value=None) + mock_ws.send_bytes = AsyncMock() + mock_aconnect_ws.return_value = mock_ws + + async def mock_audio_receiver(ws): + yield b"audio" + + with patch( + "fishaudio.resources.tts.aiter_websocket_audio", + return_value=mock_audio_receiver(mock_ws), + ): + references = [ + ReferenceAudio(audio=b"ref_audio_1", text="Sample 1"), + ReferenceAudio(audio=b"ref_audio_2", text="Sample 2"), + ] + + async def text_stream(): + yield "Test" + + audio_chunks = [] + async for chunk in async_tts_client.stream_websocket( + text_stream(), references=references + ): + audio_chunks.append(chunk) + + # Verify references in StartEvent + first_call = mock_ws.send_bytes.call_args_list[0] + start_event_payload = ormsgpack.unpackb(first_call[0][0]) + assert len(start_event_payload["request"]["references"]) == 2 + assert start_event_payload["request"]["references"][0]["text"] == "Sample 1" + assert start_event_payload["request"]["references"][1]["text"] == "Sample 2" + + @pytest.mark.asyncio + @patch("fishaudio.resources.tts.aconnect_ws") + async def test_stream_websocket_parameter_references_overrides_config( + self, mock_aconnect_ws, async_tts_client, async_mock_client_wrapper + ): + """Test that parameter references overrides config.references (async).""" + # Setup mocks + mock_ws = MagicMock() + mock_ws.__aenter__ = AsyncMock(return_value=mock_ws) + mock_ws.__aexit__ = AsyncMock(return_value=None) + mock_ws.send_bytes = AsyncMock() + mock_aconnect_ws.return_value = mock_ws + + async def mock_audio_receiver(ws): + yield b"audio" + + with patch( + "fishaudio.resources.tts.aiter_websocket_audio", + return_value=mock_audio_receiver(mock_ws), + ): + config_refs = [ReferenceAudio(audio=b"config_audio", text="Config")] + param_refs = [ReferenceAudio(audio=b"param_audio", text="Param")] + + config = TTSConfig(references=config_refs) + + async def text_stream(): + yield "Test" + + audio_chunks = [] + async for chunk in async_tts_client.stream_websocket( + text_stream(), references=param_refs, config=config + ): + audio_chunks.append(chunk) + + # Verify parameter references take precedence + first_call = mock_ws.send_bytes.call_args_list[0] + start_event_payload = ormsgpack.unpackb(first_call[0][0]) + assert len(start_event_payload["request"]["references"]) == 1 + assert start_event_payload["request"]["references"][0]["text"] == "Param"