@@ -76,14 +76,14 @@ def unary_assert_against_refimpl(
7676 in_stype = dh .get_scalar_type (in_ .dtype )
7777 if res_stype is None :
7878 res_stype = in_stype
79- if res .dtype != xp .bool :
80- m , M = dh .dtype_ranges [res .dtype ]
79+ m , M = dh .dtype_ranges .get (res .dtype , (None , None ))
8180 for idx in sh .ndindex (in_ .shape ):
8281 scalar_i = in_stype (in_ [idx ])
8382 if not filter_ (scalar_i ):
8483 continue
8584 expected = refimpl (scalar_i )
8685 if res .dtype != xp .bool :
86+ assert m is not None and M is not None # for mypy
8787 if expected <= m or expected >= M :
8888 continue
8989 scalar_o = res_stype (res [idx ])
@@ -105,7 +105,7 @@ def unary_assert_against_refimpl(
105105def binary_assert_against_refimpl (
106106 func_name : str ,
107107 left : Array ,
108- right : Union [ Scalar , Array ] ,
108+ right : Array ,
109109 res : Array ,
110110 refimpl : Callable [[Scalar , Scalar ], Scalar ],
111111 expr_template : str ,
@@ -120,24 +120,23 @@ def binary_assert_against_refimpl(
120120 in_stype = dh .get_scalar_type (left .dtype )
121121 if res_stype is None :
122122 res_stype = in_stype
123- result_dtype = dh .result_type (left .dtype , right .dtype )
124- if result_dtype != xp .bool :
125- m , M = dh .dtype_ranges [result_dtype ]
123+ m , M = dh .dtype_ranges .get (res .dtype , (None , None ))
126124 for l_idx , r_idx , o_idx in sh .iter_indices (left .shape , right .shape , res .shape ):
127125 scalar_l = in_stype (left [l_idx ])
128126 scalar_r = in_stype (right [r_idx ])
129127 if not (filter_ (scalar_l ) and filter_ (scalar_r )):
130128 continue
131129 expected = refimpl (scalar_l , scalar_r )
132- if result_dtype != xp .bool :
130+ if res .dtype != xp .bool :
131+ assert m is not None and M is not None # for mypy
133132 if expected <= m or expected >= M :
134133 continue
135134 scalar_o = res_stype (res [o_idx ])
136135 f_l = sh .fmt_idx (left_sym , l_idx )
137136 f_r = sh .fmt_idx (right_sym , r_idx )
138137 f_o = sh .fmt_idx (res_name , o_idx )
139138 expr = expr_template .format (f_l , f_r , expected )
140- if dh .is_float_dtype (result_dtype ):
139+ if dh .is_float_dtype (res . dtype ):
141140 assert isclose (scalar_o , expected ), (
142141 f"{ f_o } ={ scalar_o } , but should be roughly { expr } [{ func_name } ()]\n "
143142 f"{ f_l } ={ scalar_l } , { f_r } ={ scalar_r } "
@@ -357,18 +356,18 @@ def binary_param_assert_against_refimpl(
357356):
358357 if ctx .right_is_scalar :
359358 assert filter_ (right ) # sanity check
360- if left .dtype != xp .bool :
361- m , M = dh .dtype_ranges [left .dtype ]
362359 if in_stype is None :
363360 in_stype = dh .get_scalar_type (left .dtype )
364361 if res_stype is None :
365362 res_stype = in_stype
363+ m , M = dh .dtype_ranges .get (left .dtype , (None , None ))
366364 for idx in sh .ndindex (res .shape ):
367365 scalar_l = in_stype (left [idx ])
368366 if not filter_ (scalar_l ):
369367 continue
370368 expected = refimpl (scalar_l , right )
371369 if left .dtype != xp .bool :
370+ assert m is not None and M is not None # for mypy
372371 if expected <= m or expected >= M :
373372 continue
374373 scalar_o = res_stype (res [idx ])
0 commit comments