Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

perf(acc_op): add if condition for the element number small situations #105

Merged
merged 5 commits into from
Nov 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Add if condition of number of threads for CPU OPs by [@JieRen98](https://github.com/JieRen98) in [#105](https://github.com/metaopt/torchopt/pull/105).
- Add implicit MAML omniglot few-shot classification example with OOP APIs by [@XuehaiPan](https://github.com/XuehaiPan) in [#107](https://github.com/metaopt/torchopt/pull/107).
- Add implicit MAML omniglot few-shot classification example by [@Benjamin-eecs](https://github.com/Benjamin-eecs) in [#48](https://github.com/metaopt/torchopt/pull/48).
- Add object-oriented modules support for implicit meta-gradient by [@XuehaiPan](https://github.com/XuehaiPan) in [#101](https://github.com/metaopt/torchopt/pull/101).
Expand Down
30 changes: 23 additions & 7 deletions src/adam_op/adam_op_impl_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ using std::size_t;

namespace adam_op {

constexpr size_t MIN_NUMEL_USE_OMP = 1000;

template <typename scalar_t, typename other_t>
void adamForwardInplaceCPUKernel(const other_t b1,
const other_t inv_one_minus_pow_b1,
Expand All @@ -38,7 +40,9 @@ void adamForwardInplaceCPUKernel(const other_t b1,
scalar_t *__restrict__ updates_ptr,
scalar_t *__restrict__ mu_ptr,
scalar_t *__restrict__ nu_ptr) {
#pragma omp parallel for num_threads(omp_get_num_procs())
#pragma omp parallel for num_threads( \
std::min(n / MIN_NUMEL_USE_OMP, \
static_cast <size_t>(omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) // NOLINT
for (size_t tid = 0; tid < n; ++tid) {
const scalar_t updates = updates_ptr[tid];
const scalar_t mu = mu_ptr[tid];
Expand Down Expand Up @@ -90,7 +94,9 @@ void adamForwardMuCPUKernel(const scalar_t *__restrict__ updates_ptr,
const other_t b1,
const size_t n,
scalar_t *__restrict__ mu_out_ptr) {
#pragma omp parallel for num_threads(omp_get_num_procs())
#pragma omp parallel for num_threads( \
std::min(n / MIN_NUMEL_USE_OMP, \
static_cast <size_t>(omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) // NOLINT
for (size_t tid = 0; tid < n; ++tid) {
const scalar_t updates = updates_ptr[tid];
const scalar_t mu = mu_ptr[tid];
Expand Down Expand Up @@ -122,7 +128,9 @@ void adamForwardNuCPUKernel(const scalar_t *__restrict__ updates_ptr,
const other_t b2,
const size_t n,
scalar_t *__restrict__ nu_out_ptr) {
#pragma omp parallel for num_threads(omp_get_num_procs())
#pragma omp parallel for num_threads( \
std::min(n / MIN_NUMEL_USE_OMP, \
static_cast <size_t>(omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) // NOLINT
for (size_t tid = 0; tid < n; ++tid) {
const scalar_t updates = updates_ptr[tid];
const scalar_t nu = nu_ptr[tid];
Expand Down Expand Up @@ -158,7 +166,9 @@ void adamForwardUpdatesCPUKernel(const scalar_t *__restrict__ new_mu_ptr,
const other_t eps_root,
const size_t n,
scalar_t *__restrict__ updates_out_ptr) {
#pragma omp parallel for num_threads(omp_get_num_procs())
#pragma omp parallel for num_threads( \
std::min(n / MIN_NUMEL_USE_OMP, \
static_cast <size_t>(omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) // NOLINT
for (size_t tid = 0; tid < n; ++tid) {
const scalar_t new_mu = new_mu_ptr[tid];
const scalar_t new_nu = new_nu_ptr[tid];
Expand Down Expand Up @@ -205,7 +215,9 @@ void adamBackwardMuCPUKernel(const scalar_t *__restrict__ dmu_ptr,
const size_t n,
scalar_t *__restrict__ dupdates_out_ptr,
scalar_t *__restrict__ dmu_out_ptr) {
#pragma omp parallel for num_threads(omp_get_num_procs())
#pragma omp parallel for num_threads( \
std::min(n / MIN_NUMEL_USE_OMP, \
static_cast <size_t>(omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) // NOLINT
for (size_t tid = 0; tid < n; ++tid) {
const scalar_t dmu = dmu_ptr[tid];

Expand Down Expand Up @@ -240,7 +252,9 @@ void adamBackwardNuCPUKernel(const scalar_t *__restrict__ dnu_ptr,
const size_t n,
scalar_t *__restrict__ dupdates_out_ptr,
scalar_t *__restrict__ dnu_out_ptr) {
#pragma omp parallel for num_threads(omp_get_num_procs())
#pragma omp parallel for num_threads( \
std::min(n / MIN_NUMEL_USE_OMP, \
static_cast <size_t>(omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) // NOLINT
for (size_t tid = 0; tid < n; ++tid) {
const scalar_t dnu = dnu_ptr[tid];
const scalar_t updates = updates_ptr[tid];
Expand Down Expand Up @@ -279,7 +293,9 @@ void adamBackwardUpdatesCPUKernel(const scalar_t *__restrict__ dupdates_ptr,
const size_t n,
scalar_t *__restrict__ dnew_mu_out_ptr,
scalar_t *__restrict__ dnew_nu_out_ptr) {
#pragma omp parallel for num_threads(omp_get_num_procs())
#pragma omp parallel for num_threads( \
std::min(n / MIN_NUMEL_USE_OMP, \
static_cast <size_t>(omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) // NOLINT
for (size_t tid = 0; tid < n; ++tid) {
const scalar_t dupdates = dupdates_ptr[tid];
const scalar_t updates = updates_ptr[tid];
Expand Down