Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/fish_audio_sdk/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,20 @@ def __init__(self, apikey: str, *, base_url: str = "https://api.fish.audio"):
def init_async_client(self):
self._async_client = httpx.AsyncClient(
base_url=self._base_url,
headers={"Authorization": f"Bearer {self._apikey}"},
headers={
"Authorization": f"Bearer {self._apikey}",
"User-Agent": "fish-audio/python/legacy",
},
timeout=None,
)

def init_sync_client(self):
self._sync_client = httpx.Client(
base_url=self._base_url,
headers={"Authorization": f"Bearer {self._apikey}"},
headers={
"Authorization": f"Bearer {self._apikey}",
"User-Agent": "fish-audio/python/legacy",
},
timeout=None,
)

Expand Down
10 changes: 8 additions & 2 deletions src/fish_audio_sdk/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ def __init__(
self._executor = ThreadPoolExecutor(max_workers=max_workers)
self._client = httpx.Client(
base_url=self._base_url,
headers={"Authorization": f"Bearer {self._apikey}"},
headers={
"Authorization": f"Bearer {self._apikey}",
"User-Agent": "fish-audio/python/legacy",
},
)

def __enter__(self):
Expand Down Expand Up @@ -97,7 +100,10 @@ def __init__(
self._base_url = base_url
self._client = httpx.AsyncClient(
base_url=self._base_url,
headers={"Authorization": f"Bearer {self._apikey}"},
headers={
"Authorization": f"Bearer {self._apikey}",
"User-Agent": "fish-audio/python/legacy",
},
)

async def __aenter__(self):
Expand Down
10 changes: 5 additions & 5 deletions src/fishaudio/core/client_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,10 @@ def __init__(
)
self.base_url = base_url

def _get_headers(
def get_headers(
self, additional_headers: Optional[Dict[str, str]] = None
) -> Dict[str, str]:
"""Build headers including authentication."""
"""Build headers including authentication and user agent."""
headers = {
"Authorization": f"Bearer {self.api_key}",
"User-Agent": f"fish-audio/python/{__version__}",
Expand All @@ -77,7 +77,7 @@ def _prepare_request_kwargs(
) -> None:
"""Prepare request kwargs by merging headers, timeout, and query params."""
# Merge headers
headers = self._get_headers()
headers = self.get_headers()
if request_options and request_options.additional_headers:
headers.update(request_options.additional_headers)
kwargs["headers"] = {**headers, **kwargs.get("headers", {})}
Expand Down Expand Up @@ -113,7 +113,7 @@ def __init__(
self._client = httpx.Client(
base_url=base_url,
timeout=httpx.Timeout(timeout),
headers=self._get_headers(),
headers=self.get_headers(),
)

def request(
Expand Down Expand Up @@ -185,7 +185,7 @@ def __init__(
self._client = httpx.AsyncClient(
base_url=base_url,
timeout=httpx.Timeout(timeout),
headers=self._get_headers(),
headers=self.get_headers(),
)

async def request(
Expand Down
7 changes: 2 additions & 5 deletions src/fishaudio/resources/tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,10 +329,7 @@ def text_generator():
with connect_ws(
"/v1/tts/live",
client=self._client.client,
headers={
"model": model,
"Authorization": f"Bearer {self._client.api_key}",
},
headers=self._client.get_headers({"model": model}),
**ws_kwargs,
) as ws:

Expand Down Expand Up @@ -630,7 +627,7 @@ async def text_generator():
async with aconnect_ws(
"/v1/tts/live",
client=self._client.client,
headers={"model": model, "Authorization": f"Bearer {self._client.api_key}"},
headers=self._client.get_headers({"model": model}),
**ws_kwargs,
) as ws:

Expand Down
8 changes: 4 additions & 4 deletions tests/integration/test_tts_websocket_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def text_stream():
# Save the audio
save_audio(audio_chunks, "test_websocket_streaming.mp3")

@pytest.mark.flaky(reruns=2, reruns_delay=1)
@pytest.mark.flaky(reruns=9, reruns_delay=1)
def test_websocket_streaming_with_different_models(self, client, save_audio):
"""Test WebSocket streaming with different models."""
import time
Expand All @@ -53,7 +53,7 @@ def text_stream():
save_audio(audio_chunks, f"test_websocket_model_{model}.mp3")

# Brief delay to avoid SSL errors when opening next WebSocket connection
time.sleep(0.3)
time.sleep(1.0)

def test_websocket_streaming_with_wav_format(self, client, save_audio):
"""Test WebSocket streaming with WAV format."""
Expand Down Expand Up @@ -220,7 +220,7 @@ async def text_stream():
save_audio(audio_chunks, "test_async_websocket_streaming.mp3")

@pytest.mark.asyncio
@pytest.mark.flaky(reruns=2, reruns_delay=1)
@pytest.mark.flaky(reruns=9, reruns_delay=1)
async def test_async_websocket_streaming_with_different_models(
self, async_client, save_audio
):
Expand All @@ -246,7 +246,7 @@ async def text_stream():
save_audio(audio_chunks, f"test_async_websocket_model_{model}.mp3")

# Brief delay to avoid SSL errors when opening next WebSocket connection
await asyncio.sleep(0.3)
await asyncio.sleep(1.0)

@pytest.mark.asyncio
async def test_async_websocket_streaming_with_format(
Expand Down
6 changes: 3 additions & 3 deletions tests/unit/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,13 @@ def test_init_with_env_var(self, mock_api_key):

def test_get_headers(self, mock_api_key):
wrapper = ClientWrapper(api_key=mock_api_key)
headers = wrapper._get_headers()
headers = wrapper.get_headers()
assert headers["Authorization"] == f"Bearer {mock_api_key}"
assert "User-Agent" in headers

def test_get_headers_with_additional(self, mock_api_key):
wrapper = ClientWrapper(api_key=mock_api_key)
headers = wrapper._get_headers({"X-Custom": "value"})
headers = wrapper.get_headers({"X-Custom": "value"})
assert headers["X-Custom"] == "value"
assert headers["Authorization"] == f"Bearer {mock_api_key}"

Expand All @@ -139,6 +139,6 @@ def test_init_without_api_key_raises(self):

def test_get_headers(self, mock_api_key):
wrapper = AsyncClientWrapper(api_key=mock_api_key)
headers = wrapper._get_headers()
headers = wrapper.get_headers()
assert headers["Authorization"] == f"Bearer {mock_api_key}"
assert "User-Agent" in headers
12 changes: 12 additions & 0 deletions tests/unit/test_tts_realtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ def mock_client_wrapper(mock_api_key):
wrapper.api_key = mock_api_key
# Mock the underlying httpx.Client
wrapper._client = Mock()
# Mock get_headers to return a dict with the additional headers merged
wrapper.get_headers = lambda additional=None: {
"Authorization": f"Bearer {mock_api_key}",
"User-Agent": "fish-audio/python/test",
**(additional or {}),
}
return wrapper


Expand All @@ -26,6 +32,12 @@ def async_mock_client_wrapper(mock_api_key):
wrapper.api_key = mock_api_key
# Mock the underlying httpx.AsyncClient
wrapper._client = Mock()
# Mock get_headers to return a dict with the additional headers merged
wrapper.get_headers = lambda additional=None: {
"Authorization": f"Bearer {mock_api_key}",
"User-Agent": "fish-audio/python/test",
**(additional or {}),
}
return wrapper


Expand Down