@@ -55,16 +55,40 @@ class OpenAIConfig(TypedDict, total=False):
5555 model_id : str
5656 params : Optional [dict [str , Any ]]
5757
58- def __init__ (self , client_args : Optional [dict [str , Any ]] = None , ** model_config : Unpack [OpenAIConfig ]) -> None :
58+ def __init__ (
59+ self ,
60+ client : Optional [Client ] = None ,
61+ client_args : Optional [dict [str , Any ]] = None ,
62+ ** model_config : Unpack [OpenAIConfig ],
63+ ) -> None :
5964 """Initialize provider instance.
6065
6166 Args:
62- client_args: Arguments for the OpenAI client.
67+ client: Pre-configured OpenAI-compatible client to reuse across requests.
68+ When provided, this client will be reused for all requests and will NOT be closed
69+ by the model. The caller is responsible for managing the client lifecycle.
70+ This is useful for:
71+ - Injecting custom client wrappers (e.g., GuardrailsAsyncOpenAI)
72+ - Reusing connection pools within a single event loop/worker
73+ - Centralizing observability, retries, and networking policy
74+ - Pointing to custom model gateways
75+ Note: The client should not be shared across different asyncio event loops.
76+ client_args: Arguments for the OpenAI client (legacy approach).
6377 For a complete list of supported arguments, see https://pypi.org/project/openai/.
78+ Note: If `client` is provided, this parameter is ignored.
6479 **model_config: Configuration options for the OpenAI model.
80+
81+ Raises:
82+ ValueError: If both `client` and `client_args` are provided.
6583 """
6684 validate_config_keys (model_config , self .OpenAIConfig )
6785 self .config = dict (model_config )
86+
87+ # Validate that only one client configuration method is provided
88+ if client is not None and client_args is not None and len (client_args ) > 0 :
89+ raise ValueError ("Only one of 'client' or 'client_args' should be provided, not both." )
90+
91+ self ._injected_client = client
6892 self .client_args = client_args or {}
6993
7094 logger .debug ("config=<%s> | initializing" , self .config )
@@ -454,12 +478,20 @@ async def stream(
454478
455479 logger .debug ("invoking model" )
456480
457- # We initialize an OpenAI context on every request so as to avoid connection sharing in the underlying httpx
458- # client. The asyncio event loop does not allow connections to be shared. For more details, please refer to
459- # https://github.com/encode/httpx/discussions/2959.
460- async with openai .AsyncOpenAI (** self .client_args ) as client :
481+ # Determine which client to use based on configuration
482+ if self ._injected_client is not None :
483+ # Use the injected client (caller manages lifecycle)
484+ client_to_use = self ._injected_client
485+ else :
486+ # Create a new client from client_args
487+ # We initialize an OpenAI context on every request so as to avoid connection sharing in the underlying
488+ # httpx client. The asyncio event loop does not allow connections to be shared. For more details, please
489+ # refer to https://github.com/encode/httpx/discussions/2959.
490+ client_to_use = openai .AsyncOpenAI (** self .client_args )
491+
492+ try :
461493 try :
462- response = await client .chat .completions .create (** request )
494+ response = await client_to_use .chat .completions .create (** request )
463495 except openai .BadRequestError as e :
464496 # Check if this is a context length exceeded error
465497 if hasattr (e , "code" ) and e .code == "context_length_exceeded" :
@@ -532,6 +564,11 @@ async def stream(
532564 if event and hasattr (event , "usage" ) and event .usage :
533565 yield self .format_chunk ({"chunk_type" : "metadata" , "data" : event .usage })
534566
567+ finally :
568+ # Only close the client if we created it (not injected)
569+ if self ._injected_client is None and hasattr (client_to_use , "close" ):
570+ await client_to_use .close ()
571+
535572 logger .debug ("finished streaming response from model" )
536573
537574 def _stream_switch_content (self , data_type : str , prev_data_type : str | None ) -> tuple [list [StreamEvent ], str ]:
@@ -573,40 +610,53 @@ async def structured_output(
573610 ContextWindowOverflowException: If the input exceeds the model's context window.
574611 ModelThrottledException: If the request is throttled by OpenAI (rate limits).
575612 """
576- # We initialize an OpenAI context on every request so as to avoid connection sharing in the underlying httpx
577- # client. The asyncio event loop does not allow connections to be shared. For more details, please refer to
578- # https://github.com/encode/httpx/discussions/2959.
579- async with openai .AsyncOpenAI (** self .client_args ) as client :
580- try :
581- response : ParsedChatCompletion = await client .beta .chat .completions .parse (
613+ # Determine which client to use based on configuration
614+ if self ._injected_client is not None :
615+ # Use the injected client (caller manages lifecycle)
616+ client_to_use = self ._injected_client
617+ else :
618+ # Create a new client from client_args
619+ # We initialize an OpenAI context on every request so as to avoid connection sharing in the underlying
620+ # httpx client. The asyncio event loop does not allow connections to be shared. For more details, please
621+ # refer to https://github.com/encode/httpx/discussions/2959.
622+ client_to_use = openai .AsyncOpenAI (** self .client_args )
623+
624+ try :
625+ if hasattr (client_to_use , "beta" ):
626+ response : ParsedChatCompletion = await client_to_use .beta .chat .completions .parse (
582627 model = self .get_config ()["model_id" ],
583628 messages = self .format_request (prompt , system_prompt = system_prompt )["messages" ],
584629 response_format = output_model ,
585630 )
586- except openai .BadRequestError as e :
587- # Check if this is a context length exceeded error
588- if hasattr (e , "code" ) and e .code == "context_length_exceeded" :
589- logger .warning ("OpenAI threw context window overflow error" )
590- raise ContextWindowOverflowException (str (e )) from e
591- # Re-raise other BadRequestError exceptions
592- raise
593- except openai .RateLimitError as e :
594- # All rate limit errors should be treated as throttling, not context overflow
595- # Rate limits (including TPM) require waiting/retrying, not context reduction
596- logger .warning ("OpenAI threw rate limit error" )
597- raise ModelThrottledException (str (e )) from e
598-
599- parsed : T | None = None
600- # Find the first choice with tool_calls
601- if len (response .choices ) > 1 :
602- raise ValueError ("Multiple choices found in the OpenAI response." )
631+ except openai .BadRequestError as e :
632+ # Check if this is a context length exceeded error
633+ if hasattr (e , "code" ) and e .code == "context_length_exceeded" :
634+ logger .warning ("OpenAI threw context window overflow error" )
635+ raise ContextWindowOverflowException (str (e )) from e
636+ # Re-raise other BadRequestError exceptions
637+ raise
638+ except openai .RateLimitError as e :
639+ # All rate limit errors should be treated as throttling, not context overflow
640+ # Rate limits (including TPM) require waiting/retrying, not context reduction
641+ logger .warning ("OpenAI threw rate limit error" )
642+ raise ModelThrottledException (str (e )) from e
643+ else :
644+ parsed : T | None = None
645+ # Find the first choice with tool_calls
646+ if len (response .choices ) > 1 :
647+ raise ValueError ("Multiple choices found in the OpenAI response." )
648+
649+ for choice in response .choices :
650+ if isinstance (choice .message .parsed , output_model ):
651+ parsed = choice .message .parsed
652+ break
603653
604- for choice in response . choices :
605- if isinstance ( choice . message . parsed , output_model ):
606- parsed = choice . message . parsed
607- break
654+ if parsed :
655+ yield { "output" : parsed }
656+ else :
657+ raise ValueError ( "No valid tool use or tool use input was found in the OpenAI response." )
608658
609- if parsed :
610- yield { "output" : parsed }
611- else :
612- raise ValueError ( "No valid tool use or tool use input was found in the OpenAI response." )
659+ finally :
660+ # Only close the client if we created it (not injected)
661+ if self . _injected_client is None and hasattr ( client_to_use , "close" ) :
662+ await client_to_use . close ( )
0 commit comments