Skip to content

Commit

Permalink
perf(acc_op): add if condition for the element number small situations
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed Nov 2, 2022
1 parent bd4d51e commit cda9fe4
Showing 1 changed file with 76 additions and 38 deletions.
114 changes: 76 additions & 38 deletions src/adam_op/adam_op_impl_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ using std::size_t;

namespace adam_op {

constexpr int min_elements_use_omp = 1000;

template <typename scalar_t, typename other_t>
void adamForwardInplaceCPUKernel(const other_t b1,
const other_t inv_one_minus_pow_b1,
Expand All @@ -38,7 +40,8 @@ void adamForwardInplaceCPUKernel(const other_t b1,
scalar_t *__restrict__ updates_ptr,
scalar_t *__restrict__ mu_ptr,
scalar_t *__restrict__ nu_ptr) {
#pragma omp parallel for num_threads(omp_get_num_procs())
#pragma omp parallel for num_threads(std::min( \
n / (size_t)min_elements_use_omp, (size_t)omp_get_num_procs())) if (n > min_elements_use_omp)
for (size_t tid = 0; tid < n; ++tid) {
const scalar_t updates = updates_ptr[tid];
const scalar_t mu = mu_ptr[tid];
Expand Down Expand Up @@ -67,18 +70,20 @@ TensorArray<3> adamForwardInplaceCPU(const torch::Tensor &updates,
const other_t inv_one_minus_pow_b1 = 1 / (1 - std::pow(b1, count));
const other_t inv_one_minus_pow_b2 = 1 / (1 - std::pow(b2, count));

AT_DISPATCH_SCALAR_TYPES(
updates.scalar_type(), "adamForwardInplaceCPU", ([&] {
mu.mul_(scalar_t(b1)).add_(updates, 1 - scalar_t(b1));

nu.mul_(scalar_t(b2)).addcmul_(updates, updates.conj(), 1 - scalar_t(b2));

updates.copy_(mu.mul(scalar_t(inv_one_minus_pow_b1))
.div_(nu.mul(inv_one_minus_pow_b2)
.add_(scalar_t(eps_root))
.sqrt_()
.add_(scalar_t(eps))));
}));
const size_t n = getTensorPlainSize(updates);
AT_DISPATCH_SCALAR_TYPES(updates.scalar_type(), "adamForwardInplaceCPU", ([&] {
adamForwardInplaceCPUKernel<scalar_t, scalar_t>(
scalar_t(b1),
scalar_t(inv_one_minus_pow_b1),
scalar_t(b2),
scalar_t(inv_one_minus_pow_b2),
scalar_t(eps),
scalar_t(eps_root),
n,
updates.data_ptr<scalar_t>(),
mu.data_ptr<scalar_t>(),
nu.data_ptr<scalar_t>());
}));
return TensorArray<3>{updates, mu, nu};
}

Expand All @@ -88,7 +93,8 @@ void adamForwardMuCPUKernel(const scalar_t *__restrict__ updates_ptr,
const other_t b1,
const size_t n,
scalar_t *__restrict__ mu_out_ptr) {
#pragma omp parallel for num_threads(omp_get_num_procs())
#pragma omp parallel for num_threads(std::min( \
n / (size_t)min_elements_use_omp, (size_t)omp_get_num_procs())) if (n > min_elements_use_omp)
for (size_t tid = 0; tid < n; ++tid) {
const scalar_t updates = updates_ptr[tid];
const scalar_t mu = mu_ptr[tid];
Expand All @@ -100,10 +106,16 @@ void adamForwardMuCPUKernel(const scalar_t *__restrict__ updates_ptr,
torch::Tensor adamForwardMuCPU(const torch::Tensor &updates,
const torch::Tensor &mu,
const pyfloat_t b1) {
torch::Tensor mu_out;
auto mu_out = torch::empty_like(mu);

const size_t n = getTensorPlainSize(updates);
AT_DISPATCH_SCALAR_TYPES(updates.scalar_type(), "adamForwardMuCPU", ([&] {
mu_out = mu.mul(b1).add_(updates, 1 - scalar_t(b1));
adamForwardMuCPUKernel<scalar_t, scalar_t>(
updates.data_ptr<scalar_t>(),
mu.data_ptr<scalar_t>(),
scalar_t(b1),
n,
mu_out.data_ptr<scalar_t>());
}));
return mu_out;
}
Expand All @@ -114,7 +126,8 @@ void adamForwardNuCPUKernel(const scalar_t *__restrict__ updates_ptr,
const other_t b2,
const size_t n,
scalar_t *__restrict__ nu_out_ptr) {
#pragma omp parallel for num_threads(omp_get_num_procs())
#pragma omp parallel for num_threads(std::min( \
n / (size_t)min_elements_use_omp, (size_t)omp_get_num_procs())) if (n > min_elements_use_omp)
for (size_t tid = 0; tid < n; ++tid) {
const scalar_t updates = updates_ptr[tid];
const scalar_t nu = nu_ptr[tid];
Expand All @@ -127,11 +140,16 @@ void adamForwardNuCPUKernel(const scalar_t *__restrict__ updates_ptr,
torch::Tensor adamForwardNuCPU(const torch::Tensor &updates,
const torch::Tensor &nu,
const pyfloat_t b2) {
torch::Tensor nu_out;
auto nu_out = torch::empty_like(nu);

const size_t n = getTensorPlainSize(updates);
AT_DISPATCH_SCALAR_TYPES(updates.scalar_type(), "adamForwardNuCPU", ([&] {
nu_out =
nu.mul(b2).addcmul_(updates, updates.conj(), 1 - scalar_t(b2));
adamForwardNuCPUKernel<scalar_t, scalar_t>(
updates.data_ptr<scalar_t>(),
nu.data_ptr<scalar_t>(),
scalar_t(b2),
n,
nu_out.data_ptr<scalar_t>());
}));
return nu_out;
}
Expand All @@ -145,7 +163,8 @@ void adamForwardUpdatesCPUKernel(const scalar_t *__restrict__ new_mu_ptr,
const other_t eps_root,
const size_t n,
scalar_t *__restrict__ updates_out_ptr) {
#pragma omp parallel for num_threads(omp_get_num_procs())
#pragma omp parallel for num_threads(std::min( \
n / (size_t)min_elements_use_omp, (size_t)omp_get_num_procs())) if (n > min_elements_use_omp)
for (size_t tid = 0; tid < n; ++tid) {
const scalar_t new_mu = new_mu_ptr[tid];
const scalar_t new_nu = new_nu_ptr[tid];
Expand All @@ -164,19 +183,24 @@ torch::Tensor adamForwardUpdatesCPU(const torch::Tensor &new_mu,
const pyuint_t count) {
using other_t = pyfloat_t;

torch::Tensor updates_out;
auto updates_out = torch::empty_like(new_mu);

const other_t one_minus_pow_b1 = 1 - std::pow(b1, count);
const other_t inv_one_minus_pow_b1 = 1 / one_minus_pow_b1;
const other_t one_minus_pow_b2 = 1 - std::pow(b2, count);
const other_t inv_one_minus_pow_b2 = 1 / one_minus_pow_b2;

const size_t n = getTensorPlainSize(new_mu);
AT_DISPATCH_SCALAR_TYPES(new_mu.scalar_type(), "adamForwardUpdatesCPU", ([&] {
updates_out = new_mu.mul(scalar_t(inv_one_minus_pow_b1))
.div_(new_nu.mul(scalar_t(inv_one_minus_pow_b2))
.add_(scalar_t(eps_root))
.sqrt_()
.add_(scalar_t(eps)));
adamForwardUpdatesCPUKernel<scalar_t, scalar_t>(
new_mu.data_ptr<scalar_t>(),
new_nu.data_ptr<scalar_t>(),
scalar_t(inv_one_minus_pow_b1),
scalar_t(inv_one_minus_pow_b2),
scalar_t(eps),
scalar_t(eps_root),
n,
updates_out.data_ptr<scalar_t>());
}));
return updates_out;
}
Expand All @@ -187,7 +211,8 @@ void adamBackwardMuCPUKernel(const scalar_t *__restrict__ dmu_ptr,
const size_t n,
scalar_t *__restrict__ dupdates_out_ptr,
scalar_t *__restrict__ dmu_out_ptr) {
#pragma omp parallel for num_threads(omp_get_num_procs())
#pragma omp parallel for num_threads(std::min( \
n / (size_t)min_elements_use_omp, (size_t)omp_get_num_procs())) if (n > min_elements_use_omp)
for (size_t tid = 0; tid < n; ++tid) {
const scalar_t dmu = dmu_ptr[tid];

Expand All @@ -200,12 +225,17 @@ TensorArray<2> adamBackwardMuCPU(const torch::Tensor &dmu,
const torch::Tensor &updates,
const torch::Tensor &mu,
const pyfloat_t b1) {
torch::Tensor dupdates_out;
torch::Tensor dmu_out;
auto dupdates_out = torch::empty_like(updates);
auto dmu_out = torch::empty_like(mu);

const size_t n = getTensorPlainSize(dmu);
AT_DISPATCH_SCALAR_TYPES(dmu.scalar_type(), "adamBackwardMuCPU", ([&] {
dupdates_out = dmu.mul(1 - scalar_t(b1));
dmu_out = dmu.mul(scalar_t(b1));
adamBackwardMuCPUKernel<scalar_t, scalar_t>(
dmu.data_ptr<scalar_t>(),
scalar_t(b1),
n,
dupdates_out.data_ptr<scalar_t>(),
dmu_out.data_ptr<scalar_t>());
}));
return TensorArray<2>{std::move(dupdates_out), std::move(dmu_out)};
}
Expand All @@ -217,7 +247,8 @@ void adamBackwardNuCPUKernel(const scalar_t *__restrict__ dnu_ptr,
const size_t n,
scalar_t *__restrict__ dupdates_out_ptr,
scalar_t *__restrict__ dnu_out_ptr) {
#pragma omp parallel for num_threads(omp_get_num_procs())
#pragma omp parallel for num_threads(std::min( \
n / (size_t)min_elements_use_omp, (size_t)omp_get_num_procs())) if (n > min_elements_use_omp)
for (size_t tid = 0; tid < n; ++tid) {
const scalar_t dnu = dnu_ptr[tid];
const scalar_t updates = updates_ptr[tid];
Expand All @@ -231,12 +262,18 @@ TensorArray<2> adamBackwardNuCPU(const torch::Tensor &dnu,
const torch::Tensor &updates,
const torch::Tensor &nu,
const pyfloat_t b2) {
torch::Tensor dupdates_out;
torch::Tensor dnu_out;
auto dupdates_out = torch::empty_like(updates);
auto dnu_out = torch::empty_like(nu);

const size_t n = getTensorPlainSize(dnu);
AT_DISPATCH_SCALAR_TYPES(dnu.scalar_type(), "adamForwardNuCPU", ([&] {
dupdates_out = updates.mul(2 - 2 * scalar_t(b2)).mul_(dnu);
dnu_out = dnu.mul(scalar_t(b2));
adamBackwardNuCPUKernel<scalar_t, scalar_t>(
dnu.data_ptr<scalar_t>(),
updates.data_ptr<scalar_t>(),
scalar_t(b2),
n,
dupdates_out.data_ptr<scalar_t>(),
dnu_out.data_ptr<scalar_t>());
}));
return TensorArray<2>{std::move(dupdates_out), std::move(dnu_out)};
}
Expand All @@ -250,7 +287,8 @@ void adamBackwardUpdatesCPUKernel(const scalar_t *__restrict__ dupdates_ptr,
const size_t n,
scalar_t *__restrict__ dnew_mu_out_ptr,
scalar_t *__restrict__ dnew_nu_out_ptr) {
#pragma omp parallel for num_threads(omp_get_num_procs())
#pragma omp parallel for num_threads(std::min( \
n / (size_t)min_elements_use_omp, (size_t)omp_get_num_procs())) if (n > min_elements_use_omp)
for (size_t tid = 0; tid < n; ++tid) {
const scalar_t dupdates = dupdates_ptr[tid];
const scalar_t updates = updates_ptr[tid];
Expand Down

0 comments on commit cda9fe4

Please # to comment.