Skip to content

Commit ade264e

Browse files
Integrate dpctl_ext.tensor C-API to dpnp4pybind11.hpp
1 parent 5e7123d commit ade264e

1 file changed

Lines changed: 61 additions & 2 deletions

File tree

dpnp/backend/include/dpnp4pybind11.hpp

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,64 @@
2828

2929
#pragma once
3030

31-
#include "dpctl_capi.h"
31+
// TODO: Enable dpctl_capi.h once dpctl.tensor is removed.
32+
// Also call `import_dpctl_ext__tensor___usmarray();` right after
33+
// `import_dpctl()` (line 334) to initialize the dpctl_ext tensor C-API.
34+
//
35+
// Now we include dpctl C-API headers explicitly in order to
36+
// integrate dpctl_ext tensor C-API.
37+
38+
// #include "dpctl_capi.h"
39+
40+
// clang-format off
41+
// Ordering of includes is important here. dpctl_sycl_types and
42+
// dpctl_sycl_extension_interface define types used by dpctl's Python
43+
// C-API headers.
44+
#include "syclinterface/dpctl_sycl_types.h"
45+
#include "syclinterface/dpctl_sycl_extension_interface.h"
46+
#ifdef __cplusplus
47+
#define CYTHON_EXTERN_C extern "C"
48+
#else
49+
#define CYTHON_EXTERN_C
50+
#endif
51+
#include "dpctl/_sycl_device.h"
52+
#include "dpctl/_sycl_device_api.h"
53+
#include "dpctl/_sycl_context.h"
54+
#include "dpctl/_sycl_context_api.h"
55+
#include "dpctl/_sycl_event.h"
56+
#include "dpctl/_sycl_event_api.h"
57+
#include "dpctl/_sycl_queue.h"
58+
#include "dpctl/_sycl_queue_api.h"
59+
#include "dpctl/memory/_memory.h"
60+
#include "dpctl/memory/_memory_api.h"
61+
#include "dpctl/program/_program.h"
62+
#include "dpctl/program/_program_api.h"
63+
64+
// clang-format on
65+
66+
#include "../../../dpctl_ext/include/dpctl_ext/tensor/_usmarray.h"
67+
#include "../../../dpctl_ext/include/dpctl_ext/tensor/_usmarray_api.h"
68+
69+
/*
70+
* Function to import dpctl and make C-API functions available.
71+
* C functions can use dpctl's C-API functions without linking to
72+
* shared objects defining this symbols, if they call `import_dpctl()`
73+
* prior to using those symbols.
74+
*
75+
* It is declared inline to allow multiple definitions in
76+
* different translation units
77+
*/
78+
static inline void import_dpctl(void)
79+
{
80+
import_dpctl___sycl_device();
81+
import_dpctl___sycl_context();
82+
import_dpctl___sycl_event();
83+
import_dpctl___sycl_queue();
84+
import_dpctl__memory___memory();
85+
import_dpctl_ext__tensor___usmarray();
86+
import_dpctl__program___program();
87+
return;
88+
}
3289

3390
#include <complex>
3491
#include <cstddef> // for std::size_t for C++ linkage
@@ -410,8 +467,10 @@ class dpctl_capi
410467
default_usm_memory_ = std::shared_ptr<py::object>(
411468
new py::object{py_default_usm_memory}, Deleter{});
412469

470+
// TODO: revert to `py::module_::import("dpctl.tensor._usmarray");`
471+
// when dpnp fully migrates dpctl/tensor
413472
py::module_ mod_usmarray =
414-
py::module_::import("dpctl.tensor._usmarray");
473+
py::module_::import("dpctl_ext.tensor._usmarray");
415474
auto tensor_kl = mod_usmarray.attr("usm_ndarray");
416475

417476
const py::object &py_default_usm_ndarray =

0 commit comments

Comments
 (0)