1111
1212from typing import NamedTuple
1313from types import ModuleType
14+ import inspect
1415
1516from ._helpers import _check_device , _is_numpy_array , get_namespace
1617
@@ -161,13 +162,23 @@ class UniqueInverseResult(NamedTuple):
161162 inverse_indices : ndarray
162163
163164
165+ def _unique_kwargs (xp ):
166+ # Older versions of NumPy and CuPy do not have equal_nan. Rather than
167+ # trying to parse version numbers, just check if equal_nan is in the
168+ # signature.
169+ s = inspect .signature (xp .unique )
170+ if 'equal_nan' in s .parameters :
171+ return {'equal_nan' : False }
172+ return {}
173+
164174def unique_all (x : ndarray , / , xp ) -> UniqueAllResult :
175+ kwargs = _unique_kwargs (xp )
165176 values , indices , inverse_indices , counts = xp .unique (
166177 x ,
167178 return_counts = True ,
168179 return_index = True ,
169180 return_inverse = True ,
170- equal_nan = False ,
181+ ** kwargs ,
171182 )
172183 # np.unique() flattens inverse indices, but they need to share x's shape
173184 # See https://github.com/numpy/numpy/issues/20638
@@ -181,24 +192,26 @@ def unique_all(x: ndarray, /, xp) -> UniqueAllResult:
181192
182193
183194def unique_counts (x : ndarray , / , xp ) -> UniqueCountsResult :
195+ kwargs = _unique_kwargs (xp )
184196 res = xp .unique (
185197 x ,
186198 return_counts = True ,
187199 return_index = False ,
188200 return_inverse = False ,
189- equal_nan = False ,
201+ ** kwargs
190202 )
191203
192204 return UniqueCountsResult (* res )
193205
194206
195207def unique_inverse (x : ndarray , / , xp ) -> UniqueInverseResult :
208+ kwargs = _unique_kwargs (xp )
196209 values , inverse_indices = xp .unique (
197210 x ,
198211 return_counts = False ,
199212 return_index = False ,
200213 return_inverse = True ,
201- equal_nan = False ,
214+ ** kwargs ,
202215 )
203216 # xp.unique() flattens inverse indices, but they need to share x's shape
204217 # See https://github.com/numpy/numpy/issues/20638
@@ -207,12 +220,13 @@ def unique_inverse(x: ndarray, /, xp) -> UniqueInverseResult:
207220
208221
209222def unique_values (x : ndarray , / , xp ) -> ndarray :
223+ kwargs = _unique_kwargs (xp )
210224 return xp .unique (
211225 x ,
212226 return_counts = False ,
213227 return_index = False ,
214228 return_inverse = False ,
215- equal_nan = False ,
229+ ** kwargs ,
216230 )
217231
218232def astype (x : ndarray , dtype : Dtype , / , * , copy : bool = True ) -> ndarray :
@@ -295,8 +309,13 @@ def _asarray(
295309 _check_device (xp , device )
296310 if _is_numpy_array (obj ):
297311 import numpy as np
298- COPY_FALSE = (False , np ._CopyMode .IF_NEEDED )
299- COPY_TRUE = (True , np ._CopyMode .ALWAYS )
312+ if hasattr (np , '_CopyMode' ):
313+ # Not present in older NumPys
314+ COPY_FALSE = (False , np ._CopyMode .IF_NEEDED )
315+ COPY_TRUE = (True , np ._CopyMode .ALWAYS )
316+ else :
317+ COPY_FALSE = (False ,)
318+ COPY_TRUE = (True ,)
300319 else :
301320 COPY_FALSE = (False ,)
302321 COPY_TRUE = (True ,)
0 commit comments