2020import dpctl .tensor ._tensor_impl as ti
2121from dpctl .tensor ._device import normalize_queue_device
2222
23+ __doc__ = (
24+ "Implementation module for copy- and cast- operations on "
25+ ":class:`dpctl.tensor.usm_ndarray`."
26+ )
27+
2328
2429def _has_memory_overlap (x1 , x2 ):
2530 if x1 .size and x2 .size :
@@ -33,15 +38,13 @@ def _has_memory_overlap(x1, x2):
3338 p2_end = p2_beg + m2 .nbytes
3439 # may intersect if not ((p1_beg >= p2_end) or (p2_beg >= p2_end))
3540 return (p1_beg < p2_end ) and (p2_beg < p1_end )
36- else :
37- return False
38- else :
39- # zero element array do not overlap anything
4041 return False
42+ # zero element array do not overlap anything
43+ return False
4144
4245
4346def _copy_to_numpy (ary ):
44- if type (ary ) is not dpt .usm_ndarray :
47+ if not isinstance (ary , dpt .usm_ndarray ) :
4548 raise TypeError
4649 h = ary .usm_data .copy_to_host ().view (ary .dtype )
4750 itsz = ary .itemsize
@@ -78,9 +81,9 @@ def _copy_from_numpy(np_ary, usm_type="device", sycl_queue=None):
7881def _copy_from_numpy_into (dst , np_ary ):
7982 "Copies `np_ary` into `dst` of type :class:`dpctl.tensor.usm_ndarray"
8083 if not isinstance (np_ary , np .ndarray ):
81- raise TypeError ("Expected numpy.ndarray, got {}" . format ( type (np_ary )) )
84+ raise TypeError (f "Expected numpy.ndarray, got { type (np_ary )} " )
8285 if not isinstance (dst , dpt .usm_ndarray ):
83- raise TypeError ("Expected usm_ndarray, got {}" . format ( type (dst )) )
86+ raise TypeError (f "Expected usm_ndarray, got { type (dst )} " )
8487 src_ary = np .broadcast_to (np_ary , dst .shape )
8588 copy_q = dst .sycl_queue
8689 if copy_q .sycl_device .has_aspect_fp64 is False :
@@ -143,6 +146,8 @@ def asnumpy(usm_ary):
143146
144147
145148class Dummy :
149+ "Helper class with specified __sycl_usm_array_interface__ attribute"
150+
146151 def __init__ (self , iface ):
147152 self .__sycl_usm_array_interface__ = iface
148153
@@ -160,7 +165,7 @@ def _copy_overlapping(dst, src):
160165 hcp1 , cp1 = ti ._copy_usm_ndarray_into_usm_ndarray (
161166 src = src , dst = tmp , sycl_queue = q
162167 )
163- hcp2 , cp2 = ti ._copy_usm_ndarray_into_usm_ndarray (
168+ hcp2 , _ = ti ._copy_usm_ndarray_into_usm_ndarray (
164169 src = tmp , dst = dst , sycl_queue = q , depends = [cp1 ]
165170 )
166171 hcp2 .wait ()
@@ -174,7 +179,7 @@ def _copy_same_shape(dst, src):
174179 _copy_overlapping (src = src , dst = dst )
175180 return
176181
177- hev , ev = ti ._copy_usm_ndarray_into_usm_ndarray (
182+ hev , _ = ti ._copy_usm_ndarray_into_usm_ndarray (
178183 src = src , dst = dst , sycl_queue = dst .sycl_queue
179184 )
180185 hev .wait ()
@@ -197,7 +202,13 @@ def _broadcast_shapes(sh1, sh2):
197202
198203
199204def _copy_from_usm_ndarray_to_usm_ndarray (dst , src ):
200- if type (dst ) is not dpt .usm_ndarray or type (src ) is not dpt .usm_ndarray :
205+ if any (
206+ not isinstance (arg , dpt .usm_ndarray )
207+ for arg in (
208+ dst ,
209+ src ,
210+ )
211+ ):
201212 raise TypeError (
202213 "Both types are expected to be dpctl.tensor.usm_ndarray, "
203214 f"got { type (dst )} and { type (src )} ."
@@ -209,8 +220,8 @@ def _copy_from_usm_ndarray_to_usm_ndarray(dst, src):
209220
210221 try :
211222 common_shape = _broadcast_shapes (dst .shape , src .shape )
212- except ValueError :
213- raise ValueError ("Shapes of two arrays are not compatible" )
223+ except ValueError as exc :
224+ raise ValueError ("Shapes of two arrays are not compatible" ) from exc
214225
215226 if dst .size < src .size :
216227 raise ValueError ("Destination is smaller " )
@@ -251,9 +262,7 @@ def copy(usm_ary, order="K"):
251262 """
252263 if not isinstance (usm_ary , dpt .usm_ndarray ):
253264 return TypeError (
254- "Expected object of type dpt.usm_ndarray, got {}" .format (
255- type (usm_ary )
256- )
265+ f"Expected object of type dpt.usm_ndarray, got { type (usm_ary )} "
257266 )
258267 copy_order = "C"
259268 if order == "C" :
@@ -308,9 +317,7 @@ def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True):
308317 """
309318 if not isinstance (usm_ary , dpt .usm_ndarray ):
310319 return TypeError (
311- "Expected object of type dpt.usm_ndarray, got {}" .format (
312- type (usm_ary )
313- )
320+ f"Expected object of type dpt.usm_ndarray, got { type (usm_ary )} "
314321 )
315322 if not isinstance (order , str ) or order not in ["A" , "C" , "F" , "K" ]:
316323 raise ValueError (
@@ -321,56 +328,54 @@ def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True):
321328 target_dtype = dpt .dtype (newdtype )
322329 if not dpt .can_cast (ary_dtype , target_dtype , casting = casting ):
323330 raise TypeError (
324- "Can not cast from {} to {} according to rule {}" .format (
325- ary_dtype , newdtype , casting
326- )
331+ f"Can not cast from { ary_dtype } to { newdtype } "
332+ f"according to rule { casting } ."
327333 )
328334 c_contig = usm_ary .flags .c_contiguous
329335 f_contig = usm_ary .flags .f_contiguous
330- needs_copy = copy or not ( ary_dtype == target_dtype )
336+ needs_copy = copy or not ary_dtype == target_dtype
331337 if not needs_copy and (order != "K" ):
332338 needs_copy = (c_contig and order not in ["A" , "C" ]) or (
333339 f_contig and order not in ["A" , "F" ]
334340 )
335- if needs_copy :
336- copy_order = "C"
337- if order == "C" :
338- pass
339- elif order == "F" :
340- copy_order = order
341- elif order == "A" :
342- if usm_ary .flags .f_contiguous :
343- copy_order = "F"
344- elif order == "K" :
345- if usm_ary .flags .f_contiguous :
346- copy_order = "F"
347- else :
348- raise ValueError (
349- "Unrecognized value of the order keyword. "
350- "Recognized values are 'A', 'C', 'F', or 'K'"
351- )
341+ if not needs_copy :
342+ return usm_ary
343+ copy_order = "C"
344+ if order == "C" :
345+ pass
346+ elif order == "F" :
347+ copy_order = order
348+ elif order == "A" :
349+ if usm_ary .flags .f_contiguous :
350+ copy_order = "F"
351+ elif order == "K" :
352+ if usm_ary .flags .f_contiguous :
353+ copy_order = "F"
354+ else :
355+ raise ValueError (
356+ "Unrecognized value of the order keyword. "
357+ "Recognized values are 'A', 'C', 'F', or 'K'"
358+ )
359+ R = dpt .usm_ndarray (
360+ usm_ary .shape ,
361+ dtype = target_dtype ,
362+ buffer = usm_ary .usm_type ,
363+ order = copy_order ,
364+ buffer_ctor_kwargs = {"queue" : usm_ary .sycl_queue },
365+ )
366+ if order == "K" and (not c_contig and not f_contig ):
367+ original_strides = usm_ary .strides
368+ ind = sorted (
369+ range (usm_ary .ndim ),
370+ key = lambda i : abs (original_strides [i ]),
371+ reverse = True ,
372+ )
373+ new_strides = tuple (R .strides [ind [i ]] for i in ind )
352374 R = dpt .usm_ndarray (
353375 usm_ary .shape ,
354376 dtype = target_dtype ,
355- buffer = usm_ary .usm_type ,
356- order = copy_order ,
357- buffer_ctor_kwargs = {"queue" : usm_ary .sycl_queue },
377+ buffer = R .usm_data ,
378+ strides = new_strides ,
358379 )
359- if order == "K" and (not c_contig and not f_contig ):
360- original_strides = usm_ary .strides
361- ind = sorted (
362- range (usm_ary .ndim ),
363- key = lambda i : abs (original_strides [i ]),
364- reverse = True ,
365- )
366- new_strides = tuple (R .strides [ind [i ]] for i in ind )
367- R = dpt .usm_ndarray (
368- usm_ary .shape ,
369- dtype = target_dtype ,
370- buffer = R .usm_data ,
371- strides = new_strides ,
372- )
373- _copy_from_usm_ndarray_to_usm_ndarray (R , usm_ary )
374- return R
375- else :
376- return usm_ary
380+ _copy_from_usm_ndarray_to_usm_ndarray (R , usm_ary )
381+ return R
0 commit comments