|
28 | 28 |
|
29 | 29 | #pragma once |
30 | 30 |
|
31 | | -#include "dpctl_capi.h" |
| 31 | +// TODO: Enable dpctl_capi.h once dpctl.tensor is removed. |
| 32 | +// Also call `import_dpctl_ext__tensor___usmarray();` right after |
| 33 | +// `import_dpctl()` (line 334) to initialize the dpctl_ext tensor C-API. |
| 34 | +// |
| 35 | +// Now we include dpctl C-API headers explicitly in order to |
| 36 | +// integrate dpctl_ext tensor C-API. |
| 37 | + |
| 38 | +// #include "dpctl_capi.h" |
| 39 | + |
| 40 | +// clang-format off |
| 41 | +// Ordering of includes is important here. dpctl_sycl_types and |
| 42 | +// dpctl_sycl_extension_interface define types used by dpctl's Python |
| 43 | +// C-API headers. |
| 44 | +#include "syclinterface/dpctl_sycl_types.h" |
| 45 | +#include "syclinterface/dpctl_sycl_extension_interface.h" |
| 46 | +#ifdef __cplusplus |
| 47 | +#define CYTHON_EXTERN_C extern "C" |
| 48 | +#else |
| 49 | +#define CYTHON_EXTERN_C |
| 50 | +#endif |
| 51 | +#include "dpctl/_sycl_device.h" |
| 52 | +#include "dpctl/_sycl_device_api.h" |
| 53 | +#include "dpctl/_sycl_context.h" |
| 54 | +#include "dpctl/_sycl_context_api.h" |
| 55 | +#include "dpctl/_sycl_event.h" |
| 56 | +#include "dpctl/_sycl_event_api.h" |
| 57 | +#include "dpctl/_sycl_queue.h" |
| 58 | +#include "dpctl/_sycl_queue_api.h" |
| 59 | +#include "dpctl/memory/_memory.h" |
| 60 | +#include "dpctl/memory/_memory_api.h" |
| 61 | +#include "dpctl/program/_program.h" |
| 62 | +#include "dpctl/program/_program_api.h" |
| 63 | + |
| 64 | +// clang-format on |
| 65 | + |
| 66 | +#include "../../../dpctl_ext/include/dpctl_ext/tensor/_usmarray.h" |
| 67 | +#include "../../../dpctl_ext/include/dpctl_ext/tensor/_usmarray_api.h" |
| 68 | + |
| 69 | +/* |
| 70 | + * Function to import dpctl and make C-API functions available. |
| 71 | + * C functions can use dpctl's C-API functions without linking to |
| 72 | + * shared objects defining this symbols, if they call `import_dpctl()` |
| 73 | + * prior to using those symbols. |
| 74 | + * |
| 75 | + * It is declared inline to allow multiple definitions in |
| 76 | + * different translation units |
| 77 | + */ |
| 78 | +static inline void import_dpctl(void) |
| 79 | +{ |
| 80 | + import_dpctl___sycl_device(); |
| 81 | + import_dpctl___sycl_context(); |
| 82 | + import_dpctl___sycl_event(); |
| 83 | + import_dpctl___sycl_queue(); |
| 84 | + import_dpctl__memory___memory(); |
| 85 | + import_dpctl_ext__tensor___usmarray(); |
| 86 | + import_dpctl__program___program(); |
| 87 | + return; |
| 88 | +} |
32 | 89 |
|
33 | 90 | #include <complex> |
34 | 91 | #include <cstddef> // for std::size_t for C++ linkage |
@@ -410,8 +467,10 @@ class dpctl_capi |
410 | 467 | default_usm_memory_ = std::shared_ptr<py::object>( |
411 | 468 | new py::object{py_default_usm_memory}, Deleter{}); |
412 | 469 |
|
| 470 | + // TODO: revert to `py::module_::import("dpctl.tensor._usmarray");` |
| 471 | + // when dpnp fully migrates dpctl/tensor |
413 | 472 | py::module_ mod_usmarray = |
414 | | - py::module_::import("dpctl.tensor._usmarray"); |
| 473 | + py::module_::import("dpctl_ext.tensor._usmarray"); |
415 | 474 | auto tensor_kl = mod_usmarray.attr("usm_ndarray"); |
416 | 475 |
|
417 | 476 | const py::object &py_default_usm_ndarray = |
|
0 commit comments