Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Add support to dlpack arrays as kernel and dpjit arguments #1088

Open
ZzEeKkAa opened this issue Jul 7, 2023 · 0 comments
Open

Add support to dlpack arrays as kernel and dpjit arguments #1088

ZzEeKkAa opened this issue Jul 7, 2023 · 0 comments
Assignees
Labels
enhancement New feature or request
Milestone

Comments

@ZzEeKkAa
Copy link
Contributor

ZzEeKkAa commented Jul 7, 2023

The following example does not work.

Command to run:

LD_LIBRARY_PATH=$LD_LIBRARY_PATH:<python environment folder>/lib/python3.10/site-packages/jaxlib \
PJRT_NAMES_AND_LIBRARY_PATHS='xpu:<path to libitex_xla_extension with jax support>/libitex_xla_extension.so' \
TF_CPP_MIN_LOG_LEVEL=0 \
ONEAPI_DEVICE_SELECTOR=ext_oneapi_level_zero:gpu \
python example.py
import jax.numpy as jnp
from numba import prange
import numba as nb
from numba_dpex import dpjit

@dpjit
def _sum_nomask(nops, w):
    tot = nb.float32(1.0)

    for i in prange(nops):
        if w[i] > 0:
            tot += w[i]
    
    return tot

if __name__ == "__main__":
    arr = jnp.arange(10, dtype=jnp.float32)
    res = jnp.zeros(2, dtype=jnp.float32)

    res[1] = _sum_nomask(10, arr)
    print("jax:", res)

The output looks like this:

/Projects/users.yevhenii/examples/jax/libitex_xla_extension.so' TF_CPP_MIN_LOG_LEVEL=0 ONEAPI_DEVICE_SELECTOR=ext_oneapi_level_zero:gpu python numba_dpex_jax.py
dpex: [45.  0.]
2023-07-06 22:08:29.180337: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:169] XLA service 0x55904aa17ba0 initialized for platform Interpreter (this does not guarantee that XLA will be used). Devices:
2023-07-06 22:08:29.180371: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:177]   StreamExecutor device (0): Interpreter, <undefined>
2023-07-06 22:08:29.186010: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.cc:215] TfrtCpuClient created.
2023-07-06 22:08:29.186791: I external/org_tensorflow/tensorflow/compiler/xla/stream_executor/tpu/tpu_initializer_helper.cc:266] Libtpu path is: libtpu.so
2023-07-06 22:08:29.186938: I external/org_tensorflow/tensorflow/compiler/xla/stream_executor/tpu/tpu_platform_interface.cc:73] No TPU platform found.
2023-07-06 22:08:29.241398: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_api.cc:85] GetPjrtApi was found for xpu at /home/yevhenii/Projects/users.yevhenii/examples/jax/libitex_xla_extension.so
2023-07-06 22:08:29.241443: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_api.cc:58] PJRT_Api is set for device type xpu
2023-07-06 22:08:29.242084: I itex/core/devices/gpu/itex_gpu_runtime.cc:129] Selected platform: Intel(R) Level-Zero
2023-07-06 22:08:29.242429: I itex/core/devices/gpu/itex_gpu_runtime.cc:154] number of sub-devices is zero, expose root device.
2023-07-06 22:08:29.248729: I itex/core/compiler/xla/service/service.cc:176] XLA service 0x55904c85d4c0 initialized for platform sycl (this does not guarantee that XLA will be used). Devices:
2023-07-06 22:08:29.248766: I itex/core/compiler/xla/service/service.cc:184]   StreamExecutor device (0): <undefined>, <undefined>
2023-07-06 22:08:29.250558: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_c_api_client.cc:83] PjRtCApiClient created.
Traceback (most recent call last):
  File "/home/yevhenii/Projects/users.yevhenii/examples/numba-dpex-jax/numba_dpex_jax.py", line 49, in <module>
    res[1] = _sum_nomask(10, arr)
  File "/home/yevhenii/.local/share/virtualenvs/numba-dpex-x1V09ZPr/lib/python3.10/site-packages/numba/core/dispatcher.py", line 468, in _compile_for_args
    error_rewrite(e, 'typing')
  File "/home/yevhenii/.local/share/virtualenvs/numba-dpex-x1V09ZPr/lib/python3.10/site-packages/numba/core/dispatcher.py", line 409, in error_rewrite
    raise e.with_traceback(None)
numba.core.errors.TypingError: Failed in dpex_dpjit_nopython mode pipeline (step: nopython frontend)
non-precise type pyobject
During: typing of argument at /home/yevhenii/Projects/users.yevhenii/examples/numba-dpex-jax/numba_dpex_jax.py (15)

File "numba_dpex_jax.py", line 15:

@dpjit
^

This error may have been caused by the following argument(s):
- argument 1: Cannot determine Numba type of <class 'jaxlib.xla_extension.Array'>

2023-07-06 22:08:30.365912: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.cc:218] TfrtCpuClient destroyed.

We need to allow arrays as arguments if they have __dlpack__() method. It then can be converted to dpnp array (dpnp.from_dlpack(arr)) that can be passed to dpjit functions.

@ZzEeKkAa ZzEeKkAa self-assigned this Jul 7, 2023
@diptorupd diptorupd added this to the 0.22 milestone Dec 19, 2023
@ZzEeKkAa ZzEeKkAa changed the title Dlpack arrays are not supported as dpjit function arguments Add support to dlpack arrays as kernel and dpjit arguments Dec 20, 2023
@diptorupd diptorupd added enhancement New feature or request and removed feature labels Jan 22, 2024
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants