@@ -168,7 +168,7 @@ cdef void _managed_tensor_versioned_deleter(DLManagedTensorVersioned *dlmv_tenso
168168 stdlib.free(dlmv_tensor)
169169
170170
171- cdef object _get_default_context(c_dpctl.SyclDevice dev) except * :
171+ cdef object _get_default_context(c_dpctl.SyclDevice dev):
172172 try :
173173 default_context = dev.sycl_platform.default_context
174174 except RuntimeError :
@@ -178,7 +178,7 @@ cdef object _get_default_context(c_dpctl.SyclDevice dev) except *:
178178 return default_context
179179
180180
181- cdef int get_parent_device_ordinal_id(c_dpctl.SyclDevice dev) except * :
181+ cdef int get_parent_device_ordinal_id(c_dpctl.SyclDevice dev) except - 1 :
182182 cdef DPCTLSyclDeviceRef pDRef = NULL
183183 cdef DPCTLSyclDeviceRef tDRef = NULL
184184 cdef c_dpctl.SyclDevice p_dev
@@ -201,7 +201,7 @@ cdef int get_parent_device_ordinal_id(c_dpctl.SyclDevice dev) except *:
201201
202202cdef int get_array_dlpack_device_id(
203203 usm_ndarray usm_ary
204- ) except * :
204+ ) except - 1 :
205205 """ Finds ordinal number of the parent of device where array
206206 was allocated.
207207 """
@@ -935,6 +935,32 @@ cpdef object from_dlpack_capsule(object py_caps):
935935 " The DLPack tensor resides on unsupported device."
936936 )
937937
938+ cdef usm_ndarray _to_usm_ary_from_host_blob(object host_blob, dev : Device):
939+ q = dev.sycl_queue
940+ np_ary = np.asarray(host_blob)
941+ dt = np_ary.dtype
942+ if dt.char in " dD" and q.sycl_device.has_aspect_fp64 is False :
943+ Xusm_dtype = (
944+ " float32" if dt.char == " d" else " complex64"
945+ )
946+ else :
947+ Xusm_dtype = dt
948+ usm_mem = dpmem.MemoryUSMDevice(np_ary.nbytes, queue = q)
949+ usm_ary = usm_ndarray(np_ary.shape, dtype = Xusm_dtype, buffer = usm_mem)
950+ usm_mem.copy_from_host(np.reshape(np_ary.view(dtype = " u1" ), - 1 ))
951+ return usm_ary
952+
953+
954+ # only cdef to make it private
955+ cdef object _create_device(object device, object dl_device):
956+ if isinstance (device, Device):
957+ return device
958+ elif isinstance (device, dpctl.SyclDevice):
959+ return Device.create_device(device)
960+ else :
961+ root_device = dpctl.SyclDevice(str (< int > dl_device[1 ]))
962+ return Device.create_device(root_device)
963+
938964
939965def from_dlpack (x , /, *, device = None , copy = None ):
940966 """ from_dlpack(x, /, *, device=None, copy=None)
@@ -943,7 +969,7 @@ def from_dlpack(x, /, *, device=None, copy=None):
943969 object ``x`` that implements ``__dlpack__`` protocol.
944970
945971 Args:
946- x (Python object):
972+ x (object):
947973 A Python object representing an array that supports
948974 ``__dlpack__`` protocol.
949975 device (Optional[str,
@@ -959,7 +985,8 @@ def from_dlpack(x, /, *, device=None, copy=None):
959985 returned by :attr:`dpctl.tensor.usm_ndarray.device`, or a
960986 2-tuple matching the format of the output of the ``__dlpack_device__``
961987 method, an integer enumerator representing the device type followed by
962- an integer representing the index of the device.
988+ an integer representing the index of the device. The only supported
989+ :enum:`dpctl.tensor.DLDeviceType` types are "kDLCPU" and "kDLOneAPI".
963990 Default: ``None``.
964991 copy (bool, optional)
965992 Boolean indicating whether or not to copy the input.
@@ -1008,33 +1035,130 @@ def from_dlpack(x, /, *, device=None, copy=None):
10081035
10091036 C = Container(dpt.linspace(0, 100, num=20, dtype="int16"))
10101037 X = dpt.from_dlpack(C)
1038+ Y = dpt.from_dlpack(C, device=(dpt.DLDeviceType.kDLCPU, 0))
10111039
10121040 """
1013- if not hasattr (x, " __dlpack__" ):
1014- raise TypeError (
1015- f" The argument of type {type(x)} does not implement "
1016- " `__dlpack__` method."
1017- )
1018- dlpack_attr = getattr (x, " __dlpack__" )
1019- if not callable (dlpack_attr):
1041+ dlpack_attr = getattr (x, " __dlpack__" , None )
1042+ dlpack_dev_attr = getattr (x, " __dlpack_device__" , None )
1043+ if not callable (dlpack_attr) or not callable (dlpack_dev_attr):
10201044 raise TypeError (
10211045 f" The argument of type {type(x)} does not implement "
1022- " `__dlpack__` method ."
1046+ " `__dlpack__` and `__dlpack_device__` methods ."
10231047 )
1024- try :
1025- # device is converted to a dlpack_device if necessary
1026- dl_device = None
1027- if device:
1028- if isinstance (device, tuple ):
1029- dl_device = device
1048+ # device is converted to a dlpack_device if necessary
1049+ dl_device = None
1050+ if device:
1051+ if isinstance (device, tuple ):
1052+ dl_device = device
1053+ if len (dl_device) != 2 :
1054+ raise ValueError (
1055+ " Argument `device` specified as a tuple must have length 2"
1056+ )
1057+ else :
1058+ if not isinstance (device, dpctl.SyclDevice):
1059+ device = Device.create_device(device)
1060+ d = device.sycl_device
10301061 else :
1031- if not isinstance (device, dpctl.SyclDevice):
1032- d = Device.create_device(device).sycl_device
1033- dl_device = (device_OneAPI, get_parent_device_ordinal_id(< c_dpctl.SyclDevice> d))
1034- else :
1035- dl_device = (device_OneAPI, get_parent_device_ordinal_id(< c_dpctl.SyclDevice> device))
1036- dlpack_capsule = dlpack_attr(max_version = get_build_dlpack_version(), dl_device = dl_device, copy = copy)
1037- return from_dlpack_capsule(dlpack_capsule)
1062+ d = device
1063+ dl_device = (device_OneAPI, get_parent_device_ordinal_id(< c_dpctl.SyclDevice> d))
1064+ if dl_device is not None :
1065+ if (dl_device[0 ] not in [device_OneAPI, device_CPU]):
1066+ raise ValueError (
1067+ f" Argument `device`={device} is not supported."
1068+ )
1069+ got_type_error = False
1070+ got_buffer_error = False
1071+ got_other_error = False
1072+ saved_exception = None
1073+ # First DLPack version supporting dl_device, and copy
1074+ requested_ver = (1 , 0 )
1075+ cpu_dev = (device_CPU, 0 )
1076+ try :
1077+ # setting max_version to minimal version that supports dl_device/copy keywords
1078+ dlpack_capsule = dlpack_attr(
1079+ max_version = requested_ver,
1080+ dl_device = dl_device,
1081+ copy = copy
1082+ )
10381083 except TypeError :
1039- dlpack_capsule = dlpack_attr()
1084+ # exporter does not support max_version keyword
1085+ got_type_error = True
1086+ except (BufferError, NotImplementedError ):
1087+ # Either dl_device, or copy can be satisfied
1088+ got_buffer_error = True
1089+ except Exception as e:
1090+ got_other_error = True
1091+ saved_exception = e
1092+ else :
1093+ # execution did not raise exceptions
10401094 return from_dlpack_capsule(dlpack_capsule)
1095+ finally :
1096+ if got_type_error:
1097+ # max_version/dl_device, copy keywords are not supported by __dlpack__
1098+ x_dldev = dlpack_dev_attr()
1099+ if (dl_device is None ) or (dl_device == x_dldev):
1100+ dlpack_capsule = dlpack_attr()
1101+ return from_dlpack_capsule(dlpack_capsule)
1102+ # must copy via host
1103+ if copy is False :
1104+ raise BufferError(
1105+ " Importing data via DLPack requires copying, but copy=False was provided"
1106+ )
1107+ # when max_version/dl_device/copy are not supported
1108+ # we can only support importing to OneAPI devices
1109+ # from host, or from another oneAPI device
1110+ is_supported_x_dldev = (
1111+ x_dldev == cpu_dev or
1112+ (x_dldev[0 ] == device_OneAPI)
1113+ )
1114+ is_supported_dl_device = (
1115+ dl_device == cpu_dev or
1116+ dl_device[0 ] == device_OneAPI
1117+ )
1118+ if is_supported_x_dldev and is_supported_dl_device:
1119+ dlpack_capsule = dlpack_attr()
1120+ blob = from_dlpack_capsule(dlpack_capsule)
1121+ else :
1122+ raise BufferError(f" Can not import to requested device {dl_device}" )
1123+ dev = _create_device(device, dl_device)
1124+ if x_dldev == cpu_dev and dl_device == cpu_dev:
1125+ # both source and destination are CPU
1126+ return blob
1127+ elif x_dldev == cpu_dev:
1128+ # source is CPU, destination is oneAPI
1129+ return _to_usm_ary_from_host_blob(blob, dev)
1130+ elif dl_device == cpu_dev:
1131+ # source is oneAPI, destination is CPU
1132+ cpu_caps = blob.__dlpack__(
1133+ max_version = get_build_dlpack_version(),
1134+ dl_device = cpu_dev
1135+ )
1136+ return from_dlpack_capsule(cpu_caps)
1137+ else :
1138+ import dpctl.tensor as dpt
1139+ return dpt.asarray(blob, device = dev)
1140+ elif got_buffer_error:
1141+ # we are here, because dlpack_attr could not deal with requested dl_device,
1142+ # or copying was required
1143+ if copy is False :
1144+ raise BufferError(
1145+ " Importing data via DLPack requires copying, but copy=False was provided"
1146+ )
1147+ # must copy via host
1148+ if dl_device[0 ] != device_OneAPI:
1149+ raise BufferError(f" Can not import to requested device {dl_device}" )
1150+ x_dldev = dlpack_dev_attr()
1151+ if x_dldev == cpu_dev:
1152+ dlpack_capsule = dlpack_attr()
1153+ host_blob = from_dlpack_capsule(dlpack_capsule)
1154+ else :
1155+ dlpack_capsule = dlpack_attr(
1156+ max_version = requested_ver,
1157+ dl_device = cpu_dev,
1158+ copy = copy
1159+ )
1160+ host_blob = from_dlpack_capsule(dlpack_capsule)
1161+ dev = _create_device(device, dl_device)
1162+ return _to_usm_ary_from_host_blob(host_blob, dev)
1163+ elif got_other_error:
1164+ raise saved_exception
0 commit comments