@@ -472,10 +472,49 @@ def vecdot(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1) -> ndarray:
472472 res = x1_ [..., None , :] @ x2_ [..., None ]
473473 return res [..., 0 , 0 ]
474474
475+ # isdtype is a new function in the 2022.12 array API specification.
476+
477+ def isdtype (
478+ dtype : Dtype , kind : Union [Dtype , str , Tuple [Union [Dtype , str ], ...]], xp ,
479+ * , _tuple = True , # Disallow nested tuples
480+ ) -> bool :
481+ """
482+ Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``.
483+
484+ See
485+ https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html
486+ for more details
487+ """
488+ if isinstance (kind , tuple ) and _tuple :
489+ return any (isdtype (dtype , k , xp , _tuple = False ) for k in kind )
490+ elif isinstance (kind , str ):
491+ if kind == 'bool' :
492+ return dtype == xp .bool_
493+ elif kind == 'signed integer' :
494+ return xp .issubdtype (dtype , xp .signedinteger )
495+ elif kind == 'unsigned integer' :
496+ return xp .issubdtype (dtype , xp .unsignedinteger )
497+ elif kind == 'integral' :
498+ return xp .issubdtype (dtype , xp .integer )
499+ elif kind == 'real floating' :
500+ return xp .issubdtype (dtype , xp .floating )
501+ elif kind == 'complex floating' :
502+ return xp .issubdtype (dtype , xp .complexfloating )
503+ elif kind == 'numeric' :
504+ return xp .issubdtype (dtype , xp .number )
505+ else :
506+ raise ValueError (f"Unrecognized data type kind: { kind !r} " )
507+ else :
508+ # This will allow things that aren't required by the spec, like
509+ # isdtype(np.float64, float) or isdtype(np.int64, 'l'). Should we be
510+ # more strict here to match the type annotation? Note that the
511+ # numpy.array_api implementation will be very strict.
512+ return dtype == kind
513+
475514__all__ = ['arange' , 'empty' , 'empty_like' , 'eye' , 'full' , 'full_like' ,
476515 'linspace' , 'ones' , 'ones_like' , 'zeros' , 'zeros_like' ,
477516 'UniqueAllResult' , 'UniqueCountsResult' , 'UniqueInverseResult' ,
478517 'unique_all' , 'unique_counts' , 'unique_inverse' , 'unique_values' ,
479518 'astype' , 'std' , 'var' , 'permute_dims' , 'reshape' , 'argsort' ,
480519 'sort' , 'sum' , 'prod' , 'ceil' , 'floor' , 'trunc' , 'matmul' ,
481- 'matrix_transpose' , 'tensordot' , 'vecdot' ]
520+ 'matrix_transpose' , 'tensordot' , 'vecdot' , 'isdtype' ]
0 commit comments