Skip to content

Commit

Permalink
chore(accelerated_op): use corret Python Ctype for pybind11 function …
Browse files Browse the repository at this point in the history
…prototype
  • Loading branch information
XuehaiPan committed Aug 7, 2022
1 parent 5c2b70a commit 83b1c5f
Show file tree
Hide file tree
Showing 8 changed files with 154 additions and 150 deletions.
12 changes: 6 additions & 6 deletions conda-recipe.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,28 +19,28 @@ dependencies:
# Learning
- pytorch::pytorch = 1.12
- pytorch::torchvision
- pytorch::pytorch-mutex = *=*cuda*
- pytorch::pytorch-mutex
- pip:
- functorch
- torchviz
- sphinxcontrib-katex # for documentation
- jax
- jaxlib >= 0.3=*cuda*
- jaxlib >= 0.3
- optax # for tutorials
- tensorboard # for examples
- wandb

# Device select
- nvidia::cudatoolkit = 11.6
- cudnn
# - nvidia::cudatoolkit = 11.6
# - cudnn

# Build toolchain
- cmake >= 3.4
- make
- cxx-compiler
- gxx = 10
- nvidia/label/cuda-11.6.2::cuda-nvcc
- nvidia/label/cuda-11.6.2::cuda-cudart-dev
# - nvidia/label/cuda-11.6.2::cuda-nvcc
# - nvidia/label/cuda-11.6.2::cuda-cudart-dev
- patchelf >= 0.9
- pybind11

Expand Down
27 changes: 15 additions & 12 deletions include/adam_op/adam_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,32 +23,35 @@
namespace torchopt {
TensorArray<3> adamForwardInplace(const torch::Tensor& updates,
const torch::Tensor& mu,
const torch::Tensor& nu, const float b1,
const float b2, const float eps,
const float eps_root, const int count);
const torch::Tensor& nu, const pyfloat_t b1,
const pyfloat_t b2, const pyfloat_t eps,
const pyfloat_t eps_root,
const pyuint_t count);

torch::Tensor adamForwardMu(const torch::Tensor& updates,
const torch::Tensor& mu, const float b1);
const torch::Tensor& mu, const pyfloat_t b1);

torch::Tensor adamForwardNu(const torch::Tensor& updates,
const torch::Tensor& nu, const float b2);
const torch::Tensor& nu, const pyfloat_t b2);

torch::Tensor adamForwardUpdates(const torch::Tensor& new_mu,
const torch::Tensor& new_nu, const float b1,
const float b2, const float eps,
const float eps_root, const int count);
const torch::Tensor& new_nu,
const pyfloat_t b1, const pyfloat_t b2,
const pyfloat_t eps, const pyfloat_t eps_root,
const pyuint_t count);

TensorArray<2> adamBackwardMu(const torch::Tensor& dmu,
const torch::Tensor& updates,
const torch::Tensor& mu, const float b1);
const torch::Tensor& mu, const pyfloat_t b1);

TensorArray<2> adamBackwardNu(const torch::Tensor& dnu,
const torch::Tensor& updates,
const torch::Tensor& nu, const float b2);
const torch::Tensor& nu, const pyfloat_t b2);

TensorArray<2> adamBackwardUpdates(const torch::Tensor& dupdates,
const torch::Tensor& updates,
const torch::Tensor& new_mu,
const torch::Tensor& new_nu, const float b1,
const float b2, const int count);
const torch::Tensor& new_nu,
const pyfloat_t b1, const pyfloat_t b2,
const pyuint_t count);
} // namespace torchopt
29 changes: 15 additions & 14 deletions include/adam_op/adam_op_impl_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,35 +21,36 @@
#include "include/common.h"

namespace torchopt {
TensorArray<3> adamForwardInplaceCPU(const torch::Tensor& updates,
const torch::Tensor& mu,
const torch::Tensor& nu, const float b1,
const float b2, const float eps,
const float eps_root, const int count);
TensorArray<3> adamForwardInplaceCPU(
const torch::Tensor& updates, const torch::Tensor& mu,
const torch::Tensor& nu, const pyfloat_t b1, const pyfloat_t b2,
const pyfloat_t eps, const pyfloat_t eps_root, const pyuint_t count);

torch::Tensor adamForwardMuCPU(const torch::Tensor& updates,
const torch::Tensor& mu, const float b1);
const torch::Tensor& mu, const pyfloat_t b1);

torch::Tensor adamForwardNuCPU(const torch::Tensor& updates,
const torch::Tensor& nu, const float b2);
const torch::Tensor& nu, const pyfloat_t b2);

torch::Tensor adamForwardUpdatesCPU(const torch::Tensor& new_mu,
const torch::Tensor& new_nu, const float b1,
const float b2, const float eps,
const float eps_root, const int count);
const torch::Tensor& new_nu,
const pyfloat_t b1, const pyfloat_t b2,
const pyfloat_t eps,
const pyfloat_t eps_root,
const pyuint_t count);

TensorArray<2> adamBackwardMuCPU(const torch::Tensor& dmu,
const torch::Tensor& updates,
const torch::Tensor& mu, const float b1);
const torch::Tensor& mu, const pyfloat_t b1);

TensorArray<2> adamBackwardNuCPU(const torch::Tensor& dnu,
const torch::Tensor& updates,
const torch::Tensor& nu, const float b2);
const torch::Tensor& nu, const pyfloat_t b2);

TensorArray<2> adamBackwardUpdatesCPU(const torch::Tensor& dupdates,
const torch::Tensor& updates,
const torch::Tensor& new_mu,
const torch::Tensor& new_nu,
const float b1, const float b2,
const int count);
const pyfloat_t b1, const pyfloat_t b2,
const pyuint_t count);
} // namespace torchopt
28 changes: 14 additions & 14 deletions include/adam_op/adam_op_impl_cuda.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,36 +21,36 @@
#include "include/common.h"

namespace torchopt {
TensorArray<3> adamForwardInplaceCUDA(const torch::Tensor &updates,
const torch::Tensor &mu,
const torch::Tensor &nu, const float b1,
const float b2, const float eps,
const float eps_root, const int count);
TensorArray<3> adamForwardInplaceCUDA(
const torch::Tensor &updates, const torch::Tensor &mu,
const torch::Tensor &nu, const pyfloat_t b1, const pyfloat_t b2,
const pyfloat_t eps, const pyfloat_t eps_root, const pyuint_t count);

torch::Tensor adamForwardMuCUDA(const torch::Tensor &updates,
const torch::Tensor &mu, const float b1);
const torch::Tensor &mu, const pyfloat_t b1);

torch::Tensor adamForwardNuCUDA(const torch::Tensor &updates,
const torch::Tensor &nu, const float b2);
const torch::Tensor &nu, const pyfloat_t b2);

torch::Tensor adamForwardUpdatesCUDA(const torch::Tensor &new_mu,
const torch::Tensor &new_nu,
const float b1, const float b2,
const float eps, const float eps_root,
const int count);
const pyfloat_t b1, const pyfloat_t b2,
const pyfloat_t eps,
const pyfloat_t eps_root,
const pyuint_t count);

TensorArray<2> adamBackwardMuCUDA(const torch::Tensor &dmu,
const torch::Tensor &updates,
const torch::Tensor &mu, const float b1);
const torch::Tensor &mu, const pyfloat_t b1);

TensorArray<2> adamBackwardNuCUDA(const torch::Tensor &dnu,
const torch::Tensor &updates,
const torch::Tensor &nu, const float b2);
const torch::Tensor &nu, const pyfloat_t b2);

TensorArray<2> adamBackwardUpdatesCUDA(const torch::Tensor &dupdates,
const torch::Tensor &updates,
const torch::Tensor &new_mu,
const torch::Tensor &new_nu,
const float b1, const float b2,
const int count);
const pyfloat_t b1, const pyfloat_t b2,
const pyuint_t count);
} // namespace torchopt
4 changes: 4 additions & 0 deletions include/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
#include <torch/extension.h>

#include <array>
#include <cstddef>

using pyfloat_t = double;
using pyuint_t = std::size_t;

namespace torchopt {
template <size_t _Nm>
Expand Down
27 changes: 15 additions & 12 deletions src/adam_op/adam_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@
namespace torchopt {
TensorArray<3> adamForwardInplace(const torch::Tensor& updates,
const torch::Tensor& mu,
const torch::Tensor& nu, const float b1,
const float b2, const float eps,
const float eps_root, const int count) {
const torch::Tensor& nu, const pyfloat_t b1,
const pyfloat_t b2, const pyfloat_t eps,
const pyfloat_t eps_root,
const pyuint_t count) {
#if defined(__CUDACC__)
if (updates.device().is_cuda()) {
return adamForwardInplaceCUDA(updates, mu, nu, b1, b2, eps, eps_root,
Expand All @@ -42,7 +43,7 @@ TensorArray<3> adamForwardInplace(const torch::Tensor& updates,
}
}
torch::Tensor adamForwardMu(const torch::Tensor& updates,
const torch::Tensor& mu, const float b1) {
const torch::Tensor& mu, const pyfloat_t b1) {
#if defined(__CUDACC__)
if (updates.device().is_cuda()) {
return adamForwardMuCUDA(updates, mu, b1);
Expand All @@ -56,7 +57,7 @@ torch::Tensor adamForwardMu(const torch::Tensor& updates,
}

torch::Tensor adamForwardNu(const torch::Tensor& updates,
const torch::Tensor& nu, const float b2) {
const torch::Tensor& nu, const pyfloat_t b2) {
#if defined(__CUDACC__)
if (updates.device().is_cuda()) {
return adamForwardNuCUDA(updates, nu, b2);
Expand All @@ -70,9 +71,10 @@ torch::Tensor adamForwardNu(const torch::Tensor& updates,
}

torch::Tensor adamForwardUpdates(const torch::Tensor& new_mu,
const torch::Tensor& new_nu, const float b1,
const float b2, const float eps,
const float eps_root, const int count) {
const torch::Tensor& new_nu,
const pyfloat_t b1, const pyfloat_t b2,
const pyfloat_t eps, const pyfloat_t eps_root,
const pyuint_t count) {
#if defined(__CUDACC__)
if (new_mu.device().is_cuda()) {
return adamForwardUpdatesCUDA(new_mu, new_nu, b1, b2, eps, eps_root, count);
Expand All @@ -87,7 +89,7 @@ torch::Tensor adamForwardUpdates(const torch::Tensor& new_mu,

TensorArray<2> adamBackwardMu(const torch::Tensor& dmu,
const torch::Tensor& updates,
const torch::Tensor& mu, const float b1) {
const torch::Tensor& mu, const pyfloat_t b1) {
#if defined(__CUDACC__)
if (dmu.device().is_cuda()) {
return adamBackwardMuCUDA(dmu, updates, mu, b1);
Expand All @@ -102,7 +104,7 @@ TensorArray<2> adamBackwardMu(const torch::Tensor& dmu,

TensorArray<2> adamBackwardNu(const torch::Tensor& dnu,
const torch::Tensor& updates,
const torch::Tensor& nu, const float b2) {
const torch::Tensor& nu, const pyfloat_t b2) {
#if defined(__CUDACC__)
if (dnu.device().is_cuda()) {
return adamBackwardNuCUDA(dnu, updates, nu, b2);
Expand All @@ -118,8 +120,9 @@ TensorArray<2> adamBackwardNu(const torch::Tensor& dnu,
TensorArray<2> adamBackwardUpdates(const torch::Tensor& dupdates,
const torch::Tensor& updates,
const torch::Tensor& new_mu,
const torch::Tensor& new_nu, const float b1,
const float b2, const int count) {
const torch::Tensor& new_nu,
const pyfloat_t b1, const pyfloat_t b2,
const pyuint_t count) {
#if defined(__CUDACC__)
if (dupdates.device().is_cuda()) {
return adamBackwardUpdatesCUDA(dupdates, updates, new_mu, new_nu, b1, b2,
Expand Down
Loading

0 comments on commit 83b1c5f

Please # to comment.