@@ -33,12 +33,18 @@ from .._backend cimport (
3333)
3434from ._usmarray cimport usm_ndarray
3535
36+ from platform import system as sys_platform
37+
3638import numpy as np
3739
3840import dpctl
3941import dpctl.memory as dpmem
4042
4143
44+ cdef bint _IS_LINUX = sys_platform() == " Linux"
45+
46+ del sys_platform
47+
4248cdef extern from ' dlpack/dlpack.h' nogil:
4349 cdef int DLPACK_VERSION
4450
@@ -140,6 +146,7 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary) except+:
140146 cdef c_dpctl.SyclQueue ary_sycl_queue
141147 cdef c_dpctl.SyclDevice ary_sycl_device
142148 cdef DPCTLSyclDeviceRef pDRef = NULL
149+ cdef DPCTLSyclDeviceRef tDRef = NULL
143150 cdef DLManagedTensor * dlm_tensor = NULL
144151 cdef DLTensor * dl_tensor = NULL
145152 cdef int nd = usm_ary.get_ndim()
@@ -157,19 +164,45 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary) except+:
157164 ary_sycl_queue = usm_ary.get_sycl_queue()
158165 ary_sycl_device = ary_sycl_queue.get_sycl_device()
159166
160- # check that ary_sycl_device is a non-partitioned device
161- pDRef = DPCTLDevice_GetParentDevice(ary_sycl_device.get_device_ref())
162- if pDRef is not NULL :
163- DPCTLDevice_Delete(pDRef)
164- raise DLPackCreationError(
165- " to_dlpack_capsule: DLPack can only export arrays allocated on "
166- " non-partitioned SYCL devices."
167- )
168- default_context = dpctl.SyclQueue(ary_sycl_device).sycl_context
169- if not usm_ary.sycl_context == default_context:
167+ try :
168+ if _IS_LINUX:
169+ default_context = ary_sycl_device.sycl_platform.default_context
170+ else :
171+ default_context = None
172+ except RuntimeError :
173+ # RT does not support default_context, e.g. Windows
174+ default_context = None
175+ if default_context is None :
176+ # check that ary_sycl_device is a non-partitioned device
177+ pDRef = DPCTLDevice_GetParentDevice(ary_sycl_device.get_device_ref())
178+ if pDRef is not NULL :
179+ DPCTLDevice_Delete(pDRef)
180+ raise DLPackCreationError(
181+ " to_dlpack_capsule: DLPack can only export arrays allocated "
182+ " on non-partitioned SYCL devices on platforms where "
183+ " default_context oneAPI extension is not supported."
184+ )
185+ else :
186+ if not usm_ary.sycl_context == default_context:
187+ raise DLPackCreationError(
188+ " to_dlpack_capsule: DLPack can only export arrays based on USM "
189+ " allocations bound to a default platform SYCL context"
190+ )
191+ # Find the unpartitioned parent of the allocation device
192+ pDRef = DPCTLDevice_GetParentDevice(ary_sycl_device.get_device_ref())
193+ if pDRef is not NULL :
194+ tDRef = DPCTLDevice_GetParentDevice(pDRef)
195+ while tDRef is not NULL :
196+ DPCTLDevice_Delete(pDRef)
197+ pDRef = tDRef
198+ tDRef = DPCTLDevice_GetParentDevice(pDRef)
199+ ary_sycl_device = c_dpctl.SyclDevice._create(pDRef)
200+
201+ # Find ordinal number of the parent device
202+ device_id = ary_sycl_device.get_overall_ordinal()
203+ if device_id < 0 :
170204 raise DLPackCreationError(
171- " to_dlpack_capsule: DLPack can only export arrays based on USM "
172- " allocations bound to a default platform SYCL context"
205+ " to_dlpack_capsule: failed to determine device_id"
173206 )
174207
175208 dlm_tensor = < DLManagedTensor * > stdlib.malloc(
@@ -192,14 +225,6 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary) except+:
192225 for i in range (nd):
193226 shape_strides_ptr[nd + i] = strides_ptr[i]
194227
195- device_id = ary_sycl_device.get_overall_ordinal()
196- if device_id < 0 :
197- stdlib.free(shape_strides_ptr)
198- stdlib.free(dlm_tensor)
199- raise DLPackCreationError(
200- " to_dlpack_capsule: failed to determine device_id"
201- )
202-
203228 ary_dt = usm_ary.dtype
204229 ary_dtk = ary_dt.kind
205230 element_offset = usm_ary.get_offset()
@@ -278,15 +303,16 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps) except +:
278303 success.
279304 Raises:
280305 TypeError: if argument is not a "dltensor" capsule.
281- ValueError: if argument is "used_dltensor" capsule,
282- if the USM pointer is not bound to the reconstructed
306+ ValueError: if argument is "used_dltensor" capsule
307+ BufferError: if the USM pointer is not bound to the reconstructed
283308 sycl context, or the DLPack's device_type is not supported
284309 by dpctl.
285310 """
286311 cdef DLManagedTensor * dlm_tensor = NULL
287312 cdef bytes usm_type
288313 cdef size_t sz = 1
289314 cdef int i
315+ cdef int device_id = - 1
290316 cdef int element_bytesize = 0
291317 cdef Py_ssize_t offset_min = 0
292318 cdef Py_ssize_t offset_max = 0
@@ -308,26 +334,40 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps) except +:
308334 py_caps, " dltensor" )
309335 # Verify that we can work with this device
310336 if dlm_tensor.dl_tensor.device.device_type == kDLOneAPI:
311- q = dpctl.SyclQueue(str (< int > dlm_tensor.dl_tensor.device.device_id))
337+ device_id = dlm_tensor.dl_tensor.device.device_id
338+ root_device = dpctl.SyclDevice(str (< int > device_id))
339+ try :
340+ if _IS_LINUX:
341+ default_context = root_device.sycl_platform.default_context
342+ else :
343+ default_context = dpctl.SyclQueue(root_device).sycl_context
344+ except RuntimeError :
345+ default_context = dpctl.SyclQueue(root_device).sycl_context
312346 if dlm_tensor.dl_tensor.data is NULL :
313347 usm_type = b" device"
348+ q = dpctl.SyclQueue(default_context, root_device)
314349 else :
315350 usm_type = c_dpmem._Memory.get_pointer_type(
316351 < DPCTLSyclUSMRef> dlm_tensor.dl_tensor.data,
317- < c_dpctl.SyclContext> q.sycl_context)
318- if usm_type == b" unknown" :
319- raise ValueError (
320- f" Data pointer in DLPack is not bound to default sycl "
321- " context of device '{device_id}', translated to "
322- " {q.sycl_device.filter_string}"
352+ < c_dpctl.SyclContext> default_context)
353+ if usm_type == b" unknown" :
354+ raise BufferError(
355+ " Data pointer in DLPack is not bound to default sycl "
356+ f" context of device '{device_id}', translated to "
357+ f" {root_device.filter_string}"
358+ )
359+ alloc_device = c_dpmem._Memory.get_pointer_device(
360+ < DPCTLSyclUSMRef> dlm_tensor.dl_tensor.data,
361+ < c_dpctl.SyclContext> default_context
323362 )
363+ q = dpctl.SyclQueue(default_context, alloc_device)
324364 if dlm_tensor.dl_tensor.dtype.bits % 8 :
325- raise ValueError (
365+ raise BufferError (
326366 " Can not import DLPack tensor whose element's "
327367 " bitsize is not a multiple of 8"
328368 )
329369 if dlm_tensor.dl_tensor.dtype.lanes != 1 :
330- raise ValueError (
370+ raise BufferError (
331371 " Can not import DLPack tensor with lanes != 1"
332372 )
333373 if dlm_tensor.dl_tensor.strides is NULL :
0 commit comments