2121from dpctl .tensor ._copy_utils import _copy_from_usm_ndarray_to_usm_ndarray
2222from dpctl .tensor ._tensor_impl import _copy_usm_ndarray_for_reshape
2323
24+ __doc__ = "Implementation module for :func:`dpctl.tensor.reshape`."
25+
2426
2527def _make_unit_indexes (shape ):
2628 """
@@ -67,10 +69,8 @@ def reshaped_strides(old_sh, old_sts, new_sh, order="C"):
6769 ]
6870 ]
6971 valid = all (
70- [
71- check_st == old_st or old_dim == 1
72- for check_st , old_st , old_dim in zip (check_sts , old_sts , old_sh )
73- ]
72+ check_st == old_st or old_dim == 1
73+ for check_st , old_st , old_dim in zip (check_sts , old_sts , old_sh )
7474 )
7575 return new_sts if valid else None
7676
@@ -82,7 +82,7 @@ def reshape(X, newshape, order="C", copy=None):
8282 Reshapes given usm_ndarray into new shape. Returns a view, if possible,
8383 a copy otherwise. Memory layout of the copy is controlled by order keyword.
8484 """
85- if type ( X ) is not dpt .usm_ndarray :
85+ if not isinstance ( X , dpt .usm_ndarray ) :
8686 raise TypeError
8787 if not isinstance (newshape , (list , tuple )):
8888 newshape = (newshape ,)
@@ -99,10 +99,10 @@ def reshape(X, newshape, order="C", copy=None):
9999 )
100100 newshape = [operator .index (d ) for d in newshape ]
101101 negative_ones_count = 0
102- for i in range ( len ( newshape )) :
103- if newshape [ i ] == - 1 :
102+ for nshi in newshape :
103+ if nshi == - 1 :
104104 negative_ones_count = negative_ones_count + 1
105- if (newshape [ i ] < - 1 ) or negative_ones_count > 1 :
105+ if (nshi < - 1 ) or negative_ones_count > 1 :
106106 raise ValueError (
107107 "Target shape should have at most 1 negative "
108108 "value which can only be -1"
@@ -111,7 +111,7 @@ def reshape(X, newshape, order="C", copy=None):
111111 v = X .size // (- np .prod (newshape ))
112112 newshape = [v if d == - 1 else d for d in newshape ]
113113 if X .size != np .prod (newshape ):
114- raise ValueError ("Can not reshape into {}" . format ( newshape ) )
114+ raise ValueError (f "Can not reshape into { newshape } " )
115115 if X .size :
116116 newsts = reshaped_strides (X .shape , X .strides , newshape , order = order )
117117 else :
@@ -143,12 +143,11 @@ def reshape(X, newshape, order="C", copy=None):
143143 return dpt .usm_ndarray (
144144 tuple (newshape ), dtype = X .dtype , buffer = flat_res , order = order
145145 )
146- else :
147- # can form a view
148- return dpt .usm_ndarray (
149- newshape ,
150- dtype = X .dtype ,
151- buffer = X ,
152- strides = tuple (newsts ),
153- offset = X .__sycl_usm_array_interface__ .get ("offset" , 0 ),
154- )
146+ # can form a view
147+ return dpt .usm_ndarray (
148+ newshape ,
149+ dtype = X .dtype ,
150+ buffer = X ,
151+ strides = tuple (newsts ),
152+ offset = X .__sycl_usm_array_interface__ .get ("offset" , 0 ),
153+ )
0 commit comments