Skip to content

Commit 98c7417

Browse files
committed
Introduce utility function _cast_fill_val to reduce code duplication in full and full_like
1 parent 4138cb4 commit 98c7417

File tree

1 file changed

+26
-27
lines changed

1 file changed

+26
-27
lines changed

dpctl/tensor/_ctors.py

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1021,6 +1021,22 @@ def ones(
10211021
return res
10221022

10231023

1024+
def _cast_fill_val(fill_val, dt):
1025+
"""
1026+
Casts the Python scalar `fill_val` to another Python type coercible to the
1027+
requested data type `dt`, if necessary.
1028+
"""
1029+
val_type = type(fill_val)
1030+
if val_type in [float, complex] and np.issubdtype(dt, np.integer):
1031+
return int(fill_val.real)
1032+
elif val_type is complex and np.issubdtype(dt, np.floating):
1033+
return fill_val.real
1034+
elif val_type is int and np.issubdtype(dt, np.integer):
1035+
return _to_scalar(fill_val, dt)
1036+
else:
1037+
return fill_val
1038+
1039+
10241040
def full(
10251041
shape,
10261042
fill_value,
@@ -1097,21 +1113,15 @@ def full(
10971113

10981114
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
10991115
usm_type = usm_type if usm_type is not None else "device"
1100-
fill_value_type = type(fill_value)
1101-
dtype = _get_dtype(dtype, sycl_queue, ref_type=fill_value_type)
1116+
dtype = _get_dtype(dtype, sycl_queue, ref_type=type(fill_value))
11021117
res = dpt.usm_ndarray(
11031118
shape,
11041119
dtype=dtype,
11051120
buffer=usm_type,
11061121
order=order,
11071122
buffer_ctor_kwargs={"queue": sycl_queue},
11081123
)
1109-
if fill_value_type in [float, complex] and np.issubdtype(dtype, np.integer):
1110-
fill_value = int(fill_value.real)
1111-
elif fill_value_type is complex and np.issubdtype(dtype, np.floating):
1112-
fill_value = fill_value.real
1113-
elif fill_value_type is int and np.issubdtype(dtype, np.integer):
1114-
fill_value = _to_scalar(fill_value, dtype)
1124+
fill_value = _cast_fill_val(fill_value, dtype)
11151125

11161126
_manager = dpctl.utils.SequentialOrderManager[sycl_queue]
11171127
# populating new allocation, no dependent events
@@ -1479,26 +1489,15 @@ def full_like(
14791489
)
14801490
_manager.add_event_pair(hev, copy_ev)
14811491
return res
1482-
else:
1483-
fill_value_type = type(fill_value)
1484-
dtype = _get_dtype(dtype, sycl_queue, ref_type=fill_value_type)
1485-
res = _empty_like_orderK(x, dtype, usm_type, sycl_queue)
1486-
if fill_value_type in [float, complex] and np.issubdtype(
1487-
dtype, np.integer
1488-
):
1489-
fill_value = int(fill_value.real)
1490-
elif fill_value_type is complex and np.issubdtype(
1491-
dtype, np.floating
1492-
):
1493-
fill_value = fill_value.real
1494-
elif fill_value_type is int and np.issubdtype(dtype, np.integer):
1495-
fill_value = _to_scalar(fill_value, dtype)
14961492

1497-
_manager = dpctl.utils.SequentialOrderManager[sycl_queue]
1498-
# populating new allocation, no dependent events
1499-
hev, full_ev = ti._full_usm_ndarray(fill_value, res, sycl_queue)
1500-
_manager.add_event_pair(hev, full_ev)
1501-
return res
1493+
dtype = _get_dtype(dtype, sycl_queue, ref_type=type(fill_value))
1494+
res = _empty_like_orderK(x, dtype, usm_type, sycl_queue)
1495+
fill_value = _cast_fill_val(fill_value, dtype)
1496+
_manager = dpctl.utils.SequentialOrderManager[sycl_queue]
1497+
# populating new allocation, no dependent events
1498+
hev, full_ev = ti._full_usm_ndarray(fill_value, res, sycl_queue)
1499+
_manager.add_event_pair(hev, full_ev)
1500+
return res
15021501
else:
15031502
return full(
15041503
sh,

0 commit comments

Comments
 (0)