2525import dpctl .utils
2626from dpctl .tensor ._device import normalize_queue_device
2727
28+ __doc__ = "Implementation of creation functions in :module:`dpctl.tensor`"
29+
2830_empty_tuple = tuple ()
2931_host_set = frozenset ([None ])
3032
@@ -34,45 +36,42 @@ def _get_dtype(dtype, sycl_obj, ref_type=None):
3436 if ref_type in [None , float ] or np .issubdtype (ref_type , np .floating ):
3537 dtype = ti .default_device_fp_type (sycl_obj )
3638 return dpt .dtype (dtype )
37- elif ref_type in [bool , np .bool_ ]:
39+ if ref_type in [bool , np .bool_ ]:
3840 dtype = ti .default_device_bool_type (sycl_obj )
3941 return dpt .dtype (dtype )
40- elif ref_type is int or np .issubdtype (ref_type , np .integer ):
42+ if ref_type is int or np .issubdtype (ref_type , np .integer ):
4143 dtype = ti .default_device_int_type (sycl_obj )
4244 return dpt .dtype (dtype )
43- elif ref_type is complex or np .issubdtype (ref_type , np .complexfloating ):
45+ if ref_type is complex or np .issubdtype (ref_type , np .complexfloating ):
4446 dtype = ti .default_device_complex_type (sycl_obj )
4547 return dpt .dtype (dtype )
46- else :
47- raise TypeError (f"Reference type { ref_type } not recognized." )
48- else :
49- return dpt .dtype (dtype )
48+ raise TypeError (f"Reference type { ref_type } not recognized." )
49+ return dpt .dtype (dtype )
5050
5151
5252def _array_info_dispatch (obj ):
5353 if isinstance (obj , dpt .usm_ndarray ):
5454 return obj .shape , obj .dtype , frozenset ([obj .sycl_queue ])
55- elif isinstance (obj , np .ndarray ):
55+ if isinstance (obj , np .ndarray ):
5656 return obj .shape , obj .dtype , _host_set
57- elif isinstance (obj , range ):
57+ if isinstance (obj , range ):
5858 return (len (obj ),), int , _host_set
59- elif isinstance (obj , bool ):
59+ if isinstance (obj , bool ):
6060 return _empty_tuple , bool , _host_set
61- elif isinstance (obj , float ):
61+ if isinstance (obj , float ):
6262 return _empty_tuple , float , _host_set
63- elif isinstance (obj , int ):
63+ if isinstance (obj , int ):
6464 return _empty_tuple , int , _host_set
65- elif isinstance (obj , complex ):
65+ if isinstance (obj , complex ):
6666 return _empty_tuple , complex , _host_set
67- elif isinstance (obj , (list , tuple , range )):
67+ if isinstance (obj , (list , tuple , range )):
6868 return _array_info_sequence (obj )
69- elif any (
69+ if any (
7070 isinstance (obj , s )
7171 for s in [np .integer , np .floating , np .complexfloating , np .bool_ ]
7272 ):
7373 return _empty_tuple , obj .dtype , _host_set
74- else :
75- raise ValueError (type (obj ))
74+ raise ValueError (type (obj ))
7675
7776
7877def _array_info_sequence (li ):
@@ -91,9 +90,7 @@ def _array_info_sequence(li):
9190 dt = np .promote_types (dt , el_dt )
9291 device = device .union (el_dev )
9392 else :
94- raise ValueError (
95- "Inconsistent dimensions, {} and {}" .format (dim , el_dim )
96- )
93+ raise ValueError (f"Inconsistent dimensions, { dim } and { el_dim } " )
9794 if dim is None :
9895 dim = tuple ()
9996 dt = float
@@ -206,18 +203,18 @@ def _map_to_device_dtype(dt, q):
206203 if np .issubdtype (dt , np .floating ):
207204 if dtc == "f" :
208205 return dt
209- else :
210- if dtc == "d" and d .has_aspect_fp64 :
211- return dt
212- if dtc == "h" and d .has_aspect_fp16 :
213- return dt
214- return dpt .dtype ("f4" )
215- elif np .issubdtype (dt , np .complexfloating ):
206+ if dtc == "d" and d .has_aspect_fp64 :
207+ return dt
208+ if dtc == "h" and d .has_aspect_fp16 :
209+ return dt
210+ return dpt .dtype ("f4" )
211+ if np .issubdtype (dt , np .complexfloating ):
216212 if dtc == "F" :
217213 return dt
218214 if dtc == "D" and d .has_aspect_fp64 :
219215 return dt
220216 return dpt .dtype ("c8" )
217+ raise RuntimeError (f"Unrecognized data type '{ dt } ' encountered." )
221218
222219
223220def _asarray_from_numpy_ndarray (
@@ -349,8 +346,7 @@ def asarray(
349346 raise ValueError (
350347 "Unrecognized order keyword value, expecting 'K', 'A', 'F', or 'C'."
351348 )
352- else :
353- order = order [0 ].upper ()
349+ order = order [0 ].upper ()
354350 # 4. Check that usm_type is None, or a valid value
355351 dpctl .utils .validate_usm_type (usm_type , allow_none = True )
356352 # 5. Normalize device/sycl_queue [keep it None if was None]
@@ -369,7 +365,7 @@ def asarray(
369365 sycl_queue = sycl_queue ,
370366 order = order ,
371367 )
372- elif hasattr (obj , "__sycl_usm_array_interface__" ):
368+ if hasattr (obj , "__sycl_usm_array_interface__" ):
373369 sua_iface = getattr (obj , "__sycl_usm_array_interface__" )
374370 membuf = dpm .as_usm_memory (obj )
375371 ary = dpt .usm_ndarray (
@@ -386,7 +382,7 @@ def asarray(
386382 sycl_queue = sycl_queue ,
387383 order = order ,
388384 )
389- elif isinstance (obj , np .ndarray ):
385+ if isinstance (obj , np .ndarray ):
390386 if copy is False :
391387 raise ValueError (
392388 "Converting numpy.ndarray to usm_ndarray requires a copy"
@@ -398,7 +394,7 @@ def asarray(
398394 sycl_queue = sycl_queue ,
399395 order = order ,
400396 )
401- elif _is_object_with_buffer_protocol (obj ):
397+ if _is_object_with_buffer_protocol (obj ):
402398 if copy is False :
403399 raise ValueError (
404400 f"Converting { type (obj )} to usm_ndarray requires a copy"
@@ -410,12 +406,12 @@ def asarray(
410406 sycl_queue = sycl_queue ,
411407 order = order ,
412408 )
413- elif isinstance (obj , (list , tuple , range )):
409+ if isinstance (obj , (list , tuple , range )):
414410 if copy is False :
415411 raise ValueError (
416412 "Converting Python sequence to usm_ndarray requires a copy"
417413 )
418- _ , dt , devs = _array_info_sequence (obj )
414+ _ , _ , devs = _array_info_sequence (obj )
419415 if devs == _host_set :
420416 return _asarray_from_numpy_ndarray (
421417 np .asarray (obj , dtype = dtype , order = order ),
@@ -474,8 +470,7 @@ def empty(
474470 raise ValueError (
475471 "Unrecognized order keyword value, expecting 'F' or 'C'."
476472 )
477- else :
478- order = order [0 ].upper ()
473+ order = order [0 ].upper ()
479474 dpctl .utils .validate_usm_type (usm_type , allow_none = False )
480475 sycl_queue = normalize_queue_device (sycl_queue = sycl_queue , device = device )
481476 dtype = _get_dtype (dtype , sycl_queue )
@@ -497,14 +492,13 @@ def _coerce_and_infer_dt(*args, dt, sycl_queue, err_msg, allow_bool=False):
497492 dt = _get_dtype (dt , sycl_queue , ref_type = seq_dt )
498493 if np .issubdtype (dt , np .integer ):
499494 return tuple (int (v ) for v in args ), dt
500- elif np .issubdtype (dt , np .floating ):
495+ if np .issubdtype (dt , np .floating ):
501496 return tuple (float (v ) for v in args ), dt
502- elif np .issubdtype (dt , np .complexfloating ):
497+ if np .issubdtype (dt , np .complexfloating ):
503498 return tuple (complex (v ) for v in args ), dt
504- elif allow_bool and dt .char == "?" :
499+ if allow_bool and dt .char == "?" :
505500 return tuple (bool (v ) for v in args ), dt
506- else :
507- raise ValueError (f"Data type { dt } is not supported" )
501+ raise ValueError (f"Data type { dt } is not supported" )
508502
509503
510504def _round_for_arange (tmp ):
@@ -570,7 +564,7 @@ def arange(
570564 is_bool = False
571565 if dtype :
572566 is_bool = (dtype is bool ) or (dpt .dtype (dtype ) == dpt .bool )
573- ( start_ , stop_ , step_ ) , dt = _coerce_and_infer_dt (
567+ _ , dt = _coerce_and_infer_dt (
574568 start ,
575569 stop ,
576570 step ,
@@ -581,9 +575,7 @@ def arange(
581575 )
582576 try :
583577 tmp = _get_arange_length (start , stop , step )
584- sh = int (tmp )
585- if sh < 0 :
586- sh = 0
578+ sh = max (int (tmp ), 0 )
587579 except TypeError :
588580 sh = 0
589581 if is_bool and sh > 2 :
@@ -655,8 +647,7 @@ def zeros(
655647 raise ValueError (
656648 "Unrecognized order keyword value, expecting 'F' or 'C'."
657649 )
658- else :
659- order = order [0 ].upper ()
650+ order = order [0 ].upper ()
660651 dpctl .utils .validate_usm_type (usm_type , allow_none = False )
661652 sycl_queue = normalize_queue_device (sycl_queue = sycl_queue , device = device )
662653 dtype = _get_dtype (dtype , sycl_queue )
@@ -703,8 +694,7 @@ def ones(
703694 raise ValueError (
704695 "Unrecognized order keyword value, expecting 'F' or 'C'."
705696 )
706- else :
707- order = order [0 ].upper ()
697+ order = order [0 ].upper ()
708698 dpctl .utils .validate_usm_type (usm_type , allow_none = False )
709699 sycl_queue = normalize_queue_device (sycl_queue = sycl_queue , device = device )
710700 dtype = _get_dtype (dtype , sycl_queue )
@@ -715,7 +705,7 @@ def ones(
715705 order = order ,
716706 buffer_ctor_kwargs = {"queue" : sycl_queue },
717707 )
718- hev , ev = ti ._full_usm_ndarray (1 , res , sycl_queue )
708+ hev , _ = ti ._full_usm_ndarray (1 , res , sycl_queue )
719709 hev .wait ()
720710 return res
721711
@@ -759,8 +749,7 @@ def full(
759749 raise ValueError (
760750 "Unrecognized order keyword value, expecting 'F' or 'C'."
761751 )
762- else :
763- order = order [0 ].upper ()
752+ order = order [0 ].upper ()
764753 dpctl .utils .validate_usm_type (usm_type , allow_none = False )
765754 sycl_queue = normalize_queue_device (sycl_queue = sycl_queue , device = device )
766755 dtype = _get_dtype (dtype , sycl_queue , ref_type = type (fill_value ))
@@ -771,7 +760,7 @@ def full(
771760 order = order ,
772761 buffer_ctor_kwargs = {"queue" : sycl_queue },
773762 )
774- hev , ev = ti ._full_usm_ndarray (fill_value , res , sycl_queue )
763+ hev , _ = ti ._full_usm_ndarray (fill_value , res , sycl_queue )
775764 hev .wait ()
776765 return res
777766
@@ -811,8 +800,7 @@ def empty_like(
811800 raise ValueError (
812801 "Unrecognized order keyword value, expecting 'F' or 'C'."
813802 )
814- else :
815- order = order [0 ].upper ()
803+ order = order [0 ].upper ()
816804 if dtype is None :
817805 dtype = x .dtype
818806 if usm_type is None :
@@ -868,8 +856,7 @@ def zeros_like(
868856 raise ValueError (
869857 "Unrecognized order keyword value, expecting 'F' or 'C'."
870858 )
871- else :
872- order = order [0 ].upper ()
859+ order = order [0 ].upper ()
873860 if dtype is None :
874861 dtype = x .dtype
875862 if usm_type is None :
@@ -925,8 +912,7 @@ def ones_like(
925912 raise ValueError (
926913 "Unrecognized order keyword value, expecting 'F' or 'C'."
927914 )
928- else :
929- order = order [0 ].upper ()
915+ order = order [0 ].upper ()
930916 if dtype is None :
931917 dtype = x .dtype
932918 if usm_type is None :
@@ -989,8 +975,7 @@ def full_like(
989975 raise ValueError (
990976 "Unrecognized order keyword value, expecting 'F' or 'C'."
991977 )
992- else :
993- order = order [0 ].upper ()
978+ order = order [0 ].upper ()
994979 if dtype is None :
995980 dtype = x .dtype
996981 if usm_type is None :
@@ -1142,8 +1127,7 @@ def eye(
11421127 raise ValueError (
11431128 "Unrecognized order keyword value, expecting 'F' or 'C'."
11441129 )
1145- else :
1146- order = order [0 ].upper ()
1130+ order = order [0 ].upper ()
11471131 n_rows = operator .index (n_rows )
11481132 n_cols = n_rows if n_cols is None else operator .index (n_cols )
11491133 k = operator .index (k )
@@ -1178,12 +1162,14 @@ def tril(X, k=0):
11781162
11791163 Returns the lower triangular part of a matrix (or a stack of matrices) X.
11801164 """
1181- if type (X ) is not dpt .usm_ndarray :
1182- raise TypeError
1165+ if not isinstance (X , dpt .usm_ndarray ):
1166+ raise TypeError (
1167+ "Expected argument of type dpctl.tensor.usm_ndarray, "
1168+ f"got { type (X )} ."
1169+ )
11831170
11841171 k = operator .index (k )
11851172
1186- # F_CONTIGUOUS = 2
11871173 order = "F" if (X .flags .f_contiguous ) else "C"
11881174
11891175 shape = X .shape
@@ -1219,12 +1205,14 @@ def triu(X, k=0):
12191205
12201206 Returns the upper triangular part of a matrix (or a stack of matrices) X.
12211207 """
1222- if type (X ) is not dpt .usm_ndarray :
1223- raise TypeError
1208+ if not isinstance (X , dpt .usm_ndarray ):
1209+ raise TypeError (
1210+ "Expected argument of type dpctl.tensor.usm_ndarray, "
1211+ f"got { type (X )} ."
1212+ )
12241213
12251214 k = operator .index (k )
12261215
1227- # F_CONTIGUOUS = 2
12281216 order = "F" if (X .flags .f_contiguous ) else "C"
12291217
12301218 shape = X .shape
0 commit comments