-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
[C API]: Compile ONNX bindings with GPU support #9
Comments
Using this branch of the ONNX Rust bindings - https://github.com/radu-matei/onnxruntime-rs/tree/cuda, the following patch works with CUDA 10.2:
Patch: From 4315f65e7fd816f0568f4f12ee58640a61e6610b Mon Sep 17 00:00:00 2001
From: Radu M <root@radu.sh>
Date: Sun, 27 Jun 2021 11:12:39 +0000
Subject: [PATCH] Trying to enable CUDA
---
Cargo.lock | 6 ++++--
crates/wasi-nn-onnx-wasmtime/Cargo.toml | 2 +-
crates/wasi-nn-onnx-wasmtime/src/onnx_runtime.rs | 1 +
3 files changed, 6 insertions(+), 3 deletions(-)
diff --git a/Cargo.lock b/Cargo.lock
index 5fc7a4d..cf77155 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -1,5 +1,7 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
+version = 3
+
[[package]]
name = "addr2line"
version = "0.15.2"
@@ -1489,7 +1491,7 @@ checksum = "692fcb63b64b1758029e0a96ee63e049ce8c5948587f2f7208df04625e5f6b56"
[[package]]
name = "onnxruntime"
version = "0.0.12"
-source = "git+https://github.com/radu-matei/onnxruntime-rs?branch=owned-session#5f47f47b24793c0d0fbb314e854cc04395b9108f"
+source = "git+https://github.com/radu-matei/onnxruntime-rs?branch=cuda#2e5a8649def1d6cdcdd02018f9ae7c415d5f6c25"
dependencies = [
"lazy_static",
"ndarray",
@@ -1501,7 +1503,7 @@ dependencies = [
[[package]]
name = "onnxruntime-sys"
version = "0.0.12"
-source = "git+https://github.com/radu-matei/onnxruntime-rs?branch=owned-session#5f47f47b24793c0d0fbb314e854cc04395b9108f"
+source = "git+https://github.com/radu-matei/onnxruntime-rs?branch=cuda#2e5a8649def1d6cdcdd02018f9ae7c415d5f6c25"
dependencies = [
"flate2",
"tar",
diff --git a/crates/wasi-nn-onnx-wasmtime/Cargo.toml b/crates/wasi-nn-onnx-wasmtime/Cargo.toml
index a307e12..8979c8d 100644
--- a/crates/wasi-nn-onnx-wasmtime/Cargo.toml
+++ b/crates/wasi-nn-onnx-wasmtime/Cargo.toml
@@ -9,7 +9,7 @@ anyhow = "1.0"
byteorder = "1.4"
log = { version = "0.4", default-features = false }
ndarray = "0.15"
-onnxruntime = { git = "https://github.com/radu-matei/onnxruntime-rs", branch = "owned-session", optional = true }
+onnxruntime = { git = "https://github.com/radu-matei/onnxruntime-rs", branch = "cuda", optional = true }
thiserror = "1.0"
tract-data = "0.14"
tract-linalg = "0.14"
diff --git a/crates/wasi-nn-onnx-wasmtime/src/onnx_runtime.rs b/crates/wasi-nn-onnx-wasmtime/src/onnx_runtime.rs
index 1fd2d7e..71d4a16 100644
--- a/crates/wasi-nn-onnx-wasmtime/src/onnx_runtime.rs
+++ b/crates/wasi-nn-onnx-wasmtime/src/onnx_runtime.rs
@@ -141,6 +141,7 @@ impl WasiEphemeralNn for WasiNnOnnxCtx {
.build()?;
let session = environment
.new_owned_session_builder()?
+ .use_cuda()?
.with_optimization_level(GraphOptimizationLevel::All)?
.with_model_from_memory(model_bytes)?;
let session = OnnxSession::with_session(session)?;
--
2.17.1 Environment:
|
CUDA 10.2 might be hitting this issue - microsoft/onnxruntime#5957 In any case, the performance is significantly worse than it was expected with a Tesla P100, and I suspect it has to do with the CUDA version. |
For Windows, we should also try compiling with DirectML support - https://www.onnxruntime.ai/docs/reference/execution-providers/DirectML-ExecutionProvider.html |
I've created a PR nbigaouette/onnxruntime-rs#87 with CUDA 11 for ONNX 1.7 based on nbigaouette/onnxruntime-rs#78 I think it is what you're looking for testing. Feel free to review the branch and point out any key issues :) |
This would add the CUDA and DirectML headers and pull the appropriate shared object.
See nbigaouette/onnxruntime-rs#57
As I am not currently on a CUDA-enabled machine, labeling this as
help wanted
.The text was updated successfully, but these errors were encountered: