From a9f18c05931980898eeb3c17fa2d38c2c3d533fa Mon Sep 17 00:00:00 2001 From: Billy Trend Date: Tue, 9 Jul 2024 12:29:51 +0100 Subject: [PATCH] Add thread_pool_executor to constructor --- src/cohere/client.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/cohere/client.py b/src/cohere/client.py index 87cd2a159..ba6780581 100644 --- a/src/cohere/client.py +++ b/src/cohere/client.py @@ -71,6 +71,8 @@ def fn(*args, **kwargs): class Client(BaseCohere, CacheMixin): + _executor: ThreadPoolExecutor + def __init__( self, api_key: typing.Optional[typing.Union[str, typing.Callable[[], str]]] = None, @@ -80,10 +82,13 @@ def __init__( client_name: typing.Optional[str] = None, timeout: typing.Optional[float] = None, httpx_client: typing.Optional[httpx.Client] = None, + thread_pool_executor: ThreadPoolExecutor = ThreadPoolExecutor(64) ): if api_key is None: api_key = _get_api_key_from_environment() + self._executor = thread_pool_executor + BaseCohere.__init__( self, base_url=base_url, @@ -108,8 +113,6 @@ def __exit__(self, exc_type, exc_value, traceback): wait = wait - _executor = ThreadPoolExecutor(64) - def embed( self, *, @@ -250,6 +253,8 @@ def fetch_tokenizer(self, *, model: str) -> Tokenizer: class AsyncClient(AsyncBaseCohere, CacheMixin): + _executor: ThreadPoolExecutor + def __init__( self, api_key: typing.Optional[typing.Union[str, typing.Callable[[], str]]] = None, @@ -259,10 +264,13 @@ def __init__( client_name: typing.Optional[str] = None, timeout: typing.Optional[float] = None, httpx_client: typing.Optional[httpx.AsyncClient] = None, + thread_pool_executor: ThreadPoolExecutor = ThreadPoolExecutor(64) ): if api_key is None: api_key = _get_api_key_from_environment() + self._executor = thread_pool_executor + AsyncBaseCohere.__init__( self, base_url=base_url, @@ -287,8 +295,6 @@ async def __aexit__(self, exc_type, exc_value, traceback): wait = async_wait - _executor = ThreadPoolExecutor(64) - async def embed( self, *,