Skip to content

Commit 3208acf

Browse files
committed
feat: allow custom-client for OpenAIModel and GeminiModel
1 parent 4342fda commit 3208acf

File tree

4 files changed

+276
-44
lines changed

4 files changed

+276
-44
lines changed

src/strands/models/gemini.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,23 +48,41 @@ class GeminiConfig(TypedDict, total=False):
4848
def __init__(
4949
self,
5050
*,
51+
client: Optional[genai.Client] = None,
5152
client_args: Optional[dict[str, Any]] = None,
5253
**model_config: Unpack[GeminiConfig],
5354
) -> None:
5455
"""Initialize provider instance.
5556
5657
Args:
58+
client: Pre-configured Gemini client to reuse across requests.
59+
When provided, this client will be reused for all requests and will NOT be closed
60+
by the model. The caller is responsible for managing the client lifecycle.
61+
This is useful for:
62+
- Injecting custom client wrappers
63+
- Reusing connection pools within a single event loop/worker
64+
- Centralizing observability, retries, and networking policy
65+
Note: The client should not be shared across different asyncio event loops.
5766
client_args: Arguments for the underlying Gemini client (e.g., api_key).
5867
For a complete list of supported arguments, see https://googleapis.github.io/python-genai/.
68+
Note: If `client` is provided, this parameter is ignored.
5969
**model_config: Configuration options for the Gemini model.
70+
71+
Raises:
72+
ValueError: If both `client` and `client_args` are provided.
6073
"""
6174
validate_config_keys(model_config, GeminiModel.GeminiConfig)
6275
self.config = GeminiModel.GeminiConfig(**model_config)
6376

64-
logger.debug("config=<%s> | initializing", self.config)
77+
# Validate that only one client configuration method is provided
78+
if client is not None and client_args is not None and len(client_args) > 0:
79+
raise ValueError("Only one of 'client' or 'client_args' should be provided, not both.")
6580

81+
self._injected_client = client
6682
self.client_args = client_args or {}
6783

84+
logger.debug("config=<%s> | initializing", self.config)
85+
6886
@override
6987
def update_config(self, **model_config: Unpack[GeminiConfig]) -> None: # type: ignore[override]
7088
"""Update the Gemini model configuration with the provided arguments.
@@ -365,9 +383,16 @@ async def stream(
365383
"""
366384
request = self._format_request(messages, tool_specs, system_prompt, self.config.get("params"))
367385

368-
client = genai.Client(**self.client_args).aio
386+
# Determine which client to use based on configuration
387+
if self._injected_client is not None:
388+
# Use the injected client (caller manages lifecycle)
389+
client_aio = self._injected_client.aio
390+
else:
391+
# Create a new client from client_args
392+
client_aio = genai.Client(**self.client_args).aio
393+
369394
try:
370-
response = await client.models.generate_content_stream(**request)
395+
response = await client_aio.models.generate_content_stream(**request)
371396

372397
yield self._format_chunk({"chunk_type": "message_start"})
373398
yield self._format_chunk({"chunk_type": "content_start", "data_type": "text"})
@@ -448,6 +473,14 @@ async def structured_output(
448473
"response_schema": output_model.model_json_schema(),
449474
}
450475
request = self._format_request(prompt, None, system_prompt, params)
451-
client = genai.Client(**self.client_args).aio
452-
response = await client.models.generate_content(**request)
476+
477+
# Determine which client to use based on configuration
478+
if self._injected_client is not None:
479+
# Use the injected client (caller manages lifecycle)
480+
client_aio = self._injected_client.aio
481+
else:
482+
# Create a new client from client_args
483+
client_aio = genai.Client(**self.client_args).aio
484+
485+
response = await client_aio.models.generate_content(**request)
453486
yield {"output": output_model.model_validate(response.parsed)}

src/strands/models/openai.py

Lines changed: 88 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -55,16 +55,40 @@ class OpenAIConfig(TypedDict, total=False):
5555
model_id: str
5656
params: Optional[dict[str, Any]]
5757

58-
def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: Unpack[OpenAIConfig]) -> None:
58+
def __init__(
59+
self,
60+
client: Optional[Client] = None,
61+
client_args: Optional[dict[str, Any]] = None,
62+
**model_config: Unpack[OpenAIConfig],
63+
) -> None:
5964
"""Initialize provider instance.
6065
6166
Args:
62-
client_args: Arguments for the OpenAI client.
67+
client: Pre-configured OpenAI-compatible client to reuse across requests.
68+
When provided, this client will be reused for all requests and will NOT be closed
69+
by the model. The caller is responsible for managing the client lifecycle.
70+
This is useful for:
71+
- Injecting custom client wrappers (e.g., GuardrailsAsyncOpenAI)
72+
- Reusing connection pools within a single event loop/worker
73+
- Centralizing observability, retries, and networking policy
74+
- Pointing to custom model gateways
75+
Note: The client should not be shared across different asyncio event loops.
76+
client_args: Arguments for the OpenAI client (legacy approach).
6377
For a complete list of supported arguments, see https://pypi.org/project/openai/.
78+
Note: If `client` is provided, this parameter is ignored.
6479
**model_config: Configuration options for the OpenAI model.
80+
81+
Raises:
82+
ValueError: If both `client` and `client_args` are provided.
6583
"""
6684
validate_config_keys(model_config, self.OpenAIConfig)
6785
self.config = dict(model_config)
86+
87+
# Validate that only one client configuration method is provided
88+
if client is not None and client_args is not None and len(client_args) > 0:
89+
raise ValueError("Only one of 'client' or 'client_args' should be provided, not both.")
90+
91+
self._injected_client = client
6892
self.client_args = client_args or {}
6993

7094
logger.debug("config=<%s> | initializing", self.config)
@@ -454,12 +478,20 @@ async def stream(
454478

455479
logger.debug("invoking model")
456480

457-
# We initialize an OpenAI context on every request so as to avoid connection sharing in the underlying httpx
458-
# client. The asyncio event loop does not allow connections to be shared. For more details, please refer to
459-
# https://github.com/encode/httpx/discussions/2959.
460-
async with openai.AsyncOpenAI(**self.client_args) as client:
481+
# Determine which client to use based on configuration
482+
if self._injected_client is not None:
483+
# Use the injected client (caller manages lifecycle)
484+
client_to_use = self._injected_client
485+
else:
486+
# Create a new client from client_args
487+
# We initialize an OpenAI context on every request so as to avoid connection sharing in the underlying
488+
# httpx client. The asyncio event loop does not allow connections to be shared. For more details, please
489+
# refer to https://github.com/encode/httpx/discussions/2959.
490+
client_to_use = openai.AsyncOpenAI(**self.client_args)
491+
492+
try:
461493
try:
462-
response = await client.chat.completions.create(**request)
494+
response = await client_to_use.chat.completions.create(**request)
463495
except openai.BadRequestError as e:
464496
# Check if this is a context length exceeded error
465497
if hasattr(e, "code") and e.code == "context_length_exceeded":
@@ -532,6 +564,11 @@ async def stream(
532564
if event and hasattr(event, "usage") and event.usage:
533565
yield self.format_chunk({"chunk_type": "metadata", "data": event.usage})
534566

567+
finally:
568+
# Only close the client if we created it (not injected)
569+
if self._injected_client is None and hasattr(client_to_use, "close"):
570+
await client_to_use.close()
571+
535572
logger.debug("finished streaming response from model")
536573

537574
def _stream_switch_content(self, data_type: str, prev_data_type: str | None) -> tuple[list[StreamEvent], str]:
@@ -573,40 +610,53 @@ async def structured_output(
573610
ContextWindowOverflowException: If the input exceeds the model's context window.
574611
ModelThrottledException: If the request is throttled by OpenAI (rate limits).
575612
"""
576-
# We initialize an OpenAI context on every request so as to avoid connection sharing in the underlying httpx
577-
# client. The asyncio event loop does not allow connections to be shared. For more details, please refer to
578-
# https://github.com/encode/httpx/discussions/2959.
579-
async with openai.AsyncOpenAI(**self.client_args) as client:
580-
try:
581-
response: ParsedChatCompletion = await client.beta.chat.completions.parse(
613+
# Determine which client to use based on configuration
614+
if self._injected_client is not None:
615+
# Use the injected client (caller manages lifecycle)
616+
client_to_use = self._injected_client
617+
else:
618+
# Create a new client from client_args
619+
# We initialize an OpenAI context on every request so as to avoid connection sharing in the underlying
620+
# httpx client. The asyncio event loop does not allow connections to be shared. For more details, please
621+
# refer to https://github.com/encode/httpx/discussions/2959.
622+
client_to_use = openai.AsyncOpenAI(**self.client_args)
623+
624+
try:
625+
if hasattr(client_to_use, "beta"):
626+
response: ParsedChatCompletion = await client_to_use.beta.chat.completions.parse(
582627
model=self.get_config()["model_id"],
583628
messages=self.format_request(prompt, system_prompt=system_prompt)["messages"],
584629
response_format=output_model,
585630
)
586-
except openai.BadRequestError as e:
587-
# Check if this is a context length exceeded error
588-
if hasattr(e, "code") and e.code == "context_length_exceeded":
589-
logger.warning("OpenAI threw context window overflow error")
590-
raise ContextWindowOverflowException(str(e)) from e
591-
# Re-raise other BadRequestError exceptions
592-
raise
593-
except openai.RateLimitError as e:
594-
# All rate limit errors should be treated as throttling, not context overflow
595-
# Rate limits (including TPM) require waiting/retrying, not context reduction
596-
logger.warning("OpenAI threw rate limit error")
597-
raise ModelThrottledException(str(e)) from e
598-
599-
parsed: T | None = None
600-
# Find the first choice with tool_calls
601-
if len(response.choices) > 1:
602-
raise ValueError("Multiple choices found in the OpenAI response.")
631+
except openai.BadRequestError as e:
632+
# Check if this is a context length exceeded error
633+
if hasattr(e, "code") and e.code == "context_length_exceeded":
634+
logger.warning("OpenAI threw context window overflow error")
635+
raise ContextWindowOverflowException(str(e)) from e
636+
# Re-raise other BadRequestError exceptions
637+
raise
638+
except openai.RateLimitError as e:
639+
# All rate limit errors should be treated as throttling, not context overflow
640+
# Rate limits (including TPM) require waiting/retrying, not context reduction
641+
logger.warning("OpenAI threw rate limit error")
642+
raise ModelThrottledException(str(e)) from e
643+
else:
644+
parsed: T | None = None
645+
# Find the first choice with tool_calls
646+
if len(response.choices) > 1:
647+
raise ValueError("Multiple choices found in the OpenAI response.")
648+
649+
for choice in response.choices:
650+
if isinstance(choice.message.parsed, output_model):
651+
parsed = choice.message.parsed
652+
break
603653

604-
for choice in response.choices:
605-
if isinstance(choice.message.parsed, output_model):
606-
parsed = choice.message.parsed
607-
break
654+
if parsed:
655+
yield {"output": parsed}
656+
else:
657+
raise ValueError("No valid tool use or tool use input was found in the OpenAI response.")
608658

609-
if parsed:
610-
yield {"output": parsed}
611-
else:
612-
raise ValueError("No valid tool use or tool use input was found in the OpenAI response.")
659+
finally:
660+
# Only close the client if we created it (not injected)
661+
if self._injected_client is None and hasattr(client_to_use, "close"):
662+
await client_to_use.close()

tests/strands/models/test_gemini.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -637,3 +637,77 @@ async def test_stream_handles_non_json_error(gemini_client, model, messages, cap
637637

638638
assert "Gemini API returned non-JSON error" in caplog.text
639639
assert f"error_message=<{error_message}>" in caplog.text
640+
641+
642+
@pytest.mark.asyncio
643+
async def test_stream_with_injected_client(model_id, agenerator, alist):
644+
"""Test that stream works with an injected client and doesn't close it."""
645+
# Create a mock injected client
646+
mock_injected_client = unittest.mock.Mock()
647+
mock_injected_client.aio = unittest.mock.AsyncMock()
648+
649+
mock_injected_client.aio.models.generate_content_stream.return_value = agenerator(
650+
[
651+
genai.types.GenerateContentResponse(
652+
candidates=[
653+
genai.types.Candidate(
654+
content=genai.types.Content(
655+
parts=[genai.types.Part(text="Hello")],
656+
),
657+
finish_reason="STOP",
658+
),
659+
],
660+
usage_metadata=genai.types.GenerateContentResponseUsageMetadata(
661+
prompt_token_count=1,
662+
total_token_count=3,
663+
),
664+
),
665+
]
666+
)
667+
668+
# Create model with injected client
669+
model = GeminiModel(client=mock_injected_client, model_id=model_id)
670+
671+
messages = [{"role": "user", "content": [{"text": "test"}]}]
672+
response = model.stream(messages)
673+
tru_events = await alist(response)
674+
675+
# Verify events were generated
676+
assert len(tru_events) > 0
677+
678+
# Verify the injected client was used
679+
mock_injected_client.aio.models.generate_content_stream.assert_called_once()
680+
681+
682+
@pytest.mark.asyncio
683+
async def test_structured_output_with_injected_client(model_id, weather_output, alist):
684+
"""Test that structured_output works with an injected client and doesn't close it."""
685+
# Create a mock injected client
686+
mock_injected_client = unittest.mock.Mock()
687+
mock_injected_client.aio = unittest.mock.AsyncMock()
688+
689+
mock_injected_client.aio.models.generate_content.return_value = unittest.mock.Mock(
690+
parsed=weather_output.model_dump()
691+
)
692+
693+
# Create model with injected client
694+
model = GeminiModel(client=mock_injected_client, model_id=model_id)
695+
696+
messages = [{"role": "user", "content": [{"text": "Generate weather"}]}]
697+
stream = model.structured_output(type(weather_output), messages)
698+
events = await alist(stream)
699+
700+
# Verify output was generated
701+
assert len(events) == 1
702+
assert events[0] == {"output": weather_output}
703+
704+
# Verify the injected client was used
705+
mock_injected_client.aio.models.generate_content.assert_called_once()
706+
707+
708+
def test_init_with_both_client_and_client_args_raises_error():
709+
"""Test that providing both client and client_args raises ValueError."""
710+
mock_client = unittest.mock.Mock()
711+
712+
with pytest.raises(ValueError, match="Only one of 'client' or 'client_args' should be provided"):
713+
GeminiModel(client=mock_client, client_args={"api_key": "test"}, model_id="test-model")

0 commit comments

Comments
 (0)