From b1d0129dbc21bd925fa5bce9c0d58adf02ca62f8 Mon Sep 17 00:00:00 2001 From: James Ding Date: Tue, 11 Nov 2025 01:56:08 -0600 Subject: [PATCH 01/10] feat: add reference_id parameter to TTS conversion methods Signed-off-by: James Ding --- src/fishaudio/resources/tts.py | 42 +++++++++++++++++ tests/unit/test_tts.py | 82 ++++++++++++++++++++++++++++++++++ 2 files changed, 124 insertions(+) diff --git a/src/fishaudio/resources/tts.py b/src/fishaudio/resources/tts.py index ea73141..5fbd7bf 100644 --- a/src/fishaudio/resources/tts.py +++ b/src/fishaudio/resources/tts.py @@ -58,6 +58,7 @@ def convert( self, *, text: str, + reference_id: Optional[str] = None, config: TTSConfig = TTSConfig(), model: Model = "s1", request_options: Optional[RequestOptions] = None, @@ -67,6 +68,7 @@ def convert( Args: text: Text to synthesize + reference_id: Voice reference ID (overridden by config.reference_id if set) config: TTS configuration (audio settings, voice, model parameters) model: TTS model to use request_options: Request-level overrides @@ -83,6 +85,9 @@ def convert( # Simple usage with defaults audio = client.tts.convert(text="Hello world") + # With reference_id parameter + audio = client.tts.convert(text="Hello world", reference_id="your_model_id") + # Custom configuration config = TTSConfig(format="wav", mp3_bitrate=192) audio = client.tts.convert(text="Hello world", config=config) @@ -94,6 +99,11 @@ def convert( """ # Build request payload from config request = _config_to_tts_request(config, text) + + # Use parameter reference_id only if config doesn't have one + if request.reference_id is None and reference_id is not None: + request.reference_id = reference_id + payload = request.model_dump(exclude_none=True) # Make request with streaming @@ -114,6 +124,7 @@ def stream_websocket( self, text_stream: Iterable[Union[str, TextEvent, FlushEvent]], *, + reference_id: Optional[str] = None, config: TTSConfig = TTSConfig(), model: Model = "s1", max_workers: int = 10, @@ -125,6 +136,7 @@ def stream_websocket( Args: text_stream: Iterator of text chunks to stream + reference_id: Voice reference ID (overridden by config.reference_id if set) config: TTS configuration (audio settings, voice, model parameters) model: TTS model to use max_workers: ThreadPoolExecutor workers for concurrent sender @@ -148,6 +160,11 @@ def text_generator(): for audio_chunk in client.tts.stream_websocket(text_generator()): f.write(audio_chunk) + # With reference_id parameter + with open("output.mp3", "wb") as f: + for audio_chunk in client.tts.stream_websocket(text_generator(), reference_id="your_model_id"): + f.write(audio_chunk) + # Custom configuration config = TTSConfig(format="wav", latency="normal") with open("output.wav", "wb") as f: @@ -158,6 +175,10 @@ def text_generator(): # Build TTSRequest from config tts_request = _config_to_tts_request(config, text="") + # Use parameter reference_id only if config doesn't have one + if tts_request.reference_id is None and reference_id is not None: + tts_request.reference_id = reference_id + executor = ThreadPoolExecutor(max_workers=max_workers) try: @@ -202,6 +223,7 @@ async def convert( self, *, text: str, + reference_id: Optional[str] = None, config: TTSConfig = TTSConfig(), model: Model = "s1", request_options: Optional[RequestOptions] = None, @@ -211,6 +233,7 @@ async def convert( Args: text: Text to synthesize + reference_id: Voice reference ID (overridden by config.reference_id if set) config: TTS configuration (audio settings, voice, model parameters) model: TTS model to use request_options: Request-level overrides @@ -227,6 +250,9 @@ async def convert( # Simple usage with defaults audio = await client.tts.convert(text="Hello world") + # With reference_id parameter + audio = await client.tts.convert(text="Hello world", reference_id="your_model_id") + # Custom configuration config = TTSConfig(format="wav", mp3_bitrate=192) audio = await client.tts.convert(text="Hello world", config=config) @@ -238,6 +264,11 @@ async def convert( """ # Build request payload from config request = _config_to_tts_request(config, text) + + # Use parameter reference_id only if config doesn't have one + if request.reference_id is None and reference_id is not None: + request.reference_id = reference_id + payload = request.model_dump(exclude_none=True) # Make request with streaming @@ -258,6 +289,7 @@ async def stream_websocket( self, text_stream: AsyncIterable[Union[str, TextEvent, FlushEvent]], *, + reference_id: Optional[str] = None, config: TTSConfig = TTSConfig(), model: Model = "s1", ): @@ -268,6 +300,7 @@ async def stream_websocket( Args: text_stream: Async iterator of text chunks to stream + reference_id: Voice reference ID (overridden by config.reference_id if set) config: TTS configuration (audio settings, voice, model parameters) model: TTS model to use @@ -290,6 +323,11 @@ async def text_generator(): async for audio_chunk in client.tts.stream_websocket(text_generator()): await f.write(audio_chunk) + # With reference_id parameter + async with aiofiles.open("output.mp3", "wb") as f: + async for audio_chunk in client.tts.stream_websocket(text_generator(), reference_id="your_model_id"): + await f.write(audio_chunk) + # Custom configuration config = TTSConfig(format="wav", latency="normal") async with aiofiles.open("output.wav", "wb") as f: @@ -300,6 +338,10 @@ async def text_generator(): # Build TTSRequest from config tts_request = _config_to_tts_request(config, text="") + # Use parameter reference_id only if config doesn't have one + if tts_request.reference_id is None and reference_id is not None: + tts_request.reference_id = reference_id + ws: AsyncWebSocketSession async with aconnect_ws( "/v1/tts/live", diff --git a/tests/unit/test_tts.py b/tests/unit/test_tts.py index d41d8ca..7c36ae8 100644 --- a/tests/unit/test_tts.py +++ b/tests/unit/test_tts.py @@ -81,6 +81,39 @@ def test_convert_with_reference_id(self, tts_client, mock_client_wrapper): payload = ormsgpack.unpackb(call_args[1]["content"]) assert payload["reference_id"] == "voice_123" + def test_convert_with_reference_id_parameter(self, tts_client, mock_client_wrapper): + """Test TTS with reference_id as direct parameter.""" + mock_response = Mock() + mock_response.iter_bytes.return_value = iter([b"audio"]) + mock_client_wrapper.request.return_value = mock_response + + list(tts_client.convert(text="Hello", reference_id="voice_456")) + + # Verify reference_id in payload + call_args = mock_client_wrapper.request.call_args + payload = ormsgpack.unpackb(call_args[1]["content"]) + assert payload["reference_id"] == "voice_456" + + def test_convert_config_reference_id_overrides_parameter( + self, tts_client, mock_client_wrapper + ): + """Test that config.reference_id overrides parameter reference_id.""" + mock_response = Mock() + mock_response.iter_bytes.return_value = iter([b"audio"]) + mock_client_wrapper.request.return_value = mock_response + + config = TTSConfig(reference_id="voice_from_config") + list( + tts_client.convert( + text="Hello", reference_id="voice_from_param", config=config + ) + ) + + # Verify config reference_id takes precedence + call_args = mock_client_wrapper.request.call_args + payload = ormsgpack.unpackb(call_args[1]["content"]) + assert payload["reference_id"] == "voice_from_config" + def test_convert_with_references(self, tts_client, mock_client_wrapper): """Test TTS with reference audio samples.""" mock_response = Mock() @@ -282,6 +315,55 @@ async def async_iter_bytes(): payload = ormsgpack.unpackb(call_args[1]["content"]) assert payload["reference_id"] == "voice_123" + @pytest.mark.asyncio + async def test_convert_with_reference_id_parameter( + self, async_tts_client, async_mock_client_wrapper + ): + """Test async TTS with reference_id as direct parameter.""" + mock_response = Mock() + + async def async_iter_bytes(): + yield b"audio" + + mock_response.aiter_bytes = async_iter_bytes + async_mock_client_wrapper.request = AsyncMock(return_value=mock_response) + + audio_chunks = [] + async for chunk in async_tts_client.convert( + text="Hello", reference_id="voice_456" + ): + audio_chunks.append(chunk) + + # Verify reference_id in payload + call_args = async_mock_client_wrapper.request.call_args + payload = ormsgpack.unpackb(call_args[1]["content"]) + assert payload["reference_id"] == "voice_456" + + @pytest.mark.asyncio + async def test_convert_config_reference_id_overrides_parameter( + self, async_tts_client, async_mock_client_wrapper + ): + """Test that config.reference_id overrides parameter reference_id (async).""" + mock_response = Mock() + + async def async_iter_bytes(): + yield b"audio" + + mock_response.aiter_bytes = async_iter_bytes + async_mock_client_wrapper.request = AsyncMock(return_value=mock_response) + + config = TTSConfig(reference_id="voice_from_config") + audio_chunks = [] + async for chunk in async_tts_client.convert( + text="Hello", reference_id="voice_from_param", config=config + ): + audio_chunks.append(chunk) + + # Verify config reference_id takes precedence + call_args = async_mock_client_wrapper.request.call_args + payload = ormsgpack.unpackb(call_args[1]["content"]) + assert payload["reference_id"] == "voice_from_config" + @pytest.mark.asyncio async def test_convert_with_prosody( self, async_tts_client, async_mock_client_wrapper From c65fcaa7a4ef509647d963c695a5b9d21ae73c08 Mon Sep 17 00:00:00 2001 From: James Ding Date: Tue, 11 Nov 2025 11:51:55 -0600 Subject: [PATCH 02/10] feat: add support for references parameter in TTS conversion methods Signed-off-by: James Ding --- src/fishaudio/resources/tts.py | 63 +++++++++++++++++++-- tests/unit/test_tts.py | 100 +++++++++++++++++++++++++++++++++ 2 files changed, 158 insertions(+), 5 deletions(-) diff --git a/src/fishaudio/resources/tts.py b/src/fishaudio/resources/tts.py index 5fbd7bf..63407cc 100644 --- a/src/fishaudio/resources/tts.py +++ b/src/fishaudio/resources/tts.py @@ -2,7 +2,7 @@ import asyncio from concurrent.futures import ThreadPoolExecutor -from typing import AsyncIterable, Iterable, Iterator, Optional, Union +from typing import AsyncIterable, Iterable, Iterator, List, Optional, Union import ormsgpack from httpx_ws import AsyncWebSocketSession, WebSocketSession, aconnect_ws, connect_ws @@ -13,6 +13,7 @@ CloseEvent, FlushEvent, Model, + ReferenceAudio, StartEvent, TextEvent, TTSConfig, @@ -59,6 +60,7 @@ def convert( *, text: str, reference_id: Optional[str] = None, + references: List[ReferenceAudio] = [], config: TTSConfig = TTSConfig(), model: Model = "s1", request_options: Optional[RequestOptions] = None, @@ -69,6 +71,7 @@ def convert( Args: text: Text to synthesize reference_id: Voice reference ID (overridden by config.reference_id if set) + references: Reference audio samples (overridden by config.references if set) config: TTS configuration (audio settings, voice, model parameters) model: TTS model to use request_options: Request-level overrides @@ -78,7 +81,7 @@ def convert( Example: ```python - from fishaudio import FishAudio, TTSConfig + from fishaudio import FishAudio, TTSConfig, ReferenceAudio client = FishAudio(api_key="...") @@ -88,6 +91,12 @@ def convert( # With reference_id parameter audio = client.tts.convert(text="Hello world", reference_id="your_model_id") + # With references parameter + audio = client.tts.convert( + text="Hello world", + references=[ReferenceAudio(audio=audio_bytes, text="sample")] + ) + # Custom configuration config = TTSConfig(format="wav", mp3_bitrate=192) audio = client.tts.convert(text="Hello world", config=config) @@ -104,6 +113,10 @@ def convert( if request.reference_id is None and reference_id is not None: request.reference_id = reference_id + # Use parameter references only if config doesn't have any + if not request.references and references: + request.references = references + payload = request.model_dump(exclude_none=True) # Make request with streaming @@ -125,6 +138,7 @@ def stream_websocket( text_stream: Iterable[Union[str, TextEvent, FlushEvent]], *, reference_id: Optional[str] = None, + references: List[ReferenceAudio] = [], config: TTSConfig = TTSConfig(), model: Model = "s1", max_workers: int = 10, @@ -137,6 +151,7 @@ def stream_websocket( Args: text_stream: Iterator of text chunks to stream reference_id: Voice reference ID (overridden by config.reference_id if set) + references: Reference audio samples (overridden by config.references if set) config: TTS configuration (audio settings, voice, model parameters) model: TTS model to use max_workers: ThreadPoolExecutor workers for concurrent sender @@ -146,7 +161,7 @@ def stream_websocket( Example: ```python - from fishaudio import FishAudio, TTSConfig + from fishaudio import FishAudio, TTSConfig, ReferenceAudio client = FishAudio(api_key="...") @@ -165,6 +180,14 @@ def text_generator(): for audio_chunk in client.tts.stream_websocket(text_generator(), reference_id="your_model_id"): f.write(audio_chunk) + # With references parameter + with open("output.mp3", "wb") as f: + for audio_chunk in client.tts.stream_websocket( + text_generator(), + references=[ReferenceAudio(audio=audio_bytes, text="sample")] + ): + f.write(audio_chunk) + # Custom configuration config = TTSConfig(format="wav", latency="normal") with open("output.wav", "wb") as f: @@ -179,6 +202,10 @@ def text_generator(): if tts_request.reference_id is None and reference_id is not None: tts_request.reference_id = reference_id + # Use parameter references only if config doesn't have any + if not tts_request.references and references: + tts_request.references = references + executor = ThreadPoolExecutor(max_workers=max_workers) try: @@ -224,6 +251,7 @@ async def convert( *, text: str, reference_id: Optional[str] = None, + references: List[ReferenceAudio] = [], config: TTSConfig = TTSConfig(), model: Model = "s1", request_options: Optional[RequestOptions] = None, @@ -234,6 +262,7 @@ async def convert( Args: text: Text to synthesize reference_id: Voice reference ID (overridden by config.reference_id if set) + references: Reference audio samples (overridden by config.references if set) config: TTS configuration (audio settings, voice, model parameters) model: TTS model to use request_options: Request-level overrides @@ -243,7 +272,7 @@ async def convert( Example: ```python - from fishaudio import AsyncFishAudio, TTSConfig + from fishaudio import AsyncFishAudio, TTSConfig, ReferenceAudio client = AsyncFishAudio(api_key="...") @@ -253,6 +282,12 @@ async def convert( # With reference_id parameter audio = await client.tts.convert(text="Hello world", reference_id="your_model_id") + # With references parameter + audio = await client.tts.convert( + text="Hello world", + references=[ReferenceAudio(audio=audio_bytes, text="sample")] + ) + # Custom configuration config = TTSConfig(format="wav", mp3_bitrate=192) audio = await client.tts.convert(text="Hello world", config=config) @@ -269,6 +304,10 @@ async def convert( if request.reference_id is None and reference_id is not None: request.reference_id = reference_id + # Use parameter references only if config doesn't have any + if not request.references and references: + request.references = references + payload = request.model_dump(exclude_none=True) # Make request with streaming @@ -290,6 +329,7 @@ async def stream_websocket( text_stream: AsyncIterable[Union[str, TextEvent, FlushEvent]], *, reference_id: Optional[str] = None, + references: List[ReferenceAudio] = [], config: TTSConfig = TTSConfig(), model: Model = "s1", ): @@ -301,6 +341,7 @@ async def stream_websocket( Args: text_stream: Async iterator of text chunks to stream reference_id: Voice reference ID (overridden by config.reference_id if set) + references: Reference audio samples (overridden by config.references if set) config: TTS configuration (audio settings, voice, model parameters) model: TTS model to use @@ -309,7 +350,7 @@ async def stream_websocket( Example: ```python - from fishaudio import AsyncFishAudio, TTSConfig + from fishaudio import AsyncFishAudio, TTSConfig, ReferenceAudio client = AsyncFishAudio(api_key="...") @@ -328,6 +369,14 @@ async def text_generator(): async for audio_chunk in client.tts.stream_websocket(text_generator(), reference_id="your_model_id"): await f.write(audio_chunk) + # With references parameter + async with aiofiles.open("output.mp3", "wb") as f: + async for audio_chunk in client.tts.stream_websocket( + text_generator(), + references=[ReferenceAudio(audio=audio_bytes, text="sample")] + ): + await f.write(audio_chunk) + # Custom configuration config = TTSConfig(format="wav", latency="normal") async with aiofiles.open("output.wav", "wb") as f: @@ -342,6 +391,10 @@ async def text_generator(): if tts_request.reference_id is None and reference_id is not None: tts_request.reference_id = reference_id + # Use parameter references only if config doesn't have any + if not tts_request.references and references: + tts_request.references = references + ws: AsyncWebSocketSession async with aconnect_ws( "/v1/tts/live", diff --git a/tests/unit/test_tts.py b/tests/unit/test_tts.py index 7c36ae8..eba871a 100644 --- a/tests/unit/test_tts.py +++ b/tests/unit/test_tts.py @@ -135,6 +135,46 @@ def test_convert_with_references(self, tts_client, mock_client_wrapper): assert payload["references"][0]["text"] == "Sample 1" assert payload["references"][1]["text"] == "Sample 2" + def test_convert_with_references_parameter(self, tts_client, mock_client_wrapper): + """Test TTS with references as direct parameter.""" + mock_response = Mock() + mock_response.iter_bytes.return_value = iter([b"audio"]) + mock_client_wrapper.request.return_value = mock_response + + references = [ + ReferenceAudio(audio=b"ref_audio_1", text="Sample 1"), + ReferenceAudio(audio=b"ref_audio_2", text="Sample 2"), + ] + + list(tts_client.convert(text="Hello", references=references)) + + # Verify references in payload + call_args = mock_client_wrapper.request.call_args + payload = ormsgpack.unpackb(call_args[1]["content"]) + assert len(payload["references"]) == 2 + assert payload["references"][0]["text"] == "Sample 1" + assert payload["references"][1]["text"] == "Sample 2" + + def test_convert_config_references_overrides_parameter( + self, tts_client, mock_client_wrapper + ): + """Test that config.references overrides parameter references.""" + mock_response = Mock() + mock_response.iter_bytes.return_value = iter([b"audio"]) + mock_client_wrapper.request.return_value = mock_response + + config_refs = [ReferenceAudio(audio=b"config_audio", text="Config")] + param_refs = [ReferenceAudio(audio=b"param_audio", text="Param")] + + config = TTSConfig(references=config_refs) + list(tts_client.convert(text="Hello", references=param_refs, config=config)) + + # Verify config references take precedence + call_args = mock_client_wrapper.request.call_args + payload = ormsgpack.unpackb(call_args[1]["content"]) + assert len(payload["references"]) == 1 + assert payload["references"][0]["text"] == "Config" + def test_convert_with_different_backend(self, tts_client, mock_client_wrapper): """Test TTS with different backend/model.""" mock_response = Mock() @@ -364,6 +404,66 @@ async def async_iter_bytes(): payload = ormsgpack.unpackb(call_args[1]["content"]) assert payload["reference_id"] == "voice_from_config" + @pytest.mark.asyncio + async def test_convert_with_references_parameter( + self, async_tts_client, async_mock_client_wrapper + ): + """Test async TTS with references as direct parameter.""" + mock_response = Mock() + + async def async_iter_bytes(): + yield b"audio" + + mock_response.aiter_bytes = async_iter_bytes + async_mock_client_wrapper.request = AsyncMock(return_value=mock_response) + + references = [ + ReferenceAudio(audio=b"ref_audio_1", text="Sample 1"), + ReferenceAudio(audio=b"ref_audio_2", text="Sample 2"), + ] + + audio_chunks = [] + async for chunk in async_tts_client.convert( + text="Hello", references=references + ): + audio_chunks.append(chunk) + + # Verify references in payload + call_args = async_mock_client_wrapper.request.call_args + payload = ormsgpack.unpackb(call_args[1]["content"]) + assert len(payload["references"]) == 2 + assert payload["references"][0]["text"] == "Sample 1" + assert payload["references"][1]["text"] == "Sample 2" + + @pytest.mark.asyncio + async def test_convert_config_references_overrides_parameter( + self, async_tts_client, async_mock_client_wrapper + ): + """Test that config.references overrides parameter references (async).""" + mock_response = Mock() + + async def async_iter_bytes(): + yield b"audio" + + mock_response.aiter_bytes = async_iter_bytes + async_mock_client_wrapper.request = AsyncMock(return_value=mock_response) + + config_refs = [ReferenceAudio(audio=b"config_audio", text="Config")] + param_refs = [ReferenceAudio(audio=b"param_audio", text="Param")] + + config = TTSConfig(references=config_refs) + audio_chunks = [] + async for chunk in async_tts_client.convert( + text="Hello", references=param_refs, config=config + ): + audio_chunks.append(chunk) + + # Verify config references take precedence + call_args = async_mock_client_wrapper.request.call_args + payload = ormsgpack.unpackb(call_args[1]["content"]) + assert len(payload["references"]) == 1 + assert payload["references"][0]["text"] == "Config" + @pytest.mark.asyncio async def test_convert_with_prosody( self, async_tts_client, async_mock_client_wrapper From 77f73ac1fa3670b805e0fd1837b7daf4e67c39aa Mon Sep 17 00:00:00 2001 From: James Ding Date: Tue, 11 Nov 2025 15:27:46 -0600 Subject: [PATCH 03/10] feat: add tests for WebSocket streaming with reference_id and references parameters Signed-off-by: James Ding --- tests/unit/test_tts_realtime.py | 320 +++++++++++++++++++++++++++++++- 1 file changed, 319 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_tts_realtime.py b/tests/unit/test_tts_realtime.py index 72141ec..b51ba32 100644 --- a/tests/unit/test_tts_realtime.py +++ b/tests/unit/test_tts_realtime.py @@ -5,7 +5,8 @@ from fishaudio.core import ClientWrapper, AsyncClientWrapper from fishaudio.resources.tts import TTSClient, AsyncTTSClient -from fishaudio.types import Prosody, TTSConfig, TextEvent, FlushEvent +from fishaudio.types import Prosody, TTSConfig, TextEvent, FlushEvent, ReferenceAudio +import ormsgpack @pytest.fixture @@ -181,6 +182,169 @@ def test_stream_websocket_max_workers( # Verify ThreadPoolExecutor was created with max_workers=5 mock_executor.assert_called_once_with(max_workers=5) + @patch("fishaudio.resources.tts.connect_ws") + @patch("fishaudio.resources.tts.ThreadPoolExecutor") + def test_stream_websocket_with_reference_id_parameter( + self, mock_executor, mock_connect_ws, tts_client, mock_client_wrapper + ): + """Test WebSocket streaming with reference_id as direct parameter.""" + # Setup mocks + mock_ws = MagicMock() + mock_ws.__enter__ = Mock(return_value=mock_ws) + mock_ws.__exit__ = Mock(return_value=None) + mock_ws.send_bytes = Mock() + mock_connect_ws.return_value = mock_ws + + # Make executor.submit actually run the function + def submit_side_effect(fn): + fn() # Execute the sender function + mock_future = Mock() + mock_future.result.return_value = None + return mock_future + + mock_executor_instance = Mock() + mock_executor_instance.submit.side_effect = submit_side_effect + mock_executor.return_value = mock_executor_instance + + with patch("fishaudio.resources.tts.iter_websocket_audio") as mock_receiver: + mock_receiver.return_value = iter([b"audio"]) + + text_stream = iter(["Test"]) + list(tts_client.stream_websocket(text_stream, reference_id="voice_456")) + + # Verify WebSocket was called with StartEvent containing reference_id + assert mock_ws.send_bytes.called + # Get the first call (StartEvent) + first_call = mock_ws.send_bytes.call_args_list[0] + start_event_payload = ormsgpack.unpackb(first_call[0][0]) + assert start_event_payload["request"]["reference_id"] == "voice_456" + + @patch("fishaudio.resources.tts.connect_ws") + @patch("fishaudio.resources.tts.ThreadPoolExecutor") + def test_stream_websocket_config_reference_id_overrides_parameter( + self, mock_executor, mock_connect_ws, tts_client, mock_client_wrapper + ): + """Test that config.reference_id overrides parameter reference_id.""" + # Setup mocks + mock_ws = MagicMock() + mock_ws.__enter__ = Mock(return_value=mock_ws) + mock_ws.__exit__ = Mock(return_value=None) + mock_ws.send_bytes = Mock() + mock_connect_ws.return_value = mock_ws + + # Make executor.submit actually run the function + def submit_side_effect(fn): + fn() # Execute the sender function + mock_future = Mock() + mock_future.result.return_value = None + return mock_future + + mock_executor_instance = Mock() + mock_executor_instance.submit.side_effect = submit_side_effect + mock_executor.return_value = mock_executor_instance + + with patch("fishaudio.resources.tts.iter_websocket_audio") as mock_receiver: + mock_receiver.return_value = iter([b"audio"]) + + config = TTSConfig(reference_id="voice_from_config") + text_stream = iter(["Test"]) + list( + tts_client.stream_websocket( + text_stream, reference_id="voice_from_param", config=config + ) + ) + + # Verify config reference_id takes precedence + first_call = mock_ws.send_bytes.call_args_list[0] + start_event_payload = ormsgpack.unpackb(first_call[0][0]) + assert start_event_payload["request"]["reference_id"] == "voice_from_config" + + @patch("fishaudio.resources.tts.connect_ws") + @patch("fishaudio.resources.tts.ThreadPoolExecutor") + def test_stream_websocket_with_references_parameter( + self, mock_executor, mock_connect_ws, tts_client, mock_client_wrapper + ): + """Test WebSocket streaming with references as direct parameter.""" + # Setup mocks + mock_ws = MagicMock() + mock_ws.__enter__ = Mock(return_value=mock_ws) + mock_ws.__exit__ = Mock(return_value=None) + mock_ws.send_bytes = Mock() + mock_connect_ws.return_value = mock_ws + + # Make executor.submit actually run the function + def submit_side_effect(fn): + fn() # Execute the sender function + mock_future = Mock() + mock_future.result.return_value = None + return mock_future + + mock_executor_instance = Mock() + mock_executor_instance.submit.side_effect = submit_side_effect + mock_executor.return_value = mock_executor_instance + + with patch("fishaudio.resources.tts.iter_websocket_audio") as mock_receiver: + mock_receiver.return_value = iter([b"audio"]) + + references = [ + ReferenceAudio(audio=b"ref_audio_1", text="Sample 1"), + ReferenceAudio(audio=b"ref_audio_2", text="Sample 2"), + ] + + text_stream = iter(["Test"]) + list(tts_client.stream_websocket(text_stream, references=references)) + + # Verify references in StartEvent + first_call = mock_ws.send_bytes.call_args_list[0] + start_event_payload = ormsgpack.unpackb(first_call[0][0]) + assert len(start_event_payload["request"]["references"]) == 2 + assert start_event_payload["request"]["references"][0]["text"] == "Sample 1" + assert start_event_payload["request"]["references"][1]["text"] == "Sample 2" + + @patch("fishaudio.resources.tts.connect_ws") + @patch("fishaudio.resources.tts.ThreadPoolExecutor") + def test_stream_websocket_config_references_overrides_parameter( + self, mock_executor, mock_connect_ws, tts_client, mock_client_wrapper + ): + """Test that config.references overrides parameter references.""" + # Setup mocks + mock_ws = MagicMock() + mock_ws.__enter__ = Mock(return_value=mock_ws) + mock_ws.__exit__ = Mock(return_value=None) + mock_ws.send_bytes = Mock() + mock_connect_ws.return_value = mock_ws + + # Make executor.submit actually run the function + def submit_side_effect(fn): + fn() # Execute the sender function + mock_future = Mock() + mock_future.result.return_value = None + return mock_future + + mock_executor_instance = Mock() + mock_executor_instance.submit.side_effect = submit_side_effect + mock_executor.return_value = mock_executor_instance + + with patch("fishaudio.resources.tts.iter_websocket_audio") as mock_receiver: + mock_receiver.return_value = iter([b"audio"]) + + config_refs = [ReferenceAudio(audio=b"config_audio", text="Config")] + param_refs = [ReferenceAudio(audio=b"param_audio", text="Param")] + + config = TTSConfig(references=config_refs) + text_stream = iter(["Test"]) + list( + tts_client.stream_websocket( + text_stream, references=param_refs, config=config + ) + ) + + # Verify config references take precedence + first_call = mock_ws.send_bytes.call_args_list[0] + start_event_payload = ormsgpack.unpackb(first_call[0][0]) + assert len(start_event_payload["request"]["references"]) == 1 + assert start_event_payload["request"]["references"][0]["text"] == "Config" + class TestAsyncTTSRealtimeClient: """Test asynchronous AsyncTTSClient realtime streaming.""" @@ -331,3 +495,157 @@ async def text_stream(): # Should have no audio assert audio_chunks == [] + + @pytest.mark.asyncio + @patch("fishaudio.resources.tts.aconnect_ws") + async def test_stream_websocket_with_reference_id_parameter( + self, mock_aconnect_ws, async_tts_client, async_mock_client_wrapper + ): + """Test async WebSocket streaming with reference_id as direct parameter.""" + # Setup mocks + mock_ws = MagicMock() + mock_ws.__aenter__ = AsyncMock(return_value=mock_ws) + mock_ws.__aexit__ = AsyncMock(return_value=None) + mock_ws.send_bytes = AsyncMock() + mock_aconnect_ws.return_value = mock_ws + + async def mock_audio_receiver(ws): + yield b"audio" + + with patch( + "fishaudio.resources.tts.aiter_websocket_audio", + return_value=mock_audio_receiver(mock_ws), + ): + + async def text_stream(): + yield "Test" + + audio_chunks = [] + async for chunk in async_tts_client.stream_websocket( + text_stream(), reference_id="voice_456" + ): + audio_chunks.append(chunk) + + # Verify WebSocket was called with StartEvent containing reference_id + assert mock_ws.send_bytes.called + # Get the first call (StartEvent) + first_call = mock_ws.send_bytes.call_args_list[0] + start_event_payload = ormsgpack.unpackb(first_call[0][0]) + assert start_event_payload["request"]["reference_id"] == "voice_456" + + @pytest.mark.asyncio + @patch("fishaudio.resources.tts.aconnect_ws") + async def test_stream_websocket_config_reference_id_overrides_parameter( + self, mock_aconnect_ws, async_tts_client, async_mock_client_wrapper + ): + """Test that config.reference_id overrides parameter reference_id (async).""" + # Setup mocks + mock_ws = MagicMock() + mock_ws.__aenter__ = AsyncMock(return_value=mock_ws) + mock_ws.__aexit__ = AsyncMock(return_value=None) + mock_ws.send_bytes = AsyncMock() + mock_aconnect_ws.return_value = mock_ws + + async def mock_audio_receiver(ws): + yield b"audio" + + with patch( + "fishaudio.resources.tts.aiter_websocket_audio", + return_value=mock_audio_receiver(mock_ws), + ): + config = TTSConfig(reference_id="voice_from_config") + + async def text_stream(): + yield "Test" + + audio_chunks = [] + async for chunk in async_tts_client.stream_websocket( + text_stream(), reference_id="voice_from_param", config=config + ): + audio_chunks.append(chunk) + + # Verify config reference_id takes precedence + first_call = mock_ws.send_bytes.call_args_list[0] + start_event_payload = ormsgpack.unpackb(first_call[0][0]) + assert start_event_payload["request"]["reference_id"] == "voice_from_config" + + @pytest.mark.asyncio + @patch("fishaudio.resources.tts.aconnect_ws") + async def test_stream_websocket_with_references_parameter( + self, mock_aconnect_ws, async_tts_client, async_mock_client_wrapper + ): + """Test async WebSocket streaming with references as direct parameter.""" + # Setup mocks + mock_ws = MagicMock() + mock_ws.__aenter__ = AsyncMock(return_value=mock_ws) + mock_ws.__aexit__ = AsyncMock(return_value=None) + mock_ws.send_bytes = AsyncMock() + mock_aconnect_ws.return_value = mock_ws + + async def mock_audio_receiver(ws): + yield b"audio" + + with patch( + "fishaudio.resources.tts.aiter_websocket_audio", + return_value=mock_audio_receiver(mock_ws), + ): + references = [ + ReferenceAudio(audio=b"ref_audio_1", text="Sample 1"), + ReferenceAudio(audio=b"ref_audio_2", text="Sample 2"), + ] + + async def text_stream(): + yield "Test" + + audio_chunks = [] + async for chunk in async_tts_client.stream_websocket( + text_stream(), references=references + ): + audio_chunks.append(chunk) + + # Verify references in StartEvent + first_call = mock_ws.send_bytes.call_args_list[0] + start_event_payload = ormsgpack.unpackb(first_call[0][0]) + assert len(start_event_payload["request"]["references"]) == 2 + assert start_event_payload["request"]["references"][0]["text"] == "Sample 1" + assert start_event_payload["request"]["references"][1]["text"] == "Sample 2" + + @pytest.mark.asyncio + @patch("fishaudio.resources.tts.aconnect_ws") + async def test_stream_websocket_config_references_overrides_parameter( + self, mock_aconnect_ws, async_tts_client, async_mock_client_wrapper + ): + """Test that config.references overrides parameter references (async).""" + # Setup mocks + mock_ws = MagicMock() + mock_ws.__aenter__ = AsyncMock(return_value=mock_ws) + mock_ws.__aexit__ = AsyncMock(return_value=None) + mock_ws.send_bytes = AsyncMock() + mock_aconnect_ws.return_value = mock_ws + + async def mock_audio_receiver(ws): + yield b"audio" + + with patch( + "fishaudio.resources.tts.aiter_websocket_audio", + return_value=mock_audio_receiver(mock_ws), + ): + config_refs = [ReferenceAudio(audio=b"config_audio", text="Config")] + param_refs = [ReferenceAudio(audio=b"param_audio", text="Param")] + + config = TTSConfig(references=config_refs) + + async def text_stream(): + yield "Test" + + audio_chunks = [] + async for chunk in async_tts_client.stream_websocket( + text_stream(), references=param_refs, config=config + ): + audio_chunks.append(chunk) + + # Verify config references take precedence + first_call = mock_ws.send_bytes.call_args_list[0] + start_event_payload = ormsgpack.unpackb(first_call[0][0]) + assert len(start_event_payload["request"]["references"]) == 1 + assert start_event_payload["request"]["references"][0]["text"] == "Config" From 87a4adaf43a7617f097699723dfef0398c2d8d78 Mon Sep 17 00:00:00 2001 From: James Ding Date: Tue, 11 Nov 2025 19:36:52 -0600 Subject: [PATCH 04/10] feat: change references parameter to Optional in TTS methods Signed-off-by: James Ding --- src/fishaudio/resources/tts.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/fishaudio/resources/tts.py b/src/fishaudio/resources/tts.py index 63407cc..e5e8eda 100644 --- a/src/fishaudio/resources/tts.py +++ b/src/fishaudio/resources/tts.py @@ -60,7 +60,7 @@ def convert( *, text: str, reference_id: Optional[str] = None, - references: List[ReferenceAudio] = [], + references: Optional[List[ReferenceAudio]] = None, config: TTSConfig = TTSConfig(), model: Model = "s1", request_options: Optional[RequestOptions] = None, @@ -138,7 +138,7 @@ def stream_websocket( text_stream: Iterable[Union[str, TextEvent, FlushEvent]], *, reference_id: Optional[str] = None, - references: List[ReferenceAudio] = [], + references: Optional[List[ReferenceAudio]] = None, config: TTSConfig = TTSConfig(), model: Model = "s1", max_workers: int = 10, @@ -251,7 +251,7 @@ async def convert( *, text: str, reference_id: Optional[str] = None, - references: List[ReferenceAudio] = [], + references: Optional[List[ReferenceAudio]] = None, config: TTSConfig = TTSConfig(), model: Model = "s1", request_options: Optional[RequestOptions] = None, @@ -329,7 +329,7 @@ async def stream_websocket( text_stream: AsyncIterable[Union[str, TextEvent, FlushEvent]], *, reference_id: Optional[str] = None, - references: List[ReferenceAudio] = [], + references: Optional[List[ReferenceAudio]] = None, config: TTSConfig = TTSConfig(), model: Model = "s1", ): From e31b4c9b0d526172e1a0e3ece145336ac49441a0 Mon Sep 17 00:00:00 2001 From: James Ding Date: Wed, 12 Nov 2025 01:00:08 -0600 Subject: [PATCH 05/10] feat: add support for additional parameters in TTS methods (format, latency, speed) Signed-off-by: James Ding --- src/fishaudio/resources/tts.py | 187 +++++++++++++++++----- tests/unit/test_tts.py | 266 ++++++++++++++++++++++++++++++-- tests/unit/test_tts_realtime.py | 32 ++-- 3 files changed, 417 insertions(+), 68 deletions(-) diff --git a/src/fishaudio/resources/tts.py b/src/fishaudio/resources/tts.py index e5e8eda..afa1448 100644 --- a/src/fishaudio/resources/tts.py +++ b/src/fishaudio/resources/tts.py @@ -10,9 +10,12 @@ from .realtime import aiter_websocket_audio, iter_websocket_audio from ..core import AsyncClientWrapper, ClientWrapper, RequestOptions from ..types import ( + AudioFormat, CloseEvent, FlushEvent, + LatencyMode, Model, + Prosody, ReferenceAudio, StartEvent, TextEvent, @@ -61,6 +64,9 @@ def convert( text: str, reference_id: Optional[str] = None, references: Optional[List[ReferenceAudio]] = None, + format: Optional[AudioFormat] = None, + latency: Optional[LatencyMode] = None, + speed: Optional[float] = None, config: TTSConfig = TTSConfig(), model: Model = "s1", request_options: Optional[RequestOptions] = None, @@ -70,8 +76,11 @@ def convert( Args: text: Text to synthesize - reference_id: Voice reference ID (overridden by config.reference_id if set) - references: Reference audio samples (overridden by config.references if set) + reference_id: Voice reference ID (overrides config.reference_id if provided) + references: Reference audio samples (overrides config.references if provided) + format: Audio format - "mp3", "wav", or "pcm" (overrides config.format if provided) + latency: Latency mode - "normal" or "balanced" (overrides config.latency if provided) + speed: Speech speed multiplier, e.g. 1.5 for 1.5x speed (overrides config.prosody.speed if provided) config: TTS configuration (audio settings, voice, model parameters) model: TTS model to use request_options: Request-level overrides @@ -88,6 +97,12 @@ def convert( # Simple usage with defaults audio = client.tts.convert(text="Hello world") + # With format parameter + audio = client.tts.convert(text="Hello world", format="wav") + + # With speed parameter + audio = client.tts.convert(text="Hello world", speed=1.5) + # With reference_id parameter audio = client.tts.convert(text="Hello world", reference_id="your_model_id") @@ -97,9 +112,18 @@ def convert( references=[ReferenceAudio(audio=audio_bytes, text="sample")] ) - # Custom configuration - config = TTSConfig(format="wav", mp3_bitrate=192) - audio = client.tts.convert(text="Hello world", config=config) + # Combine multiple parameters + audio = client.tts.convert( + text="Hello world", + format="wav", + speed=1.2, + latency="normal" + ) + + # Parameters override config values + config = TTSConfig(format="mp3", speed=1.0) + audio = client.tts.convert(text="Hello world", format="wav", config=config) + # Result: format="wav" (parameter wins) with open("output.mp3", "wb") as f: for chunk in audio: @@ -109,14 +133,22 @@ def convert( # Build request payload from config request = _config_to_tts_request(config, text) - # Use parameter reference_id only if config doesn't have one - if request.reference_id is None and reference_id is not None: + # Apply direct parameters (always override config when provided) + if reference_id is not None: request.reference_id = reference_id - # Use parameter references only if config doesn't have any - if not request.references and references: + if references is not None: request.references = references + if format is not None: + request.format = format + + if latency is not None: + request.latency = latency + + if speed is not None: + request.prosody = Prosody(speed=speed) + payload = request.model_dump(exclude_none=True) # Make request with streaming @@ -139,6 +171,9 @@ def stream_websocket( *, reference_id: Optional[str] = None, references: Optional[List[ReferenceAudio]] = None, + format: Optional[AudioFormat] = None, + latency: Optional[LatencyMode] = None, + speed: Optional[float] = None, config: TTSConfig = TTSConfig(), model: Model = "s1", max_workers: int = 10, @@ -150,8 +185,11 @@ def stream_websocket( Args: text_stream: Iterator of text chunks to stream - reference_id: Voice reference ID (overridden by config.reference_id if set) - references: Reference audio samples (overridden by config.references if set) + reference_id: Voice reference ID (overrides config.reference_id if provided) + references: Reference audio samples (overrides config.references if provided) + format: Audio format - "mp3", "wav", or "pcm" (overrides config.format if provided) + latency: Latency mode - "normal" or "balanced" (overrides config.latency if provided) + speed: Speech speed multiplier, e.g. 1.5 for 1.5x speed (overrides config.prosody.speed if provided) config: TTS configuration (audio settings, voice, model parameters) model: TTS model to use max_workers: ThreadPoolExecutor workers for concurrent sender @@ -175,6 +213,15 @@ def text_generator(): for audio_chunk in client.tts.stream_websocket(text_generator()): f.write(audio_chunk) + # With format and speed parameters + with open("output.wav", "wb") as f: + for audio_chunk in client.tts.stream_websocket( + text_generator(), + format="wav", + speed=1.3 + ): + f.write(audio_chunk) + # With reference_id parameter with open("output.mp3", "wb") as f: for audio_chunk in client.tts.stream_websocket(text_generator(), reference_id="your_model_id"): @@ -188,24 +235,36 @@ def text_generator(): ): f.write(audio_chunk) - # Custom configuration - config = TTSConfig(format="wav", latency="normal") + # Parameters override config values + config = TTSConfig(format="mp3", latency="balanced") with open("output.wav", "wb") as f: - for audio_chunk in client.tts.stream_websocket(text_generator(), config=config): + for audio_chunk in client.tts.stream_websocket( + text_generator(), + format="wav", # Parameter wins + config=config + ): f.write(audio_chunk) ``` """ # Build TTSRequest from config tts_request = _config_to_tts_request(config, text="") - # Use parameter reference_id only if config doesn't have one - if tts_request.reference_id is None and reference_id is not None: + # Apply direct parameters (always override config when provided) + if reference_id is not None: tts_request.reference_id = reference_id - # Use parameter references only if config doesn't have any - if not tts_request.references and references: + if references is not None: tts_request.references = references + if format is not None: + tts_request.format = format + + if latency is not None: + tts_request.latency = latency + + if speed is not None: + tts_request.prosody = Prosody(speed=speed) + executor = ThreadPoolExecutor(max_workers=max_workers) try: @@ -252,6 +311,9 @@ async def convert( text: str, reference_id: Optional[str] = None, references: Optional[List[ReferenceAudio]] = None, + format: Optional[AudioFormat] = None, + latency: Optional[LatencyMode] = None, + speed: Optional[float] = None, config: TTSConfig = TTSConfig(), model: Model = "s1", request_options: Optional[RequestOptions] = None, @@ -261,8 +323,11 @@ async def convert( Args: text: Text to synthesize - reference_id: Voice reference ID (overridden by config.reference_id if set) - references: Reference audio samples (overridden by config.references if set) + reference_id: Voice reference ID (overrides config.reference_id if provided) + references: Reference audio samples (overrides config.references if provided) + format: Audio format - "mp3", "wav", or "pcm" (overrides config.format if provided) + latency: Latency mode - "normal" or "balanced" (overrides config.latency if provided) + speed: Speech speed multiplier, e.g. 1.5 for 1.5x speed (overrides config.prosody.speed if provided) config: TTS configuration (audio settings, voice, model parameters) model: TTS model to use request_options: Request-level overrides @@ -279,6 +344,12 @@ async def convert( # Simple usage with defaults audio = await client.tts.convert(text="Hello world") + # With format parameter + audio = await client.tts.convert(text="Hello world", format="wav") + + # With speed parameter + audio = await client.tts.convert(text="Hello world", speed=1.5) + # With reference_id parameter audio = await client.tts.convert(text="Hello world", reference_id="your_model_id") @@ -288,9 +359,18 @@ async def convert( references=[ReferenceAudio(audio=audio_bytes, text="sample")] ) - # Custom configuration - config = TTSConfig(format="wav", mp3_bitrate=192) - audio = await client.tts.convert(text="Hello world", config=config) + # Combine multiple parameters + audio = await client.tts.convert( + text="Hello world", + format="wav", + speed=1.2, + latency="normal" + ) + + # Parameters override config values + config = TTSConfig(format="mp3", speed=1.0) + audio = await client.tts.convert(text="Hello world", format="wav", config=config) + # Result: format="wav" (parameter wins) async with aiofiles.open("output.mp3", "wb") as f: async for chunk in audio: @@ -300,14 +380,22 @@ async def convert( # Build request payload from config request = _config_to_tts_request(config, text) - # Use parameter reference_id only if config doesn't have one - if request.reference_id is None and reference_id is not None: + # Apply direct parameters (always override config when provided) + if reference_id is not None: request.reference_id = reference_id - # Use parameter references only if config doesn't have any - if not request.references and references: + if references is not None: request.references = references + if format is not None: + request.format = format + + if latency is not None: + request.latency = latency + + if speed is not None: + request.prosody = Prosody(speed=speed) + payload = request.model_dump(exclude_none=True) # Make request with streaming @@ -330,6 +418,9 @@ async def stream_websocket( *, reference_id: Optional[str] = None, references: Optional[List[ReferenceAudio]] = None, + format: Optional[AudioFormat] = None, + latency: Optional[LatencyMode] = None, + speed: Optional[float] = None, config: TTSConfig = TTSConfig(), model: Model = "s1", ): @@ -340,8 +431,11 @@ async def stream_websocket( Args: text_stream: Async iterator of text chunks to stream - reference_id: Voice reference ID (overridden by config.reference_id if set) - references: Reference audio samples (overridden by config.references if set) + reference_id: Voice reference ID (overrides config.reference_id if provided) + references: Reference audio samples (overrides config.references if provided) + format: Audio format - "mp3", "wav", or "pcm" (overrides config.format if provided) + latency: Latency mode - "normal" or "balanced" (overrides config.latency if provided) + speed: Speech speed multiplier, e.g. 1.5 for 1.5x speed (overrides config.prosody.speed if provided) config: TTS configuration (audio settings, voice, model parameters) model: TTS model to use @@ -364,6 +458,15 @@ async def text_generator(): async for audio_chunk in client.tts.stream_websocket(text_generator()): await f.write(audio_chunk) + # With format and speed parameters + async with aiofiles.open("output.wav", "wb") as f: + async for audio_chunk in client.tts.stream_websocket( + text_generator(), + format="wav", + speed=1.3 + ): + await f.write(audio_chunk) + # With reference_id parameter async with aiofiles.open("output.mp3", "wb") as f: async for audio_chunk in client.tts.stream_websocket(text_generator(), reference_id="your_model_id"): @@ -377,24 +480,36 @@ async def text_generator(): ): await f.write(audio_chunk) - # Custom configuration - config = TTSConfig(format="wav", latency="normal") + # Parameters override config values + config = TTSConfig(format="mp3", latency="balanced") async with aiofiles.open("output.wav", "wb") as f: - async for audio_chunk in client.tts.stream_websocket(text_generator(), config=config): + async for audio_chunk in client.tts.stream_websocket( + text_generator(), + format="wav", # Parameter wins + config=config + ): await f.write(audio_chunk) ``` """ # Build TTSRequest from config tts_request = _config_to_tts_request(config, text="") - # Use parameter reference_id only if config doesn't have one - if tts_request.reference_id is None and reference_id is not None: + # Apply direct parameters (always override config when provided) + if reference_id is not None: tts_request.reference_id = reference_id - # Use parameter references only if config doesn't have any - if not tts_request.references and references: + if references is not None: tts_request.references = references + if format is not None: + tts_request.format = format + + if latency is not None: + tts_request.latency = latency + + if speed is not None: + tts_request.prosody = Prosody(speed=speed) + ws: AsyncWebSocketSession async with aconnect_ws( "/v1/tts/live", diff --git a/tests/unit/test_tts.py b/tests/unit/test_tts.py index eba871a..7d39785 100644 --- a/tests/unit/test_tts.py +++ b/tests/unit/test_tts.py @@ -94,10 +94,10 @@ def test_convert_with_reference_id_parameter(self, tts_client, mock_client_wrapp payload = ormsgpack.unpackb(call_args[1]["content"]) assert payload["reference_id"] == "voice_456" - def test_convert_config_reference_id_overrides_parameter( + def test_convert_parameter_reference_id_overrides_config( self, tts_client, mock_client_wrapper ): - """Test that config.reference_id overrides parameter reference_id.""" + """Test that parameter reference_id overrides config.reference_id.""" mock_response = Mock() mock_response.iter_bytes.return_value = iter([b"audio"]) mock_client_wrapper.request.return_value = mock_response @@ -109,10 +109,10 @@ def test_convert_config_reference_id_overrides_parameter( ) ) - # Verify config reference_id takes precedence + # Verify parameter reference_id takes precedence call_args = mock_client_wrapper.request.call_args payload = ormsgpack.unpackb(call_args[1]["content"]) - assert payload["reference_id"] == "voice_from_config" + assert payload["reference_id"] == "voice_from_param" def test_convert_with_references(self, tts_client, mock_client_wrapper): """Test TTS with reference audio samples.""" @@ -155,10 +155,10 @@ def test_convert_with_references_parameter(self, tts_client, mock_client_wrapper assert payload["references"][0]["text"] == "Sample 1" assert payload["references"][1]["text"] == "Sample 2" - def test_convert_config_references_overrides_parameter( + def test_convert_parameter_references_overrides_config( self, tts_client, mock_client_wrapper ): - """Test that config.references overrides parameter references.""" + """Test that parameter references overrides config.references.""" mock_response = Mock() mock_response.iter_bytes.return_value = iter([b"audio"]) mock_client_wrapper.request.return_value = mock_response @@ -169,11 +169,11 @@ def test_convert_config_references_overrides_parameter( config = TTSConfig(references=config_refs) list(tts_client.convert(text="Hello", references=param_refs, config=config)) - # Verify config references take precedence + # Verify parameter references take precedence call_args = mock_client_wrapper.request.call_args payload = ormsgpack.unpackb(call_args[1]["content"]) assert len(payload["references"]) == 1 - assert payload["references"][0]["text"] == "Config" + assert payload["references"][0]["text"] == "Param" def test_convert_with_different_backend(self, tts_client, mock_client_wrapper): """Test TTS with different backend/model.""" @@ -301,6 +301,97 @@ def test_convert_empty_response(self, tts_client, mock_client_wrapper): assert audio_chunks == [] + def test_convert_with_format_parameter(self, tts_client, mock_client_wrapper): + """Test TTS with format as direct parameter.""" + mock_response = Mock() + mock_response.iter_bytes.return_value = iter([b"audio"]) + mock_client_wrapper.request.return_value = mock_response + + list(tts_client.convert(text="Hello", format="wav")) + + # Verify format in payload + call_args = mock_client_wrapper.request.call_args + payload = ormsgpack.unpackb(call_args[1]["content"]) + assert payload["format"] == "wav" + + def test_convert_with_latency_parameter(self, tts_client, mock_client_wrapper): + """Test TTS with latency as direct parameter.""" + mock_response = Mock() + mock_response.iter_bytes.return_value = iter([b"audio"]) + mock_client_wrapper.request.return_value = mock_response + + list(tts_client.convert(text="Hello", latency="normal")) + + # Verify latency in payload + call_args = mock_client_wrapper.request.call_args + payload = ormsgpack.unpackb(call_args[1]["content"]) + assert payload["latency"] == "normal" + + def test_convert_with_speed_parameter(self, tts_client, mock_client_wrapper): + """Test TTS with speed as direct parameter.""" + mock_response = Mock() + mock_response.iter_bytes.return_value = iter([b"audio"]) + mock_client_wrapper.request.return_value = mock_response + + list(tts_client.convert(text="Hello", speed=1.5)) + + # Verify speed creates prosody in payload + call_args = mock_client_wrapper.request.call_args + payload = ormsgpack.unpackb(call_args[1]["content"]) + assert payload["prosody"]["speed"] == 1.5 + + def test_convert_parameter_format_overrides_config( + self, tts_client, mock_client_wrapper + ): + """Test that parameter format overrides config.format.""" + mock_response = Mock() + mock_response.iter_bytes.return_value = iter([b"audio"]) + mock_client_wrapper.request.return_value = mock_response + + config = TTSConfig(format="wav") + list(tts_client.convert(text="Hello", format="pcm", config=config)) + + # Verify parameter format takes precedence + call_args = mock_client_wrapper.request.call_args + payload = ormsgpack.unpackb(call_args[1]["content"]) + assert payload["format"] == "pcm" + + def test_convert_parameter_speed_overrides_config_prosody( + self, tts_client, mock_client_wrapper + ): + """Test that parameter speed overrides config.prosody.""" + mock_response = Mock() + mock_response.iter_bytes.return_value = iter([b"audio"]) + mock_client_wrapper.request.return_value = mock_response + + config = TTSConfig(prosody=Prosody(speed=2.0, volume=0.5)) + list(tts_client.convert(text="Hello", speed=1.5, config=config)) + + # Verify parameter speed takes precedence + call_args = mock_client_wrapper.request.call_args + payload = ormsgpack.unpackb(call_args[1]["content"]) + assert payload["prosody"]["speed"] == 1.5 + # Note: volume from config.prosody is lost when speed parameter is used + + def test_convert_combined_convenience_parameters( + self, tts_client, mock_client_wrapper + ): + """Test TTS with multiple convenience parameters combined.""" + mock_response = Mock() + mock_response.iter_bytes.return_value = iter([b"audio"]) + mock_client_wrapper.request.return_value = mock_response + + list( + tts_client.convert(text="Hello", format="wav", speed=1.3, latency="normal") + ) + + # Verify all parameters in payload + call_args = mock_client_wrapper.request.call_args + payload = ormsgpack.unpackb(call_args[1]["content"]) + assert payload["format"] == "wav" + assert payload["latency"] == "normal" + assert payload["prosody"]["speed"] == 1.3 + class TestAsyncTTSClient: """Test asynchronous AsyncTTSClient.""" @@ -380,10 +471,10 @@ async def async_iter_bytes(): assert payload["reference_id"] == "voice_456" @pytest.mark.asyncio - async def test_convert_config_reference_id_overrides_parameter( + async def test_convert_parameter_reference_id_overrides_config( self, async_tts_client, async_mock_client_wrapper ): - """Test that config.reference_id overrides parameter reference_id (async).""" + """Test that parameter reference_id overrides config.reference_id (async).""" mock_response = Mock() async def async_iter_bytes(): @@ -399,10 +490,10 @@ async def async_iter_bytes(): ): audio_chunks.append(chunk) - # Verify config reference_id takes precedence + # Verify parameter reference_id takes precedence call_args = async_mock_client_wrapper.request.call_args payload = ormsgpack.unpackb(call_args[1]["content"]) - assert payload["reference_id"] == "voice_from_config" + assert payload["reference_id"] == "voice_from_param" @pytest.mark.asyncio async def test_convert_with_references_parameter( @@ -436,10 +527,10 @@ async def async_iter_bytes(): assert payload["references"][1]["text"] == "Sample 2" @pytest.mark.asyncio - async def test_convert_config_references_overrides_parameter( + async def test_convert_parameter_references_overrides_config( self, async_tts_client, async_mock_client_wrapper ): - """Test that config.references overrides parameter references (async).""" + """Test that parameter references overrides config.references (async).""" mock_response = Mock() async def async_iter_bytes(): @@ -458,11 +549,11 @@ async def async_iter_bytes(): ): audio_chunks.append(chunk) - # Verify config references take precedence + # Verify parameter references take precedence call_args = async_mock_client_wrapper.request.call_args payload = ormsgpack.unpackb(call_args[1]["content"]) assert len(payload["references"]) == 1 - assert payload["references"][0]["text"] == "Config" + assert payload["references"][0]["text"] == "Param" @pytest.mark.asyncio async def test_convert_with_prosody( @@ -534,3 +625,146 @@ async def async_iter_bytes(): audio_chunks.append(chunk) assert audio_chunks == [] + + @pytest.mark.asyncio + async def test_convert_with_format_parameter( + self, async_tts_client, async_mock_client_wrapper + ): + """Test async TTS with format as direct parameter.""" + mock_response = Mock() + + async def async_iter_bytes(): + yield b"audio" + + mock_response.aiter_bytes = async_iter_bytes + async_mock_client_wrapper.request = AsyncMock(return_value=mock_response) + + audio_chunks = [] + async for chunk in async_tts_client.convert(text="Hello", format="wav"): + audio_chunks.append(chunk) + + # Verify format in payload + call_args = async_mock_client_wrapper.request.call_args + payload = ormsgpack.unpackb(call_args[1]["content"]) + assert payload["format"] == "wav" + + @pytest.mark.asyncio + async def test_convert_with_latency_parameter( + self, async_tts_client, async_mock_client_wrapper + ): + """Test async TTS with latency as direct parameter.""" + mock_response = Mock() + + async def async_iter_bytes(): + yield b"audio" + + mock_response.aiter_bytes = async_iter_bytes + async_mock_client_wrapper.request = AsyncMock(return_value=mock_response) + + audio_chunks = [] + async for chunk in async_tts_client.convert(text="Hello", latency="normal"): + audio_chunks.append(chunk) + + # Verify latency in payload + call_args = async_mock_client_wrapper.request.call_args + payload = ormsgpack.unpackb(call_args[1]["content"]) + assert payload["latency"] == "normal" + + @pytest.mark.asyncio + async def test_convert_with_speed_parameter( + self, async_tts_client, async_mock_client_wrapper + ): + """Test async TTS with speed as direct parameter.""" + mock_response = Mock() + + async def async_iter_bytes(): + yield b"audio" + + mock_response.aiter_bytes = async_iter_bytes + async_mock_client_wrapper.request = AsyncMock(return_value=mock_response) + + audio_chunks = [] + async for chunk in async_tts_client.convert(text="Hello", speed=1.5): + audio_chunks.append(chunk) + + # Verify speed creates prosody in payload + call_args = async_mock_client_wrapper.request.call_args + payload = ormsgpack.unpackb(call_args[1]["content"]) + assert payload["prosody"]["speed"] == 1.5 + + @pytest.mark.asyncio + async def test_convert_parameter_format_overrides_config( + self, async_tts_client, async_mock_client_wrapper + ): + """Test that parameter format overrides config.format (async).""" + mock_response = Mock() + + async def async_iter_bytes(): + yield b"audio" + + mock_response.aiter_bytes = async_iter_bytes + async_mock_client_wrapper.request = AsyncMock(return_value=mock_response) + + config = TTSConfig(format="wav") + audio_chunks = [] + async for chunk in async_tts_client.convert( + text="Hello", format="pcm", config=config + ): + audio_chunks.append(chunk) + + # Verify parameter format takes precedence + call_args = async_mock_client_wrapper.request.call_args + payload = ormsgpack.unpackb(call_args[1]["content"]) + assert payload["format"] == "pcm" + + @pytest.mark.asyncio + async def test_convert_parameter_speed_overrides_config_prosody( + self, async_tts_client, async_mock_client_wrapper + ): + """Test that parameter speed overrides config.prosody (async).""" + mock_response = Mock() + + async def async_iter_bytes(): + yield b"audio" + + mock_response.aiter_bytes = async_iter_bytes + async_mock_client_wrapper.request = AsyncMock(return_value=mock_response) + + config = TTSConfig(prosody=Prosody(speed=2.0, volume=0.5)) + audio_chunks = [] + async for chunk in async_tts_client.convert( + text="Hello", speed=1.5, config=config + ): + audio_chunks.append(chunk) + + # Verify parameter speed takes precedence + call_args = async_mock_client_wrapper.request.call_args + payload = ormsgpack.unpackb(call_args[1]["content"]) + assert payload["prosody"]["speed"] == 1.5 + # Note: volume from config.prosody is lost when speed parameter is used + + @pytest.mark.asyncio + async def test_convert_combined_convenience_parameters( + self, async_tts_client, async_mock_client_wrapper + ): + """Test async TTS with multiple convenience parameters combined.""" + mock_response = Mock() + + async def async_iter_bytes(): + yield b"audio" + + mock_response.aiter_bytes = async_iter_bytes + async_mock_client_wrapper.request = AsyncMock(return_value=mock_response) + + audio_chunks = [] + async for chunk in async_tts_client.convert( + text="Hello", format="wav", speed=1.3, latency="normal" + ): + audio_chunks.append(chunk) + + # Verify all parameters in payload + call_args = async_mock_client_wrapper.request.call_args + payload = ormsgpack.unpackb(call_args[1]["content"]) + assert payload["format"] == "wav" + assert payload["latency"] == "normal" + assert payload["prosody"]["speed"] == 1.3 diff --git a/tests/unit/test_tts_realtime.py b/tests/unit/test_tts_realtime.py index b51ba32..27874bb 100644 --- a/tests/unit/test_tts_realtime.py +++ b/tests/unit/test_tts_realtime.py @@ -221,10 +221,10 @@ def submit_side_effect(fn): @patch("fishaudio.resources.tts.connect_ws") @patch("fishaudio.resources.tts.ThreadPoolExecutor") - def test_stream_websocket_config_reference_id_overrides_parameter( + def test_stream_websocket_parameter_reference_id_overrides_config( self, mock_executor, mock_connect_ws, tts_client, mock_client_wrapper ): - """Test that config.reference_id overrides parameter reference_id.""" + """Test that parameter reference_id overrides config.reference_id.""" # Setup mocks mock_ws = MagicMock() mock_ws.__enter__ = Mock(return_value=mock_ws) @@ -254,10 +254,10 @@ def submit_side_effect(fn): ) ) - # Verify config reference_id takes precedence + # Verify parameter reference_id takes precedence first_call = mock_ws.send_bytes.call_args_list[0] start_event_payload = ormsgpack.unpackb(first_call[0][0]) - assert start_event_payload["request"]["reference_id"] == "voice_from_config" + assert start_event_payload["request"]["reference_id"] == "voice_from_param" @patch("fishaudio.resources.tts.connect_ws") @patch("fishaudio.resources.tts.ThreadPoolExecutor") @@ -303,10 +303,10 @@ def submit_side_effect(fn): @patch("fishaudio.resources.tts.connect_ws") @patch("fishaudio.resources.tts.ThreadPoolExecutor") - def test_stream_websocket_config_references_overrides_parameter( + def test_stream_websocket_parameter_references_overrides_config( self, mock_executor, mock_connect_ws, tts_client, mock_client_wrapper ): - """Test that config.references overrides parameter references.""" + """Test that parameter references overrides config.references.""" # Setup mocks mock_ws = MagicMock() mock_ws.__enter__ = Mock(return_value=mock_ws) @@ -339,11 +339,11 @@ def submit_side_effect(fn): ) ) - # Verify config references take precedence + # Verify parameter references take precedence first_call = mock_ws.send_bytes.call_args_list[0] start_event_payload = ormsgpack.unpackb(first_call[0][0]) assert len(start_event_payload["request"]["references"]) == 1 - assert start_event_payload["request"]["references"][0]["text"] == "Config" + assert start_event_payload["request"]["references"][0]["text"] == "Param" class TestAsyncTTSRealtimeClient: @@ -535,10 +535,10 @@ async def text_stream(): @pytest.mark.asyncio @patch("fishaudio.resources.tts.aconnect_ws") - async def test_stream_websocket_config_reference_id_overrides_parameter( + async def test_stream_websocket_parameter_reference_id_overrides_config( self, mock_aconnect_ws, async_tts_client, async_mock_client_wrapper ): - """Test that config.reference_id overrides parameter reference_id (async).""" + """Test that parameter reference_id overrides config.reference_id (async).""" # Setup mocks mock_ws = MagicMock() mock_ws.__aenter__ = AsyncMock(return_value=mock_ws) @@ -564,10 +564,10 @@ async def text_stream(): ): audio_chunks.append(chunk) - # Verify config reference_id takes precedence + # Verify parameter reference_id takes precedence first_call = mock_ws.send_bytes.call_args_list[0] start_event_payload = ormsgpack.unpackb(first_call[0][0]) - assert start_event_payload["request"]["reference_id"] == "voice_from_config" + assert start_event_payload["request"]["reference_id"] == "voice_from_param" @pytest.mark.asyncio @patch("fishaudio.resources.tts.aconnect_ws") @@ -612,10 +612,10 @@ async def text_stream(): @pytest.mark.asyncio @patch("fishaudio.resources.tts.aconnect_ws") - async def test_stream_websocket_config_references_overrides_parameter( + async def test_stream_websocket_parameter_references_overrides_config( self, mock_aconnect_ws, async_tts_client, async_mock_client_wrapper ): - """Test that config.references overrides parameter references (async).""" + """Test that parameter references overrides config.references (async).""" # Setup mocks mock_ws = MagicMock() mock_ws.__aenter__ = AsyncMock(return_value=mock_ws) @@ -644,8 +644,8 @@ async def text_stream(): ): audio_chunks.append(chunk) - # Verify config references take precedence + # Verify parameter references take precedence first_call = mock_ws.send_bytes.call_args_list[0] start_event_payload = ormsgpack.unpackb(first_call[0][0]) assert len(start_event_payload["request"]["references"]) == 1 - assert start_event_payload["request"]["references"][0]["text"] == "Config" + assert start_event_payload["request"]["references"][0]["text"] == "Param" From 2756bccbb713e517692c1f0f937069247c843ecd Mon Sep 17 00:00:00 2001 From: James Ding Date: Wed, 12 Nov 2025 01:36:57 -0600 Subject: [PATCH 06/10] feat: preserve volume when overriding speed in Prosody configuration Signed-off-by: James Ding --- src/fishaudio/resources/tts.py | 12 ++++++++---- src/fishaudio/types/tts.py | 18 ++++++++++++++++++ tests/unit/test_tts.py | 12 ++++++------ 3 files changed, 32 insertions(+), 10 deletions(-) diff --git a/src/fishaudio/resources/tts.py b/src/fishaudio/resources/tts.py index afa1448..b7e4c88 100644 --- a/src/fishaudio/resources/tts.py +++ b/src/fishaudio/resources/tts.py @@ -147,7 +147,7 @@ def convert( request.latency = latency if speed is not None: - request.prosody = Prosody(speed=speed) + request.prosody = Prosody.from_speed_override(speed, base=config.prosody) payload = request.model_dump(exclude_none=True) @@ -263,7 +263,9 @@ def text_generator(): tts_request.latency = latency if speed is not None: - tts_request.prosody = Prosody(speed=speed) + tts_request.prosody = Prosody.from_speed_override( + speed, base=config.prosody + ) executor = ThreadPoolExecutor(max_workers=max_workers) @@ -394,7 +396,7 @@ async def convert( request.latency = latency if speed is not None: - request.prosody = Prosody(speed=speed) + request.prosody = Prosody.from_speed_override(speed, base=config.prosody) payload = request.model_dump(exclude_none=True) @@ -508,7 +510,9 @@ async def text_generator(): tts_request.latency = latency if speed is not None: - tts_request.prosody = Prosody(speed=speed) + tts_request.prosody = Prosody.from_speed_override( + speed, base=config.prosody + ) ws: AsyncWebSocketSession async with aconnect_ws( diff --git a/src/fishaudio/types/tts.py b/src/fishaudio/types/tts.py index eb36398..94d1165 100644 --- a/src/fishaudio/types/tts.py +++ b/src/fishaudio/types/tts.py @@ -20,6 +20,24 @@ class Prosody(BaseModel): speed: float = 1.0 volume: float = 0.0 + @classmethod + def from_speed_override( + cls, speed: float, base: Optional["Prosody"] = None + ) -> "Prosody": + """ + Create Prosody with speed override, preserving volume from base. + + Args: + speed: Speed value to use + base: Base prosody to preserve volume from (if any) + + Returns: + New Prosody instance with overridden speed + """ + if base: + return cls(speed=speed, volume=base.volume) + return cls(speed=speed) + class TTSConfig(BaseModel): """ diff --git a/tests/unit/test_tts.py b/tests/unit/test_tts.py index 7d39785..d3053fc 100644 --- a/tests/unit/test_tts.py +++ b/tests/unit/test_tts.py @@ -359,7 +359,7 @@ def test_convert_parameter_format_overrides_config( def test_convert_parameter_speed_overrides_config_prosody( self, tts_client, mock_client_wrapper ): - """Test that parameter speed overrides config.prosody.""" + """Test that parameter speed overrides config.prosody speed but preserves volume.""" mock_response = Mock() mock_response.iter_bytes.return_value = iter([b"audio"]) mock_client_wrapper.request.return_value = mock_response @@ -367,11 +367,11 @@ def test_convert_parameter_speed_overrides_config_prosody( config = TTSConfig(prosody=Prosody(speed=2.0, volume=0.5)) list(tts_client.convert(text="Hello", speed=1.5, config=config)) - # Verify parameter speed takes precedence + # Verify parameter speed takes precedence but volume is preserved call_args = mock_client_wrapper.request.call_args payload = ormsgpack.unpackb(call_args[1]["content"]) assert payload["prosody"]["speed"] == 1.5 - # Note: volume from config.prosody is lost when speed parameter is used + assert payload["prosody"]["volume"] == 0.5 # Preserved from config! def test_convert_combined_convenience_parameters( self, tts_client, mock_client_wrapper @@ -721,7 +721,7 @@ async def async_iter_bytes(): async def test_convert_parameter_speed_overrides_config_prosody( self, async_tts_client, async_mock_client_wrapper ): - """Test that parameter speed overrides config.prosody (async).""" + """Test that parameter speed overrides config.prosody speed but preserves volume (async).""" mock_response = Mock() async def async_iter_bytes(): @@ -737,11 +737,11 @@ async def async_iter_bytes(): ): audio_chunks.append(chunk) - # Verify parameter speed takes precedence + # Verify parameter speed takes precedence but volume is preserved call_args = async_mock_client_wrapper.request.call_args payload = ormsgpack.unpackb(call_args[1]["content"]) assert payload["prosody"]["speed"] == 1.5 - # Note: volume from config.prosody is lost when speed parameter is used + assert payload["prosody"]["volume"] == 0.5 # Preserved from config! @pytest.mark.asyncio async def test_convert_combined_convenience_parameters( From 6c2620db985f07121e374c44b95745b147ce6d54 Mon Sep 17 00:00:00 2001 From: James Ding Date: Wed, 12 Nov 2025 02:11:48 -0600 Subject: [PATCH 07/10] feat: add check_free_credit parameter and support for opus audio format Signed-off-by: James Ding --- src/fishaudio/resources/account.py | 26 +++++++++++++++++++++++++- src/fishaudio/resources/tts.py | 8 ++++---- src/fishaudio/types/shared.py | 2 +- src/fishaudio/types/tts.py | 10 +++++----- tests/unit/test_tts.py | 13 +++++++++++++ 5 files changed, 48 insertions(+), 11 deletions(-) diff --git a/src/fishaudio/resources/account.py b/src/fishaudio/resources/account.py index 02beaf3..7ef096d 100644 --- a/src/fishaudio/resources/account.py +++ b/src/fishaudio/resources/account.py @@ -2,7 +2,7 @@ from typing import Optional -from ..core import AsyncClientWrapper, ClientWrapper, RequestOptions +from ..core import OMIT, AsyncClientWrapper, ClientWrapper, RequestOptions from ..types import Credits, Package @@ -15,12 +15,14 @@ def __init__(self, client_wrapper: ClientWrapper): def get_credits( self, *, + check_free_credit: Optional[bool] = OMIT, request_options: Optional[RequestOptions] = None, ) -> Credits: """ Get API credit balance. Args: + check_free_credit: Whether to check free credit availability request_options: Request-level overrides Returns: @@ -31,11 +33,21 @@ def get_credits( client = FishAudio(api_key="...") credits = client.account.get_credits() print(f"Available credits: {float(credits.credit)}") + + # Check free credit availability + credits = client.account.get_credits(check_free_credit=True) + if credits.has_free_credit: + print("Free credits available!") ``` """ + params = {} + if check_free_credit is not OMIT: + params["check_free_credit"] = check_free_credit + response = self._client.request( "GET", "/wallet/self/api-credit", + params=params, request_options=request_options, ) return Credits.model_validate(response.json()) @@ -78,12 +90,14 @@ def __init__(self, client_wrapper: AsyncClientWrapper): async def get_credits( self, *, + check_free_credit: Optional[bool] = OMIT, request_options: Optional[RequestOptions] = None, ) -> Credits: """ Get API credit balance (async). Args: + check_free_credit: Whether to check free credit availability request_options: Request-level overrides Returns: @@ -94,11 +108,21 @@ async def get_credits( client = AsyncFishAudio(api_key="...") credits = await client.account.get_credits() print(f"Available credits: {float(credits.credit)}") + + # Check free credit availability + credits = await client.account.get_credits(check_free_credit=True) + if credits.has_free_credit: + print("Free credits available!") ``` """ + params = {} + if check_free_credit is not OMIT: + params["check_free_credit"] = check_free_credit + response = await self._client.request( "GET", "/wallet/self/api-credit", + params=params, request_options=request_options, ) return Credits.model_validate(response.json()) diff --git a/src/fishaudio/resources/tts.py b/src/fishaudio/resources/tts.py index b7e4c88..019c1a0 100644 --- a/src/fishaudio/resources/tts.py +++ b/src/fishaudio/resources/tts.py @@ -78,7 +78,7 @@ def convert( text: Text to synthesize reference_id: Voice reference ID (overrides config.reference_id if provided) references: Reference audio samples (overrides config.references if provided) - format: Audio format - "mp3", "wav", or "pcm" (overrides config.format if provided) + format: Audio format - "mp3", "wav", "pcm", or "opus" (overrides config.format if provided) latency: Latency mode - "normal" or "balanced" (overrides config.latency if provided) speed: Speech speed multiplier, e.g. 1.5 for 1.5x speed (overrides config.prosody.speed if provided) config: TTS configuration (audio settings, voice, model parameters) @@ -187,7 +187,7 @@ def stream_websocket( text_stream: Iterator of text chunks to stream reference_id: Voice reference ID (overrides config.reference_id if provided) references: Reference audio samples (overrides config.references if provided) - format: Audio format - "mp3", "wav", or "pcm" (overrides config.format if provided) + format: Audio format - "mp3", "wav", "pcm", or "opus" (overrides config.format if provided) latency: Latency mode - "normal" or "balanced" (overrides config.latency if provided) speed: Speech speed multiplier, e.g. 1.5 for 1.5x speed (overrides config.prosody.speed if provided) config: TTS configuration (audio settings, voice, model parameters) @@ -327,7 +327,7 @@ async def convert( text: Text to synthesize reference_id: Voice reference ID (overrides config.reference_id if provided) references: Reference audio samples (overrides config.references if provided) - format: Audio format - "mp3", "wav", or "pcm" (overrides config.format if provided) + format: Audio format - "mp3", "wav", "pcm", or "opus" (overrides config.format if provided) latency: Latency mode - "normal" or "balanced" (overrides config.latency if provided) speed: Speech speed multiplier, e.g. 1.5 for 1.5x speed (overrides config.prosody.speed if provided) config: TTS configuration (audio settings, voice, model parameters) @@ -435,7 +435,7 @@ async def stream_websocket( text_stream: Async iterator of text chunks to stream reference_id: Voice reference ID (overrides config.reference_id if provided) references: Reference audio samples (overrides config.references if provided) - format: Audio format - "mp3", "wav", or "pcm" (overrides config.format if provided) + format: Audio format - "mp3", "wav", "pcm", or "opus" (overrides config.format if provided) latency: Latency mode - "normal" or "balanced" (overrides config.latency if provided) speed: Speech speed multiplier, e.g. 1.5 for 1.5x speed (overrides config.prosody.speed if provided) config: TTS configuration (audio settings, voice, model parameters) diff --git a/src/fishaudio/types/shared.py b/src/fishaudio/types/shared.py index 28ef594..df7ab4a 100644 --- a/src/fishaudio/types/shared.py +++ b/src/fishaudio/types/shared.py @@ -19,7 +19,7 @@ class PaginatedResponse(BaseModel, Generic[T]): Model = Literal["speech-1.5", "speech-1.6", "s1"] # Audio format types -AudioFormat = Literal["wav", "pcm", "mp3"] +AudioFormat = Literal["wav", "pcm", "mp3", "opus"] # Visibility types Visibility = Literal["public", "unlist", "private"] diff --git a/src/fishaudio/types/tts.py b/src/fishaudio/types/tts.py index 94d1165..63eeb92 100644 --- a/src/fishaudio/types/tts.py +++ b/src/fishaudio/types/tts.py @@ -4,7 +4,7 @@ from pydantic import BaseModel, Field -from .shared import LatencyMode +from .shared import AudioFormat, LatencyMode class ReferenceAudio(BaseModel): @@ -17,8 +17,8 @@ class ReferenceAudio(BaseModel): class Prosody(BaseModel): """Speech prosody settings (speed and volume).""" - speed: float = 1.0 - volume: float = 0.0 + speed: Annotated[float, Field(ge=0.5, le=2.0)] = 1.0 + volume: Annotated[float, Field(ge=-20.0, le=20.0)] = 0.0 @classmethod def from_speed_override( @@ -48,7 +48,7 @@ class TTSConfig(BaseModel): """ # Audio output settings - format: Literal["wav", "pcm", "mp3"] = "mp3" + format: AudioFormat = "mp3" sample_rate: Optional[int] = None mp3_bitrate: Literal[64, 128, 192] = 128 opus_bitrate: Literal[-1000, 24, 32, 48, 64] = 32 @@ -78,7 +78,7 @@ class TTSRequest(BaseModel): text: str chunk_length: Annotated[int, Field(ge=100, le=300, strict=True)] = 200 - format: Literal["wav", "pcm", "mp3"] = "mp3" + format: AudioFormat = "mp3" sample_rate: Optional[int] = None mp3_bitrate: Literal[64, 128, 192] = 128 opus_bitrate: Literal[-1000, 24, 32, 48, 64] = 32 diff --git a/tests/unit/test_tts.py b/tests/unit/test_tts.py index d3053fc..6ddff60 100644 --- a/tests/unit/test_tts.py +++ b/tests/unit/test_tts.py @@ -314,6 +314,19 @@ def test_convert_with_format_parameter(self, tts_client, mock_client_wrapper): payload = ormsgpack.unpackb(call_args[1]["content"]) assert payload["format"] == "wav" + def test_convert_with_opus_format(self, tts_client, mock_client_wrapper): + """Test TTS with opus format.""" + mock_response = Mock() + mock_response.iter_bytes.return_value = iter([b"audio"]) + mock_client_wrapper.request.return_value = mock_response + + list(tts_client.convert(text="Hello", format="opus")) + + # Verify opus format in payload + call_args = mock_client_wrapper.request.call_args + payload = ormsgpack.unpackb(call_args[1]["content"]) + assert payload["format"] == "opus" + def test_convert_with_latency_parameter(self, tts_client, mock_client_wrapper): """Test TTS with latency as direct parameter.""" mock_response = Mock() From 2c037dd3764b76ada00635a6829992d390f45efe Mon Sep 17 00:00:00 2001 From: James Ding Date: Wed, 12 Nov 2025 02:17:07 -0600 Subject: [PATCH 08/10] feat: enforce value constraints on top_p and temperature parameters in TTS models Signed-off-by: James Ding --- src/fishaudio/types/tts.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/fishaudio/types/tts.py b/src/fishaudio/types/tts.py index 63eeb92..4dd7671 100644 --- a/src/fishaudio/types/tts.py +++ b/src/fishaudio/types/tts.py @@ -64,8 +64,8 @@ class TTSConfig(BaseModel): prosody: Optional[Prosody] = None # Model parameters - top_p: float = 0.7 - temperature: float = 0.7 + top_p: Annotated[float, Field(ge=0.0, le=1.0)] = 0.7 + temperature: Annotated[float, Field(ge=0.0, le=1.0)] = 0.7 class TTSRequest(BaseModel): @@ -87,8 +87,8 @@ class TTSRequest(BaseModel): normalize: bool = True latency: LatencyMode = "balanced" prosody: Optional[Prosody] = None - top_p: float = 0.7 - temperature: float = 0.7 + top_p: Annotated[float, Field(ge=0.0, le=1.0)] = 0.7 + temperature: Annotated[float, Field(ge=0.0, le=1.0)] = 0.7 # WebSocket event types for streaming TTS From da11c285560e06157f83d0daf897652b8d72849d Mon Sep 17 00:00:00 2001 From: James Ding Date: Wed, 12 Nov 2025 02:30:51 -0600 Subject: [PATCH 09/10] Update src/fishaudio/resources/tts.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/fishaudio/resources/tts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fishaudio/resources/tts.py b/src/fishaudio/resources/tts.py index 019c1a0..8ee070d 100644 --- a/src/fishaudio/resources/tts.py +++ b/src/fishaudio/resources/tts.py @@ -121,7 +121,7 @@ def convert( ) # Parameters override config values - config = TTSConfig(format="mp3", speed=1.0) + config = TTSConfig(format="mp3", prosody=Prosody(speed=1.0)) audio = client.tts.convert(text="Hello world", format="wav", config=config) # Result: format="wav" (parameter wins) From 2bb2a048f0c37b32a36e4631c20ea4424dd638fa Mon Sep 17 00:00:00 2001 From: James Ding Date: Wed, 12 Nov 2025 02:30:59 -0600 Subject: [PATCH 10/10] Update src/fishaudio/resources/tts.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/fishaudio/resources/tts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fishaudio/resources/tts.py b/src/fishaudio/resources/tts.py index 8ee070d..fef1cd4 100644 --- a/src/fishaudio/resources/tts.py +++ b/src/fishaudio/resources/tts.py @@ -370,7 +370,7 @@ async def convert( ) # Parameters override config values - config = TTSConfig(format="mp3", speed=1.0) + config = TTSConfig(format="mp3", prosody=Prosody(speed=1.0)) audio = await client.tts.convert(text="Hello world", format="wav", config=config) # Result: format="wav" (parameter wins)