@@ -66,14 +66,12 @@ def unary_assert_against_refimpl(
6666 res : Array ,
6767 refimpl : Callable [[Scalar ], Scalar ],
6868 expr_template : str ,
69- in_stype : Optional [ScalarType ] = None ,
7069 res_stype : Optional [ScalarType ] = None ,
7170 filter_ : Callable [[Scalar ], bool ] = math .isfinite ,
7271):
7372 if in_ .shape != res .shape :
7473 raise ValueError (f"{ res .shape = } , but should be { in_ .shape = } " )
75- if in_stype is None :
76- in_stype = dh .get_scalar_type (in_ .dtype )
74+ in_stype = dh .get_scalar_type (in_ .dtype )
7775 if res_stype is None :
7876 res_stype = in_stype
7977 m , M = dh .dtype_ranges .get (res .dtype , (None , None ))
@@ -109,15 +107,13 @@ def binary_assert_against_refimpl(
109107 res : Array ,
110108 refimpl : Callable [[Scalar , Scalar ], Scalar ],
111109 expr_template : str ,
112- in_stype : Optional [ScalarType ] = None ,
113110 res_stype : Optional [ScalarType ] = None ,
114111 left_sym : str = "x1" ,
115112 right_sym : str = "x2" ,
116113 res_name : str = "out" ,
117114 filter_ : Callable [[Scalar ], bool ] = math .isfinite ,
118115):
119- if in_stype is None :
120- in_stype = dh .get_scalar_type (left .dtype )
116+ in_stype = dh .get_scalar_type (left .dtype )
121117 if res_stype is None :
122118 res_stype = in_stype
123119 m , M = dh .dtype_ranges .get (res .dtype , (None , None ))
@@ -350,14 +346,12 @@ def binary_param_assert_against_refimpl(
350346 res : Array ,
351347 refimpl : Callable [[Scalar , Scalar ], Scalar ],
352348 expr_template : str ,
353- in_stype : Optional [ScalarType ] = None ,
354349 res_stype : Optional [ScalarType ] = None ,
355350 filter_ : Callable [[Scalar ], bool ] = math .isfinite ,
356351):
357352 if ctx .right_is_scalar :
358353 assert filter_ (right ) # sanity check
359- if in_stype is None :
360- in_stype = dh .get_scalar_type (left .dtype )
354+ in_stype = dh .get_scalar_type (left .dtype )
361355 if res_stype is None :
362356 res_stype = in_stype
363357 m , M = dh .dtype_ranges .get (left .dtype , (None , None ))
@@ -389,7 +383,6 @@ def binary_param_assert_against_refimpl(
389383 else :
390384 binary_assert_against_refimpl (
391385 func_name = ctx .func_name ,
392- in_stype = in_stype ,
393386 left_sym = ctx .left_sym ,
394387 left = left ,
395388 right_sym = ctx .right_sym ,
0 commit comments