1010 Protocol ,
1111 Required ,
1212 TypedDict ,
13+ TypeGuard ,
1314 final ,
1415 override ,
1516 runtime_checkable ,
3132 HttpxRequestFiles ,
3233)
3334from ._utils import (
35+ PropertyInfo ,
3436 is_list ,
3537 is_given ,
3638 is_mapping ,
3941 strip_not_given ,
4042 extract_type_arg ,
4143 is_annotated_type ,
44+ strip_annotated_type ,
4245)
4346from ._compat import (
4447 PYDANTIC_V2 ,
5558)
5659from ._constants import RAW_RESPONSE_HEADER
5760
61+ if TYPE_CHECKING :
62+ from pydantic_core .core_schema import ModelField , ModelFieldsSchema
63+
5864__all__ = ["BaseModel" , "GenericModel" ]
5965
6066_T = TypeVar ("_T" )
@@ -268,14 +274,18 @@ def _construct_field(value: object, field: FieldInfo, key: str) -> object:
268274
269275def is_basemodel (type_ : type ) -> bool :
270276 """Returns whether or not the given type is either a `BaseModel` or a union of `BaseModel`"""
271- origin = get_origin (type_ ) or type_
272277 if is_union (type_ ):
273278 for variant in get_args (type_ ):
274279 if is_basemodel (variant ):
275280 return True
276281
277282 return False
278283
284+ return is_basemodel_type (type_ )
285+
286+
287+ def is_basemodel_type (type_ : type ) -> TypeGuard [type [BaseModel ] | type [GenericModel ]]:
288+ origin = get_origin (type_ ) or type_
279289 return issubclass (origin , BaseModel ) or issubclass (origin , GenericModel )
280290
281291
@@ -286,7 +296,10 @@ def construct_type(*, value: object, type_: type) -> object:
286296 """
287297 # unwrap `Annotated[T, ...]` -> `T`
288298 if is_annotated_type (type_ ):
299+ meta = get_args (type_ )[1 :]
289300 type_ = extract_type_arg (type_ , 0 )
301+ else :
302+ meta = tuple ()
290303
291304 # we need to use the origin class for any types that are subscripted generics
292305 # e.g. Dict[str, object]
@@ -299,6 +312,28 @@ def construct_type(*, value: object, type_: type) -> object:
299312 except Exception :
300313 pass
301314
315+ # if the type is a discriminated union then we want to construct the right variant
316+ # in the union, even if the data doesn't match exactly, otherwise we'd break code
317+ # that relies on the constructed class types, e.g.
318+ #
319+ # class FooType:
320+ # kind: Literal['foo']
321+ # value: str
322+ #
323+ # class BarType:
324+ # kind: Literal['bar']
325+ # value: int
326+ #
327+ # without this block, if the data we get is something like `{'kind': 'bar', 'value': 'foo'}` then
328+ # we'd end up constructing `FooType` when it should be `BarType`.
329+ discriminator = _build_discriminated_union_meta (union = type_ , meta_annotations = meta )
330+ if discriminator and is_mapping (value ):
331+ variant_value = value .get (discriminator .field_alias_from or discriminator .field_name )
332+ if variant_value and isinstance (variant_value , str ):
333+ variant_type = discriminator .mapping .get (variant_value )
334+ if variant_type :
335+ return construct_type (type_ = variant_type , value = value )
336+
302337 # if the data is not valid, use the first variant that doesn't fail while deserializing
303338 for variant in args :
304339 try :
@@ -356,6 +391,129 @@ def construct_type(*, value: object, type_: type) -> object:
356391 return value
357392
358393
394+ @runtime_checkable
395+ class CachedDiscriminatorType (Protocol ):
396+ __discriminator__ : DiscriminatorDetails
397+
398+
399+ class DiscriminatorDetails :
400+ field_name : str
401+ """The name of the discriminator field in the variant class, e.g.
402+
403+ ```py
404+ class Foo(BaseModel):
405+ type: Literal['foo']
406+ ```
407+
408+ Will result in field_name='type'
409+ """
410+
411+ field_alias_from : str | None
412+ """The name of the discriminator field in the API response, e.g.
413+
414+ ```py
415+ class Foo(BaseModel):
416+ type: Literal['foo'] = Field(alias='type_from_api')
417+ ```
418+
419+ Will result in field_alias_from='type_from_api'
420+ """
421+
422+ mapping : dict [str , type ]
423+ """Mapping of discriminator value to variant type, e.g.
424+
425+ {'foo': FooVariant, 'bar': BarVariant}
426+ """
427+
428+ def __init__ (
429+ self ,
430+ * ,
431+ mapping : dict [str , type ],
432+ discriminator_field : str ,
433+ discriminator_alias : str | None ,
434+ ) -> None :
435+ self .mapping = mapping
436+ self .field_name = discriminator_field
437+ self .field_alias_from = discriminator_alias
438+
439+
440+ def _build_discriminated_union_meta (* , union : type , meta_annotations : tuple [Any , ...]) -> DiscriminatorDetails | None :
441+ if isinstance (union , CachedDiscriminatorType ):
442+ return union .__discriminator__
443+
444+ discriminator_field_name : str | None = None
445+
446+ for annotation in meta_annotations :
447+ if isinstance (annotation , PropertyInfo ) and annotation .discriminator is not None :
448+ discriminator_field_name = annotation .discriminator
449+ break
450+
451+ if not discriminator_field_name :
452+ return None
453+
454+ mapping : dict [str , type ] = {}
455+ discriminator_alias : str | None = None
456+
457+ for variant in get_args (union ):
458+ variant = strip_annotated_type (variant )
459+ if is_basemodel_type (variant ):
460+ if PYDANTIC_V2 :
461+ field = _extract_field_schema_pv2 (variant , discriminator_field_name )
462+ if not field :
463+ continue
464+
465+ # Note: if one variant defines an alias then they all should
466+ discriminator_alias = field .get ("serialization_alias" )
467+
468+ field_schema = field ["schema" ]
469+
470+ if field_schema ["type" ] == "literal" :
471+ for entry in field_schema ["expected" ]:
472+ if isinstance (entry , str ):
473+ mapping [entry ] = variant
474+ else :
475+ field_info = cast ("dict[str, FieldInfo]" , variant .__fields__ ).get (discriminator_field_name ) # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
476+ if not field_info :
477+ continue
478+
479+ # Note: if one variant defines an alias then they all should
480+ discriminator_alias = field_info .alias
481+
482+ if field_info .annotation and is_literal_type (field_info .annotation ):
483+ for entry in get_args (field_info .annotation ):
484+ if isinstance (entry , str ):
485+ mapping [entry ] = variant
486+
487+ if not mapping :
488+ return None
489+
490+ details = DiscriminatorDetails (
491+ mapping = mapping ,
492+ discriminator_field = discriminator_field_name ,
493+ discriminator_alias = discriminator_alias ,
494+ )
495+ cast (CachedDiscriminatorType , union ).__discriminator__ = details
496+ return details
497+
498+
499+ def _extract_field_schema_pv2 (model : type [BaseModel ], field_name : str ) -> ModelField | None :
500+ schema = model .__pydantic_core_schema__
501+ if schema ["type" ] != "model" :
502+ return None
503+
504+ fields_schema = schema ["schema" ]
505+ if fields_schema ["type" ] != "model-fields" :
506+ return None
507+
508+ fields_schema = cast ("ModelFieldsSchema" , fields_schema )
509+
510+ field = fields_schema ["fields" ].get (field_name )
511+ if not field :
512+ return None
513+
514+ return cast ("ModelField" , field ) # pyright: ignore[reportUnnecessaryCast]
515+
516+
359517def validate_type (* , type_ : type [_T ], value : object ) -> _T :
360518 """Strict validation that the given value matches the expected type"""
361519 if inspect .isclass (type_ ) and issubclass (type_ , pydantic .BaseModel ):
0 commit comments