@@ -47,7 +47,7 @@ def _spec_dtypes(library):
4747 'unsigned integer' : lambda d : d .startswith ('uint' ),
4848 'integral' : lambda d : dtype_categories ['signed integer' ](d ) or
4949 dtype_categories ['unsigned integer' ](d ),
50- 'real floating' : lambda d : d . startswith ( 'float' ) ,
50+ 'real floating' : lambda d : 'float' in d ,
5151 'complex floating' : lambda d : d .startswith ('complex' ),
5252 'numeric' : lambda d : dtype_categories ['integral' ](d ) or
5353 dtype_categories ['real floating' ](d ) or
@@ -90,3 +90,25 @@ def test_isdtype_spec_dtypes(library):
9090
9191 res = isdtype_ (dtype_ , kind1_ ) or isdtype_ (dtype_ , kind2_ )
9292 assert isdtype (dtype , kind ) == res , (dtype_ , (kind1_ , kind2_ ))
93+
94+ additional_dtypes = [
95+ 'float16' ,
96+ 'float128' ,
97+ 'complex256' ,
98+ 'bfloat16' ,
99+ ]
100+
101+ @pytest .mark .parametrize ("library" , ["cupy" , "numpy" , "torch" ])
102+ @pytest .mark .parametrize ("dtype_" , additional_dtypes )
103+ def test_isdtype_additional_dtypes (library , dtype_ ):
104+ xp = import_ ('array_api_compat.' + library )
105+
106+ isdtype = xp .isdtype
107+
108+ if not hasattr (xp , dtype_ ):
109+ pytest .skip (f"{ library } doesn't have dtype { dtype_ } " )
110+
111+ dtype = getattr (xp , dtype_ )
112+ for cat in dtype_categories :
113+ res = isdtype_ (dtype_ , cat )
114+ assert isdtype (dtype , cat ) == res , (dtype_ , cat )
0 commit comments