diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 2d43b32..6426e17 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -44,9 +44,13 @@ def client(api_key): @pytest.fixture async def async_client(api_key): """Async Fish Audio client.""" + import asyncio + client = AsyncFishAudio(api_key=api_key) yield client await client.close() + # Brief delay to avoid API rate limits on WebSocket connections + await asyncio.sleep(0.3) @pytest.fixture diff --git a/tests/integration/test_tts_integration.py b/tests/integration/test_tts_integration.py index f6b4fc2..19e511a 100644 --- a/tests/integration/test_tts_integration.py +++ b/tests/integration/test_tts_integration.py @@ -50,15 +50,11 @@ def test_tts_with_different_models(self, client, save_audio): models = get_args(Model) for model in models: - try: - audio = client.tts.convert(text=f"Testing model {model}", model=model) - assert len(audio) > 0, f"Failed for model: {model}" - - # Write to output directory - save_audio(audio, f"test_model_{model}.mp3") - except Exception as e: - # Some models might not be available - pytest.skip(f"Model {model} not available: {e}") + audio = client.tts.convert(text=f"Testing model {model}", model=model) + assert len(audio) > 0, f"Failed for model: {model}" + + # Write to output directory + save_audio(audio, f"test_model_{model}.mp3") def test_tts_longer_text(self, client, save_audio): """Test TTS with longer text.""" diff --git a/tests/integration/test_tts_websocket_integration.py b/tests/integration/test_tts_websocket_integration.py index 8341a18..f521a2d 100644 --- a/tests/integration/test_tts_websocket_integration.py +++ b/tests/integration/test_tts_websocket_integration.py @@ -1,9 +1,12 @@ """Integration tests for TTS WebSocket streaming functionality.""" +from typing import get_args + import pytest from fishaudio import WebSocketOptions from fishaudio.types import Prosody, TTSConfig, TextEvent, FlushEvent +from fishaudio.types.shared import Model from .conftest import TEST_REFERENCE_ID @@ -31,6 +34,26 @@ def text_stream(): # Save the audio save_audio(audio_chunks, "test_websocket_streaming.mp3") + def test_websocket_streaming_with_different_models(self, client, save_audio): + """Test WebSocket streaming with different models.""" + import time + + models = get_args(Model) + + for model in models: + + def text_stream(): + yield f"Testing model {model} via WebSocket." + + audio_chunks = list(client.tts.stream_websocket(text_stream(), model=model)) + assert len(audio_chunks) > 0, f"Failed for model: {model}" + + # Write to output directory + save_audio(audio_chunks, f"test_websocket_model_{model}.mp3") + + # Brief delay to avoid SSL errors when opening next WebSocket connection + time.sleep(0.3) + def test_websocket_streaming_with_wav_format(self, client, save_audio): """Test WebSocket streaming with WAV format.""" config = TTSConfig(format="wav", chunk_length=200) @@ -195,6 +218,34 @@ async def text_stream(): save_audio(audio_chunks, "test_async_websocket_streaming.mp3") + @pytest.mark.asyncio + async def test_async_websocket_streaming_with_different_models( + self, async_client, save_audio + ): + """Test async WebSocket streaming with different models.""" + import asyncio + + models = get_args(Model) + + for model in models: + + async def text_stream(): + yield f"Testing model {model} via async WebSocket." + + audio_chunks = [] + async for chunk in async_client.tts.stream_websocket( + text_stream(), model=model + ): + audio_chunks.append(chunk) + + assert len(audio_chunks) > 0, f"Failed for model: {model}" + + # Write to output directory + save_audio(audio_chunks, f"test_async_websocket_model_{model}.mp3") + + # Brief delay to avoid SSL errors when opening next WebSocket connection + await asyncio.sleep(0.3) + @pytest.mark.asyncio async def test_async_websocket_streaming_with_format( self, async_client, save_audio @@ -285,6 +336,8 @@ async def test_async_websocket_streaming_multiple_calls( self, async_client, save_audio ): """Test multiple async WebSocket streaming calls in sequence.""" + import asyncio + for i in range(3): async def text_stream(): @@ -297,6 +350,9 @@ async def text_stream(): assert len(audio_chunks) > 0, f"Call {i + 1} should return audio" save_audio(audio_chunks, f"test_async_websocket_call_{i + 1}.mp3") + # Brief delay to avoid SSL errors when opening next WebSocket connection + await asyncio.sleep(0.3) + @pytest.mark.asyncio async def test_async_websocket_streaming_empty_text(self, async_client, save_audio): """Test async WebSocket streaming with empty text stream raises error."""